Refine code based on commnets

Change code layout
Change hostname_len type to size_t
Fix various issues

Signed-off-by: Xiaokang Qian <xiaokang.qian@arm.com>
diff --git a/library/ssl_client.c b/library/ssl_client.c
index 2c5f664..8080e3e 100644
--- a/library/ssl_client.c
+++ b/library/ssl_client.c
@@ -54,7 +54,6 @@
 {
     unsigned char *p = buf;
     size_t hostname_len;
-    size_t cmp_hostname_len;
 
     *olen = 0;
 
@@ -65,25 +64,8 @@
         ( "client hello, adding server name extension: %s",
           ssl->hostname ) );
 
-    ssl->session_negotiate->hostname_mismatch = 0;
     hostname_len = strlen( ssl->hostname );
 
-    cmp_hostname_len = hostname_len < ssl->session_negotiate->hostname_len ?
-                       hostname_len : ssl->session_negotiate->hostname_len;
-
-    if( hostname_len != ssl->session_negotiate->hostname_len ||
-        memcmp( ssl->hostname, ssl->session_negotiate->hostname, cmp_hostname_len ) )
-        ssl->session_negotiate->hostname_mismatch = 1;
-
-    if( ssl->session_negotiate->hostname == NULL )
-    {
-        ssl->session_negotiate->hostname = mbedtls_calloc( 1, hostname_len );
-        if( ssl->session_negotiate->hostname == NULL )
-            return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
-        memcpy(ssl->session_negotiate->hostname, ssl->hostname, hostname_len);
-    }
-    ssl->session_negotiate->hostname_len = hostname_len;
-
     MBEDTLS_SSL_CHK_BUF_PTR( p, end, hostname_len + 9 );
 
     /*
@@ -888,6 +870,34 @@
         }
     }
 
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3) && \
+    defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
+    if( ssl->handshake->resume )
+    {
+        if( ssl->hostname != NULL && ssl->session_negotiate->hostname != NULL )
+        {
+            if( strcmp( ssl->hostname, ssl->session_negotiate->hostname ) )
+            {
+                MBEDTLS_SSL_DEBUG_MSG( 1,
+                ( "hostname mismatch the session ticket, should not resume " ) );
+                return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+            }
+        }
+        else if( ssl->session_negotiate->hostname != NULL )
+        {
+                MBEDTLS_SSL_DEBUG_MSG( 1,
+                ( "hostname missed, should not resume " ) );
+                return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+        }
+    }
+    else
+    {
+        mbedtls_ssl_session_set_hostname( ssl->session_negotiate,
+                                          ssl->hostname );
+    }
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3 &&
+          MBEDTLS_SSL_SERVER_NAME_INDICATION */
+
     return( 0 );
 }
 /*
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index afacb76..f92a4db 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -2201,6 +2201,10 @@
     return( 1 );
 }
 
+#if defined(MBEDTLS_X509_CRT_PARSE_C)
+int mbedtls_ssl_session_set_hostname( mbedtls_ssl_session *ssl,
+                                      const char *hostname );
+#endif
 #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_2)
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index abadc80..959d015 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -297,14 +297,16 @@
     }
 #endif /* MBEDTLS_SSL_SESSION_TICKETS && MBEDTLS_SSL_CLI_C */
 
-#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) && defined(MBEDTLS_SSL_CLI_C)
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3) && \
+    defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) && \
+    defined(MBEDTLS_SSL_CLI_C)
     if( src->endpoint == MBEDTLS_SSL_IS_CLIENT && src->hostname != NULL )
     {
-        dst->hostname = mbedtls_calloc( 1, src->hostname_len );
+        dst->hostname = mbedtls_calloc( 1, src->hostname_len + 1 );
         if( dst->hostname == NULL )
             return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
 
-        memcpy( dst->hostname, src->hostname, src->hostname_len );
+        strcpy( dst->hostname, src->hostname );
         dst->hostname_len = src->hostname_len;
     }
 #endif
@@ -1958,7 +1960,6 @@
  *       uint32 ticket_age_add;
  *       uint8 ticket_flags;
  *       opaque resumption_key<0..255>;
