introduce sent/recv extensions field

And remove `extensions_present`

Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/ssl_client.c b/library/ssl_client.c
index d9c6781..b0d2dcf 100644
--- a/library/ssl_client.c
+++ b/library/ssl_client.c
@@ -106,6 +106,14 @@
 
     *olen = hostname_len + 9;
 
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3)
+    mbedtls_tls13_set_sent_ext_mask( ssl,
+                                     MBEDTLS_TLS_EXT_SERVERNAME );
+    MBEDTLS_SSL_DEBUG_MSG(
+        4, ( "sent %s extension",
+             mbedtls_tls13_get_extension_name(
+                 MBEDTLS_TLS_EXT_SERVERNAME ) ) );
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
     return( 0 );
 }
 #endif /* MBEDTLS_SSL_SERVER_NAME_INDICATION */
@@ -177,6 +185,14 @@
     /* Extension length = *out_len - 2 (ext_type) - 2 (ext_len) */
     MBEDTLS_PUT_UINT16_BE( *out_len - 4, buf, 2 );
 
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3)
+    mbedtls_tls13_set_sent_ext_mask( ssl,
+                                     MBEDTLS_TLS_EXT_ALPN );
+    MBEDTLS_SSL_DEBUG_MSG(
+        4, ( "sent %s extension",
+             mbedtls_tls13_get_extension_name(
+                 MBEDTLS_TLS_EXT_ALPN ) ) );
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
     return( 0 );
 }
 #endif /* MBEDTLS_SSL_ALPN */
@@ -296,7 +312,11 @@
     *out_len = p - buf;
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
-    ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_SUPPORTED_GROUPS;
+    mbedtls_tls13_set_sent_ext_mask( ssl,
+                                     MBEDTLS_TLS_EXT_SUPPORTED_GROUPS );
+    MBEDTLS_SSL_DEBUG_MSG( 4, ( "sent %s extension",
+                                mbedtls_tls13_get_extension_name(
+                                    MBEDTLS_TLS_EXT_SUPPORTED_GROUPS ) ) );
 #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
 
     return( 0 );
@@ -557,7 +577,7 @@
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
     /* Keeping track of the included extensions */
-    handshake->extensions_present = MBEDTLS_SSL_EXT_NONE;
+    handshake->sent_extensions = MBEDTLS_SSL_EXT_NONE;
 #endif
 
     /* First write extensions, then the total length */
diff --git a/library/ssl_debug_helpers.h b/library/ssl_debug_helpers.h
index 07e8c71..6b97bc6 100644
--- a/library/ssl_debug_helpers.h
+++ b/library/ssl_debug_helpers.h
@@ -52,12 +52,12 @@
 
 void mbedtls_ssl_tls13_print_extensions( const mbedtls_ssl_context *ssl,
                                          int level, const char *file, int line,
-                                         const char *hs_msg_name,
+                                         int hs_msg_type,
                                          uint32_t extensions_present );
 
-#define MBEDTLS_SSL_TLS1_3_PRINT_EXTS( level, hs_msg_name, extensions_present ) \
+#define MBEDTLS_SSL_TLS1_3_PRINT_EXTS( level, hs_msg_type, extensions_present ) \
             mbedtls_ssl_tls13_print_extensions( \
-                ssl, level, __FILE__, __LINE__, hs_msg_name, extensions_present )
+                ssl, level, __FILE__, __LINE__, hs_msg_type, extensions_present )
 #else
 
 #define MBEDTLS_SSL_TLS1_3_PRINT_EXTS( level, hs_msg_name, extensions_present )
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 10ebfff..b7f1440 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -946,9 +946,8 @@
 #endif
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
-    uint32_t extensions_present;        /*!< extension presence; Each bitfield
-                                             represents an extension and defined
-                                             as \c MBEDTLS_SSL_EXT_XXX */
+    uint32_t sent_extensions;       /*!< extensions sent by endpoint */
+    uint32_t received_extensions;   /*!< extensions received by endpoint */
 
 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED)
     unsigned char certificate_request_context_len;
@@ -1932,6 +1931,18 @@
 
 uint32_t mbedtls_tls13_get_extension_mask( uint16_t extension_type );
 
