Merge pull request #1188 from davidhorstmann-arm/interruptible-sign-hash-buffer-protection

Add buffer protection for interruptible sign/verify
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 6538175..7473aef 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -3556,13 +3556,15 @@
 psa_status_t psa_sign_hash_start(
     psa_sign_hash_interruptible_operation_t *operation,
     mbedtls_svc_key_id_t key, psa_algorithm_t alg,
-    const uint8_t *hash, size_t hash_length)
+    const uint8_t *hash_external, size_t hash_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_slot_t *slot;
     psa_key_attributes_t attributes;
 
+    LOCAL_INPUT_DECLARE(hash_external, hash);
+
     /* Check that start has not been previously called, or operation has not
      * previously errored. */
     if (operation->id != 0 || operation->error_occurred) {
@@ -3588,6 +3590,8 @@
         goto exit;
     }
 
+    LOCAL_INPUT_ALLOC(hash_external, hash_length, hash);
+
     attributes = (psa_key_attributes_t) {
         .core = slot->attr
     };
@@ -3612,17 +3616,21 @@
         operation->error_occurred = 1;
     }
 
+    LOCAL_INPUT_FREE(hash_external, hash);
+
     return (status == PSA_SUCCESS) ? unlock_status : status;
 }
 
 
 psa_status_t psa_sign_hash_complete(
     psa_sign_hash_interruptible_operation_t *operation,
-    uint8_t *signature, size_t signature_size,
+    uint8_t *signature_external, size_t signature_size,
     size_t *signature_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
+    LOCAL_OUTPUT_DECLARE(signature_external, signature);
+
     *signature_length = 0;
 
     /* Check that start has been called first, and that operation has not
@@ -3639,6 +3647,8 @@
         goto exit;
     }
 
+    LOCAL_OUTPUT_ALLOC(signature_external, signature_size, signature);
+
     status = psa_driver_wrapper_sign_hash_complete(operation, signature,
                                                    signature_size,
                                                    signature_length);
@@ -3648,8 +3658,10 @@
 
 exit:
 
-    psa_wipe_tag_output_buffer(signature, status, signature_size,
-                               *signature_length);
+    if (signature != NULL) {
+        psa_wipe_tag_output_buffer(signature, status, signature_size,
+                                   *signature_length);
+    }
 
     if (status != PSA_OPERATION_INCOMPLETE) {
         if (status != PSA_SUCCESS) {
@@ -3659,6 +3671,8 @@
         psa_sign_hash_abort_internal(operation);
     }
 
+    LOCAL_OUTPUT_FREE(signature_external, signature);
+
     return status;
 }
 
@@ -3705,13 +3719,16 @@
 psa_status_t psa_verify_hash_start(
     psa_verify_hash_interruptible_operation_t *operation,
     mbedtls_svc_key_id_t key, psa_algorithm_t alg,
-    const uint8_t *hash, size_t hash_length,
-    const uint8_t *signature, size_t signature_length)
+    const uint8_t *hash_external, size_t hash_length,
+    const uint8_t *signature_external, size_t signature_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_slot_t *slot;
 
+    LOCAL_INPUT_DECLARE(hash_external, hash);
+    LOCAL_INPUT_DECLARE(signature_external, signature);
+
     /* Check that start has not been previously called, or operation has not
      * previously errored. */
     if (operation->id != 0 || operation->error_occurred) {
@@ -3733,6 +3750,9 @@
         return status;
     }
 
+    LOCAL_INPUT_ALLOC(hash_external, hash_length, hash);
+    LOCAL_INPUT_ALLOC(signature_external, signature_length, signature);
+
     psa_key_attributes_t attributes = {
         .core = slot->attr
     };
@@ -3745,6 +3765,9 @@
                                                   slot->key.bytes,
                                                   alg, hash, hash_length,
                                                   signature, signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
 
     if (status != PSA_SUCCESS) {
         operation->error_occurred = 1;
@@ -3757,6 +3780,9 @@
         operation->error_occurred = 1;
     }
 
+    LOCAL_INPUT_FREE(hash_external, hash);
+    LOCAL_INPUT_FREE(signature_external, signature);
+
     return (status == PSA_SUCCESS) ? unlock_status : status;
 }
 
