Create handshake record coalescing tests

Create tests that coalesce the handshake messages in the first flight from
the server. This lets us test the behavior of the library when a handshake
record contains multiple handshake messages.

Only non-protected (non-encrypted, non-authenticated) handshake messages are
supported.

The test code works for all protocol versions, but it is only effective in
TLS 1.2. In TLS 1.3, there is only a single non-encrypted handshake record,
so we can't test records containing more than one handshake message without
a lot more work.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index 2dabaea..99fbef3 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -63,6 +63,98 @@
 }
 #endif
 
+typedef enum {
+    RECOMBINE_NOMINAL,          /* param: ignored */
+    RECOMBINE_COALESCE,         /* param: min number of records */
+} recombine_records_instruction_t;
+
+/* Coalesce TLS handshake records.
+ * DTLS is not supported.
+ * Encrypted or authenticated handshake records are not supported.
+ * Assume the buffer content is a valid sequence of records.
+ */
+static int recombine_coalesce_handshake_records(mbedtls_test_ssl_buffer *buf,
+                                                int max)
+{
+    const size_t header_length = 5;
+    TEST_LE_U(header_length, buf->content_length);
+    if (buf->buffer[0] != MBEDTLS_SSL_MSG_HANDSHAKE) {
+        return 0;
+    }
+
+    size_t record_length = MBEDTLS_GET_UINT16_BE(buf->buffer, header_length - 2);
+    TEST_LE_U(header_length + record_length, buf->content_length);
+
+    int count;
+    for (count = 1; count < max; count++) {
+        size_t next_start = header_length + record_length;
+        if (next_start >= buf->content_length) {
+            /* We've already reached the last record. */
+            break;
+        }
+
+        TEST_LE_U(next_start + header_length, buf->content_length);
+        if (buf->buffer[next_start] != MBEDTLS_SSL_MSG_HANDSHAKE) {
+            /* There's another record, but it isn't a handshake record. */
+            break;
+        }
+        size_t next_length =
+            MBEDTLS_GET_UINT16_BE(buf->buffer, next_start + header_length - 2);
+        TEST_LE_U(next_start + header_length + next_length, buf->content_length);
+
+        /* Erase the next record header */
+        memmove(buf->buffer + next_start,
+                buf->buffer + next_start + header_length,
+                buf->content_length - next_start);
+        buf->content_length -= header_length;
+        /* Update the first record length */
+        record_length += next_length;
+        TEST_LE_U(record_length, 0xffff);
+        MBEDTLS_PUT_UINT16_BE(record_length, buf->buffer, header_length - 2);
+    }
+
+    return count;
+
+exit:
+    return -1;
+}
+
+static int recombine_records(mbedtls_test_ssl_endpoint *server,
+                             recombine_records_instruction_t instruction,
+                             int param)
+{
+    mbedtls_test_ssl_buffer *buf = server->socket.output;
+    int ret;
+
+    /* buf is a circular buffer. For simplicity, this code assumes that
+     * the data is located at the beginning. This should be ok since
+     * this function is only meant to be used on the first flight
+     * emitted by a server. */
+    TEST_EQUAL(buf->start, 0);
+
+    switch (instruction) {
+        case RECOMBINE_NOMINAL:
+            break;
+
+        case RECOMBINE_COALESCE:
+            ret = recombine_coalesce_handshake_records(buf, param);
+            if (param == INT_MAX) {
+                TEST_LE_S(1, ret);
+            } else {
+                TEST_EQUAL(ret, param);
+            }
+            break;
+
+        default:
+            TEST_FAIL("Instructions not understood");
+    }
+
+    return 1;
+
+exit:
+    return 0;
+}
+
 /* END_HEADER */
 
 /* BEGIN_DEPENDENCIES
@@ -2802,6 +2894,146 @@
 }
 /* END_CASE */
 