+MBEDTLS_CHECK_RETURN_CRITICAL
+int mbedtls_tls13_check_received_extensions( mbedtls_ssl_context *ssl,
+                                             int hs_msg_type,
+                                             uint32_t extension_type,
+                                             uint32_t allowed_mask );
+
+static inline void mbedtls_tls13_set_sent_ext_mask( mbedtls_ssl_context *ssl,
+                                                    uint16_t extension_type )
+{
+    ssl->handshake->sent_extensions |=
+        mbedtls_tls13_get_extension_mask( extension_type );
+}
 
 /*
  * Helper functions to check the selected key exchange mode.
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index a49f774..9947d39 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -8713,8 +8713,14 @@
     *out_len = p - buf;
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
-    ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_SIG_ALG;
+    mbedtls_tls13_set_sent_ext_mask( ssl,
+                                     MBEDTLS_TLS_EXT_SIG_ALG );
+    MBEDTLS_SSL_DEBUG_MSG(
+        4, ( "sent %s extension",
+             mbedtls_tls13_get_extension_name(
+                MBEDTLS_TLS_EXT_SIG_ALG ) ) );
 #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
+
     return( 0 );
 }
 #endif /* MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED */
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index 2e0599d..c29b90e 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -89,7 +89,12 @@
     }
 
     *out_len = 5 + versions_len;
-
+    mbedtls_tls13_set_sent_ext_mask( ssl,
+                                     MBEDTLS_TLS_EXT_SUPPORTED_VERSIONS );
+    MBEDTLS_SSL_DEBUG_MSG(
+        4, ( "sent %s extension",
+             mbedtls_tls13_get_extension_name(
+                 MBEDTLS_TLS_EXT_SUPPORTED_VERSIONS ) ) );
     return( 0 );
 }
 
@@ -360,7 +365,13 @@
 
     MBEDTLS_SSL_DEBUG_BUF( 3, "client hello, key_share extension", buf, *out_len );
 
-    ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_KEY_SHARE;
+    mbedtls_tls13_set_sent_ext_mask( ssl,
+                                     MBEDTLS_TLS_EXT_KEY_SHARE );
+    MBEDTLS_SSL_DEBUG_MSG(
+        4, ( "sent %s extension",
+
+             mbedtls_tls13_get_extension_name(
+                 MBEDTLS_TLS_EXT_KEY_SHARE ) ) );
 
 cleanup:
 
@@ -513,7 +524,6 @@
     else
         return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
 
-    ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_KEY_SHARE;
     return( ret );
 }
 
@@ -601,6 +611,13 @@
 
     *out_len = handshake->hrr_cookie_len + 6;
 
+
+    mbedtls_tls13_set_sent_ext_mask( ssl,
+                                     MBEDTLS_TLS_EXT_COOKIE );
+    MBEDTLS_SSL_DEBUG_MSG(
+        4, ( "sent %s extension",
+             mbedtls_tls13_get_extension_name(
+                 MBEDTLS_TLS_EXT_COOKIE ) ) );
     return( 0 );
 }
 
@@ -670,7 +687,13 @@
     buf[4] = ke_modes_len;
 
     *out_len = p - buf;
-    ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_PSK_KEY_EXCHANGE_MODES;
+
+    mbedtls_tls13_set_sent_ext_mask( ssl,
+                                     MBEDTLS_TLS_EXT_PSK_KEY_EXCHANGE_MODES );
+    MBEDTLS_SSL_DEBUG_MSG(
+        4, ( "sent %s extension",
+             mbedtls_tls13_get_extension_name(
+                 MBEDTLS_TLS_EXT_PSK_KEY_EXCHANGE_MODES ) ) );
     return ( 0 );
 }
 
@@ -982,8 +1005,6 @@
 
     MBEDTLS_SSL_DEBUG_BUF( 3, "pre_shared_key identities", buf, p - buf );
 
-    ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_PRE_SHARED_KEY;
-
     return( 0 );
 }
 
@@ -1038,6 +1059,13 @@
 
     MBEDTLS_SSL_DEBUG_BUF( 3, "pre_shared_key binders", buf, p - buf );
 
