Add more tests for keying material export

Signed-off-by: Max Fillinger <maximilian.fillinger@foxcrypto.com>
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index ab61e03..3301249 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -1695,7 +1695,7 @@
 }
 /* END_CASE */
 
-/* BEGIN_CASE depends_on:MBEDTLS_SSL_PROTO_TLS1_3 */
+/* BEGIN_CASE depends_on:MBEDTLS_SSL_PROTO_TLS1_3:MBEDTLS_SSL_KEYING_MATERIAL_EXPORT */
 void ssl_tls13_exporter(int hash_alg,
                         data_t *secret,
                         char *label,
@@ -5229,5 +5229,234 @@
     mbedtls_debug_set_threshold(0);
     mbedtls_free(first_frag);
     PSA_DONE();
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_SSL_KEYING_MATERIAL_EXPORT */
+void ssl_tls_exporter_consistent_result(int proto, int exported_key_length, int use_context)
+{
+    /* Test that the client and server generate the same key. */
+
+    int ret = -1;
+    uint8_t *key_buffer_server = NULL;
+    uint8_t *key_buffer_client = NULL;
+    mbedtls_test_ssl_endpoint client_ep, server_ep;
+
+    MD_OR_USE_PSA_INIT();
+
+    ret = mbedtls_test_ssl_do_handshake_with_endpoints(&server_ep, &client_ep, proto);
+    TEST_ASSERT(ret == 0);
+
+    TEST_ASSERT(exported_key_length > 0);
+    TEST_CALLOC(key_buffer_server, exported_key_length);
+    TEST_CALLOC(key_buffer_client, exported_key_length);
+
+    char label[] = "test-label";
+    unsigned char context[128] = { 0 };
+    ret = mbedtls_ssl_export_keying_material(&server_ep.ssl,
+                                             key_buffer_server, (size_t)exported_key_length,
+                                             label, sizeof(label),
+                                             context, sizeof(context), use_context);
+    TEST_ASSERT(ret == 0);
+    ret = mbedtls_ssl_export_keying_material(&client_ep.ssl,
+                                             key_buffer_client, (size_t)exported_key_length,
+                                             label, sizeof(label),
+                                             context, sizeof(context), use_context);
+    TEST_ASSERT(ret == 0);
+    TEST_ASSERT(memcmp(key_buffer_server, key_buffer_client, (size_t)exported_key_length) == 0);
+
+exit:
+    MD_OR_USE_PSA_DONE();
+    mbedtls_free(key_buffer_server);
+    mbedtls_free(key_buffer_client);
+}
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_SSL_KEYING_MATERIAL_EXPORT */
+void ssl_tls_exporter_uses_label(int proto)
+{
+    /* Test that the client and server export different keys when using different labels. */
+
+    int ret = -1;
+    mbedtls_test_ssl_endpoint client_ep, server_ep;
+
+    MD_OR_USE_PSA_INIT();
+
+    ret = mbedtls_test_ssl_do_handshake_with_endpoints(&server_ep, &client_ep, proto);
+    TEST_ASSERT(ret == 0);
+
+    char label_server[] = "test-label-server";
+    char label_client[] = "test-label-client";
+    uint8_t key_buffer_server[24] = { 0 };
+    uint8_t key_buffer_client[24] = { 0 };
+    unsigned char context[128] = { 0 };
+    ret = mbedtls_ssl_export_keying_material(&server_ep.ssl,
+                                             key_buffer_server, sizeof(key_buffer_server),
+                                             label_server, sizeof(label_server),
+                                             context, sizeof(context), 1);
+    TEST_ASSERT(ret == 0);
+    ret = mbedtls_ssl_export_keying_material(&client_ep.ssl,
+                                             key_buffer_client, sizeof(key_buffer_client),
+                                             label_client, sizeof(label_client),
+                                             context, sizeof(context), 1);
+    TEST_ASSERT(ret == 0);
+    TEST_ASSERT(memcmp(key_buffer_server, key_buffer_client, sizeof(key_buffer_server)) != 0);
+
+exit:
+    MD_OR_USE_PSA_DONE();
+}
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_SSL_KEYING_MATERIAL_EXPORT */
+void ssl_tls_exporter_uses_context(int proto)
+{
+    /* Test that the client and server export different keys when using different contexts. */
+
+    int ret = -1;
+    mbedtls_test_ssl_endpoint client_ep, server_ep;
+
+    MD_OR_USE_PSA_INIT();
+
+    ret = mbedtls_test_ssl_do_handshake_with_endpoints(&server_ep, &client_ep, proto);
+    TEST_ASSERT(ret == 0);
+
+    char label[] = "test-label";
+    uint8_t key_buffer_server[24] = { 0 };
+    uint8_t key_buffer_client[24] = { 0 };
+    unsigned char context_server[128] = { 0 };
+    unsigned char context_client[128] = { 23 };
+    ret = mbedtls_ssl_export_keying_material(&server_ep.ssl,
+                                             key_buffer_server, sizeof(key_buffer_server),
+                                             label, sizeof(label),
+                                             context_server, sizeof(context_server), 1);
+    TEST_ASSERT(ret == 0);
+    ret = mbedtls_ssl_export_keying_material(&client_ep.ssl,
+                                             key_buffer_client, sizeof(key_buffer_client),
+                                             label, sizeof(label),
+                                             context_client, sizeof(context_client), 1);
+    TEST_ASSERT(ret == 0);
+    TEST_ASSERT(memcmp(key_buffer_server, key_buffer_client, sizeof(key_buffer_server)) != 0);
+
+exit:
+    MD_OR_USE_PSA_DONE();
+}
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_SSL_PROTO_TLS1_3:MBEDTLS_SSL_KEYING_MATERIAL_EXPORT */
+void ssl_tls13_exporter_uses_length(void)
+{
+    /* In TLS 1.3, when two keys are exported with the same parameters except one is shorter,
+     * the shorter key should NOT be a prefix of the longer one. */
+
+    int ret = -1;
+    mbedtls_test_ssl_endpoint client_ep, server_ep;
+
+    MD_OR_USE_PSA_INIT();
+
+    ret = mbedtls_test_ssl_do_handshake_with_endpoints(&server_ep, &client_ep, MBEDTLS_SSL_VERSION_TLS1_3);
+    TEST_ASSERT(ret == 0);
+
+    char label[] = "test-label";
+    uint8_t key_buffer_server[16] = { 0 };
+    uint8_t key_buffer_client[24] = { 0 };
+    unsigned char context[128] = { 0 };
+    ret = mbedtls_ssl_export_keying_material(&server_ep.ssl,
+                                             key_buffer_server, sizeof(key_buffer_server),
+                                             label, sizeof(label),
+                                             context, sizeof(context), 1);
+    TEST_ASSERT(ret == 0);
+    ret = mbedtls_ssl_export_keying_material(&client_ep.ssl,
+                                             key_buffer_client, sizeof(key_buffer_client),
+                                             label, sizeof(label),
+                                             context, sizeof(context), 1);
+    TEST_ASSERT(ret == 0);
+    TEST_ASSERT(memcmp(key_buffer_server, key_buffer_client, sizeof(key_buffer_server)) != 0);
+
+exit:
+    MD_OR_USE_PSA_DONE();
+}
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_SSL_KEYING_MATERIAL_EXPORT */
+void ssl_tls_exporter_rejects_bad_parameters(
+    int proto, int exported_key_length, int label_length, int context_length)
+{
+    MD_OR_USE_PSA_INIT();
+
+    int ret = -1;
+    uint8_t *key_buffer = NULL;
+    char *label = NULL;
+    uint8_t *context = NULL;
+    mbedtls_test_ssl_endpoint client_ep, server_ep;
+
+    TEST_ASSERT(exported_key_length > 0);
+    TEST_ASSERT(label_length > 0);
+    TEST_ASSERT(context_length > 0);
+    TEST_CALLOC(key_buffer, exported_key_length);
+    TEST_CALLOC(label, label_length);
+    TEST_CALLOC(context, context_length);
+
+    ret = mbedtls_test_ssl_do_handshake_with_endpoints(&server_ep, &client_ep, proto);
+    TEST_ASSERT(ret == 0);
+
+    ret = mbedtls_ssl_export_keying_material(&client_ep.ssl,
+                                             key_buffer, exported_key_length,
+                                             label, label_length,
+                                             context, context_length, 1);
+    TEST_ASSERT(ret == MBEDTLS_ERR_SSL_BAD_INPUT_DATA);
+
+exit:
+    MD_OR_USE_PSA_DONE();
+    mbedtls_free(key_buffer);
+    mbedtls_free(label);
+    mbedtls_free(context);
+}
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_SSL_KEYING_MATERIAL_EXPORT */
+void ssl_tls_exporter_too_early(int proto, int check_server, int state)
+{
+    enum { BUFFSIZE = 1024 };
+
+    int ret = -1;
+    mbedtls_test_ssl_endpoint server_ep, client_ep;
+
+    mbedtls_test_handshake_test_options options;
+    mbedtls_test_init_handshake_options(&options);
+    options.server_min_version = proto;
+    options.client_min_version = proto;
+    options.server_max_version = proto;
+    options.client_max_version = proto;
+
+    MD_OR_USE_PSA_INIT();
+
+    ret = mbedtls_test_ssl_endpoint_init(&server_ep, MBEDTLS_SSL_IS_SERVER, &options,
+                                         NULL, NULL, NULL);
+    TEST_ASSERT(ret == 0);
+    ret = mbedtls_test_ssl_endpoint_init(&client_ep, MBEDTLS_SSL_IS_CLIENT, &options,
+                                         NULL, NULL, NULL);
+    TEST_ASSERT(ret == 0);
+
+    ret = mbedtls_test_mock_socket_connect(&client_ep.socket, &server_ep.socket, BUFFSIZE);
+    TEST_ASSERT(ret == 0);
+
+    if (check_server) {
+        ret = mbedtls_test_move_handshake_to_state(&server_ep.ssl, &client_ep.ssl, state);
+    } else {
+        ret = mbedtls_test_move_handshake_to_state(&client_ep.ssl, &server_ep.ssl, state);
+    }
+    TEST_ASSERT(ret == 0 || ret == MBEDTLS_ERR_SSL_WANT_READ || MBEDTLS_ERR_SSL_WANT_WRITE);
+
+    char label[] = "test-label";
+    uint8_t key_buffer[24] = { 0 };
+    ret = mbedtls_ssl_export_keying_material(check_server ? &server_ep.ssl : &client_ep.ssl,
+                                             key_buffer, sizeof(key_buffer),
+                                             label, sizeof(label),
+                                             NULL, 0, 0);
+
+    /* FIXME: A more appropriate error code should be created for this case. */
+    TEST_ASSERT(ret == MBEDTLS_ERR_SSL_BAD_INPUT_DATA);
+
+exit:
+    MD_OR_USE_PSA_DONE();
 }
 /* END_CASE */