pk_wrap: optimize code for ECDSA sign

Signed-off-by: Valerio Setti <valerio.setti@nordicsemi.no>
diff --git a/library/pk_wrap.c b/library/pk_wrap.c
index 54a4d5d..24d531b 100644
--- a/library/pk_wrap.c
+++ b/library/pk_wrap.c
@@ -920,6 +920,80 @@
     return 0;
 }
 
+/* This is the common helper used by ecdsa_sign_wrap() functions below (they
+ * differ in having PK_USE_PSA_EC_DATA defined or not) to sign using PSA
+ * functions. */
+static int ecdsa_sign_psa(mbedtls_svc_key_id_t key_id, psa_algorithm_t psa_sig_md,
+                          const unsigned char *hash, size_t hash_len,
+                          unsigned char *sig, size_t sig_size, size_t *sig_len)
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    psa_status_t status;
+
+    status = psa_sign_hash(key_id, psa_sig_md, hash, hash_len,
+                           sig, sig_size, sig_len);
+    if (status != PSA_SUCCESS) {
+        return PSA_PK_ECDSA_TO_MBEDTLS_ERR(status);
+    }
+
+    ret = pk_ecdsa_sig_asn1_from_psa(sig, sig_len, sig_size);
+
+    return ret;
+}
+
+/* The reason for having this duplicated compared to ecdsa_sign_wrap() below is
+ * that:
+ * - opaque keys are available as long as USE_PSA_CRYPTO is defined and even
+ *   if !PK_USE_PSA_EC_DATA
+ * - opaque keys do not support PSA_ALG_DETERMINISTIC_ECDSA() */
+static int ecdsa_sign_wrap_opaque(mbedtls_pk_context *pk,
+                                  mbedtls_md_type_t md_alg,
+                                  const unsigned char *hash, size_t hash_len,
+                                  unsigned char *sig, size_t sig_size,
+                                  size_t *sig_len,
+                                  int (*f_rng)(void *, unsigned char *, size_t),
+                                  void *p_rng)
+{
+    ((void) f_rng);
+    ((void) p_rng);
+    psa_algorithm_t psa_sig_md =
+        PSA_ALG_ECDSA(mbedtls_md_psa_alg_from_type(md_alg));
+
+    if (MBEDTLS_SVC_KEY_ID_GET_KEY_ID(pk->priv_id) == PSA_KEY_ID_NULL) {
+        return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
+    }
+
+    return ecdsa_sign_psa(pk->priv_id, psa_sig_md, hash, hash_len, sig, sig_size,
+                          sig_len);
+}
+
+#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
+static int ecdsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
+                           const unsigned char *hash, size_t hash_len,
+                           unsigned char *sig, size_t sig_size, size_t *sig_len,
+                           int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
+{
+    ((void) f_rng);
+    ((void) p_rng);
+#if defined(MBEDTLS_ECDSA_DETERMINISTIC)
+    psa_algorithm_t psa_sig_md =
+        PSA_ALG_DETERMINISTIC_ECDSA(mbedtls_md_psa_alg_from_type(md_alg));
+#else
+    psa_algorithm_t psa_sig_md =
+        PSA_ALG_ECDSA(mbedtls_md_psa_alg_from_type(md_alg));
+#endif
+
+    if (pk->ec_family == 0) {
+        return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
+    }
+    if (MBEDTLS_SVC_KEY_ID_GET_KEY_ID(pk->priv_id) == PSA_KEY_ID_NULL) {
+        return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
+    }
+
+    return ecdsa_sign_psa(pk->priv_id, psa_sig_md, hash, hash_len, sig, sig_size,
+                          sig_len);
+}
+#else /* MBEDTLS_PK_USE_PSA_EC_DATA */
 static int ecdsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                            const unsigned char *hash, size_t hash_len,
                            unsigned char *sig, size_t sig_size, size_t *sig_len,
@@ -928,16 +1002,6 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT;
     psa_status_t status;
-#if defined(MBEDTLS_ECDSA_DETERMINISTIC)
-    psa_algorithm_t psa_sig_md =
-        PSA_ALG_DETERMINISTIC_ECDSA(mbedtls_md_psa_alg_from_type(md_alg));
-#else
-    psa_algorithm_t psa_sig_md =
-        PSA_ALG_ECDSA(mbedtls_md_psa_alg_from_type(md_alg));
-#endif
-#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
-    psa_ecc_family_t curve = pk->ec_family;
-#else /* MBEDTLS_PK_USE_PSA_EC_DATA */
     mbedtls_ecp_keypair *ctx = pk->pk_ctx;
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     unsigned char buf[MBEDTLS_PSA_MAX_EC_KEY_PAIR_LENGTH];
@@ -945,9 +1009,13 @@
     psa_ecc_family_t curve =
         mbedtls_ecc_group_to_psa(ctx->grp.id, &curve_bits);
     size_t key_len = PSA_BITS_TO_BYTES(curve_bits);
-#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
-
-    /* PSA has its own RNG */
+#if defined(MBEDTLS_ECDSA_DETERMINISTIC)
+    psa_algorithm_t psa_sig_md =
+        PSA_ALG_DETERMINISTIC_ECDSA(mbedtls_md_psa_alg_from_type(md_alg));
+#else
+    psa_algorithm_t psa_sig_md =
+        PSA_ALG_ECDSA(mbedtls_md_psa_alg_from_type(md_alg));
+#endif
     ((void) f_rng);
     ((void) p_rng);
 
@@ -955,12 +1023,6 @@
         return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
     }
 
-#if defined(MBEDTLS_PK_USE_PSA_EC_DATA)
-    if (MBEDTLS_SVC_KEY_ID_GET_KEY_ID(pk->priv_id) == PSA_KEY_ID_NULL) {
-        return MBEDTLS_ERR_PK_BAD_INPUT_DATA;
-    }
-    key_id = pk->priv_id;
-#else
     if (key_len > sizeof(buf)) {
         return MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     }
@@ -973,36 +1035,24 @@
     psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_SIGN_HASH);
     psa_set_key_algorithm(&attributes, psa_sig_md);
 
-    status = psa_import_key(&attributes,
-                            buf, key_len,
-                            &key_id);
+    status = psa_import_key(&attributes, buf, key_len, &key_id);
     if (status != PSA_SUCCESS) {
         ret = PSA_PK_TO_MBEDTLS_ERR(status);
         goto cleanup;
     }
-#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
 
-    status = psa_sign_hash(key_id, psa_sig_md, hash, hash_len,
-                           sig, sig_size, sig_len);
-    if (status != PSA_SUCCESS) {
-        ret = PSA_PK_ECDSA_TO_MBEDTLS_ERR(status);
-        goto cleanup;
-    }
-
-    ret = pk_ecdsa_sig_asn1_from_psa(sig, sig_len, sig_size);
+    ret = ecdsa_sign_psa(key_id, psa_sig_md, hash, hash_len, sig, sig_size, sig_len);
 
 cleanup:
-
-#if !defined(MBEDTLS_PK_USE_PSA_EC_DATA)
     mbedtls_platform_zeroize(buf, sizeof(buf));
     status = psa_destroy_key(key_id);
-#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
     if (ret == 0 && status != PSA_SUCCESS) {
         ret = PSA_PK_TO_MBEDTLS_ERR(status);
     }
 
     return ret;
 }