diff --git a/tests/scripts/generate_psa_wrappers.py b/tests/scripts/generate_psa_wrappers.py
index a2d8787..1a05c75 100755
--- a/tests/scripts/generate_psa_wrappers.py
+++ b/tests/scripts/generate_psa_wrappers.py
@@ -142,48 +142,14 @@
                                     _buffer_name: Optional[str]) -> bool:
         """Whether the specified buffer argument to a PSA function should be copied.
         """
-        #pylint: disable=too-many-return-statements
-        if function_name.startswith('psa_pake'):
-            return True
-        if function_name.startswith('psa_aead'):
-            return True
-        if function_name in {'psa_cipher_encrypt', 'psa_cipher_decrypt',
-                             'psa_cipher_update', 'psa_cipher_finish',
-                             'psa_cipher_generate_iv', 'psa_cipher_set_iv'}:
-            return True
-        if function_name in ('psa_key_derivation_output_bytes',
-                             'psa_key_derivation_input_bytes'):
-            return True
-        if function_name in ('psa_import_key',
-                             'psa_export_key',
-                             'psa_export_public_key'):
-            return True
-        if function_name in ('psa_sign_message',
-                             'psa_verify_message',
-                             'psa_sign_hash',
-                             'psa_verify_hash'):
-            return True
-        if function_name in ('psa_hash_update',
-                             'psa_hash_finish',
-                             'psa_hash_verify',
-                             'psa_hash_compute',
-                             'psa_hash_compare'):
-            return True
-        if function_name in ('psa_key_derivation_key_agreement',
-                             'psa_raw_key_agreement'):
-            return True
-        if function_name == 'psa_generate_random':
-            return True
-        if function_name in ('psa_mac_update',
-                             'psa_mac_sign_finish',
-                             'psa_mac_verify_finish',
-                             'psa_mac_compute',
-                             'psa_mac_verify'):
-            return True
-        if function_name in ('psa_asymmetric_encrypt',
-                             'psa_asymmetric_decrypt'):
-            return True
-        return False
+        # False-positives that do not need buffer copying
+        if function_name in ('mbedtls_psa_inject_entropy',
+                             'psa_crypto_driver_pake_get_password',
+                             'psa_crypto_driver_pake_get_user',
+                             'psa_crypto_driver_pake_get_peer'):
+            return False
+
+        return True
 
     def _write_function_call(self, out: typing_util.Writable,
                              function: c_wrapper_generator.FunctionInfo,
diff --git a/tests/src/psa_test_wrappers.c b/tests/src/psa_test_wrappers.c
index 71ea09c..f303af8 100644
--- a/tests/src/psa_test_wrappers.c
+++ b/tests/src/psa_test_wrappers.c
@@ -1162,7 +1162,13 @@
     size_t arg2_signature_size,
     size_t *arg3_signature_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg1_signature, arg2_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_sign_hash_complete)(arg0_operation, arg1_signature, arg2_signature_size, arg3_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg1_signature, arg2_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -1174,7 +1180,13 @@
     const uint8_t *arg3_hash,
     size_t arg4_hash_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg3_hash, arg4_hash_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_sign_hash_start)(arg0_operation, arg1_key, arg2_alg, arg3_hash, arg4_hash_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg3_hash, arg4_hash_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -1247,7 +1259,15 @@
     const uint8_t *arg5_signature,
     size_t arg6_signature_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg3_hash, arg4_hash_length);
+    MBEDTLS_TEST_MEMORY_POISON(arg5_signature, arg6_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_verify_hash_start)(arg0_operation, arg1_key, arg2_alg, arg3_hash, arg4_hash_length, arg5_signature, arg6_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg3_hash, arg4_hash_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg5_signature, arg6_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }