mpi_exp_mod: load the output variable to the table

This is done in preparation for constant time loading that will be added
in a later commit.

Signed-off-by: Janos Follath <janos.follath@arm.com>
diff --git a/library/bignum.c b/library/bignum.c
index 8717c8a..1c42ef2 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -1974,7 +1974,7 @@
     size_t i, j, nblimbs;
     size_t bufsize, nbits;
     mbedtls_mpi_uint ei, mm, state;
-    mbedtls_mpi RR, T, W[ 1 << MBEDTLS_MPI_WINDOW_SIZE ], WW, Apos;
+    mbedtls_mpi RR, T, W[ ( 1 << MBEDTLS_MPI_WINDOW_SIZE ) + 1 ], WW, Apos;
     int neg;
 
     MPI_VALIDATE_RET( X != NULL );
@@ -2022,6 +2022,14 @@
     MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &T, j * 2 ) );
 
     /*
+     * Append the output variable to the end of the  table for constant time
+     * lookup. From this point on we need to use the table entry in each
+     * calculation, this makes it safe to use simple assignment.
+     */
+    const size_t x_index = sizeof( W ) / sizeof( W[0] ) - 1;
+    W[x_index] = *X;
+
+    /*
      * Compensate for negative A (and correct at the end)
      */
     neg = ( A->s == -1 );
@@ -2066,10 +2074,10 @@
     mpi_montmul( &W[1], &RR, N, mm, &T );
 
     /*
-     * X = R^2 * R^-1 mod N = R mod N
+     * W[x_index] = R^2 * R^-1 mod N = R mod N
      */
-    MBEDTLS_MPI_CHK( mbedtls_mpi_copy( X, &RR ) );
-    mpi_montred( X, N, mm, &T );
+    MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &W[x_index], &RR ) );
+    mpi_montred( &W[x_index], N, mm, &T );
 
     if( wsize > 1 )
     {
@@ -2127,9 +2135,9 @@
         if( ei == 0 && state == 1 )
         {
             /*
-             * out of window, square X
+             * out of window, square W[x_index]
              */
-            mpi_montmul( X, X, N, mm, &T );
+            mpi_montmul( &W[x_index], &W[x_index], N, mm, &T );
             continue;
         }
 
@@ -2144,16 +2152,16 @@
         if( nbits == wsize )
         {
             /*
-             * X = X^wsize R^-1 mod N
+             * W[x_index] = W[x_index]^wsize R^-1 mod N
              */
             for( i = 0; i < wsize; i++ )
-                mpi_montmul( X, X, N, mm, &T );
+                mpi_montmul( &W[x_index], &W[x_index], N, mm, &T );
 
             /*
-             * X = X * W[wbits] R^-1 mod N
+             * W[x_index] = W[x_index] * W[wbits] R^-1 mod N
              */
             MBEDTLS_MPI_CHK( mpi_select( &WW, W, (size_t) 1 << wsize, wbits ) );
-            mpi_montmul( X, &WW, N, mm, &T );
+            mpi_montmul( &W[x_index], &WW, N, mm, &T );
 
             state--;
             nbits = 0;
@@ -2166,25 +2174,30 @@
      */
     for( i = 0; i < nbits; i++ )
     {
-        mpi_montmul( X, X, N, mm, &T );
+        mpi_montmul( &W[x_index], &W[x_index], N, mm, &T );
 
         wbits <<= 1;
 
         if( ( wbits & ( one << wsize ) ) != 0 )
-            mpi_montmul( X, &W[1], N, mm, &T );
+            mpi_montmul( &W[x_index], &W[1], N, mm, &T );
     }
 
     /*
-     * X = A^E * R * R^-1 mod N = A^E mod N
+     * W[x_index] = A^E * R * R^-1 mod N = A^E mod N
      */
-    mpi_montred( X, N, mm, &T );
+    mpi_montred( &W[x_index], N, mm, &T );
 
     if( neg && E->n != 0 && ( E->p[0] & 1 ) != 0 )
     {
-        X->s = -1;
-        MBEDTLS_MPI_CHK( mbedtls_mpi_add_mpi( X, N, X ) );
+        W[x_index].s = -1;
+        MBEDTLS_MPI_CHK( mbedtls_mpi_add_mpi( &W[x_index], N, &W[x_index] ) );
     }
 
+    /*
+     * Load the result in the output variable.
+     */
+    *X = W[x_index];
+
 cleanup:
 
     for( i = ( one << ( wsize - 1 ) ); i < ( one << wsize ); i++ )