Introduce mbedtls_ssl_get_ciphersuite_sig_pk_ext_alg() and use it in ssl_pick_cert()

Signed-off-by: Neil Armstrong <narmstrong@baylibre.com>
diff --git a/include/mbedtls/ssl_ciphersuites.h b/include/mbedtls/ssl_ciphersuites.h
index c770528..5352d75 100644
--- a/include/mbedtls/ssl_ciphersuites.h
+++ b/include/mbedtls/ssl_ciphersuites.h
@@ -389,6 +389,10 @@
 
 #if defined(MBEDTLS_PK_C)
 mbedtls_pk_type_t mbedtls_ssl_get_ciphersuite_sig_pk_alg( const mbedtls_ssl_ciphersuite_t *info );
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+psa_algorithm_t mbedtls_ssl_get_ciphersuite_sig_pk_ext_alg( const mbedtls_ssl_ciphersuite_t *info );
+psa_key_usage_t mbedtls_ssl_get_ciphersuite_sig_pk_ext_usage( const mbedtls_ssl_ciphersuite_t *info );
+#endif
 mbedtls_pk_type_t mbedtls_ssl_get_ciphersuite_sig_alg( const mbedtls_ssl_ciphersuite_t *info );
 #endif
 
diff --git a/library/ssl_ciphersuites.c b/library/ssl_ciphersuites.c
index 7deb57a..73a99e3 100644
--- a/library/ssl_ciphersuites.c
+++ b/library/ssl_ciphersuites.c
@@ -1921,6 +1921,53 @@
     }
 }
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+psa_algorithm_t mbedtls_ssl_get_ciphersuite_sig_pk_ext_alg( const mbedtls_ssl_ciphersuite_t *info )
+{
+    switch( info->key_exchange )
+    {
+        case MBEDTLS_KEY_EXCHANGE_RSA:
+        case MBEDTLS_KEY_EXCHANGE_RSA_PSK:
+            return( PSA_ALG_RSA_PKCS1V15_CRYPT );
+        case MBEDTLS_KEY_EXCHANGE_DHE_RSA:
+        case MBEDTLS_KEY_EXCHANGE_ECDHE_RSA:
+            return( PSA_ALG_RSA_PKCS1V15_SIGN(
+                        mbedtls_psa_translate_md( info->mac ) ) );
+
+        case MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA:
+            return( PSA_ALG_ECDSA( mbedtls_psa_translate_md( info->mac ) ) );
+
+        case MBEDTLS_KEY_EXCHANGE_ECDH_RSA:
+        case MBEDTLS_KEY_EXCHANGE_ECDH_ECDSA:
+            return( PSA_ALG_ECDH );
+
+        default:
+            return( PSA_ALG_NONE );
+    }
+}
+
+psa_key_usage_t mbedtls_ssl_get_ciphersuite_sig_pk_ext_usage( const mbedtls_ssl_ciphersuite_t *info )
+{
+    switch( info->key_exchange )
+    {
+        case MBEDTLS_KEY_EXCHANGE_RSA:
+        case MBEDTLS_KEY_EXCHANGE_RSA_PSK:
+            return( PSA_KEY_USAGE_DECRYPT );
+        case MBEDTLS_KEY_EXCHANGE_DHE_RSA:
+        case MBEDTLS_KEY_EXCHANGE_ECDHE_RSA:
+        case MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA:
+            return( PSA_KEY_USAGE_SIGN_HASH );
+
+        case MBEDTLS_KEY_EXCHANGE_ECDH_RSA:
+        case MBEDTLS_KEY_EXCHANGE_ECDH_ECDSA:
+            return( PSA_KEY_USAGE_DERIVE );
+
+        default:
+            return( 0 );
+    }
+}
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
 mbedtls_pk_type_t mbedtls_ssl_get_ciphersuite_sig_alg( const mbedtls_ssl_ciphersuite_t *info )
 {
     switch( info->key_exchange )
diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c
index 21e5cda..161a5eb 100644
--- a/library/ssl_tls12_server.c
+++ b/library/ssl_tls12_server.c
@@ -682,8 +682,15 @@
                           const mbedtls_ssl_ciphersuite_t * ciphersuite_info )
 {
     mbedtls_ssl_key_cert *cur, *list;
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_algorithm_t pk_alg =
+        mbedtls_ssl_get_ciphersuite_sig_pk_ext_alg( ciphersuite_info );
+    psa_key_usage_t pk_usage =
+        mbedtls_ssl_get_ciphersuite_sig_pk_ext_usage( ciphersuite_info );
+#else
     mbedtls_pk_type_t pk_alg =
         mbedtls_ssl_get_ciphersuite_sig_pk_alg( ciphersuite_info );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     uint32_t flags;
 
 #if defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
@@ -693,7 +700,11 @@
 #endif
         list = ssl->conf->key_cert;
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if( pk_alg == PSA_ALG_NONE )
+#else
     if( pk_alg == MBEDTLS_PK_NONE )
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
         return( 0 );
 
     MBEDTLS_SSL_DEBUG_MSG( 3, ( "ciphersuite requires certificate" ) );
@@ -710,7 +721,18 @@
         MBEDTLS_SSL_DEBUG_CRT( 3, "candidate certificate chain, certificate",
                           cur->cert );
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE)
+        if( ( ssl->conf->f_async_sign_start == NULL &&
+              ssl->conf->f_async_decrypt_start == NULL &&
+              ! mbedtls_pk_can_do_ext( cur->key, pk_alg, pk_usage ) ) ||
+            ! mbedtls_pk_can_do_ext( &cur->cert->pk, pk_alg, pk_usage ) )
+#else
+        if( ! mbedtls_pk_can_do_ext( cur->key, pk_alg, pk_usage ) )
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE */
+#else
         if( ! mbedtls_pk_can_do( &cur->cert->pk, pk_alg ) )
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
         {
             MBEDTLS_SSL_DEBUG_MSG( 3, ( "certificate mismatch: key type" ) );
             continue;