Address various issues

Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/ssl_tls13_client.c b/library/ssl_tls13_client.c
index 768caed..15bf43b 100644
--- a/library/ssl_tls13_client.c
+++ b/library/ssl_tls13_client.c
@@ -97,12 +97,12 @@
 
 static int ssl_tls13_parse_supported_versions_ext( mbedtls_ssl_context *ssl,
                                                    const unsigned char *buf,
-                                                   size_t buf_len )
+                                                   const unsigned char *end )
 {
     ((void) ssl);
 
-    if( buf_len != 2 ||
-        buf[0] != MBEDTLS_SSL_MAJOR_VERSION_3 ||
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( buf, end, 2);
+    if( buf[0] != MBEDTLS_SSL_MAJOR_VERSION_3 ||
         buf[1] != MBEDTLS_SSL_MINOR_VERSION_4 )
     {
         MBEDTLS_SSL_DEBUG_MSG( 1, ( "unexpected version" ) );
@@ -497,7 +497,7 @@
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "ECDH curve: %s", curve_info->name ) );
 
     if( mbedtls_ssl_check_curve( ssl, grp_id ) != 0 )
-            return( -1 );
+        return( -1 );
 
     MBEDTLS_SSL_DEBUG_ECDH( 3, &ssl->handshake->ecdh_ctx,
                             MBEDTLS_DEBUG_ECDH_QP );
@@ -505,12 +505,6 @@
     return( 0 );
 }
 
-/* The ssl_tls13_parse_key_share_ext() function is used
- *  by the client to parse a KeyShare extension in
- *  a Server Hello message.
- *
- *  The server only provides a single KeyShareEntry.
- */
 static int ssl_tls13_read_public_ecdhe_share( mbedtls_ssl_context *ssl,
                                               const unsigned char *buf,
                                               size_t buf_len )