+/* BEGIN_CASE */
+void recombine_server_first_flight(int version,
+                                   int instruction, int param,
+                                   char *client_log, char *server_log,
+                                   int goal_state, int expected_ret)
+{
+    enum { BUFFSIZE = 17000 };
+    mbedtls_test_ssl_endpoint client = { 0 };
+    mbedtls_test_ssl_endpoint server = { 0 };
+    mbedtls_test_handshake_test_options client_options;
+    mbedtls_test_init_handshake_options(&client_options);
+    mbedtls_test_handshake_test_options server_options;
+    mbedtls_test_init_handshake_options(&server_options);
+#if defined(MBEDTLS_DEBUG_C)
+    mbedtls_test_ssl_log_pattern cli_pattern = { .pattern = client_log };
+    mbedtls_test_ssl_log_pattern srv_pattern = { .pattern = server_log };
+#endif
+    int ret = 0;
+
+    MD_OR_USE_PSA_INIT();
+#if defined(MBEDTLS_DEBUG_C)
+    mbedtls_debug_set_threshold(3);
+#endif
+
+    client_options.client_min_version = version;
+    client_options.client_max_version = version;
+#if defined(MBEDTLS_DEBUG_C)
+    client_options.cli_log_obj = &cli_pattern;
+    client_options.cli_log_fun = mbedtls_test_ssl_log_analyzer;
+#else
+    (void) cli_pattern;
+#endif
+    TEST_EQUAL(mbedtls_test_ssl_endpoint_init(&client, MBEDTLS_SSL_IS_CLIENT,
+                                              &client_options, NULL, NULL,
+                                              NULL), 0);
+#if defined(MBEDTLS_DEBUG_C)
+    mbedtls_ssl_conf_dbg(&client.conf, client_options.cli_log_fun,
+                         client_options.cli_log_obj);
+#endif
+
+    server_options.server_min_version = version;
+    server_options.server_max_version = version;
+#if defined(MBEDTLS_DEBUG_C)
+    server_options.srv_log_obj = &srv_pattern;
+    server_options.srv_log_fun = mbedtls_test_ssl_log_analyzer;
+#else
+    (void) srv_pattern;
+#endif
+    TEST_EQUAL(mbedtls_test_ssl_endpoint_init(&server, MBEDTLS_SSL_IS_SERVER,
+                                              &server_options, NULL, NULL,
+                                              NULL), 0);
+#if defined(MBEDTLS_DEBUG_C)
+    mbedtls_ssl_conf_dbg(&server.conf, server_options.srv_log_fun,
+                         server_options.srv_log_obj);
+#endif
+
+    TEST_EQUAL(mbedtls_test_mock_socket_connect(&client.socket,
+                                                &server.socket,
+                                                BUFFSIZE), 0);
+
+    /* Client: emit the first flight from the client */
+    while (ret == 0) {
+        mbedtls_test_set_step(client.ssl.state);
+        ret = mbedtls_ssl_handshake_step(&client.ssl);
+    }
+    TEST_EQUAL(ret, MBEDTLS_ERR_SSL_WANT_READ);
+    ret = 0;
+    TEST_EQUAL(client.ssl.state, MBEDTLS_SSL_SERVER_HELLO);
+
+    /* Server: parse the first flight from the client
+     * and emit the first flight from the server */
+    while (ret == 0) {
+        mbedtls_test_set_step(1000 + server.ssl.state);
+        ret = mbedtls_ssl_handshake_step(&server.ssl);
+    }
+    TEST_EQUAL(ret, MBEDTLS_ERR_SSL_WANT_READ);
+    ret = 0;
+    TEST_EQUAL(server.ssl.state, MBEDTLS_SSL_SERVER_HELLO_DONE + 1);
+
+    /* Recombine the first flight from the server */
+    TEST_ASSERT(recombine_records(&server, instruction, param));
+
+    /* Client: parse the first flight from the server
+     * and emit the second flight from the client */
+    while (ret == 0 && !mbedtls_ssl_is_handshake_over(&client.ssl)) {
+        mbedtls_test_set_step(client.ssl.state);
+        ret = mbedtls_ssl_handshake_step(&client.ssl);
+        if (client.ssl.state == goal_state && ret != 0) {
+            TEST_EQUAL(ret, expected_ret);
+            goto goal_reached;
+        }
+    }
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3)
+    if (version >= MBEDTLS_SSL_VERSION_TLS1_3 &&
+        goal_state >= MBEDTLS_SSL_HANDSHAKE_OVER) {
+        TEST_EQUAL(ret, 0);
+    } else
+#endif
+    {
+        TEST_EQUAL(ret, MBEDTLS_ERR_SSL_WANT_READ);
+    }
+    ret = 0;
+
+    /* Server: parse the first flight from the client
+     * and emit the second flight from the server */
+    while (ret == 0 && !mbedtls_ssl_is_handshake_over(&server.ssl)) {
+        mbedtls_test_set_step(1000 + server.ssl.state);
+        ret = mbedtls_ssl_handshake_step(&server.ssl);
+    }
+    TEST_EQUAL(ret, 0);
+
+    /* Client: parse the second flight from the server */
+    while (ret == 0 && !mbedtls_ssl_is_handshake_over(&client.ssl)) {
+        mbedtls_test_set_step(client.ssl.state);
+        ret = mbedtls_ssl_handshake_step(&client.ssl);
+    }
+    if (client.ssl.state == goal_state) {
+        TEST_EQUAL(ret, expected_ret);
+    } else {
+        TEST_EQUAL(ret, 0);
+    }
+
+goal_reached:
+#if defined(MBEDTLS_DEBUG_C)
+    TEST_ASSERT(cli_pattern.counter >= 1);
+    TEST_ASSERT(srv_pattern.counter >= 1);
+#endif
+
+exit:
+    mbedtls_test_ssl_endpoint_free(&client, NULL);
+    mbedtls_test_ssl_endpoint_free(&server, NULL);
+    mbedtls_test_free_handshake_options(&client_options);
+    mbedtls_test_free_handshake_options(&server_options);
+    MD_OR_USE_PSA_DONE();
+#if defined(MBEDTLS_DEBUG_C)
+    mbedtls_debug_set_threshold(0);
+#endif
+}
+/* END_CASE */
+
 /* BEGIN_CASE depends_on:MBEDTLS_SSL_HANDSHAKE_WITH_CERT_ENABLED:!MBEDTLS_SSL_PROTO_TLS1_3:MBEDTLS_PKCS1_V15:MBEDTLS_SSL_PROTO_TLS1_2:MBEDTLS_RSA_C:MBEDTLS_ECP_HAVE_SECP384R1:MBEDTLS_SSL_PROTO_DTLS:MBEDTLS_SSL_RENEGOTIATION:MBEDTLS_MD_CAN_SHA256:MBEDTLS_CAN_HANDLE_RSA_TEST_KEY */
 void renegotiation(int legacy_renegotiation)
 {