psa: Change psa_import_key_into_slot() signature

Change psa_import_key_into_slot() signature to the signature
of an import_key driver entry point.

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 411010c..4cdfd99 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -583,49 +583,48 @@
     return( PSA_SUCCESS );
 }
 
-/** Import key data into a slot.
+/** Import a key in binary format.
  *
- * `slot->type` must have been set previously.
- * This function assumes that the slot does not contain any key material yet.
- * On failure, the slot content is unchanged.
+ * \note The signature of this function is that of a PSA driver
+ *       import_key entry point. This function behaves as an import_key
+ *       entry point as defined in the PSA driver interface specification for
+ *       transparent drivers.
  *
- * Persistent storage is not affected.
- *
- * \param[in,out] slot     The key slot to import data into.
- *                         Its `type` field must have previously been set to
- *                         the desired key type.
- *                         It must not contain any key material yet.
- * \param[in] data         Buffer containing the key material to parse and
- *                         import.
- * \param data_length      Size of \p data in bytes.
- * \param[out] key_buffer  The buffer containing the export representation.
- * \param[in]  key_buffer_size    The size of \p key_buffer in bytes. The size
- *                                is greater or equal to \p data_length.
+ * \param[in]  attributes       The attributes for the key to import.
+ * \param[in]  data             The buffer containing the key data in import
+ *                              format.
+ * \param[in]  data_length      Size of the \p data buffer in bytes.
+ * \param[out] key_buffer       The buffer containing the key data in output
+ *                              format.
+ * \param[in]  key_buffer_size  Size of the \p key_buffer buffer in bytes. This
+ *                              size is greater or equal to \p data_length.
  * \param[out] key_buffer_length  The length of the data written in \p
  *                                key_buffer in bytes.
+ * \param[out] bits             The key size in number of bits.
  *
- * \retval #PSA_SUCCESS
+ * \retval #PSA_SUCCESS  The key was imported successfully.
  * \retval #PSA_ERROR_INVALID_ARGUMENT
+ *         The key data is not correctly formatted.
  * \retval #PSA_ERROR_NOT_SUPPORTED
  * \retval #PSA_ERROR_INSUFFICIENT_MEMORY
+ * \retval #PSA_ERROR_CORRUPTION_DETECTED
  */
-static psa_status_t psa_import_key_into_slot( psa_key_slot_t *slot,
-                                              const uint8_t *data,
-                                              size_t data_length,
-                                              uint8_t *key_buffer,
-                                              size_t key_buffer_size,
-                                              size_t *key_buffer_length )
+static psa_status_t psa_import_key_into_slot(
+    const psa_key_attributes_t *attributes,
+    const uint8_t *data, size_t data_length,
+    uint8_t *key_buffer, size_t key_buffer_size,
+    size_t *key_buffer_length, size_t *bits )
 {
-    psa_status_t status = PSA_SUCCESS;
-    size_t bit_size;
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_key_type_t type = attributes->core.type;
 
     /* zero-length keys are never supported. */
     if( data_length == 0 )
         return( PSA_ERROR_NOT_SUPPORTED );
 
-    if( key_type_is_raw_bytes( slot->attr.type ) )
+    if( key_type_is_raw_bytes( type ) )
     {
-        bit_size = PSA_BYTES_TO_BITS( data_length );
+        *bits = PSA_BYTES_TO_BITS( data_length );
 
         /* Ensure that the bytes-to-bits conversion hasn't overflown. */
         if( data_length > SIZE_MAX / 8 )
@@ -633,10 +632,10 @@
 
         /* Enforce a size limit, and in particular ensure that the bit
          * size fits in its representation type. */
-        if( bit_size > PSA_MAX_KEY_BITS )
+        if( ( *bits ) > PSA_MAX_KEY_BITS )
             return( PSA_ERROR_NOT_SUPPORTED );
 
-        status = validate_unstructured_key_bit_size( slot->attr.type, bit_size );
+        status = validate_unstructured_key_bit_size( type, *bits );
         if( status != PSA_SUCCESS )
             return( status );
 
@@ -645,41 +644,18 @@
         *key_buffer_length = data_length;
         (void)key_buffer_size;
 
-        /* Write the actual key size to the slot.
-         * psa_start_key_creation() wrote the size declared by the
-         * caller, which may be 0 (meaning unspecified) or wrong. */
-        slot->attr.bits = (psa_key_bits_t) bit_size;
-
         return( PSA_SUCCESS );
     }
-    else if( PSA_KEY_TYPE_IS_ASYMMETRIC( slot->attr.type ) )
+    else if( PSA_KEY_TYPE_IS_ASYMMETRIC( type ) )
     {
-        /* Try validation through accelerators first. */
-        psa_key_attributes_t attributes = {
-          .core = slot->attr
-        };
-
-        bit_size = slot->attr.bits;
-        status = psa_driver_wrapper_import_key( &attributes,
+        status = psa_driver_wrapper_import_key( attributes,
                                                 data, data_length,
                                                 key_buffer,
                                                 key_buffer_size,
                                                 key_buffer_length,
-                                                &bit_size );
-        if( status == PSA_SUCCESS )
-        {
-            if( slot->attr.bits == 0 )
-                slot->attr.bits = (psa_key_bits_t) bit_size;
-            else if( bit_size != slot->attr.bits )
-                return( PSA_ERROR_INVALID_ARGUMENT );
-
-            return( PSA_SUCCESS );
-        }
-        else
-        {
-            if( status != PSA_ERROR_NOT_SUPPORTED )
-                return( status );
-        }
+                                                bits );
+        if( status != PSA_ERROR_NOT_SUPPORTED )
+            return( status );
 
         mbedtls_platform_zeroize( key_buffer, key_buffer_size );
 
@@ -687,41 +663,31 @@
          * if present. */
 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_ECC_KEY_PAIR) || \
     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_ECC_PUBLIC_KEY)