@@ -522,12 +516,16 @@
     if( ret != 0 )
     {
         MBEDTLS_SSL_DEBUG_RET( 1, ( "mbedtls_ecdh_tls13_read_public" ), ret );
-        return( ret );
+
+        MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
+                                      MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
+        return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
     }
 
     if( ssl_tls13_check_ecdh_params( ssl ) != 0 )
     {
         MBEDTLS_SSL_DEBUG_MSG( 1, ( "ssl_tls13_check_ecdh_params() failed!" ) );
+
         MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
                                       MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
         return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
@@ -538,7 +536,9 @@
 #endif /* MBEDTLS_ECDH_C */
 
 /*
- * Parse key_share extension in Server Hello
+ * ssl_tls13_parse_key_share_ext()
+ *      Parse key_share extension in Server Hello
+ *
  * struct {
  *        KeyShareEntry server_share;
  * } KeyShareServerHello;
@@ -551,7 +551,7 @@
                                           const unsigned char *buf,
                                           const unsigned char *end )
 {
-    int ret = 0;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     const unsigned char *p = buf;
     uint16_t group, offered_group;
 
@@ -583,8 +583,9 @@
         if( ret != 0 )
             return( ret );
     }
+    else
 #endif /* MBEDTLS_ECDH_C */
-    else if( 0 /* other KEMs? */ )
+    if( 0 /* other KEMs? */ )
     {
         /* Do something */
     }
@@ -883,9 +884,18 @@
 /*
  * Functions for parsing and processing Server Hello
  */
-static int ssl_server_hello_is_hrr( unsigned const char *buf, size_t blen )
+/* Fetch and preprocess
+ * Returns a negative value on failure, and otherwise
+ * - SSL_SERVER_HELLO_COORDINATE_HELLO or
+ * - SSL_SERVER_HELLO_COORDINATE_HRR
+ * to indicate which message is expected and to be parsed next. */
+#define SSL_SERVER_HELLO_COORDINATE_HELLO 0
+#define SSL_SERVER_HELLO_COORDINATE_HRR 1
+static int ssl_server_hello_is_hrr( mbedtls_ssl_context *ssl,
+                                    const unsigned char *buf,
+                                    const unsigned char *end )
 {
-    static const unsigned char magic_hrr_string[32] =
+    static const unsigned char magic_hrr_string[SERVER_HELLO_RANDOM_LEN] =
         { 0xCF, 0x21, 0xAD, 0x74, 0xE5, 0x9A, 0x61, 0x11,
           0xBE, 0x1D, 0x8C, 0x02, 0x1E, 0x65, 0xB8, 0x91,
           0xC2, 0xA2, 0x11, 0x16, 0x7A, 0xBB, 0x8C, 0x5E,
@@ -902,31 +912,23 @@
      *    opaque legacy_session_id_echo<0..32>;
      *    CipherSuite cipher_suite;
      *    uint8 legacy_compression_method = 0;
-     *    Extension extensions<6..2 ^ 16 - 1>;
+     *    Extension extensions<6..2^16-1>;
      * } ServerHello;
      *
      */
-    if( blen < 2 + sizeof( magic_hrr_string ) )
-        return (MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL );
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( buf, end, 2 + sizeof( magic_hrr_string ) );
 
     if( memcmp( buf + 2, magic_hrr_string, sizeof( magic_hrr_string ) ) == 0 )
     {
-        return( 1 );
+        return( SSL_SERVER_HELLO_COORDINATE_HRR );
     }
 
-    return( 0 );
+    return( SSL_SERVER_HELLO_COORDINATE_HELLO );
 }
 
-/* Fetch and preprocess
- * Returns a negative value on failure, and otherwise
- * - SSL_SERVER_HELLO_COORDINATE_HELLO or
- * - SSL_SERVER_HELLO_COORDINATE_HRR
- * to indicate which message is expected and to be parsed next. */
-#define SSL_SERVER_HELLO_COORDINATE_HELLO 0
-#define SSL_SERVER_HELLO_COORDINATE_HRR 1
-static int ssl_server_hello_coordinate( mbedtls_ssl_context *ssl,
-                                        unsigned char **buf,
-                                        size_t *buf_len )
+static int ssl_tls13_server_hello_coordinate( mbedtls_ssl_context *ssl,
+                                              unsigned char **buf,
+                                              size_t *buf_len )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
@@ -945,15 +947,15 @@
     *buf = ssl->in_msg + 4;
     *buf_len = ssl->in_hslen - 4;
 
-    if( ssl_server_hello_is_hrr( *buf, *buf_len ) )
+    ret = ssl_server_hello_is_hrr( ssl, *buf, *buf + *buf_len );
+    switch( ret )
     {
-        MBEDTLS_SSL_DEBUG_MSG( 2, ( "received HelloRetryRequest message" ) );
-        ret = SSL_SERVER_HELLO_COORDINATE_HRR;
-    }
-    else
-    {
+    case SSL_SERVER_HELLO_COORDINATE_HELLO:
         MBEDTLS_SSL_DEBUG_MSG( 2, ( "received ServerHello message" ) );
-        ret = SSL_SERVER_HELLO_COORDINATE_HELLO;
+        break;
+    case SSL_SERVER_HELLO_COORDINATE_HRR:
+        MBEDTLS_SSL_DEBUG_MSG( 2, ( "received HelloRetryRequest message" ) );
+        break;
     }
 
 cleanup:
@@ -977,10 +979,6 @@
     if( ssl->session_negotiate->id_len != legacy_session_id_echo_len ||
         memcmp( ssl->session_negotiate->id, p , legacy_session_id_echo_len ) != 0 )
     {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "Mismatch of session id length:"
-            " id_len = %" MBEDTLS_PRINTF_SIZET
-            " , legacy_session_id_echo_len = %" MBEDTLS_PRINTF_SIZET,
-            ssl->session_negotiate->id_len, legacy_session_id_echo_len ) );
         MBEDTLS_SSL_DEBUG_BUF( 3, "Expected Session ID",
                                ssl->session_negotiate->id,
                                ssl->session_negotiate->id_len );
@@ -1025,17 +1023,17 @@
  *    opaque legacy_session_id_echo<0..32>;
  *    CipherSuite cipher_suite;
  *    uint8 legacy_compression_method = 0;
