pk_wrap: optimize code for ECDSA verify
Signed-off-by: Valerio Setti <valerio.setti@nordicsemi.no>
diff --git a/library/pk_wrap.c b/library/pk_wrap.c
index 24d531b..75904c4 100644
--- a/library/pk_wrap.c
+++ b/library/pk_wrap.c
@@ -717,36 +717,19 @@
return 0;
}
-static int ecdsa_verify_wrap(mbedtls_pk_context *pk,
- mbedtls_md_type_t md_alg,
- const unsigned char *hash, size_t hash_len,
- const unsigned char *sig, size_t sig_len)
+static int ecdsa_verify_psa(unsigned char *key, size_t key_len,
+ psa_ecc_family_t curve, size_t curve_bits,
+ const unsigned char *hash, size_t hash_len,
+ const unsigned char *sig, size_t sig_len)
{
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
- psa_status_t status;
- unsigned char *p;
psa_algorithm_t psa_sig_md = PSA_ALG_ECDSA_ANY;
- size_t signature_len;
- ((void) md_alg);
-#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
- unsigned char buf[PSA_VENDOR_ECDSA_SIGNATURE_MAX_SIZE];
- psa_ecc_family_t curve = pk->ec_family;
- size_t curve_bits = pk->ec_bits;
-#else
- mbedtls_ecp_keypair *ctx = pk->pk_ctx;
- size_t key_len;
- /* This buffer will initially contain the public key and then the signature
- * but at different points in time. For all curves except secp224k1, which
- * is not currently supported in PSA, the public key is one byte longer
- * (header byte + 2 numbers, while the signature is only 2 numbers),
- * so use that as the buffer size. */
- unsigned char buf[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
- size_t curve_bits;
- psa_ecc_family_t curve =
- mbedtls_ecc_group_to_psa(ctx->grp.id, &curve_bits);
-#endif
+ size_t signature_len = PSA_ECDSA_SIGNATURE_SIZE(curve_bits);
+ unsigned char extracted_sig[PSA_VENDOR_ECDSA_SIGNATURE_MAX_SIZE];
+ unsigned char *p;
+ psa_status_t status;
if (curve == 0) {
return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
@@ -756,29 +739,13 @@
psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_VERIFY_HASH);
psa_set_key_algorithm(&attributes, psa_sig_md);
-#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
- status = psa_import_key(&attributes,
- pk->pub_raw, pk->pub_raw_len,
- &key_id);
-#else /* MBEDTLS_PK_USE_PSA_EC_DATA */
- ret = mbedtls_ecp_point_write_binary(&ctx->grp, &ctx->Q,
- MBEDTLS_ECP_PF_UNCOMPRESSED,
- &key_len, buf, sizeof(buf));
- if (ret != 0) {
- goto cleanup;
- }
-
- status = psa_import_key(&attributes,
- buf, key_len,
- &key_id);
-#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
+ status = psa_import_key(&attributes, key, key_len, &key_id);
if (status != PSA_SUCCESS) {
ret = PSA_PK_TO_MBEDTLS_ERR(status);
goto cleanup;
}
- signature_len = PSA_ECDSA_SIGNATURE_SIZE(curve_bits);
- if (signature_len > sizeof(buf)) {
+ if (signature_len > sizeof(extracted_sig)) {
ret = MBEDTLS_ERR_PK_BAD_INPUT_DATA;
goto cleanup;
}
@@ -787,14 +754,13 @@
/* extract_ecdsa_sig's last parameter is the size
* of each integer to be parsed, so it's actually half
* the size of the signature. */
- if ((ret = extract_ecdsa_sig(&p, sig + sig_len, buf,
+ if ((ret = extract_ecdsa_sig(&p, sig + sig_len, extracted_sig,
signature_len/2)) != 0) {
goto cleanup;
}
- status = psa_verify_hash(key_id, psa_sig_md,
- hash, hash_len,
- buf, signature_len);
+ status = psa_verify_hash(key_id, psa_sig_md, hash, hash_len,
+ extracted_sig, signature_len);
if (status != PSA_SUCCESS) {
ret = PSA_PK_ECDSA_TO_MBEDTLS_ERR(status);
goto cleanup;
@@ -814,6 +780,45 @@
return ret;
}
+
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+static int ecdsa_verify_wrap(mbedtls_pk_context *pk,
+ mbedtls_md_type_t md_alg,
+ const unsigned char *hash, size_t hash_len,
+ const unsigned char *sig, size_t sig_len)
+{
+ (void) md_alg;
+ psa_ecc_family_t curve = pk->ec_family;
+ size_t curve_bits = pk->ec_bits;
+
+ return ecdsa_verify_psa(pk->pub_raw, pk->pub_raw_len, curve, curve_bits,
+ hash, hash_len, sig, sig_len);
+}
+#else /* MBEDTLS_PK_USE_PSA_EC_DATA */
+static int ecdsa_verify_wrap(mbedtls_pk_context *pk,
+ mbedtls_md_type_t md_alg,
+ const unsigned char *hash, size_t hash_len,
+ const unsigned char *sig, size_t sig_len)
+{
+ (void) md_alg;
+ int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+ mbedtls_ecp_keypair *ctx = pk->pk_ctx;
+ unsigned char key[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
+ size_t key_len;
+ size_t curve_bits;
+ psa_ecc_family_t curve = mbedtls_ecc_group_to_psa(ctx->grp.id, &curve_bits);
+
+ ret = mbedtls_ecp_point_write_binary(&ctx->grp, &ctx->Q,
+ MBEDTLS_ECP_PF_UNCOMPRESSED,
+ &key_len, key, sizeof(key));
+ if (ret != 0) {
+ return ret;
+ }
+
+ return ecdsa_verify_psa(key, key_len, curve, curve_bits,
+ hash, hash_len, sig, sig_len);
+}
+#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
#else /* MBEDTLS_USE_PSA_CRYPTO */
static int ecdsa_verify_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
const unsigned char *hash, size_t hash_len,