Add mbedtls_ssl_session_set_alpn() function
Signed-off-by: Waleed Elmelegy <waleed.elmelegy@arm.com>
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 2ec898b..948c802 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -2852,6 +2852,13 @@
const char *hostname);
#endif
+#if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_EARLY_DATA) && \
+ defined(MBEDTLS_SSL_ALPN)
+MBEDTLS_CHECK_RETURN_CRITICAL
+int mbedtls_ssl_session_set_alpn(mbedtls_ssl_session *session,
+ const char *alpn);
+#endif
+
#if defined(MBEDTLS_SSL_PROTO_TLS1_3) && defined(MBEDTLS_SSL_SESSION_TICKETS)
#define MBEDTLS_SSL_TLS1_3_MAX_ALLOWED_TICKET_LIFETIME (604800)
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index d7d26ab..f78b97d 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -2450,6 +2450,7 @@
#if defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3)
+
psa_status_t mbedtls_ssl_cipher_to_psa(mbedtls_cipher_type_t mbedtls_cipher_type,
size_t taglen,
psa_algorithm_t *alg,
@@ -3771,8 +3772,8 @@
#if defined(MBEDTLS_SSL_SRV_C) && \
defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
- const uint8_t alpn_len = (session->alpn == NULL) ?
- 0 : (uint8_t) strlen(session->alpn) + 1;
+ const uint8_t alpn_len = (session->ticket_alpn == NULL) ?
+ 0 : (uint8_t) strlen(session->ticket_alpn) + 1;
#endif
size_t needed = 4 /* ticket_age_add */
+ 1 /* ticket_flags */
@@ -3858,7 +3859,7 @@
*p++ = alpn_len;
if (alpn_len > 0) {
/* save chosen alpn */
- memcpy(p, session->alpn, alpn_len);
+ memcpy(p, session->ticket_alpn, alpn_len);
p += alpn_len;
}
#endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */
@@ -3951,6 +3952,7 @@
#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
uint8_t alpn_len;
+ int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
if (end - p < 1) {
return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
@@ -3960,12 +3962,12 @@
if (end - p < alpn_len) {
return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
}
+
if (alpn_len > 0) {
- session->alpn = mbedtls_calloc(alpn_len, sizeof(char));
- if (session->alpn == NULL) {
- return MBEDTLS_ERR_SSL_ALLOC_FAILED;
+ ret = mbedtls_ssl_session_set_alpn(session, (const char *) p);
+ if (ret != 0) {
+ return ret;
}
- memcpy(session->alpn, p, alpn_len);
p += alpn_len;
}
#endif /* MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */
@@ -4917,11 +4919,12 @@
defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
mbedtls_free(session->hostname);
#endif
+ mbedtls_free(session->ticket);
+#endif
+
#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN) && \
defined(MBEDTLS_SSL_SRV_C)
- mbedtls_free(session->alpn);
-#endif
- mbedtls_free(session->ticket);
+ mbedtls_free(session->ticket_alpn);
#endif
mbedtls_platform_zeroize(session, sizeof(mbedtls_ssl_session));
@@ -9870,4 +9873,37 @@
MBEDTLS_SSL_SERVER_NAME_INDICATION &&
MBEDTLS_SSL_CLI_C */
+#if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_EARLY_DATA) && \
+ defined(MBEDTLS_SSL_ALPN)
+int mbedtls_ssl_session_set_alpn(mbedtls_ssl_session *session,
+ const char *alpn)
+{
+ size_t alpn_len = 0;
+
+ if (alpn != NULL) {
+ alpn_len = strlen(alpn);
+
+ if (alpn_len > MBEDTLS_SSL_MAX_ALPN_NAME_LEN) {
+ return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
+ }
+ }
+
+ if (session->ticket_alpn != NULL) {
+ mbedtls_zeroize_and_free(session->ticket_alpn,
+ strlen(session->ticket_alpn));
+ }
+
+ if (alpn == NULL) {
+ session->ticket_alpn = NULL;
+ } else {
+ session->ticket_alpn = mbedtls_calloc(strlen(alpn) + 1, sizeof(char));
+ if (session->ticket_alpn == NULL) {
+ return MBEDTLS_ERR_SSL_ALLOC_FAILED;
+ }
+ memcpy(session->ticket_alpn, alpn, strlen(alpn) + 1);
+ }
+
+ return 0;
+}
+#endif /* MBEDTLS_SSL_SRV_C && MBEDTLS_SSL_EARLY_DATA && MBEDTLS_SSL_ALPN */
#endif /* MBEDTLS_SSL_TLS_C */
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 291d645..9c73c7a 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -469,12 +469,10 @@
dst->max_early_data_size = src->max_early_data_size;
#if defined(MBEDTLS_SSL_ALPN)
- if (src->alpn != NULL) {
- dst->alpn = mbedtls_calloc(strlen(src->alpn) + 1, sizeof(char));
- if (dst->alpn == NULL) {
- return MBEDTLS_ERR_SSL_ALLOC_FAILED;
- }
- memcpy(dst->alpn, src->alpn, strlen(src->alpn) + 1);
+ int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+ ret = mbedtls_ssl_session_set_alpn(dst, src->ticket_alpn);
+ if (ret != 0) {
+ return ret;
}
#endif /* MBEDTLS_SSL_ALPN */
#endif /* MBEDTLS_SSL_EARLY_DATA*/
@@ -3148,12 +3146,9 @@
MBEDTLS_SSL_PRINT_TICKET_FLAGS(4, session->ticket_flags);
#if defined(MBEDTLS_SSL_EARLY_DATA) && defined(MBEDTLS_SSL_ALPN)
- if (ssl->alpn_chosen != NULL) {
- session->alpn = mbedtls_calloc(strlen(ssl->alpn_chosen) + 1, sizeof(char));
- if (session->alpn == NULL) {
- return MBEDTLS_ERR_SSL_ALLOC_FAILED;
- }
- memcpy(session->alpn, ssl->alpn_chosen, strlen(ssl->alpn_chosen) + 1);
+ ret = mbedtls_ssl_session_set_alpn(session, ssl->alpn_chosen);
+ if (ret != 0) {
+ return ret;
}
#endif