Set hs_psk,ciphercuit_info and kex mode when writing pre-share key

Signed-off-by: Xiaokang Qian <xiaokang.qian@arm.com>
diff --git a/library/ssl_client.c b/library/ssl_client.c
index 4e42d00..08b3de8 100644
--- a/library/ssl_client.c
+++ b/library/ssl_client.c
@@ -966,24 +966,6 @@
 
 #if defined(MBEDTLS_SSL_EARLY_DATA)
         if (ssl->early_data_status == MBEDTLS_SSL_EARLY_DATA_STATUS_REJECTED) {
-            psa_algorithm_t hash_alg = PSA_ALG_NONE;
-            const unsigned char *psk;
-            size_t psk_len;
-            MBEDTLS_SSL_DEBUG_MSG(1, ("in generate early keys"));
-
-            if ((ret = mbedtls_ssl_tls13_ticket_get_psk(
-                     ssl, &hash_alg, &psk, &psk_len))
-                != 0) {
-                MBEDTLS_SSL_DEBUG_RET(
-                    1, "mbedtls_ssl_tls13_ticket_get_psk", ret);
-                goto cleanup;
-            }
-
-            if ((ret = mbedtls_ssl_set_hs_psk(ssl, psk, psk_len)) != 0) {
-                MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_set_hs_psk", ret);
-                goto cleanup;
-            }
-
             /* Start the TLS 1.3 key schedule:
              * Set the PSK and derive early secret.
              */
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 376f6cf..86f5c0b 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -1676,8 +1676,6 @@
                                       session->ciphersuite));
             return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
         }
-        ssl->handshake->ciphersuite_info = ciphersuite_info;
-        ssl->handshake->key_exchange_mode = MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_ALL;
     }
 #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
 
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index 6f91fb2..874f243 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -893,11 +893,16 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     int configured_psk_count = 0;
     unsigned char *p = buf;
-    psa_algorithm_t hash_alg;
+    psa_algorithm_t hash_alg = PSA_ALG_NONE;
     const unsigned char *identity;
     size_t identity_len;
     size_t l_binders_len = 0;
     size_t output_len;
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+    const unsigned char *psk;
+    size_t psk_len;
+    const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
+#endif
 
     *out_len = 0;
     *binders_len = 0;
@@ -962,6 +967,30 @@
 
         p += output_len;
         l_binders_len += 1 + PSA_HASH_LENGTH(hash_alg);
+
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+        MBEDTLS_SSL_DEBUG_MSG(
+            1, ("Set hs psk for early data when writing the first psk"));
+
+        ret = ssl_tls13_ticket_get_psk(ssl, &hash_alg, &psk, &psk_len);
+        if (ret != 0) {
+            MBEDTLS_SSL_DEBUG_RET(
+                1, "mbedtls_ssl_tls13_ticket_get_psk", ret);
+            return ret;
+        }
+
+        ret = mbedtls_ssl_set_hs_psk(ssl, psk, psk_len);
+        if (ret  != 0) {
+            MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_set_hs_psk", ret);
+            return ret;
+        }
+
+        ciphersuite_info = mbedtls_ssl_ciphersuite_from_id(
+            ssl->session_negotiate->ciphersuite);
+        ssl->handshake->ciphersuite_info = ciphersuite_info;
+        ssl->handshake->key_exchange_mode =
+            MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_ALL;
+#endif /* MBEDTLS_SSL_EARLY_DATA */
     }
 #endif /* MBEDTLS_SSL_SESSION_TICKETS */