Remove RSA internal representation from key slot

Change to on-demand loading of the internal representation when required
in order to call an mbed TLS cryptography API.

Signed-off-by: Steven Cooreman <steven.cooreman@silabs.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index e2e99d7..6f374b1 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -492,7 +492,9 @@
     return( PSA_SUCCESS );
 }
 
-#if defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PK_PARSE_C)
+#if defined(MBEDTLS_RSA_C)
+
+#if defined(MBEDTLS_PK_PARSE_C)
 /* Mbed TLS doesn't support non-byte-aligned key sizes (i.e. key sizes
  * that are not a multiple of 8) well. For example, there is only
  * mbedtls_rsa_get_len(), which returns a number of bytes, and no
@@ -514,63 +516,201 @@
     mbedtls_mpi_free( &n );
     return( status );
 }
+#endif /* MBEDTLS_PK_PARSE_C */
 
-static psa_status_t psa_import_rsa_key( psa_key_type_t type,
-                                        const uint8_t *data,
-                                        size_t data_length,
-                                        mbedtls_rsa_context **p_rsa )
+/** Load the contents of a key slot into an internal RSA representation
+ *
+ * \param[in] slot  The slot from which to load the representation
+ * \param[out] rsa  The internal RSA representation to hold the key. Must be
+ *                  allocated and initialized. If it already holds a
+ *                  different key, it will be overwritten and cause a memory
+ *                  leak.
+ */
+static psa_status_t psa_load_rsa_representation( const psa_key_slot_t *slot,
+                                                 mbedtls_rsa_context *rsa )
 {
+#if defined(MBEDTLS_PK_PARSE_C)
     psa_status_t status;
-    mbedtls_pk_context pk;
-    mbedtls_rsa_context *rsa;
+    mbedtls_pk_context ctx;
     size_t bits;
-
-    mbedtls_pk_init( &pk );
+    mbedtls_pk_init( &ctx );
 
     /* Parse the data. */
-    if( PSA_KEY_TYPE_IS_KEY_PAIR( type ) )
+    if( PSA_KEY_TYPE_IS_KEY_PAIR( slot->attr.type ) )
         status = mbedtls_to_psa_error(
-            mbedtls_pk_parse_key( &pk, data, data_length, NULL, 0 ) );
+            mbedtls_pk_parse_key( &ctx, slot->data.key.data, slot->data.key.bytes, NULL, 0 ) );
     else
         status = mbedtls_to_psa_error(
-            mbedtls_pk_parse_public_key( &pk, data, data_length ) );
+            mbedtls_pk_parse_public_key( &ctx, slot->data.key.data, slot->data.key.bytes ) );
     if( status != PSA_SUCCESS )
         goto exit;
 
     /* We have something that the pkparse module recognizes. If it is a
      * valid RSA key, store it. */
-    if( mbedtls_pk_get_type( &pk ) != MBEDTLS_PK_RSA )
+    if( mbedtls_pk_get_type( &ctx ) != MBEDTLS_PK_RSA )
     {
         status = PSA_ERROR_INVALID_ARGUMENT;
         goto exit;
     }
 
-    rsa = mbedtls_pk_rsa( pk );
     /* The size of an RSA key doesn't have to be a multiple of 8. Mbed TLS
      * supports non-byte-aligned key sizes, but not well. For example,
      * mbedtls_rsa_get_len() returns the key size in bytes, not in bits. */
-    bits = PSA_BYTES_TO_BITS( mbedtls_rsa_get_len( rsa ) );
+    bits = PSA_BYTES_TO_BITS( mbedtls_rsa_get_len( mbedtls_pk_rsa( ctx ) ) );
     if( bits > PSA_VENDOR_RSA_MAX_KEY_BITS )
     {
         status = PSA_ERROR_NOT_SUPPORTED;
         goto exit;
     }
-    status = psa_check_rsa_key_byte_aligned( rsa );
+    status = psa_check_rsa_key_byte_aligned( mbedtls_pk_rsa( ctx ) );
+
+    if( status != PSA_SUCCESS )
+        goto exit;
+
+    /* Copy the PK-contained RSA context to the one provided as function input */
+    status = mbedtls_to_psa_error(
+                mbedtls_rsa_copy( rsa, mbedtls_pk_rsa( ctx ) ) );
 
 exit:
