Import PSK as opaque PSA key for mbedtls_ssl_conf_psk() & mbedtls_ssl_set_hs_psk()

Signed-off-by: Neil Armstrong <narmstrong@baylibre.com>
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 250bae9..85cfa48 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -1541,8 +1541,13 @@
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( ! mbedtls_svc_key_id_is_null( conf->psk_opaque ) )
     {
-        /* The maintenance of the PSK key slot is the
+        /* The maintenance of the external PSK key slot is the
          * user's responsibility. */
+        if( conf->psk_opaque_is_internal )
+        {
+            psa_destroy_key( conf->psk_opaque );
+            conf->psk_opaque_is_internal = 0;
+        }
         conf->psk_opaque = MBEDTLS_SVC_KEY_ID_INIT;
     }
     /* This and the following branch should never
@@ -1600,6 +1605,11 @@
                 const unsigned char *psk_identity, size_t psk_identity_len )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_key_attributes_t key_attributes = psa_key_attributes_init();
+    psa_status_t status;
+    mbedtls_svc_key_id_t key;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
     /* We currently only support one PSK, raw or opaque. */
     if( ssl_conf_psk_is_configured( conf ) )
@@ -1613,6 +1623,23 @@
     if( psk_len > MBEDTLS_PSK_MAX_LEN )
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_set_key_usage_flags( &key_attributes, PSA_KEY_USAGE_DERIVE );
+    psa_set_key_algorithm( &key_attributes,
+                           PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_256) );
+    psa_set_key_enrollment_algorithm( &key_attributes,
+                           PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_384) );
+    psa_set_key_type( &key_attributes, PSA_KEY_TYPE_DERIVE );
+
+    status = psa_import_key( &key_attributes, psk, psk_len, &key );
+    if( status != PSA_SUCCESS )
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+    /* Allow calling psa_destroy_key() on config psk remove/free */
+    conf->psk_opaque_is_internal = 1;
+    ret = mbedtls_ssl_conf_psk_opaque( conf, key,
+                                       psk_identity, psk_identity_len );
+#else
     if( ( conf->psk = mbedtls_calloc( 1, psk_len ) ) == NULL )
         return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
     conf->psk_len = psk_len;
@@ -1622,6 +1649,7 @@
     ret = ssl_conf_set_psk_identity( conf, psk_identity, psk_identity_len );
     if( ret != 0 )
         ssl_conf_remove_psk( conf );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
     return( ret );
 }
@@ -1631,6 +1659,13 @@
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( ! mbedtls_svc_key_id_is_null( ssl->handshake->psk_opaque ) )
     {
+        /* The maintenance of the external PSK key slot is the
+         * user's responsibility. */
+        if( ssl->handshake->psk_opaque_is_internal )
+        {
+            psa_destroy_key( ssl->handshake->psk_opaque );
+            ssl->handshake->psk_opaque_is_internal = 0;
+        }
         ssl->handshake->psk_opaque = MBEDTLS_SVC_KEY_ID_INIT;
     }
     else
@@ -1647,6 +1682,13 @@
 int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl,
                             const unsigned char *psk, size_t psk_len )
 {
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_key_attributes_t key_attributes = psa_key_attributes_init();
+    psa_status_t status;
+    psa_algorithm_t alg;
+    mbedtls_svc_key_id_t key;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
     if( psk == NULL || ssl->handshake == NULL )
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
 
@@ -1655,6 +1697,24 @@
 
     ssl_remove_psk( ssl );
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if( ssl->handshake->ciphersuite_info->mac == MBEDTLS_MD_SHA384)
+        alg = PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_384);
+    else
+        alg = PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_256);
+
+    psa_set_key_usage_flags( &key_attributes, PSA_KEY_USAGE_DERIVE );
+    psa_set_key_algorithm( &key_attributes, alg );
+    psa_set_key_type( &key_attributes, PSA_KEY_TYPE_DERIVE );
+
+    status = psa_import_key( &key_attributes, psk, psk_len, &key );
+    if( status != PSA_SUCCESS )
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
+
+    /* Allow calling psa_destroy_key() on psk remove */
+    ssl->handshake->psk_opaque_is_internal = 1;
+    return mbedtls_ssl_set_hs_psk_opaque( ssl, key );
+#else
     if( ( ssl->handshake->psk = mbedtls_calloc( 1, psk_len ) ) == NULL )
         return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
 
@@ -1662,6 +1722,7 @@
     memcpy( ssl->handshake->psk, psk, ssl->handshake->psk_len );
 
     return( 0 );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 }
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
@@ -3231,6 +3292,19 @@
 #endif
 
 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if( ! mbedtls_svc_key_id_is_null( ssl->handshake->psk_opaque ) )
+    {
+        /* The maintenance of the external PSK key slot is the
+         * user's responsibility. */
+        if( ssl->handshake->psk_opaque_is_internal )
+        {
+            psa_destroy_key( ssl->handshake->psk_opaque );
+            ssl->handshake->psk_opaque_is_internal = 0;
+        }
+        ssl->handshake->psk_opaque = MBEDTLS_SVC_KEY_ID_INIT;
+    }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     if( handshake->psk != NULL )
     {
         mbedtls_platform_zeroize( handshake->psk, handshake->psk_len );
@@ -4424,6 +4498,17 @@
 #endif
 
 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if( ! mbedtls_svc_key_id_is_null( conf->psk_opaque ) )
+    {
+        if( conf->psk_opaque_is_internal )
+        {
+            psa_destroy_key( conf->psk_opaque );
+            conf->psk_opaque_is_internal = 0;
+        }
+        conf->psk_opaque = MBEDTLS_SVC_KEY_ID_INIT;
+    }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     if( conf->psk != NULL )
     {
         mbedtls_platform_zeroize( conf->psk, conf->psk_len );