- *    Extension extensions<6..2 ^ 16 - 1>;
+ *    Extension extensions<6..2^16-1>;
  * } ServerHello;
  */
 static int ssl_tls13_parse_server_hello( mbedtls_ssl_context *ssl,
                                          const unsigned char *buf,
                                          const unsigned char *end )
 {
-    int ret;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     const unsigned char *p = buf;
-    size_t extensions_len; /* Length of field */
-    const unsigned char *extensions_end; /* Pointer to end of individual extension */
+    size_t extensions_len;
+    const unsigned char *extensions_end;
     uint16_t cipher_suite;
     const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
 
@@ -1054,7 +1052,7 @@
     MBEDTLS_SSL_DEBUG_BUF( 3, "server hello, version", p, 2 );
 
     /* ...
-     * ProtocaolVersion legacy_version = 0x0303; // TLS 1.2
+     * ProtocolVersion legacy_version = 0x0303; // TLS 1.2
      * ...
      * with ProtocolVersion defined as:
      * uint16 ProtocolVersion;
@@ -1112,9 +1110,8 @@
     if( ciphersuite_info == NULL ||
         ssl_tls13_cipher_suite_is_offered( ssl, cipher_suite ) == 0 )
     {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "ciphersuite info for %04x not found",
+        MBEDTLS_SSL_DEBUG_MSG( 1, ( "ciphersuite(%04x) not found or not offered",
                                     cipher_suite ) );
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad server hello message" ) );
 
         MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
                                       MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
@@ -1142,17 +1139,16 @@
     MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, 1 );
     if( p[0] != 0 )
     {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad server hello message" ) );
+        MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad legacy compression method" ) );
         MBEDTLS_SSL_PEND_FATAL_ALERT( MBEDTLS_SSL_ALERT_MSG_ILLEGAL_PARAMETER,
                                       MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
         return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
     }
     p++;
 
-    /*
-     *    ....
-     *    Extension extensions<6..2 ^ 16 - 1>;
-     *    ....
+    /* ...
+     * Extension extensions<6..2^16-1>;
+     * ...
      * struct {
      *      ExtensionType extension_type; (2 bytes)
      *      opaque extension_data<0..2^16-1>;
@@ -1166,9 +1162,6 @@
     MBEDTLS_SSL_CHK_BUF_READ_PTR( p, end, extensions_len );
     extensions_end = p + extensions_len;
 
-    MBEDTLS_SSL_DEBUG_MSG( 3,
-        ( "server hello, total extension length: %" MBEDTLS_PRINTF_SIZET ,
-          extensions_len ) );
     MBEDTLS_SSL_DEBUG_BUF( 3, "server hello extensions", p, extensions_len );
 
     while( p < extensions_end )
@@ -1190,7 +1183,8 @@
                             ( "found supported_versions extension" ) );
 
                 ret = ssl_tls13_parse_supported_versions_ext( ssl,
-                                                        p, extension_data_len );
+                                                              p,
+                                                              p + extension_data_len );
                 if( ret != 0 )
                     return( ret );
                 break;
@@ -1238,45 +1232,39 @@
 
 static int ssl_tls13_finalize_server_hello( mbedtls_ssl_context *ssl )
 {
-    int ret;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     mbedtls_ssl_key_set traffic_keys;
     mbedtls_ssl_transform *transform_handshake = NULL;
+    mbedtls_ssl_handshake_params *handshake = ssl->handshake;
 
-    /* We need to set the key exchange algorithm based on the
-     * following rules:
-     *
-     *   1) IF PRE_SHARED_KEY extension was received
-     *          THEN set KEY_EXCHANGE_MODE_PSK_EPHEMERAL;
-     *   2) IF PRE_SHARED_KEY extension && KEY_SHARE was received
-     *          THEN set KEY_EXCHANGE_MODE_PSK;
-     *   3) IF KEY_SHARES extension was received && SIG_ALG extension received
-     *      THEN set KEY_EXCHANGE_MODE_EPHEMERAL
-     *   ELSE unknown key exchange mechanism.
+    /* Determine the key exchange mode:
+     * 1) If both the pre_shared_key and key_share extensions were received
+     *    then the key exchange mode is PSK with EPHEMERAL.
+     * 2) If only the pre_shared_key extension was received then the key
+     *    exchange mode is PSK-only.
+     * 3) If only the key_share extension was received then the key
+     *    exchange mode is EPHEMERAL-only.
      */
-    if( ssl->handshake->extensions_present & MBEDTLS_SSL_EXT_PRE_SHARED_KEY )
+    switch( handshake->extensions_present &
+            ( MBEDTLS_SSL_EXT_PRE_SHARED_KEY | MBEDTLS_SSL_EXT_KEY_SHARE ) )
     {
-        if( ssl->handshake->extensions_present & MBEDTLS_SSL_EXT_KEY_SHARE )
-        {
-            /* Condition 2) */
-            ssl->handshake->tls1_3_kex_modes =
-                MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_PSK_EPHEMERAL;
-        }
-        else
-        {
-            /* Condition 1) */
-            ssl->handshake->tls1_3_kex_modes =
-                MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_PSK;
-        }
-    }
-    else if( ( ssl->handshake->extensions_present & MBEDTLS_SSL_EXT_KEY_SHARE ) )
-    {
-        /* Condition 3) */
-        ssl->handshake->tls1_3_kex_modes =
-            MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_EPHEMERAL;
-    }
-    else
-    {
-        /* ELSE case */
+    /* Only the pre_shared_key extension was received */
+    case MBEDTLS_SSL_EXT_PRE_SHARED_KEY:
+        handshake->tls1_3_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_PSK;
+        break;
+
+    /* Only the key_share extension was received */
+    case MBEDTLS_SSL_EXT_KEY_SHARE:
+        handshake->tls1_3_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_EPHEMERAL;
+        break;
+
+    /* Both the pre_shared_key and key_share extensions were received */
+    case ( MBEDTLS_SSL_EXT_PRE_SHARED_KEY | MBEDTLS_SSL_EXT_KEY_SHARE ):
+        handshake->tls1_3_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_PSK_EPHEMERAL;
+        break;
+
+    /* Neither pre_shared_key nor key_share extension was received */
+    default:
         MBEDTLS_SSL_DEBUG_MSG( 1, ( "Unknown key exchange." ) );
         ret = MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
         goto cleanup;
@@ -1313,8 +1301,7 @@
         goto cleanup;
     }
 
-    transform_handshake =
-        mbedtls_calloc( 1, sizeof( mbedtls_ssl_transform ) );
+    transform_handshake = mbedtls_calloc( 1, sizeof( mbedtls_ssl_transform ) );
     if( transform_handshake == NULL )
     {
         ret = MBEDTLS_ERR_SSL_ALLOC_FAILED;
@@ -1332,8 +1319,8 @@
         goto cleanup;
     }
 
-    ssl->handshake->transform_handshake = transform_handshake;
-    mbedtls_ssl_set_inbound_transform( ssl, ssl->handshake->transform_handshake );
+    handshake->transform_handshake = transform_handshake;
+    mbedtls_ssl_set_inbound_transform( ssl, transform_handshake );
 
     MBEDTLS_SSL_DEBUG_MSG( 1, ( "Switch to handshake keys for inbound traffic" ) );
     ssl->session_in = ssl->session_negotiate;
@@ -1348,8 +1335,7 @@
     mbedtls_platform_zeroize( &traffic_keys, sizeof( traffic_keys ) );
     if( ret != 0 )
     {
-        if( transform_handshake != NULL )
-            mbedtls_free( transform_handshake );
+        mbedtls_free( transform_handshake );
 
         MBEDTLS_SSL_PEND_FATAL_ALERT(
             MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE,
@@ -1375,11 +1361,10 @@
      * - Make sure it's either a ServerHello or a HRR.
      * - Switch processing routine in case of HRR
      */
-
     ssl->major_ver = MBEDTLS_SSL_MAJOR_VERSION_3;
     ssl->handshake->extensions_present = MBEDTLS_SSL_EXT_NONE;
 
-    ret = ssl_server_hello_coordinate( ssl, &buf, &buf_len );
+    ret = ssl_tls13_server_hello_coordinate( ssl, &buf, &buf_len );
     /* Parsing step
      * We know what message to expect by now and call
      * the respective parsing function.