Merge pull request #6499 from xkqian/tls13_write_end_of_early_data

Tls13 write end of early data
diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h
index dbc37e8..8bc8fd0 100644
--- a/include/mbedtls/ssl.h
+++ b/include/mbedtls/ssl.h
@@ -533,7 +533,8 @@
 #define MBEDTLS_SSL_HS_SERVER_HELLO             2
 #define MBEDTLS_SSL_HS_HELLO_VERIFY_REQUEST     3
 #define MBEDTLS_SSL_HS_NEW_SESSION_TICKET       4
-#define MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS     8 // NEW IN TLS 1.3
+#define MBEDTLS_SSL_HS_END_OF_EARLY_DATA        5
+#define MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS     8
 #define MBEDTLS_SSL_HS_CERTIFICATE             11
 #define MBEDTLS_SSL_HS_SERVER_KEY_EXCHANGE     12
 #define MBEDTLS_SSL_HS_CERTIFICATE_REQUEST     13
@@ -671,10 +672,12 @@
     MBEDTLS_SSL_SERVER_HELLO_VERIFY_REQUEST_SENT,
     MBEDTLS_SSL_HELLO_RETRY_REQUEST,
     MBEDTLS_SSL_ENCRYPTED_EXTENSIONS,
+    MBEDTLS_SSL_END_OF_EARLY_DATA,
     MBEDTLS_SSL_CLIENT_CERTIFICATE_VERIFY,
     MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED,
     MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO,
     MBEDTLS_SSL_SERVER_CCS_AFTER_SERVER_HELLO,
+    MBEDTLS_SSL_CLIENT_CCS_AFTER_CLIENT_HELLO,
     MBEDTLS_SSL_SERVER_CCS_AFTER_HELLO_RETRY_REQUEST,
     MBEDTLS_SSL_HANDSHAKE_OVER,
     MBEDTLS_SSL_TLS1_3_NEW_SESSION_TICKET,
diff --git a/library/ssl_client.c b/library/ssl_client.c
index ab897c4..963f8bb 100644
--- a/library/ssl_client.c
+++ b/library/ssl_client.c
@@ -961,9 +961,20 @@
         MBEDTLS_SSL_PROC_CHK(mbedtls_ssl_finish_handshake_msg(ssl,
                                                               buf_len,
                                                               msg_len));
-        mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_SERVER_HELLO);
-    }
 
+        /*
+         * Set next state. Note that if TLS 1.3 is proposed, this may be
+         * overwritten by mbedtls_ssl_tls13_finalize_client_hello().
+         */
+        mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_SERVER_HELLO);
+
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3)
+        if (ssl->handshake->min_tls_version <=  MBEDTLS_SSL_VERSION_TLS1_3 &&
+            MBEDTLS_SSL_VERSION_TLS1_3 <= ssl->tls_version) {
+            ret = mbedtls_ssl_tls13_finalize_client_hello(ssl);
+        }
+#endif
+    }
 
 cleanup:
 
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 146dae0..ef05dca 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -2740,4 +2740,8 @@
 }
 #endif /* MBEDTLS_SSL_PROTO_TLS1_3 && MBEDTLS_SSL_SESSION_TICKETS */
 
+#if defined(MBEDTLS_SSL_CLI_C) && defined(MBEDTLS_SSL_PROTO_TLS1_3)
+int mbedtls_ssl_tls13_finalize_client_hello(mbedtls_ssl_context *ssl);
+#endif
+
 #endif /* ssl_misc.h */
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index 4aea61c..1e79afa 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -893,7 +893,7 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     int configured_psk_count = 0;
     unsigned char *p = buf;
-    psa_algorithm_t hash_alg;
+    psa_algorithm_t hash_alg = PSA_ALG_NONE;
     const unsigned char *identity;
     size_t identity_len;
     size_t l_binders_len = 0;
@@ -1092,6 +1092,7 @@
 
     MBEDTLS_SSL_CHK_BUF_READ_PTR(buf, end, 2);
     selected_identity = MBEDTLS_GET_UINT16_BE(buf, 0);
