Add prepare function to switch transform to early keys

Signed-off-by: Xiaokang Qian <xiaokang.qian@arm.com>
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index e469174..b58cc29 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -2150,6 +2150,34 @@
 }
 
 MBEDTLS_CHECK_RETURN_CRITICAL
+static int ssl_tls13_prepare_end_of_early_data(mbedtls_ssl_context *ssl)
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+
+    /* Start the TLS 1.3 key schedule: Set the PSK and derive early secret. */
+    ret = mbedtls_ssl_tls13_key_schedule_stage_early(ssl);
+    if (ret != 0) {
+        MBEDTLS_SSL_DEBUG_RET(1,
+                              "mbedtls_ssl_tls13_key_schedule_stage_early", ret);
+        return ret;
+    }
+
+    /* Derive 0-RTT key material */
+    ret = mbedtls_ssl_tls13_compute_early_transform(ssl);
+    if (ret != 0) {
+        MBEDTLS_SSL_DEBUG_RET(1,
+                              "mbedtls_ssl_tls13_compute_early_transform", ret);
+        return ret;
+    }
+
+    /* Activate transform */
+    MBEDTLS_SSL_DEBUG_MSG(1, ("Switch to early data keys for outbound traffic"));
+    mbedtls_ssl_set_outbound_transform(ssl, ssl->handshake->transform_earlydata);
+
+    return 0;
+}
+
+MBEDTLS_CHECK_RETURN_CRITICAL
 static int ssl_tls13_finalize_write_end_of_early_data(
     mbedtls_ssl_context *ssl)
 {
@@ -2175,11 +2203,11 @@
         unsigned char *buf = NULL;
         size_t buf_len;
 
+        MBEDTLS_SSL_PROC_CHK(ssl_tls13_prepare_end_of_early_data(ssl));
         MBEDTLS_SSL_DEBUG_MSG(2, ("Client write EndOfEarlyData"));
 
-        MBEDTLS_SSL_PROC_CHK(mbedtls_ssl_start_handshake_msg(ssl,
-                                                             MBEDTLS_SSL_HS_END_OF_EARLY_DATA, &buf,
-                                                             &buf_len));
+        MBEDTLS_SSL_PROC_CHK(mbedtls_ssl_start_handshake_msg(
+                                 ssl, MBEDTLS_SSL_HS_END_OF_EARLY_DATA, &buf, &buf_len));
 
         mbedtls_ssl_add_hs_hdr_to_checksum(
             ssl, MBEDTLS_SSL_HS_END_OF_EARLY_DATA, 0);