Merge pull request #5679 from yuhaoth/pr/add-tls13-write-server-hello

diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 966764a..141c40a 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -1783,6 +1783,16 @@
 
 int mbedtls_ssl_reset_transcript_for_hrr( mbedtls_ssl_context *ssl );
 
+#if defined(MBEDTLS_ECDH_C)
+int mbedtls_ssl_tls13_generate_and_write_ecdh_key_exchange(
+                mbedtls_ssl_context *ssl,
+                uint16_t named_group,
+                unsigned char *buf,
+                unsigned char *end,
+                size_t *out_len );
+#endif /* MBEDTLS_ECDH_C */
+
+
 #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
 
 #if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index cf5b382..d024abf 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -204,65 +204,6 @@
 /*
  * Functions for writing key_share extension.
  */
-#if defined(MBEDTLS_ECDH_C)
-static int ssl_tls13_generate_and_write_ecdh_key_exchange(
-                mbedtls_ssl_context *ssl,
-                uint16_t named_group,
-                unsigned char *buf,
-                unsigned char *end,
-                size_t *out_len )
-{
-    psa_status_t status = PSA_ERROR_GENERIC_ERROR;
-    int ret = MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE;
-    psa_key_attributes_t key_attributes;
-    size_t own_pubkey_len;
-    mbedtls_ssl_handshake_params *handshake = ssl->handshake;
-    size_t ecdh_bits = 0;
-
-    MBEDTLS_SSL_DEBUG_MSG( 1, ( "Perform PSA-based ECDH computation." ) );
-
-    /* Convert EC group to PSA key type. */
-    if( ( handshake->ecdh_psa_type =
-        mbedtls_psa_parse_tls_ecc_group( named_group, &ecdh_bits ) ) == 0 )
-            return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-
-    ssl->handshake->ecdh_bits = ecdh_bits;
-
-    key_attributes = psa_key_attributes_init();
-    psa_set_key_usage_flags( &key_attributes, PSA_KEY_USAGE_DERIVE );
-    psa_set_key_algorithm( &key_attributes, PSA_ALG_ECDH );
-    psa_set_key_type( &key_attributes, handshake->ecdh_psa_type );
-    psa_set_key_bits( &key_attributes, handshake->ecdh_bits );
-
-    /* Generate ECDH private key. */
-    status = psa_generate_key( &key_attributes,
-                                &handshake->ecdh_psa_privkey );
-    if( status != PSA_SUCCESS )
-    {
-        ret = psa_ssl_status_to_mbedtls( status );
-        MBEDTLS_SSL_DEBUG_RET( 1, "psa_generate_key", ret );
-        return( ret );
-
-    }
-
-    /* Export the public part of the ECDH private key from PSA. */
-    status = psa_export_public_key( handshake->ecdh_psa_privkey,
-                                    buf, (size_t)( end - buf ),
-                                    &own_pubkey_len );
-    if( status != PSA_SUCCESS )
-    {
-        ret = psa_ssl_status_to_mbedtls( status );
-        MBEDTLS_SSL_DEBUG_RET( 1, "psa_export_public_key", ret );
-        return( ret );
-
-    }
-
-    *out_len = own_pubkey_len;
-
-    return( 0 );
-}
-#endif /* MBEDTLS_ECDH_C */
-
 static int ssl_tls13_get_default_group_id( mbedtls_ssl_context *ssl,
                                            uint16_t *group_id )
 {
@@ -367,8 +308,8 @@
          */
         MBEDTLS_SSL_CHK_BUF_PTR( p, end, 4 );
         p += 4;
-        ret = ssl_tls13_generate_and_write_ecdh_key_exchange( ssl, group_id, p, end,
-                                                              &key_exchange_len );
+        ret = mbedtls_ssl_tls13_generate_and_write_ecdh_key_exchange(
+                                    ssl, group_id, p, end, &key_exchange_len );
         p += key_exchange_len;
         if( ret != 0 )
             return( ret );
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index 4bee319..f5d791f 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -1535,6 +1535,63 @@
 
     return( 0 );
 }