-    /* Free the content of the pk object only on error. */
+    mbedtls_pk_free( &ctx );
+    return( status );
+#else
+    (void) slot;
+    (void) rsa;
+    return( PSA_ERROR_NOT_SUPPORTED );
+#endif /* MBEDTLS_PK_PARSE_C */
+}
+
+/** Export an RSA key to export representation
+ *
+ * \param[in] type          The type of key (public/private) to export
+ * \param[in] rsa           The internal RSA representation from which to export
+ * \param[out] data         The buffer to export to
+ * \param[in] data_size     The length of the buffer to export to
+ * \param[out] data_length  The amount of bytes written to \p data
+ */
+static psa_status_t psa_export_rsa_key( psa_key_type_t type,
+                                        mbedtls_rsa_context *rsa,
+                                        uint8_t *data,
+                                        size_t data_size,
+                                        size_t *data_length )
+{
+#if defined(MBEDTLS_PK_WRITE_C)
+    int ret;
+    mbedtls_pk_context pk;
+    uint8_t *pos = data + data_size;
+
+    mbedtls_pk_init( &pk );
+    pk.pk_info = &mbedtls_rsa_info;
+    pk.pk_ctx = rsa;
+
+    /* PSA Crypto API defines the format of an RSA key as a DER-encoded
+     * representation of respectively the non-encrypted PKCS#1 RSAPrivateKey
+     * or the RFC3279 RSAPublicKey for a private key or a public key. */
+    if( PSA_KEY_TYPE_IS_KEY_PAIR( type ) )
+        ret = mbedtls_pk_write_key_der( &pk, data, data_size );
+    else
+        ret = mbedtls_pk_write_pubkey( &pos, data, &pk );
+
+    if( ret < 0 )
+        return mbedtls_to_psa_error( ret );
+
+    /* The mbedtls_pk_xxx functions write to the end of the buffer.
+     * Move the data to the beginning and erase remaining data
+     * at the original location. */
+    if( 2 * (size_t) ret <= data_size )
+    {
+        memcpy( data, data + data_size - ret, ret );
+        memset( data + data_size - ret, 0, ret );
+    }
+    else if( (size_t) ret < data_size )
+    {
+        memmove( data, data + data_size - ret, ret );
+        memset( data + ret, 0, data_size - ret );
+    }
+
+    *data_length = ret;
+    return( PSA_SUCCESS );
+#else
+    (void) type;
+    (void) rsa;
+    (void) data;
+    (void) data_size;
+    (void) data_length;
+    return( PSA_ERROR_NOT_SUPPORTED );
+#endif /* MBEDTLS_PK_WRITE_C */
+}
+
+/** Import an RSA key from import representation to a slot
+ *
+ * \param[in,out] slot      The slot where to store the export representation to
+ * \param[in] data          The buffer containing the import representation
+ * \param[in] data_length   The amount of bytes in \p data
+ */
+static psa_status_t psa_import_rsa_key( psa_key_slot_t *slot,
+                                        const uint8_t *data,
+                                        size_t data_length )
+{
+    psa_status_t status;
+    uint8_t* output = NULL;
+    mbedtls_rsa_context rsa;
+    mbedtls_rsa_init( &rsa, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
+
+    /* Temporarily load input into slot. The cast here is safe since it'll
+     * only be used for load_rsa_representation, which doesn't modify the
+     * buffer. */
+    slot->data.key.data = (uint8_t *)data;
+    slot->data.key.bytes = data_length;
+
+    /* Parse input */
+    status = psa_load_rsa_representation( slot, &rsa );
+    if( status != PSA_SUCCESS )
+        goto exit;
+
+    slot->attr.bits = (psa_key_bits_t) PSA_BYTES_TO_BITS(
+        mbedtls_rsa_get_len( &rsa ) );
+
+    /* Re-export the data to PSA export format, which in case of RSA is the
+     * smallest representation we can parse. */
+    output = mbedtls_calloc( 1, data_length );
+
+    if( output == NULL )
+    {
+        status = PSA_ERROR_INSUFFICIENT_MEMORY;
+        goto exit;
+    }
+
+    /* PSA Crypto API defines the format of an RSA key as a DER-encoded
+     * representation of respectively the non-encrypted PKCS#1 RSAPrivateKey
+     * or the RFC3279 RSAPublicKey for a private key or a public key. That
+     * means we have no other choice then to run an import to verify the key
+     * size. */
+    status = psa_export_rsa_key( slot->attr.type,
+                                 &rsa,
+                                 output,
+                                 data_length,
+                                 &data_length);
+
+exit:
+    /* Always free the RSA object */
+    mbedtls_rsa_free( &rsa );
+
+    /* Free the allocated buffer only on error. */
     if( status != PSA_SUCCESS )
     {
-        mbedtls_pk_free( &pk );
+        mbedtls_free( output );
+        slot->data.key.data = NULL;
+        slot->data.key.bytes = 0;
         return( status );
     }
 
-    /* On success, store the content of the object in the RSA context. */
-    *p_rsa = rsa;
+    /* On success, store the allocated export-formatted key. */
+    slot->data.key.data = output;
+    slot->data.key.bytes = data_length;
 
     return( PSA_SUCCESS );
 }
-#endif /* defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PK_PARSE_C) */
+#endif /* defined(MBEDTLS_RSA_C) */
 
 #if defined(MBEDTLS_ECP_C)
 static psa_status_t psa_prepare_import_ec_key( psa_ecc_family_t curve,
@@ -708,10 +848,6 @@
 
     if( key_type_is_raw_bytes( slot->attr.type ) )
         bits = PSA_BYTES_TO_BITS( slot->data.key.bytes );
-#if defined(MBEDTLS_RSA_C)
-    else if( PSA_KEY_TYPE_IS_RSA( slot->attr.type ) )
-        bits = PSA_BYTES_TO_BITS( mbedtls_rsa_get_len( slot->data.rsa ) );
-#endif /* defined(MBEDTLS_RSA_C) */
 #if defined(MBEDTLS_ECP_C)
     else if( PSA_KEY_TYPE_IS_ECC( slot->attr.type ) )
         bits = slot->data.ecp->grp.pbits;
@@ -788,9 +924,7 @@
 #if defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PK_PARSE_C)
     if( PSA_KEY_TYPE_IS_RSA( slot->attr.type ) )
     {
-        status = psa_import_rsa_key( slot->attr.type,
-            data, data_length,
-            &slot->data.rsa );
+        status = psa_import_rsa_key( slot, data, data_length );
     }
     else
 #endif /* defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PK_PARSE_C) */
@@ -800,10 +934,13 @@
 
     if( status == PSA_SUCCESS )
     {
-        /* Write the actual key size to the slot.
-         * psa_start_key_creation() wrote the size declared by the
-         * caller, which may be 0 (meaning unspecified) or wrong. */
-        slot->attr.bits = psa_calculate_key_bits( slot );
+        if( !PSA_KEY_TYPE_IS_RSA( slot->attr.type ) )
+        {
+            /* Write the actual key size to the slot.
+             * psa_start_key_creation() wrote the size declared by the
+             * caller, which may be 0 (meaning unspecified) or wrong. */
+            slot->attr.bits = psa_calculate_key_bits( slot );
+        }
     }
     return( status );
 }
@@ -980,8 +1117,9 @@
 #if defined(MBEDTLS_RSA_C)
     if( PSA_KEY_TYPE_IS_RSA( slot->attr.type ) )
     {
-        mbedtls_rsa_free( slot->data.rsa );
-        mbedtls_free( slot->data.rsa );
+        mbedtls_free( slot->data.key.data );
+        slot->data.key.data = NULL;
+        slot->data.key.bytes = 0;
     }
     else
 #endif /* defined(MBEDTLS_RSA_C) */
@@ -1232,7 +1370,18 @@
             if( psa_key_slot_is_external( slot ) )
                 break;
 #endif /* MBEDTLS_PSA_CRYPTO_SE_C */
-            status = psa_get_rsa_public_exponent( slot->data.rsa, attributes );
+            {
+                mbedtls_rsa_context rsa;
+                mbedtls_rsa_init( &rsa, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
+
+                status = psa_load_rsa_representation( slot, &rsa );
+                if( status != PSA_SUCCESS )
+                    break;
+
+                status = psa_get_rsa_public_exponent( &rsa,
+                                                      attributes );
+                mbedtls_rsa_free( &rsa );
+            }
             break;
 #endif /* MBEDTLS_RSA_C */
         default:
@@ -1276,6 +1425,20 @@
 }
 #endif /* defined(MBEDTLS_RSA_C) || defined(MBEDTLS_ECP_C) */
 
+static psa_status_t psa_internal_export_key_buffer( const psa_key_slot_t *slot,
+                                                    uint8_t *data,
+                                                    size_t data_size,
+                                                    size_t *data_length )
+{
+    if( slot->data.key.bytes > data_size )
+        return( PSA_ERROR_BUFFER_TOO_SMALL );
+    memcpy( data, slot->data.key.data, slot->data.key.bytes );
+    memset( data + slot->data.key.bytes, 0,
+            data_size - slot->data.key.bytes );
+    *data_length = slot->data.key.bytes;
+    return( PSA_SUCCESS );
+}
+
 static psa_status_t psa_internal_export_key( const psa_key_slot_t *slot,
                                              uint8_t *data,
                                              size_t data_size,
@@ -1354,10 +1517,36 @@
             if( PSA_KEY_TYPE_IS_RSA( slot->attr.type ) )
             {
 #if defined(MBEDTLS_RSA_C)
-                mbedtls_pk_init( &pk );
-                pk.pk_info = &mbedtls_rsa_info;
-                pk.pk_ctx = slot->data.rsa;
+                if( PSA_KEY_TYPE_IS_PUBLIC_KEY( slot->attr.type ) )
+                {
+                    /* Exporting public -> public */
+                    return( psa_internal_export_key_buffer( slot, data, data_size, data_length ) );
+                }
+                else if( !export_public_key )
+                {
+                    /* Exporting private -> private */
+                    return( psa_internal_export_key_buffer( slot, data, data_size, data_length ) );
+                }
+
+                /* Exporting private -> public */
+                mbedtls_rsa_context rsa;
+                mbedtls_rsa_init( &rsa, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
+
+                psa_status_t status = psa_load_rsa_representation( slot, &rsa );
+                if( status != PSA_SUCCESS )
+                    return status;
+
+                status = psa_export_rsa_key( PSA_KEY_TYPE_RSA_PUBLIC_KEY,
+                                             &rsa,
+                                             data,
+                                             data_size,
+                                             data_length );
+
+                mbedtls_rsa_free( &rsa );
+
+                return( status );
 #else
+                /* We don't know how to convert a private RSA key to public. */
                 return( PSA_ERROR_NOT_SUPPORTED );
 #endif
             }
@@ -1805,12 +1994,19 @@
 #if defined(MBEDTLS_RSA_C)
         if( PSA_KEY_TYPE_IS_RSA( slot->attr.type ) )
         {
+            mbedtls_rsa_context rsa;
+            mbedtls_rsa_init( &rsa, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
+
+            psa_status_t status = psa_load_rsa_representation( slot, &rsa );
+            if( status != PSA_SUCCESS )
+                return status;
             mbedtls_mpi actual, required;
             int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
             mbedtls_mpi_init( &actual );
             mbedtls_mpi_init( &required );
-            ret = mbedtls_rsa_export( slot->data.rsa,
+            ret = mbedtls_rsa_export( &rsa,
                                       NULL, NULL, NULL, NULL, &actual );
+            mbedtls_rsa_free( &rsa );
             if( ret != 0 )
                 goto rsa_exit;
             ret = mbedtls_mpi_read_binary( &required,
@@ -3447,11 +3643,21 @@
 #if defined(MBEDTLS_RSA_C)
     if( slot->attr.type == PSA_KEY_TYPE_RSA_KEY_PAIR )
     {
-        status = psa_rsa_sign( slot->data.rsa,
+        mbedtls_rsa_context rsa;
+        mbedtls_rsa_init( &rsa, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
+
+        status = psa_load_rsa_representation( slot,
+                                              &rsa );
+        if( status != PSA_SUCCESS )
+            goto exit;
+
+        status = psa_rsa_sign( &rsa,
                                alg,
                                hash, hash_length,
                                signature, signature_size,
                                signature_length );
+
+        mbedtls_rsa_free( &rsa );
     }
     else
 #endif /* defined(MBEDTLS_RSA_C) */
@@ -3533,10 +3739,19 @@
 #if defined(MBEDTLS_RSA_C)
     if( PSA_KEY_TYPE_IS_RSA( slot->attr.type ) )
     {
-        return( psa_rsa_verify( slot->data.rsa,
-                                alg,
-                                hash, hash_length,
-                                signature, signature_length ) );
+        mbedtls_rsa_context rsa;
+        mbedtls_rsa_init( &rsa, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
+
+        status = psa_load_rsa_representation( slot, &rsa );
+        if( status != PSA_SUCCESS )
+            return status;
+
+        status = psa_rsa_verify( &rsa,
+                                 alg,
+                                 hash, hash_length,
+                                 signature, signature_length );
+        mbedtls_rsa_free( &rsa );
+        return( status );
     }
     else
 #endif /* defined(MBEDTLS_RSA_C) */
@@ -3606,14 +3821,22 @@
 #if defined(MBEDTLS_RSA_C)
     if( PSA_KEY_TYPE_IS_RSA( slot->attr.type ) )
     {
-        mbedtls_rsa_context *rsa = slot->data.rsa;
+        mbedtls_rsa_context rsa;
+        mbedtls_rsa_init( &rsa, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
+
+        status = psa_load_rsa_representation( slot, &rsa );
+        if( status != PSA_SUCCESS )
+            return status;
         int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-        if( output_size < mbedtls_rsa_get_len( rsa ) )
+        if( output_size < mbedtls_rsa_get_len( &rsa ) )
+        {
+            mbedtls_rsa_free( &rsa );
             return( PSA_ERROR_BUFFER_TOO_SMALL );
+        }
 #if defined(MBEDTLS_PKCS1_V15)
         if( alg == PSA_ALG_RSA_PKCS1V15_CRYPT )
         {
-            ret = mbedtls_rsa_pkcs1_encrypt( rsa,
+            ret = mbedtls_rsa_pkcs1_encrypt( &rsa,
                                              mbedtls_ctr_drbg_random,
                                              &global_data.ctr_drbg,
                                              MBEDTLS_RSA_PUBLIC,
@@ -3626,8 +3849,8 @@
 #if defined(MBEDTLS_PKCS1_V21)
         if( PSA_ALG_IS_RSA_OAEP( alg ) )
         {
-            psa_rsa_oaep_set_padding_mode( alg, rsa );
-            ret = mbedtls_rsa_rsaes_oaep_encrypt( rsa,
+            psa_rsa_oaep_set_padding_mode( alg, &rsa );
+            ret = mbedtls_rsa_rsaes_oaep_encrypt( &rsa,
                                                   mbedtls_ctr_drbg_random,
                                                   &global_data.ctr_drbg,
                                                   MBEDTLS_RSA_PUBLIC,
@@ -3639,10 +3862,13 @@
         else
 #endif /* MBEDTLS_PKCS1_V21 */
         {
+            mbedtls_rsa_free( &rsa );
             return( PSA_ERROR_INVALID_ARGUMENT );
         }
         if( ret == 0 )
-            *output_length = mbedtls_rsa_get_len( rsa );
+            *output_length = mbedtls_rsa_get_len( &rsa );
+
+        mbedtls_rsa_free( &rsa );
         return( mbedtls_to_psa_error( ret ) );
     }
     else
@@ -3685,16 +3911,24 @@
 #if defined(MBEDTLS_RSA_C)
     if( slot->attr.type == PSA_KEY_TYPE_RSA_KEY_PAIR )
     {
-        mbedtls_rsa_context *rsa = slot->data.rsa;
+        mbedtls_rsa_context rsa;
+        mbedtls_rsa_init( &rsa, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
+
+        status = psa_load_rsa_representation( slot, &rsa );
+        if( status != PSA_SUCCESS )
+            return status;
         int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
-        if( input_length != mbedtls_rsa_get_len( rsa ) )
+        if( input_length != mbedtls_rsa_get_len( &rsa ) )
+        {
+            mbedtls_rsa_free( &rsa );
             return( PSA_ERROR_INVALID_ARGUMENT );
+        }
 
 #if defined(MBEDTLS_PKCS1_V15)
         if( alg == PSA_ALG_RSA_PKCS1V15_CRYPT )
         {
-            ret = mbedtls_rsa_pkcs1_decrypt( rsa,
+            ret = mbedtls_rsa_pkcs1_decrypt( &rsa,
                                              mbedtls_ctr_drbg_random,
                                              &global_data.ctr_drbg,
                                              MBEDTLS_RSA_PRIVATE,
@@ -3708,8 +3942,8 @@
 #if defined(MBEDTLS_PKCS1_V21)
         if( PSA_ALG_IS_RSA_OAEP( alg ) )
         {
-            psa_rsa_oaep_set_padding_mode( alg, rsa );
-            ret = mbedtls_rsa_rsaes_oaep_decrypt( rsa,
+            psa_rsa_oaep_set_padding_mode( alg, &rsa );
+            ret = mbedtls_rsa_rsaes_oaep_decrypt( &rsa,
                                                   mbedtls_ctr_drbg_random,
                                                   &global_data.ctr_drbg,
                                                   MBEDTLS_RSA_PRIVATE,
@@ -3722,9 +3956,11 @@
         else
 #endif /* MBEDTLS_PKCS1_V21 */
         {
+            mbedtls_rsa_free( &rsa );
             return( PSA_ERROR_INVALID_ARGUMENT );
         }
 
+        mbedtls_rsa_free( &rsa );
         return( mbedtls_to_psa_error( ret ) );
     }
     else
@@ -5567,7 +5803,7 @@
 #if defined(MBEDTLS_RSA_C) && defined(MBEDTLS_GENPRIME)
     if ( type == PSA_KEY_TYPE_RSA_KEY_PAIR )
     {
-        mbedtls_rsa_context *rsa;
+        mbedtls_rsa_context rsa;
         int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
         int exponent;
         psa_status_t status;
@@ -5582,22 +5818,36 @@
                                         &exponent );
         if( status != PSA_SUCCESS )
             return( status );
-        rsa = mbedtls_calloc( 1, sizeof( *rsa ) );
-        if( rsa == NULL )
-            return( PSA_ERROR_INSUFFICIENT_MEMORY );
-        mbedtls_rsa_init( rsa, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
-        ret = mbedtls_rsa_gen_key( rsa,
+        mbedtls_rsa_init( &rsa, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
+        ret = mbedtls_rsa_gen_key( &rsa,
                                    mbedtls_ctr_drbg_random,
                                    &global_data.ctr_drbg,
                                    (unsigned int) bits,
                                    exponent );
         if( ret != 0 )
-        {
-            mbedtls_rsa_free( rsa );
-            mbedtls_free( rsa );
             return( mbedtls_to_psa_error( ret ) );
+
+        /* Make sure to always have an export representation available */
+        size_t bytes = PSA_KEY_EXPORT_RSA_KEY_PAIR_MAX_SIZE( bits );
+
+        slot->data.key.data = mbedtls_calloc( 1, bytes );
+        if( slot->data.key.data == NULL )
+        {
+            mbedtls_rsa_free( &rsa );
+            return( PSA_ERROR_INSUFFICIENT_MEMORY );
         }
-        slot->data.rsa = rsa;
+
+        status = psa_export_rsa_key( type,
+                                     &rsa,
+                                     slot->data.key.data,
+                                     bytes,
+                                     &slot->data.key.bytes );
+        mbedtls_rsa_free( &rsa );
+        if( status != PSA_SUCCESS )
+        {
+            psa_remove_key_data_from_memory( slot );
+            return( status );
+        }
     }
     else
 #endif /* MBEDTLS_RSA_C && MBEDTLS_GENPRIME */
diff --git a/library/psa_crypto_core.h b/library/psa_crypto_core.h
index 8af45a1..c90d737 100644
--- a/library/psa_crypto_core.h
+++ b/library/psa_crypto_core.h
@@ -33,7 +33,6 @@
 #include "psa/crypto_se_driver.h"
 
 #include "mbedtls/ecp.h"
-#include "mbedtls/rsa.h"
 
 /** The data structure representing a key slot, containing key material
  * and metadata for one key.
@@ -50,10 +49,6 @@
             uint8_t *data;
             size_t bytes;
         } key;
-#if defined(MBEDTLS_RSA_C)
-        /* RSA public key or key pair */
-        mbedtls_rsa_context *rsa;
-#endif /* MBEDTLS_RSA_C */
 #if defined(MBEDTLS_ECP_C)
         /* EC public key or key pair */
         mbedtls_ecp_keypair *ecp;