Share parse_key_share() between client and server

Change-Id: I3fd2604296dc0e1e8380f5405429a6b0feb6e981
Signed-off-by: XiaokangQian <xiaokang.qian@arm.com>
diff --git a/library/ecp_internal.h b/library/ecp_internal.h
deleted file mode 100644
index ccd860f..0000000
--- a/library/ecp_internal.h
+++ /dev/null
@@ -1,45 +0,0 @@
-
-/**
- * \file ecp_internal.h
- *
- * \brief ECC-related functions with external linkage but which are
- *        not part of the public API.
- */
-/*
- *  Copyright The Mbed TLS Contributors
- *  SPDX-License-Identifier: Apache-2.0
- *
- *  Licensed under the Apache License, Version 2.0 (the "License"); you may
- *  not use this file except in compliance with the License.
- *  You may obtain a copy of the License at
- *
- *  http://www.apache.org/licenses/LICENSE-2.0
- *
- *  Unless required by applicable law or agreed to in writing, software
- *  distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- *  WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- *  See the License for the specific language governing permissions and
- *  limitations under the License.
- */
-#ifndef MBEDTLS_ECP_INTERNAL_H
-#define MBEDTLS_ECP_INTERNAL_H
-
-#include "common.h"
-#include "mbedtls/ecp.h"
-#include "mbedtls/ecdh.h"
-
-static inline mbedtls_ecp_group_id mbedtls_ecp_named_group_to_id(
-    uint16_t named_curve )
-{
-    const mbedtls_ecp_curve_info *curve_info;
-    curve_info = mbedtls_ecp_curve_info_from_tls_id( named_curve );
-    if( curve_info == NULL )
-        return( MBEDTLS_ECP_DP_NONE );
-    return( curve_info->grp_id );
-}
-
-int mbedtls_ecdh_import_public_raw( mbedtls_ecdh_context *ctx,
-                                    const unsigned char *buf,
-                                    const unsigned char *end );
-
-#endif /* MBEDTLS_ECP_INTERNAL_H */
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index f39f78d..b23fc1d 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -2176,4 +2176,12 @@
 }
 #endif /* MBEDTLS_USE_PSA_CRYPTO || MBEDTLS_SSL_PROTO_TLS1_3 */
 
+#if defined(MBEDTLS_ECDH_C)
+
+int mbedtls_ssl_tls13_read_public_ecdhe_share( mbedtls_ssl_context *ssl,
+                                               const unsigned char *buf,
+                                               size_t buf_len );
+
+#endif /* MBEDTLS_ECDH_C */
+
 #endif /* ssl_misc.h */
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index b05d2f2..198c20a 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -417,30 +417,6 @@
     return( ret );
 }
 
-#if defined(MBEDTLS_ECDH_C)
-
-static int ssl_tls13_read_public_ecdhe_share( mbedtls_ssl_context *ssl,
-                                              const unsigned char *buf,
-                                              size_t buf_len )
-{
-    uint8_t *p = (uint8_t*)buf;
-    mbedtls_ssl_handshake_params *handshake = ssl->handshake;
-
-    /* Get size of the TLS opaque key_exchange field of the KeyShareEntry struct. */
-    uint16_t peerkey_len = MBEDTLS_GET_UINT16_BE( p, 0 );
-    p += 2;
-
-    /* Check if key size is consistent with given buffer length. */
-    if ( peerkey_len > ( buf_len - 2 ) )
-        return( MBEDTLS_ERR_SSL_DECODE_ERROR );
-
-    /* Store peer's ECDH public key. */
-    memcpy( handshake->ecdh_psa_peerkey, p, peerkey_len );
-    handshake->ecdh_psa_peerkey_len = peerkey_len;
-
-    return( 0 );
-}
-#endif /* MBEDTLS_ECDH_C */
 
 /*
  * ssl_tls13_parse_hrr_key_share_ext()
@@ -565,7 +541,7 @@
 
         MBEDTLS_SSL_DEBUG_MSG( 2, ( "ECDH curve: %s", curve_info->name ) );
 
-        ret = ssl_tls13_read_public_ecdhe_share( ssl, p, end - p );
+        ret = mbedtls_ssl_tls13_read_public_ecdhe_share( ssl, p, end - p );
         if( ret != 0 )
             return( ret );
     }
diff --git a/library/ssl_tls13_generic.c b/library/ssl_tls13_generic.c
index 18a66ec..a6bcac3 100644
--- a/library/ssl_tls13_generic.c
+++ b/library/ssl_tls13_generic.c
@@ -30,7 +30,6 @@
 #include "mbedtls/constant_time.h"
 #include <string.h>
 
-#include "ecp_internal.h"
 #include "ssl_misc.h"
 #include "ssl_tls13_keys.h"
 #include "ssl_debug_helpers.h"
@@ -1512,59 +1511,29 @@
     return( ret );
 }
 
-#define ECDH_VALIDATE_RET( cond )    \
-    MBEDTLS_INTERNAL_VALIDATE_RET( cond, MBEDTLS_ERR_ECP_BAD_INPUT_DATA )
+#if defined(MBEDTLS_ECDH_C)
 
-#if !defined(MBEDTLS_ECDH_LEGACY_CONTEXT)
-static int ecdh_import_public_raw( mbedtls_ecdh_context_mbed *ctx,
-                                   const unsigned char *buf,
-                                   const unsigned char *end )
+int mbedtls_ssl_tls13_read_public_ecdhe_share( mbedtls_ssl_context *ssl,
+                                               const unsigned char *buf,
+                                               size_t buf_len )
 {
-    return( mbedtls_ecp_point_read_binary( &ctx->grp, &ctx->Qp,
-                                           buf, end - buf ) );
-}
-#endif /* MBEDTLS_ECDH_LEGACY_CONTEXT */
+    uint8_t *p = (uint8_t*)buf;
+    mbedtls_ssl_handshake_params *handshake = ssl->handshake;
 
-#if defined(MBEDTLS_ECDH_VARIANT_EVEREST_ENABLED)
-static int everest_import_public_raw( mbedtls_x25519_context *ctx,
-                                      const unsigned char *buf,
-                                      const unsigned char *end )
-{
-    if( end - buf != MBEDTLS_X25519_KEY_SIZE_BYTES )
-        return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
+    /* Get size of the TLS opaque key_exchange field of the KeyShareEntry struct. */
+    uint16_t peerkey_len = MBEDTLS_GET_UINT16_BE( p, 0 );
+    p += 2;
 
-    memcpy( ctx->peer_point, buf, MBEDTLS_X25519_KEY_SIZE_BYTES );
+    /* Check if key size is consistent with given buffer length. */
+    if ( peerkey_len > ( buf_len - 2 ) )
+        return( MBEDTLS_ERR_SSL_DECODE_ERROR );
+
+    /* Store peer's ECDH public key. */
+    memcpy( handshake->ecdh_psa_peerkey, p, peerkey_len );
+    handshake->ecdh_psa_peerkey_len = peerkey_len;
+
     return( 0 );
 }
-#endif /* MBEDTLS_ECDH_VARIANT_EVEREST_ENABLED */
-
-int mbedtls_ecdh_import_public_raw( mbedtls_ecdh_context *ctx,
-                                    const unsigned char *buf,
-                                    const unsigned char *end )
-{
-    ECDH_VALIDATE_RET( ctx != NULL );
-    ECDH_VALIDATE_RET( buf != NULL );
-    ECDH_VALIDATE_RET( end != NULL );
-#if defined(MBEDTLS_ECDH_LEGACY_CONTEXT)
-    ((void) ctx);
-    ((void) buf);
-    ((void) end);
-    return ( 0 );
-#else
-    switch( ctx->var )
-    {
-#if defined(MBEDTLS_ECDH_VARIANT_EVEREST_ENABLED)
-        case MBEDTLS_ECDH_VARIANT_EVEREST:
-            return( everest_import_public_raw( &ctx->ctx.everest_ecdh.ctx,
-                                               buf, end) );
-#endif /* MBEDTLS_ECDH_VARIANT_EVEREST_ENABLED */
-        case MBEDTLS_ECDH_VARIANT_MBEDTLS_2_0:
-            return( ecdh_import_public_raw( &ctx->ctx.mbed_ecdh,
-                                            buf, end ) );
-        default:
-            return MBEDTLS_ERR_ECP_BAD_INPUT_DATA;
-    }
-#endif /* MBEDTLS_ECDH_LEGACY_CONTEXT */
-}
+#endif /* MBEDTLS_ECDH_C */
 
 #endif /* MBEDTLS_SSL_TLS_C && MBEDTLS_SSL_PROTO_TLS1_3 */
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 9f55fe7..f002595 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -29,7 +29,6 @@
 #include <string.h>
 #if defined(MBEDTLS_ECP_C)
 #include "mbedtls/ecp.h"
-#include "ecp_internal.h"
 #endif /* MBEDTLS_ECP_C */
 
 #if defined(MBEDTLS_PLATFORM_C)
@@ -238,9 +237,7 @@
 
     for( ; p < extentions_end; p += cur_share_len )
     {
-        uint16_t their_group;
-        mbedtls_ecp_group_id their_curve;
-        unsigned char const *end_of_share;
+        uint16_t group;
 
         /*
          * struct {
@@ -250,13 +247,11 @@
          */
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extentions_end, 4 );
 
-        their_group = MBEDTLS_GET_UINT16_BE( p, 0 );
-        p   += 2;
+        group = MBEDTLS_GET_UINT16_BE( p, 0 );
+        p += 2;
 
         cur_share_len = MBEDTLS_GET_UINT16_BE( p, 0 );
-        p   += 2;
-
-        end_of_share = p + cur_share_len;
+        p += 2;
 
         /* Continue parsing even if we have already found a match,
          * for input validation purposes.
@@ -268,60 +263,39 @@
          * NamedGroup matching
          *
          * For now, we only support ECDHE groups, but e.g.
-         * PQC KEMs will need to be added at a later stage.
-         */
 
-        /* Type 1: ECDHE shares
+         * Type 1: ECDHE shares
          *
          * - Check if we recognize the group
          * - Check if it's supported
          */
 
-        their_curve = mbedtls_ecp_named_group_to_id( their_group );
-        if( mbedtls_ssl_check_curve( ssl, their_curve ) != 0 )
-            continue;
+        if( mbedtls_ssl_tls13_named_group_is_ecdhe( group ) )
+        {
+            const mbedtls_ecp_curve_info *curve_info =
+                mbedtls_ecp_curve_info_from_tls_id( group );
+            if( curve_info == NULL )
+            {
+                MBEDTLS_SSL_DEBUG_MSG( 1, ( "Invalid TLS curve group id" ) );
+                return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+            }
 
-        /* Skip if we no match succeeded. */
-        if( their_curve == MBEDTLS_ECP_DP_NONE )
+            match_found = 1;
+
+            MBEDTLS_SSL_DEBUG_MSG( 2, ( "ECDH curve: %s", curve_info->name ) );
+
+            ret = mbedtls_ssl_tls13_read_public_ecdhe_share( ssl, p, end - p );
+            if( ret != 0 )
+                return( ret );
+        }
+        else
         {
             MBEDTLS_SSL_DEBUG_MSG( 4, ( "Unrecognized NamedGroup %u",
-                                        (unsigned) their_group ) );
+                                        (unsigned) group ) );
             continue;
         }
 
-        match_found = 1;
-
-        /* KeyShare parsing
-         *
-         * Once we add more key share types, this needs to be a switch
-         * over the (type of) the named curve
-         */
-
-        /* Type 1: ECDHE shares
-         *
-         * - Setup ECDHE context
-         * - Import client's public key
-         * - Apply further curve checks
-         */
-
-        MBEDTLS_SSL_DEBUG_MSG( 2, ( "ECDH curve: %ud", their_curve ) );
-
-        ret = mbedtls_ecdh_setup( &ssl->handshake->ecdh_ctx, their_curve );
-        if( ret != 0 )
-        {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ecdh_setup()", ret );
-            return( ret );
-        }
-
-        ret = mbedtls_ecdh_import_public_raw( &ssl->handshake->ecdh_ctx,
-                                              p, end_of_share );
-        if( ret != 0 )
-        {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ecdh_import_public_raw()", ret );
-            return( ret );
-        }
-
-        ssl->handshake->offered_group_id = their_group;
+        ssl->handshake->offered_group_id = group;
     }
 
     if( match_found == 0 )