+
+int mbedtls_ssl_tls13_generate_and_write_ecdh_key_exchange(
+                mbedtls_ssl_context *ssl,
+                uint16_t named_group,
+                unsigned char *buf,
+                unsigned char *end,
+                size_t *out_len )
+{
+    psa_status_t status = PSA_ERROR_GENERIC_ERROR;
+    int ret = MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE;
+    psa_key_attributes_t key_attributes;
+    size_t own_pubkey_len;
+    mbedtls_ssl_handshake_params *handshake = ssl->handshake;
+    size_t ecdh_bits = 0;
+
+    MBEDTLS_SSL_DEBUG_MSG( 1, ( "Perform PSA-based ECDH computation." ) );
+
+    /* Convert EC group to PSA key type. */
+    if( ( handshake->ecdh_psa_type =
+        mbedtls_psa_parse_tls_ecc_group( named_group, &ecdh_bits ) ) == 0 )
+            return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+
+    ssl->handshake->ecdh_bits = ecdh_bits;
+
+    key_attributes = psa_key_attributes_init();
+    psa_set_key_usage_flags( &key_attributes, PSA_KEY_USAGE_DERIVE );
+    psa_set_key_algorithm( &key_attributes, PSA_ALG_ECDH );
+    psa_set_key_type( &key_attributes, handshake->ecdh_psa_type );
+    psa_set_key_bits( &key_attributes, handshake->ecdh_bits );
+
+    /* Generate ECDH private key. */
+    status = psa_generate_key( &key_attributes,
+                                &handshake->ecdh_psa_privkey );
+    if( status != PSA_SUCCESS )
+    {
+        ret = psa_ssl_status_to_mbedtls( status );
+        MBEDTLS_SSL_DEBUG_RET( 1, "psa_generate_key", ret );
+        return( ret );
+
+    }
+
+    /* Export the public part of the ECDH private key from PSA. */
+    status = psa_export_public_key( handshake->ecdh_psa_privkey,
+                                    buf, (size_t)( end - buf ),
+                                    &own_pubkey_len );
+    if( status != PSA_SUCCESS )
+    {
+        ret = psa_ssl_status_to_mbedtls( status );
+        MBEDTLS_SSL_DEBUG_RET( 1, "psa_export_public_key", ret );
+        return( ret );
+
+    }
+
+    *out_len = own_pubkey_len;
+
+    return( 0 );
+}
 #endif /* MBEDTLS_ECDH_C */
 
 #endif /* MBEDTLS_SSL_TLS_C && MBEDTLS_SSL_PROTO_TLS1_3 */
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 8d1b1d8..d06b9a8 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -26,7 +26,7 @@
 #include "ssl_misc.h"
 #include "ssl_tls13_keys.h"
 #include "ssl_debug_helpers.h"
-#include <string.h>
+
 #if defined(MBEDTLS_ECP_C)
 #include "mbedtls/ecp.h"
 #endif /* MBEDTLS_ECP_C */
