Merge pull request #5454 from gstrauss/cert_cb-user_data

server certificate selection callback
diff --git a/ChangeLog.d/mbedtls_ssl_cert_cb.txt b/ChangeLog.d/mbedtls_ssl_cert_cb.txt
new file mode 100644
index 0000000..fcdc23c
--- /dev/null
+++ b/ChangeLog.d/mbedtls_ssl_cert_cb.txt
@@ -0,0 +1,7 @@
+Features
+   * Add server certificate selection callback near end of Client Hello.
+     Register callback with mbedtls_ssl_conf_cert_cb().
+   * Provide mechanism to reset handshake cert list by calling
+     mbedtls_ssl_set_hs_own_cert() with NULL value for own_cert param.
+   * Add accessor mbedtls_ssl_get_hs_sni() to retrieve SNI from within
+     cert callback (mbedtls_ssl_conf_cert_cb()) during handshake.
diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h
index 7544f42..b819bba 100644
--- a/include/mbedtls/ssl.h
+++ b/include/mbedtls/ssl.h
@@ -1475,6 +1475,10 @@
      * access it afterwards.
      */
     mbedtls_ssl_user_data_t MBEDTLS_PRIVATE(user_data);
+
+#if defined(MBEDTLS_SSL_SRV_C)
+    int (*MBEDTLS_PRIVATE(f_cert_cb))(mbedtls_ssl_context *); /*!< certificate selection callback */
+#endif /* MBEDTLS_SSL_SRV_C */
 };
 
 struct mbedtls_ssl_context
@@ -2220,6 +2224,28 @@
                                mbedtls_ssl_set_timer_t *f_set_timer,
                                mbedtls_ssl_get_timer_t *f_get_timer );
 
+#if defined(MBEDTLS_SSL_SRV_C)
+/**
+ * \brief           Set the certificate selection callback (server-side only).
+ *
+ *                  If set, the callback is always called for each handshake,
+ *                  after `ClientHello` processing has finished.
+ *
+ *                  The callback has the following parameters:
+ *                  - \c mbedtls_ssl_context*: The SSL context to which
+ *                                             the operation applies.
+ *                  The return value of the callback is 0 if successful,
+ *                  or a specific MBEDTLS_ERR_XXX code, which will cause
+ *                  the handshake to be aborted.
+ *
+ * \param conf      The SSL configuration to register the callback with.
+ * \param f_cert_cb The callback for selecting server certificate after
+ *                  `ClientHello` processing has finished.
+ */
+void mbedtls_ssl_conf_cert_cb( mbedtls_ssl_config *conf,
+                               int (*f_cert_cb)(mbedtls_ssl_context *) );
+#endif /* MBEDTLS_SSL_SRV_C */
+
 /**
  * \brief           Callback type: generate and write session ticket
  *
@@ -3515,10 +3541,34 @@
 
 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
 /**
+ * \brief          Retrieve SNI extension value for the current handshake.
+ *                 Available in \p f_cert_cb of \c mbedtls_ssl_conf_cert_cb(),
+ *                 this is the same value passed to \p f_sni callback of
+ *                 \c mbedtls_ssl_conf_sni() and may be used instead of
+ *                 \c mbedtls_ssl_conf_sni().
+ *
+ * \param ssl      SSL context
+ * \param name_len pointer into which to store length of returned value.
+ *                 0 if SNI extension is not present or not yet processed.
+ *
+ * \return         const pointer to SNI extension value.
+ *                 - value is valid only when called in \p f_cert_cb
+ *                   registered with \c mbedtls_ssl_conf_cert_cb().
+ *                 - value is NULL if SNI extension is not present.
+ *                 - value is not '\0'-terminated.  Use \c name_len for len.
+ *                 - value must not be freed.
+ */
+const unsigned char *mbedtls_ssl_get_hs_sni( mbedtls_ssl_context *ssl,
+                                             size_t *name_len );
+
+/**
  * \brief          Set own certificate and key for the current handshake
  *
  * \note           Same as \c mbedtls_ssl_conf_own_cert() but for use within
- *                 the SNI callback.
+ *                 the SNI callback or the certificate selection callback.
+ *
+ * \note           Passing null \c own_cert clears the certificate list for
+ *                 the current handshake.
  *
  * \param ssl      SSL context
  * \param own_cert own public certificate chain
@@ -3535,7 +3585,7 @@
  *                 current handshake
  *
  * \note           Same as \c mbedtls_ssl_conf_ca_chain() but for use within
- *                 the SNI callback.
+ *                 the SNI callback or the certificate selection callback.
  *
  * \param ssl      SSL context
  * \param ca_chain trusted CA chain (meaning all fully trusted top-level CAs)
@@ -3549,7 +3599,7 @@
  * \brief          Set authmode for the current handshake.
  *
  * \note           Same as \c mbedtls_ssl_conf_authmode() but for use within
- *                 the SNI callback.
+ *                 the SNI callback or the certificate selection callback.
  *
  * \param ssl      SSL context
  * \param authmode MBEDTLS_SSL_VERIFY_NONE, MBEDTLS_SSL_VERIFY_OPTIONAL or
@@ -3574,8 +3624,7 @@
  *                 mbedtls_ssl_set_hs_ca_chain() as well as the client
  *                 authentication mode with \c mbedtls_ssl_set_hs_authmode(),
  *                 then must return 0. If no matching name is found, the
- *                 callback must either set a default cert, or
- *                 return non-zero to abort the handshake at this point.
+ *                 callback may return non-zero to abort the handshake.
  *
  * \param conf     SSL configuration
  * \param f_sni    verification function
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 9ed5dfa..be01eba 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -850,6 +850,11 @@
      * The library does not use it internally. */
     void *user_async_ctx;
 #endif /* MBEDTLS_SSL_ASYNC_PRIVATE */
+
+#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
+    const unsigned char *sni_name;      /*!< raw SNI                        */
+    size_t sni_name_len;                /*!< raw SNI len                    */
+#endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION */
 };
 
 typedef struct mbedtls_ssl_hs_buffer mbedtls_ssl_hs_buffer;
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index e9febfd..c757ac8 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -118,6 +118,11 @@
 
         if( p[0] == MBEDTLS_TLS_EXT_SERVERNAME_HOSTNAME )
         {
+            ssl->handshake->sni_name = p + 3;
+            ssl->handshake->sni_name_len = hostname_len;
+            if( ssl->conf->f_sni == NULL )
+                return( 0 );
+
             ret = ssl->conf->f_sni( ssl->conf->p_sni,
                                     ssl, p + 3, hostname_len );
             if( ret != 0 )
@@ -1643,9 +1648,6 @@
 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
             case MBEDTLS_TLS_EXT_SERVERNAME:
                 MBEDTLS_SSL_DEBUG_MSG( 3, ( "found ServerName extension" ) );
-                if( ssl->conf->f_sni == NULL )
-                    break;
-
                 ret = ssl_parse_servername_ext( ssl, ext + 4, ext_size );
                 if( ret != 0 )
                     return( ret );
@@ -1871,9 +1873,23 @@
     }
 
     /*
+     * Server certification selection (after processing TLS extensions)
+     */
+    if( ssl->conf->f_cert_cb && ( ret = ssl->conf->f_cert_cb( ssl ) ) != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "f_cert_cb", ret );
+        return( ret );
+    }
+#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
+    ssl->handshake->sni_name = NULL;
+    ssl->handshake->sni_name_len = 0;
+#endif
+
+    /*
      * Search for a matching ciphersuite
      * (At the end because we need information from the EC-based extensions
-     * and certificate from the SNI callback triggered by the SNI extension.)
+     * and certificate from the SNI callback triggered by the SNI extension
+     * or certificate from server certificate selection callback.)
      */
     got_common_suite = 0;
     ciphersuites = ssl->conf->ciphersuite_list;
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index adb18ab..2220721 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -1233,6 +1233,14 @@
 }
 
 #if defined(MBEDTLS_SSL_SRV_C)
+void mbedtls_ssl_conf_cert_cb( mbedtls_ssl_config *conf,
+                               int (*f_cert_cb)(mbedtls_ssl_context *) )
+{
+    conf->f_cert_cb = f_cert_cb;
+}
+#endif /* MBEDTLS_SSL_SRV_C */
+
+#if defined(MBEDTLS_SSL_SRV_C)
 void mbedtls_ssl_conf_session_cache( mbedtls_ssl_config *conf,
                                      void *p_cache,
                                      mbedtls_ssl_cache_get_t *f_get_cache,
@@ -1291,6 +1299,18 @@
     conf->cert_profile = profile;
 }
 
