Move early keys generation into mbedtls_ssl_tls13_finalize_write_client_hello

Signed-off-by: Xiaokang Qian <xiaokang.qian@arm.com>
diff --git a/library/ssl_client.c b/library/ssl_client.c
index 08b3de8..a975d6a 100644
--- a/library/ssl_client.c
+++ b/library/ssl_client.c
@@ -962,34 +962,9 @@
         MBEDTLS_SSL_PROC_CHK(mbedtls_ssl_finish_handshake_msg(ssl,
                                                               buf_len,
                                                               msg_len));
-        mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_SERVER_HELLO);
 
-#if defined(MBEDTLS_SSL_EARLY_DATA)
-        if (ssl->early_data_status == MBEDTLS_SSL_EARLY_DATA_STATUS_REJECTED) {
-            /* Start the TLS 1.3 key schedule:
-             * Set the PSK and derive early secret.
-             */
-            ret = mbedtls_ssl_tls13_key_schedule_stage_early(ssl);
-            if (ret != 0) {
-                MBEDTLS_SSL_DEBUG_RET(1,
-                                      "mbedtls_ssl_tls13_key_schedule_stage_early", ret);
-                goto cleanup;
-            }
+        mbedtls_ssl_tls13_finalize_write_client_hello(ssl);
 
-            /* Derive early data key material */
-            ret = mbedtls_ssl_tls13_compute_early_transform(ssl);
-            if (ret != 0) {
-                MBEDTLS_SSL_DEBUG_RET(1,
-                                      "mbedtls_ssl_tls13_compute_early_transform", ret);
-                goto cleanup;
-            }
-
-            MBEDTLS_SSL_DEBUG_MSG(
-                1, ("Switch to early data keys for outbound traffic"));
-            mbedtls_ssl_set_outbound_transform(
-                ssl, ssl->handshake->transform_earlydata);
-        }
-#endif /* MBEDTLS_SSL_EARLY_DATA */
     }
 
 cleanup:
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 146dae0..e2efabd 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -2740,4 +2740,6 @@
 }
 #endif /* MBEDTLS_SSL_PROTO_TLS1_3 && MBEDTLS_SSL_SESSION_TICKETS */
 
+int mbedtls_ssl_tls13_finalize_write_client_hello(mbedtls_ssl_context *ssl);
+
 #endif /* ssl_misc.h */
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index 8f5e0fc..7a0f6b8 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -898,11 +898,6 @@
     size_t identity_len;
     size_t l_binders_len = 0;
     size_t output_len;
-#if defined(MBEDTLS_SSL_EARLY_DATA)
-    const unsigned char *psk;
-    size_t psk_len;
-    const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
-#endif
 
     *out_len = 0;
     *binders_len = 0;
@@ -968,29 +963,6 @@
         p += output_len;
         l_binders_len += 1 + PSA_HASH_LENGTH(hash_alg);
 
-#if defined(MBEDTLS_SSL_EARLY_DATA)
-        MBEDTLS_SSL_DEBUG_MSG(
-            1, ("Set hs psk for early data when writing the first psk"));
-
-        ret = ssl_tls13_ticket_get_psk(ssl, &hash_alg, &psk, &psk_len);
-        if (ret != 0) {
-            MBEDTLS_SSL_DEBUG_RET(
-                1, "ssl_tls13_ticket_get_psk", ret);
-            return ret;
-        }
-
-        ret = mbedtls_ssl_set_hs_psk(ssl, psk, psk_len);
-        if (ret  != 0) {
-            MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_set_hs_psk", ret);
-            return ret;
-        }
-
-        ciphersuite_info = mbedtls_ssl_ciphersuite_from_id(
-            ssl->session_negotiate->ciphersuite);
-        ssl->handshake->ciphersuite_info = ciphersuite_info;
-        ssl->handshake->key_exchange_mode =
-            MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_ALL;
-#endif /* MBEDTLS_SSL_EARLY_DATA */
     }
 #endif /* MBEDTLS_SSL_SESSION_TICKETS */
 
@@ -1240,6 +1212,66 @@
     return 0;
 }
 
+int mbedtls_ssl_tls13_finalize_write_client_hello(mbedtls_ssl_context *ssl)
+{
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    psa_algorithm_t hash_alg = PSA_ALG_NONE;
+    const unsigned char *psk;
+    size_t psk_len;
+    const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
+#endif
+    mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_SERVER_HELLO);
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+    if (ssl->early_data_status == MBEDTLS_SSL_EARLY_DATA_STATUS_REJECTED) {
+        MBEDTLS_SSL_DEBUG_MSG(
+            1, ("Set hs psk for early data when writing the first psk"));
+
+        ret = ssl_tls13_ticket_get_psk(ssl, &hash_alg, &psk, &psk_len);
+        if (ret != 0) {
+            MBEDTLS_SSL_DEBUG_RET(
+                1, "ssl_tls13_ticket_get_psk", ret);
+            return ret;
+        }
+
+        ret = mbedtls_ssl_set_hs_psk(ssl, psk, psk_len);
+        if (ret  != 0) {
+            MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_set_hs_psk", ret);
+            return ret;
+        }
+
+        ciphersuite_info = mbedtls_ssl_ciphersuite_from_id(
+            ssl->session_negotiate->ciphersuite);
+        ssl->handshake->ciphersuite_info = ciphersuite_info;
+        ssl->handshake->key_exchange_mode =
+            MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_ALL;
+
+        /* Start the TLS 1.3 key schedule:
+         * Set the PSK and derive early secret.
+         */
+        ret = mbedtls_ssl_tls13_key_schedule_stage_early(ssl);
+        if (ret != 0) {
+            MBEDTLS_SSL_DEBUG_RET(1,
+                                  "mbedtls_ssl_tls13_key_schedule_stage_early", ret);
+            return ret;
+        }
+
+        /* Derive early data key material */
+        ret = mbedtls_ssl_tls13_compute_early_transform(ssl);
+        if (ret != 0) {
+            MBEDTLS_SSL_DEBUG_RET(1,
+                                  "mbedtls_ssl_tls13_compute_early_transform", ret);
+            return ret;
+        }
+
+        MBEDTLS_SSL_DEBUG_MSG(
+            1, ("Switch to early data keys for outbound traffic"));
+        mbedtls_ssl_set_outbound_transform(
+            ssl, ssl->handshake->transform_earlydata);
+    }
+#endif /* MBEDTLS_SSL_EARLY_DATA */
+    return 0;
+}
 /*
  * Functions for parsing and processing Server Hello
  */