Merge branch 'development-restricted' into key_agreement_buffer_protection

Signed-off-by: tom-daubney-arm <74920390+tom-daubney-arm@users.noreply.github.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 77c6955..bff510d 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7581,12 +7581,13 @@
 psa_status_t psa_key_derivation_key_agreement(psa_key_derivation_operation_t *operation,
                                               psa_key_derivation_step_t step,
                                               mbedtls_svc_key_id_t private_key,
-                                              const uint8_t *peer_key,
+                                              const uint8_t *peer_key_external,
                                               size_t peer_key_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_slot_t *slot;
+    LOCAL_INPUT_DECLARE(peer_key_external, peer_key);
 
     if (!PSA_ALG_IS_KEY_AGREEMENT(operation->alg)) {
         return PSA_ERROR_INVALID_ARGUMENT;
@@ -7596,9 +7597,15 @@
     if (status != PSA_SUCCESS) {
         return status;
     }
+
+    LOCAL_INPUT_ALLOC(peer_key_external, peer_key_length, peer_key)
     status = psa_key_agreement_internal(operation, step,
                                         slot,
                                         peer_key, peer_key_length);
+
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+exit:
+#endif
     if (status != PSA_SUCCESS) {
         psa_key_derivation_abort(operation);
     } else {
@@ -7610,15 +7617,15 @@
     }
 
     unlock_status = psa_unregister_read(slot);
-
+    LOCAL_INPUT_FREE(peer_key_external, peer_key);
     return (status == PSA_SUCCESS) ? unlock_status : status;
 }
 
 psa_status_t psa_raw_key_agreement(psa_algorithm_t alg,
                                    mbedtls_svc_key_id_t private_key,
-                                   const uint8_t *peer_key,
+                                   const uint8_t *peer_key_external,
                                    size_t peer_key_length,
-                                   uint8_t *output,
+                                   uint8_t *output_external,
                                    size_t output_size,
                                    size_t *output_length)
 {
@@ -7626,6 +7633,9 @@
     psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_slot_t *slot = NULL;
     size_t expected_length;
+    LOCAL_INPUT_DECLARE(peer_key_external, peer_key);
+    LOCAL_OUTPUT_DECLARE(output_external, output);
+    LOCAL_OUTPUT_ALLOC(output_external, output_size, output);
 
     if (!PSA_ALG_IS_KEY_AGREEMENT(alg)) {
         status = PSA_ERROR_INVALID_ARGUMENT;
@@ -7652,13 +7662,16 @@
         goto exit;
     }
 
+    LOCAL_INPUT_ALLOC(peer_key_external, peer_key_length, peer_key);
     status = psa_key_agreement_raw_internal(alg, slot,
                                             peer_key, peer_key_length,
                                             output, output_size,
                                             output_length);
 
 exit:
-    if (status != PSA_SUCCESS) {
+    /* Check for successful allocation of output,
+     * with an unsuccessful status. */
+    if (output != NULL && status != PSA_SUCCESS) {
         /* If an error happens and is not handled properly, the output
          * may be used as a key to protect sensitive data. Arrange for such
          * a key to be random, which is likely to result in decryption or
@@ -7670,8 +7683,15 @@
         *output_length = output_size;
     }
 
+    if (output == NULL) {
+        /* output allocation failed. */
+        *output_length = 0;
+    }
+
     unlock_status = psa_unregister_read(slot);
 
+    LOCAL_INPUT_FREE(peer_key_external, peer_key);
+    LOCAL_OUTPUT_FREE(output_external, output);
     return (status == PSA_SUCCESS) ? unlock_status : status;
 }
 
diff --git a/tests/scripts/generate_psa_wrappers.py b/tests/scripts/generate_psa_wrappers.py
index fbe7cf1..eb80900 100755
--- a/tests/scripts/generate_psa_wrappers.py
+++ b/tests/scripts/generate_psa_wrappers.py
@@ -167,6 +167,9 @@
                              'psa_hash_compute',
                              'psa_hash_compare'):
             return True
+        if function_name in ('psa_key_derivation_key_agreement',
+                             'psa_raw_key_agreement'):
+            return True
         if function_name == 'psa_generate_random':
             return True
         if function_name in ('psa_mac_update',
diff --git a/tests/src/psa_test_wrappers.c b/tests/src/psa_test_wrappers.c
index 86ab1c3..b37378b 100644
--- a/tests/src/psa_test_wrappers.c
+++ b/tests/src/psa_test_wrappers.c
@@ -810,7 +810,13 @@
     const uint8_t *arg3_peer_key,
     size_t arg4_peer_key_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg3_peer_key, arg4_peer_key_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_key_derivation_key_agreement)(arg0_operation, arg1_step, arg2_private_key, arg3_peer_key, arg4_peer_key_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg3_peer_key, arg4_peer_key_length);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
@@ -1083,7 +1089,15 @@
     size_t arg5_output_size,
     size_t *arg6_output_length)
 {
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_POISON(arg2_peer_key, arg3_peer_key_length);
+    MBEDTLS_TEST_MEMORY_POISON(arg4_output, arg5_output_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     psa_status_t status = (psa_raw_key_agreement)(arg0_alg, arg1_private_key, arg2_peer_key, arg3_peer_key_length, arg4_output, arg5_output_size, arg6_output_length);
+#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg2_peer_key, arg3_peer_key_length);
+    MBEDTLS_TEST_MEMORY_UNPOISON(arg4_output, arg5_output_size);
+#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */
     return status;
 }
 
diff --git a/tests/suites/test_suite_psa_crypto_op_fail.function b/tests/suites/test_suite_psa_crypto_op_fail.function
index 20942bf..9878237 100644
--- a/tests/suites/test_suite_psa_crypto_op_fail.function
+++ b/tests/suites/test_suite_psa_crypto_op_fail.function
@@ -359,9 +359,9 @@
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
     uint8_t public_key[PSA_EXPORT_PUBLIC_KEY_MAX_SIZE] = { 0 };
-    size_t public_key_length = SIZE_MAX;
+    size_t public_key_length = 0;
     uint8_t output[PSA_RAW_KEY_AGREEMENT_OUTPUT_MAX_SIZE] = { 0 };
-    size_t length = SIZE_MAX;
+    size_t length = 0;
     psa_key_derivation_operation_t operation = PSA_KEY_DERIVATION_OPERATION_INIT;
 
     PSA_INIT();