tls13: srv: Deprotect and discard early data records

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
diff --git a/library/ssl_msg.c b/library/ssl_msg.c
index 20501c9..bf9a8ca 100644
--- a/library/ssl_msg.c
+++ b/library/ssl_msg.c
@@ -3985,6 +3985,31 @@
                                            rec)) != 0) {
             MBEDTLS_SSL_DEBUG_RET(1, "ssl_decrypt_buf", ret);
 
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_SRV_C)
+            /*
+             * Although the server rejected early data, it might receive early
+             * data as long as it has not received the client Finished message.
+             * It is encrypted with early keys and should be ignored as stated
+             * in section 4.2.10 of RFC 8446:
+             *
+             * "Ignore the extension and return a regular 1-RTT response. The
+             * server then skips past early data by attempting to deprotect
+             * received records using the handshake traffic key, discarding
+             * records which fail deprotection (up to the configured
+             * max_early_data_size). Once a record is deprotected successfully,
+             * it is treated as the start of the client's second flight and the
+             * server proceeds as with an ordinary 1-RTT handshake."
+             */
+            if ((old_msg_type == MBEDTLS_SSL_MSG_APPLICATION_DATA) &&
+                (ssl->discard_early_data_record ==
+                 MBEDTLS_SSL_EARLY_DATA_TRY_TO_DEPROTECT_AND_DISCARD)) {
+                MBEDTLS_SSL_DEBUG_MSG(
+                    3, ("EarlyData: deprotect and discard app data records."));
+                /* TODO: Add max_early_data_size check here. */
+                ret = MBEDTLS_ERR_SSL_CONTINUE_PROCESSING;
+            }
+#endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_SRV_C */
+
 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
             if (ret == MBEDTLS_ERR_SSL_UNEXPECTED_CID &&
                 ssl->conf->ignore_unexpected_cid
@@ -3997,6 +4022,20 @@
             return ret;
         }
 
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_SRV_C)
+        /*
+         * If the server were discarding protected records that it fails to
+         * deprotect because it has rejected early data, as we have just
+         * deprotected successfully a record, the server has to resume normal
+         * operation and fail the connection if the deprotection of a record
+         * fails.
+         */
+        if (ssl->discard_early_data_record ==
+            MBEDTLS_SSL_EARLY_DATA_TRY_TO_DEPROTECT_AND_DISCARD) {
+            ssl->discard_early_data_record = MBEDTLS_SSL_EARLY_DATA_NO_DISCARD;
+        }
+#endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_SRV_C */
+
         if (old_msg_type != rec->type) {
             MBEDTLS_SSL_DEBUG_MSG(4, ("record type after decrypt (before %d): %d",
                                       old_msg_type, rec->type));
diff --git a/tests/suites/test_suite_ssl.data b/tests/suites/test_suite_ssl.data
index c06c0a7..404818d 100644
--- a/tests/suites/test_suite_ssl.data
+++ b/tests/suites/test_suite_ssl.data
@@ -3274,5 +3274,8 @@
 TLS 1.3 resume session with ticket
 tls13_resume_session_with_ticket
 
-TLS 1.3 early data
-tls13_early_data
+TLS 1.3 early data, reference
+tls13_early_data:"reference"
+
+TLS 1.3 early data, deprotect and discard
+tls13_early_data:"deprotect and discard"
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index 2d1a757..31a973b 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -3662,9 +3662,10 @@
 /* END_CASE */
 
 /* BEGIN_CASE depends_on:MBEDTLS_SSL_EARLY_DATA:MBEDTLS_SSL_CLI_C:MBEDTLS_SSL_SRV_C:MBEDTLS_DEBUG_C:MBEDTLS_TEST_AT_LEAST_ONE_TLS1_3_CIPHERSUITE:MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_EPHEMERAL_ENABLED:MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_EPHEMERAL_ENABLED:MBEDTLS_MD_CAN_SHA256:MBEDTLS_ECP_HAVE_SECP256R1:MBEDTLS_ECP_HAVE_SECP384R1:MBEDTLS_PK_CAN_ECDSA_VERIFY:MBEDTLS_SSL_SESSION_TICKETS */
-void tls13_early_data()
+void tls13_early_data(char *scenario_string)
 {
     int ret = -1;
+    int scenario = 0;
     unsigned char buf[64];
     const char *early_data = "This is early data.";
     size_t early_data_len = strlen(early_data);
@@ -3672,6 +3673,18 @@
     mbedtls_test_handshake_test_options client_options;
     mbedtls_test_handshake_test_options server_options;
     mbedtls_ssl_session saved_session;
+    mbedtls_test_ssl_log_pattern server_pattern = { NULL, 0 };
+
+    /*
+     * Determine scenario.
+     */
+    if (strcmp(scenario_string, "reference") == 0) {
+        scenario = 0;
+    } else if (strcmp(scenario_string, "deprotect and discard") == 0) {
+        scenario = 1;
+    } else {
+        TEST_FAIL("Unknown scenario.");
+    }
 
     /*
      * Test set-up
@@ -3692,15 +3705,17 @@
     mbedtls_ssl_conf_early_data(&client_ep.conf, MBEDTLS_SSL_EARLY_DATA_ENABLED);
 
     server_options.pk_alg = MBEDTLS_PK_ECDSA;
+    server_options.srv_log_fun = mbedtls_test_ssl_log_analyzer;
+    server_options.srv_log_obj = &server_pattern;
     ret = mbedtls_test_ssl_endpoint_init(&server_ep, MBEDTLS_SSL_IS_SERVER,
                                          &server_options, NULL, NULL, NULL,
                                          NULL);
     TEST_EQUAL(ret, 0);
+    mbedtls_ssl_conf_early_data(&server_ep.conf, MBEDTLS_SSL_EARLY_DATA_ENABLED);
     mbedtls_ssl_conf_session_tickets_cb(&server_ep.conf,
                                         mbedtls_test_ticket_write,
                                         mbedtls_test_ticket_parse,
                                         NULL);
-    mbedtls_ssl_conf_early_data(&server_ep.conf, MBEDTLS_SSL_EARLY_DATA_ENABLED);
 
     ret = mbedtls_test_mock_socket_connect(&(client_ep.socket),
                                            &(server_ep.socket), 1024);
@@ -3740,6 +3755,16 @@
     ret = mbedtls_ssl_set_session(&(client_ep.ssl), &saved_session);
     TEST_EQUAL(ret, 0);
 
+    switch (scenario) {
+        case 1: /* deprotect and discard */
+            mbedtls_debug_set_threshold(3);
+            server_pattern.pattern =
+                "EarlyData: deprotect and discard app data records.";
+            mbedtls_ssl_conf_early_data(&server_ep.conf,
+                                        MBEDTLS_SSL_EARLY_DATA_DISABLED);
+            break;
+    }
+
     TEST_EQUAL(mbedtls_test_move_handshake_to_state(
                    &(client_ep.ssl), &(server_ep.ssl),
                    MBEDTLS_SSL_SERVER_HELLO), 0);
@@ -3751,18 +3776,29 @@
                            early_data_len);
     TEST_EQUAL(ret, early_data_len);
 
