Add ALPN checking when accepting early data

Signed-off-by: Waleed Elmelegy <waleed.elmelegy@arm.com>
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index ffca53e..ac53853 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -238,6 +238,11 @@
 #endif
 #endif /* MBEDTLS_SSL_SESSION_TICKETS && MBEDTLS_SSL_CLI_C */
 
+#if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_ALPN) && \
+    defined(MBEDTLS_SSL_EARLY_DATA)
+    dst->ticket_alpn = NULL;
+#endif
+
 #if defined(MBEDTLS_X509_CRT_PARSE_C)
 
 #if defined(MBEDTLS_SSL_KEEP_PEER_CERTIFICATE)
@@ -275,6 +280,16 @@
 
 #endif /* MBEDTLS_X509_CRT_PARSE_C */
 
+#if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_ALPN) && \
+    defined(MBEDTLS_SSL_EARLY_DATA)
+    {
+        int ret = mbedtls_ssl_session_set_ticket_alpn(dst, src->ticket_alpn);
+        if (ret != 0) {
+            return ret;
+        }
+    }
+#endif /* MBEDTLS_SSL_SRV_C && MBEDTLS_SSL_ALPN && MBEDTLS_SSL_EARLY_DATA */
+
 #if defined(MBEDTLS_SSL_SESSION_TICKETS) && defined(MBEDTLS_SSL_CLI_C)
     if (src->ticket != NULL) {
         dst->ticket = mbedtls_calloc(1, src->ticket_len);
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 2c30da8..e8afe45 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -1819,7 +1819,6 @@
      * NOTE:
      *  - The TLS version number is checked in
      *    ssl_tls13_offered_psks_check_identity_match_ticket().
-     *  - ALPN is not checked for the time being (TODO).
      */
 
     if (handshake->selected_identity != 0) {
@@ -1846,6 +1845,28 @@
         return -1;
     }
 
+#if defined(MBEDTLS_SSL_ALPN)
+    const char *alpn = mbedtls_ssl_get_alpn_protocol(ssl);
+    size_t alpn_len;
+
+    if (alpn == NULL && ssl->session_negotiate->ticket_alpn == NULL) {
+        return 0;
+    }
+
+    if (alpn != NULL) {
+        alpn_len = strlen(alpn);
+    }
+
+    if (alpn == NULL ||
+        ssl->session_negotiate->ticket_alpn == NULL ||
+        alpn_len != strlen(ssl->session_negotiate->ticket_alpn) ||
+        (memcmp(alpn, ssl->session_negotiate->ticket_alpn, alpn_len) != 0)) {
+        MBEDTLS_SSL_DEBUG_MSG(1, ("EarlyData: rejected, the selected ALPN is different "
+                                  "from the one associated with the pre-shared key."));
+        return -1;
+    }
+#endif
+
     return 0;
 }
 #endif /* MBEDTLS_SSL_EARLY_DATA */
diff --git a/tests/include/test/ssl_helpers.h b/tests/include/test/ssl_helpers.h
index 335386b..77f85c4 100644
--- a/tests/include/test/ssl_helpers.h
+++ b/tests/include/test/ssl_helpers.h
@@ -78,6 +78,10 @@
 #undef MBEDTLS_SSL_TLS1_3_LABEL
 };
 
+#if defined(MBEDTLS_SSL_ALPN)
+#define MBEDTLS_TEST_MAX_ALPN_LIST_SIZE 10
+#endif
+
 typedef struct mbedtls_test_ssl_log_pattern {
     const char *pattern;
     size_t counter;
@@ -118,6 +122,9 @@
 #if defined(MBEDTLS_SSL_CACHE_C)
     mbedtls_ssl_cache_context *cache;
 #endif
+#if defined(MBEDTLS_SSL_ALPN)
+    const char *alpn_list[MBEDTLS_TEST_MAX_ALPN_LIST_SIZE];
+#endif
 } mbedtls_test_handshake_test_options;
 
 /*
diff --git a/tests/src/test_helpers/ssl_helpers.c b/tests/src/test_helpers/ssl_helpers.c
index 963938f..55201c0 100644
--- a/tests/src/test_helpers/ssl_helpers.c
+++ b/tests/src/test_helpers/ssl_helpers.c
@@ -833,6 +833,12 @@
                                              options->max_early_data_size);
     }
 #endif
+#if defined(MBEDTLS_SSL_ALPN)
+    /* check that alpn_list contains at least one valid entry */
+    if (options->alpn_list[0] != NULL) {
+        mbedtls_ssl_conf_alpn_protocols(&(ep->conf), options->alpn_list);
+    }
+#endif
 #endif
 
 #if defined(MBEDTLS_SSL_CACHE_C) && defined(MBEDTLS_SSL_SRV_C)
diff --git a/tests/suites/test_suite_ssl.data b/tests/suites/test_suite_ssl.data
index 0ecf65c..734b945 100644
--- a/tests/suites/test_suite_ssl.data
+++ b/tests/suites/test_suite_ssl.data
@@ -3294,6 +3294,22 @@
 TLS 1.3 read early data, discard after HRR
 tls13_read_early_data:TEST_EARLY_DATA_HRR
 
