Add a 'sni' option to ssl_server2
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index 81027a6..052e56d 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -25,6 +25,17 @@
 
 #include "polarssl/config.h"
 
+#if defined(POLARSSL_SSL_SERVER_NAME_INDICATION) && defined(POLARSSL_FS_IO)
+#define POLARSSL_SNI
+#endif
+
+#if defined(POLARSSL_PLATFORM_C)
+#include "polarssl/platform.h"
+#else
+#define polarssl_malloc     malloc
+#define polarssl_free       free
+#endif
+
 #if defined(_WIN32)
 #include <windows.h>
 #endif
@@ -72,6 +83,7 @@
 #define DFL_TICKET_TIMEOUT      -1
 #define DFL_CACHE_MAX           -1
 #define DFL_CACHE_TIMEOUT       -1
+#define DFL_SNI                 NULL
 
 #define LONG_RESPONSE "<p>01-blah-blah-blah-blah-blah-blah-blah-blah-blah\r\n" \
     "02-blah-blah-blah-blah-blah-blah-blah-blah-blah-blah-blah-blah-blah\r\n"  \
@@ -116,6 +128,7 @@
     int ticket_timeout;         /* session ticket lifetime                  */
     int cache_max;              /* max number of session cache entries      */
     int cache_timeout;          /* expiration delay of session cache entries */
+    char *sni;                  /* string decribing sni information         */
 } opt;
 
 static void my_debug( void *ctx, int level, const char *str )
@@ -177,6 +190,14 @@
 #define USAGE_CACHE ""
 #endif /* POLARSSL_SSL_CACHE_C */
 
+#if defined(POLARSSL_SNI)
+#define USAGE_SNI                                                           \
+    "    sni=%%s              name1,cert1,key1[,name2,cert2,key2[,...]]\n"  \
+    "                         default: disabled\n"
+#else
+#define USAGE_SNI ""
+#endif /* POLARSSL_SNI */
+
 #if defined(POLARSSL_SSL_MAX_FRAGMENT_LENGTH)
 #define USAGE_MAX_FRAG_LEN                                      \
     "    max_frag_len=%%d     default: 16384 (tls default)\n"   \
@@ -195,6 +216,7 @@
     "    auth_mode=%%s        default: \"optional\"\n"      \
     "                        options: none, optional, required\n" \
     USAGE_IO                                                \
+    USAGE_SNI                                               \
     "\n"                                                    \
     USAGE_PSK                                               \
     "\n"                                                    \
@@ -227,6 +249,116 @@
     return( 0 );
 }
 #else