+    ssl->handshake->selected_identity = (uint16_t) selected_identity;
 
     MBEDTLS_SSL_DEBUG_MSG(3, ("selected_identity = %d", selected_identity));
 
@@ -1118,6 +1119,16 @@
         return ret;
     }
 
+    if (mbedtls_psa_translate_md(ssl->handshake->ciphersuite_info->mac)
+        != hash_alg) {
+        MBEDTLS_SSL_DEBUG_MSG(
+            1, ("Invalid ciphersuite for external psk."));
+
+        MBEDTLS_SSL_PEND_FATAL_ALERT(MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
+                                     MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER);
+        return MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER;
+    }
+
     ret = mbedtls_ssl_set_hs_psk(ssl, psk, psk_len);
     if (ret != 0) {
         MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_set_hs_psk", ret);
@@ -1211,6 +1222,80 @@
     return 0;
 }
 
+int mbedtls_ssl_tls13_finalize_client_hello(mbedtls_ssl_context *ssl)
+{
+    ((void) ssl);
+
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    psa_algorithm_t hash_alg = PSA_ALG_NONE;
+    const unsigned char *psk;
+    size_t psk_len;
+    const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
+
+    if (ssl->early_data_status == MBEDTLS_SSL_EARLY_DATA_STATUS_REJECTED) {
+#if defined(MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE)
+        mbedtls_ssl_handshake_set_state(
+            ssl, MBEDTLS_SSL_CLIENT_CCS_AFTER_CLIENT_HELLO);
+#endif
+        MBEDTLS_SSL_DEBUG_MSG(
+            1, ("Set hs psk for early data when writing the first psk"));
+
+        ret = ssl_tls13_ticket_get_psk(ssl, &hash_alg, &psk, &psk_len);
+        if (ret != 0) {
+            MBEDTLS_SSL_DEBUG_RET(
+                1, "ssl_tls13_ticket_get_psk", ret);
+            return ret;
+        }
+
+        ret = mbedtls_ssl_set_hs_psk(ssl, psk, psk_len);
+        if (ret  != 0) {
+            MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_set_hs_psk", ret);
+            return ret;
+        }
+
+        /*
+         * Early data are going to be encrypted using the ciphersuite
+         * associated with the pre-shared key used for the handshake.
+         * Note that if the server rejects early data, the handshake
+         * based on the pre-shared key may complete successfully
+         * with a selected ciphersuite different from the ciphersuite
+         * associated with the pre-shared key. Only the hashes of the
+         * two ciphersuites have to be the same. In that case, the
+         * encrypted handshake data and application data are
+         * encrypted using a different ciphersuite than the one used for
+         * the rejected early data.
+         */
+        ciphersuite_info = mbedtls_ssl_ciphersuite_from_id(
+            ssl->session_negotiate->ciphersuite);
+        ssl->handshake->ciphersuite_info = ciphersuite_info;
+
+        /* Enable psk and psk_ephermal to make stage early happy */
+        ssl->handshake->key_exchange_mode =
+            MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_ALL;
+
+        /* Start the TLS 1.3 key schedule:
+         *     Set the PSK and derive early secret.
+         */
+        ret = mbedtls_ssl_tls13_key_schedule_stage_early(ssl);
+        if (ret != 0) {
+            MBEDTLS_SSL_DEBUG_RET(
+                1, "mbedtls_ssl_tls13_key_schedule_stage_early", ret);
+            return ret;
+        }
+
+        /* Derive early data key material */
+        ret = mbedtls_ssl_tls13_compute_early_transform(ssl);
+        if (ret != 0) {
+            MBEDTLS_SSL_DEBUG_RET(
+                1, "mbedtls_ssl_tls13_compute_early_transform", ret);
+            return ret;
+        }
+
+    }
+#endif /* MBEDTLS_SSL_EARLY_DATA */
+    return 0;
+}
 /*
  * Functions for parsing and processing Server Hello
  */
@@ -1627,8 +1712,6 @@
     mbedtls_ssl_optimize_checksum(ssl, ciphersuite_info);
 
     handshake->ciphersuite_info = ciphersuite_info;
