mbedtls_pk_get_psa_attributes: RSA support

Add code and unit tests for MBEDTLS_PK_RSA in mbedtls_pk_get_psa_attributes().

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/library/pk.c b/library/pk.c
index 706d5d3..1485bd7 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -379,6 +379,27 @@
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 #if defined(MBEDTLS_PSA_CRYPTO_CLIENT)
+#if defined(MBEDTLS_RSA_C)
+static psa_algorithm_t psa_algorithm_for_rsa(const mbedtls_rsa_context *rsa,
+                                             int want_crypt)
+{
+    if (mbedtls_rsa_get_padding_mode(rsa) == MBEDTLS_RSA_PKCS_V21) {
+        if (want_crypt) {
+            mbedtls_md_type_t md_type = mbedtls_rsa_get_md_alg(rsa);
+            return PSA_ALG_RSA_OAEP(mbedtls_md_psa_alg_from_type(md_type));
+        } else {
+            return PSA_ALG_RSA_PSS_ANY_SALT(PSA_ALG_ANY_HASH);
+        }
+    } else {
+        if (want_crypt) {
+            return PSA_ALG_RSA_PKCS1V15_CRYPT;
+        } else {
+            return PSA_ALG_RSA_PKCS1V15_SIGN(PSA_ALG_ANY_HASH);
+        }
+    }
+}
+#endif /* MBEDTLS_RSA_C */
+
 int mbedtls_pk_get_psa_attributes(const mbedtls_pk_context *pk,
                                   psa_key_usage_t usage,
                                   psa_key_attributes_t *attributes)
