Make mbedtls_ssl_psk_derive_premaster() only for when MBEDTLS_USE_PSA_CRYPTO is not selected

Signed-off-by: Neil Armstrong <narmstrong@baylibre.com>
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index b9f456f..b0424bf 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -1307,8 +1307,10 @@
                                          size_t msg_len );
 
 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
 int mbedtls_ssl_psk_derive_premaster( mbedtls_ssl_context *ssl,
                                       mbedtls_key_exchange_type_t key_ex );
+#endif /* !MBEDTLS_USE_PSA_CRYPTO */
 #if defined(MBEDTLS_SSL_CLI_C) && defined(MBEDTLS_SSL_PROTO_TLS1_2)
 int mbedtls_ssl_conf_has_static_psk( mbedtls_ssl_config const *conf );
 #endif
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index dd8ebeb..f2a3e1a 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -5530,18 +5530,12 @@
 }
 #endif /* MBEDTLS_SHA384_C */
 
-#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+#if !defined(MBEDTLS_USE_PSA_CRYPTO) &&                      \
+    defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
 int mbedtls_ssl_psk_derive_premaster( mbedtls_ssl_context *ssl, mbedtls_key_exchange_type_t key_ex )
 {
-#if !defined(MBEDTLS_USE_PSA_CRYPTO) ||                 \
-    defined(MBEDTLS_KEY_EXCHANGE_DHE_PSK_ENABLED)
     unsigned char *p = ssl->handshake->premaster;
     unsigned char *end = p + sizeof( ssl->handshake->premaster );
-#else
-    (void)ssl;
-    (void)key_ex;
-#endif /* !MBEDTLS_USE_PSA_CRYPTO || MBEDTLS_KEY_EXCHANGE_DHE_PSK_ENABLED */
-#if !defined(MBEDTLS_USE_PSA_CRYPTO)
     const unsigned char *psk = NULL;
     size_t psk_len = 0;
     int psk_ret = mbedtls_ssl_get_psk( ssl, &psk, &psk_len );
@@ -5602,7 +5596,6 @@
     }
     else
 #endif /* MBEDTLS_KEY_EXCHANGE_RSA_PSK_ENABLED */
-#endif /* !MBEDTLS_USE_PSA_CRYPTO */
 #if defined(MBEDTLS_KEY_EXCHANGE_DHE_PSK_ENABLED)
     if( key_ex == MBEDTLS_KEY_EXCHANGE_DHE_PSK )
     {
@@ -5624,8 +5617,7 @@
     }
     else
 #endif /* MBEDTLS_KEY_EXCHANGE_DHE_PSK_ENABLED */
-#if !defined(MBEDTLS_USE_PSA_CRYPTO) &&                 \
-    defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED)
+#if defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED)
     if( key_ex == MBEDTLS_KEY_EXCHANGE_ECDHE_PSK )
     {
         int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
@@ -5646,13 +5638,12 @@
                                 MBEDTLS_DEBUG_ECDH_Z );
     }
     else
-#endif /* !MBEDTLS_USE_PSA_CRYPTO && MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED */
+#endif /* MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED */
     {
         MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) );
         return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
     }
 
-#if !defined(MBEDTLS_USE_PSA_CRYPTO)
     /* opaque psk<0..2^16-1>; */
     if( end - p < 2 )
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
@@ -5667,11 +5658,10 @@
     p += psk_len;
 
     ssl->handshake->pmslen = p - ssl->handshake->premaster;
-#endif /* !MBEDTLS_USE_PSA_CRYPTO */
 
     return( 0 );
 }
-#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
+#endif /* !MBEDTLS_USE_PSA_CRYPTO && MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
 
 #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_RENEGOTIATION)
 static int ssl_write_hello_request( mbedtls_ssl_context *ssl );
diff --git a/library/ssl_tls12_client.c b/library/ssl_tls12_client.c
index e154442..a22d97f 100644
--- a/library/ssl_tls12_client.c
+++ b/library/ssl_tls12_client.c
@@ -3126,6 +3126,25 @@
                 MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_dhm_make_public", ret );
                 return( ret );
             }
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+            unsigned char *pms = ssl->handshake->premaster;
+            unsigned char *pms_end = pms + sizeof( ssl->handshake->premaster );
+            size_t pms_len;
+
+            /* Write length only when we know the actual value */
+            if( ( ret = mbedtls_dhm_calc_secret( &ssl->handshake->dhm_ctx,
+                                          pms + 2, pms_end - ( pms + 2 ), &pms_len,
+                                          ssl->conf->f_rng, ssl->conf->p_rng ) ) != 0 )
+            {
+                MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_dhm_calc_secret", ret );
+                return( ret );
+            }
+            MBEDTLS_PUT_UINT16_BE( pms_len, pms, 0 );
+            pms += 2 + pms_len;
+
+            MBEDTLS_SSL_DEBUG_MPI( 3, "DHM: K ", &ssl->handshake->dhm_ctx.K  );
+#endif
         }
         else
 #endif /* MBEDTLS_KEY_EXCHANGE_DHE_PSK_ENABLED */
@@ -3157,21 +3176,15 @@
             return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
         }
 
-#if defined(MBEDTLS_USE_PSA_CRYPTO) &&          \
-    defined(MBEDTLS_KEY_EXCHANGE_PSK_ENABLED)
-        if( ciphersuite_info->key_exchange != MBEDTLS_KEY_EXCHANGE_PSK &&
-            ciphersuite_info->key_exchange != MBEDTLS_KEY_EXCHANGE_RSA_PSK )
-#endif /* MBEDTLS_USE_PSA_CRYPTO &&
-          MBEDTLS_KEY_EXCHANGE_PSK_ENABLED */
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
+        if( ( ret = mbedtls_ssl_psk_derive_premaster( ssl,
+                        ciphersuite_info->key_exchange ) ) != 0 )
         {
-            if( ( ret = mbedtls_ssl_psk_derive_premaster( ssl,
-                            ciphersuite_info->key_exchange ) ) != 0 )
-            {
-                MBEDTLS_SSL_DEBUG_RET( 1,
+            MBEDTLS_SSL_DEBUG_RET( 1,
                     "mbedtls_ssl_psk_derive_premaster", ret );
-                return( ret );
-            }
+            return( ret );
         }
+#endif /* !MBEDTLS_USE_PSA_CRYPTO */
     }
     else
 #endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c
index 0c8f0f5..29cbe75 100644
--- a/library/ssl_tls12_server.c
+++ b/library/ssl_tls12_server.c
@@ -4053,12 +4053,31 @@
             return( MBEDTLS_ERR_SSL_DECODE_ERROR );
         }
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        unsigned char *pms = ssl->handshake->premaster;
+        unsigned char *pms_end = pms + sizeof( ssl->handshake->premaster );
+        size_t pms_len;
+
+        /* Write length only when we know the actual value */
+        if( ( ret = mbedtls_dhm_calc_secret( &ssl->handshake->dhm_ctx,
+                                      pms + 2, pms_end - ( pms + 2 ), &pms_len,
+                                      ssl->conf->f_rng, ssl->conf->p_rng ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_dhm_calc_secret", ret );
+            return( ret );
+        }
+        MBEDTLS_PUT_UINT16_BE( pms_len, pms, 0 );
+        pms += 2 + pms_len;
+
+        MBEDTLS_SSL_DEBUG_MPI( 3, "DHM: K ", &ssl->handshake->dhm_ctx.K  );
+#else
         if( ( ret = mbedtls_ssl_psk_derive_premaster( ssl,
                         ciphersuite_info->key_exchange ) ) != 0 )
         {
             MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_psk_derive_premaster", ret );
             return( ret );
         }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     }
     else
 #endif /* MBEDTLS_KEY_EXCHANGE_DHE_PSK_ENABLED */