Add support for RSA in mbedtls_pk_wrap_as_opaque()

Signed-off-by: Neil Armstrong <narmstrong@baylibre.com>
diff --git a/library/pk.c b/library/pk.c
index abed468..d6ea912 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -22,6 +22,7 @@
 #if defined(MBEDTLS_PK_C)
 #include "mbedtls/pk.h"
 #include "pk_wrap.h"
+#include "pkwrite.h"
 
 #include "mbedtls/platform_util.h"
 #include "mbedtls/error.h"
@@ -708,51 +709,92 @@
                                mbedtls_svc_key_id_t *key,
                                psa_algorithm_t hash_alg )
 {
-#if !defined(MBEDTLS_ECP_C)
+#if !defined(MBEDTLS_ECP_C) && !defined(MBEDTLS_RSA_C)
     ((void) pk);
     ((void) key);
     ((void) hash_alg);
-    return( MBEDTLS_ERR_PK_TYPE_MISMATCH );
 #else
-    const mbedtls_ecp_keypair *ec;
-    unsigned char d[MBEDTLS_ECP_MAX_BYTES];
-    size_t d_len;
-    psa_ecc_family_t curve_id;
-    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
-    psa_key_type_t key_type;
-    size_t bits;
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+#if defined(MBEDTLS_ECP_C)
+    if( mbedtls_pk_get_type( pk ) == MBEDTLS_PK_ECKEY )
+    {
+        const mbedtls_ecp_keypair *ec;
+        unsigned char d[MBEDTLS_ECP_MAX_BYTES];
+        size_t d_len;
+        psa_ecc_family_t curve_id;
+        psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+        psa_key_type_t key_type;
+        size_t bits;
+        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
-    /* export the private key material in the format PSA wants */
-    if( mbedtls_pk_get_type( pk ) != MBEDTLS_PK_ECKEY )
-        return( MBEDTLS_ERR_PK_TYPE_MISMATCH );
+        /* export the private key material in the format PSA wants */
+        ec = mbedtls_pk_ec( *pk );
+        d_len = ( ec->grp.nbits + 7 ) / 8;
+        if( ( ret = mbedtls_mpi_write_binary( &ec->d, d, d_len ) ) != 0 )
+            return( ret );
 
-    ec = mbedtls_pk_ec( *pk );
-    d_len = ( ec->grp.nbits + 7 ) / 8;
-    if( ( ret = mbedtls_mpi_write_binary( &ec->d, d, d_len ) ) != 0 )
-        return( ret );
+        curve_id = mbedtls_ecc_group_to_psa( ec->grp.id, &bits );
+        key_type = PSA_KEY_TYPE_ECC_KEY_PAIR( curve_id );
 
-    curve_id = mbedtls_ecc_group_to_psa( ec->grp.id, &bits );
-    key_type = PSA_KEY_TYPE_ECC_KEY_PAIR( curve_id );
+        /* prepare the key attributes */
+        psa_set_key_type( &attributes, key_type );
+        psa_set_key_bits( &attributes, bits );
+        psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_SIGN_HASH |
+                                              PSA_KEY_USAGE_DERIVE);
+        psa_set_key_algorithm( &attributes, PSA_ALG_ECDSA( hash_alg ) );
+        psa_set_key_enrollment_algorithm( &attributes, PSA_ALG_ECDH );
 
-    /* prepare the key attributes */
-    psa_set_key_type( &attributes, key_type );
-    psa_set_key_bits( &attributes, bits );
-    psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_SIGN_HASH |
-                                          PSA_KEY_USAGE_DERIVE);
-    psa_set_key_algorithm( &attributes, PSA_ALG_ECDSA(hash_alg) );
-    psa_set_key_enrollment_algorithm( &attributes, PSA_ALG_ECDH );
+        /* import private key into PSA */
+        if( PSA_SUCCESS != psa_import_key( &attributes, d, d_len, key ) )
+            return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
 
-    /* import private key into PSA */
-    if( PSA_SUCCESS != psa_import_key( &attributes, d, d_len, key ) )
-        return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
+        /* make PK context wrap the key slot */
+        mbedtls_pk_free( pk );
+        mbedtls_pk_init( pk );
 
-    /* make PK context wrap the key slot */
-    mbedtls_pk_free( pk );
-    mbedtls_pk_init( pk );
-
-    return( mbedtls_pk_setup_opaque( pk, *key ) );
+        return( mbedtls_pk_setup_opaque( pk, *key ) );
+    }
+    else
 #endif /* MBEDTLS_ECP_C */
+#if defined(MBEDTLS_RSA_C)
+    if( mbedtls_pk_get_type( pk ) == MBEDTLS_PK_RSA )
+    {
+        unsigned char buf[MBEDTLS_PK_RSA_PRV_DER_MAX_BYTES];
+        psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+        int key_len;
+        psa_status_t status;
+
+        /* export the private key material in the format PSA wants */
+        key_len = mbedtls_pk_write_key_der( pk, buf, sizeof( buf ) );
+        if( key_len <= 0 )
+            return( MBEDTLS_ERR_PK_BAD_INPUT_DATA );
+
+        /* prepare the key attributes */
+        psa_set_key_type( &attributes, PSA_KEY_TYPE_RSA_KEY_PAIR );
+        psa_set_key_bits( &attributes, mbedtls_pk_get_bitlen( pk ) );
+        psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_SIGN_HASH );
+        psa_set_key_algorithm( &attributes,
+                               PSA_ALG_RSA_PKCS1V15_SIGN( hash_alg ) );
+
+        /* import private key into PSA */
+        status = psa_import_key( &attributes,
+                                 buf + sizeof( buf ) - key_len,
+                                 key_len, key);
+
+        mbedtls_platform_zeroize( buf, sizeof( buf ) );
+
+        if( status != PSA_SUCCESS )
+            return( MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED );
+
+        /* make PK context wrap the key slot */
+        mbedtls_pk_free( pk );
+        mbedtls_pk_init( pk );
+
+        return( mbedtls_pk_setup_opaque( pk, *key ) );
+    }
+    else
+#endif /* MBEDTLS_RSA_C */
+#endif /* !MBEDTLS_ECP_C && !MBEDTLS_RSA_C */
+    return( MBEDTLS_ERR_PK_TYPE_MISMATCH );
 }
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
 #endif /* MBEDTLS_PK_C */