DHE-PSK based ciphersuite support added and cleaner key exchange based
code selection

The base RFC 4279 DHE-PSK ciphersuites are now supported and added.

The SSL code cuts out code not relevant for defined key exchange methods
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index ff467f3..377d12b 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -727,26 +727,12 @@
     return( 0 );
 }
 
-#if !defined(POLARSSL_DHM_C) && !defined(POLARSSL_ECDH_C) &&                \
-    !defined(POLARSSL_KEY_EXCHANGE_PSK_ENABLED)
-static int ssl_parse_server_key_exchange( ssl_context *ssl )
-{
-    SSL_DEBUG_MSG( 2, ( "=> parse server key exchange" ) );
-    SSL_DEBUG_MSG( 2, ( "<= skip parse server key exchange" ) );
-    ssl->state++;
-    return( 0 );
-}
-#else
+#if defined(POLARSSL_KEY_EXCHANGE_DHE_RSA_ENABLED)
 static int ssl_parse_server_dh_params( ssl_context *ssl, unsigned char **p,
                                        unsigned char *end )
 {
     int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE;
 
-#if !defined(POLARSSL_DHM_C)
-    ((void) ssl);
-    ((void) p);
-    ((void) end);
-#else
     /*
      * Ephemeral DH parameters:
      *
@@ -772,22 +758,18 @@
     SSL_DEBUG_MPI( 3, "DHM: P ", &ssl->handshake->dhm_ctx.P  );
     SSL_DEBUG_MPI( 3, "DHM: G ", &ssl->handshake->dhm_ctx.G  );
     SSL_DEBUG_MPI( 3, "DHM: GY", &ssl->handshake->dhm_ctx.GY );
-#endif /* POLARSSL_DHM_C */
 
     return( ret );
 }
+#endif /* POLARSSL_KEY_EXCHANGE_DHE_RSA_ENABLED */
 
+#if defined(POLARSSL_KEY_EXCHANGE_ECDHE_RSA_ENABLED)
 static int ssl_parse_server_ecdh_params( ssl_context *ssl,
                                          unsigned char **p,
                                          unsigned char *end )
 {
     int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE;
 
-#if !defined(POLARSSL_ECDH_C)
-    ((void) ssl);
-    ((void) p);
-    ((void) end);
-#else
     /*
      * Ephemeral ECDH parameters:
      *
@@ -813,22 +795,18 @@
     }
 
     SSL_DEBUG_ECP( 3, "ECDH: Qp", &ssl->handshake->ecdh_ctx.Qp );
-#endif /* POLARSSL_ECDH_C */
 
     return( ret );
 }
+#endif /* POLARSSL_KEY_EXCHANGE_ECDHE_RSA_ENABLED */
 
+#if defined(POLARSSL_KEY_EXCHANGE_PSK_ENABLED) ||                           \
+    defined(POLARSSL_KEY_EXCHANGE_DHE_PSK_ENABLED)
 static int ssl_parse_server_psk_hint( ssl_context *ssl,
                                       unsigned char **p,
                                       unsigned char *end )
 {
     int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE;
-
-#if !defined(POLARSSL_KEY_EXCHANGE_PSK_ENABLED)
-    ((void) ssl);
-    ((void) p);
-    ((void) end);
-#else
     size_t  len;
 
     /*
@@ -837,6 +815,7 @@
      * opaque psk_identity_hint<0..2^16-1>;
      */
     len = (*p)[1] << 8 | (*p)[0];
+    *p += 2;
 
     if( (*p) + len > end )
     {
@@ -847,12 +826,15 @@
     // TODO: Retrieve PSK identity hint and callback to app
     //
     *p += len;
-#endif /* POLARSSL_KEY_EXCHANGE_PSK_ENABLED */
+    ret = 0;
 
     return( ret );
 }
+#endif /* POLARSSL_KEY_EXCHANGE_PSK_ENABLED ||
+          POLARSSL_KEY_EXCHANGE_DHE_PSK_ENABLED */
 
-#if defined(POLARSSL_RSA_C)
+#if defined(POLARSSL_KEY_EXCHANGE_DHE_RSA_ENABLED) ||                       \
+    defined(POLARSSL_KEY_EXCHANGE_ECDHE_RSA_ENABLED)
 static int ssl_parse_signature_algorithm( ssl_context *ssl,
                                           unsigned char **p,
                                           unsigned char *end,
@@ -860,7 +842,7 @@
 {
     *md_alg = POLARSSL_MD_NONE;
 
-    if( (*p) + 2 < end )
+    if( (*p) + 2 > end )
         return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
 
     if( (*p)[1] != SSL_SIG_RSA )
@@ -908,26 +890,28 @@
 
     return( 0 );
 }
-#endif /* POLARSSL_RSA_C */
+#endif /* POLARSSL_KEY_EXCHANGE_DHE_RSA_ENABLED ||
+          POLARSSL_KEY_EXCHANGE_ECDHE_RSA_ENABLED */
 
 static int ssl_parse_server_key_exchange( ssl_context *ssl )
 {
     int ret;
+    const ssl_ciphersuite_t *ciphersuite_info = ssl->transform_negotiate->ciphersuite_info;
     unsigned char *p, *end;
-#if defined(POLARSSL_RSA_C)
+#if defined(POLARSSL_KEY_EXCHANGE_DHE_RSA_ENABLED) ||                       \
+    defined(POLARSSL_KEY_EXCHANGE_ECDHE_RSA_ENABLED)
     size_t n;
     unsigned char hash[64];
     md_type_t md_alg = POLARSSL_MD_NONE;
     unsigned int hashlen = 0;
-#endif
-
-    const ssl_ciphersuite_t *ciphersuite_info = ssl->transform_negotiate->ciphersuite_info;
+#endif 
 
     SSL_DEBUG_MSG( 2, ( "=> parse server key exchange" ) );
 
     if( ciphersuite_info->key_exchange != POLARSSL_KEY_EXCHANGE_DHE_RSA &&
         ciphersuite_info->key_exchange != POLARSSL_KEY_EXCHANGE_ECDHE_RSA &&
-        ciphersuite_info->key_exchange != POLARSSL_KEY_EXCHANGE_PSK )
+        ciphersuite_info->key_exchange != POLARSSL_KEY_EXCHANGE_PSK &&
+        ciphersuite_info->key_exchange != POLARSSL_KEY_EXCHANGE_DHE_PSK )
     {
         SSL_DEBUG_MSG( 2, ( "<= skip parse server key exchange" ) );
         ssl->state++;
@@ -963,6 +947,7 @@
     p   = ssl->in_msg + 4;
     end = ssl->in_msg + ssl->in_hslen;
 
+#if defined(POLARSSL_KEY_EXCHANGE_DHE_RSA_ENABLED)
     if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_DHE_RSA )
     {
         if( ssl_parse_server_dh_params( ssl, &p, end ) != 0 )
@@ -971,7 +956,10 @@
             return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
         }
     }
-    else if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_ECDHE_RSA )
+    else
+#endif /* POLARSSL_KEY_EXCHANGE_DHE_RSA_ENABLED */
+#if defined(POLARSSL_KEY_EXCHANGE_ECDHE_RSA_ENABLED)
+    if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_ECDHE_RSA )
     {
         if( ssl_parse_server_ecdh_params( ssl, &p, end ) != 0 )
         {
@@ -979,7 +967,10 @@
             return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
         }
     }
-    else if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_PSK )
+    else
+#endif /* POLARSSL_KEY_EXCHANGE_ECDHE_RSA_ENABLED */
+#if defined(POLARSSL_KEY_EXCHANGE_PSK_ENABLED)
+    if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_PSK )
     {
         if( ssl_parse_server_psk_hint( ssl, &p, end ) != 0 )
         {
@@ -987,8 +978,30 @@
             return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
         }
     }
+    else
+#endif /* POLARSSL_KEY_EXCHANGE_PSK_ENABLED */
+#if defined(POLARSSL_KEY_EXCHANGE_DHE_PSK_ENABLED)
+    if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_DHE_PSK )
+    {
+        if( ssl_parse_server_psk_hint( ssl, &p, end ) != 0 )
+        {
+            SSL_DEBUG_MSG( 1, ( "bad server key exchange message" ) );
+            return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
+        }
+        if( ssl_parse_server_dh_params( ssl, &p, end ) != 0 )
+        {
+            SSL_DEBUG_MSG( 1, ( "failed to parsebad server key exchange message" ) );
+            return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
+        }
+    }
+    else
+#endif /* POLARSSL_KEY_EXCHANGE_DHE_PSK_ENABLED */
+    {
+        return( POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE );
+    }
 
-#if defined(POLARSSL_RSA_C)
+#if defined(POLARSSL_KEY_EXCHANGE_DHE_RSA_ENABLED) ||                       \
+    defined(POLARSSL_KEY_EXCHANGE_ECDHE_RSA_ENABLED)
     if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_DHE_RSA ||
         ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_ECDHE_RSA )
     {
@@ -1088,7 +1101,8 @@
             return( ret );
         }
     }