+    mbedtls_tls13_set_sent_ext_mask( ssl,
+                                     MBEDTLS_TLS_EXT_PRE_SHARED_KEY );
+    MBEDTLS_SSL_DEBUG_MSG(
+        4, ( "sent %s extension",
+             mbedtls_tls13_get_extension_name(
+                 MBEDTLS_TLS_EXT_PRE_SHARED_KEY ) ) );
+
     return( 0 );
 }
 
@@ -1110,8 +1138,6 @@
         return( ret );
     }
 
-    ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_PRE_SHARED_KEY;
-
     return( 0 );
 }
 #endif /* MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_SOME_PSK_ENABLED */
@@ -1389,7 +1415,7 @@
     ssl->session_negotiate->tls_version = ssl->tls_version;
 #endif /* MBEDTLS_SSL_SESSION_TICKETS */
 
-    handshake->extensions_present = MBEDTLS_SSL_EXT_NONE;
+    handshake->received_extensions = MBEDTLS_SSL_EXT_NONE;
 
     ret = ssl_server_hello_is_hrr( ssl, buf, end );
     switch( ret )
@@ -1496,10 +1522,10 @@
     mbedtls_ssl_handshake_params *handshake = ssl->handshake;
     size_t extensions_len;
     const unsigned char *extensions_end;
-    uint32_t extensions_present, allowed_extension_mask;
     uint16_t cipher_suite;
     const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
     int fatal_alert = 0;
+    uint32_t allowed_extensions_mask;
 
     /*
      * Check there is space for minimal fields
@@ -1642,8 +1668,8 @@
 
     MBEDTLS_SSL_DEBUG_BUF( 3, "server hello extensions", p, extensions_len );
 
-    extensions_present = MBEDTLS_SSL_EXT_NONE;
-    allowed_extension_mask = is_hrr ?
+    ssl->handshake->received_extensions = MBEDTLS_SSL_EXT_NONE;
+    allowed_extensions_mask = is_hrr ?
                                   MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_HRR :
                                   MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_SH;
 
@@ -1661,23 +1687,14 @@
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_end, extension_data_len );
         extension_data_end = p + extension_data_len;
 
-        /* RFC 8446 page 35
-         *
-         * If an implementation receives an extension which it recognizes and which
-         * is not specified for the message in which it appears, it MUST abort the
-         * handshake with an "illegal_parameter" alert.
-         */
-        extensions_present |= mbedtls_tls13_get_extension_mask( extension_type );
-        MBEDTLS_SSL_DEBUG_MSG( 3,
-                    ( "%s: received %s(%u) extension",
-                      is_hrr ? "hello retry request" : "server hello",
-                      mbedtls_tls13_get_extension_name( extension_type ),
-                      extension_type ) );
-        if( ( extensions_present & allowed_extension_mask ) == 0 )
-        {
-            fatal_alert = MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER;
-            goto cleanup;
-        }
+        ret = mbedtls_tls13_check_received_extensions(
+                  ssl,
+                  is_hrr ?
+                      -MBEDTLS_SSL_HS_SERVER_HELLO : MBEDTLS_SSL_HS_SERVER_HELLO,
+                  extension_type,
+                  allowed_extensions_mask );
+        if( ret != 0 )
+            return( ret );
 
         switch( extension_type )
         {
@@ -1740,11 +1757,6 @@
                 break;
 
             default:
-                MBEDTLS_SSL_DEBUG_MSG( 2,
-                    ( "%s: unexpected extension (%s(%u)) received .",
-                      is_hrr ? "hello retry request" : "server hello",
-                      mbedtls_tls13_get_extension_name( extension_type ),
-                      extension_type ) );
                 ret = MBEDTLS_ERR_SSL_INTERNAL_ERROR;
                 goto cleanup;
         }
@@ -1753,7 +1765,8 @@
     }
 
     MBEDTLS_SSL_TLS1_3_PRINT_EXTS(
-        3, is_hrr ? "HelloRetryRequest" : "ServerHello", extensions_present );
+        3, is_hrr ? -MBEDTLS_SSL_HS_SERVER_HELLO : MBEDTLS_SSL_HS_SERVER_HELLO,
+        ssl->handshake->received_extensions );
 
 cleanup:
 
@@ -1803,7 +1816,7 @@
      * 3) If only the key_share extension was received then the key
      *    exchange mode is EPHEMERAL-only.
      */