@@ -728,6 +728,333 @@
 }
 
 /*
+ * Handler for MBEDTLS_SSL_SERVER_HELLO
+ */
+static int ssl_tls13_prepare_server_hello( mbedtls_ssl_context *ssl )
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    unsigned char *server_randbytes =
+                    ssl->handshake->randbytes + MBEDTLS_CLIENT_HELLO_RANDOM_LEN;
+    if( ssl->conf->f_rng == NULL )
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 1, ( "no RNG provided" ) );
+        return( MBEDTLS_ERR_SSL_NO_RNG );
+    }
+
+    if( ( ret = ssl->conf->f_rng( ssl->conf->p_rng, server_randbytes,
+                                  MBEDTLS_SERVER_HELLO_RANDOM_LEN ) ) != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "f_rng", ret );
+        return( ret );
+    }
+
+    MBEDTLS_SSL_DEBUG_BUF( 3, "server hello, random bytes", server_randbytes,
+                           MBEDTLS_SERVER_HELLO_RANDOM_LEN );
+
+#if defined(MBEDTLS_HAVE_TIME)
+    ssl->session_negotiate->start = time( NULL );
+#endif /* MBEDTLS_HAVE_TIME */
+
+    return( ret );
+}
+
+/*
+ * ssl_tls13_write_server_hello_supported_versions_ext ():
+ *
+ * struct {
+ *      ProtocolVersion selected_version;
+ * } SupportedVersions;
+ */
+static int ssl_tls13_write_server_hello_supported_versions_ext(
+                                                mbedtls_ssl_context *ssl,
+                                                unsigned char *buf,
+                                                unsigned char *end,
+                                                size_t *out_len )
+{
+    *out_len = 0;
+
+    MBEDTLS_SSL_DEBUG_MSG( 3, ( "server hello, write selected version" ) );
+
+    /* Check if we have space to write the extension:
+     * - extension_type         (2 bytes)
+     * - extension_data_length  (2 bytes)
+     * - selected_version       (2 bytes)
+     */
+    MBEDTLS_SSL_CHK_BUF_PTR( buf, end, 6 );
+
+    MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_SUPPORTED_VERSIONS, buf, 0 );
+
+    MBEDTLS_PUT_UINT16_BE( 2, buf, 2 );
+
+    mbedtls_ssl_write_version( buf + 4,
+                               ssl->conf->transport,
+                               ssl->tls_version );
+
+    MBEDTLS_SSL_DEBUG_MSG( 3, ( "supported version: [%04x]",
+                                ssl->tls_version ) );
+
+    *out_len = 6;
+
+    return( 0 );
+}
+
+
+
+/* Generate and export a single key share. For hybrid KEMs, this can
+ * be called multiple times with the different components of the hybrid. */
+static int ssl_tls13_generate_and_write_key_share( mbedtls_ssl_context *ssl,
+                                                   uint16_t named_group,
+                                                   unsigned char *buf,
+                                                   unsigned char *end,
+                                                   size_t *out_len )
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+
+    *out_len = 0;
+
+#if defined(MBEDTLS_ECDH_C)
+    if( mbedtls_ssl_tls13_named_group_is_ecdhe( named_group ) )
+    {
+        ret = mbedtls_ssl_tls13_generate_and_write_ecdh_key_exchange(
+                                        ssl, named_group, buf, end, out_len );
+        if( ret != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET(
+                1, "mbedtls_ssl_tls13_generate_and_write_ecdh_key_exchange",
+                ret );
+            return( ret );
+        }
+    }
+    else
+#endif /* MBEDTLS_ECDH_C */
+    if( 0 /* Other kinds of KEMs */ )
+    {
+    }
+    else
+    {
+        ((void) ssl);
+        ((void) named_group);
+        ((void) buf);
+        ((void) end);
+        ret = MBEDTLS_ERR_SSL_INTERNAL_ERROR;
+    }
+
+    return( ret );
+}
+
+/*
+ * ssl_tls13_write_key_share_ext
+ *
+ * Structure of key_share extension in ServerHello:
+ *
+ * struct {
+ *     NamedGroup group;
+ *     opaque key_exchange<1..2^16-1>;
+ * } KeyShareEntry;
+ * struct {
+ *     KeyShareEntry server_share;
+ * } KeyShareServerHello;
+ */
+static int ssl_tls13_write_key_share_ext( mbedtls_ssl_context *ssl,
+                                          unsigned char *buf,
+                                          unsigned char *end,
+                                          size_t *out_len )
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    unsigned char *p = buf;
+    uint16_t group = ssl->handshake->offered_group_id;
+    unsigned char *server_share = buf + 4;
+    size_t key_exchange_length;
+
+    *out_len = 0;
+
+    MBEDTLS_SSL_DEBUG_MSG( 3, ( "server hello, adding key share extension" ) );
+
+    /* Check if we have space for header and length fields:
+     * - extension_type         (2 bytes)
+     * - extension_data_length  (2 bytes)
+     * - group                  (2 bytes)
+     * - key_exchange_length    (2 bytes)
+     */
+    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 8 );
+    MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_KEY_SHARE, p, 0 );
+    MBEDTLS_PUT_UINT16_BE( group, server_share, 0 );
+    p += 8;
+
+    /* When we introduce PQC-ECDHE hybrids, we'll want to call this
+     * function multiple times. */
+    ret = ssl_tls13_generate_and_write_key_share(
+              ssl, group, server_share + 4, end, &key_exchange_length );
+    if( ret != 0 )
+        return( ret );
+    p += key_exchange_length;
+    MBEDTLS_PUT_UINT16_BE( key_exchange_length, server_share + 2, 0 );
+
+    MBEDTLS_PUT_UINT16_BE( p - server_share, buf, 2 );
+
+    *out_len = p - buf;
+
+    return( 0 );
+}
+
+
+/*
+ * Structure of ServerHello message:
+ *
+ *     struct {
+ *        ProtocolVersion legacy_version = 0x0303;    // TLS v1.2
+ *        Random random;
+ *        opaque legacy_session_id_echo<0..32>;
+ *        CipherSuite cipher_suite;
+ *        uint8 legacy_compression_method = 0;
+ *        Extension extensions<6..2^16-1>;
+ *    } ServerHello;
+ */
+static int ssl_tls13_write_server_hello_body( mbedtls_ssl_context *ssl,
+                                              unsigned char *buf,
+                                              unsigned char *end,
+                                              size_t *out_len )
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    unsigned char *p = buf;
+    unsigned char *p_extensions_len;
+    size_t output_len;               /* Length of buffer used by function */
+
+    *out_len = 0;
+
+    /* ...
+     * ProtocolVersion legacy_version = 0x0303; // TLS 1.2
+     * ...
+     * with ProtocolVersion defined as:
+     * uint16 ProtocolVersion;
+     */
+    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 2 );
+    MBEDTLS_PUT_UINT16_BE( 0x0303, p, 0 );
+    p += 2;
+
+    /* ...
+     * Random random;
+     * ...
+     * with Random defined as:
+     * opaque Random[MBEDTLS_SERVER_HELLO_RANDOM_LEN];
+     */
+    MBEDTLS_SSL_CHK_BUF_PTR( p, end, MBEDTLS_SERVER_HELLO_RANDOM_LEN );
+    memcpy( p, &ssl->handshake->randbytes[MBEDTLS_CLIENT_HELLO_RANDOM_LEN],
+               MBEDTLS_SERVER_HELLO_RANDOM_LEN );
+    MBEDTLS_SSL_DEBUG_BUF( 3, "server hello, random bytes",
+                           p, MBEDTLS_SERVER_HELLO_RANDOM_LEN );
+    p += MBEDTLS_SERVER_HELLO_RANDOM_LEN;
+
+    /* ...
+     * opaque legacy_session_id_echo<0..32>;
+     * ...
+     */
+    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 1 + ssl->session_negotiate->id_len );
+    *p++ = (unsigned char)ssl->session_negotiate->id_len;
+    if( ssl->session_negotiate->id_len > 0 )
+    {
+        memcpy( p, &ssl->session_negotiate->id[0],
+                ssl->session_negotiate->id_len );
+        p += ssl->session_negotiate->id_len;
+
+        MBEDTLS_SSL_DEBUG_BUF( 3, "session id", ssl->session_negotiate->id,
+                               ssl->session_negotiate->id_len );
+    }
+
+    /* ...
+     * CipherSuite cipher_suite;
+     * ...
+     * with CipherSuite defined as:
+     * uint8 CipherSuite[2];
+     */
+    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 2 );
+    MBEDTLS_PUT_UINT16_BE( ssl->session_negotiate->ciphersuite, p, 0 );
+    p += 2;
+    MBEDTLS_SSL_DEBUG_MSG( 3,
+        ( "server hello, chosen ciphersuite: %s ( id=%d )",
+          mbedtls_ssl_get_ciphersuite_name(
+            ssl->session_negotiate->ciphersuite ),
+          ssl->session_negotiate->ciphersuite ) );
+
+    /* ...
+     * uint8 legacy_compression_method = 0;
+     * ...
+     */
+    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 1 );
+    *p++ = 0x0;
+
+    /* ...
+     * Extension extensions<6..2^16-1>;
+     * ...
+     * struct {
+     *      ExtensionType extension_type; (2 bytes)
+     *      opaque extension_data<0..2^16-1>;
+     * } Extension;
+     */
+    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 2 );
+    p_extensions_len = p;
+    p += 2;
+
+    if( ( ret = ssl_tls13_write_server_hello_supported_versions_ext(
+                                            ssl, p, end, &output_len ) ) != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_RET(
+            1, "ssl_tls13_write_server_hello_supported_versions_ext", ret );
+        return( ret );
+    }
+    p += output_len;
+
+    if( mbedtls_ssl_conf_tls13_some_ephemeral_enabled( ssl ) )
+    {
+        ret = ssl_tls13_write_key_share_ext( ssl, p, end, &output_len );
+        if( ret != 0 )
+            return( ret );
+        p += output_len;
+    }
+
+    MBEDTLS_PUT_UINT16_BE( p - p_extensions_len - 2, p_extensions_len, 0 );
+
+    MBEDTLS_SSL_DEBUG_BUF( 4, "server hello extensions",
+                           p_extensions_len, p - p_extensions_len );
+
+    *out_len = p - buf;
+
+    MBEDTLS_SSL_DEBUG_BUF( 3, "server hello", buf, *out_len );
+
+    return( ret );
+}
+
+static int ssl_tls13_write_server_hello( mbedtls_ssl_context *ssl )
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    unsigned char *buf;
+    size_t buf_len, msg_len;
+
+    MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write server hello" ) );
+
+    MBEDTLS_SSL_PROC_CHK( ssl_tls13_prepare_server_hello( ssl ) );
+
+    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_start_handshake_msg( ssl,
+                                MBEDTLS_SSL_HS_SERVER_HELLO, &buf, &buf_len ) );
+
+    MBEDTLS_SSL_PROC_CHK( ssl_tls13_write_server_hello_body( ssl, buf,
+                                                             buf + buf_len,
+                                                             &msg_len ) );
+
+    mbedtls_ssl_add_hs_msg_to_checksum(
+        ssl, MBEDTLS_SSL_HS_SERVER_HELLO, buf, msg_len );
+
+    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_finish_handshake_msg(
+                              ssl, buf_len, msg_len ) );
+
+    mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_ENCRYPTED_EXTENSIONS );
+cleanup:
+
+    MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= write server hello" ) );
+    return( ret );
+}
+
+/*
  * TLS 1.3 State Machine -- server side
  */
 int mbedtls_ssl_tls13_handshake_server_step( mbedtls_ssl_context *ssl )
