Add mbedtls_ssl_session_set_alpn() function

Signed-off-by: Waleed Elmelegy <waleed.elmelegy@arm.com>
diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h
index a6ee9a4..b05bfe1 100644
--- a/include/mbedtls/ssl.h
+++ b/include/mbedtls/ssl.h
@@ -1304,6 +1304,10 @@
     char *MBEDTLS_PRIVATE(hostname);             /*!< host name binded with tickets */
 #endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION && MBEDTLS_SSL_CLI_C */
 
+#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C)
+    char *ticket_alpn;                      /*!< ALPN negotiated in the session */
+#endif
+
 #if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_CLI_C)
     /*! Time in milliseconds when the last ticket was received. */
     mbedtls_ms_time_t MBEDTLS_PRIVATE(ticket_reception_time);
@@ -1312,9 +1316,6 @@
 
 #if defined(MBEDTLS_SSL_EARLY_DATA)
     uint32_t MBEDTLS_PRIVATE(max_early_data_size);          /*!< maximum amount of early data in tickets */
-#if defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C)
-    char *alpn;                      /*!< ALPN negotiated in the session */
-#endif
 #endif
 
 #if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 2ec898b..948c802 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -2852,6 +2852,13 @@
                                      const char *hostname);
 #endif
 
+#if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_EARLY_DATA) && \
+    defined(MBEDTLS_SSL_ALPN)
+MBEDTLS_CHECK_RETURN_CRITICAL
+int mbedtls_ssl_session_set_alpn(mbedtls_ssl_session *session,
+                                 const char *alpn);
+#endif
+
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3) && defined(MBEDTLS_SSL_SESSION_TICKETS)
 
 #define MBEDTLS_SSL_TLS1_3_MAX_ALLOWED_TICKET_LIFETIME (604800)
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index d7d26ab..f78b97d 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -2450,6 +2450,7 @@
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3)
 
