Simplify usage of DHM blinding
diff --git a/include/polarssl/dhm.h b/include/polarssl/dhm.h
index 0152dc9..75dff19 100644
--- a/include/polarssl/dhm.h
+++ b/include/polarssl/dhm.h
@@ -230,13 +230,11 @@
  *
  * \return         0 if successful, or an POLARSSL_ERR_DHM_XXX error code
  *
- * \note           If f_rng is not NULL, it is used to blind the input as
- *                 countermeasure against timing attacks. This is only useful
- *                 when this function is called repeatedly with the same
- *                 secret value (X field), eg when using DH key exchange as
- *                 opposed to DHE. It is recommended to use a non-NULL f_rng
- *                 only when needed, since otherwise this countermeasure has
- *                 high overhead.
+ * \note           If non-NULL, f_rng is used to blind the input as
+ *                 countermeasure against timing attacks. Blinding is
+ *                 automatically used if and only if our secret value X is
+ *                 re-used and costs nothing otherwise, so it is recommended
+ *                 to always pass a non-NULL f_rng argument.
  */
 int dhm_calc_secret( dhm_context *ctx,
                      unsigned char *output, size_t *olen,
diff --git a/library/dhm.c b/library/dhm.c
index 625837e..dc815d9 100644
--- a/library/dhm.c
+++ b/library/dhm.c
@@ -273,51 +273,55 @@
     int ret, count;
 
     /*
-     * If Vi is initialized, update it by squaring it
+     * Don't use any blinding the first time a particular X is used,
+     * but remember it to use blinding next time.
      */
-    if( ctx->Vi.p != NULL )
+    if( mpi_cmp_mpi( &ctx->X, &ctx->_X ) != 0 )
     {
-        MPI_CHK( mpi_mul_mpi( &ctx->Vi, &ctx->Vi, &ctx->Vi ) );
-        MPI_CHK( mpi_mod_mpi( &ctx->Vi, &ctx->Vi, &ctx->P ) );
-    }
-    else
-    {
-        /* Vi = random( 2, P-1 ) */
-        count = 0;
-        do
-        {
-            mpi_fill_random( &ctx->Vi, mpi_size( &ctx->P ), f_rng, p_rng );
+        MPI_CHK( mpi_copy( &ctx->_X, &ctx->X ) );
+        MPI_CHK( mpi_lset( &ctx->Vi, 1 ) );
+        MPI_CHK( mpi_lset( &ctx->Vf, 1 ) );
 
-            while( mpi_cmp_mpi( &ctx->Vi, &ctx->P ) >= 0 )
-                mpi_shift_r( &ctx->Vi, 1 );
-
-            if( count++ > 10 )
-                return( POLARSSL_ERR_MPI_NOT_ACCEPTABLE );
-        }
-        while( mpi_cmp_int( &ctx->Vi, 1 ) <= 0 );
-    }
-
-    /*
-     * If X did not change, update Vf by squaring it too
-     */
-    if( mpi_cmp_mpi( &ctx->X, &ctx->_X ) == 0 )
-    {
-        MPI_CHK( mpi_mul_mpi( &ctx->Vf, &ctx->Vf, &ctx->Vf ) );
-        MPI_CHK( mpi_mod_mpi( &ctx->Vf, &ctx->Vf, &ctx->P ) );
         return( 0 );
     }
 
     /*
-     * Otherwise, compute Vf from scratch
+     * Ok, we need blinding. Can we re-use existing values?
+     * If yes, just update them by squaring them.
      */
+    if( mpi_cmp_int( &ctx->Vi, 1 ) != 0 )
+    {
+        MPI_CHK( mpi_mul_mpi( &ctx->Vi, &ctx->Vi, &ctx->Vi ) );
+        MPI_CHK( mpi_mod_mpi( &ctx->Vi, &ctx->Vi, &ctx->P ) );
+
+        MPI_CHK( mpi_mul_mpi( &ctx->Vf, &ctx->Vf, &ctx->Vf ) );
+        MPI_CHK( mpi_mod_mpi( &ctx->Vf, &ctx->Vf, &ctx->P ) );
+
+        return( 0 );
+    }
+
+    /*
+     * We need to generate blinding values from scratch
+     */
+
+    /* Vi = random( 2, P-1 ) */
+    count = 0;
+    do
+    {
+        mpi_fill_random( &ctx->Vi, mpi_size( &ctx->P ), f_rng, p_rng );
+
+        while( mpi_cmp_mpi( &ctx->Vi, &ctx->P ) >= 0 )
+            mpi_shift_r( &ctx->Vi, 1 );
+
+        if( count++ > 10 )
+            return( POLARSSL_ERR_MPI_NOT_ACCEPTABLE );
+    }
+    while( mpi_cmp_int( &ctx->Vi, 1 ) <= 0 );
 
     /* Vf = Vi^-X mod P */
     MPI_CHK( mpi_inv_mod( &ctx->Vf, &ctx->Vi, &ctx->P ) );
     MPI_CHK( mpi_exp_mod( &ctx->Vf, &ctx->Vf, &ctx->X, &ctx->P, &ctx->RP ) );
 
-    /* Remember secret associated with Vi and Vf */
-    MPI_CHK( mpi_copy( &ctx->_X, &ctx->X ) );;
-
 cleanup:
     return( ret );
 }
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index 300001e..4cc28c3 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -1702,11 +1702,10 @@
 
         ssl->handshake->pmslen = ssl->handshake->dhm_ctx.len;
 
