Merge branch 'development-restricted' into update-development-r

Signed-off-by: Dave Rodgman <dave.rodgman@arm.com>
diff --git a/library/cmac.c b/library/cmac.c
index f40cae2..56a9c71 100644
--- a/library/cmac.c
+++ b/library/cmac.c
@@ -34,6 +34,7 @@
 #include "mbedtls/platform_util.h"
 #include "mbedtls/error.h"
 #include "mbedtls/platform.h"
+#include "constant_time_internal.h"
 
 #include <string.h>
 
@@ -57,7 +58,7 @@
 {
     const unsigned char R_128 = 0x87;
     const unsigned char R_64 = 0x1B;
-    unsigned char R_n, mask;
+    unsigned char R_n;
     unsigned char overflow = 0x00;
     int i;
 
@@ -74,21 +75,8 @@
         overflow = input[i] >> 7;
     }
 
-    /* mask = ( input[0] >> 7 ) ? 0xff : 0x00
-     * using bit operations to avoid branches */
-
-    /* MSVC has a warning about unary minus on unsigned, but this is
-     * well-defined and precisely what we want to do here */
-#if defined(_MSC_VER)
-#pragma warning( push )
-#pragma warning( disable : 4146 )
-#endif
-    mask = -(input[0] >> 7);
-#if defined(_MSC_VER)
-#pragma warning( pop )
-#endif
-
-    output[blocksize - 1] ^= R_n & mask;
+    R_n = (unsigned char) mbedtls_ct_uint_if_else_0(mbedtls_ct_bool(input[0] >> 7), R_n);
+    output[blocksize - 1] ^= R_n;
 
     return 0;
 }
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 70500aa..57844c5 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -186,6 +186,23 @@
     } \
     output_copy = LOCAL_OUTPUT_COPY_OF_##output.buffer;
 