@@ -758,6 +1085,10 @@
 
             break;
 
+        case MBEDTLS_SSL_SERVER_HELLO:
+            ret = ssl_tls13_write_server_hello( ssl );
+            break;
+
         default:
             MBEDTLS_SSL_DEBUG_MSG( 1, ( "invalid state %d", ssl->state ) );
             return( MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE );
diff --git a/tests/ssl-opt.sh b/tests/ssl-opt.sh
index d207e54..bcbc0a0 100755
--- a/tests/ssl-opt.sh
+++ b/tests/ssl-opt.sh
@@ -10479,11 +10479,12 @@
 requires_openssl_tls1_3
 run_test    "TLS 1.3: Server side check - openssl" \
             "$P_SRV debug_level=4 crt_file=data_files/server5.crt key_file=data_files/server5.key force_version=tls13 tickets=0" \
-            "$O_NEXT_CLI -msg -tls1_3" \
+            "$O_NEXT_CLI -msg -debug -tls1_3" \
             1 \
-            -s " tls13 server state: MBEDTLS_SSL_CLIENT_HELLO" \
-            -s " tls13 server state: MBEDTLS_SSL_SERVER_HELLO" \
-            -s " SSL - The requested feature is not available" \
+            -s "tls13 server state: MBEDTLS_SSL_CLIENT_HELLO" \
+            -s "tls13 server state: MBEDTLS_SSL_SERVER_HELLO" \
+            -s "tls13 server state: MBEDTLS_SSL_ENCRYPTED_EXTENSIONS" \
+            -s "SSL - The requested feature is not available" \
             -s "=> parse client hello" \
             -s "<= parse client hello"
 