+TLS 1.3 cli, early data, same ALPN
+depends_on:MBEDTLS_SSL_ALPN
+tls13_read_early_data:TEST_EARLY_DATA_SAME_ALPN
+
+TLS 1.3 cli, early data, different ALPN
+depends_on:MBEDTLS_SSL_ALPN
+tls13_read_early_data:TEST_EARLY_DATA_DIFF_ALPN
+
+TLS 1.3 cli, early data, no initial ALPN
+depends_on:MBEDTLS_SSL_ALPN
+tls13_read_early_data:TEST_EARLY_DATA_NO_INITIAL_ALPN
+
+TLS 1.3 cli, early data, no later ALPN
+depends_on:MBEDTLS_SSL_ALPN
+tls13_read_early_data:TEST_EARLY_DATA_NO_LATER_ALPN
+
 TLS 1.3 cli, early data state, early data accepted
 tls13_cli_early_data_state:TEST_EARLY_DATA_ACCEPTED
 
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index 2fe4997..67d97e4 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -17,6 +17,10 @@
 #define TEST_EARLY_DATA_NO_INDICATION_SENT 1
 #define TEST_EARLY_DATA_SERVER_REJECTS 2
 #define TEST_EARLY_DATA_HRR 3
+#define TEST_EARLY_DATA_SAME_ALPN 4
+#define TEST_EARLY_DATA_DIFF_ALPN 5
+#define TEST_EARLY_DATA_NO_INITIAL_ALPN 6
+#define TEST_EARLY_DATA_NO_LATER_ALPN 7
 
 #if (!defined(MBEDTLS_SSL_PROTO_TLS1_2)) && \
     defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_CLI_C) && \
@@ -3728,6 +3732,19 @@
     server_options.group_list = group_list;
     server_options.early_data = MBEDTLS_SSL_EARLY_DATA_ENABLED;
 
+#if defined(MBEDTLS_SSL_ALPN)
+    switch (scenario) {
+        case TEST_EARLY_DATA_SAME_ALPN:
+        case TEST_EARLY_DATA_DIFF_ALPN:
+        case TEST_EARLY_DATA_NO_LATER_ALPN:
+            client_options.alpn_list[0] = "ALPNExample";
+            client_options.alpn_list[1] = NULL;
+            server_options.alpn_list[0] = "ALPNExample";
+            server_options.alpn_list[1] = NULL;
+            break;
+    }
+#endif
+
     ret = mbedtls_test_get_tls13_ticket(&client_options, &server_options,
                                         &saved_session);
     TEST_EQUAL(ret, 0);
@@ -3756,6 +3773,33 @@
                 "EarlyData: Ignore application message before 2nd ClientHello";
             server_options.group_list = group_list + 1;
             break;
+#if defined(MBEDTLS_SSL_ALPN)
+        case TEST_EARLY_DATA_SAME_ALPN:
+            client_options.alpn_list[0] = "ALPNExample";
+            client_options.alpn_list[1] = NULL;
+            server_options.alpn_list[0] = "ALPNExample";
+            server_options.alpn_list[1] = NULL;
+            break;
+        case TEST_EARLY_DATA_DIFF_ALPN:
+        case TEST_EARLY_DATA_NO_INITIAL_ALPN:
+            client_options.alpn_list[0] = "ALPNExample2";
+            client_options.alpn_list[1] = NULL;
+            server_options.alpn_list[0] = "ALPNExample2";
+            server_options.alpn_list[1] = NULL;
+            mbedtls_debug_set_threshold(3);
+            server_pattern.pattern =
+                "EarlyData: rejected, the selected ALPN is different "
+                "from the one associated with the pre-shared key.";
+            break;
+        case TEST_EARLY_DATA_NO_LATER_ALPN:
+            client_options.alpn_list[0] = NULL;
+            server_options.alpn_list[0] = NULL;
+            mbedtls_debug_set_threshold(3);
+            server_pattern.pattern =
+                "EarlyData: rejected, the selected ALPN is different "
+                "from the one associated with the pre-shared key.";
+            break;
+#endif
 
         default:
             TEST_FAIL("Unknown scenario.");
@@ -3807,6 +3851,9 @@
 
     switch (scenario) {
         case TEST_EARLY_DATA_ACCEPTED:
+#if defined(MBEDTLS_SSL_ALPN)
+        case TEST_EARLY_DATA_SAME_ALPN:
+#endif
             TEST_EQUAL(ret, MBEDTLS_ERR_SSL_RECEIVED_EARLY_DATA);
             TEST_EQUAL(server_ep.ssl.handshake->early_data_accepted, 1);
             TEST_EQUAL(mbedtls_ssl_read_early_data(&(server_ep.ssl),
@@ -3821,6 +3868,11 @@
 
         case TEST_EARLY_DATA_SERVER_REJECTS: /* Intentional fallthrough */
         case TEST_EARLY_DATA_HRR:
+#if defined(MBEDTLS_SSL_ALPN)
+        case TEST_EARLY_DATA_DIFF_ALPN:
+        case TEST_EARLY_DATA_NO_INITIAL_ALPN:
+        case TEST_EARLY_DATA_NO_LATER_ALPN:
+#endif
             TEST_EQUAL(ret, 0);
             TEST_EQUAL(server_ep.ssl.handshake->early_data_accepted, 0);
             TEST_EQUAL(server_pattern.counter, 1);