-        /* No blinding needed for DHE, but will be needed for fixed DH! */
         if( ( ret = dhm_calc_secret( &ssl->handshake->dhm_ctx,
                                       ssl->handshake->premaster,
                                      &ssl->handshake->pmslen,
-                                     NULL, NULL ) ) != 0 )
+                                      ssl->f_rng, ssl->p_rng ) ) != 0 )
         {
             SSL_DEBUG_RET( 1, "dhm_calc_secret", ret );
             return( ret );
@@ -1834,9 +1833,8 @@
 
         *(p++) = (unsigned char)( ssl->handshake->dhm_ctx.len >> 8 );
         *(p++) = (unsigned char)( ssl->handshake->dhm_ctx.len      );
-        /* No blinding needed since this is ephemeral DHM */
         if( ( ret = dhm_calc_secret( &ssl->handshake->dhm_ctx,
-                                      p, &n, NULL, NULL ) ) != 0 )
+                                      p, &n, ssl->f_rng, ssl->p_rng ) ) != 0 )
         {
             SSL_DEBUG_RET( 1, "dhm_calc_secret", ret );
             return( ret );
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index 88afc84..0ef3423 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -2373,7 +2373,7 @@
         if( ( ret = dhm_calc_secret( &ssl->handshake->dhm_ctx,
                                       ssl->handshake->premaster,
                                      &ssl->handshake->pmslen,
-                                      NULL, NULL ) ) != 0 )
+                                      ssl->f_rng, ssl->p_rng ) ) != 0 )
         {
             SSL_DEBUG_RET( 1, "dhm_calc_secret", ret );
             return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_KEY_EXCHANGE_CS );
@@ -2460,7 +2460,7 @@
 
         /* No blinding needed since this is ephemeral DHM */
         if( ( ret = dhm_calc_secret( &ssl->handshake->dhm_ctx,
-                                      p, &n, NULL, NULL ) ) != 0 )
+                                      p, &n, ssl->f_rng, ssl->p_rng ) ) != 0 )
         {
             SSL_DEBUG_RET( 1, "dhm_calc_secret", ret );
             return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_KEY_EXCHANGE_CS );
diff --git a/programs/pkey/dh_client.c b/programs/pkey/dh_client.c
index c5c6f75..4e78d6f 100644
--- a/programs/pkey/dh_client.c
+++ b/programs/pkey/dh_client.c
@@ -239,7 +239,8 @@
     fflush( stdout );
 
     n = dhm.len;
