Support TLS 1.3 key exchange config in ssl_client2/ssl_server2

Signed-off-by: Hanno Becker <hanno.becker@arm.com>
diff --git a/programs/ssl/ssl_client2.c b/programs/ssl/ssl_client2.c
index 86c314c..f408973 100644
--- a/programs/ssl/ssl_client2.c
+++ b/programs/ssl/ssl_client2.c
@@ -65,6 +65,7 @@
 #define DFL_ECJPAKE_PW          NULL
 #define DFL_EC_MAX_OPS          -1
 #define DFL_FORCE_CIPHER        0
+#define DFL_TLS13_KEX_MODES     MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_ALL
 #define DFL_RENEGOTIATION       MBEDTLS_SSL_RENEGOTIATION_DISABLED
 #define DFL_ALLOW_LEGACY        -2
 #define DFL_RENEGOTIATE         0
@@ -335,6 +336,14 @@
 #define USAGE_SERIALIZATION ""
 #endif
 
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL)
+#define USAGE_TLS13_KEY_EXCHANGE_MODES \
+    "    tls13_kex_modes=%%s   default: all\n"     \
+    "                          options: psk, psk_ephemeral, ephemeral, psk_all, all\n"
+#else
+#define USAGE_TLS13_KEY_EXCHANGE_MODES ""
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
+
 /* USAGE is arbitrarily split to stay under the portable string literal
  * length limit: 4095 bytes in C99. */
 #define USAGE1 \
@@ -403,18 +412,19 @@
 #endif /* !MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
 
 #define USAGE4 \
-    "    allow_sha1=%%d       default: 0\n"                             \
-    "    min_version=%%s      default: (library default: tls1_2)\n"       \
-    "    max_version=%%s      default: (library default: tls1_2)\n"     \
-    "    force_version=%%s    default: \"\" (none)\n"       \
+    "    allow_sha1=%%d       default: 0\n"                                   \
+    "    min_version=%%s      default: (library default: tls1_2)\n"           \
+    "    max_version=%%s      default: (library default: tls1_2)\n"           \
+    "    force_version=%%s    default: \"\" (none)\n"                         \
     "                        options: tls1_2, dtls1_2" TLS1_3_VERSION_OPTIONS \
-    "\n\n"                                                    \
-    "    force_ciphersuite=<name>    default: all enabled\n"\
-    "    query_config=<name>         return 0 if the specified\n"       \
+    "\n\n"                                                                    \
+    "    force_ciphersuite=<name>    default: all enabled\n"                  \
+    USAGE_TLS13_KEY_EXCHANGE_MODES                                            \
+    "    query_config=<name>         return 0 if the specified\n"             \
     "                                configuration macro is defined and 1\n"  \
     "                                otherwise. The expansion of the macro\n" \
-    "                                is printed if it is defined\n"     \
-    USAGE_SERIALIZATION                                     \
+    "                                is printed if it is defined\n"           \
+    USAGE_SERIALIZATION                                                       \
     " acceptable ciphersuite names:\n"
 
 #define ALPN_LIST_SIZE  10
@@ -453,6 +463,9 @@
     const char *ecjpake_pw;     /* the EC J-PAKE password                   */
     int ec_max_ops;             /* EC consecutive operations limit          */
     int force_ciphersuite[2];   /* protocol/ciphersuite to use, or all      */
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL)
+    int tls13_kex_modes;        /* supported TLS 1.3 key exchange modes     */
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
     int renegotiation;          /* enable / disable renegotiation           */
     int allow_legacy;           /* allow legacy renegotiation               */
     int renegotiate;            /* attempt renegotiation?                   */
@@ -814,6 +827,9 @@
     opt.ecjpake_pw          = DFL_ECJPAKE_PW;
     opt.ec_max_ops          = DFL_EC_MAX_OPS;
     opt.force_ciphersuite[0]= DFL_FORCE_CIPHER;
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL)
+    opt.tls13_kex_modes     = DFL_TLS13_KEX_MODES;
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
     opt.renegotiation       = DFL_RENEGOTIATION;
     opt.allow_legacy        = DFL_ALLOW_LEGACY;
     opt.renegotiate         = DFL_RENEGOTIATE;
