Merge pull request #7392 from valeriosetti/issue7388

PK: use PSA to complete public key when USE_PSA is enabled
diff --git a/library/pk_wrap.c b/library/pk_wrap.c
index 45d743f..4e5293d 100644
--- a/library/pk_wrap.c
+++ b/library/pk_wrap.c
@@ -1104,9 +1104,10 @@
  * - write the raw content of public key "pub" to a local buffer
  * - compare the two buffers
  */
-static int eckey_check_pair_psa(const void *pub, const void *prv)
+static int eckey_check_pair_psa(const mbedtls_ecp_keypair *pub,
+                                const mbedtls_ecp_keypair *prv)
 {
-    psa_status_t status;
+    psa_status_t status, destruction_status;
     psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT;
     mbedtls_ecp_keypair *prv_ctx = (mbedtls_ecp_keypair *) prv;
     mbedtls_ecp_keypair *pub_ctx = (mbedtls_ecp_keypair *) pub;
@@ -1133,20 +1134,21 @@
     }
 
     status = psa_import_key(&key_attr, prv_key_buf, curve_bytes, &key_id);
-    if (status != PSA_SUCCESS) {
-        ret = PSA_PK_TO_MBEDTLS_ERR(status);
+    ret = PSA_PK_TO_MBEDTLS_ERR(status);
+    if (ret != 0) {
         return ret;
     }
 
     mbedtls_platform_zeroize(prv_key_buf, sizeof(prv_key_buf));
 
-    ret = PSA_PK_TO_MBEDTLS_ERR(psa_export_public_key(key_id,
-                                                      prv_key_buf,
-                                                      sizeof(prv_key_buf),
-                                                      &prv_key_len));
-    status = psa_destroy_key(key_id);
-    if (ret != 0 || status != PSA_SUCCESS) {
-        return (ret != 0) ? ret : PSA_PK_TO_MBEDTLS_ERR(status);
+    status = psa_export_public_key(key_id, prv_key_buf, sizeof(prv_key_buf),
+                                   &prv_key_len);
+    ret = PSA_PK_TO_MBEDTLS_ERR(status);
+    destruction_status = psa_destroy_key(key_id);
+    if (ret != 0) {
+        return ret;
+    } else if (destruction_status != PSA_SUCCESS) {
+        return PSA_PK_TO_MBEDTLS_ERR(destruction_status);
     }
 
     ret = mbedtls_ecp_point_write_binary(&pub_ctx->grp, &pub_ctx->Q,
@@ -1172,8 +1174,7 @@
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     (void) f_rng;
     (void) p_rng;
-    return eckey_check_pair_psa((const mbedtls_ecp_keypair *) pub,
-                                (const mbedtls_ecp_keypair *) prv);
+    return eckey_check_pair_psa(pub, prv);
 #else /* MBEDTLS_USE_PSA_CRYPTO */
     return mbedtls_ecp_check_pub_priv((const mbedtls_ecp_keypair *) pub,
                                       (const mbedtls_ecp_keypair *) prv,
diff --git a/library/pkparse.c b/library/pkparse.c
index ccca692..fa61a06 100644
--- a/library/pkparse.c
+++ b/library/pkparse.c
@@ -48,6 +48,14 @@
 #include "mbedtls/pkcs12.h"
 #endif
 
+#if defined(MBEDTLS_PSA_CRYPTO_C)
+#include "mbedtls/psa_util.h"
+#endif
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#include "psa/crypto.h"
+#endif
+
 #include "mbedtls/platform.h"
 
 #if defined(MBEDTLS_FS_IO)
@@ -869,6 +877,57 @@
 #endif /* MBEDTLS_RSA_C */
 
 #if defined(MBEDTLS_ECP_C)
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+/*
+ * Helper function for deriving a public key from its private counterpart by
+ * using PSA functions.
+ */
+static int pk_derive_public_key(mbedtls_ecp_group *grp, mbedtls_ecp_point *Q,
+                                const mbedtls_mpi *d)
+{
+    psa_status_t status, destruction_status;
+    psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT;
+    size_t curve_bits;
+    psa_ecc_family_t curve = mbedtls_ecc_group_to_psa(grp->id, &curve_bits);
+    /* This buffer is used to store the private key at first and then the
+     * public one (but not at the same time). Therefore we size it for the
+     * latter since it's bigger. */
+    unsigned char key_buf[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
+    size_t key_len = PSA_BITS_TO_BYTES(curve_bits);
+    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    int ret;
+
+    psa_set_key_type(&key_attr, PSA_KEY_TYPE_ECC_KEY_PAIR(curve));
+    psa_set_key_usage_flags(&key_attr, PSA_KEY_USAGE_EXPORT);
+
+    ret = mbedtls_mpi_write_binary(d, key_buf, key_len);
+    if (ret != 0) {
+        return ret;
+    }
+
+    status = psa_import_key(&key_attr, key_buf, key_len, &key_id);
+    ret = psa_pk_status_to_mbedtls(status);
+    if (ret != 0) {
+        return ret;
+    }
+
+    mbedtls_platform_zeroize(key_buf, sizeof(key_buf));
+
+    status = psa_export_public_key(key_id, key_buf, sizeof(key_buf), &key_len);
+    ret = psa_pk_status_to_mbedtls(status);
+    destruction_status = psa_destroy_key(key_id);
+    if (ret != 0) {
+        return ret;
+    } else if (destruction_status != PSA_SUCCESS) {
+        return psa_pk_status_to_mbedtls(destruction_status);
+    }
+
+    ret = mbedtls_ecp_point_read_binary(grp, Q, key_buf, key_len);
+
+    return ret;
+}
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
 /*
  * Parse a SEC1 encoded private EC key
  */
@@ -975,11 +1034,21 @@
         }
     }
 
-    if (!pubkey_done &&
-        (ret = mbedtls_ecp_mul(&eck->grp, &eck->Q, &eck->d, &eck->grp.G,
-                               f_rng, p_rng)) != 0) {
-        mbedtls_ecp_keypair_free(eck);
-        return MBEDTLS_ERROR_ADD(MBEDTLS_ERR_PK_KEY_INVALID_FORMAT, ret);
+    if (!pubkey_done) {
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        (void) f_rng;
+        (void) p_rng;
+        if ((ret = pk_derive_public_key(&eck->grp, &eck->Q, &eck->d)) != 0) {
+            mbedtls_ecp_keypair_free(eck);
+            return ret;
+        }
+#else /* MBEDTLS_USE_PSA_CRYPTO */
+        if ((ret = mbedtls_ecp_mul(&eck->grp, &eck->Q, &eck->d, &eck->grp.G,
+                                   f_rng, p_rng)) != 0) {
+            mbedtls_ecp_keypair_free(eck);
+            return MBEDTLS_ERROR_ADD(MBEDTLS_ERR_PK_KEY_INVALID_FORMAT, ret);
+        }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     }
 
     if ((ret = mbedtls_ecp_check_privkey(&eck->grp, &eck->d)) != 0) {
diff --git a/tests/suites/test_suite_pkparse.function b/tests/suites/test_suite_pkparse.function
index 1a6858f..d7c2d0b 100644
--- a/tests/suites/test_suite_pkparse.function
+++ b/tests/suites/test_suite_pkparse.function
@@ -101,6 +101,7 @@
     mbedtls_pk_context ctx;
     int res;
 
+    USE_PSA_INIT();
     mbedtls_pk_init(&ctx);
 
     res = mbedtls_pk_parse_keyfile(&ctx, key_file, password,
@@ -117,6 +118,7 @@
 
 exit:
     mbedtls_pk_free(&ctx);
+    USE_PSA_DONE();
 }
 /* END_CASE */