-    if( ( ret = dhm_calc_secret( &dhm, buf, &n, NULL, NULL ) ) != 0 )
+    if( ( ret = dhm_calc_secret( &dhm, buf, &n,
+                                 ctr_drbg_random, &ctr_drbg ) ) != 0 )
     {
         printf( " failed\n  ! dhm_calc_secret returned %d\n\n", ret );
         goto exit;
diff --git a/programs/pkey/dh_server.c b/programs/pkey/dh_server.c
index 3382307..a657362 100644
--- a/programs/pkey/dh_server.c
+++ b/programs/pkey/dh_server.c
@@ -242,7 +242,8 @@
     printf( "\n  . Shared secret: " );
     fflush( stdout );
 
-    if( ( ret = dhm_calc_secret( &dhm, buf, &n, NULL, NULL ) ) != 0 )
+    if( ( ret = dhm_calc_secret( &dhm, buf, &n,
+                                 ctr_drbg_random, &ctr_drbg ) ) != 0 )
     {
         printf( " failed\n  ! dhm_calc_secret returned %d\n\n", ret );
         goto exit;
diff --git a/programs/test/benchmark.c b/programs/test/benchmark.c
index 52aecf2..436912a 100644
--- a/programs/test/benchmark.c
+++ b/programs/test/benchmark.c
@@ -558,7 +558,7 @@
     {
         olen = sizeof( buf );
         ret |= dhm_make_public( &dhm, dhm.len, buf, dhm.len, myrand, NULL );
-        ret |= dhm_calc_secret( &dhm, buf, &olen, NULL, NULL );
+        ret |= dhm_calc_secret( &dhm, buf, &olen, myrand, NULL );
     }
 
     if( ret != 0 )
@@ -617,7 +617,7 @@
     for( i = 1; ! alarmed && ! ret ; i++ )
     {
         olen = sizeof( buf );
-        ret |= dhm_calc_secret( &dhm, buf, &olen, NULL, NULL );
+        ret |= dhm_calc_secret( &dhm, buf, &olen, myrand, NULL );
     }
 
     if( ret != 0 )
@@ -643,7 +643,7 @@
     {
         olen = sizeof( buf );
         ret |= dhm_make_public( &dhm, dhm.len, buf, dhm.len, myrand, NULL );
-        ret |= dhm_calc_secret( &dhm, buf, &olen, NULL, NULL );
+        ret |= dhm_calc_secret( &dhm, buf, &olen, myrand, NULL );
     }
 
     if( ret != 0 )
diff --git a/tests/suites/test_suite_dhm.function b/tests/suites/test_suite_dhm.function
index dcf2363..24e7b08 100644
--- a/tests/suites/test_suite_dhm.function
+++ b/tests/suites/test_suite_dhm.function
@@ -22,7 +22,7 @@
     size_t pub_cli_len = 0;
     size_t sec_srv_len = 1000;
     size_t sec_cli_len = 1000;
-    int x_size;
+    int x_size, i;
     rnd_pseudo_info rnd_info;
 
     memset( &ctx_srv, 0x00, sizeof( dhm_context ) );
@@ -59,13 +59,16 @@
     TEST_ASSERT( sec_srv_len != 0 );
     TEST_ASSERT( memcmp( sec_srv, sec_cli, sec_srv_len ) == 0 );
 
-    /* Re-do calc_secret on server to test update of blinding values */
-    sec_srv_len = 1000;
-    TEST_ASSERT( dhm_calc_secret( &ctx_srv, sec_srv, &sec_srv_len, &rnd_pseudo_rand, &rnd_info ) == 0 );
+    /* Re-do calc_secret on server a few times to test update of blinding values */
+    for( i = 0; i < 3; i++ )
+    {
+        sec_srv_len = 1000;
+        TEST_ASSERT( dhm_calc_secret( &ctx_srv, sec_srv, &sec_srv_len, &rnd_pseudo_rand, &rnd_info ) == 0 );
 
-    TEST_ASSERT( sec_srv_len == sec_cli_len );
-    TEST_ASSERT( sec_srv_len != 0 );
-    TEST_ASSERT( memcmp( sec_srv, sec_cli, sec_srv_len ) == 0 );
+        TEST_ASSERT( sec_srv_len == sec_cli_len );
+        TEST_ASSERT( sec_srv_len != 0 );
+        TEST_ASSERT( memcmp( sec_srv, sec_cli, sec_srv_len ) == 0 );
+    }
 
     /*
      * Second key exchange to test change of blinding values on server