@@ -386,6 +407,49 @@
     mbedtls_pk_type_t pk_type = mbedtls_pk_get_type(pk);
 
     switch (pk_type) {
+#if defined(MBEDTLS_RSA_C)
+        case MBEDTLS_PK_RSA:
+            int want_crypt = 0;
+            int want_private = 0;
+            switch (usage) {
+                case PSA_KEY_USAGE_SIGN_MESSAGE:
+                    usage |= PSA_KEY_USAGE_VERIFY_MESSAGE;
+                    want_private = 1;
+                    break;
+                case PSA_KEY_USAGE_SIGN_HASH:
+                    usage |= PSA_KEY_USAGE_VERIFY_HASH;
+                    want_private = 1;
+                    break;
+                case PSA_KEY_USAGE_DECRYPT:
+                    usage |= PSA_KEY_USAGE_ENCRYPT;
+                    want_private = 1;
+                    want_crypt = 1;
+                    break;
+                case PSA_KEY_USAGE_VERIFY_MESSAGE:
+                case PSA_KEY_USAGE_VERIFY_HASH:
+                    break;
+                case PSA_KEY_USAGE_ENCRYPT:
+                    want_crypt = 1;
+                    break;
+                default:
+                    return MBEDTLS_ERR_PK_TYPE_MISMATCH;
+            }
+            /* Detect the presence of a private key in a way that works both
+             * in CRT and non-CRT configurations. */
+            mbedtls_rsa_context *rsa = mbedtls_pk_rsa(*pk);
+            int has_private = (mbedtls_rsa_check_privkey(rsa) == 0);
+            if (want_private && !has_private) {
+                return MBEDTLS_ERR_PK_TYPE_MISMATCH;
+            }
+            psa_set_key_type(attributes, (want_private ?
+                                          PSA_KEY_TYPE_RSA_KEY_PAIR :
+                                          PSA_KEY_TYPE_RSA_PUBLIC_KEY));
+            psa_set_key_bits(attributes, mbedtls_mpi_bitlen(&rsa->N));
+            psa_set_key_algorithm(attributes,
+                                  psa_algorithm_for_rsa(rsa, want_crypt));
+            break;
+#endif /* MBEDTLS_RSA_C */
+
 #if defined(MBEDTLS_PK_RSA_ALT_SUPPORT)
         case MBEDTLS_PK_RSA_ALT:
             return MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE;
diff --git a/tests/suites/test_suite_pk.data b/tests/suites/test_suite_pk.data
index e8ffff4..bc0de71 100644
--- a/tests/suites/test_suite_pk.data
+++ b/tests/suites/test_suite_pk.data
@@ -685,3 +685,107 @@
 pk_get_psa_attributes_fail:MBEDTLS_PK_NONE:0:PSA_KEY_USAGE_SIGN_MESSAGE:MBEDTLS_ERR_PK_BAD_INPUT_DATA
 
 # There is a (negative) test for pk_type=MBEDTLS_PK_RSA_ALT in pk_rsa_alt().
+
+PSA attributes for pk: RSA v15 pair DECRYPT
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes:MBEDTLS_PK_RSA:1:PSA_KEY_USAGE_DECRYPT:1:PSA_ALG_RSA_PKCS1V15_CRYPT
+
+PSA attributes for pk: RSA v21 SHA-256 pair DECRYPT
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V21:MBEDTLS_MD_CAN_SHA256
+pk_rsa_v21_get_psa_attributes:MBEDTLS_MD_SHA256:1:PSA_KEY_USAGE_DECRYPT:1:PSA_ALG_RSA_OAEP(PSA_ALG_SHA_256)
+
+PSA attributes for pk: RSA v21 SHA-512 pair DECRYPT
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V21:MBEDTLS_MD_CAN_SHA512
+pk_rsa_v21_get_psa_attributes:MBEDTLS_MD_SHA512:1:PSA_KEY_USAGE_DECRYPT:1:PSA_ALG_RSA_OAEP(PSA_ALG_SHA_512)
+
+PSA attributes for pk: RSA v15 pair->public ENCRYPT
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes:MBEDTLS_PK_RSA:1:PSA_KEY_USAGE_ENCRYPT:0:PSA_ALG_RSA_PKCS1V15_CRYPT
+
+PSA attributes for pk: RSA v21 SHA-256 pair->public ENCRYPT
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V21:MBEDTLS_MD_CAN_SHA256
+pk_rsa_v21_get_psa_attributes:MBEDTLS_MD_SHA256:1:PSA_KEY_USAGE_ENCRYPT:0:PSA_ALG_RSA_OAEP(PSA_ALG_SHA_256)
+
+PSA attributes for pk: RSA v21 SHA-512 pair->public ENCRYPT
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V21:MBEDTLS_MD_CAN_SHA512
+pk_rsa_v21_get_psa_attributes:MBEDTLS_MD_SHA512:1:PSA_KEY_USAGE_ENCRYPT:0:PSA_ALG_RSA_OAEP(PSA_ALG_SHA_512)
+
+PSA attributes for pk: RSA v15 public ENCRYPT
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes:MBEDTLS_PK_RSA:0:PSA_KEY_USAGE_ENCRYPT:0:PSA_ALG_RSA_PKCS1V15_CRYPT
+
+PSA attributes for pk: RSA v21 SHA-256 public ENCRYPT
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V21:MBEDTLS_MD_CAN_SHA256
+pk_rsa_v21_get_psa_attributes:MBEDTLS_MD_SHA256:0:PSA_KEY_USAGE_ENCRYPT:0:PSA_ALG_RSA_OAEP(PSA_ALG_SHA_256)
+
+PSA attributes for pk: RSA v21 SHA-512 public ENCRYPT
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V21:MBEDTLS_MD_CAN_SHA512
+pk_rsa_v21_get_psa_attributes:MBEDTLS_MD_SHA512:0:PSA_KEY_USAGE_ENCRYPT:0:PSA_ALG_RSA_OAEP(PSA_ALG_SHA_512)
+
+PSA attributes for pk: RSA v15 public DECRYPT (bad)
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes_fail:MBEDTLS_PK_RSA:0:PSA_KEY_USAGE_DECRYPT:MBEDTLS_ERR_PK_TYPE_MISMATCH
+
+PSA attributes for pk: RSA v15 pair SIGN_MESSAGE
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes:MBEDTLS_PK_RSA:1:PSA_KEY_USAGE_SIGN_MESSAGE:1:PSA_ALG_RSA_PKCS1V15_SIGN(PSA_ALG_ANY_HASH)
+
+PSA attributes for pk: RSA v21 SHA-256 pair SIGN_MESSAGE
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V21
+pk_rsa_v21_get_psa_attributes:MBEDTLS_MD_NONE:1:PSA_KEY_USAGE_SIGN_MESSAGE:1:PSA_ALG_RSA_PSS_ANY_SALT(PSA_ALG_ANY_HASH)
+
+PSA attributes for pk: RSA v15 pair SIGN_HASH
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes:MBEDTLS_PK_RSA:1:PSA_KEY_USAGE_SIGN_HASH:1:PSA_ALG_RSA_PKCS1V15_SIGN(PSA_ALG_ANY_HASH)
+
+PSA attributes for pk: RSA v21 SHA-256 pair SIGN_HASH
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V21
+pk_rsa_v21_get_psa_attributes:MBEDTLS_MD_NONE:1:PSA_KEY_USAGE_SIGN_HASH:1:PSA_ALG_RSA_PSS_ANY_SALT(PSA_ALG_ANY_HASH)
+
+PSA attributes for pk: RSA v15 pair->public VERIFY_MESSAGE
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes:MBEDTLS_PK_RSA:1:PSA_KEY_USAGE_VERIFY_MESSAGE:0:PSA_ALG_RSA_PKCS1V15_SIGN(PSA_ALG_ANY_HASH)
+
+PSA attributes for pk: RSA v21 SHA-256 pair->public VERIFY_MESSAGE
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V21
+pk_rsa_v21_get_psa_attributes:MBEDTLS_MD_NONE:1:PSA_KEY_USAGE_VERIFY_MESSAGE:0:PSA_ALG_RSA_PSS_ANY_SALT(PSA_ALG_ANY_HASH)
+
+PSA attributes for pk: RSA v15 pair->public VERIFY_HASH
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes:MBEDTLS_PK_RSA:1:PSA_KEY_USAGE_VERIFY_HASH:0:PSA_ALG_RSA_PKCS1V15_SIGN(PSA_ALG_ANY_HASH)
+
+PSA attributes for pk: RSA v21 SHA-256 pair->public VERIFY_HASH
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V21
+pk_rsa_v21_get_psa_attributes:MBEDTLS_MD_NONE:1:PSA_KEY_USAGE_VERIFY_HASH:0:PSA_ALG_RSA_PSS_ANY_SALT(PSA_ALG_ANY_HASH)
+
+PSA attributes for pk: RSA v15 public VERIFY_MESSAGE
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes:MBEDTLS_PK_RSA:0:PSA_KEY_USAGE_VERIFY_MESSAGE:0:PSA_ALG_RSA_PKCS1V15_SIGN(PSA_ALG_ANY_HASH)
+
+PSA attributes for pk: RSA v21 SHA-256 public VERIFY_MESSAGE
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V21
+pk_rsa_v21_get_psa_attributes:MBEDTLS_MD_NONE:0:PSA_KEY_USAGE_VERIFY_MESSAGE:0:PSA_ALG_RSA_PSS_ANY_SALT(PSA_ALG_ANY_HASH)
+
+PSA attributes for pk: RSA v15 public VERIFY_HASH
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes:MBEDTLS_PK_RSA:0:PSA_KEY_USAGE_VERIFY_HASH:0:PSA_ALG_RSA_PKCS1V15_SIGN(PSA_ALG_ANY_HASH)
+
+PSA attributes for pk: RSA v21 SHA-256 public VERIFY_HASH
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V21
+pk_rsa_v21_get_psa_attributes:MBEDTLS_MD_NONE:0:PSA_KEY_USAGE_VERIFY_HASH:0:PSA_ALG_RSA_PSS_ANY_SALT(PSA_ALG_ANY_HASH)
+
+PSA attributes for pk: RSA v15 public SIGN_MESSAGE (bad)
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes_fail:MBEDTLS_PK_RSA:0:PSA_KEY_USAGE_SIGN_MESSAGE:MBEDTLS_ERR_PK_TYPE_MISMATCH
+
+PSA attributes for pk: RSA v15 public SIGN_HASH (bad)
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes_fail:MBEDTLS_PK_RSA:0:PSA_KEY_USAGE_SIGN_HASH:MBEDTLS_ERR_PK_TYPE_MISMATCH
+
+PSA attributes for pk: RSA v15 pair DERIVE (bad)
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes_fail:MBEDTLS_PK_RSA:1:PSA_KEY_USAGE_DERIVE:MBEDTLS_ERR_PK_TYPE_MISMATCH
+
+PSA attributes for pk: RSA v15 public DERIVE (bad)
+depends_on:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V15
+pk_get_psa_attributes_fail:MBEDTLS_PK_RSA:0:PSA_KEY_USAGE_DERIVE:MBEDTLS_ERR_PK_TYPE_MISMATCH
diff --git a/tests/suites/test_suite_pk.function b/tests/suites/test_suite_pk.function
index 0ac84a2..d6902b4 100644
--- a/tests/suites/test_suite_pk.function
+++ b/tests/suites/test_suite_pk.function
@@ -174,6 +174,30 @@
     TEST_EQUAL(mbedtls_pk_setup(pk, mbedtls_pk_info_from_type(pk_type)), 0);
 
     switch (pk_type) {
+#if defined(MBEDTLS_RSA_C)
+        case MBEDTLS_PK_RSA:
+        {
+            *psa_type = PSA_KEY_TYPE_RSA_KEY_PAIR;
+            mbedtls_rsa_context *rsa = mbedtls_pk_rsa(*pk);
+            if (want_pair) {
+                TEST_EQUAL(mbedtls_rsa_gen_key(
+                               rsa,
+                               mbedtls_test_rnd_std_rand, NULL,
+                               MBEDTLS_RSA_GEN_KEY_MIN_BITS, 65537), 0);
+            } else {
+                unsigned char N[PSA_BITS_TO_BYTES(MBEDTLS_RSA_GEN_KEY_MIN_BITS)] = { 0xff };
+                N[sizeof(N) - 1] = 0x03;
+                const unsigned char E[1] = {0x03};
+                TEST_EQUAL(mbedtls_rsa_import_raw(rsa,
+                                                  N, sizeof(N),
+                                                  NULL, 0, NULL, 0, NULL, 0,
+                                                  E, sizeof(E)), 0);
+                TEST_EQUAL(mbedtls_rsa_complete(rsa), 0);
+            }
+            break;
+        }
+#endif /* MBEDTLS_RSA_C */
+
         default:
             TEST_FAIL("Unknown PK type in test data");
             break;
@@ -1606,6 +1630,131 @@
 /* END_CASE */
 
 /* BEGIN_CASE depends_on:MBEDTLS_PSA_CRYPTO_C */
+void pk_get_psa_attributes(int pk_type, int from_pair,
+                           int usage_arg,
+                           int to_pair, int expected_alg)
+{
+    mbedtls_pk_context pk;
+    mbedtls_pk_init(&pk);
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    psa_key_usage_t usage = usage_arg;
+
+    MD_OR_USE_PSA_INIT();
+
+    psa_key_type_t expected_psa_type = 0;
+    if (!pk_setup_for_type(pk_type, from_pair, &pk, &expected_psa_type)) {
+        goto exit;
+    }
+    if (!to_pair) {
+        expected_psa_type = PSA_KEY_TYPE_PUBLIC_KEY_OF_KEY_PAIR(expected_psa_type);
+    }
+
+    psa_key_lifetime_t lifetime = PSA_KEY_LIFETIME_VOLATILE; //TODO: diversity
+    mbedtls_svc_key_id_t key_id = MBEDTLS_SVC_KEY_ID_INIT; //TODO: diversity
+    psa_set_key_id(&attributes, key_id);
+    psa_set_key_lifetime(&attributes, lifetime);
+
+    psa_key_usage_t expected_usage = usage;
+    /* Usage implied universally */
+    if (expected_usage & PSA_KEY_USAGE_SIGN_HASH) {
+        expected_usage |= PSA_KEY_USAGE_SIGN_MESSAGE;
+    }
+    if (expected_usage & PSA_KEY_USAGE_VERIFY_HASH) {
+        expected_usage |= PSA_KEY_USAGE_VERIFY_MESSAGE;
+    }
+    /* Usage implied by mbedtls_pk_get_psa_attributes() */
+    if (expected_usage & PSA_KEY_USAGE_SIGN_HASH) {
+        expected_usage |= PSA_KEY_USAGE_VERIFY_HASH;
+    }
+    if (expected_usage & PSA_KEY_USAGE_SIGN_MESSAGE) {
+        expected_usage |= PSA_KEY_USAGE_VERIFY_MESSAGE;
+    }
+    if (expected_usage & PSA_KEY_USAGE_DECRYPT) {
+        expected_usage |= PSA_KEY_USAGE_ENCRYPT;
+    }
+    expected_usage |= PSA_KEY_USAGE_EXPORT | PSA_KEY_USAGE_COPY;
+
+    TEST_EQUAL(mbedtls_pk_get_psa_attributes(&pk, usage, &attributes), 0);
+
+    TEST_EQUAL(psa_get_key_lifetime(&attributes), lifetime);
+    TEST_ASSERT(mbedtls_svc_key_id_equal(psa_get_key_id(&attributes),
+                                         key_id));
+    TEST_EQUAL(psa_get_key_type(&attributes), expected_psa_type);
+    TEST_EQUAL(psa_get_key_bits(&attributes),
+               mbedtls_pk_get_bitlen(&pk));
+    TEST_EQUAL(psa_get_key_usage_flags(&attributes), expected_usage);
+    TEST_EQUAL(psa_get_key_algorithm(&attributes), expected_alg);
+    TEST_EQUAL(psa_get_key_enrollment_algorithm(&attributes), PSA_ALG_NONE);
+
+exit:
+    mbedtls_pk_free(&pk);
+    psa_reset_key_attributes(&attributes);
+    MD_OR_USE_PSA_DONE();
+}
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_PSA_CRYPTO_C:MBEDTLS_RSA_C:MBEDTLS_PKCS1_V21 */
+void pk_rsa_v21_get_psa_attributes(int md_type, int from_pair,
+                                   int usage_arg,
+                                   int to_pair, int expected_alg)
+{
+    mbedtls_pk_context pk;
+    mbedtls_pk_init(&pk);
+    psa_key_usage_t usage = usage_arg;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+
+    MD_OR_USE_PSA_INIT();
+
+    psa_key_type_t expected_psa_type = 0;
+    if (!pk_setup_for_type(MBEDTLS_PK_RSA, from_pair, &pk, &expected_psa_type)) {
+        goto exit;
+    }
+    mbedtls_rsa_context *rsa = mbedtls_pk_rsa(pk);
+    TEST_EQUAL(mbedtls_rsa_set_padding(rsa, MBEDTLS_RSA_PKCS_V21, md_type), 0);
+    if (!to_pair) {
+        expected_psa_type = PSA_KEY_TYPE_PUBLIC_KEY_OF_KEY_PAIR(expected_psa_type);
+    }
+
+    psa_key_usage_t expected_usage = usage;
+    /* Usage implied universally */
+    if (expected_usage & PSA_KEY_USAGE_SIGN_HASH) {
+        expected_usage |= PSA_KEY_USAGE_SIGN_MESSAGE;
+    }
+    if (expected_usage & PSA_KEY_USAGE_VERIFY_HASH) {
+        expected_usage |= PSA_KEY_USAGE_VERIFY_MESSAGE;
+    }
+    /* Usage implied by mbedtls_pk_get_psa_attributes() */
+    if (expected_usage & PSA_KEY_USAGE_SIGN_HASH) {
+        expected_usage |= PSA_KEY_USAGE_VERIFY_HASH;
+    }
+    if (expected_usage & PSA_KEY_USAGE_SIGN_MESSAGE) {
+        expected_usage |= PSA_KEY_USAGE_VERIFY_MESSAGE;
+    }
+    if (expected_usage & PSA_KEY_USAGE_DECRYPT) {
+        expected_usage |= PSA_KEY_USAGE_ENCRYPT;
+    }
+    expected_usage |= PSA_KEY_USAGE_EXPORT | PSA_KEY_USAGE_COPY;
+
+    TEST_EQUAL(mbedtls_pk_get_psa_attributes(&pk, usage, &attributes), 0);
+
+    TEST_EQUAL(psa_get_key_lifetime(&attributes), PSA_KEY_LIFETIME_VOLATILE);
+    TEST_ASSERT(mbedtls_svc_key_id_equal(psa_get_key_id(&attributes),
+                                         MBEDTLS_SVC_KEY_ID_INIT));
+    TEST_EQUAL(psa_get_key_type(&attributes), expected_psa_type);
+    TEST_EQUAL(psa_get_key_bits(&attributes),
+               mbedtls_pk_get_bitlen(&pk));
+    TEST_EQUAL(psa_get_key_usage_flags(&attributes), expected_usage);
+    TEST_EQUAL(psa_get_key_algorithm(&attributes), expected_alg);
+    TEST_EQUAL(psa_get_key_enrollment_algorithm(&attributes), PSA_ALG_NONE);
+
+exit:
+    mbedtls_pk_free(&pk);
+    psa_reset_key_attributes(&attributes);
+    MD_OR_USE_PSA_DONE();
+}
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_PSA_CRYPTO_C */
 void pk_get_psa_attributes_fail(int pk_type, int from_pair,
                                 int usage_arg,
                                 int expected_ret)