Merge pull request #5456 from mpg/cleanup-ecdh-psa

Cleanup PSA-based ECDHE in TLS 1.2
diff --git a/ChangeLog.d/use-psa-ecdhe-curve.txt b/ChangeLog.d/use-psa-ecdhe-curve.txt
new file mode 100644
index 0000000..cc432bd
--- /dev/null
+++ b/ChangeLog.d/use-psa-ecdhe-curve.txt
@@ -0,0 +1,7 @@
+Bugfix
+   * Fix a bug in (D)TLS curve negotiation: when MBEDTLS_USE_PSA_CRYPTO was
+     enabled and an ECDHE-ECDSA or ECDHE-RSA key exchange was used, the
+     client would fail to check that the curve selected by the server for
+     ECDHE was indeed one that was offered. As a result, the client would
+     accept any curve that it supported, even if that curve was not allowed
+     according to its configuration.
diff --git a/include/mbedtls/psa_util.h b/include/mbedtls/psa_util.h
index c54c035..e38e2e3 100644
--- a/include/mbedtls/psa_util.h
+++ b/include/mbedtls/psa_util.h
@@ -258,85 +258,24 @@
     return( -1 );
 }
 
-#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH 1
+#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH \
+    PSA_KEY_EXPORT_ECC_PUBLIC_KEY_MAX_SIZE( PSA_VENDOR_ECC_MAX_CURVE_BITS )
 
-#if defined(MBEDTLS_ECP_DP_SECP192R1_ENABLED)
-#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 192 + 7 ) / 8 ) + 1 )
-#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
-#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 192 + 7 ) / 8 ) + 1 )
-#endif
-#endif /* MBEDTLS_ECP_DP_SECP192R1_ENABLED */
-
-#if defined(MBEDTLS_ECP_DP_SECP224R1_ENABLED)
-#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 224 + 7 ) / 8 ) + 1 )
-#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
-#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 224 + 7 ) / 8 ) + 1 )
-#endif
-#endif /* MBEDTLS_ECP_DP_SECP224R1_ENABLED */
-
-#if defined(MBEDTLS_ECP_DP_SECP256R1_ENABLED)
-#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 256 + 7 ) / 8 ) + 1 )
-#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
-#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 256 + 7 ) / 8 ) + 1 )
-#endif
-#endif /* MBEDTLS_ECP_DP_SECP256R1_ENABLED */
-
-#if defined(MBEDTLS_ECP_DP_SECP384R1_ENABLED)
-#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 384 + 7 ) / 8 ) + 1 )
-#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
-#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 384 + 7 ) / 8 ) + 1 )
-#endif
-#endif /* MBEDTLS_ECP_DP_SECP384R1_ENABLED */
-
-#if defined(MBEDTLS_ECP_DP_SECP521R1_ENABLED)
-#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 521 + 7 ) / 8 ) + 1 )
-#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
-#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 521 + 7 ) / 8 ) + 1 )
-#endif
-#endif /* MBEDTLS_ECP_DP_SECP521R1_ENABLED */
-
-#if defined(MBEDTLS_ECP_DP_SECP192K1_ENABLED)
-#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 192 + 7 ) / 8 ) + 1 )
-#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
-#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 192 + 7 ) / 8 ) + 1 )
-#endif
-#endif /* MBEDTLS_ECP_DP_SECP192K1_ENABLED */
-
-#if defined(MBEDTLS_ECP_DP_SECP224K1_ENABLED)
-#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 224 + 7 ) / 8 ) + 1 )
-#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
-#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 224 + 7 ) / 8 ) + 1 )
-#endif
-#endif /* MBEDTLS_ECP_DP_SECP224K1_ENABLED */
-
-#if defined(MBEDTLS_ECP_DP_SECP256K1_ENABLED)
-#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 256 + 7 ) / 8 ) + 1 )
-#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
-#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 256 + 7 ) / 8 ) + 1 )
-#endif
-#endif /* MBEDTLS_ECP_DP_SECP256K1_ENABLED */
-
-#if defined(MBEDTLS_ECP_DP_BP256R1_ENABLED)
-#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 256 + 7 ) / 8 ) + 1 )
-#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
-#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 256 + 7 ) / 8 ) + 1 )
-#endif
-#endif /* MBEDTLS_ECP_DP_BP256R1_ENABLED */
-
-#if defined(MBEDTLS_ECP_DP_BP384R1_ENABLED)
-#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 384 + 7 ) / 8 ) + 1 )
-#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
-#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 384 + 7 ) / 8 ) + 1 )
-#endif
-#endif /* MBEDTLS_ECP_DP_BP384R1_ENABLED */
-
-#if defined(MBEDTLS_ECP_DP_BP512R1_ENABLED)
-#if MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH < ( 2 * ( ( 512 + 7 ) / 8 ) + 1 )
-#undef MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH
-#define MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH ( 2 * ( ( 512 + 7 ) / 8 ) + 1 )
-#endif
-#endif /* MBEDTLS_ECP_DP_BP512R1_ENABLED */
-
+/* This function transforms an ECC group identifier from
+ * https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-8
+ * into a PSA ECC group identifier. */
+#if defined(MBEDTLS_ECP_C)
+static inline psa_key_type_t mbedtls_psa_parse_tls_ecc_group(
+    uint16_t tls_ecc_grp_reg_id, size_t *bits )
+{
+    const mbedtls_ecp_curve_info *curve_info =
+        mbedtls_ecp_curve_info_from_tls_id( tls_ecc_grp_reg_id );
+    if( curve_info == NULL )
+        return( 0 );
+    return( PSA_KEY_TYPE_ECC_KEY_PAIR(
+                mbedtls_ecc_group_to_psa( curve_info->grp_id, bits ) ) );
+}
+#endif /* MBEDTLS_ECP_C */
 
 /* Translations for PK layer */
 
