Merge pull request #5523 from ronald-cron-arm/one-flush-output-development

TLS 1.3: One flush output
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index be01eba..5d0331e 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -1236,13 +1236,14 @@
 int mbedtls_ssl_fetch_input( mbedtls_ssl_context *ssl, size_t nb_want );
 
 int mbedtls_ssl_write_handshake_msg_ext( mbedtls_ssl_context *ssl,
-                                         int update_checksum );
+                                         int update_checksum,
+                                         int force_flush );
 static inline int mbedtls_ssl_write_handshake_msg( mbedtls_ssl_context *ssl )
 {
-    return( mbedtls_ssl_write_handshake_msg_ext( ssl, 1 /* update checksum */ ) );
+    return( mbedtls_ssl_write_handshake_msg_ext( ssl, 1 /* update checksum */, 1 /* force flush */ ) );
 }
 
-int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl, uint8_t force_flush );
+int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl, int force_flush );
 int mbedtls_ssl_flush_output( mbedtls_ssl_context *ssl );
 
 int mbedtls_ssl_parse_certificate( mbedtls_ssl_context *ssl );
diff --git a/library/ssl_msg.c b/library/ssl_msg.c
index 5f80ed5..ffb1346 100644
--- a/library/ssl_msg.c
+++ b/library/ssl_msg.c
@@ -2157,7 +2157,7 @@
             ( cur->type == MBEDTLS_SSL_MSG_HANDSHAKE &&
               cur->p[0] == MBEDTLS_SSL_HS_FINISHED );
 
-        uint8_t const force_flush = ssl->disable_datagram_packing == 1 ?
+        int const force_flush = ssl->disable_datagram_packing == 1 ?
             SSL_FORCE_FLUSH : SSL_DONT_FORCE_FLUSH;
 
         /* Swap epochs before sending Finished: we can't do it after
@@ -2368,7 +2368,8 @@
  *   - ssl->out_msg: the record contents (handshake headers + content)
  */
 int mbedtls_ssl_write_handshake_msg_ext( mbedtls_ssl_context *ssl,
-                                         int update_checksum )
+                                         int update_checksum,
+                                         int force_flush )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     const size_t hs_len = ssl->out_msglen - 4;
@@ -2495,7 +2496,7 @@
     else
 #endif
     {
-        if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 )
+        if( ( ret = mbedtls_ssl_write_record( ssl, force_flush ) ) != 0 )
         {
             MBEDTLS_SSL_DEBUG_RET( 1, "ssl_write_record", ret );
             return( ret );
@@ -2519,11 +2520,11 @@
  *  - ssl->out_msglen: length of the record content (excl headers)
  *  - ssl->out_msg: record content
  */
-int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl, uint8_t force_flush )
+int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl, int force_flush )
 {
     int ret, done = 0;
     size_t len = ssl->out_msglen;
-    uint8_t flush = force_flush;
+    int flush = force_flush;
 
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write record" ) );
 
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 2220721..6c7f84f 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -2728,6 +2728,21 @@
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
+    /*
+     * We may have not been able to send to the peer all the handshake data
+     * that were written into the output buffer by the previous handshake step,
+     * if the write to the network callback returned with the
+     * #MBEDTLS_ERR_SSL_WANT_WRITE error code.
+     * We proceed to the next handshake step only when all data from the
+     * previous one have been sent to the peer, thus we make sure that this is
+     * the case here by calling `mbedtls_ssl_flush_output()`. The function may
+     * return with the #MBEDTLS_ERR_SSL_WANT_WRITE error code in which case
+     * we have to wait before to go ahead.
+     * In the case of TLS 1.3, handshake step handlers do not send data to the
+     * peer. Data are only sent here and through
+     * `mbedtls_ssl_handle_pending_alert` in case an error that triggered an
+     * alert occured.
+     */
     if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 )
         return( ret );
 
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index cd1baa1..05b7941 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -919,12 +919,6 @@
     return( 0 );
 }
 
-static int ssl_tls13_finalize_client_hello( mbedtls_ssl_context *ssl )
-{
-    mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_SERVER_HELLO );
-    return( 0 );
-}
-
 static int ssl_tls13_prepare_client_hello( mbedtls_ssl_context *ssl )
 {
     int ret;
@@ -991,11 +985,12 @@
                                               msg_len );
     ssl->handshake->update_checksum( ssl, buf, msg_len );
 
-    MBEDTLS_SSL_PROC_CHK( ssl_tls13_finalize_client_hello( ssl ) );
     MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_finish_handshake_msg( ssl,
                                                                   buf_len,
                                                                   msg_len ) );
 
+    mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_SERVER_HELLO );
+
 cleanup:
 
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= write client hello" ) );
@@ -2049,52 +2044,62 @@
         ssl,
         MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED );
 #else
-#if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
     mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_CERTIFICATE );
-#else
-    mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_FINISHED );
-#endif /* MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED */
-
 #endif /* MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE */
 
     return( 0 );
 }
 
 /*
- * Handler for MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED
- */
-#if defined(MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE)
-static int ssl_tls13_write_change_cipher_spec( mbedtls_ssl_context *ssl )
-{
-    int ret;
-
-    ret = mbedtls_ssl_tls13_write_change_cipher_spec( ssl );
-    if( ret != 0 )
-        return( ret );
-
-    return( 0 );
-}
-#endif /* MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE */
-
-#if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
-/*
  * Handler for MBEDTLS_SSL_CLIENT_CERTIFICATE
  */
 static int ssl_tls13_write_client_certificate( mbedtls_ssl_context *ssl )
 {
+    int non_empty_certificate_msg = 0;
+
     MBEDTLS_SSL_DEBUG_MSG( 1,
                   ( "Switch to handshake traffic keys for outbound traffic" ) );
     mbedtls_ssl_set_outbound_transform( ssl, ssl->handshake->transform_handshake );
 
-    return( mbedtls_ssl_tls13_write_certificate( ssl ) );
+#if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
+    if( ssl->handshake->client_auth )
+    {
+        int ret = mbedtls_ssl_tls13_write_certificate( ssl );
+        if( ret != 0 )
+            return( ret );
+
+        if( mbedtls_ssl_own_cert( ssl ) != NULL )
+            non_empty_certificate_msg = 1;
+    }
+    else
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 2, ( "No certificate message to send." ) );
+    }
+#endif
+
+   if( non_empty_certificate_msg )
+   {
+        mbedtls_ssl_handshake_set_state( ssl,
+                                         MBEDTLS_SSL_CLIENT_CERTIFICATE_VERIFY );
+   }
+   else
+        mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_FINISHED );
+
+    return( 0 );
 }
 
+#if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
 /*
  * Handler for MBEDTLS_SSL_CLIENT_CERTIFICATE_VERIFY
  */
 static int ssl_tls13_write_client_certificate_verify( mbedtls_ssl_context *ssl )
 {
-    return( mbedtls_ssl_tls13_write_certificate_verify( ssl ) );
+    int ret = mbedtls_ssl_tls13_write_certificate_verify( ssl );
+
+    if( ret == 0 )
+        mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_FINISHED );
+
+    return( ret );
 }
 #endif /* MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED */
 
@@ -2105,13 +2110,6 @@
 {
     int ret;
 
-    if( !ssl->handshake->client_auth )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1,
-                  ( "Switch to handshake traffic keys for outbound traffic" ) );
-        mbedtls_ssl_set_outbound_transform( ssl,
-                                        ssl->handshake->transform_handshake );
-    }
     ret = mbedtls_ssl_tls13_write_finished_message( ssl );
     if( ret != 0 )
         return( ret );
@@ -2192,11 +2190,11 @@
             ret = ssl_tls13_process_server_finished( ssl );
             break;
 
-#if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
         case MBEDTLS_SSL_CLIENT_CERTIFICATE:
             ret = ssl_tls13_write_client_certificate( ssl );
             break;
 
+#if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
         case MBEDTLS_SSL_CLIENT_CERTIFICATE_VERIFY:
             ret = ssl_tls13_write_client_certificate_verify( ssl );
             break;
@@ -2218,9 +2216,16 @@
          * Injection of dummy-CCS's for middlebox compatibility
          */
 #if defined(MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE)
-        case MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED:
         case MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO:
-            ret = ssl_tls13_write_change_cipher_spec( ssl );
+            ret = mbedtls_ssl_tls13_write_change_cipher_spec( ssl );
+            if( ret == 0 )
+                mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_HELLO );
+            break;
+
+        case MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED:
+            ret = mbedtls_ssl_tls13_write_change_cipher_spec( ssl );
+            if( ret == 0 )
+                mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_CERTIFICATE );
             break;
 #endif /* MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE */
 
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index f006438..24a3d9d 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -103,7 +103,7 @@
     /* Add reserved 4 bytes for handshake header */
     msg_with_header_len = msg_len + 4;
     ssl->out_msglen = msg_with_header_len;
-    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_write_handshake_msg_ext( ssl, 0 ) );
+    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_write_handshake_msg_ext( ssl, 0, 0 ) );
 
 cleanup:
     return( ret );
@@ -847,54 +847,6 @@
     return( ret );
 }
 #if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
-
-/*
- * STATE HANDLING: Output Certificate
- */
-/* Check if a certificate should be written, and if yes,
- * if it is available.
- * Returns a negative error code on failure ( such as no certificate
- * being available on the server ), and otherwise
- * SSL_WRITE_CERTIFICATE_SEND or
- * SSL_WRITE_CERTIFICATE_SKIP
- * indicating that a Certificate message should be written based
- * on the configured certificate, or whether it should be silently skipped.
- */
-#define SSL_WRITE_CERTIFICATE_SEND  0
-#define SSL_WRITE_CERTIFICATE_SKIP  1
-
-static int ssl_tls13_write_certificate_coordinate( mbedtls_ssl_context *ssl )
-{
-
-    /* For PSK and ECDHE-PSK ciphersuites there is no certificate to exchange. */
-    if( mbedtls_ssl_tls13_some_psk_enabled( ssl ) )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= skip write certificate" ) );
-        return( SSL_WRITE_CERTIFICATE_SKIP );
-    }
-
-#if defined(MBEDTLS_SSL_CLI_C)
-    if( ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT )
-    {
-        /* The client MUST send a Certificate message if and only
-         * if the server has requested client authentication via a
-         * CertificateRequest message.
-         *
-         * client_auth indicates whether the server had requested
-         * client authentication.
-         */
-        if( ssl->handshake->client_auth == 0 )
-        {
-            MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= skip write certificate" ) );
-            return( SSL_WRITE_CERTIFICATE_SKIP );
-        }
-    }
-#endif /* MBEDTLS_SSL_CLI_C */
-
-    return( SSL_WRITE_CERTIFICATE_SEND );
-
-}
-
 /*
  *  enum {
  *        X509(0),
@@ -982,63 +934,29 @@
     return( 0 );
 }
 
-static int ssl_tls13_finalize_write_certificate( mbedtls_ssl_context *ssl )
-{
-#if defined(MBEDTLS_SSL_CLI_C)
-    if( ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT )
-    {
-        const mbedtls_x509_crt *crt = mbedtls_ssl_own_cert( ssl );
-        if( ssl->handshake->client_auth && crt != NULL )
-        {
-            mbedtls_ssl_handshake_set_state( ssl,
-                                        MBEDTLS_SSL_CLIENT_CERTIFICATE_VERIFY );
-        }
-        else
-            mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_FINISHED );
-        return( 0 );
-    }
-    else
-#endif /* MBEDTLS_SSL_CLI_C */
-    ((void) ssl);
-    return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
-}
-
 int mbedtls_ssl_tls13_write_certificate( mbedtls_ssl_context *ssl )
 {
     int ret;
+    unsigned char *buf;
+    size_t buf_len, msg_len;
+
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write certificate" ) );
 
-    /* Coordination: Check if we need to send a certificate. */
-    MBEDTLS_SSL_PROC_CHK_NEG( ssl_tls13_write_certificate_coordinate( ssl ) );
+    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_start_handshake_msg( ssl,
+                          MBEDTLS_SSL_HS_CERTIFICATE, &buf, &buf_len ) );
 
-    if( ret == SSL_WRITE_CERTIFICATE_SEND )
-    {
-        unsigned char *buf;
-        size_t buf_len, msg_len;
+    MBEDTLS_SSL_PROC_CHK( ssl_tls13_write_certificate_body( ssl,
+                                                            buf,
+                                                            buf + buf_len,
+                                                            &msg_len ) );
 
-        MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_start_handshake_msg( ssl,
-                   MBEDTLS_SSL_HS_CERTIFICATE, &buf, &buf_len ) );
+    mbedtls_ssl_tls13_add_hs_msg_to_checksum( ssl,
+                                              MBEDTLS_SSL_HS_CERTIFICATE,
+                                              buf,
+                                              msg_len );
 