-    ssl->session_negotiate->ciphersuite = cipher_suite;
-
     MBEDTLS_SSL_DEBUG_MSG(3, ("server hello, chosen ciphersuite: ( %04x ) - %s",
                               cipher_suite, ciphersuite_info->name));
 
@@ -1824,8 +1907,39 @@
             ret = MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
             goto cleanup;
     }
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+    if (handshake->received_extensions & MBEDTLS_SSL_EXT_MASK(EARLY_DATA) &&
+        (handshake->selected_identity != 0 ||
+         handshake->ciphersuite_info->id !=
+         ssl->session_negotiate->ciphersuite)) {
+        /* RFC8446 4.2.11
+         * If the server supplies an "early_data" extension, the
+         * client MUST verify that the server's selected_identity
+         * is 0. If any other value is returned, the client MUST
+         * abort the handshake with an "illegal_parameter" alert.
+         *
+         * RFC 8446 4.2.10
+         * In order to accept early data, the server MUST have accepted a PSK
+         * cipher suite and selected the first key offered in the client's
+         * "pre_shared_key" extension. In addition, it MUST verify that the
+         * following values are the same as those associated with the
+         * selected PSK:
+         * - The TLS version number
+         * - The selected cipher suite
+         * - The selected ALPN [RFC7301] protocol, if any
+         *
+         * We check here that when early data is involved the server
+         * selected the cipher suite associated to the pre-shared key
+         * as it must have.
+         */
+        MBEDTLS_SSL_PEND_FATAL_ALERT(MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
+                                     MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER);
+        return MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER;
+    }
+#endif
 
-    if (!mbedtls_ssl_conf_tls13_check_kex_modes(ssl, handshake->key_exchange_mode)) {
+    if (!mbedtls_ssl_conf_tls13_check_kex_modes(
+            ssl, handshake->key_exchange_mode)) {
         ret = MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
         MBEDTLS_SSL_DEBUG_MSG(2,
                               ("Key exchange mode(%s) is not supported.",
@@ -1837,16 +1951,27 @@
                           ("Selected key exchange mode: %s",
                            ssl_tls13_get_kex_mode_str(handshake->key_exchange_mode)));
 
-    /* Start the TLS 1.3 key schedule: Set the PSK and derive early secret.
+    /* Start the TLS 1.3 key scheduling if not already done.
      *
-     * TODO: We don't have to do this in case we offered 0-RTT and the
-     *       server accepted it. In this case, we could skip generating
-     *       the early secret. */
-    ret = mbedtls_ssl_tls13_key_schedule_stage_early(ssl);
-    if (ret != 0) {
-        MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_tls13_key_schedule_stage_early",
-                              ret);
-        goto cleanup;
+     * If we proposed early data then we have already derived an
+     * early secret using the selected PSK and its associated hash.
+     * It means that if the negotiated key exchange mode is psk or
+     * psk_ephemeral, we have already correctly computed the
+     * early secret and thus we do not do it again. In all other
+     * cases we compute it here.
+     */
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+    if (ssl->early_data_status == MBEDTLS_SSL_EARLY_DATA_STATUS_NOT_SENT ||
+        handshake->key_exchange_mode ==
+        MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_EPHEMERAL)
+#endif
+    {
+        ret = mbedtls_ssl_tls13_key_schedule_stage_early(ssl);
+        if (ret != 0) {
+            MBEDTLS_SSL_DEBUG_RET(
+                1, "mbedtls_ssl_tls13_key_schedule_stage_early", ret);
+            goto cleanup;
+        }
     }
 
     ret = mbedtls_ssl_tls13_compute_handshake_transform(ssl);
@@ -1859,6 +1984,7 @@
 
     mbedtls_ssl_set_inbound_transform(ssl, handshake->transform_handshake);
     MBEDTLS_SSL_DEBUG_MSG(1, ("Switch to handshake keys for inbound traffic"));
+    ssl->session_negotiate->ciphersuite = handshake->ciphersuite_info->id;
     ssl->session_in = ssl->session_negotiate;
 
 cleanup:
@@ -1889,6 +2015,7 @@
         return ret;
     }
 
+    ssl->session_negotiate->ciphersuite = ssl->handshake->ciphersuite_info->id;
     return 0;
 }
 
@@ -2108,6 +2235,44 @@
 
 }
 
