Merge pull request #5908 from ronald-cron-arm/tls13-fixes-doc
TLS 1.3: Fixes and add documentation
Validated by the internal CI, no need to wait for the Open CI.
diff --git a/docs/architecture/tls13-support.md b/docs/architecture/tls13-support.md
index 2cf2a48..6c39bc5 100644
--- a/docs/architecture/tls13-support.md
+++ b/docs/architecture/tls13-support.md
@@ -409,3 +409,101 @@
                                      buf_len );
     ```
     even if it fits.
+
+
+Overview of handshake code organization
+---------------------------------------
+
+The TLS 1.3 handshake protocol is implemented as a state machine. The
+functions `mbedtls_ssl_tls13_handshake_{client,server}_step` are the top level
+functions of that implementation. They are implemented as a switch over all the
+possible states of the state machine.
+
+Most of the states are either dedicated to the processing or writing of an
+handshake message.
+
+The implementation does not go systematically through all states as this would
+result in too many checks of whether something needs to be done or not in a
+given state to be duplicated across several state handlers. For example, on
+client side, the states related to certificate parsing and validation are
+bypassed if the handshake is based on a pre-shared key and thus does not
+involve certificates.
+
+On the contrary, the implementation goes systematically though some states
+even if they could be bypassed if it helps in minimizing when and where inbound
+and outbound keys are updated. The `MBEDTLS_SSL_CLIENT_CERTIFICATE` state on
+client side is a example of that.
+
+The names of the handlers processing/writing an handshake message are
+prefixed with `(mbedtls_)ssl_tls13_{process,write}`. To ease the maintenance and
+reduce the risk of bugs, the code of the message processing and writing
+handlers is split into a sequence of stages.
+
+The sending of data to the peer only occurs in `mbedtls_ssl_handshake_step`
+between the calls to the handlers and as a consequence handlers do not have to
+care about the MBEDTLS_ERR_SSL_WANT_WRITE error code. Furthermore, all pending
+data are flushed before to call the next handler. That way, handlers do not
+have to worry about pending data when changing outbound keys.
+
+### Message processing handlers
+For message processing handlers, the stages are:
+
+* coordination stage: check if the state should be bypassed. This stage is
+optional. The check is either purely based on the reading of the value of some
+fields of the SSL context or based on the reading of the type of the next
+message. The latter occurs when it is not known what the next handshake message
+will be, an example of that on client side being if we are going to receive a
+CertificateRequest message or not. The intent is, apart from the next record
+reading to not modify the SSL context as this stage may be repeated if the
+next handshake message has not been received yet.
+
+* fetching stage: at this stage we are sure of the type of the handshake
+message we must receive next and we try to fetch it. If we did not go through
+a coordination stage involving the next record type reading, the next
+handshake message may not have been received yet, the handler returns with
+`MBEDTLS_ERR_SSL_WANT_READ` without changing the current state and it will be
+called again later.
+
+* pre-processing stage: prepare the SSL context for the message parsing. This
+stage is optional. Any processing that must be done before the parsing of the
+message or that can be done to simplify the parsing code. Some simple and
+partial parsing of the handshake message may append at that stage like in the
+ServerHello message pre-processing.
+
+* parsing stage: parse the message and restrict as much as possible any
+update of the SSL context. The idea of the pre-processing/parsing/post-processing
+organization is to concentrate solely on the parsing in the parsing function to
+reduce the size of its code and to simplify it.
+
+* post-processing stage: following the parsing, further update of the SSL
+context to prepare for the next incoming and outgoing messages. This stage is
+optional. For example, secret and key computations occur at this stage, as well
+as handshake messages checksum update.
+
+* state change: the state change is done in the main state handler to ease the
+navigation of the state machine transitions.
+
+
+### Message writing handlers
+For message writing handlers, the stages are:
+
+* coordination stage: check if the state should be bypassed. This stage is
+optional. The check is based on the value of some fields of the SSL context.
+
+* preparation stage: prepare for the message writing. This stage is optional.
+Any processing that must be done before the writing of the message or that can
+be done to simplify the writing code.
+
+* writing stage: write the message and restrict as much as possible any update
+of the SSL context. The idea of the preparation/writing/finalization
+organization is to concentrate solely on the writing in the writing function to
+reduce the size of its code and simplify it.
+
+* finalization stage: following the writing, further update of the SSL
+context to prepare for the next incoming and outgoing messages. This stage is
+optional. For example, handshake secret and key computation occur at that
+stage (ServerHello writing finalization), switching to handshake keys for
+outbound message on server side as well.
+
+* state change: the state change is done in the main state handler to ease
+the navigation of the state machine transitions.
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index ead0db8..416316b 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -121,11 +121,12 @@
 
 #if defined(MBEDTLS_SSL_ALPN)
 static int ssl_tls13_parse_alpn_ext( mbedtls_ssl_context *ssl,
-                               const unsigned char *buf, size_t len )
+                                     const unsigned char *buf, size_t len )
 {
-    size_t list_len, name_len;
     const unsigned char *p = buf;
     const unsigned char *end = buf + len;
+    size_t protocol_name_list_len, protocol_name_len;
+    const unsigned char *protocol_name_list_end;
 
     /* If we didn't send it, the server shouldn't send it */
     if( ssl->conf->alpn_list == NULL )
@@ -141,21 +142,22 @@
      * the "ProtocolNameList" MUST contain exactly one "ProtocolName"
      */
 
-    /* Min length is 2 ( list_len ) + 1 ( name_len ) + 1 ( name ) */
-    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, 4 );
-
-    list_len = MBEDTLS_GET_UINT16_BE( p, 0 );
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, 2 );
+    protocol_name_list_len = MBEDTLS_GET_UINT16_BE( p, 0 );
     p += 2;
-    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, list_len );
 
-    name_len = *p++;
-    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, list_len - 1 );
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, protocol_name_list_len );
+    protocol_name_list_end = p + protocol_name_list_len;
+
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, protocol_name_list_end, 1 );
+    protocol_name_len = *p++;
 
     /* Check that the server chosen protocol was in our list and save it */
-    for ( const char **alpn = ssl->conf->alpn_list; *alpn != NULL; alpn++ )
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, protocol_name_list_end, protocol_name_len );
+    for( const char **alpn = ssl->conf->alpn_list; *alpn != NULL; alpn++ )
     {
-        if( name_len == strlen( *alpn ) &&
-            memcmp( buf + 3, *alpn, name_len ) == 0 )
+        if( protocol_name_len == strlen( *alpn ) &&
+            memcmp( p, *alpn, protocol_name_len ) == 0 )
         {
             ssl->alpn_chosen = *alpn;
             return( 0 );
@@ -667,6 +669,7 @@
      * - cipher_suite               2 bytes
      * - legacy_compression_method  1 byte
      */
+     MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, legacy_session_id_echo_len + 4 );
      p += legacy_session_id_echo_len + 4;
 
     /* Case of no extension */
@@ -740,12 +743,12 @@
 }
 
 /* Returns a negative value on failure, and otherwise
- * - SSL_SERVER_HELLO_COORDINATE_HELLO or
- * - SSL_SERVER_HELLO_COORDINATE_HRR
+ * - SSL_SERVER_HELLO or
+ * - SSL_SERVER_HELLO_HRR
  * to indicate which message is expected and to be parsed next.
  */
-#define SSL_SERVER_HELLO_COORDINATE_HELLO 0
-#define SSL_SERVER_HELLO_COORDINATE_HRR 1
+#define SSL_SERVER_HELLO 0
+#define SSL_SERVER_HELLO_HRR 1
 static int ssl_server_hello_is_hrr( mbedtls_ssl_context *ssl,
                                     const unsigned char *buf,
                                     const unsigned char *end )
@@ -772,37 +775,32 @@
     if( memcmp( buf + 2, mbedtls_ssl_tls13_hello_retry_request_magic,
                 sizeof( mbedtls_ssl_tls13_hello_retry_request_magic ) ) == 0 )
     {
-        return( SSL_SERVER_HELLO_COORDINATE_HRR );
+        return( SSL_SERVER_HELLO_HRR );
     }
 
-    return( SSL_SERVER_HELLO_COORDINATE_HELLO );
+    return( SSL_SERVER_HELLO );
 }
 
-/* Fetch and preprocess
+/*
  * Returns a negative value on failure, and otherwise
- * - SSL_SERVER_HELLO_COORDINATE_HELLO or
- * - SSL_SERVER_HELLO_COORDINATE_HRR or
- * - SSL_SERVER_HELLO_COORDINATE_TLS1_2
+ * - SSL_SERVER_HELLO or
+ * - SSL_SERVER_HELLO_HRR or
+ * - SSL_SERVER_HELLO_TLS1_2
  */
-#define SSL_SERVER_HELLO_COORDINATE_TLS1_2 2
-static int ssl_tls13_server_hello_coordinate( mbedtls_ssl_context *ssl,
-                                              unsigned char **buf,
-                                              size_t *buf_len )
+#define SSL_SERVER_HELLO_TLS1_2 2
+static int ssl_tls13_preprocess_server_hello( mbedtls_ssl_context *ssl,
+                                              const unsigned char *buf,
+                                              const unsigned char *end )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    const unsigned char *end;
-
-    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_fetch_handshake_msg( ssl,
-                                             MBEDTLS_SSL_HS_SERVER_HELLO,
-                                             buf, buf_len ) );
-    end = *buf + *buf_len;
+    mbedtls_ssl_handshake_params *handshake = ssl->handshake;
 
     MBEDTLS_SSL_PROC_CHK_NEG( ssl_tls13_is_supported_versions_ext_present(
-                                  ssl, *buf, end ) );
+                                  ssl, buf, end ) );
     if( ret == 0 )
     {
         MBEDTLS_SSL_PROC_CHK_NEG(
-            ssl_tls13_is_downgrade_negotiation( ssl, *buf, end ) );
+            ssl_tls13_is_downgrade_negotiation( ssl, buf, end ) );
 
         /* If the server is negotiating TLS 1.2 or below and:
          * . we did not propose TLS 1.2 or
@@ -810,7 +808,7 @@
          *   version of the protocol and thus we are under downgrade attack
          * abort the handshake with an "illegal parameter" alert.
          */
-        if( ssl->handshake->min_tls_version > MBEDTLS_SSL_VERSION_TLS1_2 || ret )
+        if( handshake->min_tls_version > MBEDTLS_SSL_VERSION_TLS1_2 || ret )
         {
             MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
                                           MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
@@ -820,7 +818,7 @@
         ssl->keep_current_message = 1;
         ssl->tls_version = MBEDTLS_SSL_VERSION_TLS1_2;
         mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_SERVER_HELLO,
-                                            *buf, *buf_len );
+                                            buf, (size_t)(end - buf) );
 
         if( mbedtls_ssl_conf_tls13_some_ephemeral_enabled( ssl ) )
         {
@@ -829,23 +827,25 @@
                 return( ret );
         }
 
-        return( SSL_SERVER_HELLO_COORDINATE_TLS1_2 );
+        return( SSL_SERVER_HELLO_TLS1_2 );
     }
 
-    ret = ssl_server_hello_is_hrr( ssl, *buf, end );
+    handshake->extensions_present = MBEDTLS_SSL_EXT_NONE;
+
+    ret = ssl_server_hello_is_hrr( ssl, buf, end );
     switch( ret )
     {
-        case SSL_SERVER_HELLO_COORDINATE_HELLO:
+        case SSL_SERVER_HELLO:
             MBEDTLS_SSL_DEBUG_MSG( 2, ( "received ServerHello message" ) );
             break;
-        case SSL_SERVER_HELLO_COORDINATE_HRR:
+        case SSL_SERVER_HELLO_HRR:
             MBEDTLS_SSL_DEBUG_MSG( 2, ( "received HelloRetryRequest message" ) );
              /* If a client receives a second
               * HelloRetryRequest in the same connection (i.e., where the ClientHello
               * was itself in response to a HelloRetryRequest), it MUST abort the
               * handshake with an "unexpected_message" alert.
               */
-            if( ssl->handshake->hello_retry_request_count > 0 )
+            if( handshake->hello_retry_request_count > 0 )
             {
                 MBEDTLS_SSL_DEBUG_MSG( 1, ( "Multiple HRRs received" ) );
                 MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_UNEXPECTED_MESSAGE,
@@ -868,7 +868,7 @@
                 return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
             }
 
-            ssl->handshake->hello_retry_request_count++;
+            handshake->hello_retry_request_count++;
 
             break;
     }
@@ -1247,11 +1247,6 @@
     MBEDTLS_SSL_DEBUG_MSG( 1, ( "Switch to handshake keys for inbound traffic" ) );
     ssl->session_in = ssl->session_negotiate;
 
-    /*
-     * State machine update
-     */
-    mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_ENCRYPTED_EXTENSIONS );
-
 cleanup:
     if( ret != 0 )
     {
@@ -1267,17 +1262,6 @@
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
-#if defined(MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE)
-    /* If not offering early data, the client sends a dummy CCS record
-     * immediately before its second flight. This may either be before
-     * its second ClientHello or before its encrypted handshake flight.
-     */
-    mbedtls_ssl_handshake_set_state( ssl,
-            MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO );
-#else
-    mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_HELLO );
-#endif /* MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE */
-
     mbedtls_ssl_session_reset_msg_layer( ssl, 0 );
 
     /*
@@ -1306,20 +1290,17 @@
 
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> %s", __func__ ) );
 
-    /* Coordination step
-     * - Fetch record
-     * - Make sure it's either a ServerHello or a HRR.
-     * - Switch processing routine in case of HRR
-     */
-    ssl->handshake->extensions_present = MBEDTLS_SSL_EXT_NONE;
+    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_fetch_handshake_msg( ssl,
+                                             MBEDTLS_SSL_HS_SERVER_HELLO,
+                                             &buf, &buf_len ) );
 
-    ret = ssl_tls13_server_hello_coordinate( ssl, &buf, &buf_len );
+    ret = ssl_tls13_preprocess_server_hello( ssl, buf, buf + buf_len );
     if( ret < 0 )
         goto cleanup;
     else
-        is_hrr = ( ret == SSL_SERVER_HELLO_COORDINATE_HRR );
+        is_hrr = ( ret == SSL_SERVER_HELLO_HRR );
 
-    if( ret == SSL_SERVER_HELLO_COORDINATE_TLS1_2 )
+    if( ret == SSL_SERVER_HELLO_TLS1_2 )
     {
         ret = 0;
         goto cleanup;
@@ -1335,9 +1316,24 @@
                                         buf, buf_len );
 
     if( is_hrr )
+    {
         MBEDTLS_SSL_PROC_CHK( ssl_tls13_postprocess_hrr( ssl ) );
+#if defined(MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE)
+    /* If not offering early data, the client sends a dummy CCS record
+     * immediately before its second flight. This may either be before
+     * its second ClientHello or before its encrypted handshake flight.
+     */
+        mbedtls_ssl_handshake_set_state( ssl,
+            MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO );
+#else
+        mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_HELLO );
+#endif /* MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE */
+    }
     else
+    {
         MBEDTLS_SSL_PROC_CHK( ssl_tls13_postprocess_server_hello( ssl ) );
+        mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_ENCRYPTED_EXTENSIONS );
+    }
 
 cleanup:
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= %s ( %s )", __func__,
@@ -1347,56 +1343,13 @@
 
 /*
  *
- * EncryptedExtensions message
+ * Handler for MBEDTLS_SSL_ENCRYPTED_EXTENSIONS
  *
  * The EncryptedExtensions message contains any extensions which
  * should be protected, i.e., any which are not needed to establish
  * the cryptographic context.
  */
 
-/*
- * Overview
- */
-
-/* Main entry point; orchestrates the other functions */
-static int ssl_tls13_process_encrypted_extensions( mbedtls_ssl_context *ssl );
-
-static int ssl_tls13_parse_encrypted_extensions( mbedtls_ssl_context *ssl,
-                                                 const unsigned char *buf,
-                                                 const unsigned char *end );
-static int ssl_tls13_postprocess_encrypted_extensions( mbedtls_ssl_context *ssl );
-
-/*
- * Handler for  MBEDTLS_SSL_ENCRYPTED_EXTENSIONS
- */
-static int ssl_tls13_process_encrypted_extensions( mbedtls_ssl_context *ssl )
-{
-    int ret;
-    unsigned char *buf;
-    size_t buf_len;
-
-    MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> parse encrypted extensions" ) );
-
-    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_fetch_handshake_msg( ssl,
-                                             MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS,
-                                             &buf, &buf_len ) );
-
-    /* Process the message contents */
-    MBEDTLS_SSL_PROC_CHK(
-        ssl_tls13_parse_encrypted_extensions( ssl, buf, buf + buf_len ) );
-
-    mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS,
-                                        buf, buf_len );
-
-    MBEDTLS_SSL_PROC_CHK( ssl_tls13_postprocess_encrypted_extensions( ssl ) );
-
-cleanup:
-
-    MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= parse encrypted extensions" ) );
-    return( ret );
-
-}
-
 /* Parse EncryptedExtensions message
  * struct {
  *     Extension extensions<0..2^16-1>;
@@ -1416,8 +1369,8 @@
     p += 2;
 
     MBEDTLS_SSL_DEBUG_BUF( 3, "encrypted extensions", p, extensions_len );
-    extensions_end = p + extensions_len;
     MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, extensions_len );
+    extensions_end = p + extensions_len;
 
     while( p < extensions_end )
     {
@@ -1483,8 +1436,25 @@
     return( ret );
 }
 
-static int ssl_tls13_postprocess_encrypted_extensions( mbedtls_ssl_context *ssl )
+static int ssl_tls13_process_encrypted_extensions( mbedtls_ssl_context *ssl )
 {
+    int ret;
+    unsigned char *buf;
+    size_t buf_len;
+
+    MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> parse encrypted extensions" ) );
+
+    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_fetch_handshake_msg( ssl,
+                                             MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS,
+                                             &buf, &buf_len ) );
+
+    /* Process the message contents */
+    MBEDTLS_SSL_PROC_CHK(
+        ssl_tls13_parse_encrypted_extensions( ssl, buf, buf + buf_len ) );
+
+    mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_ENCRYPTED_EXTENSIONS,
+                                        buf, buf_len );
+
 #if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
     if( mbedtls_ssl_tls13_some_psk_enabled( ssl ) )
         mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_SERVER_FINISHED );
@@ -1494,12 +1464,16 @@
     ((void) ssl);
     mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_SERVER_FINISHED );
 #endif
-    return( 0 );
+
+cleanup:
+
+    MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= parse encrypted extensions" ) );
+    return( ret );
+
 }
 
 #if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
 /*
- *
  * STATE HANDLING: CertificateRequest
  *
  */
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index acd227d..893de43 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -812,7 +812,7 @@
         /* Currently, we don't have any certificate extensions defined.
          * Hence, we are sending an empty extension with length zero.
          */
-        MBEDTLS_PUT_UINT24_BE( 0, p, 0 );
+        MBEDTLS_PUT_UINT16_BE( 0, p, 0 );
         p += 2;
     }
 
@@ -1437,12 +1437,12 @@
     mbedtls_ssl_handshake_params *handshake = ssl->handshake;
 
     /* Get size of the TLS opaque key_exchange field of the KeyShareEntry struct. */
-    MBEDTLS_SSL_CHK_BUF_PTR( p, end, 2 );
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, 2 );
     uint16_t peerkey_len = MBEDTLS_GET_UINT16_BE( p, 0 );
     p += 2;
 
     /* Check if key size is consistent with given buffer length. */
-    MBEDTLS_SSL_CHK_BUF_PTR( p, end, peerkey_len );
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, peerkey_len );
 
     /* Store peer's ECDH public key. */
     memcpy( handshake->ecdh_psa_peerkey, p, peerkey_len );
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 719bf05..bfb2204 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -1095,7 +1095,7 @@
      * - extension_data_length  (2 bytes)
      * - selected_group         (2 bytes)
      */
-    MBEDTLS_SSL_CHK_BUF_READ_PTR( buf, end, 6 );
+    MBEDTLS_SSL_CHK_BUF_PTR( buf, end, 6 );
 
     MBEDTLS_PUT_UINT16_BE( MBEDTLS_TLS_EXT_KEY_SHARE, buf, 0 );
     MBEDTLS_PUT_UINT16_BE( 2, buf, 2 );
@@ -1307,8 +1307,7 @@
 /*
  * Handler for MBEDTLS_SSL_HELLO_RETRY_REQUEST
  */
-static int ssl_tls13_write_hello_retry_request_coordinate(
-                                                    mbedtls_ssl_context *ssl )
+static int ssl_tls13_prepare_hello_retry_request( mbedtls_ssl_context *ssl )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     if( ssl->handshake->hello_retry_request_count > 0 )
@@ -1342,7 +1341,7 @@
 
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write hello retry request" ) );
 
-    MBEDTLS_SSL_PROC_CHK( ssl_tls13_write_hello_retry_request_coordinate( ssl ) );
+    MBEDTLS_SSL_PROC_CHK( ssl_tls13_prepare_hello_retry_request( ssl ) );
 
     MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_start_handshake_msg(
                               ssl, MBEDTLS_SSL_HS_SERVER_HELLO,
@@ -1636,19 +1635,18 @@
         return( ret );
     }
 
-    if( ssl->handshake->certificate_request_sent )
-    {
-        mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_CERTIFICATE );
+    MBEDTLS_SSL_DEBUG_MSG( 1, ( "Switch to handshake keys for inbound traffic" ) );
+    mbedtls_ssl_set_inbound_transform( ssl, ssl->handshake->transform_handshake );
 
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "Switch to handshake keys for inbound traffic" ) );
-        mbedtls_ssl_set_inbound_transform( ssl, ssl->handshake->transform_handshake );
-    }
+    if( ssl->handshake->certificate_request_sent )
+        mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_CERTIFICATE );
     else
     {
         MBEDTLS_SSL_DEBUG_MSG( 2, ( "skip parse certificate" ) );
         MBEDTLS_SSL_DEBUG_MSG( 2, ( "skip parse certificate verify" ) );
         mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_FINISHED );
     }
+
     return( 0 );
 }
 
@@ -1659,12 +1657,6 @@
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
-    if( ! ssl->handshake->certificate_request_sent )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1,
-            ( "Switch to handshake traffic keys for inbound traffic" ) );
-        mbedtls_ssl_set_inbound_transform( ssl, ssl->handshake->transform_handshake );
-    }
     ret = mbedtls_ssl_tls13_process_finished_message( ssl );
     if( ret != 0 )
         return( ret );
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index 9be1bff..b8caca3 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -2101,14 +2101,9 @@
     TEST_ASSERT( mbedtls_ssl_is_handshake_over( &client.ssl ) == 1 );
 
     /* Make sure server state is moved to HANDSHAKE_OVER also. */
-    TEST_ASSERT( mbedtls_move_handshake_to_state( &(server.ssl),
-                                                  &(client.ssl),
-                                                  MBEDTLS_SSL_HANDSHAKE_OVER )
-                 ==  expected_handshake_result );
-    if( expected_handshake_result != 0 )
-    {
-        goto exit;
-    }
+    TEST_EQUAL( mbedtls_move_handshake_to_state( &(server.ssl),
+                                                 &(client.ssl),
+                                                 MBEDTLS_SSL_HANDSHAKE_OVER ), 0 );
 
     TEST_ASSERT( mbedtls_ssl_is_handshake_over( &server.ssl ) == 1 );