Refactor PSA key agreement API implementation

Signed-off-by: Waleed Elmelegy <waleed.elmelegy@arm.com>
diff --git a/tf-psa-crypto/core/psa_crypto.c b/tf-psa-crypto/core/psa_crypto.c
index 1ab2dbb..9accc7f 100644
--- a/tf-psa-crypto/core/psa_crypto.c
+++ b/tf-psa-crypto/core/psa_crypto.c
@@ -7764,12 +7764,25 @@
     return status;
 }
 
+#if defined(MBEDTLS_ECP_RESTARTABLE) && \
+    defined(MBEDTLS_PSA_BUILTIN_ALG_ECDH)
+
+static psa_status_t psa_key_agreement_iop_abort_internal(psa_key_agreement_iop_t *operation)
+{
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
+    status = psa_driver_wrapper_key_agreement_abort(operation);
+
+    return status;
+}
+#endif
+
 uint32_t psa_key_agreement_iop_get_num_ops(
     psa_key_agreement_iop_t *operation)
 {
 #if defined(MBEDTLS_ECP_RESTARTABLE) && \
     defined(MBEDTLS_PSA_BUILTIN_ALG_ECDH)
-    return psa_driver_wrapper_key_agreement_get_num_ops(operation);
+    return operation->num_ops;
 #else
     (void) operation;
     return 0;
@@ -7786,60 +7799,52 @@
 {
 #if defined(MBEDTLS_ECP_RESTARTABLE) && \
     defined(MBEDTLS_PSA_BUILTIN_ALG_ECDH)
-    psa_status_t status = PSA_SUCCESS;
-    uint8_t *private_key_buffer = NULL;
-    size_t key_size = 0;
-    size_t key_len = 0;
-    psa_key_attributes_t private_key_attributes;
-    psa_key_type_t private_key_type;
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_key_type_t key_type;
+    psa_key_slot_t *slot = NULL;
 
     if (operation->id != 0 || operation->error_occurred) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    status = psa_get_key_attributes(private_key, &private_key_attributes);
-    if (status != PSA_SUCCESS) {
-        operation->error_occurred = 1;
-        return status;
-    }
-
-    private_key_type = psa_get_key_type(&private_key_attributes);
-    if (!PSA_KEY_TYPE_IS_ECC_KEY_PAIR(private_key_type) ||
-        !PSA_ALG_IS_ECDH(alg)) {
+    if (!PSA_ALG_IS_RAW_KEY_AGREEMENT(alg)) {
         operation->error_occurred = 1;
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
-    key_size = PSA_EXPORT_KEY_OUTPUT_SIZE(private_key_type,
-                                          psa_get_key_bits(&private_key_attributes));
-    if (key_size == 0) {
+    key_type = psa_get_key_type(attributes);
+    if (key_type != PSA_KEY_TYPE_DERIVE &&
+        key_type != PSA_KEY_TYPE_RAW_DATA) {
         operation->error_occurred = 1;
         return PSA_ERROR_INVALID_ARGUMENT;
     }
-    private_key_buffer = mbedtls_calloc(key_size, 1);
-    if (private_key_buffer == NULL) {
-        operation->error_occurred = 1;
-        return PSA_ERROR_INSUFFICIENT_MEMORY;
-    }
 
-    status = psa_export_key(private_key, private_key_buffer, key_size, &key_len);
+    status = psa_get_and_lock_transparent_key_slot_with_policy(
+        private_key, &slot, PSA_KEY_USAGE_DERIVE, alg);
     if (status != PSA_SUCCESS) {
         goto exit;
     }
 
-    operation->ctx.mbedtls_ctx.attributes = attributes;
+    operation->attributes = *attributes;
 
-    status = psa_driver_wrapper_key_agreement_setup(operation, private_key_buffer,
-                                                    key_len, peer_key,
+    operation->num_ops = 0;
+
+    status = psa_driver_wrapper_key_agreement_setup(operation, slot->key.data,
+                                                    slot->key.bytes, peer_key,
                                                     peer_key_length,
-                                                    &private_key_attributes);
+                                                    &slot->attr);
+
+    operation->num_ops = psa_driver_wrapper_key_agreement_get_num_ops(operation);
 
 exit:
-    mbedtls_free(private_key_buffer);
+    unlock_status = psa_unregister_read_under_mutex(slot);
     if (status != PSA_SUCCESS) {
         operation->error_occurred = 1;
+        psa_key_agreement_iop_abort_internal(operation);
+        return status;
     }
-    return status;
+    return unlock_status;
 #else
     (void) operation;
     (void) private_key;
@@ -7873,19 +7878,20 @@
     operation->num_ops = psa_driver_wrapper_key_agreement_get_num_ops(operation);
 
     if (status == PSA_SUCCESS) {
-        status = psa_import_key(operation->ctx.mbedtls_ctx.attributes, intermediate_key,
+        status = psa_import_key(&operation->attributes, intermediate_key,
                                 key_len, key);
     }
 
     if (status != PSA_SUCCESS && status != PSA_OPERATION_INCOMPLETE) {
         operation->error_occurred = 1;
+        psa_key_agreement_iop_abort_internal(operation);
     }
     mbedtls_platform_zeroize(intermediate_key, PSA_RAW_KEY_AGREEMENT_OUTPUT_MAX_SIZE);
     return status;
 #else
     (void) operation;
     (void) key;
-    return PSA_ERROR_NOT_SUPPORTED;
+    return PSA_ERROR_BAD_STATE;
 #endif
 }
 
@@ -7896,12 +7902,16 @@
     defined(MBEDTLS_PSA_BUILTIN_ALG_ECDH)
     psa_status_t status;
 
-    status = psa_driver_wrapper_key_agreement_abort(operation);
+    status = psa_key_agreement_iop_abort_internal(operation);
+
+    operation->num_ops = 0;
     operation->error_occurred = 0;
+    operation->id = 0;
+
     return status;
 #else
     (void) operation;
-    return PSA_ERROR_NOT_SUPPORTED;
+    return PSA_SUCCESS;
 #endif
 }
 
diff --git a/tf-psa-crypto/core/psa_crypto_core.h b/tf-psa-crypto/core/psa_crypto_core.h
index 33950e0..c23b0b2 100644
--- a/tf-psa-crypto/core/psa_crypto_core.h
+++ b/tf-psa-crypto/core/psa_crypto_core.h
@@ -702,7 +702,7 @@
 
 /**
  * \brief Get the total number of ops that a key agreement operation has taken
- *        Since it's start.
+ *        Since its start.
  *
  * \note The signature of this function is that of a PSA driver
  *       key_agreement_get_num_ops entry point. This function behaves as an
@@ -718,7 +718,7 @@
     mbedtls_psa_key_agreement_interruptible_operation_t *operation);
 
 /**
- * \brief  Setup a new interruptible key agreement operation.
+ * \brief  Set up a new interruptible key agreement operation.
  *
  * \note The signature of this function is that of a PSA driver
  *       key_agreement_setup entry point. This function behaves as a
@@ -775,8 +775,8 @@
  * \retval #PSA_SUCCESS
  *         The shared secret was calculated successfully.
  * \retval #PSA_ERROR_INVALID_ARGUMENT \emptydescription
- * \retval #PSA_ERROR_NOT_SUPPORTED
- *         Internal interruptible operations are currently supported.
+ * \retval #PSA_ERROR_CORRUPTION_DETECTED
+ *         Internal interruptible operations are currently not supported.
  * \retval #PSA_ERROR_BUFFER_TOO_SMALL
  *         \p shared_secret_size is too small
  */
diff --git a/tf-psa-crypto/drivers/builtin/src/psa_crypto_ecp.c b/tf-psa-crypto/drivers/builtin/src/psa_crypto_ecp.c
index 6556c6e..949bf3b 100644
--- a/tf-psa-crypto/drivers/builtin/src/psa_crypto_ecp.c
+++ b/tf-psa-crypto/drivers/builtin/src/psa_crypto_ecp.c
@@ -630,17 +630,12 @@
 /* Interruptible ECC Key Agreement */
 /****************************************************************/
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_ECDH)
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_ECDH) && defined(MBEDTLS_ECP_RESTARTABLE)
 
 uint32_t mbedtls_psa_key_agreement_get_num_ops(
     mbedtls_psa_key_agreement_interruptible_operation_t *operation)
 {
-#if defined(MBEDTLS_ECP_RESTARTABLE)
     return operation->num_ops;
-#else
-    (void) operation;
-    return 0;
-#endif
 }
 
 psa_status_t mbedtls_psa_key_agreement_setup(
@@ -651,29 +646,40 @@
     size_t peer_key_length,
     const psa_key_attributes_t *attributes)
 {
-#if defined(MBEDTLS_ECP_RESTARTABLE)
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    mbedtls_ecp_keypair *ecp = NULL;
+    mbedtls_ecp_keypair *our_key = NULL;
     mbedtls_ecp_keypair *their_key = NULL;
-    size_t bits = 0;
+
+    mbedtls_ecdh_init(&operation->ctx);
+    mbedtls_ecdh_enable_restart(&operation->ctx);
+
+    /* We need to clear number of ops here in case there was a previous
+       complete operation which doesn't reset it after finsishing. */
+    operation->num_ops = 0;
 
     status = mbedtls_psa_ecp_load_representation(
         psa_get_key_type(attributes),
         psa_get_key_bits(attributes),
         private_key_buffer,
         private_key_buffer_len,
-        &ecp);
+        &our_key);
     if (status != PSA_SUCCESS) {
         goto exit;
     }
 
-    psa_ecc_family_t curve = mbedtls_ecc_group_to_psa(ecp->grp.id, &bits);
+    status = mbedtls_to_psa_error(
+        mbedtls_ecdh_get_params(&operation->ctx, our_key, MBEDTLS_ECDH_OURS));
+    if (status != PSA_SUCCESS) {
+        goto exit;
+    }
 
-    mbedtls_ecdh_init(&operation->ctx);
+    mbedtls_ecp_keypair_free(our_key);
+    mbedtls_free(our_key);
+    our_key = NULL;
 
     status = mbedtls_psa_ecp_load_representation(
-        PSA_KEY_TYPE_ECC_PUBLIC_KEY(curve),
-        bits,
+        PSA_KEY_TYPE_PUBLIC_KEY_OF_KEY_PAIR(psa_get_key_type(attributes)),
+        psa_get_key_bits(attributes),
         peer_key,
         peer_key_length,
         &their_key);
@@ -681,37 +687,22 @@
         goto exit;
     }
 
+    /* mbedtls_psa_ecp_load_representation() calls mbedtls_ecp_check_pubkey() which
+       takes MBEDTLS_ECP_OPS_CHK amount of ops. */
+    operation->num_ops += MBEDTLS_ECP_OPS_CHK;
+
     status = mbedtls_to_psa_error(
         mbedtls_ecdh_get_params(&operation->ctx, their_key, MBEDTLS_ECDH_THEIRS));
     if (status != PSA_SUCCESS) {
         goto exit;
     }
 
-    status = mbedtls_to_psa_error(
-        mbedtls_ecdh_get_params(&operation->ctx, ecp, MBEDTLS_ECDH_OURS));
-    if (status != PSA_SUCCESS) {
-        goto exit;
-    }
-
-    mbedtls_ecdh_enable_restart(&operation->ctx);
-    operation->num_ops = 0;
-
 exit:
+    mbedtls_ecp_keypair_free(our_key);
+    mbedtls_free(our_key);
     mbedtls_ecp_keypair_free(their_key);
-    mbedtls_ecp_keypair_free(ecp);
-    mbedtls_free(ecp);
     mbedtls_free(their_key);
     return status;
-#else
-    (void) operation;
-    (void) private_key_buffer;
-    (void) private_key_buffer;
-    (void) private_key_buffer_len;
-    (void) peer_key;
-    (void) peer_key_length;
-    (void) attributes;
-    return PSA_ERROR_NOT_SUPPORTED;
-#endif
 }
 
 psa_status_t mbedtls_psa_key_agreement_complete(
@@ -720,7 +711,6 @@
     size_t shared_secret_size,
     size_t *shared_secret_length)
 {
-#if defined(MBEDTLS_ECP_RESTARTABLE)
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
     mbedtls_psa_interruptible_set_max_ops(psa_interruptible_get_max_ops());
@@ -734,27 +724,14 @@
     operation->num_ops += operation->ctx.rs.ops_done;
 
     return status;
-#else
-    (void) operation;
-    (void) shared_secret;
-    (void) shared_secret_size;
-    (void) shared_secret_length;
-    return PSA_ERROR_NOT_SUPPORTED;
-#endif
 }
 
 psa_status_t mbedtls_psa_key_agreement_abort(
     mbedtls_psa_key_agreement_interruptible_operation_t *operation)
 {
-#if defined(MBEDTLS_ECP_RESTARTABLE)
     operation->num_ops = 0;
-    operation->attributes = NULL;
     mbedtls_ecdh_free(&operation->ctx);
     return PSA_SUCCESS;
-#else
-    (void) operation;
-    return PSA_ERROR_NOT_SUPPORTED;
-#endif
 }
 
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_ECDH */
diff --git a/tf-psa-crypto/include/psa/crypto_builtin_composites.h b/tf-psa-crypto/include/psa/crypto_builtin_composites.h
index c5dde0d..47493b2 100644
--- a/tf-psa-crypto/include/psa/crypto_builtin_composites.h
+++ b/tf-psa-crypto/include/psa/crypto_builtin_composites.h
@@ -235,7 +235,6 @@
 typedef struct {
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_ECDH) && defined(MBEDTLS_ECP_RESTARTABLE)
     mbedtls_ecdh_context MBEDTLS_PRIVATE(ctx);