+/*
+ * Handler for MBEDTLS_SSL_END_OF_EARLY_DATA
+ *
+ * RFC 8446 section 4.5
+ *
+ * struct {} EndOfEarlyData;
+ *
+ * If the server sent an "early_data" extension in EncryptedExtensions, the
+ * client MUST send an EndOfEarlyData message after receiving the server
+ * Finished. Otherwise, the client MUST NOT send an EndOfEarlyData message.
+ */
+
+MBEDTLS_CHECK_RETURN_CRITICAL
+static int ssl_tls13_write_end_of_early_data(mbedtls_ssl_context *ssl)
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    unsigned char *buf = NULL;
+    size_t buf_len;
+    MBEDTLS_SSL_DEBUG_MSG(2, ("=> write EndOfEarlyData"));
+
+    MBEDTLS_SSL_PROC_CHK(mbedtls_ssl_start_handshake_msg(
+                             ssl, MBEDTLS_SSL_HS_END_OF_EARLY_DATA,
+                             &buf, &buf_len));
+
+    mbedtls_ssl_add_hs_hdr_to_checksum(
+        ssl, MBEDTLS_SSL_HS_END_OF_EARLY_DATA, 0);
+
+    MBEDTLS_SSL_PROC_CHK(
+        mbedtls_ssl_finish_handshake_msg(ssl, buf_len, 0));
+
+    mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_CLIENT_CERTIFICATE);
+
+cleanup:
+
+    MBEDTLS_SSL_DEBUG_MSG(2, ("<= write EndOfEarlyData"));
+    return ret;
+}
+
 #if defined(MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_EPHEMERAL_ENABLED)
 /*
  * STATE HANDLING: CertificateRequest
@@ -2367,13 +2532,21 @@
         return ret;
     }
 
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+    if (ssl->early_data_status == MBEDTLS_SSL_EARLY_DATA_STATUS_ACCEPTED) {
+        mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_END_OF_EARLY_DATA);
+    } else if (ssl->early_data_status == MBEDTLS_SSL_EARLY_DATA_STATUS_REJECTED) {
+        mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_CLIENT_CERTIFICATE);
+    } else
+#endif /* MBEDTLS_SSL_EARLY_DATA */
+    {
 #if defined(MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE)
-    mbedtls_ssl_handshake_set_state(
-        ssl,
-        MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED);
+        mbedtls_ssl_handshake_set_state(
+            ssl, MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED);
 #else
-    mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_CLIENT_CERTIFICATE);
+        mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_CLIENT_CERTIFICATE);
 #endif /* MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE */
+    }
 
     return 0;
 }
@@ -2789,6 +2962,10 @@
             ret = ssl_tls13_process_server_finished(ssl);
             break;
 
+        case MBEDTLS_SSL_END_OF_EARLY_DATA:
+            ret = ssl_tls13_write_end_of_early_data(ssl);
+            break;
+
         case MBEDTLS_SSL_CLIENT_CERTIFICATE:
             ret = ssl_tls13_write_client_certificate(ssl);
             break;
@@ -2828,6 +3005,20 @@
                 mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_CLIENT_CERTIFICATE);
             }
             break;
+
+        case MBEDTLS_SSL_CLIENT_CCS_AFTER_CLIENT_HELLO:
+            ret = mbedtls_ssl_tls13_write_change_cipher_spec(ssl);
+            if (ret == 0) {
+                mbedtls_ssl_handshake_set_state(ssl, MBEDTLS_SSL_SERVER_HELLO);
+
+#if defined(MBEDTLS_SSL_EARLY_DATA)
+                MBEDTLS_SSL_DEBUG_MSG(
+                    1, ("Switch to early data keys for outbound traffic"));
+                mbedtls_ssl_set_outbound_transform(
+                    ssl, ssl->handshake->transform_earlydata);
+#endif
+            }
+            break;
 #endif /* MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE */
 
 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index 513937e..4fb73f9 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -1378,9 +1378,8 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char hash_transcript[PSA_HASH_MAX_SIZE + 4];
     size_t hash_len;