@@ -366,63 +305,6 @@
     }
 }
 
-/* Translations for ECC */
-
-/* This function transforms an ECC group identifier from
- * https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-8
- * into a PSA ECC group identifier. */
-#if defined(MBEDTLS_ECP_C)
-static inline psa_key_type_t mbedtls_psa_parse_tls_ecc_group(
-    uint16_t tls_ecc_grp_reg_id, size_t *bits )
-{
-    const mbedtls_ecp_curve_info *curve_info =
-        mbedtls_ecp_curve_info_from_tls_id( tls_ecc_grp_reg_id );
-    if( curve_info == NULL )
-        return( 0 );
-    return( PSA_KEY_TYPE_ECC_KEY_PAIR(
-                mbedtls_ecc_group_to_psa( curve_info->grp_id, bits ) ) );
-}
-#endif /* MBEDTLS_ECP_C */
-
-/* This function takes a buffer holding an EC public key
- * exported through psa_export_public_key(), and converts
- * it into an ECPoint structure to be put into a ClientKeyExchange
- * message in an ECDHE exchange.
- *
- * Both the present and the foreseeable future format of EC public keys
- * used by PSA have the ECPoint structure contained in the exported key
- * as a subbuffer, and the function merely selects this subbuffer instead
- * of making a copy.
- */
-static inline int mbedtls_psa_tls_psa_ec_to_ecpoint( unsigned char *src,
-                                                     size_t srclen,
-                                                     unsigned char **dst,
-                                                     size_t *dstlen )
-{
-    *dst = src;
-    *dstlen = srclen;
-    return( 0 );
-}
-
-/* This function takes a buffer holding an ECPoint structure
- * (as contained in a TLS ServerKeyExchange message for ECDHE
- * exchanges) and converts it into a format that the PSA key
- * agreement API understands.
- */
-static inline int mbedtls_psa_tls_ecpoint_to_psa_ec( unsigned char const *src,
-                                                     size_t srclen,
-                                                     unsigned char *dst,
-                                                     size_t dstlen,
-                                                     size_t *olen )
-{
-    if( srclen > dstlen )
-        return( MBEDTLS_ERR_ECP_BUFFER_TOO_SMALL );
-
-    memcpy( dst, src, srclen );
-    *olen = srclen;
-    return( 0 );
-}
-
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 /* Expose whatever RNG the PSA subsystem uses to applications using the
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index 825034a..694473f 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -2334,12 +2334,7 @@
 
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "ECDH curve: %s", curve_info->name ) );
 
-#if defined(MBEDTLS_ECP_C)
     if( mbedtls_ssl_check_curve( ssl, grp_id ) != 0 )
-#else
-    if( ssl->handshake->ecdh_ctx.grp.nbits < 163 ||
-        ssl->handshake->ecdh_ctx.grp.nbits > 521 )
-#endif
         return( -1 );
 
     MBEDTLS_SSL_DEBUG_ECDH( 3, &ssl->handshake->ecdh_ctx,
@@ -2366,9 +2361,16 @@
     mbedtls_ssl_handshake_params *handshake = ssl->handshake;
 
     /*
-     * Parse ECC group
+     * struct {
+     *     ECParameters curve_params;
+     *     ECPoint      public;
+     * } ServerECDHParams;
+     *
+     *  1       curve_type (must be "named_curve")
+     *  2..3    NamedCurve
+     *  4       ECPoint.len
+     *  5+      ECPoint contents
      */
-
     if( end - *p < 4 )
         return( MBEDTLS_ERR_SSL_DECODE_ERROR );
 
@@ -2381,6 +2383,15 @@
     tls_id <<= 8;
     tls_id |= *(*p)++;
 
+    /* Check it's a curve we offered */
+    if( mbedtls_ssl_check_curve_tls_id( ssl, tls_id ) != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_MSG( 2,
+            ( "bad server key exchange message (ECDHE curve): %u",
+              (unsigned) tls_id ) );
+        return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+    }
+
     /* Convert EC group to PSA key type. */
     if( ( handshake->ecdh_psa_type =
           mbedtls_psa_parse_tls_ecc_group( tls_id, &ecdh_bits ) ) == 0 )
@@ -2391,24 +2402,18 @@
         return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
     handshake->ecdh_bits = (uint16_t) ecdh_bits;
 
-    /*
-     * Put peer's ECDH public key in the format understood by PSA.
-     */
-
+    /* Keep a copy of the peer's public key */
     ecpoint_len = *(*p)++;
     if( (size_t)( end - *p ) < ecpoint_len )
         return( MBEDTLS_ERR_SSL_DECODE_ERROR );
 
-    if( mbedtls_psa_tls_ecpoint_to_psa_ec(
-                                    *p, ecpoint_len,
-                                    handshake->ecdh_psa_peerkey,
-                                    sizeof( handshake->ecdh_psa_peerkey ),
-                                    &handshake->ecdh_psa_peerkey_len ) != 0 )
-    {
-        return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
-    }
+    if( ecpoint_len > sizeof( handshake->ecdh_psa_peerkey ) )
+        return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
 
+    memcpy( handshake->ecdh_psa_peerkey, *p, ecpoint_len );
+    handshake->ecdh_psa_peerkey_len = ecpoint_len;
     *p += ecpoint_len;
+
     return( 0 );
 }
 #endif /* MBEDTLS_USE_PSA_CRYPTO &&
@@ -3373,11 +3378,6 @@
 
         mbedtls_ssl_handshake_params *handshake = ssl->handshake;
 
-        unsigned char own_pubkey[MBEDTLS_PSA_MAX_EC_PUBKEY_LENGTH];
-        size_t own_pubkey_len;
-        unsigned char *own_pubkey_ecpoint;
-        size_t own_pubkey_ecpoint_len;
-
         header_len = 4;
 
         MBEDTLS_SSL_DEBUG_MSG( 1, ( "Perform PSA-based ECDH computation." ) );
@@ -3405,27 +3405,22 @@
         if( status != PSA_SUCCESS )
             return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
 
-        /* Export the public part of the ECDH private key from PSA
-         * and convert it to ECPoint format used in ClientKeyExchange. */
+        /* Export the public part of the ECDH private key from PSA.
+         * The export format is an ECPoint structure as expected by TLS,
+         * but we just need to add a length byte before that. */
+        unsigned char *own_pubkey = ssl->out_msg + header_len + 1;
+        unsigned char *end = ssl->out_msg + MBEDTLS_SSL_OUT_CONTENT_LEN;
+        size_t own_pubkey_max_len = (size_t)( end - own_pubkey );
+        size_t own_pubkey_len;
+
         status = psa_export_public_key( handshake->ecdh_psa_privkey,
-                                        own_pubkey, sizeof( own_pubkey ),
+                                        own_pubkey, own_pubkey_max_len,
                                         &own_pubkey_len );
         if( status != PSA_SUCCESS )
             return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
 
-        if( mbedtls_psa_tls_psa_ec_to_ecpoint( own_pubkey,
-                                               own_pubkey_len,
-                                               &own_pubkey_ecpoint,
-                                               &own_pubkey_ecpoint_len ) != 0 )
-        {
-            return( MBEDTLS_ERR_SSL_HW_ACCEL_FAILED );
-        }
-
-        /* Copy ECPoint structure to outgoing message buffer. */
-        ssl->out_msg[header_len] = (unsigned char) own_pubkey_ecpoint_len;
-        memcpy( ssl->out_msg + header_len + 1,
-                own_pubkey_ecpoint, own_pubkey_ecpoint_len );
-        content_len = own_pubkey_ecpoint_len + 1;
+        ssl->out_msg[header_len] = (unsigned char) own_pubkey_len;
+        content_len = own_pubkey_len + 1;
 
         /* The ECDH secret is the premaster secret used for key derivation. */
 
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 8cb9576..4d753c8 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -1314,6 +1314,7 @@
 unsigned char mbedtls_ssl_hash_from_md_alg( int md );
 int mbedtls_ssl_set_calc_verify_md( mbedtls_ssl_context *ssl, int md );
 
+int mbedtls_ssl_check_curve_tls_id( const mbedtls_ssl_context *ssl, uint16_t tls_id );
 #if defined(MBEDTLS_ECP_C)
 int mbedtls_ssl_check_curve( const mbedtls_ssl_context *ssl, mbedtls_ecp_group_id grp_id );
 #endif
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 988fafb..df17ca5 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -7216,18 +7216,16 @@
     }
 }
 
-#if defined(MBEDTLS_ECP_C)
 /*
  * Check if a curve proposed by the peer is in our list.
  * Return 0 if we're willing to use it, -1 otherwise.
  */
-int mbedtls_ssl_check_curve( const mbedtls_ssl_context *ssl, mbedtls_ecp_group_id grp_id )
+int mbedtls_ssl_check_curve_tls_id( const mbedtls_ssl_context *ssl, uint16_t tls_id )
 {
     const uint16_t *group_list = mbedtls_ssl_get_groups( ssl );
 
     if( group_list == NULL )
         return( -1 );
-    uint16_t tls_id = mbedtls_ecp_curve_info_from_grp_id(grp_id)->tls_id;
 
     for( ; *group_list != 0; group_list++ )
     {
@@ -7237,6 +7235,16 @@
 
     return( -1 );
 }
+
+#if defined(MBEDTLS_ECP_C)
+/*
+ * Same as mbedtls_ssl_check_curve_tls_id() but with a mbedtls_ecp_group_id.
+ */
+int mbedtls_ssl_check_curve( const mbedtls_ssl_context *ssl, mbedtls_ecp_group_id grp_id )
+{
+    uint16_t tls_id = mbedtls_ecp_curve_info_from_grp_id( grp_id )->tls_id;
+    return mbedtls_ssl_check_curve_tls_id( ssl, tls_id );
+}
 #endif /* MBEDTLS_ECP_C */
 
 #if defined(MBEDTLS_X509_CRT_PARSE_C)