Move early keys generation into mbedtls_ssl_tls13_finalize_write_client_hello

Signed-off-by: Xiaokang Qian <xiaokang.qian@arm.com>
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index 8f5e0fc..7a0f6b8 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -898,11 +898,6 @@
     size_t identity_len;
     size_t l_binders_len = 0;
     size_t output_len;
-#if defined(MBEDTLS_SSL_EARLY_DATA)
-    const unsigned char *psk;
-    size_t psk_len;
-    const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
-#endif
 
     *out_len = 0;
     *binders_len = 0;
@@ -968,29 +963,6 @@
         p += output_len;
         l_binders_len += 1 + PSA_HASH_LENGTH(hash_alg);
 
-#if defined(MBEDTLS_SSL_EARLY_DATA)
-        MBEDTLS_SSL_DEBUG_MSG(
-            1, ("Set hs psk for early data when writing the first psk"));
-
-        ret = ssl_tls13_ticket_get_psk(ssl, &hash_alg, &psk, &psk_len);
-        if (ret != 0) {
-            MBEDTLS_SSL_DEBUG_RET(
-                1, "ssl_tls13_ticket_get_psk", ret);
-            return ret;
-        }
-
-        ret = mbedtls_ssl_set_hs_psk(ssl, psk, psk_len);
-        if (ret  != 0) {
-            MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_set_hs_psk", ret);
-            return ret;
-        }
-
-        ciphersuite_info = mbedtls_ssl_ciphersuite_from_id(
-            ssl->session_negotiate->ciphersuite);
-        ssl->handshake->ciphersuite_info = ciphersuite_info;
-        ssl->handshake->key_exchange_mode =
-            MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_ALL;
-#endif /* MBEDTLS_SSL_EARLY_DATA */
     }
 #endif /* MBEDTLS_SSL_SESSION_TICKETS */
 
@@ -1240,6 +1212,66 @@
     return 0;
 }
 
+int mbedtls_ssl_tls13_finalize_write_client_hello(mbedtls_ssl_context *ssl)
+{
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    psa_algorithm_t hash_alg = PSA_ALG_NONE;
+    const unsigned char *psk;
+    size_t psk_len;
+    const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
+#endif
+    mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_SERVER_HELLO);
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+    if (ssl->early_data_status == MBEDTLS_SSL_EARLY_DATA_STATUS_REJECTED) {
+        MBEDTLS_SSL_DEBUG_MSG(
+            1, ("Set hs psk for early data when writing the first psk"));
+
+        ret = ssl_tls13_ticket_get_psk(ssl, &hash_alg, &psk, &psk_len);
+        if (ret != 0) {
+            MBEDTLS_SSL_DEBUG_RET(
+                1, "ssl_tls13_ticket_get_psk", ret);
+            return ret;
+        }
+
+        ret = mbedtls_ssl_set_hs_psk(ssl, psk, psk_len);
+        if (ret  != 0) {
+            MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_set_hs_psk", ret);
+            return ret;
+        }
+
+        ciphersuite_info = mbedtls_ssl_ciphersuite_from_id(
+            ssl->session_negotiate->ciphersuite);
+        ssl->handshake->ciphersuite_info = ciphersuite_info;
+        ssl->handshake->key_exchange_mode =
+            MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_ALL;
+
+        /* Start the TLS 1.3 key schedule:
+         * Set the PSK and derive early secret.
+         */
+        ret = mbedtls_ssl_tls13_key_schedule_stage_early(ssl);
+        if (ret != 0) {
+            MBEDTLS_SSL_DEBUG_RET(1,
+                                  "mbedtls_ssl_tls13_key_schedule_stage_early", ret);
+            return ret;
+        }
+
+        /* Derive early data key material */
+        ret = mbedtls_ssl_tls13_compute_early_transform(ssl);
+        if (ret != 0) {
+            MBEDTLS_SSL_DEBUG_RET(1,
+                                  "mbedtls_ssl_tls13_compute_early_transform", ret);
+            return ret;
+        }
+
+        MBEDTLS_SSL_DEBUG_MSG(
+            1, ("Switch to early data keys for outbound traffic"));
+        mbedtls_ssl_set_outbound_transform(
+            ssl, ssl->handshake->transform_earlydata);
+    }
+#endif /* MBEDTLS_SSL_EARLY_DATA */
+    return 0;
+}
 /*
  * Functions for parsing and processing Server Hello
  */