-    TEST_EQUAL(mbedtls_test_move_handshake_to_state(
-                   &(server_ep.ssl), &(client_ep.ssl),
-                   MBEDTLS_SSL_CLIENT_FINISHED), MBEDTLS_ERR_SSL_RECEIVED_EARLY_DATA);
+    ret = mbedtls_test_move_handshake_to_state(
+        &(server_ep.ssl), &(client_ep.ssl),
+        MBEDTLS_SSL_HANDSHAKE_WRAPUP);
 
-    TEST_EQUAL(server_ep.ssl.handshake->early_data_accepted, 1);
-    TEST_EQUAL(mbedtls_ssl_read_early_data(&(server_ep.ssl), buf, sizeof(buf)),
-               early_data_len);
-    TEST_MEMORY_COMPARE(buf, early_data_len, early_data, early_data_len);
+    switch (scenario) {
+        case 0:
+            TEST_EQUAL(ret, MBEDTLS_ERR_SSL_RECEIVED_EARLY_DATA);
+            TEST_EQUAL(server_ep.ssl.handshake->early_data_accepted, 1);
+            TEST_EQUAL(mbedtls_ssl_read_early_data(&(server_ep.ssl),
+                                                   buf, sizeof(buf)), early_data_len);
+            TEST_MEMORY_COMPARE(buf, early_data_len, early_data, early_data_len);
 
-    TEST_EQUAL(mbedtls_test_move_handshake_to_state(
-                   &(server_ep.ssl), &(client_ep.ssl),
-                   MBEDTLS_SSL_HANDSHAKE_OVER), 0);
+            TEST_EQUAL(mbedtls_test_move_handshake_to_state(
+                           &(server_ep.ssl), &(client_ep.ssl),
+                           MBEDTLS_SSL_HANDSHAKE_WRAPUP), 0);
+            break;
+
+        case 1:
+            TEST_EQUAL(ret, 0);
+            TEST_EQUAL(server_ep.ssl.handshake->early_data_accepted, 0);
+            TEST_EQUAL(server_pattern.counter, 1);
+            break;
+    }
 
 exit:
     mbedtls_test_ssl_endpoint_free(&client_ep, NULL);
@@ -3770,6 +3806,7 @@
     mbedtls_test_free_handshake_options(&client_options);
     mbedtls_test_free_handshake_options(&server_options);
     mbedtls_ssl_session_free(&saved_session);
+    mbedtls_debug_set_threshold(0);
     PSA_DONE();
 }
 /* END_CASE */