New function mbedtls_dhm_get_value to copy a field of a DHM context
Reduce the need to break the DHM abstraction by accessing the context directly.
Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/ChangeLog.d/dhm-fields.txt b/ChangeLog.d/dhm-fields.txt
index 620e3dc..4d5c751 100644
--- a/ChangeLog.d/dhm-fields.txt
+++ b/ChangeLog.d/dhm-fields.txt
@@ -1,6 +1,8 @@
 Features
    * The new functions mbedtls_dhm_get_len() and mbedtls_dhm_get_bitlen()
      query the size of the modulus in a Diffie-Hellman context.
+   * The new function mbedtls_dhm_get_value() copy a field out of a
+     Diffie-Hellman context.
 
 API changes
    * Instead of accessing the len field of a DHM context, which is no longer
diff --git a/include/mbedtls/dhm.h b/include/mbedtls/dhm.h
index 3f7206e..6c319f8 100644
--- a/include/mbedtls/dhm.h
+++ b/include/mbedtls/dhm.h
@@ -85,6 +85,17 @@
 #define MBEDTLS_ERR_DHM_FILE_IO_ERROR                     -0x3480  /**< Read or write of file failed. */
 #define MBEDTLS_ERR_DHM_SET_GROUP_FAILED                  -0x3580  /**< Setting the modulus and generator failed. */
 
+/** Which parameter to access in mbedtls_dhm_get_value(). */
+typedef enum
+{
+    MBEDTLS_DHM_PARAM_P,  /*!<  The prime modulus. */
+    MBEDTLS_DHM_PARAM_G,  /*!<  The generator. */
+    MBEDTLS_DHM_PARAM_X,  /*!<  Our secret value. */
+    MBEDTLS_DHM_PARAM_GX, /*!<  Our public key = \c G^X mod \c P. */
+    MBEDTLS_DHM_PARAM_GY, /*!<  The public key of the peer = \c G^Y mod \c P. */
+    MBEDTLS_DHM_PARAM_K,  /*!<  The shared secret = \c G^(XY) mod \c P. */
+} mbedtls_dhm_parameter;
+
 #ifdef __cplusplus
 extern "C" {
 #endif
@@ -302,6 +313,22 @@
 size_t mbedtls_dhm_get_len( const mbedtls_dhm_context *ctx );
 
 /**
+ * \brief          This function copies a parameter of a DHM key.
+ *
+ * \param dest     The MPI object to copy the value into. It must be
+ *                 initialized.
+ * \param ctx      The DHM context to query.
+ * \param param    The parameter to copy.
+ *
+ * \return         \c 0 on success.
+ * \return         #MBEDTLS_ERR_DHM_BAD_INPUT_DATA if \p field is invalid.
+ * \return         An \c MBEDTLS_ERR_MPI_XXX error code if the copy fails.
+ */
+int mbedtls_dhm_get_value( mbedtls_mpi *dest,
+                           const mbedtls_dhm_context *ctx,
+                           mbedtls_dhm_parameter param );
+
+/**
  * \brief          This function frees and clears the components
  *                 of a DHM context.
  *
diff --git a/library/dhm.c b/library/dhm.c
index 2543be1..cb9299f 100644
--- a/library/dhm.c
+++ b/library/dhm.c
@@ -134,6 +134,37 @@
     return( mbedtls_mpi_size( &ctx->P ) );
 }
 
+int mbedtls_dhm_get_value( mbedtls_mpi *dest,
+                           const mbedtls_dhm_context *ctx,
+                           mbedtls_dhm_parameter param )
+{
+    const mbedtls_mpi *src = NULL;
+    switch( param )
+    {
+        case MBEDTLS_DHM_PARAM_P:
+            src = &ctx->P;
+            break;
+        case MBEDTLS_DHM_PARAM_G:
+            src = &ctx->G;
+            break;
+        case MBEDTLS_DHM_PARAM_X:
+            src = &ctx->X;
+            break;
+        case MBEDTLS_DHM_PARAM_GX:
+            src = &ctx->GX;
+            break;
+        case MBEDTLS_DHM_PARAM_GY:
+            src = &ctx->GY;
+            break;
+        case MBEDTLS_DHM_PARAM_K:
+            src = &ctx->K;
+            break;
+        default:
+            return( MBEDTLS_ERR_DHM_BAD_INPUT_DATA );
+    }
+    return( mbedtls_mpi_copy( dest, src ) );
+}
+
 /*
  * Parse the ServerKeyExchange parameters
  */
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 3bdc1cf..bef6864 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -3871,8 +3871,10 @@
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
-    if( ( ret = mbedtls_mpi_copy( &conf->dhm_P, &dhm_ctx->P ) ) != 0 ||
-        ( ret = mbedtls_mpi_copy( &conf->dhm_G, &dhm_ctx->G ) ) != 0 )
+    if( ( ret = mbedtls_dhm_get_value( &conf->dhm_P, dhm_ctx,
+                                       MBEDTLS_DHM_PARAM_P ) ) != 0 ||
+        ( ret = mbedtls_dhm_get_value( &conf->dhm_G, dhm_ctx,
+                                       MBEDTLS_DHM_PARAM_G ) ) != 0 )
     {
         mbedtls_mpi_free( &conf->dhm_P );
         mbedtls_mpi_free( &conf->dhm_G );
diff --git a/tests/suites/test_suite_dhm.function b/tests/suites/test_suite_dhm.function
index 7e01eb7..d48c4e3 100644
--- a/tests/suites/test_suite_dhm.function
+++ b/tests/suites/test_suite_dhm.function
@@ -1,6 +1,23 @@
 /* BEGIN_HEADER */
 #include "mbedtls/dhm.h"
 
+int check_get_value( const mbedtls_dhm_context *ctx,
+                     mbedtls_dhm_parameter param,
+                     const mbedtls_mpi *expected )
+{
+    mbedtls_mpi actual;
+    int ok = 0;
+    mbedtls_mpi_init( &actual );
+
+    TEST_ASSERT( mbedtls_dhm_get_value( &actual, ctx, param ) == 0 );
+    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &actual, expected ) == 0 );
+    ok = 1;
+
+exit:
+    mbedtls_mpi_free( &actual );
+    return( ok );
+}
+
 /* Sanity checks on a Diffie-Hellman parameter: check the length-value
  * syntax and check that the value is the expected one (taken from the
  * DHM context by the caller). */