-#endif /* POLARSSL_RSA_C */
+#endif /* POLARSSL_KEY_EXCHANGE_DHE_RSA_ENABLED ||
+          POLARSSL_KEY_EXCHANGE_ECDHE_RSA_ENABLED */
 
 exit:
     ssl->state++;
@@ -1097,7 +1111,6 @@
 
     return( 0 );
 }
-#endif /* POLARSSL_DHM_C || POLARSSL_ECDH_C */
 
 static int ssl_parse_certificate_request( ssl_context *ssl )
 {
@@ -1262,7 +1275,7 @@
 
     SSL_DEBUG_MSG( 2, ( "=> write client key exchange" ) );
 
-#if defined(POLARSSL_DHM_C)
+#if defined(POLARSSL_KEY_EXCHANGE_DHE_RSA_ENABLED)
     if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_DHE_RSA )
     {
         /*
@@ -1300,8 +1313,8 @@
         SSL_DEBUG_MPI( 3, "DHM: K ", &ssl->handshake->dhm_ctx.K  );
     }
     else
-#endif /* POLARSSL_DHM_C */
-#if defined(POLARSSL_ECDH_C)
+#endif /* POLARSSL_KEY_EXCHANGE_DHE_RSA_ENABLED */
+#if defined(POLARSSL_KEY_EXCHANGE_ECDHE_RSA_ENABLED)
     if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_ECDHE_RSA )
     {
         /*
@@ -1333,7 +1346,7 @@
         SSL_DEBUG_MPI( 3, "ECDH: z", &ssl->handshake->ecdh_ctx.z );
     }
     else
-#endif /* POLARSSL_ECDH_C */
+#endif /* POLARSSL_KEY_EXCHANGE_ECDHE_RSA_ENABLED */
 #if defined(POLARSSL_KEY_EXCHANGE_PSK_ENABLED)
     if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_PSK )
     {
@@ -1344,7 +1357,7 @@
          *
          * opaque psk_identity<0..2^16-1>;
          */
-        if( ssl->hostname == NULL )
+        if( ssl->psk == NULL )
             return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
 
         if( sizeof(ssl->handshake->premaster) < 4 + 2 * ssl->psk_len )
@@ -1371,7 +1384,78 @@
     }
     else
 #endif /* POLARSSL_KEY_EXCHANGE_PSK_ENABLED */