@@ -1072,6 +1088,24 @@
                 default: goto usage;
             }
         }
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL)
+        else if( strcmp( p, "tls13_kex_modes" ) == 0 )
+        {
+            if( strcmp( q, "psk_pure" ) == 0 )
+                opt.tls13_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_PSK;
+            else if( strcmp(q, "psk_ephemeral" ) == 0 )
+                opt.tls13_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_PSK_EPHEMERAL;
+            else if( strcmp(q, "ephemeral_pure" ) == 0 )
+                opt.tls13_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_EPHEMERAL;
+            else if( strcmp(q, "ephemeral_all" ) == 0 )
+                opt.tls13_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_EPHEMERAL_ALL;
+            else if( strcmp( q, "psk_all" ) == 0 )
+                opt.tls13_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_PSK_ALL;
+            else if( strcmp( q, "all" ) == 0 )
+                opt.tls13_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_ALL;
+            else goto usage;
+        }
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
         else if( strcmp( p, "min_version" ) == 0 )
         {
             if( strcmp( q, "tls1_2" ) == 0 ||
@@ -1748,6 +1782,10 @@
     if( opt.force_ciphersuite[0] != DFL_FORCE_CIPHER )
         mbedtls_ssl_conf_ciphersuites( &conf, opt.force_ciphersuite );
 
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL)
+    mbedtls_ssl_conf_tls13_key_exchange_modes( &conf, opt.tls13_kex_modes );
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
+
     if( opt.allow_legacy != DFL_ALLOW_LEGACY )
         mbedtls_ssl_conf_legacy_renegotiation( &conf, opt.allow_legacy );
 #if defined(MBEDTLS_SSL_RENEGOTIATION)
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index 83bd617..25cdb40 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -95,6 +95,7 @@
 #define DFL_ECJPAKE_PW          NULL
 #define DFL_PSK_LIST            NULL
 #define DFL_FORCE_CIPHER        0
+#define DFL_TLS13_KEX_MODES     MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_ALL
 #define DFL_RENEGOTIATION       MBEDTLS_SSL_RENEGOTIATION_DISABLED
 #define DFL_ALLOW_LEGACY        -2
 #define DFL_RENEGOTIATE         0
@@ -564,6 +565,9 @@
     char *psk_list;             /* list of PSK id/key pairs for callback    */
     const char *ecjpake_pw;     /* the EC J-PAKE password                   */
     int force_ciphersuite[2];   /* protocol/ciphersuite to use, or all      */
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL)
+    int tls13_kex_modes;        /* supported TLS 1.3 key exchange modes     */
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
     int renegotiation;          /* enable / disable renegotiation           */
     int allow_legacy;           /* allow legacy renegotiation               */
     int renegotiate;            /* attempt renegotiation?                   */
@@ -1478,6 +1482,9 @@
     opt.psk_list            = DFL_PSK_LIST;
     opt.ecjpake_pw          = DFL_ECJPAKE_PW;
     opt.force_ciphersuite[0]= DFL_FORCE_CIPHER;
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL)
+    opt.tls13_kex_modes     = DFL_TLS13_KEX_MODES;
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
     opt.renegotiation       = DFL_RENEGOTIATION;
     opt.allow_legacy        = DFL_ALLOW_LEGACY;
     opt.renegotiate         = DFL_RENEGOTIATE;
@@ -1714,6 +1721,25 @@
             if( opt.exchanges < 0 )
                 goto usage;
         }
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL)
+        else if( strcmp( p, "tls13_kex_modes" ) == 0 )
+        {
+            if( strcmp( q, "psk_pure" ) == 0 )
+                opt.tls13_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_PSK;
+            else if( strcmp(q, "psk_ephemeral" ) == 0 )
+                opt.tls13_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_PSK_EPHEMERAL;
+            else if( strcmp(q, "ephemeral_pure" ) == 0 )
+                opt.tls13_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_EPHEMERAL;
+            else if( strcmp(q, "ephemeral_all" ) == 0 )
+                opt.tls13_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_EPHEMERAL_ALL;
+            else if( strcmp( q, "psk_all" ) == 0 )
+                opt.tls13_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_PSK_ALL;
+            else if( strcmp( q, "all" ) == 0 )
+                opt.tls13_kex_modes = MBEDTLS_SSL_TLS13_KEY_EXCHANGE_MODE_ALL;
+            else goto usage;
+        }
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
+
         else if( strcmp( p, "min_version" ) == 0 )
         {
             if( strcmp( q, "tls1_2" ) == 0 ||
@@ -2610,6 +2636,10 @@
     if( opt.force_ciphersuite[0] != DFL_FORCE_CIPHER )
         mbedtls_ssl_conf_ciphersuites( &conf, opt.force_ciphersuite );
 
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL)
+    mbedtls_ssl_conf_tls13_key_exchange_modes( &conf, opt.tls13_kex_modes );
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3_EXPERIMENTAL */
+
     if( opt.allow_legacy != DFL_ALLOW_LEGACY )
         mbedtls_ssl_conf_legacy_renegotiation( &conf, opt.allow_legacy );
 #if defined(MBEDTLS_SSL_RENEGOTIATION)