Merge pull request #1127 from davidhorstmann-arm/prototype-single-fn-copytesting

Prototype poisoning testing with a single function
diff --git a/include/mbedtls/mbedtls_config.h b/include/mbedtls/mbedtls_config.h
index 96a3e43..ed33828 100644
--- a/include/mbedtls/mbedtls_config.h
+++ b/include/mbedtls/mbedtls_config.h
@@ -1469,6 +1469,22 @@
 //#define MBEDTLS_PSA_INJECT_ENTROPY
 
 /**
+ * \def MBEDTLS_PSA_COPY_CALLER_BUFFERS
+ *
+ * Make local copies of buffers supplied by the callers of PSA functions.
+ *
+ * This should be enabled whenever caller-supplied buffers are owned by
+ * an untrusted party, for example where arguments to PSA calls are passed
+ * across a trust boundary.
+ *
+ * \note Enabling this option increases memory usage and code size.
+ *
+ * \note Disabling this option causes overlap of input and output buffers
+ *       not to be supported by PSA functions.
+ */
+#define MBEDTLS_PSA_COPY_CALLER_BUFFERS
+
+/**
  * \def MBEDTLS_RSA_NO_CRT
  *
  * Do not use the Chinese Remainder Theorem
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index e3187d8..8c9f9de 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -110,6 +110,118 @@
     if (global_data.initialized == 0)  \
     return PSA_ERROR_BAD_STATE;
 
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+
+/* Declare a local copy of an input buffer and a variable that will be used
+ * to store a pointer to the start of the buffer.
+ *
+ * Note: This macro must be called before any operations which may jump to
+ * the exit label, so that the local input copy object is safe to be freed.
+ *
+ * Assumptions:
+ * - input is the name of a pointer to the buffer to be copied
+ * - The name LOCAL_INPUT_COPY_OF_input is unused in the current scope
+ * - input_copy_name is a name that is unused in the current scope
+ */
+#define LOCAL_INPUT_DECLARE(input, input_copy_name) \
+    psa_crypto_local_input_t LOCAL_INPUT_COPY_OF_##input = PSA_CRYPTO_LOCAL_INPUT_INIT; \
+    const uint8_t *input_copy_name = NULL;
+
+/* Allocate a copy of the buffer input and set the pointer input_copy to
+ * point to the start of the copy.
+ *
+ * Assumptions:
+ * - psa_status_t status exists
+ * - An exit label is declared
+ * - input is the name of a pointer to the buffer to be copied
+ * - LOCAL_INPUT_DECLARE(input, input_copy) has previously been called
+ */
+#define LOCAL_INPUT_ALLOC(input, length, input_copy) \
+    status = psa_crypto_local_input_alloc(input, length, \
+                                          &LOCAL_INPUT_COPY_OF_##input); \
+    if (status != PSA_SUCCESS) { \
+        goto exit; \
+    } \
+    input_copy = LOCAL_INPUT_COPY_OF_##input.buffer;
+
+/* Free the local input copy allocated previously by LOCAL_INPUT_ALLOC()
+ *
+ * Assumptions:
+ * - input_copy is the name of the input copy pointer set by LOCAL_INPUT_ALLOC()
+ * - input is the name of the original buffer that was copied
+ */
+#define LOCAL_INPUT_FREE(input, input_copy) \
+    input_copy = NULL; \
+    psa_crypto_local_input_free(&LOCAL_INPUT_COPY_OF_##input);
+
+/* Declare a local copy of an output buffer and a variable that will be used
+ * to store a pointer to the start of the buffer.
+ *
+ * Note: This macro must be called before any operations which may jump to
+ * the exit label, so that the local output copy object is safe to be freed.
+ *
+ * Assumptions:
+ * - output is the name of a pointer to the buffer to be copied
+ * - The name LOCAL_OUTPUT_COPY_OF_output is unused in the current scope
+ * - output_copy_name is a name that is unused in the current scope
+ */
+#define LOCAL_OUTPUT_DECLARE(output, output_copy_name) \
+    psa_crypto_local_output_t LOCAL_OUTPUT_COPY_OF_##output = PSA_CRYPTO_LOCAL_OUTPUT_INIT; \
+    uint8_t *output_copy_name = NULL;
+
+/* Allocate a copy of the buffer output and set the pointer output_copy to
+ * point to the start of the copy.
+ *
+ * Assumptions:
+ * - psa_status_t status exists
+ * - An exit label is declared
+ * - output is the name of a pointer to the buffer to be copied
+ * - LOCAL_OUTPUT_DECLARE(output, output_copy) has previously been called
+ */
+#define LOCAL_OUTPUT_ALLOC(output, length, output_copy) \
+    status = psa_crypto_local_output_alloc(output, length, \
+                                           &LOCAL_OUTPUT_COPY_OF_##output); \
+    if (status != PSA_SUCCESS) { \
+        goto exit; \
+    } \
+    output_copy = LOCAL_OUTPUT_COPY_OF_##output.buffer;
+
+/* Free the local output copy allocated previously by LOCAL_OUTPUT_ALLOC()
+ * after first copying back its contents to the original buffer.
+ *
+ * Assumptions:
+ * - psa_status_t status exists
+ * - output_copy is the name of the output copy pointer set by LOCAL_OUTPUT_ALLOC()
+ * - output is the name of the original buffer that was copied
+ */
+#define LOCAL_OUTPUT_FREE(output, output_copy) \
+    output_copy = NULL; \
+    do { \
+        psa_status_t local_output_status; \
+        local_output_status = psa_crypto_local_output_free(&LOCAL_OUTPUT_COPY_OF_##output); \
+        if (local_output_status != PSA_SUCCESS) { \
+            /* Since this error case is an internal error, it's more serious than \
+             * any existing error code and so it's fine to overwrite the existing \
+             * status. */ \
+            status = local_output_status; \
+        } \
+    } while (0)
+#else /* MBEDTLS_PSA_COPY_CALLER_BUFFERS */
+#define LOCAL_INPUT_DECLARE(input, input_copy_name) \
+    const uint8_t *input_copy_name = NULL;
+#define LOCAL_INPUT_ALLOC(input, length, input_copy) \
+    input_copy = input;
+#define LOCAL_INPUT_FREE(input, input_copy) \
+    input_copy = NULL;
+#define LOCAL_OUTPUT_DECLARE(output, output_copy_name) \
+    uint8_t *output_copy_name = NULL;
+#define LOCAL_OUTPUT_ALLOC(output, length, output_copy) \
+    output_copy = output;
+#define LOCAL_OUTPUT_FREE(output, output_copy) \
+    output_copy = NULL;
+#endif /* MBEDTLS_PSA_COPY_CALLER_BUFFERS */
+
+
 int psa_can_do_hash(psa_algorithm_t hash_alg)
 {
     (void) hash_alg;
@@ -4329,9 +4441,9 @@
 
 psa_status_t psa_cipher_encrypt(mbedtls_svc_key_id_t key,
                                 psa_algorithm_t alg,
-                                const uint8_t *input,
+                                const uint8_t *input_external,
                                 size_t input_length,
-                                uint8_t *output,
+                                uint8_t *output_external,
                                 size_t output_size,
                                 size_t *output_length)
 {
@@ -4342,6 +4454,12 @@
     size_t default_iv_length = 0;
     psa_key_attributes_t attributes;
 
+    LOCAL_INPUT_DECLARE(input_external, input);
+    LOCAL_OUTPUT_DECLARE(output_external, output);
+
+    LOCAL_INPUT_ALLOC(input_external, input_length, input);
+    LOCAL_OUTPUT_ALLOC(output_external, output_size, output);
+
     if (!PSA_ALG_IS_CIPHER(alg)) {
         status = PSA_ERROR_INVALID_ARGUMENT;
         goto exit;
@@ -4397,6 +4515,9 @@
         *output_length = 0;
     }
 
+    LOCAL_INPUT_FREE(input_external, input);
+    LOCAL_OUTPUT_FREE(output_external, output);
+
     return status;
 }
 
@@ -8430,6 +8551,16 @@
 }
 #endif /* PSA_WANT_ALG_SOME_PAKE */
 
+/* Memory copying test hooks. These are called before input copy, after input
+ * copy, before output copy and after output copy, respectively.
+ * They are used by memory-poisoning tests to temporarily unpoison buffers
+ * while they are copied. */
+#if defined(MBEDTLS_TEST_HOOKS)
+void (*psa_input_pre_copy_hook)(const uint8_t *input, size_t input_len) = NULL;
+void (*psa_input_post_copy_hook)(const uint8_t *input, size_t input_len) = NULL;
+void (*psa_output_pre_copy_hook)(const uint8_t *output, size_t output_len) = NULL;
+void (*psa_output_post_copy_hook)(const uint8_t *output, size_t output_len) = NULL;
+#endif
 
 /** Copy from an input buffer to a local copy.
  *
@@ -8451,10 +8582,22 @@
         return PSA_ERROR_CORRUPTION_DETECTED;
     }
 
+#if defined(MBEDTLS_TEST_HOOKS)
+    if (psa_input_pre_copy_hook != NULL) {
+        psa_input_pre_copy_hook(input, input_len);
+    }
+#endif
+
     if (input_len > 0) {
         memcpy(input_copy, input, input_len);
     }
 
+#if defined(MBEDTLS_TEST_HOOKS)
+    if (psa_input_post_copy_hook != NULL) {
+        psa_input_post_copy_hook(input, input_len);
+    }
+#endif
+
     return PSA_SUCCESS;
 }
 
@@ -8478,10 +8621,22 @@
         return PSA_ERROR_BUFFER_TOO_SMALL;
     }
 
+#if defined(MBEDTLS_TEST_HOOKS)
+    if (psa_output_pre_copy_hook != NULL) {
+        psa_output_pre_copy_hook(output, output_len);
+    }
+#endif
+
     if (output_copy_len > 0) {
         memcpy(output, output_copy, output_copy_len);
     }
 
+#if defined(MBEDTLS_TEST_HOOKS)
+    if (psa_output_post_copy_hook != NULL) {
+        psa_output_post_copy_hook(output, output_len);
+    }
+#endif
+
     return PSA_SUCCESS;
 }
 
diff --git a/library/psa_crypto_invasive.h b/library/psa_crypto_invasive.h
index 6a1181f..51c90c6 100644
--- a/library/psa_crypto_invasive.h
+++ b/library/psa_crypto_invasive.h
@@ -79,6 +79,14 @@
 psa_status_t psa_crypto_copy_output(const uint8_t *output_copy, size_t output_copy_len,
                                     uint8_t *output, size_t output_len);
 
+/*
+ * Test hooks to use for memory unpoisoning/poisoning in copy functions.
+ */
+extern void (*psa_input_pre_copy_hook)(const uint8_t *input, size_t input_len);
+extern void (*psa_input_post_copy_hook)(const uint8_t *input, size_t input_len);
+extern void (*psa_output_pre_copy_hook)(const uint8_t *output, size_t output_len);
+extern void (*psa_output_post_copy_hook)(const uint8_t *output, size_t output_len);
+
 #endif /* MBEDTLS_TEST_HOOKS && MBEDTLS_PSA_CRYPTO_C */
 
 #endif /* PSA_CRYPTO_INVASIVE_H */
diff --git a/tests/include/test/psa_crypto_helpers.h b/tests/include/test/psa_crypto_helpers.h
index 04b90b9..41b7752 100644
--- a/tests/include/test/psa_crypto_helpers.h
+++ b/tests/include/test/psa_crypto_helpers.h
@@ -16,6 +16,10 @@
 #include <psa/crypto.h>
 #endif
 
+#if defined(MBEDTLS_TEST_HOOKS) && defined(MBEDTLS_PSA_CRYPTO_C) \
+    && defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+#include "test/psa_memory_poisoning_wrappers.h"
+#endif
 
 #if defined(MBEDTLS_PSA_CRYPTO_C)
 /** Initialize the PSA Crypto subsystem. */
diff --git a/tests/include/test/psa_memory_poisoning_wrappers.h b/tests/include/test/psa_memory_poisoning_wrappers.h
new file mode 100644
index 0000000..0052c2f
--- /dev/null
+++ b/tests/include/test/psa_memory_poisoning_wrappers.h
@@ -0,0 +1,47 @@
+/** Memory poisoning wrappers for PSA functions.
+ *
+ *  These wrappers poison the input and output buffers of each function
+ *  before calling it, to ensure that it does not access the buffers
+ *  except by calling the approved buffer-copying functions.
+ */
+/*
+ *  Copyright The Mbed TLS Contributors
+ *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
+ */
+
+#ifndef PSA_MEMORY_POISONING_WRAPPERS_H
+#define PSA_MEMORY_POISONING_WRAPPERS_H
+
+#include "psa/crypto.h"
+
+#include "test/memory.h"
+
+#if defined(MBEDTLS_TEST_HOOKS) && defined(MBEDTLS_TEST_MEMORY_CAN_POISON)
+
+/**
+ * \brief         Setup the memory poisoning test hooks used by
+ *                psa_crypto_copy_input() and psa_crypto_copy_output() for
+ *                memory poisoning.
+ */
+void mbedtls_poison_test_hooks_setup(void);
+
+/**
+ * \brief         Teardown the memory poisoning test hooks used by
+ *                psa_crypto_copy_input() and psa_crypto_copy_output() for
+ *                memory poisoning.
+ */
+void mbedtls_poison_test_hooks_teardown(void);
+
+psa_status_t wrap_psa_cipher_encrypt(mbedtls_svc_key_id_t key,
+                                     psa_algorithm_t alg,
+                                     const uint8_t *input,
+                                     size_t input_length,
+                                     uint8_t *output,
+                                     size_t output_size,
+                                     size_t *output_length);
+
+#define psa_cipher_encrypt(...) wrap_psa_cipher_encrypt(__VA_ARGS__)
+
+#endif /* MBEDTLS_TEST_HOOKS && MBEDTLS_TEST_MEMORY_CAN_POISON */
+
+#endif /* PSA_MEMORY_POISONING_WRAPPERS_H */
diff --git a/tests/scripts/all.sh b/tests/scripts/all.sh
index c3d2fb8..4465d05 100755
--- a/tests/scripts/all.sh
+++ b/tests/scripts/all.sh
@@ -1216,6 +1216,17 @@
     make test
 }
 
+component_test_no_psa_copy_caller_buffers () {
+    msg "build: full config - MBEDTLS_PSA_COPY_CALLER_BUFFERS, cmake, gcc, ASan"
+    scripts/config.py full
+    scripts/config.py unset MBEDTLS_PSA_COPY_CALLER_BUFFERS
+    CC=gcc cmake -D CMAKE_BUILD_TYPE:String=Asan .
+    make
+
+    msg "test: full config - MBEDTLS_PSA_COPY_CALLER_BUFFERS, cmake, gcc, ASan"
+    make test
+}
+
 # check_renamed_symbols HEADER LIB
 # Check that if HEADER contains '#define MACRO ...' then MACRO is not a symbol
 # name is LIB.
diff --git a/tests/src/helpers.c b/tests/src/helpers.c
index eb28919..36564fe 100644
--- a/tests/src/helpers.c
+++ b/tests/src/helpers.c
@@ -13,6 +13,10 @@
 #include <test/psa_crypto_helpers.h>
 #endif
 
+#if defined(MBEDTLS_TEST_HOOKS) && defined(MBEDTLS_PSA_CRYPTO_C)
+#include <test/psa_memory_poisoning_wrappers.h>
+#endif
+
 /*----------------------------------------------------------------------------*/
 /* Static global variables */
 
@@ -29,6 +33,12 @@
 {
     int ret = 0;
 
+#if defined(MBEDTLS_TEST_HOOKS) && defined(MBEDTLS_PSA_CRYPTO_C) \
+    && defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) \
+    && defined(MBEDTLS_TEST_MEMORY_CAN_POISON)
+    mbedtls_poison_test_hooks_setup();
+#endif
+
 #if defined(MBEDTLS_PSA_INJECT_ENTROPY)
     /* Make sure that injected entropy is present. Otherwise
      * psa_crypto_init() will fail. This is not necessary for test suites
@@ -49,6 +59,12 @@
 
 void mbedtls_test_platform_teardown(void)
 {
+#if defined(MBEDTLS_TEST_HOOKS) && defined(MBEDTLS_PSA_CRYPTO_C) \
+    && defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) \
+    &&  defined(MBEDTLS_TEST_MEMORY_CAN_POISON)
+    mbedtls_poison_test_hooks_teardown();
+#endif
+
 #if defined(MBEDTLS_PLATFORM_C)
     mbedtls_platform_teardown(&platform_ctx);
 #endif /* MBEDTLS_PLATFORM_C */
diff --git a/tests/src/psa_memory_poisoning_wrappers.c b/tests/src/psa_memory_poisoning_wrappers.c
new file mode 100644
index 0000000..a53e875
--- /dev/null
+++ b/tests/src/psa_memory_poisoning_wrappers.c
@@ -0,0 +1,53 @@
+/** Helper functions for memory poisoning in tests.
+ */
+/*
+ *  Copyright The Mbed TLS Contributors
+ *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
+ */
+#include "test/memory.h"
+
+#include "psa_crypto_invasive.h"
+
+#if defined(MBEDTLS_TEST_HOOKS)  && defined(MBEDTLS_PSA_CRYPTO_C) \
+    && defined(MBEDTLS_TEST_MEMORY_CAN_POISON)
+
+void mbedtls_poison_test_hooks_setup(void)
+{
+    psa_input_pre_copy_hook = mbedtls_test_memory_unpoison;
+    psa_input_post_copy_hook = mbedtls_test_memory_poison;
+    psa_output_pre_copy_hook = mbedtls_test_memory_unpoison;
+    psa_output_post_copy_hook = mbedtls_test_memory_poison;
+}
+
+void mbedtls_poison_test_hooks_teardown(void)
+{
+    psa_input_pre_copy_hook = NULL;
+    psa_input_post_copy_hook = NULL;
+    psa_output_pre_copy_hook = NULL;
+    psa_output_post_copy_hook = NULL;
+}
+
+psa_status_t wrap_psa_cipher_encrypt(mbedtls_svc_key_id_t key,
+                                     psa_algorithm_t alg,
+                                     const uint8_t *input,
+                                     size_t input_length,
+                                     uint8_t *output,
+                                     size_t output_size,
+                                     size_t *output_length)
+{
+    MBEDTLS_TEST_MEMORY_POISON(input, input_length);
+    MBEDTLS_TEST_MEMORY_POISON(output, output_size);
+    psa_status_t status = psa_cipher_encrypt(key,
+                                             alg,
+                                             input,
+                                             input_length,
+                                             output,
+                                             output_size,
+                                             output_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(input, input_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(output, output_size);
+    return status;
+}
+
+#endif /* MBEDTLS_TEST_HOOKS && MBEDTLS_PSA_CRYPTO_C &&
+          MBEDTLS_TEST_MEMORY_CAN_POISON */
diff --git a/tests/suites/test_suite_psa_crypto_driver_wrappers.function b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
index 1d96f72..69dda35 100644
--- a/tests/suites/test_suite_psa_crypto_driver_wrappers.function
+++ b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
@@ -1486,14 +1486,7 @@
         output, output_buffer_size, &function_output_length);
     TEST_EQUAL(mbedtls_test_driver_cipher_hooks.hits, 1);
     TEST_EQUAL(status, PSA_ERROR_GENERIC_ERROR);
-    /*
-     * Check that the output buffer is still in the same state.
-     * This will fail if the output buffer is used by the core to pass the IV
-     * it generated to the driver (and is not restored).
-     */
-    for (size_t i = 0; i < output_buffer_size; i++) {
-        TEST_EQUAL(output[i], 0xa5);
-    }
+
     mbedtls_test_driver_cipher_hooks.hits = 0;
 
     /* Test setup call, encrypt */
diff --git a/tests/suites/test_suite_psa_crypto_memory.function b/tests/suites/test_suite_psa_crypto_memory.function
index 2bb0f0d..55c0092 100644
--- a/tests/suites/test_suite_psa_crypto_memory.function
+++ b/tests/suites/test_suite_psa_crypto_memory.function
@@ -9,6 +9,7 @@
 #include "psa_crypto_invasive.h"
 
 #include "test/psa_crypto_helpers.h"
+#include "test/memory.h"
 
 /* Helper to fill a buffer with a data pattern. The pattern is not
  * important, it just allows a basic check that the correct thing has
@@ -42,6 +43,7 @@
     TEST_EQUAL(status, exp_status);
 
     if (exp_status == PSA_SUCCESS) {
+        MBEDTLS_TEST_MEMORY_UNPOISON(src_buffer, src_len);
         /* Note: We compare the first src_len bytes of each buffer, as this is what was copied. */
         TEST_MEMORY_COMPARE(src_buffer, src_len, dst_buffer, src_len);
     }
@@ -68,6 +70,7 @@
     TEST_EQUAL(status, exp_status);
 
     if (exp_status == PSA_SUCCESS) {
+        MBEDTLS_TEST_MEMORY_UNPOISON(dst_buffer, dst_len);
         /* Note: We compare the first src_len bytes of each buffer, as this is what was copied. */
         TEST_MEMORY_COMPARE(src_buffer, src_len, dst_buffer, src_len);
     }
@@ -94,6 +97,7 @@
     TEST_EQUAL(status, exp_status);
 
     if (exp_status == PSA_SUCCESS) {
+        MBEDTLS_TEST_MEMORY_UNPOISON(input, input_len);
         if (input_len != 0) {
             TEST_ASSERT(local_input.buffer != input);
         }
@@ -139,6 +143,8 @@
 
     status = psa_crypto_local_input_alloc(input, sizeof(input), &local_input);
     TEST_EQUAL(status, PSA_SUCCESS);
+
+    MBEDTLS_TEST_MEMORY_UNPOISON(input, sizeof(input));
     TEST_MEMORY_COMPARE(local_input.buffer, local_input.length,
                         input, sizeof(input));
     TEST_ASSERT(local_input.buffer != input);
@@ -204,6 +210,7 @@
     TEST_EQUAL(status, exp_status);
 
     if (exp_status == PSA_SUCCESS) {
+        MBEDTLS_TEST_MEMORY_UNPOISON(output, output_len);
         TEST_ASSERT(local_output.buffer == NULL);
         TEST_EQUAL(local_output.length, 0);
         TEST_MEMORY_COMPARE(buffer_copy_for_comparison, output_len,
@@ -240,6 +247,7 @@
     TEST_ASSERT(local_output.buffer == NULL);
     TEST_EQUAL(local_output.length, 0);
 
+    MBEDTLS_TEST_MEMORY_UNPOISON(output, sizeof(output));
     /* Check that the buffer was correctly copied back */
     TEST_MEMORY_COMPARE(output, sizeof(output),
                         buffer_copy_for_comparison, sizeof(output));