pk_wrap: optimize code for ECDSA verify

Signed-off-by: Valerio Setti <valerio.setti@nordicsemi.no>
diff --git a/library/pk_wrap.c b/library/pk_wrap.c
index 24d531b..75904c4 100644
--- a/library/pk_wrap.c
+++ b/library/pk_wrap.c
@@ -717,36 +717,19 @@
     return 0;
 }
 
-static int ecdsa_verify_wrap(mbedtls_pk_context *pk,
-                             mbedtls_md_type_t md_alg,
-                             const unsigned char *hash, size_t hash_len,
-                             const unsigned char *sig, size_t sig_len)
+static int ecdsa_verify_psa(unsigned char *key, size_t key_len,
+                            psa_ecc_family_t curve, size_t curve_bits,
+                            const unsigned char *hash, size_t hash_len,
+                            const unsigned char *sig, size_t sig_len)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
-    psa_status_t status;
-    unsigned char *p;
     psa_algorithm_t psa_sig_md = PSA_ALG_ECDSA_ANY;
-    size_t signature_len;
-    ((void) md_alg);
-#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
-    unsigned char buf[PSA_VENDOR_ECDSA_SIGNATURE_MAX_SIZE];
-    psa_ecc_family_t curve = pk->ec_family;
-    size_t curve_bits = pk->ec_bits;
-#else
-    mbedtls_ecp_keypair *ctx = pk->pk_ctx;
-    size_t key_len;
-    /* This buffer will initially contain the public key and then the signature
-     * but at different points in time. For all curves except secp224k1, which
-     * is not currently supported in PSA, the public key is one byte longer
-     * (header byte + 2 numbers, while the signature is only 2 numbers),
-     * so use that as the buffer size. */
-    unsigned char buf[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
-    size_t curve_bits;
-    psa_ecc_family_t curve =
-        mbedtls_ecc_group_to_psa(ctx->grp.id, &curve_bits);
-#endif
+    size_t signature_len = PSA_ECDSA_SIGNATURE_SIZE(curve_bits);
+    unsigned char extracted_sig[PSA_VENDOR_ECDSA_SIGNATURE_MAX_SIZE];
+    unsigned char *p;
+    psa_status_t status;
 
     if (curve == 0) {
         return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
@@ -756,29 +739,13 @@
     psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_VERIFY_HASH);
     psa_set_key_algorithm(&attributes, psa_sig_md);
 
-#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
-    status = psa_import_key(&attributes,
-                            pk->pub_raw, pk->pub_raw_len,
-                            &key_id);
-#else /* MBEDTLS_PK_USE_PSA_EC_DATA */
-    ret = mbedtls_ecp_point_write_binary(&ctx->grp, &ctx->Q,
-                                         MBEDTLS_ECP_PF_UNCOMPRESSED,
-                                         &key_len, buf, sizeof(buf));
-    if (ret != 0) {
-        goto cleanup;
-    }
-
-    status = psa_import_key(&attributes,
-                            buf, key_len,
-                            &key_id);
-#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
+    status = psa_import_key(&attributes, key, key_len, &key_id);
     if (status != PSA_SUCCESS) {
         ret = PSA_PK_TO_MBEDTLS_ERR(status);
         goto cleanup;
     }
 
-    signature_len = PSA_ECDSA_SIGNATURE_SIZE(curve_bits);
-    if (signature_len > sizeof(buf)) {
+    if (signature_len > sizeof(extracted_sig)) {
         ret = MBEDTLS_ERR_PK_BAD_INPUT_DATA;
         goto cleanup;
     }
@@ -787,14 +754,13 @@
     /* extract_ecdsa_sig's last parameter is the size
      * of each integer to be parsed, so it's actually half
      * the size of the signature. */
-    if ((ret = extract_ecdsa_sig(&p, sig + sig_len, buf,
+    if ((ret = extract_ecdsa_sig(&p, sig + sig_len, extracted_sig,
                                  signature_len/2)) != 0) {
         goto cleanup;
     }
 
-    status = psa_verify_hash(key_id, psa_sig_md,
-                             hash, hash_len,
-                             buf, signature_len);
+    status = psa_verify_hash(key_id, psa_sig_md, hash, hash_len,
+                             extracted_sig, signature_len);
     if (status != PSA_SUCCESS) {
         ret = PSA_PK_ECDSA_TO_MBEDTLS_ERR(status);
         goto cleanup;
@@ -814,6 +780,45 @@
 
     return ret;
 }
+
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+static int ecdsa_verify_wrap(mbedtls_pk_context *pk,
+                             mbedtls_md_type_t md_alg,
+                             const unsigned char *hash, size_t hash_len,
+                             const unsigned char *sig, size_t sig_len)
+{
+    (void) md_alg;
+    psa_ecc_family_t curve = pk->ec_family;
+    size_t curve_bits = pk->ec_bits;
+
+    return ecdsa_verify_psa(pk->pub_raw, pk->pub_raw_len, curve, curve_bits,
+                            hash, hash_len, sig, sig_len);
+}
+#else /* MBEDTLS_PK_USE_PSA_EC_DATA */
+static int ecdsa_verify_wrap(mbedtls_pk_context *pk,
+                             mbedtls_md_type_t md_alg,
+                             const unsigned char *hash, size_t hash_len,
+                             const unsigned char *sig, size_t sig_len)
+{
+    (void) md_alg;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    mbedtls_ecp_keypair *ctx = pk->pk_ctx;
+    unsigned char key[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
+    size_t key_len;
+    size_t curve_bits;
+    psa_ecc_family_t curve = mbedtls_ecc_group_to_psa(ctx->grp.id, &curve_bits);
+
+    ret = mbedtls_ecp_point_write_binary(&ctx->grp, &ctx->Q,
+                                         MBEDTLS_ECP_PF_UNCOMPRESSED,
+                                         &key_len, key, sizeof(key));
+    if (ret != 0) {
+        return ret;
+    }
+
+    return ecdsa_verify_psa(key, key_len, curve, curve_bits,
+                            hash, hash_len, sig, sig_len);
+}
+#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
 #else /* MBEDTLS_USE_PSA_CRYPTO */
 static int ecdsa_verify_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                              const unsigned char *hash, size_t hash_len,