+#endif /* MBEDTLS_PK_USE_PSA_EC_DATA */
 #else /* MBEDTLS_USE_PSA_CRYPTO */
 static int ecdsa_sign_wrap(mbedtls_pk_context *pk, mbedtls_md_type_t md_alg,
                            const unsigned char *hash, size_t hash_len,
@@ -1600,7 +1650,7 @@
                                unsigned char *sig, size_t sig_size, size_t *sig_len,
                                int (*f_rng)(void *, unsigned char *, size_t), void *p_rng)
 {
-#if !defined(MBEDTLS_PK_CAN_ECDSA_SIGN) && !defined(MBEDTLS_RSA_C)
+#if !defined(MBEDTLS_RSA_C)
     ((void) pk);
     ((void) md_alg);
     ((void) hash);
@@ -1611,7 +1661,7 @@
     ((void) f_rng);
     ((void) p_rng);
     return MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE;
-#else /* !MBEDTLS_PK_CAN_ECDSA_SIGN && !MBEDTLS_RSA_C */
+#else /* !MBEDTLS_RSA_C */
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     psa_algorithm_t alg;
     psa_key_type_t type;
@@ -1629,44 +1679,25 @@
     type = psa_get_key_type(&attributes);
     psa_reset_key_attributes(&attributes);
 
-#if defined(MBEDTLS_PK_CAN_ECDSA_SIGN)
-    if (PSA_KEY_TYPE_IS_ECC_KEY_PAIR(type)) {
-        alg = PSA_ALG_ECDSA(mbedtls_md_psa_alg_from_type(md_alg));
-    } else
-#endif /* MBEDTLS_PK_CAN_ECDSA_SIGN */
-#if defined(MBEDTLS_RSA_C)
     if (PSA_KEY_TYPE_IS_RSA(type)) {
         alg = PSA_ALG_RSA_PKCS1V15_SIGN(mbedtls_md_psa_alg_from_type(md_alg));
-    } else
-#endif /* MBEDTLS_RSA_C */
-    return MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE;
+    } else {
+        return MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE;
+    }
 
     /* make the signature */
     status = psa_sign_hash(pk->priv_id, alg, hash, hash_len,
                            sig, sig_size, sig_len);
     if (status != PSA_SUCCESS) {
-#if defined(MBEDTLS_PK_CAN_ECDSA_SIGN)
-        if (PSA_KEY_TYPE_IS_ECC_KEY_PAIR(type)) {
-            return PSA_PK_ECDSA_TO_MBEDTLS_ERR(status);
-        } else
-#endif /* MBEDTLS_PK_CAN_ECDSA_SIGN */
-#if defined(MBEDTLS_RSA_C)
         if (PSA_KEY_TYPE_IS_RSA(type)) {
             return PSA_PK_RSA_TO_MBEDTLS_ERR(status);
-        } else
-#endif /* MBEDTLS_RSA_C */
-        return PSA_PK_TO_MBEDTLS_ERR(status);
+        } else {
+            return PSA_PK_TO_MBEDTLS_ERR(status);
+        }
     }
 
-#if defined(MBEDTLS_PK_CAN_ECDSA_SIGN)
-    if (PSA_KEY_TYPE_IS_ECC_KEY_PAIR(type)) {
-        /* transcode it to ASN.1 sequence */
-        return pk_ecdsa_sig_asn1_from_psa(sig, sig_len, sig_size);
-    }
-#endif /* MBEDTLS_PK_CAN_ECDSA_SIGN */
-
     return 0;
-#endif /* !MBEDTLS_PK_CAN_ECDSA_SIGN && !MBEDTLS_RSA_C */
+#endif /* !MBEDTLS_RSA_C */
 }
 
 static int pk_opaque_ec_check_pair(mbedtls_pk_context *pub, mbedtls_pk_context *prv,
@@ -1722,7 +1753,7 @@
     pk_opaque_get_bitlen,
     pk_opaque_ecdsa_can_do,
     NULL, /* verify - will be done later */
-    pk_opaque_sign_wrap,
+    ecdsa_sign_wrap_opaque,
 #if defined(MBEDTLS_ECDSA_C) && defined(MBEDTLS_ECP_RESTARTABLE)
     NULL, /* restartable verify - not relevant */
     NULL, /* restartable sign - not relevant */