Merge pull request #1151 from tom-daubney-arm/asymmetric_sign_buffer_protection
Implement safe buffer copying in asymmetric signature API
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 3b3c116..d21c13e 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -3165,15 +3165,27 @@
psa_status_t psa_sign_message(mbedtls_svc_key_id_t key,
psa_algorithm_t alg,
- const uint8_t *input,
+ const uint8_t *input_external,
size_t input_length,
- uint8_t *signature,
+ uint8_t *signature_external,
size_t signature_size,
size_t *signature_length)
{
- return psa_sign_internal(
- key, 1, alg, input, input_length,
- signature, signature_size, signature_length);
+ psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+ LOCAL_INPUT_DECLARE(input_external, input);
+ LOCAL_OUTPUT_DECLARE(signature_external, signature);
+
+ LOCAL_INPUT_ALLOC(input_external, input_length, input);
+ LOCAL_OUTPUT_ALLOC(signature_external, signature_size, signature);
+ status = psa_sign_internal(key, 1, alg, input, input_length, signature,
+ signature_size, signature_length);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+ LOCAL_INPUT_FREE(input_external, input);
+ LOCAL_OUTPUT_FREE(signature_external, signature);
+ return status;
}
psa_status_t psa_verify_message_builtin(
@@ -3212,14 +3224,27 @@
psa_status_t psa_verify_message(mbedtls_svc_key_id_t key,
psa_algorithm_t alg,
- const uint8_t *input,
+ const uint8_t *input_external,
size_t input_length,
- const uint8_t *signature,
+ const uint8_t *signature_external,
size_t signature_length)
{
- return psa_verify_internal(
- key, 1, alg, input, input_length,
- signature, signature_length);
+ psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+ LOCAL_INPUT_DECLARE(input_external, input);
+ LOCAL_INPUT_DECLARE(signature_external, signature);
+
+ LOCAL_INPUT_ALLOC(input_external, input_length, input);
+ LOCAL_INPUT_ALLOC(signature_external, signature_length, signature);
+ status = psa_verify_internal(key, 1, alg, input, input_length, signature,
+ signature_length);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+ LOCAL_INPUT_FREE(input_external, input);
+ LOCAL_INPUT_FREE(signature_external, signature);
+
+ return status;
}
psa_status_t psa_sign_hash_builtin(
@@ -3272,15 +3297,28 @@
psa_status_t psa_sign_hash(mbedtls_svc_key_id_t key,
psa_algorithm_t alg,
- const uint8_t *hash,
+ const uint8_t *hash_external,
size_t hash_length,
- uint8_t *signature,
+ uint8_t *signature_external,
size_t signature_size,
size_t *signature_length)
{
- return psa_sign_internal(
- key, 0, alg, hash, hash_length,
- signature, signature_size, signature_length);
+ psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+ LOCAL_INPUT_DECLARE(hash_external, hash);
+ LOCAL_OUTPUT_DECLARE(signature_external, signature);
+
+ LOCAL_INPUT_ALLOC(hash_external, hash_length, hash);
+ LOCAL_OUTPUT_ALLOC(signature_external, signature_size, signature);
+ status = psa_sign_internal(key, 0, alg, hash, hash_length, signature,
+ signature_size, signature_length);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+ LOCAL_INPUT_FREE(hash_external, hash);
+ LOCAL_OUTPUT_FREE(signature_external, signature);
+
+ return status;
}
psa_status_t psa_verify_hash_builtin(
@@ -3332,14 +3370,27 @@
psa_status_t psa_verify_hash(mbedtls_svc_key_id_t key,
psa_algorithm_t alg,
- const uint8_t *hash,
+ const uint8_t *hash_external,
size_t hash_length,
- const uint8_t *signature,
+ const uint8_t *signature_external,
size_t signature_length)
{
- return psa_verify_internal(
- key, 0, alg, hash, hash_length,
- signature, signature_length);
+ psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+ LOCAL_INPUT_DECLARE(hash_external, hash);
+ LOCAL_INPUT_DECLARE(signature_external, signature);
+
+ LOCAL_INPUT_ALLOC(hash_external, hash_length, hash);
+ LOCAL_INPUT_ALLOC(signature_external, signature_length, signature);
+ status = psa_verify_internal(key, 0, alg, hash, hash_length, signature,
+ signature_length);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
+ LOCAL_INPUT_FREE(hash_external, hash);
+ LOCAL_INPUT_FREE(signature_external, signature);
+
+ return status;
}
psa_status_t psa_asymmetric_encrypt(mbedtls_svc_key_id_t key,
diff --git a/tests/scripts/generate_psa_wrappers.py b/tests/scripts/generate_psa_wrappers.py
index e5b4256..005a324 100755
--- a/tests/scripts/generate_psa_wrappers.py
+++ b/tests/scripts/generate_psa_wrappers.py
@@ -145,6 +145,11 @@
# Proof-of-concept: just instrument one function for now
if function_name == 'psa_cipher_encrypt':
return True
+ if function_name in ('psa_sign_message',
+ 'psa_verify_message',
+ 'psa_sign_hash',
+ 'psa_verify_hash'):
+ return True
return False
def _write_function_call(self, out: typing_util.Writable,
diff --git a/tests/src/psa_test_wrappers.c b/tests/src/psa_test_wrappers.c
index 3a3aaad..460d453 100644
--- a/tests/src/psa_test_wrappers.c
+++ b/tests/src/psa_test_wrappers.c
@@ -873,7 +873,15 @@
size_t arg5_signature_size,
size_t *arg6_signature_length)
{
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+ MBEDTLS_TEST_MEMORY_POISON(arg2_hash, arg3_hash_length);
+ MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
psa_status_t status = (psa_sign_hash)(arg0_key, arg1_alg, arg2_hash, arg3_hash_length, arg4_signature, arg5_signature_size, arg6_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+ MBEDTLS_TEST_MEMORY_UNPOISON(arg2_hash, arg3_hash_length);
+ MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
return status;
}
@@ -918,7 +926,15 @@
size_t arg5_signature_size,
size_t *arg6_signature_length)
{
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+ MBEDTLS_TEST_MEMORY_POISON(arg2_input, arg3_input_length);
+ MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
psa_status_t status = (psa_sign_message)(arg0_key, arg1_alg, arg2_input, arg3_input_length, arg4_signature, arg5_signature_size, arg6_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+ MBEDTLS_TEST_MEMORY_UNPOISON(arg2_input, arg3_input_length);
+ MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
return status;
}
@@ -931,7 +947,15 @@
const uint8_t *arg4_signature,
size_t arg5_signature_length)
{
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+ MBEDTLS_TEST_MEMORY_POISON(arg2_hash, arg3_hash_length);
+ MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
psa_status_t status = (psa_verify_hash)(arg0_key, arg1_alg, arg2_hash, arg3_hash_length, arg4_signature, arg5_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+ MBEDTLS_TEST_MEMORY_UNPOISON(arg2_hash, arg3_hash_length);
+ MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
return status;
}
@@ -974,7 +998,15 @@
const uint8_t *arg4_signature,
size_t arg5_signature_length)
{
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+ MBEDTLS_TEST_MEMORY_POISON(arg2_input, arg3_input_length);
+ MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
psa_status_t status = (psa_verify_message)(arg0_key, arg1_alg, arg2_input, arg3_input_length, arg4_signature, arg5_signature_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+ MBEDTLS_TEST_MEMORY_UNPOISON(arg2_input, arg3_input_length);
+ MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
return status;
}