Merge pull request #9872 from rojer/tls_hs_defrag_in

Defragment incoming TLS handshake messages
diff --git a/ChangeLog.d/tls-hs-defrag-in.txt b/ChangeLog.d/tls-hs-defrag-in.txt
new file mode 100644
index 0000000..55103c9
--- /dev/null
+++ b/ChangeLog.d/tls-hs-defrag-in.txt
@@ -0,0 +1,5 @@
+Bugfix
+   * Support re-assembly of fragmented handshake messages in TLS, as mandated
+     by the spec. Lack of support was causing handshake failures with some
+     servers, especially with TLS 1.3 in practice (though both protocol
+     version could be affected in principle, and both are fixed now).
diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h
index e0c0eae..0c0c8bb 100644
--- a/include/mbedtls/ssl.h
+++ b/include/mbedtls/ssl.h
@@ -1795,6 +1795,8 @@
 
     size_t MBEDTLS_PRIVATE(in_hslen);            /*!< current handshake message length,
                                                     including the handshake header   */
+    size_t MBEDTLS_PRIVATE(in_hsfraglen);        /*!< accumulated length of hs fragments
+                                                    (up to in_hslen) */
     int MBEDTLS_PRIVATE(nb_zero);                /*!< # of 0-length encrypted messages */
 
     int MBEDTLS_PRIVATE(keep_current_message);   /*!< drop or reuse current message
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 9f91861..2140ac4 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -1756,10 +1756,11 @@
 MBEDTLS_CHECK_RETURN_CRITICAL
 int mbedtls_ssl_check_timer(mbedtls_ssl_context *ssl);
 
-void mbedtls_ssl_reset_in_out_pointers(mbedtls_ssl_context *ssl);
+void mbedtls_ssl_reset_in_pointers(mbedtls_ssl_context *ssl);
+void mbedtls_ssl_update_in_pointers(mbedtls_ssl_context *ssl);
+void mbedtls_ssl_reset_out_pointers(mbedtls_ssl_context *ssl);
 void mbedtls_ssl_update_out_pointers(mbedtls_ssl_context *ssl,
                                      mbedtls_ssl_transform *transform);
-void mbedtls_ssl_update_in_pointers(mbedtls_ssl_context *ssl);
 
 MBEDTLS_CHECK_RETURN_CRITICAL
 int mbedtls_ssl_session_reset_int(mbedtls_ssl_context *ssl, int partial);
diff --git a/library/ssl_msg.c b/library/ssl_msg.c
index 97c4866..a87785c 100644
--- a/library/ssl_msg.c
+++ b/library/ssl_msg.c
@@ -2962,13 +2962,17 @@
 
 int mbedtls_ssl_prepare_handshake_record(mbedtls_ssl_context *ssl)
 {
-    if (ssl->in_msglen < mbedtls_ssl_hs_hdr_len(ssl)) {
+    /* First handshake fragment must at least include the header. */
+    if (ssl->in_msglen < mbedtls_ssl_hs_hdr_len(ssl) && ssl->in_hslen == 0) {
         MBEDTLS_SSL_DEBUG_MSG(1, ("handshake message too short: %" MBEDTLS_PRINTF_SIZET,
                                   ssl->in_msglen));
         return MBEDTLS_ERR_SSL_INVALID_RECORD;
     }
 
-    ssl->in_hslen = mbedtls_ssl_hs_hdr_len(ssl) + ssl_get_hs_total_len(ssl);
+    if (ssl->in_hslen == 0) {
+        ssl->in_hslen = mbedtls_ssl_hs_hdr_len(ssl) + ssl_get_hs_total_len(ssl);
+        ssl->in_hsfraglen = 0;
+    }
 
     MBEDTLS_SSL_DEBUG_MSG(3, ("handshake message: msglen ="
                               " %" MBEDTLS_PRINTF_SIZET ", type = %u, hslen = %"
@@ -3034,10 +3038,60 @@
         }
     } else
 #endif /* MBEDTLS_SSL_PROTO_DTLS */
-    /* With TLS we don't handle fragmentation (for now) */
-    if (ssl->in_msglen < ssl->in_hslen) {
-        MBEDTLS_SSL_DEBUG_MSG(1, ("TLS handshake fragmentation not supported"));
-        return MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE;
+    if (ssl->in_hsfraglen <= ssl->in_hslen) {
+        int ret;
+        const size_t hs_remain = ssl->in_hslen - ssl->in_hsfraglen;
+        MBEDTLS_SSL_DEBUG_MSG(3,
+                              ("handshake fragment: %" MBEDTLS_PRINTF_SIZET " .. %"
+                               MBEDTLS_PRINTF_SIZET " of %"
+                               MBEDTLS_PRINTF_SIZET " msglen %" MBEDTLS_PRINTF_SIZET,
+                               ssl->in_hsfraglen,
+                               ssl->in_hsfraglen +
+                               (hs_remain <= ssl->in_msglen ? hs_remain : ssl->in_msglen),
+                               ssl->in_hslen, ssl->in_msglen));
+        if (ssl->in_msglen < hs_remain) {
+            ssl->in_hsfraglen += ssl->in_msglen;
+            ssl->in_hdr = ssl->in_msg + ssl->in_msglen;
+            ssl->in_msglen = 0;
+            mbedtls_ssl_update_in_pointers(ssl);
+            return MBEDTLS_ERR_SSL_CONTINUE_PROCESSING;
+        }
+        if (ssl->in_hsfraglen > 0) {
+            /*
+             * At in_first_hdr we have a sequence of records that cover the next handshake
+             * record, each with its own record header that we need to remove.
+             * Note that the reassembled record size may not equal the size of the message,
+             * there may be more messages after it, complete or partial.
+             */
+            unsigned char *in_first_hdr = ssl->in_buf + MBEDTLS_SSL_SEQUENCE_NUMBER_LEN;
+            unsigned char *p = in_first_hdr, *q = NULL;
+            size_t merged_rec_len = 0;
+            do {
+                mbedtls_record rec;
+                ret = ssl_parse_record_header(ssl, p, mbedtls_ssl_in_hdr_len(ssl), &rec);
+                if (ret != 0) {
+                    return ret;
+                }
+                merged_rec_len += rec.data_len;
+                p = rec.buf + rec.buf_len;
+                if (q != NULL) {
+                    memmove(q, rec.buf + rec.data_offset, rec.data_len);
+                    q += rec.data_len;
+                } else {
+                    q = p;
+                }
+            } while (merged_rec_len < ssl->in_hslen);
+            ssl->in_hdr = in_first_hdr;
+            mbedtls_ssl_update_in_pointers(ssl);
+            ssl->in_msglen = merged_rec_len;
+            /* Adjust message length. */
+            MBEDTLS_PUT_UINT16_BE(merged_rec_len, ssl->in_len, 0);
+            ssl->in_hsfraglen = 0;
+            MBEDTLS_SSL_DEBUG_BUF(4, "reassembled record",
+                                  ssl->in_hdr, mbedtls_ssl_in_hdr_len(ssl) + merged_rec_len);
+        }
+    } else {
+        return MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     }
 
     return 0;
@@ -4382,6 +4436,16 @@
             return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
         }
 
+        if (ssl->in_hsfraglen != 0) {
+            /* Not all handshake fragments have arrived, do not consume. */
+            MBEDTLS_SSL_DEBUG_MSG(3,
+                                  ("waiting for more fragments (%" MBEDTLS_PRINTF_SIZET " of %"
+                                   MBEDTLS_PRINTF_SIZET ", %" MBEDTLS_PRINTF_SIZET " left)",
+                                   ssl->in_hsfraglen, ssl->in_hslen,
+                                   ssl->in_hslen - ssl->in_hsfraglen));
+            return 0;
+        }
+
         /*
          * Get next Handshake message in the current record
          */
@@ -4407,6 +4471,7 @@
             ssl->in_msglen -= ssl->in_hslen;
             memmove(ssl->in_msg, ssl->in_msg + ssl->in_hslen,
                     ssl->in_msglen);
+            MBEDTLS_PUT_UINT16_BE(ssl->in_msglen, ssl->in_len, 0);
 
             MBEDTLS_SSL_DEBUG_BUF(4, "remaining content in record",
                                   ssl->in_msg, ssl->in_msglen);
@@ -5081,7 +5146,7 @@
     } else
 #endif
     {
-        ssl->in_ctr = ssl->in_hdr - MBEDTLS_SSL_SEQUENCE_NUMBER_LEN;
+        ssl->in_ctr = ssl->in_buf;
         ssl->in_len = ssl->in_hdr + 3;
 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
         ssl->in_cid = ssl->in_len;
@@ -5097,24 +5162,35 @@
  * Setup an SSL context
  */
 
-void mbedtls_ssl_reset_in_out_pointers(mbedtls_ssl_context *ssl)
+void mbedtls_ssl_reset_in_pointers(mbedtls_ssl_context *ssl)
+{
+#if defined(MBEDTLS_SSL_PROTO_DTLS)
+    if (ssl->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
+        ssl->in_hdr = ssl->in_buf;
+    } else
+#endif  /* MBEDTLS_SSL_PROTO_DTLS */
+    {
+        ssl->in_hdr = ssl->in_buf + MBEDTLS_SSL_SEQUENCE_NUMBER_LEN;
+    }
+
+    /* Derive other internal pointers. */
+    mbedtls_ssl_update_in_pointers(ssl);
+}
+
+void mbedtls_ssl_reset_out_pointers(mbedtls_ssl_context *ssl)
 {
     /* Set the incoming and outgoing record pointers. */
 #if defined(MBEDTLS_SSL_PROTO_DTLS)
     if (ssl->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM) {
         ssl->out_hdr = ssl->out_buf;
-        ssl->in_hdr  = ssl->in_buf;
     } else
 #endif /* MBEDTLS_SSL_PROTO_DTLS */
     {
         ssl->out_ctr = ssl->out_buf;
-        ssl->out_hdr = ssl->out_buf + 8;
-        ssl->in_hdr  = ssl->in_buf  + 8;
+        ssl->out_hdr = ssl->out_buf + MBEDTLS_SSL_SEQUENCE_NUMBER_LEN;
     }
-
     /* Derive other internal pointers. */
     mbedtls_ssl_update_out_pointers(ssl, NULL /* no transform enabled */);
-    mbedtls_ssl_update_in_pointers(ssl);
 }
 
 /*
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 60f2e1c..7d20b3c 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -339,12 +339,13 @@
                                    size_t out_buf_new_len)
 {
     int modified = 0;
-    size_t written_in = 0, iv_offset_in = 0, len_offset_in = 0;
+    size_t written_in = 0, iv_offset_in = 0, len_offset_in = 0, hdr_in = 0;
     size_t written_out = 0, iv_offset_out = 0, len_offset_out = 0;
     if (ssl->in_buf != NULL) {
         written_in = ssl->in_msg - ssl->in_buf;
         iv_offset_in = ssl->in_iv - ssl->in_buf;
         len_offset_in = ssl->in_len - ssl->in_buf;
+        hdr_in = ssl->in_hdr - ssl->in_buf;
         if (downsizing ?
             ssl->in_buf_len > in_buf_new_len && ssl->in_left < in_buf_new_len :
             ssl->in_buf_len < in_buf_new_len) {
@@ -376,7 +377,10 @@
     }
     if (modified) {
         /* Update pointers here to avoid doing it twice. */
-        mbedtls_ssl_reset_in_out_pointers(ssl);
+        ssl->in_hdr = ssl->in_buf + hdr_in;
+        mbedtls_ssl_update_in_pointers(ssl);
+        mbedtls_ssl_reset_out_pointers(ssl);
+
         /* Fields below might not be properly updated with record
          * splitting or with CID, so they are manually updated here. */
         ssl->out_msg = ssl->out_buf + written_out;
@@ -1277,7 +1281,8 @@
         goto error;
     }
 
-    mbedtls_ssl_reset_in_out_pointers(ssl);
+    mbedtls_ssl_reset_in_pointers(ssl);
+    mbedtls_ssl_reset_out_pointers(ssl);
 
 #if defined(MBEDTLS_SSL_DTLS_SRTP)
     memset(&ssl->dtls_srtp_info, 0, sizeof(ssl->dtls_srtp_info));
@@ -1342,7 +1347,8 @@
     /* Cancel any possibly running timer */
     mbedtls_ssl_set_timer(ssl, 0);
 
-    mbedtls_ssl_reset_in_out_pointers(ssl);
+    mbedtls_ssl_reset_in_pointers(ssl);
+    mbedtls_ssl_reset_out_pointers(ssl);
 
     /* Reset incoming message parsing */
     ssl->in_offt    = NULL;
@@ -1350,6 +1356,7 @@
     ssl->in_msgtype = 0;
     ssl->in_msglen  = 0;
     ssl->in_hslen   = 0;
+    ssl->in_hsfraglen = 0;
     ssl->keep_current_message = 0;
     ssl->transform_in  = NULL;
 
diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c
index fc9b860..a302af4 100644
--- a/library/ssl_tls12_server.c
+++ b/library/ssl_tls12_server.c
@@ -1015,28 +1015,6 @@
         MBEDTLS_SSL_DEBUG_MSG(1, ("bad client hello message"));
         return MBEDTLS_ERR_SSL_UNEXPECTED_MESSAGE;
     }
-    {
-        size_t handshake_len = MBEDTLS_GET_UINT24_BE(buf, 1);
-        MBEDTLS_SSL_DEBUG_MSG(3, ("client hello v3, handshake len.: %u",
-                                  (unsigned) handshake_len));
-
-        /* The record layer has a record size limit of 2^14 - 1 and
-         * fragmentation is not supported, so buf[1] should be zero. */
-        if (buf[1] != 0) {
-            MBEDTLS_SSL_DEBUG_MSG(1, ("bad client hello message: %u != 0",
-                                      (unsigned) buf[1]));
-            return MBEDTLS_ERR_SSL_DECODE_ERROR;
-        }
-
-        /* We don't support fragmentation of ClientHello (yet?) */
-        if (msg_len != mbedtls_ssl_hs_hdr_len(ssl) + handshake_len) {
-            MBEDTLS_SSL_DEBUG_MSG(1, ("bad client hello message: %u != %u + %u",
-                                      (unsigned) msg_len,
-                                      (unsigned) mbedtls_ssl_hs_hdr_len(ssl),
-                                      (unsigned) handshake_len));
-            return MBEDTLS_ERR_SSL_DECODE_ERROR;
-        }
-    }
 
 #if defined(MBEDTLS_SSL_PROTO_DTLS)
     if (ssl->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM) {