+/* Allocate a copy of the buffer output and set the pointer output_copy to
+ * point to the start of the copy.
+ *
+ * Assumptions:
+ * - psa_status_t status exists
+ * - An exit label is declared
+ * - output is the name of a pointer to the buffer to be copied
+ * - LOCAL_OUTPUT_DECLARE(output, output_copy) has previously been called
+ */
+#define LOCAL_OUTPUT_ALLOC_WITH_COPY(output, length, output_copy) \
+    status = psa_crypto_local_output_alloc_with_copy(output, length, \
+                                                     &LOCAL_OUTPUT_COPY_OF_##output); \
+    if (status != PSA_SUCCESS) { \
+        goto exit; \
+    } \
+    output_copy = LOCAL_OUTPUT_COPY_OF_##output.buffer;
+
 /* Free the local output copy allocated previously by LOCAL_OUTPUT_ALLOC()
  * after first copying back its contents to the original buffer.
  *
@@ -1455,13 +1472,14 @@
 }
 
 psa_status_t psa_export_key(mbedtls_svc_key_id_t key,
-                            uint8_t *data,
+                            uint8_t *data_external,
                             size_t data_size,
                             size_t *data_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_slot_t *slot;
+    LOCAL_OUTPUT_DECLARE(data_external, data);
 
     /* Reject a zero-length output buffer now, since this can never be a
      * valid key representation. This way we know that data must be a valid
@@ -1486,6 +1504,8 @@
         return status;
     }
 
+    LOCAL_OUTPUT_ALLOC(data_external, data_size, data);
+
     psa_key_attributes_t attributes = {
         .core = slot->attr
     };
@@ -1493,8 +1513,12 @@
                                            slot->key.data, slot->key.bytes,
                                            data, data_size, data_length);
 
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
     unlock_status = psa_unregister_read(slot);
 
+    LOCAL_OUTPUT_FREE(data_external, data);
     return (status == PSA_SUCCESS) ? unlock_status : status;
 }
 
@@ -1566,7 +1590,7 @@
 }
 
 psa_status_t psa_export_public_key(mbedtls_svc_key_id_t key,
-                                   uint8_t *data,
+                                   uint8_t *data_external,
                                    size_t data_size,
                                    size_t *data_length)
 {
@@ -1574,6 +1598,7 @@
     psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_slot_t *slot;
     psa_key_attributes_t attributes;
+    LOCAL_OUTPUT_DECLARE(data_external, data);
 
     /* Reject a zero-length output buffer now, since this can never be a
      * valid key representation. This way we know that data must be a valid
@@ -1594,6 +1619,8 @@
         return status;
     }
 
+    LOCAL_OUTPUT_ALLOC(data_external, data_size, data);
+
     if (!PSA_KEY_TYPE_IS_ASYMMETRIC(slot->attr.type)) {
         status = PSA_ERROR_INVALID_ARGUMENT;
         goto exit;
@@ -1609,6 +1636,7 @@
 exit:
     unlock_status = psa_unregister_read(slot);
 
+    LOCAL_OUTPUT_FREE(data_external, data);
     return (status == PSA_SUCCESS) ? unlock_status : status;
 }
 
@@ -2055,11 +2083,12 @@
 }
 
 psa_status_t psa_import_key(const psa_key_attributes_t *attributes,
-                            const uint8_t *data,
+                            const uint8_t *data_external,
                             size_t data_length,
                             mbedtls_svc_key_id_t *key)
 {
     psa_status_t status;
+    LOCAL_INPUT_DECLARE(data_external, data);
     psa_key_slot_t *slot = NULL;
     psa_se_drv_table_entry_t *driver = NULL;
     size_t bits;
@@ -2079,6 +2108,8 @@
         return PSA_ERROR_NOT_SUPPORTED;
     }
 
+    LOCAL_INPUT_ALLOC(data_external, data_length, data);
+
     status = psa_start_key_creation(PSA_KEY_CREATION_IMPORT, attributes,
                                     &slot, &driver);
     if (status != PSA_SUCCESS) {
@@ -2133,6 +2164,7 @@
 
     status = psa_finish_key_creation(slot, driver, key);
 exit:
+    LOCAL_INPUT_FREE(data_external, data);
     if (status != PSA_SUCCESS) {
         psa_fail_key_creation(slot, driver);
     }
@@ -3021,15 +3053,27 @@
 
 psa_status_t psa_sign_message(mbedtls_svc_key_id_t key,
                               psa_algorithm_t alg,
-                              const uint8_t *input,
+                              const uint8_t *input_external,
                               size_t input_length,
-                              uint8_t *signature,
+                              uint8_t *signature_external,
                               size_t signature_size,
                               size_t *signature_length)
 {
-    return psa_sign_internal(
-        key, 1, alg, input, input_length,
-        signature, signature_size, signature_length);
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    LOCAL_INPUT_DECLARE(input_external, input);
+    LOCAL_OUTPUT_DECLARE(signature_external, signature);
+
+    LOCAL_INPUT_ALLOC(input_external, input_length, input);
+    LOCAL_OUTPUT_ALLOC(signature_external, signature_size, signature);
+    status = psa_sign_internal(key, 1, alg, input, input_length, signature,
+                               signature_size, signature_length);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+    LOCAL_INPUT_FREE(input_external, input);
+    LOCAL_OUTPUT_FREE(signature_external, signature);
+    return status;
 }
 
 psa_status_t psa_verify_message_builtin(
@@ -3068,14 +3112,27 @@
 
 psa_status_t psa_verify_message(mbedtls_svc_key_id_t key,
                                 psa_algorithm_t alg,
-                                const uint8_t *input,
+                                const uint8_t *input_external,
                                 size_t input_length,
-                                const uint8_t *signature,
+                                const uint8_t *signature_external,
                                 size_t signature_length)
 {
-    return psa_verify_internal(
-        key, 1, alg, input, input_length,
-        signature, signature_length);
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    LOCAL_INPUT_DECLARE(input_external, input);
+    LOCAL_INPUT_DECLARE(signature_external, signature);
+
+    LOCAL_INPUT_ALLOC(input_external, input_length, input);
+    LOCAL_INPUT_ALLOC(signature_external, signature_length, signature);
+    status = psa_verify_internal(key, 1, alg, input, input_length, signature,
+                                 signature_length);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+    LOCAL_INPUT_FREE(input_external, input);
+    LOCAL_INPUT_FREE(signature_external, signature);
+
+    return status;
 }
 
 psa_status_t psa_sign_hash_builtin(
@@ -3128,15 +3185,28 @@
 
 psa_status_t psa_sign_hash(mbedtls_svc_key_id_t key,
                            psa_algorithm_t alg,
-                           const uint8_t *hash,
+                           const uint8_t *hash_external,
                            size_t hash_length,
-                           uint8_t *signature,
+                           uint8_t *signature_external,
                            size_t signature_size,
                            size_t *signature_length)
 {
-    return psa_sign_internal(
-        key, 0, alg, hash, hash_length,
-        signature, signature_size, signature_length);
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    LOCAL_INPUT_DECLARE(hash_external, hash);
+    LOCAL_OUTPUT_DECLARE(signature_external, signature);
+
+    LOCAL_INPUT_ALLOC(hash_external, hash_length, hash);
+    LOCAL_OUTPUT_ALLOC(signature_external, signature_size, signature);
+    status = psa_sign_internal(key, 0, alg, hash, hash_length, signature,
+                               signature_size, signature_length);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+    LOCAL_INPUT_FREE(hash_external, hash);
+    LOCAL_OUTPUT_FREE(signature_external, signature);
+
+    return status;
 }
 
 psa_status_t psa_verify_hash_builtin(
@@ -3188,14 +3258,27 @@
 
 psa_status_t psa_verify_hash(mbedtls_svc_key_id_t key,
                              psa_algorithm_t alg,
-                             const uint8_t *hash,
+                             const uint8_t *hash_external,
                              size_t hash_length,
-                             const uint8_t *signature,
+                             const uint8_t *signature_external,
                              size_t signature_length)
 {
-    return psa_verify_internal(
-        key, 0, alg, hash, hash_length,
-        signature, signature_length);
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    LOCAL_INPUT_DECLARE(hash_external, hash);
+    LOCAL_INPUT_DECLARE(signature_external, signature);
+
+    LOCAL_INPUT_ALLOC(hash_external, hash_length, hash);
+    LOCAL_INPUT_ALLOC(signature_external, signature_length, signature);
+    status = psa_verify_internal(key, 0, alg, hash, hash_length, signature,
+                                 signature_length);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+    LOCAL_INPUT_FREE(hash_external, hash);
+    LOCAL_INPUT_FREE(signature_external, signature);
+
+    return status;
 }
 
 psa_status_t psa_asymmetric_encrypt(mbedtls_svc_key_id_t key,
@@ -8576,6 +8659,39 @@
     return PSA_SUCCESS;
 }
 
+psa_status_t psa_crypto_local_output_alloc_with_copy(uint8_t *output, size_t output_len,
+                                                     psa_crypto_local_output_t *local_output)
+{
+    psa_status_t status;
+    *local_output = PSA_CRYPTO_LOCAL_OUTPUT_INIT;
+
+    if (output_len == 0) {
+        return PSA_SUCCESS;
+    }
+    local_output->buffer = mbedtls_calloc(output_len, 1);
+    if (local_output->buffer == NULL) {
+        /* Since we dealt with the zero-length case above, we know that
+         * a NULL return value means a failure of allocation. */
+        return PSA_ERROR_INSUFFICIENT_MEMORY;
+    }
+    local_output->length = output_len;
+    local_output->original = output;
+
+    status = psa_crypto_copy_input(output, output_len,
+                                   local_output->buffer, local_output->length);
+    if (status != PSA_SUCCESS) {
+        goto error;
+    }
+
+    return PSA_SUCCESS;
+
+error:
+    mbedtls_free(local_output->buffer);
+    local_output->buffer = NULL;
+    local_output->length = 0;
+    return status;
+}
+
 psa_status_t psa_crypto_local_output_free(psa_crypto_local_output_t *local_output)
 {
     psa_status_t status;
diff --git a/library/psa_crypto_core.h b/library/psa_crypto_core.h
index 8d8f153..67966f9 100644
--- a/library/psa_crypto_core.h
+++ b/library/psa_crypto_core.h
@@ -931,6 +931,25 @@
 psa_status_t psa_crypto_local_output_alloc(uint8_t *output, size_t output_len,
                                            psa_crypto_local_output_t *local_output);
 
+/** Allocate a local copy of an output buffer and copy the contents into it.
+ *
+ * \note                        This allocates and copies a buffer
+ *                              whose contents will be copied back to the
+ *                              original in a future call to
+ *                              psa_crypto_local_output_free().
+ *
+ * \param[in] output            Pointer to output buffer.
+ * \param[in] output_len        Length of the output buffer.
+ * \param[out] local_output     Pointer to a psa_crypto_local_output_t struct to
+ *                              populate with the local output copy.
+ * \return                      #PSA_SUCCESS, if the buffer was successfully
+ *                              copied.
+ * \return                      #PSA_ERROR_INSUFFICIENT_MEMORY, if a copy of
+ *                              the buffer cannot be allocated.
+ */
+psa_status_t psa_crypto_local_output_alloc_with_copy(uint8_t *output, size_t output_len,
+                                                     psa_crypto_local_output_t *local_output);
+
 /** Copy from a local copy of an output buffer back to the original, then
  *  free the local copy.
  *
diff --git a/tests/include/test/psa_crypto_helpers.h b/tests/include/test/psa_crypto_helpers.h
index bf1a594..7306d8e 100644
--- a/tests/include/test/psa_crypto_helpers.h
+++ b/tests/include/test/psa_crypto_helpers.h
@@ -16,13 +16,6 @@
 #include <psa/crypto.h>
 #endif
 
-#include "test/psa_test_wrappers.h"
-
-#if defined(MBEDTLS_TEST_HOOKS) && defined(MBEDTLS_PSA_CRYPTO_C) \
-    && defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
-#include "test/psa_memory_poisoning_wrappers.h"
-#endif
-
 #if defined(MBEDTLS_PSA_CRYPTO_C)
 /** Initialize the PSA Crypto subsystem. */
 #define PSA_INIT() PSA_ASSERT(psa_crypto_init())
diff --git a/tests/scripts/generate_psa_wrappers.py b/tests/scripts/generate_psa_wrappers.py
index e5b4256..3cdafed 100755
--- a/tests/scripts/generate_psa_wrappers.py
+++ b/tests/scripts/generate_psa_wrappers.py
@@ -145,6 +145,15 @@
         # Proof-of-concept: just instrument one function for now
         if function_name == 'psa_cipher_encrypt':
             return True
+        if function_name in ('psa_import_key',
+                             'psa_export_key',
+                             'psa_export_public_key'):
+            return True
+        if function_name in ('psa_sign_message',
+                             'psa_verify_message',
+                             'psa_sign_hash',
+                             'psa_verify_hash'):
+            return True
         return False
 
     def _write_function_call(self, out: typing_util.Writable,
diff --git a/tests/src/psa_test_wrappers.c b/tests/src/psa_test_wrappers.c
index 3a3aaad..bb1409e 100644
--- a/tests/src/psa_test_wrappers.c
+++ b/tests/src/psa_test_wrappers.c
@@ -435,7 +435,13 @@
     size_t arg2_data_size,
     size_t *arg3_data_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg1_data, arg2_data_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_export_key)(arg0_key, arg1_data, arg2_data_size, arg3_data_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg1_data, arg2_data_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -446,7 +452,13 @@
     size_t arg2_data_size,
     size_t *arg3_data_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg1_data, arg2_data_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_export_public_key)(arg0_key, arg1_data, arg2_data_size, arg3_data_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg1_data, arg2_data_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -566,7 +578,13 @@
     size_t arg2_data_length,
     mbedtls_svc_key_id_t *arg3_key)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg1_data, arg2_data_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_import_key)(arg0_attributes, arg1_data, arg2_data_length, arg3_key);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg1_data, arg2_data_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -873,7 +891,15 @@
     size_t arg5_signature_size,
     size_t *arg6_signature_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg2_hash, arg3_hash_length);
+    MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_sign_hash)(arg0_key, arg1_alg, arg2_hash, arg3_hash_length, arg4_signature, arg5_signature_size, arg6_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg2_hash, arg3_hash_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -918,7 +944,15 @@
     size_t arg5_signature_size,
     size_t *arg6_signature_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg2_input, arg3_input_length);
+    MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_sign_message)(arg0_key, arg1_alg, arg2_input, arg3_input_length, arg4_signature, arg5_signature_size, arg6_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg2_input, arg3_input_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -931,7 +965,15 @@
     const uint8_t *arg4_signature,
     size_t arg5_signature_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg2_hash, arg3_hash_length);
+    MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_verify_hash)(arg0_key, arg1_alg, arg2_hash, arg3_hash_length, arg4_signature, arg5_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg2_hash, arg3_hash_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -974,7 +1016,15 @@
     const uint8_t *arg4_signature,
     size_t arg5_signature_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg2_input, arg3_input_length);
+    MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_verify_message)(arg0_key, arg1_alg, arg2_input, arg3_input_length, arg4_signature, arg5_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg2_input, arg3_input_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }