Align ciphersuite with psk key

With OpenSSL and GnuTLS client, if the MAC of ciphersuite
does not match selected binder, client will reject connection.
This change is to select ciphersuite base on algo of psk binder.

Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 36a8119..91e6f4e 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -136,7 +136,8 @@
 MBEDTLS_CHECK_RETURN_CRITICAL
 static int ssl_tls13_offered_psks_check_binder_match( mbedtls_ssl_context *ssl,
                                                       const unsigned char *binder,
-                                                      size_t binder_len )
+                                                      size_t binder_len,
+                                                      mbedtls_md_type_t *psk_alg )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     int psk_type;
@@ -149,6 +150,7 @@
     size_t psk_len;
     unsigned char server_computed_binder[PSA_HASH_MAX_SIZE];
 
+    *psk_alg = MBEDTLS_MD_NONE;
     psk_type = MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL;
     switch( binder_len )
     {
@@ -192,6 +194,7 @@
 
     if( mbedtls_ct_memcmp( server_computed_binder, binder, binder_len ) == 0 )
     {
+        *psk_alg = md_alg;
         return( SSL_TLS1_3_OFFERED_PSK_MATCH );
     }
 
@@ -223,7 +226,8 @@
 MBEDTLS_CHECK_RETURN_CRITICAL
 static int ssl_tls13_parse_pre_shared_key_ext( mbedtls_ssl_context *ssl,
                                                const unsigned char *buf,
-                                               const unsigned char *end )
+                                               const unsigned char *end,
+                                               mbedtls_md_type_t *psk_alg )
 {
     const unsigned char *identities = buf;
     const unsigned char *p_identity_len;
@@ -236,6 +240,8 @@
     int matched_identity = -1;
     int identity_id = -1;
 
+    *psk_alg = MBEDTLS_MD_NONE;
+
     MBEDTLS_SSL_DEBUG_BUF( 3, "pre_shared_key extension", buf, end - buf );
 
     /* identities_len       2 bytes
@@ -266,6 +272,7 @@
         const unsigned char *binder;
         size_t binder_len;
         int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+        mbedtls_md_type_t alg;
 
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p_identity_len, identities_end, 2 + 1 + 4 );
         identity_len = MBEDTLS_GET_UINT16_BE( p_identity_len, 0 );
@@ -286,11 +293,11 @@
 
         ret = ssl_tls13_offered_psks_check_identity_match(
                                             ssl, identity, identity_len );
-        if( SSL_TLS1_3_OFFERED_PSK_NOT_MATCH == ret )
+        if( ret != SSL_TLS1_3_OFFERED_PSK_MATCH )
             continue;
 
         ret = ssl_tls13_offered_psks_check_binder_match(
-                                            ssl, binder, binder_len );
+                                            ssl, binder, binder_len, &alg );
         if( ret != SSL_TLS1_3_OFFERED_PSK_MATCH )
         {
             MBEDTLS_SSL_DEBUG_RET( 1,
@@ -300,10 +307,9 @@
                 MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
             return( ret );
         }
-        if( SSL_TLS1_3_OFFERED_PSK_NOT_MATCH == ret )
-            continue;
 
         matched_identity = identity_id;
+        *psk_alg = alg;
     }
 
     if( p_identity_len != identities_end || p_binder_len != binders_end )
@@ -914,10 +920,10 @@
     const unsigned char *extensions_end;
     int hrr_required = 0;
 
-    const mbedtls_ssl_ciphersuite_t* ciphersuite_info;
 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
     const unsigned char *pre_shared_key_ext_start = NULL;
     const unsigned char *pre_shared_key_ext_end = NULL;
+    mbedtls_md_type_t psk_alg = MBEDTLS_MD_NONE;
 #endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
 
     ssl->handshake->extensions_present = MBEDTLS_SSL_EXT_NONE;
@@ -1000,7 +1006,7 @@
                            p, legacy_session_id_len );
     /*
      * Check we have enough data for the legacy session identifier
-     * and the ciphersuite list  length.
+     * and the ciphersuite list length.
      */
     MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, legacy_session_id_len + 2 );
 
@@ -1012,59 +1018,42 @@
 
     /* Check we have enough data for the ciphersuite list, the legacy
      * compression methods and the length of the extensions.
+     *
+     * cipher_suites                cipher_suites_len bytes
+     * legacy_compression_methods                   2 bytes
+     * extensions_len                               2 bytes
      */
     MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, cipher_suites_len + 2 + 2 );
 
-   /* ...
-    * CipherSuite cipher_suites<2..2^16-2>;
-    * ...
-    * with CipherSuite defined as:
-    * uint8 CipherSuite[2];
+   /*
+    * uint8 CipherSuite[2];    // Cryptographic suite selector
+    *
+    * struct {
+    *     ...
+    *     CipherSuite cipher_suites<2..2^16-2>;
+    *     ...
+    * } ClientHello;
     */
     cipher_suites = p;
     cipher_suites_end = p + cipher_suites_len;
     MBEDTLS_SSL_DEBUG_BUF( 3, "client hello, ciphersuitelist",
                           p, cipher_suites_len );
-    /*
-     * Search for a matching ciphersuite
-     */
-    int ciphersuite_match = 0;
+#if defined(MBEDTLS_DEBUG_C)
     for ( ; p < cipher_suites_end; p += 2 )
     {
         uint16_t cipher_suite;
+        const mbedtls_ssl_ciphersuite_t* ciphersuite_info;
         MBEDTLS_SSL_CHK_BUF_READ_PTR( p, cipher_suites_end, 2 );
         cipher_suite = MBEDTLS_GET_UINT16_BE( p, 0 );
         ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( cipher_suite );
-        /*
-         * Check whether this ciphersuite is valid and offered.
-         */
-        if( ( mbedtls_ssl_validate_ciphersuite(
-                ssl, ciphersuite_info, ssl->tls_version,
-                ssl->tls_version ) != 0 ) ||
-            ! mbedtls_ssl_tls13_cipher_suite_is_offered( ssl, cipher_suite ) )
-        {
-            continue;
-        }
-
-        ssl->session_negotiate->ciphersuite = cipher_suite;
-        ssl->handshake->ciphersuite_info = ciphersuite_info;
-        ciphersuite_match = 1;
-
-        break;
-
+        MBEDTLS_SSL_DEBUG_MSG( 2, ( "client hello, received ciphersuite: %04x - %s",
+                                    cipher_suite,
+                                    ciphersuite_info == NULL ?
+                                        "Unkown": ciphersuite_info->name ) );
     }
-
-    if( ! ciphersuite_match )
-    {
-        MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE,
-                                      MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-        return ( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
-    }
-
-    MBEDTLS_SSL_DEBUG_MSG( 2, ( "selected ciphersuite: %s",
-                                ciphersuite_info->name ) );
-
-    p = cipher_suites + cipher_suites_len;
+#else
+    p = cipher_suites_end;
+#endif /* MBEDTLS_DEBUG_C */
 
     /* ...
      * opaque legacy_compression_methods<1..2^8-1>;
@@ -1298,6 +1287,7 @@
                                         MBEDTLS_SSL_HS_CLIENT_HELLO,
                                         p - buf );
 
+/* TODO: move later */
 #if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
     /* Update checksum with either
      * - The entire content of the CH message, if no PSK extension is present
@@ -1311,7 +1301,8 @@
                                          pre_shared_key_ext_start - buf );
         ret = ssl_tls13_parse_pre_shared_key_ext( ssl,
                                                   pre_shared_key_ext_start,
-                                                  pre_shared_key_ext_end );
+                                                  pre_shared_key_ext_end,
+                                                  &psk_alg );
         if( ret == MBEDTLS_ERR_SSL_UNKNOWN_IDENTITY)
         {
             ssl->handshake->extensions_present &= ~MBEDTLS_SSL_EXT_PRE_SHARED_KEY;
@@ -1329,6 +1320,51 @@
         ssl->handshake->update_checksum( ssl, buf, p - buf );
     }
 
+    /*
+     * Search for a matching ciphersuite
+     */
+    for ( const unsigned char * p_chiper_suite = cipher_suites ;
+            p_chiper_suite < cipher_suites_end; p_chiper_suite += 2 )
+    {
+        uint16_t cipher_suite;
+        const mbedtls_ssl_ciphersuite_t* ciphersuite_info;
+
+        MBEDTLS_SSL_CHK_BUF_READ_PTR( p_chiper_suite, cipher_suites_end, 2 );
+
+        cipher_suite = MBEDTLS_GET_UINT16_BE( p_chiper_suite, 0 );
+        if( ! mbedtls_ssl_tls13_cipher_suite_is_offered( ssl, cipher_suite ) )
+            continue;
+
+        ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( cipher_suite );
+        if( ( mbedtls_ssl_validate_ciphersuite(
+                ssl, ciphersuite_info, ssl->tls_version,
+                ssl->tls_version ) != 0 ) )
+        {
+            continue;
+        }
+
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+        /* MAC of selected ciphersuite MUST be same with PSK binder if exist.
+         * Otherwise, client should reject.
+         */
+        if( psk_alg != MBEDTLS_MD_NONE && psk_alg != ciphersuite_info->mac )
+            continue;
+#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
+
+        ssl->session_negotiate->ciphersuite = cipher_suite;
+        ssl->handshake->ciphersuite_info = ciphersuite_info;
+        MBEDTLS_SSL_DEBUG_MSG( 2, ( "selected ciphersuite: %04x - %s",
+                                    cipher_suite,
+                                    ciphersuite_info->name ) );
+    }
+
+    if( ssl->handshake->ciphersuite_info == NULL )
+    {
+        MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE,
+                                      MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+        return ( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+    }
+
     ret = ssl_tls13_determine_key_exchange_mode( ssl );
     if( ret < 0 )
         return( ret );