-    switch( handshake->extensions_present &
+    switch( handshake->received_extensions &
             ( MBEDTLS_SSL_EXT_PRE_SHARED_KEY | MBEDTLS_SSL_EXT_KEY_SHARE ) )
     {
         /* Only the pre_shared_key extension was received */
@@ -1986,7 +1999,6 @@
     size_t extensions_len;
     const unsigned char *p = buf;
     const unsigned char *extensions_end;
-    uint32_t extensions_present;
 
     MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, 2 );
     extensions_len = MBEDTLS_GET_UINT16_BE( p, 0 );
@@ -1996,7 +2008,7 @@
     MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, extensions_len );
     extensions_end = p + extensions_len;
 
-    extensions_present = MBEDTLS_SSL_EXT_NONE;
+    ssl->handshake->received_extensions = MBEDTLS_SSL_EXT_NONE;
 
     while( p < extensions_end )
     {
@@ -2016,26 +2028,11 @@
 
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_end, extension_data_len );
 
-        /* RFC 8446 page 35
-         *
-         * If an implementation receives an extension which it recognizes and which
-         * is not specified for the message in which it appears, it MUST abort the
-         * handshake with an "illegal_parameter" alert.
-         */
-        extensions_present |= mbedtls_tls13_get_extension_mask( extension_type );
-        MBEDTLS_SSL_DEBUG_MSG( 3,
-                    ( "encrypted extensions : received %s(%u) extension",
-                      mbedtls_tls13_get_extension_name( extension_type ),
-                      extension_type ) );
-        if( ( extensions_present & MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_EE ) == 0 )
-        {
-            MBEDTLS_SSL_DEBUG_MSG(
-                3, ( "forbidden extension received." ) );
-            MBEDTLS_SSL_PEND_FATAL_ALERT(
-                MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
-                MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-            return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-        }
+        ret = mbedtls_tls13_check_received_extensions(
+                  ssl, MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS, extension_type,
+                  MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_EE );
+        if( ret != 0 )
+            return( ret );
 
         switch( extension_type )
         {
@@ -2071,7 +2068,8 @@
         p += extension_data_len;
     }
 
-    MBEDTLS_SSL_TLS1_3_PRINT_EXTS( 3, "EncrypedExtensions", extensions_present );
+    MBEDTLS_SSL_TLS1_3_PRINT_EXTS( 3, MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS,
+                                   ssl->handshake->received_extensions );
 
     /* Check that we consumed all the message. */
     if( p != end )
@@ -2178,7 +2176,6 @@
     size_t certificate_request_context_len = 0;
     size_t extensions_len = 0;
     const unsigned char *extensions_end;
-    uint32_t extensions_present;
 
     /* ...
      * opaque certificate_request_context<0..2^8-1>
@@ -2218,13 +2215,12 @@
     MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, extensions_len );
     extensions_end = p + extensions_len;
 
-    extensions_present = MBEDTLS_SSL_EXT_NONE;
+    ssl->handshake->received_extensions = MBEDTLS_SSL_EXT_NONE;
 
     while( p < extensions_end )
     {
         unsigned int extension_type;
         size_t extension_data_len;
-        uint32_t extension_mask;
 
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_end, 4 );
         extension_type = MBEDTLS_GET_UINT16_BE( p, 0 );
@@ -2233,29 +2229,11 @@
 
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_end, extension_data_len );
 
-        /* RFC 8446 page 35
-         *
-         * If an implementation receives an extension which it recognizes and which
-         * is not specified for the message in which it appears, it MUST abort the
-         * handshake with an "illegal_parameter" alert.
-         */
-        extension_mask = mbedtls_tls13_get_extension_mask( extension_type );
-
-        MBEDTLS_SSL_DEBUG_MSG( 3,
-                    ( "encrypted extensions : received %s(%u) extension",
-                      mbedtls_tls13_get_extension_name( extension_type ),
-                      extension_type ) );
-        if( ( extension_mask & MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_CR ) == 0 )
-        {
-            MBEDTLS_SSL_DEBUG_MSG(
-                3, ( "forbidden extension received." ) );
-            MBEDTLS_SSL_PEND_FATAL_ALERT(
-                MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
-                MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-            return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-        }
-
-        extensions_present |= extension_mask;
+        ret = mbedtls_tls13_check_received_extensions(
+                  ssl, MBEDTLS_SSL_HS_CERTIFICATE_REQUEST, extension_type,
+                  MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_CR );
+        if( ret != 0 )
+            return( ret );
 
         switch( extension_type )
         {
@@ -2280,7 +2258,9 @@
         p += extension_data_len;
     }
 
-    MBEDTLS_SSL_TLS1_3_PRINT_EXTS( 3, "CertificateRequest", extensions_present );
+    MBEDTLS_SSL_TLS1_3_PRINT_EXTS( 3,
+                                   MBEDTLS_SSL_HS_CERTIFICATE_REQUEST,
+                                   ssl->handshake->received_extensions );
 
     /* Check that we consumed all the message. */
     if( p != end )
@@ -2290,11 +2270,11 @@
         goto decode_error;
     }
 
-    /* RFC 8446 page 60
+    /* RFC 8446 section 4.3.2
      *
      * The "signature_algorithms" extension MUST be specified
      */
-    if( ( extensions_present & MBEDTLS_SSL_EXT_SIG_ALG ) == 0 )
+    if( ( ssl->handshake->received_extensions & MBEDTLS_SSL_EXT_SIG_ALG ) == 0 )
     {
         MBEDTLS_SSL_DEBUG_MSG( 3,
             ( "no signature algorithms extension found" ) );
@@ -2535,16 +2515,15 @@
                                                     const unsigned char *end )
 {
     const unsigned char *p = buf;
-    uint32_t extensions_present;
 
-    ((void) ssl);
 
-    extensions_present = MBEDTLS_SSL_EXT_NONE;
+    ssl->handshake->received_extensions = MBEDTLS_SSL_EXT_NONE;
 
     while( p < end )
     {
         unsigned int extension_type;
         size_t extension_data_len;
+        int ret;
 
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, 4 );
         extension_type = MBEDTLS_GET_UINT16_BE( p, 0 );
@@ -2553,26 +2532,11 @@
 
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, extension_data_len );
 
-        /* RFC 8446 page 35
-         *
-         * If an implementation receives an extension which it recognizes and which
-         * is not specified for the message in which it appears, it MUST abort the
-         * handshake with an "illegal_parameter" alert.
-         */
-        extensions_present |= mbedtls_tls13_get_extension_mask( extension_type );
-        MBEDTLS_SSL_DEBUG_MSG( 3,
-                    ( "NewSessionTicket : received %s(%u) extension",
-                      mbedtls_tls13_get_extension_name( extension_type ),
-                      extension_type ) );
-        if( ( extensions_present & MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_NST ) == 0 )
-        {
-            MBEDTLS_SSL_DEBUG_MSG(
-                3, ( "forbidden extension received." ) );
-            MBEDTLS_SSL_PEND_FATAL_ALERT(
-                MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
-                MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-            return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-        }
+        ret = mbedtls_tls13_check_received_extensions(
+                  ssl, MBEDTLS_SSL_HS_CLIENT_HELLO, extension_type,
+                  MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_CH );
+        if( ret != 0 )
+            return( ret );
 
         switch( extension_type )
         {
@@ -2591,7 +2555,8 @@
         p +=  extension_data_len;
     }
 
-    MBEDTLS_SSL_TLS1_3_PRINT_EXTS( 3, "NewSessionTicket", extensions_present );
+    MBEDTLS_SSL_TLS1_3_PRINT_EXTS(
+        3, MBEDTLS_SSL_HS_NEW_SESSION_TICKET, ssl->handshake->received_extensions );
 
     return( 0 );
 }
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index 5eac1f1..7b66be1 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -448,7 +448,6 @@
     {
         size_t cert_data_len, extensions_len;
         const unsigned char *extensions_end;
-        uint32_t extensions_present;
 
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, certificate_list_end, 3 );
         cert_data_len = MBEDTLS_GET_UINT24_BE( p, 0 );
@@ -508,7 +507,7 @@
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, certificate_list_end, extensions_len );
 
         extensions_end = p + extensions_len;