-        if( PSA_KEY_TYPE_IS_ECC( slot->attr.type ) )
+        if( PSA_KEY_TYPE_IS_ECC( type ) )
         {
-            status = mbedtls_psa_ecp_import_key( &attributes,
-                                                 data, data_length,
-                                                 key_buffer, key_buffer_size,
-                                                 key_buffer_length,
-                                                 &bit_size );
-            slot->attr.bits = (psa_key_bits_t) bit_size;
-            return( status );
+            return( mbedtls_psa_ecp_import_key( attributes,
+                                                data, data_length,
+                                                key_buffer, key_buffer_size,
+                                                key_buffer_length,
+                                                bits ) );
         }
 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_ECC_KEY_PAIR) ||
         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_ECC_PUBLIC_KEY) */
 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) || \
     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
-        if( PSA_KEY_TYPE_IS_RSA( slot->attr.type ) )
+        if( PSA_KEY_TYPE_IS_RSA( type ) )
         {
-            status = mbedtls_psa_rsa_import_key( &attributes,
-                                                 data, data_length,
-                                                 key_buffer, key_buffer_size,
-                                                 key_buffer_length,
-                                                 &bit_size );
-            slot->attr.bits = (psa_key_bits_t) bit_size;
-            return( status );
+            return( mbedtls_psa_rsa_import_key( attributes,
+                                                data, data_length,
+                                                key_buffer, key_buffer_size,
+                                                key_buffer_length,
+                                                bits ) );
         }
 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) ||
         * defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
+    }
 
-        /* Fell through the fallback as well, so have nothing else to try. */
-        return( PSA_ERROR_NOT_SUPPORTED );
-    }
-    else
-    {
-        /* Unknown key type */
-        return( PSA_ERROR_NOT_SUPPORTED );
-    }
+    return( PSA_ERROR_NOT_SUPPORTED );
 }
 
 /** Calculate the intersection of two algorithm usage policies.
@@ -1929,13 +1895,24 @@
         if( status != PSA_SUCCESS )
             goto exit;
 
-        status = psa_import_key_into_slot( slot, data, data_length,
+        size_t bits = slot->attr.bits;
+        status = psa_import_key_into_slot( attributes,
+                                           data, data_length,
                                            slot->key.data,
                                            slot->key.bytes,
-                                           &slot->key.bytes );
+                                           &slot->key.bytes, &bits );
         if( status != PSA_SUCCESS )
             goto exit;
+
+        if( slot->attr.bits == 0 )
+            slot->attr.bits = (psa_key_bits_t) bits;
+        else if( bits != slot->attr.bits )
+        {
+            status = PSA_ERROR_INVALID_ARGUMENT;
+            goto exit;
+        }
     }
+
     status = psa_validate_optional_attributes( slot, attributes );
     if( status != PSA_SUCCESS )
         goto exit;
@@ -5240,9 +5217,16 @@
     if( status != PSA_SUCCESS )
         return( status );
 
-    status = psa_import_key_into_slot( slot, data, bytes,
+    psa_key_attributes_t attributes = {
+      .core = slot->attr
+    };
+
+    status = psa_import_key_into_slot( &attributes,
+                                       data, bytes,
                                        slot->key.data, slot->key.bytes,
-                                       &slot->key.bytes );
+                                       &slot->key.bytes,
+                                       &bits );
+    slot->attr.bits = (psa_key_bits_t) bits;
 
 exit:
     mbedtls_free( data );