@@ -102,6 +119,8 @@
     TEST_ASSERT( mbedtls_mpi_read_string( &ctx_srv.P, radix_P, input_P ) == 0 );
     TEST_ASSERT( mbedtls_mpi_read_string( &ctx_srv.G, radix_G, input_G ) == 0 );
     pub_cli_len = mbedtls_mpi_size( &ctx_srv.P );
+    TEST_ASSERT( check_get_value( &ctx_srv, MBEDTLS_DHM_PARAM_P, &ctx_srv.P ) );
+    TEST_ASSERT( check_get_value( &ctx_srv, MBEDTLS_DHM_PARAM_G, &ctx_srv.G ) );
 
     /*
      * First key exchange
@@ -118,6 +137,9 @@
     ske[ske_len++] = 0;
     ske[ske_len++] = 0;
     TEST_ASSERT( mbedtls_dhm_read_params( &ctx_cli, &p, ske + ske_len ) == 0 );
+    /* The domain parameters must be the same on both side. */
+    TEST_ASSERT( check_get_value( &ctx_cli, MBEDTLS_DHM_PARAM_P, &ctx_srv.P ) );
+    TEST_ASSERT( check_get_value( &ctx_cli, MBEDTLS_DHM_PARAM_G, &ctx_srv.G ) );
 
     TEST_ASSERT( mbedtls_dhm_make_public( &ctx_cli, x_size, pub_cli, pub_cli_len,
                                           &mbedtls_test_rnd_pseudo_rand,
@@ -134,6 +156,17 @@
     TEST_ASSERT( sec_srv_len != 0 );
     TEST_ASSERT( memcmp( sec_srv, sec_cli, sec_srv_len ) == 0 );
 
+    /* Internal value checks */
+    TEST_ASSERT( check_get_value( &ctx_cli, MBEDTLS_DHM_PARAM_X, &ctx_cli.X ) );
+    TEST_ASSERT( check_get_value( &ctx_srv, MBEDTLS_DHM_PARAM_X, &ctx_srv.X ) );
+    /* Cross-checks */
+    TEST_ASSERT( check_get_value( &ctx_cli, MBEDTLS_DHM_PARAM_GX, &ctx_srv.GY ) );
+    TEST_ASSERT( check_get_value( &ctx_cli, MBEDTLS_DHM_PARAM_GY, &ctx_srv.GX ) );
+    TEST_ASSERT( check_get_value( &ctx_cli, MBEDTLS_DHM_PARAM_K, &ctx_srv.K ) );
+    TEST_ASSERT( check_get_value( &ctx_srv, MBEDTLS_DHM_PARAM_GX, &ctx_cli.GY ) );
+    TEST_ASSERT( check_get_value( &ctx_srv, MBEDTLS_DHM_PARAM_GY, &ctx_cli.GX ) );
+    TEST_ASSERT( check_get_value( &ctx_srv, MBEDTLS_DHM_PARAM_K, &ctx_cli.K ) );
+
     /* Re-do calc_secret on server a few times to test update of blinding values */
     for( i = 0; i < 3; i++ )
     {
@@ -231,8 +264,8 @@
 
     TEST_EQUAL( mbedtls_dhm_get_len( &ctx ), (size_t) len );
     TEST_EQUAL( mbedtls_dhm_get_bitlen( &ctx ), mbedtls_mpi_bitlen( &P ) );
-    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &ctx.P, &P ) == 0 );
-    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &ctx.G, &G ) == 0 );
+    TEST_ASSERT( check_get_value( &ctx, MBEDTLS_DHM_PARAM_P, &P ) );
+    TEST_ASSERT( check_get_value( &ctx, MBEDTLS_DHM_PARAM_G, &G ) );
 
 exit:
     mbedtls_mpi_free( &P ); mbedtls_mpi_free( &G );