Add psk_key_exchange_modes parser

Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/ssl_tls13_server.c b/library/ssl_tls13_server.c
index 7d99433..fc5ceeb 100644
--- a/library/ssl_tls13_server.c
+++ b/library/ssl_tls13_server.c
@@ -45,6 +45,53 @@
 #include "ssl_tls13_keys.h"
 #include "ssl_debug_helpers.h"
 
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+/* From RFC 8446:
+ *
+ *   enum { psk_ke(0), psk_dhe_ke(1), (255) } PskKeyExchangeMode;
+ *   struct {
+ *       PskKeyExchangeMode ke_modes<1..255>;
+ *   } PskKeyExchangeModes;
+ */
+static int ssl_tls13_parse_key_exchange_modes_ext( mbedtls_ssl_context *ssl,
+                                                   const unsigned char *buf,
+                                                   const unsigned char *end)
+{
+    size_t ke_modes_len;
+    int ke_modes = 0;
+
+    /* Read PSK mode list length (1 Byte) */
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( buf, end, 1 );
+    ke_modes_len = *buf++;
+    /* Currently, there are only two PSK modes, so even without looking
+     * at the content, something's wrong if the list has more than 2 items. */
+    if( ke_modes_len > 2 )
+        return( MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE );
+
+    MBEDTLS_SSL_CHK_BUF_READ_PTR( buf, end, ke_modes_len );
+
+    while( ke_modes_len-- != 0 )
+    {
+        switch( *buf++ )
+        {
+        case MBEDTLS_SSL_TLS1_3_PSK_MODE_PURE:
+            ke_modes |= MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK;
+            MBEDTLS_SSL_DEBUG_MSG( 3, ( "Found PSK KEX MODE" ) );
+            break;
+        case MBEDTLS_SSL_TLS1_3_PSK_MODE_ECDHE:
+            ke_modes |= MBEDTLS_SSL_TLS1_3_KEY_EXCHANGE_MODE_PSK_EPHEMERAL;
+            MBEDTLS_SSL_DEBUG_MSG( 3, ( "Found PSK_EPHEMERAL KEX MODE" ) );
+            break;
+        default:
+            return( MBEDTLS_ERR_SSL_ILLEGAL_PARAMETER );
+        }
+    }
+
+    ssl->handshake->tls13_kex_modes = ke_modes;
+    return( 0 );
+}
+#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
+
 /* From RFC 8446:
  *   struct {
  *          ProtocolVersion versions<2..254>;
@@ -754,6 +801,23 @@
                 ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_SUPPORTED_VERSIONS;
                 break;
 
+#if defined(MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED)
+            case MBEDTLS_TLS_EXT_PSK_KEY_EXCHANGE_MODES:
+                MBEDTLS_SSL_DEBUG_MSG( 3, ( "found psk key exchange modes extension" ) );
+
+                ret = ssl_tls13_parse_key_exchange_modes_ext(
+                          ssl, p, extension_data_end );
+                if( ret != 0 )
+                {
+                    MBEDTLS_SSL_DEBUG_RET(
+                        1, "ssl_tls13_parse_key_exchange_modes_ext", ret );
+                    return( ret );
+                }
+
+                ssl->handshake->extensions_present |= MBEDTLS_SSL_EXT_PSK_KEY_EXCHANGE_MODES;
+                break;
+#endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
+
 #if defined(MBEDTLS_SSL_ALPN)
             case MBEDTLS_TLS_EXT_ALPN:
                 MBEDTLS_SSL_DEBUG_MSG( 3, ( "found alpn extension" ) );