- *
  *       select ( endpoint ) {
  *            case client: ClientOnlyData;
  *            case server: uint64 start_time;
@@ -1993,7 +1994,7 @@
     if( session->endpoint == MBEDTLS_SSL_IS_CLIENT )
     {
 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
-        needed +=  1                        /* hostname_len */
+        needed +=  2                        /* hostname_len */
                  + session->hostname_len;   /* hostname */
 #endif
 
@@ -2026,13 +2027,15 @@
 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION) && defined(MBEDTLS_SSL_CLI_C)
     if( session->endpoint == MBEDTLS_SSL_IS_CLIENT )
     {
-        p[0] = session->hostname_len;
-        p++;
+        MBEDTLS_PUT_UINT16_BE( session->hostname_len, p, 0 );
+        p += 2;
         if ( session->hostname_len > 0 &&
              session->hostname != NULL )
-        /* save host name */
-        memcpy( p, session->hostname, session->hostname_len );
-        p += session->hostname_len;
+        {
+            /* save host name */
+            memcpy( p, session->hostname, session->hostname_len );
+            p += session->hostname_len;
+        }
     }
 #endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION && MBEDTLS_SSL_CLI_C */
 
@@ -2098,19 +2101,20 @@
     if( session->endpoint == MBEDTLS_SSL_IS_CLIENT )
     {
         /* load host name */
-        if( end - p < 1 )
+        if( end - p < 2 )
             return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
-        session->hostname_len = p[0];
-        p += 1;
+        session->hostname_len = MBEDTLS_GET_UINT16_BE( p, 0);
+        p += 2;
 
-        if( end - p < session->hostname_len )
+        if( end - p < ( long int )session->hostname_len )
             return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
         if( session->hostname_len > 0 )
         {
-            session->hostname = mbedtls_calloc( 1, session->hostname_len );
+            session->hostname = mbedtls_calloc( 1, session->hostname_len + 1 );
             if( session->hostname == NULL )
                 return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
             memcpy( session->hostname, p, session->hostname_len );
+            session->hostname[session->hostname_len] = '\0';
             p += session->hostname_len;
         }
     }
@@ -3733,7 +3737,8 @@
     mbedtls_free( session->ticket );
 #endif
 
-#if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3) && \
+    defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
     mbedtls_free( session->hostname );
 #endif
 
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index abb7a14..1b827ac 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -1485,4 +1485,51 @@
 }
 #endif /* MBEDTLS_ECDH_C */
 
+#if defined(MBEDTLS_X509_CRT_PARSE_C)
+int mbedtls_ssl_session_set_hostname( mbedtls_ssl_session *ssl,
+                                      const char *hostname )
+{
+    /* Initialize to suppress unnecessary compiler warning */
+    size_t hostname_len = 0;
+
+    /* Check if new hostname is valid before
+     * making any change to current one */
+    if( hostname != NULL )
+    {
+        hostname_len = strlen( hostname );
+
+        if( hostname_len > MBEDTLS_SSL_MAX_HOST_NAME_LEN )
+            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+    }
+
+    /* Now it's clear that we will overwrite the old hostname,
+     * so we can free it safely */
+
+    if( ssl->hostname != NULL )
+    {
+        mbedtls_platform_zeroize( ssl->hostname, strlen( ssl->hostname ) );
+        mbedtls_free( ssl->hostname );
+    }
+
+    /* Passing NULL as hostname shall clear the old one */
+
+    if( hostname == NULL )
+    {
+        ssl->hostname = NULL;
+    }
+    else
+    {
+        ssl->hostname = mbedtls_calloc( 1, hostname_len + 1 );
+        if( ssl->hostname == NULL )
+            return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
+
+        memcpy( ssl->hostname, hostname, hostname_len );
+
+        ssl->hostname[hostname_len] = '\0';
+        ssl->hostname_len = hostname_len;
+    }
+
+    return( 0 );
+}
+#endif /* MBEDTLS_X509_CRT_PARSE_C */
 #endif /* MBEDTLS_SSL_TLS_C && MBEDTLS_SSL_PROTO_TLS1_3 */