pk: add an alternative function for checking private/public key pairs

Instead of using the legacy mbedtls_ecp_check_pub_priv() function which
was based on ECP math, we add a new option named eckey_check_pair_psa()
which takes advantage of PSA.
Of course, this is available when MBEDTLS_USE_PSA_CRYPTO in enabled.

Tests were also fixed accordingly.

Signed-off-by: Valerio Setti <valerio.setti@nordicsemi.no>
diff --git a/include/mbedtls/ecp.h b/include/mbedtls/ecp.h
index b6144d9..0b72a83 100644
--- a/include/mbedtls/ecp.h
+++ b/include/mbedtls/ecp.h
@@ -1296,9 +1296,12 @@
  * \return          An \c MBEDTLS_ERR_ECP_XXX or an \c MBEDTLS_ERR_MPI_XXX
  *                  error code on calculation failure.
  */
+
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
 int mbedtls_ecp_check_pub_priv(
     const mbedtls_ecp_keypair *pub, const mbedtls_ecp_keypair *prv,
     int (*f_rng)(void *, unsigned char *, size_t), void *p_rng);
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 /**
  * \brief           This function exports generic key-pair parameters.
diff --git a/library/ecp.c b/library/ecp.c
index 08fbe86..a794b3b 100644
--- a/library/ecp.c
+++ b/library/ecp.c
@@ -3316,7 +3316,7 @@
     return ret;
 }
 
-
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
 /*
  * Check a public-private key pair
  */
@@ -3357,6 +3357,7 @@
 
     return ret;
 }
+#endif /* !MBEDTLS_USE_PSA_CRYPTO */
 
 /*
  * Export generic key-pair parameters.
diff --git a/library/pk_wrap.c b/library/pk_wrap.c
index 4d91f22..2d5a0b7 100644
--- a/library/pk_wrap.c
+++ b/library/pk_wrap.c
@@ -1095,13 +1095,92 @@
 }
 #endif /* MBEDTLS_ECDSA_C && MBEDTLS_ECP_RESTARTABLE */
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+/*
+ * Alternative function used to verify that the EC private/public key pair
+ * is valid using PSA functions instead of ECP ones.
+ * The flow is:
+ * - import the private key "prv" to PSA and export its public part
+ * - write the raw content of public key "pub" to a local buffer
+ * - compare the two buffers
+ */
+static int eckey_check_pair_psa(const void *pub, const void *prv)
+{
+    psa_status_t status;
+    psa_key_attributes_t key_attr = PSA_KEY_ATTRIBUTES_INIT;
+    mbedtls_ecp_keypair *prv_ctx = (mbedtls_ecp_keypair *) prv;
+    mbedtls_ecp_keypair *pub_ctx = (mbedtls_ecp_keypair *) pub;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    uint8_t prv_key_buf[MBEDTLS_PSA_MAX_EC_KEY_PAIR_LENGTH];
+    size_t prv_key_len;
+    uint8_t pub_key_buf[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
+    size_t pub_key_len;
+    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
+    size_t curve_bits;
+    psa_ecc_family_t curve =
+        mbedtls_ecc_group_to_psa(prv_ctx->grp.id, &curve_bits);
+    size_t curve_bytes = PSA_BITS_TO_BYTES(curve_bits);
+
+    psa_set_key_type(&key_attr, PSA_KEY_TYPE_ECC_KEY_PAIR(curve));
+    psa_set_key_usage_flags(&key_attr, PSA_KEY_USAGE_EXPORT);
+
+    ret = mbedtls_mpi_write_binary(&prv_ctx->d, prv_key_buf, curve_bytes);
+    if (ret != 0) {
+        return ret;
+    }
+
+    status = psa_import_key(&key_attr, prv_key_buf, curve_bytes, &key_id);
+    if (status != PSA_SUCCESS) {
+        ret = PSA_PK_TO_MBEDTLS_ERR(status);
+        return ret;
+    }
+
+    mbedtls_platform_zeroize(prv_key_buf, sizeof(prv_key_buf));
+
+    status = psa_export_public_key(key_id, prv_key_buf, sizeof(prv_key_buf),
+                                   &prv_key_len);
+    if (status != PSA_SUCCESS) {
+        ret = PSA_PK_TO_MBEDTLS_ERR(status);
+        status = psa_destroy_key(key_id);
+        return (status != PSA_SUCCESS) ? PSA_PK_TO_MBEDTLS_ERR(status) : ret;
+    }
+
+    status = psa_destroy_key(key_id);
+    if (status != PSA_SUCCESS) {
+        return PSA_PK_TO_MBEDTLS_ERR(status);
+    }
+
+    ret = mbedtls_ecp_point_write_binary(&pub_ctx->grp, &pub_ctx->Q,
+                                         MBEDTLS_ECP_PF_UNCOMPRESSED,
+                                         &pub_key_len, pub_key_buf,
+                                         sizeof(pub_key_buf));
+    if (ret != 0) {
+        return ret;
+    }
+
+    if (memcmp(prv_key_buf, pub_key_buf, curve_bytes) != 0) {
+        return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
+    }
+
+    return 0;
+}
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
 static int eckey_check_pair(const void *pub, const void *prv,
                             int (*f_rng)(void *, unsigned char *, size_t),
                             void *p_rng)
 {
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    (void) f_rng;
+    (void) p_rng;
+    return eckey_check_pair_psa((const mbedtls_ecp_keypair *) pub,
+                                (const mbedtls_ecp_keypair *) prv);
+#else /* MBEDTLS_USE_PSA_CRYPTO */
     return mbedtls_ecp_check_pub_priv((const mbedtls_ecp_keypair *) pub,
                                       (const mbedtls_ecp_keypair *) prv,
                                       f_rng, p_rng);
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+    return MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE;
 }
 
 static void *eckey_alloc_wrap(void)
diff --git a/tests/suites/test_suite_ecp.function b/tests/suites/test_suite_ecp.function
index 408fe5d..a772608 100644
--- a/tests/suites/test_suite_ecp.function
+++ b/tests/suites/test_suite_ecp.function
@@ -955,7 +955,7 @@
 }
 /* END_CASE */
 
-/* BEGIN_CASE */
+/* BEGIN_CASE depends_on:!MBEDTLS_USE_PSA_CRYPTO */
 void mbedtls_ecp_check_pub_priv(int id_pub, char *Qx_pub, char *Qy_pub,
                                 int id, char *d, char *Qx, char *Qy,
                                 int ret)
diff --git a/tests/suites/test_suite_pk.function b/tests/suites/test_suite_pk.function
index 20f61fc..de531d3 100644
--- a/tests/suites/test_suite_pk.function
+++ b/tests/suites/test_suite_pk.function
@@ -489,6 +489,15 @@
     mbedtls_pk_init(&prv);
     mbedtls_pk_init(&alt);
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    /* mbedtls_pk_check_pair() returns either PK or ECP error codes depending
+       on MBEDTLS_USE_PSA_CRYPTO so here we dynamically translate between the
+       two */
+    if (ret == MBEDTLS_ERR_ECP_BAD_INPUT_DATA) {
+        ret = MBEDTLS_ERR_PK_BAD_INPUT_DATA;
+    }
+#endif
+
     TEST_ASSERT(mbedtls_pk_parse_public_keyfile(&pub, pub_file) == 0);
     TEST_ASSERT(mbedtls_pk_parse_keyfile(&prv, prv_file, NULL,
                                          mbedtls_test_rnd_std_rand, NULL)