+static void ssl_key_cert_free( mbedtls_ssl_key_cert *key_cert )
+{
+    mbedtls_ssl_key_cert *cur = key_cert, *next;
+
+    while( cur != NULL )
+    {
+        next = cur->next;
+        mbedtls_free( cur );
+        cur = next;
+    }
+}
+
 /* Append a new keycert entry to a (possibly empty) list */
 static int ssl_append_key_cert( mbedtls_ssl_key_cert **head,
                                 mbedtls_x509_crt *cert,
@@ -1298,6 +1318,14 @@
 {
     mbedtls_ssl_key_cert *new_cert;
 
+    if( cert == NULL )
+    {
+        /* Free list if cert is null */
+        ssl_key_cert_free( *head );
+        *head = NULL;
+        return( 0 );
+    }
+
     new_cert = mbedtls_calloc( 1, sizeof( mbedtls_ssl_key_cert ) );
     if( new_cert == NULL )
         return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
@@ -1306,7 +1334,7 @@
     new_cert->key  = key;
     new_cert->next = NULL;
 
-    /* Update head is the list was null, else add to the end */
+    /* Update head if the list was null, else add to the end */
     if( *head == NULL )
     {
         *head = new_cert;
@@ -1361,6 +1389,13 @@
 #endif /* MBEDTLS_X509_CRT_PARSE_C */
 
 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
+const unsigned char *mbedtls_ssl_get_hs_sni( mbedtls_ssl_context *ssl,
+                                             size_t *name_len )
+{
+    *name_len = ssl->handshake->sni_name_len;
+    return( ssl->handshake->sni_name );
+}
+
 int mbedtls_ssl_set_hs_own_cert( mbedtls_ssl_context *ssl,
                                  mbedtls_x509_crt *own_cert,
                                  mbedtls_pk_context *pk_key )
@@ -2941,20 +2976,6 @@
 }
 #endif /* MBEDTLS_SSL_RENEGOTIATION */
 
-#if defined(MBEDTLS_X509_CRT_PARSE_C)
-static void ssl_key_cert_free( mbedtls_ssl_key_cert *key_cert )
-{
-    mbedtls_ssl_key_cert *cur = key_cert, *next;
-
-    while( cur != NULL )
-    {
-        next = cur->next;
-        mbedtls_free( cur );
-        cur = next;
-    }
-}
-#endif /* MBEDTLS_X509_CRT_PARSE_C */
-
 void mbedtls_ssl_handshake_free( mbedtls_ssl_context *ssl )
 {
     mbedtls_ssl_handshake_params *handshake = ssl->handshake;
@@ -3042,17 +3063,7 @@
      * Free only the linked list wrapper, not the keys themselves
      * since the belong to the SNI callback
      */
-    if( handshake->sni_key_cert != NULL )
-    {
-        mbedtls_ssl_key_cert *cur = handshake->sni_key_cert, *next;
-
-        while( cur != NULL )
-        {
-            next = cur->next;
-            mbedtls_free( cur );
-            cur = next;
-        }
-    }
+    ssl_key_cert_free( handshake->sni_key_cert );
 #endif /* MBEDTLS_X509_CRT_PARSE_C && MBEDTLS_SSL_SERVER_NAME_INDICATION */
 
 #if defined(MBEDTLS_SSL_ECP_RESTARTABLE_ENABLED)
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index 595300e..802078c 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -823,18 +823,23 @@
 {
     const sni_entry *cur = (const sni_entry *) p_info;
 
+    /* preserve behavior which checks for SNI match in sni_callback() for
+     * the benefits of tests using sni_callback(), even though the actual
+     * certificate assignment has moved to certificate selection callback
+     * in this application.  This exercises sni_callback and cert_callback
+     * even though real applications might choose to do this differently.
+     * Application might choose to save name and name_len in user_data for
+     * later use in certificate selection callback.
+     */
     while( cur != NULL )
     {
         if( name_len == strlen( cur->name ) &&
             memcmp( name, cur->name, name_len ) == 0 )
         {
-            if( cur->ca != NULL )
-                mbedtls_ssl_set_hs_ca_chain( ssl, cur->ca, cur->crl );
-
-            if( cur->authmode != DFL_AUTH_MODE )
-                mbedtls_ssl_set_hs_authmode( ssl, cur->authmode );
-
-            return( mbedtls_ssl_set_hs_own_cert( ssl, cur->cert, cur->key ) );
+            void *p;
+            *(const void **)&p = cur;
+            mbedtls_ssl_set_user_data_p( ssl, p );
+            return( 0 );
         }
 
         cur = cur->next;
@@ -843,6 +848,33 @@
     return( -1 );
 }
 
+/*
+ * server certificate selection callback.
+ */
+int cert_callback( mbedtls_ssl_context *ssl )
+{
+    const sni_entry *cur = (sni_entry *) mbedtls_ssl_get_user_data_p( ssl );
+    if( cur != NULL )
+    {
+        /*(exercise mbedtls_ssl_get_hs_sni(); not otherwise used here)*/
+        size_t name_len;
+        const unsigned char *name = mbedtls_ssl_get_hs_sni( ssl, &name_len );
+        if( strlen( cur->name ) != name_len ||
+            memcmp( cur->name, name, name_len ) != 0 )
+            return( MBEDTLS_ERR_SSL_DECODE_ERROR );
+
+        if( cur->ca != NULL )
+            mbedtls_ssl_set_hs_ca_chain( ssl, cur->ca, cur->crl );
+
+        if( cur->authmode != DFL_AUTH_MODE )
+            mbedtls_ssl_set_hs_authmode( ssl, cur->authmode );
+
+        return( mbedtls_ssl_set_hs_own_cert( ssl, cur->cert, cur->key ) );
+    }
+
+    return( 0 );
+}
+
 #endif /* SNI_OPTION */
 
 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
@@ -2923,6 +2955,7 @@
     if( opt.sni != NULL )
     {
         mbedtls_ssl_conf_sni( &conf, sni_callback, sni_info );
+        mbedtls_ssl_conf_cert_cb( &conf, cert_callback );
 #if defined(MBEDTLS_SSL_ASYNC_PRIVATE)
         if( opt.async_private_delay2 >= 0 )
         {
diff --git a/tests/ssl-opt.sh b/tests/ssl-opt.sh
index e092309..9e99c1f 100755
--- a/tests/ssl-opt.sh
+++ b/tests/ssl-opt.sh
@@ -5025,7 +5025,6 @@
              crt_file=data_files/server5.crt key_file=data_files/server5.key" \
             "$P_CLI server_name=localhost" \
             0 \
-            -S "parse ServerName extension" \
             -c "issuer name *: C=NL, O=PolarSSL, CN=Polarssl Test EC CA" \
             -c "subject name *: C=NL, O=PolarSSL, CN=localhost"
 
@@ -5175,7 +5174,6 @@
              crt_file=data_files/server5.crt key_file=data_files/server5.key" \
             "$P_CLI server_name=localhost dtls=1" \
             0 \
-            -S "parse ServerName extension" \
             -c "issuer name *: C=NL, O=PolarSSL, CN=Polarssl Test EC CA" \
             -c "subject name *: C=NL, O=PolarSSL, CN=localhost"
 
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index 3831d4a..67f4d6e 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -854,6 +854,15 @@
     ret = mbedtls_ssl_conf_own_cert( &( ep->conf ), &( cert->cert ),
                                      &( cert->pkey ) );
     TEST_ASSERT( ret == 0 );
+    TEST_ASSERT( ep->conf.key_cert != NULL );
+
+    ret = mbedtls_ssl_conf_own_cert( &( ep->conf ), NULL, NULL );
+    TEST_ASSERT( ret == 0 );
+    TEST_ASSERT( ep->conf.key_cert == NULL );
+
+    ret = mbedtls_ssl_conf_own_cert( &( ep->conf ), &( cert->cert ),
+                                     &( cert->pkey ) );
+    TEST_ASSERT( ret == 0 );
 
 exit:
     if( ret != 0 )