-        MBEDTLS_SSL_PROC_CHK( ssl_tls13_write_certificate_body( ssl,
-                                                                buf,
-                                                                buf + buf_len,
-                                                                &msg_len ) );
-
-        mbedtls_ssl_tls13_add_hs_msg_to_checksum( ssl,
-                                                  MBEDTLS_SSL_HS_CERTIFICATE,
-                                                  buf,
-                                                  msg_len );
-
-        MBEDTLS_SSL_PROC_CHK( ssl_tls13_finalize_write_certificate( ssl ) );
-        MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_finish_handshake_msg(
-                                  ssl, buf_len, msg_len ) );
-    }
-    else
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= skip write certificate" ) );
-        MBEDTLS_SSL_PROC_CHK( ssl_tls13_finalize_write_certificate( ssl ) );
-    }
-
+    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_finish_handshake_msg(
+                              ssl, buf_len, msg_len ) );
 cleanup:
 
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= write certificate" ) );
@@ -1188,22 +1106,6 @@
     return( ret );
 }
 
-static int ssl_tls13_finalize_certificate_verify( mbedtls_ssl_context *ssl )
-{
-#if defined(MBEDTLS_SSL_CLI_C)
-    if( ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT )
-    {
-        mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_FINISHED );
-    }
-    else
-#endif /* MBEDTLS_SSL_CLI_C */
-    {
-        mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_SERVER_FINISHED );
-    }
-
-    return( 0 );
-}
-
 int mbedtls_ssl_tls13_write_certificate_verify( mbedtls_ssl_context *ssl )
 {
     int ret = 0;
@@ -1220,8 +1122,6 @@
 
     mbedtls_ssl_tls13_add_hs_msg_to_checksum(
         ssl, MBEDTLS_SSL_HS_CERTIFICATE_VERIFY, buf, msg_len );
-    /* Update state */
-    MBEDTLS_SSL_PROC_CHK( ssl_tls13_finalize_certificate_verify( ssl ) );
 
     MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_finish_handshake_msg(
                                 ssl, buf_len, msg_len ) );
@@ -1483,7 +1383,6 @@
     MBEDTLS_SSL_PROC_CHK( ssl_tls13_finalize_finished_message( ssl ) );
     MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_tls13_finish_handshake_msg( ssl,
                                               buf_len, msg_len ) );
-    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_flush_output( ssl ) );
 
 cleanup:
 
@@ -1516,40 +1415,6 @@
  *
  */
 #if defined(MBEDTLS_SSL_TLS1_3_COMPATIBILITY_MODE)
-
-static int ssl_tls13_finalize_change_cipher_spec( mbedtls_ssl_context* ssl )
-{
-
-#if defined(MBEDTLS_SSL_CLI_C)
-    if( ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT )
-    {
-        switch( ssl->state )
-        {
-            case MBEDTLS_SSL_CLIENT_CCS_BEFORE_2ND_CLIENT_HELLO:
-                mbedtls_ssl_handshake_set_state( ssl, MBEDTLS_SSL_CLIENT_HELLO );
-                break;
-            case MBEDTLS_SSL_CLIENT_CCS_AFTER_SERVER_FINISHED:
-#if defined(MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED)
-                mbedtls_ssl_handshake_set_state( ssl,
-                                            MBEDTLS_SSL_CLIENT_CERTIFICATE );
-#else
-                mbedtls_ssl_handshake_set_state( ssl,
-                                                 MBEDTLS_SSL_CLIENT_FINISHED );
-#endif /* MBEDTLS_KEY_EXCHANGE_WITH_CERT_ENABLED */
-
-                break;
-            default:
-                MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) );
-                return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
-        }
-    }
-#else
-    ((void) ssl);
-#endif /* MBEDTLS_SSL_CLI_C */
-
-    return( 0 );
-}
-
 static int ssl_tls13_write_change_cipher_spec_body( mbedtls_ssl_context *ssl,
                                                     unsigned char *buf,
                                                     unsigned char *end,
@@ -1570,8 +1435,6 @@
 
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write change cipher spec" ) );
 
-    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_flush_output( ssl ) );
-
     /* Write CCS message */
     MBEDTLS_SSL_PROC_CHK( ssl_tls13_write_change_cipher_spec_body(
                               ssl, ssl->out_msg,
@@ -1580,11 +1443,8 @@
 
     ssl->out_msgtype = MBEDTLS_SSL_MSG_CHANGE_CIPHER_SPEC;
 
-    /* Update state */
-    MBEDTLS_SSL_PROC_CHK( ssl_tls13_finalize_change_cipher_spec( ssl ) );
-
     /* Dispatch message */
-    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_write_record( ssl, 1 ) );
+    MBEDTLS_SSL_PROC_CHK( mbedtls_ssl_write_record( ssl, 0 ) );
 
 cleanup: