Align ciphersuite with overwrite.

Selected ciphersuite MUST be same with ciphsersuite of PSK.
Overwrite the old ciphersuite with the one of PSK.

Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index b37fe5a..e4ff3b1 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -106,8 +106,10 @@
 static int ssl_tls13_offered_psks_check_identity_match(
                mbedtls_ssl_context *ssl,
                const unsigned char *identity,
-               size_t identity_len )
+               size_t identity_len,
+               int *psk_type )
 {
+    *psk_type = MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL;
     /* Check identity with external configured function */
     if( ssl->conf->f_psk != NULL )
     {
@@ -137,12 +139,11 @@
 static int ssl_tls13_offered_psks_check_binder_match( mbedtls_ssl_context *ssl,
                                                       const unsigned char *binder,
                                                       size_t binder_len,
-                                                      mbedtls_md_type_t *psk_alg )
+                                                      int psk_type,
+                                                      mbedtls_md_type_t psk_alg )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    int psk_type;
 
-    mbedtls_md_type_t md_alg;
     psa_algorithm_t psa_md_alg;
     unsigned char transcript[PSA_HASH_MAX_SIZE];
     size_t transcript_len;
@@ -150,22 +151,9 @@
     size_t psk_len;
     unsigned char server_computed_binder[PSA_HASH_MAX_SIZE];
 
-    *psk_alg = MBEDTLS_MD_NONE;
-    psk_type = MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL;
-    switch( binder_len )
-    {
-        case 32:
-            md_alg = MBEDTLS_MD_SHA256;
-            break;
-        case 48:
-            md_alg = MBEDTLS_MD_SHA384;
-            break;
-        default:
-            return( MBEDTLS_SSL_ALERT_MSG_DECRYPT_ERROR );
-    }
-    psa_md_alg = mbedtls_psa_translate_md( md_alg );
+    psa_md_alg = mbedtls_psa_translate_md( psk_alg );
     /* Get current state of handshake transcript. */
-    ret = mbedtls_ssl_get_handshake_transcript( ssl, md_alg,
+    ret = mbedtls_ssl_get_handshake_transcript( ssl, psk_alg,
                                                 transcript, sizeof( transcript ),
                                                 &transcript_len );
     if( ret != 0 )
@@ -194,7 +182,6 @@
 
     if( mbedtls_ct_memcmp( server_computed_binder, binder, binder_len ) == 0 )
     {
-        *psk_alg = md_alg;
         return( SSL_TLS1_3_OFFERED_PSK_MATCH );
     }
 
@@ -203,6 +190,70 @@
     return( SSL_TLS1_3_OFFERED_PSK_NOT_MATCH );
 }
 
+MBEDTLS_CHECK_RETURN_CRITICAL
+static int ssl_tls13_psk_external_check_ciphersuites( mbedtls_ssl_context *ssl,
+                                                      const unsigned char *buf,
+                                                      const unsigned char *end,
+                                                      size_t binder_len,
+                                                      uint16_t *selected_cipher_suite )
+{
+    mbedtls_md_type_t psk_alg;
+
+    *selected_cipher_suite = 0;
+
+    switch( binder_len )
+    {
+#if defined(MBEDTLS_SHA256_C)
+        case 32:
+            psk_alg = MBEDTLS_MD_SHA256;
+            break;
+#endif
+#if defined(MBEDTLS_SHA384_C)
+        case 48:
+            psk_alg = MBEDTLS_MD_SHA384;
+            break;
+#endif
+        default:
+            return( MBEDTLS_SSL_ALERT_MSG_DECRYPT_ERROR );
+    }
+    /*
+     * Search for a matching ciphersuite
+     */
+    for ( const unsigned char *p = buf ; p < end ; p += 2 )
+    {
+        uint16_t cipher_suite;
+        const mbedtls_ssl_ciphersuite_t* ciphersuite_info;
+
+        MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, 2 );
+
+        cipher_suite = MBEDTLS_GET_UINT16_BE( p, 0 );
+        if( ! mbedtls_ssl_tls13_cipher_suite_is_offered( ssl, cipher_suite ) )
+            continue;
+
+        ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( cipher_suite );
+        if( ( mbedtls_ssl_validate_ciphersuite(
+                ssl, ciphersuite_info, ssl->tls_version,
+                ssl->tls_version ) != 0 ) )
+        {
+            continue;
+        }
+
+        /* MAC of selected ciphersuite MUST be same with PSK binder if exist.
+         * Otherwise, client should reject.
+         */
+        if( psk_alg != MBEDTLS_MD_NONE && psk_alg != ciphersuite_info->mac )
+            continue;
+
+        *selected_cipher_suite = cipher_suite;
+
+        MBEDTLS_SSL_DEBUG_MSG( 5, ( "PSK matched ciphersuite: %04x - %s",
+                                    cipher_suite,
+                                    ciphersuite_info->name ) );
+        return( 0 );
+    }
+
+    return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+}
 /* Parser for pre_shared_key extension in client hello
  *    struct {
  *        opaque identity<1..2^16-1>;
@@ -227,7 +278,8 @@
 static int ssl_tls13_parse_pre_shared_key_ext( mbedtls_ssl_context *ssl,
                                                const unsigned char *buf,
                                                const unsigned char *end,
-                                               mbedtls_md_type_t *psk_alg )
+                                               const unsigned char *ciphersuites,
+                                               const unsigned char *ciphersuites_end )
 {
     const unsigned char *identities = buf;
     const unsigned char *p_identity_len;
@@ -240,8 +292,6 @@
     int matched_identity = -1;
     int identity_id = -1;
 
-    *psk_alg = MBEDTLS_MD_NONE;
-
     MBEDTLS_SSL_DEBUG_BUF( 3, "pre_shared_key extension", buf, end - buf );
 
     /* identities_len       2 bytes
@@ -272,7 +322,9 @@
         const unsigned char *binder;
         size_t binder_len;
         int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-        mbedtls_md_type_t alg;
+        int psk_type;
+        uint16_t cipher_suite;
+        const mbedtls_ssl_ciphersuite_t* ciphersuite_info;
 
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p_identity_len, identities_end, 2 + 1 + 4 );
         identity_len = MBEDTLS_GET_UINT16_BE( p_identity_len, 0 );
@@ -291,12 +343,34 @@
             continue;
 
         ret = ssl_tls13_offered_psks_check_identity_match(
-                                            ssl, identity, identity_len );
+                                    ssl, identity, identity_len, &psk_type );
         if( ret != SSL_TLS1_3_OFFERED_PSK_MATCH )
             continue;
 
+        if( psk_type == MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL )
+        {
+            ret = ssl_tls13_psk_external_check_ciphersuites(
+                            ssl, ciphersuites, ciphersuites_end,
+                            binder_len, &cipher_suite );
+            if( ret < 0 )
+            {
+                /* See below, no cipher_suite available, abort handshake */
+                MBEDTLS_SSL_PEND_FATAL_ALERT(
+                    MBEDTLS_SSL_ALERT_MSG_DECRYPT_ERROR,
+                    MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+                return( ret );
+            }
+            ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( cipher_suite );
+        }
+        else
+        {
+            MBEDTLS_SSL_DEBUG_MSG( 4, ( "`psk_type = %d` not support yet",
+                                        psk_type ) );
+            return( MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE );
+        }
+
         ret = ssl_tls13_offered_psks_check_binder_match(
-                                            ssl, binder, binder_len, &alg );
+                ssl, binder, binder_len, psk_type, ciphersuite_info->mac );
         /* For the security rationale, handshake should be abort when binder
          * value mismatch. See RFC 8446 section 4.2.11.2 and appendix E.6. */
         if( ret != SSL_TLS1_3_OFFERED_PSK_MATCH )
@@ -311,7 +385,16 @@
         }
 
         matched_identity = identity_id;
-        *psk_alg = alg;
+
+        /* Update handshake parameters */
+        if( psk_type == MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL )
+        {
+            ssl->session_negotiate->ciphersuite = cipher_suite;
+            ssl->handshake->ciphersuite_info = ciphersuite_info;
+            MBEDTLS_SSL_DEBUG_MSG( 2, ( "overwrite ciphersuite: %04x - %s",
+                                        cipher_suite,
+                                        ciphersuite_info->name ) );
+        }
     }
 
     if( p_identity_len != identities_end || p_binder_len != binders_end )
@@ -915,7 +998,6 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     const unsigned char *p = buf;
     size_t legacy_session_id_len;
-    const unsigned char *cipher_suites;
     size_t cipher_suites_len;
     const unsigned char *cipher_suites_end;
     size_t extensions_len;
@@ -923,9 +1005,9 @@
     int hrr_required = 0;
 
 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+    const unsigned char *cipher_suites;
     const unsigned char *pre_shared_key_ext_start = NULL;
     const unsigned char *pre_shared_key_ext_end = NULL;
-    mbedtls_md_type_t psk_alg = MBEDTLS_MD_NONE;
 #endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
 
     ssl->handshake->extensions_present = MBEDTLS_SSL_EXT_NONE;
@@ -1033,26 +1115,48 @@
     * with CipherSuite defined as:
     * uint8 CipherSuite[2];
     */
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
     cipher_suites = p;
+#endif
     cipher_suites_end = p + cipher_suites_len;
     MBEDTLS_SSL_DEBUG_BUF( 3, "client hello, ciphersuitelist",
                           p, cipher_suites_len );
-#if defined(MBEDTLS_DEBUG_C)
+
+    /*
+     * Search for a matching ciphersuite
+     */
     for ( ; p < cipher_suites_end; p += 2 )
     {
         uint16_t cipher_suite;
         const mbedtls_ssl_ciphersuite_t* ciphersuite_info;
+
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, cipher_suites_end, 2 );
+
         cipher_suite = MBEDTLS_GET_UINT16_BE( p, 0 );
+        if( ! mbedtls_ssl_tls13_cipher_suite_is_offered( ssl, cipher_suite ) )
+            continue;
+
         ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( cipher_suite );
-        MBEDTLS_SSL_DEBUG_MSG( 2, ( "client hello, received ciphersuite: %04x - %s",
+        if( ( mbedtls_ssl_validate_ciphersuite(
+                ssl, ciphersuite_info, ssl->tls_version,
+                ssl->tls_version ) != 0 ) )
+        {
+            continue;
+        }
+
+        ssl->session_negotiate->ciphersuite = cipher_suite;
+        ssl->handshake->ciphersuite_info = ciphersuite_info;
+        MBEDTLS_SSL_DEBUG_MSG( 2, ( "selected ciphersuite: %04x - %s",
                                     cipher_suite,
-                                    ciphersuite_info == NULL ?
-                                        "Unkown": ciphersuite_info->name ) );
+                                    ciphersuite_info->name ) );
     }
-#else /* MBEDTLS_DEBUG_C */
-    p = cipher_suites_end;
-#endif /* !MBEDTLS_DEBUG_C */
+
+    if( ssl->handshake->ciphersuite_info == NULL )
+    {
+        MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE,
+                                      MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+        return ( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+    }
 
     /* ...
      * opaque legacy_compression_methods<1..2^8-1>;
@@ -1301,7 +1405,8 @@
         ret = ssl_tls13_parse_pre_shared_key_ext( ssl,
                                                   pre_shared_key_ext_start,
                                                   pre_shared_key_ext_end,
-                                                  &psk_alg );
+                                                  cipher_suites,
+                                                  cipher_suites_end );
         if( ret == MBEDTLS_ERR_SSL_UNKNOWN_IDENTITY )
         {
             ssl->handshake->extensions_present &= ~MBEDTLS_SSL_EXT_PRE_SHARED_KEY;
@@ -1319,51 +1424,6 @@
         ssl->handshake->update_checksum( ssl, buf, p - buf );
     }
 
-    /*
-     * Search for a matching ciphersuite
-     */
-    for ( const unsigned char * p_chiper_suite = cipher_suites ;
-          p_chiper_suite < cipher_suites_end; p_chiper_suite += 2 )
-    {
-        uint16_t cipher_suite;
-        const mbedtls_ssl_ciphersuite_t* ciphersuite_info;
-
-        MBEDTLS_SSL_CHK_BUF_READ_PTR( p_chiper_suite, cipher_suites_end, 2 );
-
-        cipher_suite = MBEDTLS_GET_UINT16_BE( p_chiper_suite, 0 );
-        if( ! mbedtls_ssl_tls13_cipher_suite_is_offered( ssl, cipher_suite ) )
-            continue;
-
-        ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( cipher_suite );
-        if( ( mbedtls_ssl_validate_ciphersuite(
-                ssl, ciphersuite_info, ssl->tls_version,
-                ssl->tls_version ) != 0 ) )
-        {
-            continue;
-        }
-
-#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
-        /* MAC of selected ciphersuite MUST be same with PSK binder if exist.
-         * Otherwise, client should reject.
-         */
-        if( psk_alg != MBEDTLS_MD_NONE && psk_alg != ciphersuite_info->mac )
-            continue;
-#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
-
-        ssl->session_negotiate->ciphersuite = cipher_suite;
-        ssl->handshake->ciphersuite_info = ciphersuite_info;
-        MBEDTLS_SSL_DEBUG_MSG( 2, ( "selected ciphersuite: %04x - %s",
-                                    cipher_suite,
-                                    ciphersuite_info->name ) );
-    }
-
-    if( ssl->handshake->ciphersuite_info == NULL )
-    {
-        MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE,
-                                      MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-        return ( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-    }
-
     ret = ssl_tls13_determine_key_exchange_mode( ssl );
     if( ret < 0 )
         return( ret );