Add logic to switch to TLS 1.2

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 211d23a..7c4e6fc 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -38,6 +38,7 @@
 #endif /* !MBEDTLS_PLATFORM_C */
 
 #include "mbedtls/ssl.h"
+#include "ssl_client.h"
 #include "ssl_debug_helpers.h"
 #include "ssl_misc.h"
 #include "mbedtls/debug.h"
@@ -2828,15 +2829,28 @@
         MBEDTLS_SSL_DEBUG_MSG( 2, ( "client state: %s",
                                     mbedtls_ssl_states_str( ssl->state ) ) );
 
-#if defined(MBEDTLS_SSL_PROTO_TLS1_3)
-        if( mbedtls_ssl_conf_is_tls13_only( ssl->conf ) )
-            ret = mbedtls_ssl_tls13_handshake_client_step( ssl );
-#endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
+        switch( ssl->state )
+        {
+            case MBEDTLS_SSL_HELLO_REQUEST:
+                ssl->state = MBEDTLS_SSL_CLIENT_HELLO;
+                break;
 
-#if defined(MBEDTLS_SSL_PROTO_TLS1_2)
-        if( mbedtls_ssl_conf_is_tls12_only( ssl->conf ) )
-            ret = mbedtls_ssl_handshake_client_step( ssl );
-#endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
+            case MBEDTLS_SSL_CLIENT_HELLO:
+                ret = mbedtls_ssl_write_client_hello( ssl );
+                break;
+
+            default:
+#if defined(MBEDTLS_SSL_PROTO_TLS1_2) && defined(MBEDTLS_SSL_PROTO_TLS1_3)
+                if( ssl->minor_ver == MBEDTLS_SSL_MINOR_VERSION_4 )
+                    ret = mbedtls_ssl_tls13_handshake_client_step( ssl );
+                else
+                    ret = mbedtls_ssl_handshake_client_step( ssl );
+#elif defined(MBEDTLS_SSL_PROTO_TLS1_2)
+                ret = mbedtls_ssl_handshake_client_step( ssl );
+#else
+                ret = mbedtls_ssl_tls13_handshake_client_step( ssl );
+#endif
+        }
     }
 #endif
 #if defined(MBEDTLS_SSL_SRV_C)
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index fc8ffa7..0f7448f 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -710,6 +710,75 @@
 /*
  * Functions for parsing and processing Server Hello
  */
+static int ssl_tls13_is_supported_versions_ext_present(
+    mbedtls_ssl_context *ssl,
+    const unsigned char *buf,
+    const unsigned char *end )
+{
+    const unsigned char *p = buf;
+    size_t legacy_session_id_echo_len;
+    size_t extensions_len;
+    const unsigned char *extensions_end;
+
+    /*
+     * Check there is enough data to access the legacy_session_id_echo vector
+     * length.
+     * - legacy_version,         2 bytes
+     * - random                  MBEDTLS_SERVER_HELLO_RANDOM_LEN bytes
+     * - legacy_session_id_echo  1 byte
+     */
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, MBEDTLS_SERVER_HELLO_RANDOM_LEN + 3 );
+    p += MBEDTLS_SERVER_HELLO_RANDOM_LEN + 2;
+    legacy_session_id_echo_len = *p;
+
+    /*
+     * Jump to the extensions, jumping over:
+     * - legacy_session_id_echo     (legacy_session_id_echo_len + 1) bytes
+     * - cipher_suite               2 bytes
+     * - legacy_compression_method  1 byte
+     */
+     p += legacy_session_id_echo_len + 4;
+
+    /* Case of no extension */
+    if( p == end )
+        return( 0 );
+
+    /* ...
+     * Extension extensions<6..2^16-1>;
+     * ...
+     * struct {
+     *      ExtensionType extension_type; (2 bytes)
+     *      opaque extension_data<0..2^16-1>;
+     * } Extension;
+     */
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, 2 );
+    extensions_len = MBEDTLS_GET_UINT16_BE( p, 0 );
+    p += 2;
+
+    /* Check extensions do not go beyond the buffer of data. */
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, extensions_len );
+    extensions_end = p + extensions_len;
+
+    while( p < extensions_end )
+    {
+        unsigned int extension_type;
+        size_t extension_data_len;
+
+        MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_end, 4 );
+        extension_type = MBEDTLS_GET_UINT16_BE( p, 0 );
+        extension_data_len = MBEDTLS_GET_UINT16_BE( p, 2 );
+        p += 4;
+
+        if( extension_type == MBEDTLS_TLS_EXT_SUPPORTED_VERSIONS )
+            return( 1 );
+
+        MBEDTLS_SSL_CHK_BUF_READ_PTR( p, extensions_end, extension_data_len );
+        p += extension_data_len;
+    }
+
+    return( 0 );
+}
+
 /* Returns a negative value on failure, and otherwise
  * - SSL_SERVER_HELLO_COORDINATE_HELLO or
  * - SSL_SERVER_HELLO_COORDINATE_HRR
@@ -755,8 +824,10 @@
 /* Fetch and preprocess
  * Returns a negative value on failure, and otherwise
  * - SSL_SERVER_HELLO_COORDINATE_HELLO or
- * - SSL_SERVER_HELLO_COORDINATE_HRR
+ * - SSL_SERVER_HELLO_COORDINATE_HRR or
+ * - SSL_SERVER_HELLO_COORDINATE_TLS1_2
  */
+#define SSL_SERVER_HELLO_COORDINATE_TLS1_2 2
 static int ssl_tls13_server_hello_coordinate( mbedtls_ssl_context *ssl,
                                               unsigned char **buf,
                                               size_t *buf_len )
@@ -767,6 +838,36 @@
                                              MBEDTLS_SSL_HS_SERVER_HELLO,
                                              buf, buf_len ) );
 
+    MBEDTLS_SSL_PROC_CHK_NEG( ssl_tls13_is_supported_versions_ext_present(
+                                  ssl, *buf, *buf + *buf_len ) );
+    if( ret == 0 )
+    {
+        /* If the supported versions extension is not present but we were
+         * expecting it, abort the handshake. Otherwise, switch to TLS 1.2
+         * handshake.
+         */
+        if( ssl->handshake->min_minor_ver > MBEDTLS_SSL_MINOR_VERSION_3 )
+        {
+            MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
+                                          MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
+            return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
+        }
+
+        ssl->keep_current_message = 1;
+        ssl->minor_ver = MBEDTLS_SSL_MINOR_VERSION_3;
+        mbedtls_ssl_add_hs_msg_to_checksum( ssl, MBEDTLS_SSL_HS_SERVER_HELLO,
+                                            *buf, *buf_len );
+
+        if( mbedtls_ssl_conf_tls13_some_ephemeral_enabled( ssl ) )
+        {
+            ret = ssl_tls13_reset_key_share( ssl );
+            if( ret != 0 )
+                return( ret );
+        }
+
+        return( SSL_SERVER_HELLO_COORDINATE_TLS1_2 );
+    }
+
     ret = ssl_server_hello_is_hrr( ssl, *buf, *buf + *buf_len );
     switch( ret )
     {
@@ -888,7 +989,6 @@
     const unsigned char *extensions_end;
     uint16_t cipher_suite;
     const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
-    int supported_versions_ext_found = 0;
     int fatal_alert = 0;
 
     /*
@@ -1068,10 +1168,6 @@
                 break;
 
             case MBEDTLS_TLS_EXT_SUPPORTED_VERSIONS:
-                supported_versions_ext_found = 1;
-                MBEDTLS_SSL_DEBUG_MSG( 3,
-                            ( "found supported_versions extension" ) );
-
                 ret = ssl_tls13_parse_supported_versions_ext( ssl,
                                                               p,
                                                               extension_data_end );
@@ -1122,13 +1218,6 @@
         p += extension_data_len;
     }
 
-    if( !supported_versions_ext_found )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "supported_versions not found" ) );
-        fatal_alert = MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER;
-        goto cleanup;
-    }
-
 cleanup:
 
     if( fatal_alert == MBEDTLS_SSL_ALERT_MSG_UNSUPPORTED_EXT )
@@ -1317,6 +1406,12 @@
     else
         is_hrr = ( ret == SSL_SERVER_HELLO_COORDINATE_HRR );
 
+    if( ret == SSL_SERVER_HELLO_COORDINATE_TLS1_2 )
+    {
+        ret = 0;
+        goto cleanup;
+    }
+
     MBEDTLS_SSL_PROC_CHK( ssl_tls13_parse_server_hello( ssl, buf,
                                                         buf + buf_len,
                                                         is_hrr ) );