Add psk_list option to ssl_server2: PSK callback
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index c932d14..e80038e 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -77,6 +77,7 @@
 #define DFL_KEY_FILE2           ""
 #define DFL_PSK                 ""
 #define DFL_PSK_IDENTITY        "Client_identity"
+#define DFL_PSK_LIST            NULL
 #define DFL_FORCE_CIPHER        0
 #define DFL_RENEGOTIATION       SSL_RENEGOTIATION_DISABLED
 #define DFL_ALLOW_LEGACY        SSL_LEGACY_NO_RENEGOTIATION
@@ -127,6 +128,7 @@
     const char *key_file2;      /* the file with the 2nd server key         */
     const char *psk;            /* the pre-shared key                       */
     const char *psk_identity;   /* the pre-shared key identity              */
+    char *psk_list;             /* list of PSK id/key pairs for callback    */
     int force_ciphersuite[2];   /* protocol/ciphersuite to use, or all      */
     int renegotiation;          /* enable / disable renegotiation           */
     int allow_legacy;           /* allow legacy renegotiation               */
@@ -474,6 +476,97 @@
 
     return( 0 );
 }
+
+typedef struct _psk_entry psk_entry;
+
+struct _psk_entry
+{
+    const char *name;
+    size_t key_len;
+    unsigned char key[MAX_PSK_LEN];
+    psk_entry *next;
+};
+
+/*
+ * Parse a string of pairs name1,key1[,name2,key2[,...]]
+ * into a usable psk_entry list.
+ *
+ * Modifies the input string! This is not production quality!
+ * (leaks memory if parsing fails, no error reporting, ...)
+ */
+psk_entry *psk_parse( char *psk_string )
+{
+    psk_entry *cur = NULL, *new = NULL;
+    char *p = psk_string;
+    char *end = p;
+    char *key_hex;
+
+    while( *end != '\0' )
+        ++end;
+    *end = ',';
+
+    while( p <= end )
+    {
+        if( ( new = polarssl_malloc( sizeof( psk_entry ) ) ) == NULL )
+            return( NULL );
+
+        memset( new, 0, sizeof( psk_entry ) );
+
+        new->name = p;
+        while( *p != ',' ) if( ++p > end ) return( NULL );
+        *p++ = '\0';
+
+        key_hex = p;
+        while( *p != ',' ) if( ++p > end ) return( NULL );
+        *p++ = '\0';
+
+        if( unhexify( new->key, key_hex, &new->key_len ) != 0 )
+            return( NULL );
+
+        new->next = cur;
+        cur = new;
+    }
+
+    return( cur );
+}
+
+/*
+ * Free a list of psk_entry's
+ */
+void psk_free( psk_entry *head )
+{
+    psk_entry *next;
+
+    while( head != NULL )
+    {
+        next = head->next;
+        polarssl_free( head );
+        head = next;
+    }
+}
+
+/*
+ * PSK callback
+ */
+int psk_callback( void *p_info, ssl_context *ssl,
+                  const unsigned char *name, size_t name_len )
+{
+    psk_entry *cur = (psk_entry *) p_info;
+
+    while( cur != NULL )
+    {
+        if( name_len == strlen( cur->name ) &&
+            memcmp( name, cur->name, name_len ) == 0 )
+        {
+            return( ssl_set_psk( ssl, cur->key, cur->key_len,
+                                 name, name_len ) );
+        }
+
+        cur = cur->next;
+    }
+
+    return( -1 );
+}
 #endif /* POLARSSL_KEY_EXCHANGE__SOME__PSK_ENABLED */
 
 int main( int argc, char *argv[] )
@@ -485,6 +578,7 @@
 #if defined(POLARSSL_KEY_EXCHANGE__SOME__PSK_ENABLED)
     unsigned char psk[MAX_PSK_LEN];
     size_t psk_len = 0;
+    psk_entry *psk_info;
 #endif
     const char *pers = "ssl_server2";
 
@@ -579,6 +673,7 @@
     opt.key_file2           = DFL_KEY_FILE2;
     opt.psk                 = DFL_PSK;
     opt.psk_identity        = DFL_PSK_IDENTITY;
+    opt.psk_list            = DFL_PSK_LIST;
     opt.force_ciphersuite[0]= DFL_FORCE_CIPHER;
     opt.renegotiation       = DFL_RENEGOTIATION;
     opt.allow_legacy        = DFL_ALLOW_LEGACY;
@@ -640,6 +735,8 @@
             opt.psk = q;
         else if( strcmp( p, "psk_identity" ) == 0 )
             opt.psk_identity = q;
+        else if( strcmp( p, "psk_list" ) == 0 )
+            opt.psk_list = q;
         else if( strcmp( p, "force_ciphersuite" ) == 0 )
         {
             opt.force_ciphersuite[0] = -1;
@@ -812,13 +909,19 @@
 
 #if defined(POLARSSL_KEY_EXCHANGE__SOME__PSK_ENABLED)
     /*
-     * Unhexify the pre-shared key if any is given
+     * Unhexify the pre-shared key and parse the list if any given
      */
-    if( opt.psk != NULL )
+    if( unhexify( psk, opt.psk, &psk_len ) != 0 )
     {
-        if( unhexify( psk, opt.psk, &psk_len ) != 0 )
+        printf( "pre-shared key not valid hex\n" );
+        goto exit;
+    }
+
+    if( opt.psk_list != NULL )
+    {
+        if( ( psk_info = psk_parse( opt.psk_list ) ) == NULL )
         {
-            printf("pre-shared key not valid hex\n");
+            printf( "psk_list invalid" );
             goto exit;
         }
     }
@@ -1127,6 +1230,8 @@
 #if defined(POLARSSL_KEY_EXCHANGE__SOME__PSK_ENABLED)
     ssl_set_psk( &ssl, psk, psk_len, (const unsigned char *) opt.psk_identity,
                  strlen( opt.psk_identity ) );
+    if( opt.psk_list != NULL )
+        ssl_set_psk_cb( &ssl, psk_callback, psk_info );
 #endif
 
 #if defined(POLARSSL_DHM_C)