tls13: srv: Factorize ciphersuite selection code

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index d0e8195..25279b3 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -39,6 +39,62 @@
     return ciphersuite_info;
 }
 
+static void ssl_tls13_select_ciphersuite(
+    mbedtls_ssl_context *ssl,
+    const unsigned char *cipher_suites,
+    const unsigned char *cipher_suites_end,
+    int psk_ciphersuite_id,
+    psa_algorithm_t psk_hash_alg,
+    const mbedtls_ssl_ciphersuite_t **selected_ciphersuite_info)
+{
+    *selected_ciphersuite_info = NULL;
+
+    /*
+     * This function assumes that the length of the list of ciphersuites is
+     * even and it should have been checked it is so in the main ClientHello
+     * parsing function. Double check here.
+     */
+    if ((cipher_suites_end - cipher_suites) & 1) {
+        return;
+    }
+
+    for (const unsigned char *p = cipher_suites;
+         p < cipher_suites_end; p += 2) {
+        /*
+         * "cipher_suites_end - p is even" is an invariant of the loop. As
+         * cipher_suites_end - p > 0, we have cipher_suites_end - p >= 2 and it
+         * is thus safe to read two bytes.
+         */
+        uint16_t id = MBEDTLS_GET_UINT16_BE(p, 0);
+
+        const mbedtls_ssl_ciphersuite_t *info =
+            ssl_tls13_validate_peer_ciphersuite(ssl, id);
+        if (info == NULL) {
+            continue;
+        }
+
+        /*
+         * If a valid PSK ciphersuite identifier has been passed in, we seek
+         * for an exact match.
+         */
+        if (psk_ciphersuite_id != 0) {
+            if (id != psk_ciphersuite_id) {
+                continue;
+            }
+        } else if (psk_hash_alg != PSA_ALG_NONE) {
+            if (mbedtls_md_psa_alg_from_type((mbedtls_md_type_t) info->mac) !=
+                psk_hash_alg) {
+                continue;
+            }
+        }
+
+        *selected_ciphersuite_info = info;
+        return;
+    }
+
+    MBEDTLS_SSL_DEBUG_MSG(2, ("No matched ciphersuite"));
+}
+
 #if defined(MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_SOME_PSK_ENABLED)
 /* From RFC 8446:
  *
@@ -386,92 +442,8 @@
     return SSL_TLS1_3_BINDER_DOES_NOT_MATCH;
 }
 
-MBEDTLS_CHECK_RETURN_CRITICAL
-static int ssl_tls13_select_ciphersuite_for_psk(
-    mbedtls_ssl_context *ssl,
-    const unsigned char *cipher_suites,
-    const unsigned char *cipher_suites_end,
-    uint16_t *selected_ciphersuite,
-    const mbedtls_ssl_ciphersuite_t **selected_ciphersuite_info)
-{
-    psa_algorithm_t psk_hash_alg = PSA_ALG_SHA_256;
-
-    *selected_ciphersuite = 0;
-    *selected_ciphersuite_info = NULL;
-
-    /* RFC 8446, page 55.
-     *
-     * For externally established PSKs, the Hash algorithm MUST be set when the
-     * PSK is established or default to SHA-256 if no such algorithm is defined.
-     *
-     */
-
-    /*
-     * Search for a matching ciphersuite
-     */
-    for (const unsigned char *p = cipher_suites;
-         p < cipher_suites_end; p += 2) {
-        uint16_t cipher_suite;
-        const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
-
-        cipher_suite = MBEDTLS_GET_UINT16_BE(p, 0);
-        ciphersuite_info = ssl_tls13_validate_peer_ciphersuite(ssl,
-                                                               cipher_suite);
-        if (ciphersuite_info == NULL) {
-            continue;
-        }
-
-        /* MAC of selected ciphersuite MUST be same with PSK binder if exist.
-         * Otherwise, client should reject.
-         */
-        if (psk_hash_alg ==
-            mbedtls_md_psa_alg_from_type((mbedtls_md_type_t) ciphersuite_info->mac)) {
-            *selected_ciphersuite = cipher_suite;
-            *selected_ciphersuite_info = ciphersuite_info;
-            return 0;
-        }
-    }
-    MBEDTLS_SSL_DEBUG_MSG(2, ("No matched ciphersuite"));
-    return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
-}
-
 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
 MBEDTLS_CHECK_RETURN_CRITICAL
-static int ssl_tls13_select_ciphersuite_for_resumption(
-    mbedtls_ssl_context *ssl,
-    const unsigned char *cipher_suites,
-    const unsigned char *cipher_suites_end,
-    mbedtls_ssl_session *session,
-    uint16_t *selected_ciphersuite,
-    const mbedtls_ssl_ciphersuite_t **selected_ciphersuite_info)
-{
-
-    *selected_ciphersuite = 0;
-    *selected_ciphersuite_info = NULL;
-    for (const unsigned char *p = cipher_suites; p < cipher_suites_end; p += 2) {
-        uint16_t cipher_suite = MBEDTLS_GET_UINT16_BE(p, 0);
-        const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
-
-        if (cipher_suite != session->ciphersuite) {
-            continue;
-        }
-
-        ciphersuite_info = ssl_tls13_validate_peer_ciphersuite(ssl,
-                                                               cipher_suite);
-        if (ciphersuite_info == NULL) {
-            continue;
-        }
-
-        *selected_ciphersuite = cipher_suite;
-        *selected_ciphersuite_info = ciphersuite_info;
-
-        return 0;
-    }
-
-    return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
-}
-
-MBEDTLS_CHECK_RETURN_CRITICAL
 static int ssl_tls13_session_copy_ticket(mbedtls_ssl_session *dst,
                                          const mbedtls_ssl_session *src)
 {
@@ -569,7 +541,8 @@
         const unsigned char *binder;
         size_t binder_len;
         int psk_type;
-        uint16_t cipher_suite;
+        int psk_ciphersuite_id;
+        psa_algorithm_t psk_hash_alg;
         const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
         mbedtls_ssl_session session;
@@ -602,33 +575,39 @@
         }
 
         MBEDTLS_SSL_DEBUG_MSG(4, ("found matched identity"));
+
         switch (psk_type) {
             case MBEDTLS_SSL_TLS1_3_PSK_EXTERNAL:
-                ret = ssl_tls13_select_ciphersuite_for_psk(
-                    ssl, ciphersuites, ciphersuites_end,
-                    &cipher_suite, &ciphersuite_info);
+                psk_ciphersuite_id = 0;
+                psk_hash_alg = PSA_ALG_SHA_256;
                 break;
 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
             case MBEDTLS_SSL_TLS1_3_PSK_RESUMPTION:
-                ret = ssl_tls13_select_ciphersuite_for_resumption(
-                    ssl, ciphersuites, ciphersuites_end, &session,
-                    &cipher_suite, &ciphersuite_info);
-                if (ret != 0) {
-                    mbedtls_ssl_session_free(&session);
-                }
+                psk_ciphersuite_id = session.ciphersuite;
+                psk_hash_alg = PSA_ALG_NONE;
                 break;
 #endif
             default:
                 return MBEDTLS_ERR_SSL_INTERNAL_ERROR;
         }
-        if (ret != 0) {
-            /* See below, no cipher_suite available, abort handshake */
+
+        ssl_tls13_select_ciphersuite(ssl, ciphersuites, ciphersuites_end,
+                                     psk_ciphersuite_id, psk_hash_alg,
+                                     &ciphersuite_info);
+
+        if (ciphersuite_info == NULL) {
+#if defined(MBEDTLS_SSL_SESSION_TICKETS)
+            mbedtls_ssl_session_free(&session);
+#endif
+            /*
+             * We consider finding a ciphersuite suitable for the PSK as part
+             * of the validation of its binder. Thus if we do not find one, we
+             * abort the handshake with a decrypt_error alert.
+             */
             MBEDTLS_SSL_PEND_FATAL_ALERT(
                 MBEDTLS_SSL_ALERT_MSG_DECRYPT_ERROR,
                 MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE);
-            MBEDTLS_SSL_DEBUG_RET(
-                2, "ssl_tls13_select_ciphersuite", ret);
-            return ret;
+            return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
         }
 
         ret = ssl_tls13_offered_psks_check_binder_match(
@@ -654,9 +633,10 @@
 
         /* Update handshake parameters */
         ssl->handshake->ciphersuite_info = ciphersuite_info;
-        ssl->session_negotiate->ciphersuite = cipher_suite;
+        ssl->session_negotiate->ciphersuite = ciphersuite_info->id;
         MBEDTLS_SSL_DEBUG_MSG(2, ("overwrite ciphersuite: %04x - %s",
-                                  cipher_suite, ciphersuite_info->name));
+                                  ((unsigned) ciphersuite_info->id),
+                                  ciphersuite_info->name));
 #if defined(MBEDTLS_SSL_SESSION_TICKETS)
         if (psk_type == MBEDTLS_SSL_TLS1_3_PSK_RESUMPTION) {
             ret = ssl_tls13_session_copy_ticket(ssl->session_negotiate,
@@ -1472,37 +1452,20 @@
      */
     MBEDTLS_SSL_DEBUG_BUF(3, "client hello, list of cipher suites",
                           cipher_suites, cipher_suites_len);
-    for (const unsigned char *cipher_suites_p = cipher_suites;
-         cipher_suites_p < cipher_suites_end; cipher_suites_p += 2) {
-        uint16_t cipher_suite;
-        const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
 
-        /*
-         * "cipher_suites_end - cipher_suites_p is even" is an invariant of the
-         * loop. As cipher_suites_end - cipher_suites_p > 0, we have
-         * cipher_suites_end - cipher_suites_p >= 2 and it is thus safe to read
-         * two bytes.
-         */
-        cipher_suite = MBEDTLS_GET_UINT16_BE(cipher_suites_p, 0);
-        ciphersuite_info = ssl_tls13_validate_peer_ciphersuite(
-            ssl, cipher_suite);
-        if (ciphersuite_info == NULL) {
-            continue;
-        }
-
-        ssl->session_negotiate->ciphersuite = cipher_suite;
-        handshake->ciphersuite_info = ciphersuite_info;
-        MBEDTLS_SSL_DEBUG_MSG(2, ("selected ciphersuite: %04x - %s",
-                                  cipher_suite,
-                                  ciphersuite_info->name));
-        break;
-    }
+    ssl_tls13_select_ciphersuite(ssl, cipher_suites, cipher_suites_end,
+                                 0, PSA_ALG_NONE, &handshake->ciphersuite_info);
 
     if (handshake->ciphersuite_info == NULL) {
         MBEDTLS_SSL_PEND_FATAL_ALERT(MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE,
                                      MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE);
         return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
     }
+    ssl->session_negotiate->ciphersuite = handshake->ciphersuite_info->id;
+
+    MBEDTLS_SSL_DEBUG_MSG(2, ("selected ciphersuite: %04x - %s",
+                              ((unsigned) handshake->ciphersuite_info->id),
+                              handshake->ciphersuite_info->name));
 
     /* ...
      * opaque legacy_compression_methods<1..2^8-1>;