-        extensions_present = MBEDTLS_SSL_EXT_NONE;
+        ssl->handshake->received_extensions = MBEDTLS_SSL_EXT_NONE;
 
         while( p < extensions_end )
         {
@@ -528,26 +527,12 @@
 
             MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_end, extension_data_len );
 
-            /* RFC 8446 page 35
-             *
-             * If an implementation receives an extension which it recognizes and
-             * which is not specified for the message in which it appears, it MUST
-             * abort the handshake with an "illegal_parameter" alert.
-             */
-            extensions_present |= mbedtls_tls13_get_extension_mask( extension_type );
-            MBEDTLS_SSL_DEBUG_MSG( 3,
-                        ( "encrypted extensions : received %s(%u) extension",
-                        mbedtls_tls13_get_extension_name( extension_type ),
-                        extension_type ) );
-            if( ( extensions_present & MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_CT ) == 0 )
-            {
-                MBEDTLS_SSL_DEBUG_MSG(
-                    3, ( "forbidden extension received." ) );
-                MBEDTLS_SSL_PEND_FATAL_ALERT(
-                    MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
-                    MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-                return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-            }
+            ret = mbedtls_tls13_check_received_extensions(
+                  ssl, MBEDTLS_SSL_HS_CERTIFICATE, extension_type,
+                  MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_CT );
+            if( ret != 0 )
+                return( ret );
+
             switch( extension_type )
             {
                 default:
@@ -561,7 +546,8 @@
             p += extension_data_len;
         }
 
-        MBEDTLS_SSL_TLS1_3_PRINT_EXTS( 3, "Certificate", extensions_present );
+        MBEDTLS_SSL_TLS1_3_PRINT_EXTS(
+            3, MBEDTLS_SSL_HS_CERTIFICATE, ssl->handshake->received_extensions );
     }
 
 exit:
@@ -1691,9 +1677,31 @@
     return( "unknown" );
 }
 
+static const char *ssl_tls13_get_hs_msg_name( int hs_msg_type )
+{
+    switch( hs_msg_type )
+    {
+        case MBEDTLS_SSL_HS_CLIENT_HELLO:
+            return( "ClientHello" );
+        case MBEDTLS_SSL_HS_SERVER_HELLO:
+            return( "ServerHello" );
+        case -MBEDTLS_SSL_HS_SERVER_HELLO: // HRR does not have IANA value.
+            return( "HelloRetryRequest" );
+        case MBEDTLS_SSL_HS_NEW_SESSION_TICKET:
+            return( "NewSessionTicket" );
+        case MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS:
+            return( "EncryptedExtensions" );
+        case MBEDTLS_SSL_HS_CERTIFICATE:
+            return( "Certificate" );
+        case MBEDTLS_SSL_HS_CERTIFICATE_REQUEST:
+            return( "CertificateRequest" );
+    }
+    return( NULL );
+}
+
 void mbedtls_ssl_tls13_print_extensions( const mbedtls_ssl_context *ssl,
                                          int level, const char *file, int line,
-                                         const char *hs_msg_name,
+                                         int hs_msg_type,
                                          uint32_t extensions_present )
 {
     static const struct{
@@ -1724,7 +1732,8 @@
             { MBEDTLS_SSL_EXT_KEY_SHARE, "key_share" } };
 
     mbedtls_debug_print_msg( ssl, level, file, line,
-                             "extension list of %s:", hs_msg_name );
+                             "extension list of %s:",
+                             ssl_tls13_get_hs_msg_name( hs_msg_type ) );
 
     for( unsigned i = 0;
          i < sizeof( mask_to_str_table ) / sizeof( mask_to_str_table[0] );
@@ -1742,4 +1751,63 @@
 
 #endif /* MBEDTLS_DEBUG_C */
 
+/* RFC 8446 section 4.2
+ *
+ * If an implementation receives an extension which it recognizes and which is
+ * not specified for the message in which it appears, it MUST abort the handshake
+ * with an "illegal_parameter" alert.
+ *
+ */
+
+int mbedtls_tls13_check_received_extensions( mbedtls_ssl_context *ssl,
+                                             int hs_msg_type,
+                                             uint32_t extension_type,
+                                             uint32_t allowed_mask )
+{
+    uint32_t extension_mask;
+
+#if defined(MBEDTLS_DEBUG_C)
+    const char *hs_msg_name = ssl_tls13_get_hs_msg_name( hs_msg_type );
+#endif
+
+    extension_mask = mbedtls_tls13_get_extension_mask( extension_type );
+
+    MBEDTLS_SSL_DEBUG_MSG( 3,
+                ( "%s : received %s(%x) extension",
+                  hs_msg_name,
+                  mbedtls_tls13_get_extension_name( extension_type ),
+                  (unsigned int)extension_type ) );
+
+    if( ( extension_mask & allowed_mask ) == 0 )
+    {
+        MBEDTLS_SSL_DEBUG_MSG(
+            3, ( "%s : forbidden extension received.", hs_msg_name ) );
+        MBEDTLS_SSL_PEND_FATAL_ALERT(
+            MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
+            MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+        return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+    }
+
+    ssl->handshake->received_extensions |= extension_mask;
+    switch( hs_msg_type )
+    {
+        case MBEDTLS_SSL_HS_SERVER_HELLO:
+        case -MBEDTLS_SSL_HS_SERVER_HELLO: // HRR does not have IANA value.
+        case MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS:
+        case MBEDTLS_SSL_HS_CERTIFICATE:
+            if( ( ~ssl->handshake->sent_extensions & extension_mask ) == 0 )
+                return( 0 );
+            break;
+        default:
+            return( 0 );
+    }
+
+    MBEDTLS_SSL_DEBUG_MSG(
+            3, ( "%s : forbidden extension received.", hs_msg_name ) );
+    MBEDTLS_SSL_PEND_FATAL_ALERT(
+        MBEDTLS_SSL_ALERT_MSG_UNSUPPORTED_EXT,
+        MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+    return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+}
+
 #endif /* MBEDTLS_SSL_TLS_C && MBEDTLS_SSL_PROTO_TLS1_3 */
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 32f64d7..4fdd6ad 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -930,7 +930,7 @@
 static int ssl_tls13_client_hello_has_exts( mbedtls_ssl_context *ssl,
                                             int exts_mask )
 {
-    int masked = ssl->handshake->extensions_present & exts_mask;
+    int masked = ssl->handshake->received_extensions & exts_mask;
     return( masked == exts_mask );
 }
 
@@ -1239,7 +1239,6 @@
     const unsigned char *cipher_suites_end;
     size_t extensions_len;
     const unsigned char *extensions_end;
-    uint32_t extensions_present;
     int hrr_required = 0;
 
 #if defined(MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_SOME_PSK_ENABLED)
@@ -1248,8 +1247,6 @@
     const unsigned char *pre_shared_key_ext_end = NULL;
 #endif
 
-    extensions_present = MBEDTLS_SSL_EXT_NONE;
-
     /*
      * ClientHello layout:
      *     0  .   1   protocol version
@@ -1419,20 +1416,23 @@
 
     MBEDTLS_SSL_DEBUG_BUF( 3, "client hello extensions", p, extensions_len );
 
+    ssl->handshake->received_extensions = MBEDTLS_SSL_EXT_NONE;
+
     while( p < extensions_end )
     {
         unsigned int extension_type;
         size_t extension_data_len;
         const unsigned char *extension_data_end;
 
-        /* RFC 8446, page 57
+        /* RFC 8446, section 4.2.11
          *
          * The "pre_shared_key" extension MUST be the last extension in the
          * ClientHello (this facilitates implementation as described below).
          * Servers MUST check that it is the last extension and otherwise fail
          * the handshake with an "illegal_parameter" alert.
          */
-        if( extensions_present & MBEDTLS_SSL_EXT_PRE_SHARED_KEY )
+        if( ssl->handshake->received_extensions &
+            mbedtls_tls13_get_extension_mask( MBEDTLS_TLS_EXT_PRE_SHARED_KEY ) )
         {
             MBEDTLS_SSL_DEBUG_MSG(
                 3, ( "pre_shared_key is not last extension." ) );
@@ -1450,26 +1450,11 @@
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_end, extension_data_len );
         extension_data_end = p + extension_data_len;
 
-        /* RFC 8446 page 35
-         *
-         * If an implementation receives an extension which it recognizes and which
-         * is not specified for the message in which it appears, it MUST abort the
-         * handshake with an "illegal_parameter" alert.
-         */
-        extensions_present |= mbedtls_tls13_get_extension_mask( extension_type );
-        MBEDTLS_SSL_DEBUG_MSG( 3,
-                    ( "client hello : received %s(%u) extension",
-                      mbedtls_tls13_get_extension_name( extension_type ),
-                      extension_type ) );
-        if( ( extensions_present & MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_CH ) == 0 )
-        {
-            MBEDTLS_SSL_DEBUG_MSG(
-                3, ( "forbidden extension received." ) );
-            MBEDTLS_SSL_PEND_FATAL_ALERT(
-                MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
-                MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-            return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-        }
+        ret = mbedtls_tls13_check_received_extensions(
+                  ssl, MBEDTLS_SSL_HS_CLIENT_HELLO, extension_type,
+                  MBEDTLS_SSL_TLS1_3_ALLOWED_EXTS_OF_CH );
+        if( ret != 0 )
+            return( ret );
 
         switch( extension_type )
         {
@@ -1569,7 +1554,7 @@
 
             case MBEDTLS_TLS_EXT_PRE_SHARED_KEY:
                 MBEDTLS_SSL_DEBUG_MSG( 3, ( "found pre_shared_key extension" ) );
-                if( ( extensions_present &
+                if( ( ssl->handshake->received_extensions &
                       MBEDTLS_SSL_EXT_PSK_KEY_EXCHANGE_MODES ) == 0 )
                 {
                     MBEDTLS_SSL_PEND_FATAL_ALERT(
@@ -1622,26 +1607,14 @@
                     ( "client hello: received %s(%u) extension ( ignored )",
                       mbedtls_tls13_get_extension_name( extension_type ),
                       extension_type ) );
+                break;
         }
 
         p += extension_data_len;
     }
 
-    MBEDTLS_SSL_TLS1_3_PRINT_EXTS( 3, "ClientHello", extensions_present );
-
-    /* RFC 8446 page 102
-     * -  "supported_versions" is REQUIRED for all ClientHello, ServerHello, and
-     *    HelloRetryRequest messages.
-     */
-    if( ( extensions_present & MBEDTLS_SSL_EXT_SUPPORTED_VERSIONS ) == 0 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1,
-                    ( "client hello: supported_versions not found" ) );
-        MBEDTLS_SSL_PEND_FATAL_ALERT(
-                MBEDTLS_SSL_ALERT_MSG_MISSING_EXTENSION,
-                MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-        return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-    }
+    MBEDTLS_SSL_TLS1_3_PRINT_EXTS(
+        3, MBEDTLS_SSL_HS_CLIENT_HELLO, ssl->handshake->received_extensions );
 
     mbedtls_ssl_add_hs_hdr_to_checksum( ssl,
                                         MBEDTLS_SSL_HS_CLIENT_HELLO,
@@ -1655,7 +1628,8 @@
     /* If we've settled on a PSK-based exchange, parse PSK identity ext */
     if( mbedtls_ssl_tls13_some_psk_enabled( ssl ) &&
         mbedtls_ssl_conf_tls13_some_psk_enabled( ssl ) &&
-        ( ssl->handshake->extensions_present & MBEDTLS_SSL_EXT_PRE_SHARED_KEY ) )
+        ( ssl->handshake->received_extensions &
+          MBEDTLS_SSL_EXT_PRE_SHARED_KEY ) )
     {
         ssl->handshake->update_checksum( ssl, buf,
                                          pre_shared_key_ext - buf );
@@ -1666,7 +1640,8 @@
                                                   cipher_suites_end );
         if( ret == MBEDTLS_ERR_SSL_UNKNOWN_IDENTITY )
         {
-            extensions_present &= ~MBEDTLS_SSL_EXT_PRE_SHARED_KEY;
+            ssl->handshake->received_extensions &=
+                                ~MBEDTLS_SSL_EXT_PRE_SHARED_KEY;
         }
         else if( ret != 0 )
         {
@@ -1681,7 +1656,6 @@
         ssl->handshake->update_checksum( ssl, buf, p - buf );
     }
 
-    ssl->handshake->extensions_present = extensions_present;
     ret = ssl_tls13_determine_key_exchange_mode( ssl );
     if( ret < 0 )
         return( ret );