Merge pull request #1151 from tom-daubney-arm/asymmetric_sign_buffer_protection

Implement safe buffer copying in asymmetric signature API
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 3b3c116..d21c13e 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -3165,15 +3165,27 @@
 
 psa_status_t psa_sign_message(mbedtls_svc_key_id_t key,
                               psa_algorithm_t alg,
-                              const uint8_t *input,
+                              const uint8_t *input_external,
                               size_t input_length,
-                              uint8_t *signature,
+                              uint8_t *signature_external,
                               size_t signature_size,
                               size_t *signature_length)
 {
-    return psa_sign_internal(
-        key, 1, alg, input, input_length,
-        signature, signature_size, signature_length);
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    LOCAL_INPUT_DECLARE(input_external, input);
+    LOCAL_OUTPUT_DECLARE(signature_external, signature);
+
+    LOCAL_INPUT_ALLOC(input_external, input_length, input);
+    LOCAL_OUTPUT_ALLOC(signature_external, signature_size, signature);
+    status = psa_sign_internal(key, 1, alg, input, input_length, signature,
+                               signature_size, signature_length);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+    LOCAL_INPUT_FREE(input_external, input);
+    LOCAL_OUTPUT_FREE(signature_external, signature);
+    return status;
 }
 
 psa_status_t psa_verify_message_builtin(
@@ -3212,14 +3224,27 @@
 
 psa_status_t psa_verify_message(mbedtls_svc_key_id_t key,
                                 psa_algorithm_t alg,
-                                const uint8_t *input,
+                                const uint8_t *input_external,
                                 size_t input_length,
-                                const uint8_t *signature,
+                                const uint8_t *signature_external,
                                 size_t signature_length)
 {
-    return psa_verify_internal(
-        key, 1, alg, input, input_length,
-        signature, signature_length);
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    LOCAL_INPUT_DECLARE(input_external, input);
+    LOCAL_INPUT_DECLARE(signature_external, signature);
+
+    LOCAL_INPUT_ALLOC(input_external, input_length, input);
+    LOCAL_INPUT_ALLOC(signature_external, signature_length, signature);
+    status = psa_verify_internal(key, 1, alg, input, input_length, signature,
+                                 signature_length);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+    LOCAL_INPUT_FREE(input_external, input);
+    LOCAL_INPUT_FREE(signature_external, signature);
+
+    return status;
 }
 
 psa_status_t psa_sign_hash_builtin(
@@ -3272,15 +3297,28 @@
 
 psa_status_t psa_sign_hash(mbedtls_svc_key_id_t key,
                            psa_algorithm_t alg,
-                           const uint8_t *hash,
+                           const uint8_t *hash_external,
                            size_t hash_length,
-                           uint8_t *signature,
+                           uint8_t *signature_external,
                            size_t signature_size,
                            size_t *signature_length)
 {
-    return psa_sign_internal(
-        key, 0, alg, hash, hash_length,
-        signature, signature_size, signature_length);
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    LOCAL_INPUT_DECLARE(hash_external, hash);
+    LOCAL_OUTPUT_DECLARE(signature_external, signature);
+
+    LOCAL_INPUT_ALLOC(hash_external, hash_length, hash);
+    LOCAL_OUTPUT_ALLOC(signature_external, signature_size, signature);
+    status = psa_sign_internal(key, 0, alg, hash, hash_length, signature,
+                               signature_size, signature_length);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+    LOCAL_INPUT_FREE(hash_external, hash);
+    LOCAL_OUTPUT_FREE(signature_external, signature);
+
+    return status;
 }
 
 psa_status_t psa_verify_hash_builtin(
@@ -3332,14 +3370,27 @@
 
 psa_status_t psa_verify_hash(mbedtls_svc_key_id_t key,
                              psa_algorithm_t alg,
-                             const uint8_t *hash,
+                             const uint8_t *hash_external,
                              size_t hash_length,
-                             const uint8_t *signature,
+                             const uint8_t *signature_external,
                              size_t signature_length)
 {
-    return psa_verify_internal(
-        key, 0, alg, hash, hash_length,
-        signature, signature_length);
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    LOCAL_INPUT_DECLARE(hash_external, hash);
+    LOCAL_INPUT_DECLARE(signature_external, signature);
+
+    LOCAL_INPUT_ALLOC(hash_external, hash_length, hash);
+    LOCAL_INPUT_ALLOC(signature_external, signature_length, signature);
+    status = psa_verify_internal(key, 0, alg, hash, hash_length, signature,
+                                 signature_length);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+    LOCAL_INPUT_FREE(hash_external, hash);
+    LOCAL_INPUT_FREE(signature_external, signature);
+
+    return status;
 }
 
 psa_status_t psa_asymmetric_encrypt(mbedtls_svc_key_id_t key,
diff --git a/tests/scripts/generate_psa_wrappers.py b/tests/scripts/generate_psa_wrappers.py
index e5b4256..005a324 100755
--- a/tests/scripts/generate_psa_wrappers.py
+++ b/tests/scripts/generate_psa_wrappers.py
@@ -145,6 +145,11 @@
         # Proof-of-concept: just instrument one function for now
         if function_name == 'psa_cipher_encrypt':
             return True
+        if function_name in ('psa_sign_message',
+                             'psa_verify_message',
+                             'psa_sign_hash',
+                             'psa_verify_hash'):
+            return True
         return False
 
     def _write_function_call(self, out: typing_util.Writable,
diff --git a/tests/src/psa_test_wrappers.c b/tests/src/psa_test_wrappers.c
index 3a3aaad..460d453 100644
--- a/tests/src/psa_test_wrappers.c
+++ b/tests/src/psa_test_wrappers.c
@@ -873,7 +873,15 @@
     size_t arg5_signature_size,
     size_t *arg6_signature_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg2_hash, arg3_hash_length);
+    MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_sign_hash)(arg0_key, arg1_alg, arg2_hash, arg3_hash_length, arg4_signature, arg5_signature_size, arg6_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg2_hash, arg3_hash_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -918,7 +926,15 @@
     size_t arg5_signature_size,
     size_t *arg6_signature_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg2_input, arg3_input_length);
+    MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_sign_message)(arg0_key, arg1_alg, arg2_input, arg3_input_length, arg4_signature, arg5_signature_size, arg6_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg2_input, arg3_input_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -931,7 +947,15 @@
     const uint8_t *arg4_signature,
     size_t arg5_signature_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg2_hash, arg3_hash_length);
+    MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_verify_hash)(arg0_key, arg1_alg, arg2_hash, arg3_hash_length, arg4_signature, arg5_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg2_hash, arg3_hash_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -974,7 +998,15 @@
     const uint8_t *arg4_signature,
     size_t arg5_signature_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg2_input, arg3_input_length);
+    MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_verify_message)(arg0_key, arg1_alg, arg2_input, arg3_input_length, arg4_signature, arg5_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg2_input, arg3_input_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }