Merge pull request #5611 from superna9999/5318-tls-ecdhe-psk

TLS ECDH 3a: ECDHE-PSK (both sides, 1.2)
diff --git a/library/ssl_tls12_client.c b/library/ssl_tls12_client.c
index 7771d38..8250260 100644
--- a/library/ssl_tls12_client.c
+++ b/library/ssl_tls12_client.c
@@ -1813,6 +1813,7 @@
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO) &&                           \
         ( defined(MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED) ||     \
+          defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED) ||     \
           defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED) )
 static int ssl_parse_server_ecdh_params_psa( mbedtls_ssl_context *ssl,
                                              unsigned char **p,
@@ -2346,8 +2347,10 @@
           MBEDTLS_KEY_EXCHANGE_DHE_PSK_ENABLED */
 #if defined(MBEDTLS_USE_PSA_CRYPTO) &&                           \
         ( defined(MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED) ||     \
+          defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED) ||      \
           defined(MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED) )
     if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_RSA ||
+        ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_PSK ||
         ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA )
     {
         if( ssl_parse_server_ecdh_params_psa( ssl, &p, end ) != 0 )
@@ -2363,6 +2366,7 @@
     else
 #endif /* MBEDTLS_USE_PSA_CRYPTO &&
             ( MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED ||
+              MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED ||
               MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED ) */
 #if defined(MBEDTLS_KEY_EXCHANGE_ECDHE_RSA_ENABLED) ||                     \
     defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED) ||                     \
@@ -2996,6 +3000,162 @@
           MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA_ENABLED ||
           MBEDTLS_KEY_EXCHANGE_ECDH_RSA_ENABLED ||
           MBEDTLS_KEY_EXCHANGE_ECDH_ECDSA_ENABLED */
+#if defined(MBEDTLS_USE_PSA_CRYPTO) &&                           \
+    defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED)
+    if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_PSK )
+    {
+        psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+        psa_status_t destruction_status = PSA_ERROR_CORRUPTION_DETECTED;
+        psa_key_attributes_t key_attributes;
+
+        mbedtls_ssl_handshake_params *handshake = ssl->handshake;
+
+        /*
+         * opaque psk_identity<0..2^16-1>;
+         */
+        if( mbedtls_ssl_conf_has_static_psk( ssl->conf ) == 0 )
+            /* We don't offer PSK suites if we don't have a PSK,
+             * and we check that the server's choice is among the
+             * ciphersuites we offered, so this should never happen. */
+            return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+
+        /* Opaque PSKs are currently only supported for PSK-only suites. */
+        if( ssl_conf_has_static_raw_psk( ssl->conf ) == 0 )
+            return( MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE );
+
+        /* uint16 to store content length */
+        const size_t content_len_size = 2;
+
+        header_len = 4;
+
+        if( header_len + content_len_size + ssl->conf->psk_identity_len
+                    > MBEDTLS_SSL_OUT_CONTENT_LEN )
+        {
+            MBEDTLS_SSL_DEBUG_MSG( 1,
+                ( "psk identity too long or SSL buffer too short" ) );
+            return( MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL );
+        }
+
+        unsigned char *p = ssl->out_msg + header_len;
+
+        *p++ = MBEDTLS_BYTE_1( ssl->conf->psk_identity_len );
+        *p++ = MBEDTLS_BYTE_0( ssl->conf->psk_identity_len );
+        header_len += content_len_size;
+
+        memcpy( p, ssl->conf->psk_identity,
+                ssl->conf->psk_identity_len );
+        p += ssl->conf->psk_identity_len;
+
+        header_len += ssl->conf->psk_identity_len;
+
+        MBEDTLS_SSL_DEBUG_MSG( 1, ( "Perform PSA-based ECDH computation." ) );
+
+        /*
+         * Generate EC private key for ECDHE exchange.
+         */
+
+        /* The master secret is obtained from the shared ECDH secret by
+         * applying the TLS 1.2 PRF with a specific salt and label. While
+         * the PSA Crypto API encourages combining key agreement schemes
+         * such as ECDH with fixed KDFs such as TLS 1.2 PRF, it does not
+         * yet support the provisioning of salt + label to the KDF.
+         * For the time being, we therefore need to split the computation
+         * of the ECDH secret and the application of the TLS 1.2 PRF. */
+        key_attributes = psa_key_attributes_init();
+        psa_set_key_usage_flags( &key_attributes, PSA_KEY_USAGE_DERIVE );
+        psa_set_key_algorithm( &key_attributes, PSA_ALG_ECDH );
+        psa_set_key_type( &key_attributes, handshake->ecdh_psa_type );
+        psa_set_key_bits( &key_attributes, handshake->ecdh_bits );
+
+        /* Generate ECDH private key. */
+        status = psa_generate_key( &key_attributes,
+                                   &handshake->ecdh_psa_privkey );
+        if( status != PSA_SUCCESS )
+            return( psa_ssl_status_to_mbedtls( status ) );
+
+        /* Export the public part of the ECDH private key from PSA.
+         * The export format is an ECPoint structure as expected by TLS,
+         * but we just need to add a length byte before that. */
+        unsigned char *own_pubkey = p + 1;
+        unsigned char *end = ssl->out_msg + MBEDTLS_SSL_OUT_CONTENT_LEN;
+        size_t own_pubkey_max_len = (size_t)( end - own_pubkey );
+        size_t own_pubkey_len = 0;
+
+        status = psa_export_public_key( handshake->ecdh_psa_privkey,
+                                        own_pubkey, own_pubkey_max_len,
+                                        &own_pubkey_len );
+        if( status != PSA_SUCCESS )
+        {
+            psa_destroy_key( handshake->ecdh_psa_privkey );
+            handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+            return( psa_ssl_status_to_mbedtls( status ) );
+        }
+
+        *p = (unsigned char) own_pubkey_len;
+        content_len = own_pubkey_len + 1;
+
+        /* As RFC 5489 section 2, the premaster secret is formed as follows:
+         * - a uint16 containing the length (in octets) of the ECDH computation
+         * - the octet string produced by the ECDH computation
+         * - a uint16 containing the length (in octets) of the PSK
+         * - the PSK itself
+         */
+        unsigned char *pms = ssl->handshake->premaster;
+        const unsigned char* const pms_end = pms +
+                                sizeof( ssl->handshake->premaster );
+        /* uint16 to store length (in octets) of the ECDH computation */
+        const size_t zlen_size = 2;
+        size_t zlen = 0;
+
+        /* Perform ECDH computation after the uint16 reserved for the length */
+        status = psa_raw_key_agreement( PSA_ALG_ECDH,
+                                        handshake->ecdh_psa_privkey,
+                                        handshake->ecdh_psa_peerkey,
+                                        handshake->ecdh_psa_peerkey_len,
+                                        pms + zlen_size,
+                                        pms_end - ( pms + zlen_size ),
+                                        &zlen );
+
+        destruction_status = psa_destroy_key( handshake->ecdh_psa_privkey );
+        handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+
+        if( status != PSA_SUCCESS )
+            return( psa_ssl_status_to_mbedtls( status ) );
+        else if( destruction_status != PSA_SUCCESS )
+            return( psa_ssl_status_to_mbedtls( destruction_status ) );
+
+        /* Write the ECDH computation length before the ECDH computation */
+        MBEDTLS_PUT_UINT16_BE( zlen, pms, 0 );
+        pms += zlen_size + zlen;
+
+        const unsigned char *psk = NULL;
+        size_t psk_len = 0;
+
+        if( mbedtls_ssl_get_psk( ssl, &psk, &psk_len )
+                == MBEDTLS_ERR_SSL_PRIVATE_KEY_REQUIRED )
+            /*
+             * This should never happen because the existence of a PSK is always
+             * checked before calling this function
+             */
+            return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+
+        /* opaque psk<0..2^16-1>; */
+        if( (size_t)( pms_end - pms ) < ( 2 + psk_len ) )
+            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+        /* Write the PSK length as uint16 */
+        MBEDTLS_PUT_UINT16_BE( psk_len, pms, 0 );
+        pms += 2;
+
+        /* Write the PSK itself */
+        memcpy( pms, psk, psk_len );
+        pms += psk_len;
+
+        ssl->handshake->pmslen = pms - ssl->handshake->premaster;
+    }
+    else
+#endif /* MBEDTLS_USE_PSA_CRYPTO &&
+          MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED */
 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
     if( mbedtls_ssl_ciphersuite_uses_psk( ciphersuite_info ) )
     {
diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c
index 9ecfdd2..93cd0a5 100644
--- a/library/ssl_tls12_server.c
+++ b/library/ssl_tls12_server.c
@@ -3163,7 +3163,8 @@
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
         if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_RSA ||
-            ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA )
+            ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_ECDSA ||
+            ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_PSK )
         {
             psa_status_t status = PSA_ERROR_GENERIC_ERROR;
             psa_key_attributes_t key_attributes;
@@ -4142,6 +4143,115 @@
     }
     else
 #endif /* MBEDTLS_KEY_EXCHANGE_DHE_PSK_ENABLED */
+#if defined(MBEDTLS_USE_PSA_CRYPTO) &&                           \
+        defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED)
+    if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_PSK )
+    {
+        psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+        psa_status_t destruction_status = PSA_ERROR_CORRUPTION_DETECTED;
+        uint8_t ecpoint_len;
+
+        /* Opaque PSKs are currently only supported for PSK-only. */
+        if( ssl_use_opaque_psk( ssl ) == 1 )
+            return( MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE );
+
+        mbedtls_ssl_handshake_params *handshake = ssl->handshake;
+
+        if( ( ret = ssl_parse_client_psk_identity( ssl, &p, end ) ) != 0 )
+        {
+            MBEDTLS_SSL_DEBUG_RET( 1, ( "ssl_parse_client_psk_identity" ), ret );
+            psa_destroy_key( handshake->ecdh_psa_privkey );
+            handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+            return( ret );
+        }
+
+        /* Keep a copy of the peer's public key */
+        if( p >= end )
+        {
+            psa_destroy_key( handshake->ecdh_psa_privkey );
+            handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+            return( MBEDTLS_ERR_SSL_DECODE_ERROR );
+        }
+
+        ecpoint_len = *(p++);
+        if( (size_t)( end - p ) < ecpoint_len ) {
+            psa_destroy_key( handshake->ecdh_psa_privkey );
+            handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+            return( MBEDTLS_ERR_SSL_DECODE_ERROR );
+        }
+
+        if( ecpoint_len > sizeof( handshake->ecdh_psa_peerkey ) ) {
+            psa_destroy_key( handshake->ecdh_psa_privkey );
+            handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+            return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+        }
+
+        memcpy( handshake->ecdh_psa_peerkey, p, ecpoint_len );
+        handshake->ecdh_psa_peerkey_len = ecpoint_len;
+        p += ecpoint_len;
+
+        /* As RFC 5489 section 2, the premaster secret is formed as follows:
+         * - a uint16 containing the length (in octets) of the ECDH computation
+         * - the octet string produced by the ECDH computation
+         * - a uint16 containing the length (in octets) of the PSK
+         * - the PSK itself
+         */
+        unsigned char *psm = ssl->handshake->premaster;
+        const unsigned char* const psm_end =
+                    psm + sizeof( ssl->handshake->premaster );
+        /* uint16 to store length (in octets) of the ECDH computation */
+        const size_t zlen_size = 2;
+        size_t zlen = 0;
+
+        /* Compute ECDH shared secret. */
+        status = psa_raw_key_agreement( PSA_ALG_ECDH,
+                                        handshake->ecdh_psa_privkey,
+                                        handshake->ecdh_psa_peerkey,
+                                        handshake->ecdh_psa_peerkey_len,
+                                        psm + zlen_size,
+                                        psm_end - ( psm + zlen_size ),
+                                        &zlen );
+
+        destruction_status = psa_destroy_key( handshake->ecdh_psa_privkey );
+        handshake->ecdh_psa_privkey = MBEDTLS_SVC_KEY_ID_INIT;
+
+        if( status != PSA_SUCCESS )
+            return( psa_ssl_status_to_mbedtls( status ) );
+        else if( destruction_status != PSA_SUCCESS )
+            return( psa_ssl_status_to_mbedtls( destruction_status ) );
+
+        /* Write the ECDH computation length before the ECDH computation */
+        MBEDTLS_PUT_UINT16_BE( zlen, psm, 0 );
+        psm += zlen_size + zlen;
+
+        const unsigned char *psk = NULL;
+        size_t psk_len = 0;
+
+        if( mbedtls_ssl_get_psk( ssl, &psk, &psk_len )
+                == MBEDTLS_ERR_SSL_PRIVATE_KEY_REQUIRED )
+            /*
+             * This should never happen because the existence of a PSK is always
+             * checked before calling this function
+             */
+            return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+
+        /* opaque psk<0..2^16-1>; */
+        if( (size_t)( psm_end - psm ) < ( 2 + psk_len ) )
+            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+        /* Write the PSK length as uint16 */
+        MBEDTLS_PUT_UINT16_BE( psk_len, psm, 0 );
+        psm += 2;
+
+        /* Write the PSK itself */
+        memcpy( psm, psk, psk_len );
+        psm += psk_len;
+
+        ssl->handshake->pmslen = psm - ssl->handshake->premaster;
+    }
+    else
+#endif /* MBEDTLS_USE_PSA_CRYPTO &&
+            MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED */
 #if defined(MBEDTLS_KEY_EXCHANGE_ECDHE_PSK_ENABLED)
     if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECDHE_PSK )
     {