Add support of server name extension to server side

Change-Id: Iccf5017e306ba6ead2e1026a29f397ead084cc4d
Signed-off-by: XiaokangQian <xiaokang.qian@arm.com>
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 9912d6c..f0f4465 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -2270,4 +2270,10 @@
 int mbedtls_ssl_write_sig_alg_ext( mbedtls_ssl_context *ssl, unsigned char *buf,
                                    const unsigned char *end, size_t *out_len );
 
+#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
+int mbedtls_ssl_parse_servername_ext( mbedtls_ssl_context *ssl,
+                                      const unsigned char *buf,
+                                      const unsigned char *end );
+#endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION */
+
 #endif /* ssl_misc.h */
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 5331865..29a33f4 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -8210,4 +8210,55 @@
 }
 #endif /* MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED */
 
+#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
+int mbedtls_ssl_parse_servername_ext( mbedtls_ssl_context *ssl,
+                                             const unsigned char *buf,
+                                             const unsigned char *end )
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    const unsigned char *p = buf;
+    size_t servername_list_size, hostname_len;
+    const unsigned char *servername_end;
+
+    if( ssl->conf->p_sni == NULL )
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 3, ( "No SNI callback configured. Skip SNI parsing." ) );
+        return( 0 );
+    }
+
+    MBEDTLS_SSL_DEBUG_MSG( 3, ( "Parse ServerName extension" ) );
+
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, 2 );
+    servername_list_size = MBEDTLS_GET_UINT16_BE( p, 0 );
+    p += 2;
+
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, servername_list_size );
+    servername_end = p + servername_list_size;
+    while ( p < servername_end )
+    {
+        MBEDTLS_SSL_CHK_BUF_READ_PTR( p, servername_end, 3 );
+        hostname_len = MBEDTLS_GET_UINT16_BE( p, 1 );
+        MBEDTLS_SSL_CHK_BUF_READ_PTR( p, servername_end, hostname_len + 3 );
+
+        if( p[0] == MBEDTLS_TLS_EXT_SERVERNAME_HOSTNAME )
+        {
+            ret = ssl->conf->f_sni( ssl->conf->p_sni,
+                                   ssl, p + 3, hostname_len );
+            if( ret != 0 )
+            {
+                MBEDTLS_SSL_DEBUG_RET( 1, "sni_wrapper", ret );
+                mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
+                                               MBEDTLS_SSL_ALERT_MSG_UNRECOGNIZED_NAME );
+                return( MBEDTLS_ERR_SSL_UNRECOGNIZED_NAME );
+            }
+            return( 0 );
+        }
+
+        p += hostname_len + 3;
+    }
+
+    return( 0 );
+}
+#endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION */
+
 #endif /* MBEDTLS_SSL_TLS_C */
diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c
index d092633..e48a8ca 100644
--- a/library/ssl_tls12_server.c
+++ b/library/ssl_tls12_server.c
@@ -77,80 +77,6 @@
 }
 #endif /* MBEDTLS_SSL_DTLS_HELLO_VERIFY */
 
-#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
-static int ssl_parse_servername_ext( mbedtls_ssl_context *ssl,
-                                     const unsigned char *buf,
-                                     size_t len )
-{
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    size_t servername_list_size, hostname_len;
-    const unsigned char *p;
-
-    MBEDTLS_SSL_DEBUG_MSG( 3, ( "parse ServerName extension" ) );
-
-    if( len < 2 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad client hello message" ) );
-        mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
-                                       MBEDTLS_SSL_ALERT_MSG_DECODE_ERROR );
-        return( MBEDTLS_ERR_SSL_DECODE_ERROR );
-    }
-    servername_list_size = ( ( buf[0] << 8 ) | ( buf[1] ) );
-    if( servername_list_size + 2 != len )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad client hello message" ) );
-        mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
-                                        MBEDTLS_SSL_ALERT_MSG_DECODE_ERROR );
-        return( MBEDTLS_ERR_SSL_DECODE_ERROR );
-    }
-
-    p = buf + 2;
-    while( servername_list_size > 2 )
-    {
-        hostname_len = ( ( p[1] << 8 ) | p[2] );
-        if( hostname_len + 3 > servername_list_size )
-        {
-            MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad client hello message" ) );
-            mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
-                                            MBEDTLS_SSL_ALERT_MSG_DECODE_ERROR );
-            return( MBEDTLS_ERR_SSL_DECODE_ERROR );
-        }
-
-        if( p[0] == MBEDTLS_TLS_EXT_SERVERNAME_HOSTNAME )
-        {
-            ssl->handshake->sni_name = p + 3;
-            ssl->handshake->sni_name_len = hostname_len;
-            if( ssl->conf->f_sni == NULL )
-                return( 0 );
-
-            ret = ssl->conf->f_sni( ssl->conf->p_sni,
-                                    ssl, p + 3, hostname_len );
-            if( ret != 0 )
-            {
-                MBEDTLS_SSL_DEBUG_RET( 1, "ssl_sni_wrapper", ret );
-                mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
-                        MBEDTLS_SSL_ALERT_MSG_UNRECOGNIZED_NAME );
-                return( MBEDTLS_ERR_SSL_UNRECOGNIZED_NAME );
-            }
-            return( 0 );
-        }
-
-        servername_list_size -= hostname_len + 3;
-        p += hostname_len + 3;
-    }
-
-    if( servername_list_size != 0 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad client hello message" ) );
-        mbedtls_ssl_send_alert_message( ssl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
-                                        MBEDTLS_SSL_ALERT_MSG_DECODE_ERROR );
-        return( MBEDTLS_ERR_SSL_DECODE_ERROR );
-    }
-
-    return( 0 );
-}
-#endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION */
-
 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
 static int ssl_conf_has_psk_or_cb( mbedtls_ssl_config const *conf )
 {
@@ -1483,7 +1409,8 @@
 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
             case MBEDTLS_TLS_EXT_SERVERNAME:
                 MBEDTLS_SSL_DEBUG_MSG( 3, ( "found ServerName extension" ) );
-                ret = ssl_parse_servername_ext( ssl, ext + 4, ext_size );
+                ret = mbedtls_ssl_parse_servername_ext( ssl, ext + 4,
+                                                        ext + 4 + ext_size );
                 if( ret != 0 )
                     return( ret );
                 break;
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index f3843b1..9d2c8ec 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -580,6 +580,21 @@
 
         switch( extension_type )
         {
+#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
+            case MBEDTLS_TLS_EXT_SERVERNAME:
+                MBEDTLS_SSL_DEBUG_MSG( 3, ( "found ServerName extension" ) );
+                ret = mbedtls_ssl_parse_servername_ext( ssl, p,
+                                                        extension_data_end );
+                if( ret != 0 )
+                {
+                    MBEDTLS_SSL_DEBUG_RET(
+                            1, "mbedtls_ssl_parse_servername_ext", ret );
+                    return( ret );
+                }
+                ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_SERVERNAME;
+                break;
+#endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION */
+
 #if defined(MBEDTLS_ECDH_C)
             case MBEDTLS_TLS_EXT_SUPPORTED_GROUPS:
                 MBEDTLS_SSL_DEBUG_MSG( 3, ( "found supported group extension" ) );
@@ -1337,6 +1352,11 @@
 {
     int authmode;
 
+#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
+    if( ssl->handshake->sni_authmode != MBEDTLS_SSL_VERIFY_UNSET )
+        authmode = ssl->handshake->sni_authmode;
+    else
+#endif
     authmode = ssl->conf->authmode;
 
     if( authmode == MBEDTLS_SSL_VERIFY_NONE )