+
 psa_status_t mbedtls_ssl_cipher_to_psa(mbedtls_cipher_type_t mbedtls_cipher_type,
                                        size_t taglen,
                                        psa_algorithm_t *alg,
@@ -3771,8 +3772,8 @@
 
 #if defined(MBEDTLS_SSL_SRV_C) && \
     defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
-    const uint8_t alpn_len = (session->alpn == NULL) ?
-                             0 : (uint8_t) strlen(session->alpn) + 1;
+    const uint8_t alpn_len = (session->ticket_alpn == NULL) ?
+                             0 : (uint8_t) strlen(session->ticket_alpn) + 1;
 #endif
     size_t needed =   4  /* ticket_age_add */
                     + 1  /* ticket_flags */
@@ -3858,7 +3859,7 @@
         *p++ = alpn_len;
         if (alpn_len > 0) {
             /* save chosen alpn */
-            memcpy(p, session->alpn, alpn_len);
+            memcpy(p, session->ticket_alpn, alpn_len);
             p += alpn_len;
         }
 #endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */
@@ -3951,6 +3952,7 @@
 
 #if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
         uint8_t alpn_len;
+        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
         if (end - p < 1) {
             return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
@@ -3960,12 +3962,12 @@
         if (end - p < alpn_len) {
             return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
         }
+
         if (alpn_len > 0) {
-            session->alpn = mbedtls_calloc(alpn_len, sizeof(char));
-            if (session->alpn == NULL) {
-                return MBEDTLS_ERR_SSL_ALLOC_FAILED;
+            ret = mbedtls_ssl_session_set_alpn(session, (const char *) p);
+            if (ret != 0) {
+                return ret;
             }
-            memcpy(session->alpn, p, alpn_len);
             p += alpn_len;
         }
 #endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */
@@ -4917,11 +4919,12 @@
     defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
     mbedtls_free(session->hostname);
 #endif
+    mbedtls_free(session->ticket);
+#endif
+
 #if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) && \
     defined(MBEDTLS_SSL_SRV_C)
-    mbedtls_free(session->alpn);
-#endif
-    mbedtls_free(session->ticket);
+    mbedtls_free(session->ticket_alpn);
 #endif
 
     mbedtls_platform_zeroize(session, sizeof(mbedtls_ssl_session));
@@ -9870,4 +9873,37 @@
           MBEDTLS_SSL_SERVER_NAME_INDICATION &&
           MBEDTLS_SSL_CLI_C */
 
+#if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_EARLY_DATA) && \
+    defined(MBEDTLS_SSL_ALPN)
+int mbedtls_ssl_session_set_alpn(mbedtls_ssl_session *session,
+                                 const char *alpn)
+{
+    size_t alpn_len = 0;
+
+    if (alpn != NULL) {
+        alpn_len = strlen(alpn);
+
+        if (alpn_len > MBEDTLS_SSL_MAX_ALPN_NAME_LEN) {
+            return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
+        }
+    }
+
+    if (session->ticket_alpn != NULL) {
+        mbedtls_zeroize_and_free(session->ticket_alpn,
+                                 strlen(session->ticket_alpn));
+    }
+
+    if (alpn == NULL) {
+        session->ticket_alpn = NULL;
+    } else {
+        session->ticket_alpn = mbedtls_calloc(strlen(alpn) + 1, sizeof(char));
+        if (session->ticket_alpn == NULL) {
+            return MBEDTLS_ERR_SSL_ALLOC_FAILED;
+        }
+        memcpy(session->ticket_alpn, alpn, strlen(alpn) + 1);
+    }
+
+    return 0;
+}
+#endif /* MBEDTLS_SSL_SRV_C && MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */
 #endif /* MBEDTLS_SSL_TLS_C */
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 291d645..9c73c7a 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -469,12 +469,10 @@
     dst->max_early_data_size = src->max_early_data_size;
 
 #if defined(MBEDTLS_SSL_ALPN)
-    if (src->alpn != NULL) {
-        dst->alpn = mbedtls_calloc(strlen(src->alpn) + 1, sizeof(char));
-        if (dst->alpn == NULL) {
-            return MBEDTLS_ERR_SSL_ALLOC_FAILED;
-        }
-        memcpy(dst->alpn, src->alpn, strlen(src->alpn) + 1);
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    ret = mbedtls_ssl_session_set_alpn(dst, src->ticket_alpn);
+    if (ret != 0) {
+        return ret;
     }
 #endif /* MBEDTLS_SSL_ALPN */
 #endif /* MBEDTLS_SSL_EARLY_DATA*/
@@ -3148,12 +3146,9 @@
     MBEDTLS_SSL_PRINT_TICKET_FLAGS(4, session->ticket_flags);
 
 #if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
-    if (ssl->alpn_chosen != NULL) {
-        session->alpn = mbedtls_calloc(strlen(ssl->alpn_chosen) + 1, sizeof(char));
-        if (session->alpn == NULL) {
-            return MBEDTLS_ERR_SSL_ALLOC_FAILED;
-        }
-        memcpy(session->alpn, ssl->alpn_chosen, strlen(ssl->alpn_chosen) + 1);
+    ret = mbedtls_ssl_session_set_alpn(session, ssl->alpn_chosen);
+    if (ret != 0) {
+        return ret;
     }
 #endif
 
diff --git a/tests/src/test_helpers/ssl_helpers.c b/tests/src/test_helpers/ssl_helpers.c
index 89c1bbf..9c1676f 100644
--- a/tests/src/test_helpers/ssl_helpers.c
+++ b/tests/src/test_helpers/ssl_helpers.c
@@ -1794,11 +1794,11 @@
 #if defined(MBEDTLS_SSL_EARLY_DATA)
     session->max_early_data_size = 0x87654321;
 #if defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C)
-    session->alpn = mbedtls_calloc(strlen("ALPNExample")+1, sizeof(char));
-    if (session->alpn == NULL) {
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    ret = mbedtls_ssl_session_set_alpn(session, "ALPNExample");
+    if (ret != 0) {
         return -1;
     }
-    strcpy(session->alpn, "ALPNExample");
 #endif /* MBEDTLS_SSL_ALPN && MBEDTLS_SSL_SRV_C */
 #endif /* MBEDTLS_SSL_EARLY_DATA */
 
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index da07f2c..e29667d 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -2106,11 +2106,10 @@
             original.max_early_data_size == restored.max_early_data_size);
 #if defined(MBEDTLS_SSL_ALPN) && defined(MBEDTLS_SSL_SRV_C)
         if (endpoint_type == MBEDTLS_SSL_IS_SERVER) {
-            TEST_ASSERT(original.alpn != NULL);
-            TEST_ASSERT(restored.alpn != NULL);
-            TEST_ASSERT(memcmp(original.alpn,
-                               restored.alpn,
-                               strlen(original.alpn)) == 0);
+            TEST_ASSERT(original.ticket_alpn != NULL);
+            TEST_ASSERT(restored.ticket_alpn != NULL);
+            TEST_MEMORY_COMPARE(original.ticket_alpn, strlen(original.ticket_alpn),
+                                restored.ticket_alpn, strlen(restored.ticket_alpn));
         }
 #endif
 #endif