+
+#if defined(POLARSSL_SNI)
+typedef struct _sni_entry sni_entry;
+
+struct _sni_entry {
+    const char *name;
+    x509_crt *cert;
+    pk_context *key;
+    sni_entry *next;
+};
+
+/*
+ * Parse a string of triplets name1,crt1,key1[,name2,crt2,key2[,...]]
+ * into a usable sni_entry list.
+ *
+ * Note: this is not production quality: leaks memory if parsing fails,
+ * and error reporting is poor.
+ */
+sni_entry *sni_parse( char *sni_string )
+{
+    sni_entry *cur = NULL, *new = NULL;
+    char *p = sni_string;
+    char *end = p;
+    char *crt_file, *key_file;
+
+    while( *end != '\0' )
+        ++end;
+    *end = ',';
+
+    while( p <= end )
+    {
+        if( ( new = polarssl_malloc( sizeof( sni_entry ) ) ) == NULL )
+            return( NULL );
+
+        memset( new, 0, sizeof( sni_entry ) );
+
+        if( ( new->cert = polarssl_malloc( sizeof( x509_crt ) ) ) == NULL ||
+            ( new->key = polarssl_malloc( sizeof( pk_context ) ) ) == NULL )
+            return( NULL );
+
+        x509_crt_init( new->cert );
+        pk_init( new->key );
+
+        new->name = p;
+        while( *p != ',' ) if( ++p > end ) return( NULL );
+        *p++ = '\0';
+
+        crt_file = p;
+        while( *p != ',' ) if( ++p > end ) return( NULL );
+        *p++ = '\0';
+
+        key_file = p;
+        while( *p != ',' ) if( ++p > end ) return( NULL );
+        *p++ = '\0';
+
+        if( x509_crt_parse_file( new->cert, crt_file ) != 0 ||
+            pk_parse_keyfile( new->key, key_file, "" ) != 0 )
+            return( NULL );
+
+        new->next = cur;
+        cur = new;
+
+    }
+
+    return( cur );
+}
+
+void sni_free( sni_entry *head )
+{
+    sni_entry *cur = head, *next;
+
+    while( cur != NULL )
+    {
+        x509_crt_free( cur->cert );
+        polarssl_free( cur->cert );
+
+        pk_free( cur->key );
+        polarssl_free( cur->key );
+
+        next = cur->next;
+        polarssl_free( cur );
+        cur = next;
+    }
+}
+
+/*
+ * SNI callback.
+ */
+int sni_callback( void *p_info, ssl_context *ssl,
+                  const unsigned char *name, size_t name_len )
+{
+    sni_entry *cur = (sni_entry *) p_info;
+
+    while( cur != NULL )
+    {
+        if( name_len == strlen( cur->name ) &&
+            memcmp( name, cur->name, name_len ) == 0 )
+        {
+            ssl_set_own_cert( ssl, cur->cert, cur->key );
+            return( 0 );
+        }
+
+        cur = cur->next;
+    }
+
+    return( -1 );
+}
+
+#endif /* POLARSSL_SNI */
+
 int main( int argc, char *argv[] )
 {
     int ret = 0, len, written, frags;
@@ -253,6 +385,9 @@
 #if defined(POLARSSL_SSL_CACHE_C)
     ssl_cache_context cache;
 #endif
+#if defined(POLARSSL_SNI)
+    sni_entry *sni_info = NULL;
+#endif
 #if defined(POLARSSL_MEMORY_BUFFER_ALLOC_C)
     unsigned char alloc_buf[100000];
 #endif
@@ -326,6 +461,7 @@
     opt.ticket_timeout      = DFL_TICKET_TIMEOUT;
     opt.cache_max           = DFL_CACHE_MAX;
     opt.cache_timeout       = DFL_CACHE_TIMEOUT;
+    opt.sni                 = DFL_SNI;
 
     for( i = 1; i < argc; i++ )
     {
@@ -493,6 +629,10 @@
             if( opt.cache_timeout < 0 )
                 goto usage;
         }
+        else if( strcmp( p, "sni" ) == 0 )
+        {
+            opt.sni = q;
+        }
         else
             goto usage;
     }
@@ -725,6 +865,22 @@
     printf( " ok\n" );
 #endif /* POLARSSL_X509_CRT_PARSE_C */
 
+#if defined(POLARSSL_SNI)
+    if( opt.sni != NULL )
+    {
+        printf( "  . Setting up SNI information..." );
+        fflush( stdout );
+
+        if( ( sni_info = sni_parse( opt.sni ) ) == NULL )
+        {
+            printf( " failed\n" );
+            goto exit;
+        }
+
+        printf( " ok\n" );
+    }
+#endif /* POLARSSL_SNI */
+
     /*
      * 2. Setup the listening TCP socket
      */
@@ -794,6 +950,11 @@
         ssl_set_own_cert( &ssl, &srvcert2, &pkey2 );
 #endif
 
+#if defined(POLARSSL_SNI)
+    if( opt.sni != NULL )
+        ssl_set_sni( &ssl, sni_callback, sni_info );
+#endif
+
 #if defined(POLARSSL_KEY_EXCHANGE__SOME__PSK_ENABLED)
     ssl_set_psk( &ssl, psk, psk_len, (const unsigned char *) opt.psk_identity,
                  strlen( opt.psk_identity ) );
@@ -1047,6 +1208,9 @@
     x509_crt_free( &srvcert2 );
     pk_free( &pkey2 );
 #endif
+#if defined(POLARSSL_SNI)
+    sni_free( sni_info );
+#endif
 
     ssl_free( &ssl );
     entropy_free( &entropy );