-    const psa_key_attributes_t *MBEDTLS_PRIVATE(attributes);
     uint32_t MBEDTLS_PRIVATE(num_ops);
 #else
     /* Make the struct non-empty if algs not supported. */
@@ -244,7 +243,7 @@
 } mbedtls_psa_key_agreement_interruptible_operation_t;
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_ECDH) && defined(MBEDTLS_ECP_RESTARTABLE)
-#define MBEDTLS_PSA_KEY_AGREEMENT_INTERRUPTIBLE_OPERATION_INIT { { 0 }, { 0 }, 0 }
+#define MBEDTLS_PSA_KEY_AGREEMENT_INTERRUPTIBLE_OPERATION_INIT { { 0 }, 0 }
 #else
 #define MBEDTLS_PSA_KEY_AGREEMENT_INTERRUPTIBLE_OPERATION_INIT { 0 }
 #endif
diff --git a/tf-psa-crypto/include/psa/crypto_struct.h b/tf-psa-crypto/include/psa/crypto_struct.h
index b9e0f4e..8fc542b 100644
--- a/tf-psa-crypto/include/psa/crypto_struct.h
+++ b/tf-psa-crypto/include/psa/crypto_struct.h
@@ -511,13 +511,14 @@
     psa_driver_key_agreement_interruptible_context_t MBEDTLS_PRIVATE(ctx);
     uint32_t MBEDTLS_PRIVATE(num_ops);
     unsigned int MBEDTLS_PRIVATE(error_occurred) : 1;
+    psa_key_attributes_t MBEDTLS_PRIVATE(attributes);
 #endif
 };
 
 #if defined(MBEDTLS_PSA_CRYPTO_CLIENT) && !defined(MBEDTLS_PSA_CRYPTO_C)
 #define PSA_KEY_AGREEMENT_IOP_INIT { 0 }
 #else
-#define PSA_KEY_AGREEMENT_IOP_INIT { 0, { 0 }, 0, 0 }
+#define PSA_KEY_AGREEMENT_IOP_INIT { 0, { 0 }, 0, 0, { 0 } }
 #endif
 
 static inline struct psa_key_agreement_iop_s