-#if defined(POLARSSL_X509_PARSE_C)
+#if defined(POLARSSL_KEY_EXCHANGE_DHE_PSK_ENABLED)
+    if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_DHE_PSK )
+    {
+        unsigned char *p = ssl->handshake->premaster;
+
+        /*
+         * DHE_PSK key exchange
+         *
+         * opaque psk_identity<0..2^16-1>;
+         * ClientDiffieHellmanPublic public (DHM send G^X mod P)
+         */
+        if( ssl->psk == NULL )
+            return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
+
+        if( sizeof(ssl->handshake->premaster) < 4 + ssl->psk_identity_len +
+                                                ssl->handshake->dhm_ctx.len )
+            return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
+
+        i = 4;
+        n = ssl->psk_identity_len;
+        ssl->out_msg[4] = (unsigned char)( n >> 8 );
+        ssl->out_msg[5] = (unsigned char)( n      );
+
+        memcpy( ssl->out_msg + 6, ssl->psk_identity, ssl->psk_identity_len );
+
+        n = ssl->handshake->dhm_ctx.len;
+        ssl->out_msg[6 + ssl->psk_identity_len] = (unsigned char)( n >> 8 );
+        ssl->out_msg[7 + ssl->psk_identity_len] = (unsigned char)( n      );
+
+        ret = dhm_make_public( &ssl->handshake->dhm_ctx,
+                                mpi_size( &ssl->handshake->dhm_ctx.P ),
+                               &ssl->out_msg[8 + ssl->psk_identity_len], n,
+                                ssl->f_rng, ssl->p_rng );
+        if( ret != 0 )
+        {
+            SSL_DEBUG_RET( 1, "dhm_make_public", ret );
+            return( ret );
+        }
+
+        SSL_DEBUG_MPI( 3, "DHM: X ", &ssl->handshake->dhm_ctx.X  );
+        SSL_DEBUG_MPI( 3, "DHM: GX", &ssl->handshake->dhm_ctx.GX );
+
+        *(p++) = (unsigned char)( ssl->handshake->dhm_ctx.len >> 8 );
+        *(p++) = (unsigned char)( ssl->handshake->dhm_ctx.len      );
+        if( ( ret = dhm_calc_secret( &ssl->handshake->dhm_ctx,
+                                      p, &n ) ) != 0 )
+        {
+            SSL_DEBUG_RET( 1, "dhm_calc_secret", ret );
+            return( ret );
+        }
+
+        if( n != ssl->handshake->dhm_ctx.len )
+        {
+            SSL_DEBUG_MSG( 1, ( "dhm_calc_secret result smaller than DHM" ) );
+            return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
+        }
+
+        SSL_DEBUG_MPI( 3, "DHM: K ", &ssl->handshake->dhm_ctx.K  );
+
+        p += ssl->handshake->dhm_ctx.len;
+
+        *(p++) = (unsigned char)( ssl->psk_len >> 8 );
+        *(p++) = (unsigned char)( ssl->psk_len      );
+        memcpy( p, ssl->psk, ssl->psk_len );
+        p += ssl->psk_len;
+
+        ssl->handshake->pmslen = 4 + ssl->handshake->dhm_ctx.len + ssl->psk_len;
+        n = ssl->handshake->pmslen;
+    }
+    else
+#endif /* POLARSSL_KEY_EXCHANGE_DHE_PSK_ENABLED */
+#if defined(POLARSSL_KEY_EXCHANGE_RSA_ENABLED)
     if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_RSA )
     {
         /*
@@ -1409,7 +1493,7 @@
         }
     }
     else
-#endif /* POLARSSL_X509_PARSE_C */
+#endif /* POLARSSL_KEY_EXCHANGE_RSA_ENABLED */
     {
         ((void) ciphersuite_info);
         return( POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE );
@@ -1438,27 +1522,46 @@
     return( 0 );
 }
 
+#if !defined(POLARSSL_KEY_EXCHANGE_RSA_ENABLED)       && \
+    !defined(POLARSSL_KEY_EXCHANGE_DHE_RSA_ENABLED)   && \
+    !defined(POLARSSL_KEY_EXCHANGE_ECDHE_RSA_ENABLED)
 static int ssl_write_certificate_verify( ssl_context *ssl )
 {
     int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE;
     const ssl_ciphersuite_t *ciphersuite_info = ssl->transform_negotiate->ciphersuite_info;
-#if defined(POLARSSL_X509_PARSE_C)
-    size_t n = 0, offset = 0;
-    unsigned char hash[48];
-    md_type_t md_alg = POLARSSL_MD_NONE;
-    unsigned int hashlen = 0;
-#endif
 
     SSL_DEBUG_MSG( 2, ( "=> write certificate verify" ) );
 
-    if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_PSK )
+    if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_PSK ||
+        ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_DHE_PSK )
     {
         SSL_DEBUG_MSG( 2, ( "<= skip write certificate verify" ) );
         ssl->state++;
         return( 0 );
     }
 
-#if defined(POLARSSL_X509_PARSE_C)
+    return( ret );
+}
+#else
+static int ssl_write_certificate_verify( ssl_context *ssl )
+{
+    int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE;
+    const ssl_ciphersuite_t *ciphersuite_info = ssl->transform_negotiate->ciphersuite_info;
+    size_t n = 0, offset = 0;
+    unsigned char hash[48];
+    md_type_t md_alg = POLARSSL_MD_NONE;
+    unsigned int hashlen = 0;
+
+    SSL_DEBUG_MSG( 2, ( "=> write certificate verify" ) );
+
+    if( ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_PSK ||
+        ciphersuite_info->key_exchange == POLARSSL_KEY_EXCHANGE_DHE_PSK )
+    {
+        SSL_DEBUG_MSG( 2, ( "<= skip write certificate verify" ) );
+        ssl->state++;
+        return( 0 );
+    }
+
     if( ssl->client_auth == 0 || ssl->own_cert == NULL )
     {
         SSL_DEBUG_MSG( 2, ( "<= skip write certificate verify" ) );
@@ -1558,12 +1661,14 @@
         SSL_DEBUG_RET( 1, "ssl_write_record", ret );
         return( ret );
     }
-#endif /* POLARSSL_X509_PARSE_C */
 
     SSL_DEBUG_MSG( 2, ( "<= write certificate verify" ) );
 
     return( ret );
 }
+#endif /* !POLARSSL_KEY_EXCHANGE_RSA_ENABLED &&
+          !POLARSSL_KEY_EXCHANGE_DHE_RSA_ENABLED &&
+          !POLARSSL_KEY_EXCHANGE_ECDHE_RSA_ENABLED */
 
 /*
  * SSL handshake -- client side -- single step