-    const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
-    uint16_t cipher_suite = ssl->session_negotiate->ciphersuite;
-    ciphersuite_info = mbedtls_ssl_ciphersuite_from_id(cipher_suite);
+    const mbedtls_ssl_ciphersuite_t *ciphersuite_info =
+        ssl->handshake->ciphersuite_info;
 
     MBEDTLS_SSL_DEBUG_MSG(3, ("Reset SSL session for HRR"));
 
diff --git a/library/ssl_tls13_keys.c b/library/ssl_tls13_keys.c
index b92f12e..2e34ee8 100644
--- a/library/ssl_tls13_keys.c
+++ b/library/ssl_tls13_keys.c
@@ -1238,7 +1238,7 @@
     ret = mbedtls_ssl_tls13_populate_transform(
         transform_earlydata,
         ssl->conf->endpoint,
-        ssl->session_negotiate->ciphersuite,
+        handshake->ciphersuite_info->id,
         &traffic_keys,
         ssl);
     if (ret != 0) {
@@ -1699,7 +1699,7 @@
     ret = mbedtls_ssl_tls13_populate_transform(
         transform_handshake,
         ssl->conf->endpoint,
-        ssl->session_negotiate->ciphersuite,
+        handshake->ciphersuite_info->id,
         &traffic_keys,
         ssl);
     if (ret != 0) {
@@ -1789,7 +1789,7 @@
     ret = mbedtls_ssl_tls13_populate_transform(
         transform_application,
         ssl->conf->endpoint,
-        ssl->session_negotiate->ciphersuite,
+        ssl->handshake->ciphersuite_info->id,
         &traffic_keys,
         ssl);
     if (ret != 0) {
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index ef90f69..81c289a 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -2100,7 +2100,7 @@
 }
 
 MBEDTLS_CHECK_RETURN_CRITICAL
-static int ssl_tls13_finalize_write_server_hello(mbedtls_ssl_context *ssl)
+static int ssl_tls13_finalize_server_hello(mbedtls_ssl_context *ssl)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     ret = mbedtls_ssl_tls13_compute_handshake_transform(ssl);
@@ -2140,7 +2140,7 @@
     MBEDTLS_SSL_PROC_CHK(mbedtls_ssl_finish_handshake_msg(
                              ssl, buf_len, msg_len));
 
-    MBEDTLS_SSL_PROC_CHK(ssl_tls13_finalize_write_server_hello(ssl));
+    MBEDTLS_SSL_PROC_CHK(ssl_tls13_finalize_server_hello(ssl));
 
 #if defined(MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE)
     /* The server sends a dummy change_cipher_spec record immediately
diff --git a/tests/opt-testcases/tls13-misc.sh b/tests/opt-testcases/tls13-misc.sh
index 821a37b..46c371f 100755
--- a/tests/opt-testcases/tls13-misc.sh
+++ b/tests/opt-testcases/tls13-misc.sh
@@ -275,14 +275,16 @@
 run_test    "TLS 1.3 m->G: EarlyData: basic check, good" \
             "$G_NEXT_SRV -d 10 --priority=NORMAL:-VERS-ALL:+VERS-TLS1.3:+CIPHER-ALL:+ECDHE-PSK:+PSK --earlydata --disable-client-cert" \
             "$P_CLI debug_level=4 early_data=1 reco_mode=1 reconnect=1 reco_delay=900" \
-            1 \
+            0 \
             -c "Reconnecting with saved session" \
             -c "NewSessionTicket: early_data(42) extension received." \
             -c "ClientHello: early_data(42) extension exists." \
             -c "EncryptedExtensions: early_data(42) extension received." \
             -c "EncryptedExtensions: early_data(42) extension exists." \
+            -c "<= write EndOfEarlyData" \
             -s "Parsing extension 'Early Data/42' (0 bytes)" \
             -s "Sending extension Early Data/42 (0 bytes)" \
+            -s "END OF EARLY DATA (5) was received." \
             -s "early data accepted"
 
 requires_gnutls_tls1_3