@@ -10496,9 +10497,26 @@
             "$P_SRV debug_level=4 crt_file=data_files/server5.crt key_file=data_files/server5.key force_version=tls13 tickets=0" \
             "$G_NEXT_CLI localhost -d 4 --priority=NORMAL:-VERS-ALL:+VERS-TLS1.3:%NO_TICKETS:%DISABLE_TLS13_COMPAT_MODE -V" \
             1 \
-            -s " tls13 server state: MBEDTLS_SSL_CLIENT_HELLO" \
-            -s " tls13 server state: MBEDTLS_SSL_SERVER_HELLO" \
-            -s " SSL - The requested feature is not available" \
+            -s "tls13 server state: MBEDTLS_SSL_CLIENT_HELLO" \
+            -s "tls13 server state: MBEDTLS_SSL_SERVER_HELLO" \
+            -s "tls13 server state: MBEDTLS_SSL_ENCRYPTED_EXTENSIONS" \
+            -s "SSL - The requested feature is not available" \
+            -s "=> parse client hello" \
+            -s "<= parse client hello"
+
+requires_config_enabled MBEDTLS_SSL_PROTO_TLS1_3
+requires_config_enabled MBEDTLS_DEBUG_C
+requires_config_enabled MBEDTLS_SSL_SRV_C
+requires_config_enabled MBEDTLS_SSL_CLI_C
+run_test    "TLS 1.3: Server side check - mbedtls" \
+            "$P_SRV debug_level=4 crt_file=data_files/server5.crt key_file=data_files/server5.key force_version=tls13 tickets=0" \
+            "$P_CLI debug_level=4 force_version=tls13" \
+            1 \
+            -s "tls13 server state: MBEDTLS_SSL_CLIENT_HELLO" \
+            -s "tls13 server state: MBEDTLS_SSL_SERVER_HELLO" \
+            -s "tls13 server state: MBEDTLS_SSL_ENCRYPTED_EXTENSIONS" \
+            -c "client state: MBEDTLS_SSL_ENCRYPTED_EXTENSIONS" \
+            -s "SSL - The requested feature is not available" \
             -s "=> parse client hello" \
             -s "<= parse client hello"