Merge pull request #3412 from gilles-peskine-arm/montmul-cmp-branch-2.7

Backport 2.7: Remove a secret-dependent branch in Montgomery multiplication
diff --git a/ChangeLog.d/montmul-cmp-branch.txt b/ChangeLog.d/montmul-cmp-branch.txt
new file mode 100644
index 0000000..5994518
--- /dev/null
+++ b/ChangeLog.d/montmul-cmp-branch.txt
@@ -0,0 +1,6 @@
+Security
+   * Fix a side channel vulnerability in modular exponentiation that could
+     reveal an RSA private key used in a secure enclave. Noticed by Sangho Lee,
+     Ming-Wei Shih, Prasun Gera, Taesoo Kim and Hyesoon Kim (Georgia Institute
+     of Technology); and Marcus Peinado (Microsoft Research). Reported by Raoul
+     Strackx (Fortanix) in #3394.
diff --git a/library/bignum.c b/library/bignum.c
index 827a3cb..68460df 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -227,6 +227,22 @@
 }
 
 /*
+ * Conditionally assign dest = src, without leaking information
+ * about whether the assignment was made or not.
+ * dest and src must be arrays of limbs of size n.
+ * assign must be 0 or 1.
+ */
+static void mpi_safe_cond_assign( size_t n,
+                                  mbedtls_mpi_uint *dest,
+                                  const mbedtls_mpi_uint *src,
+                                  unsigned char assign )
+{
+    size_t i;
+    for( i = 0; i < n; i++ )
+        dest[i] = dest[i] * ( 1 - assign ) + src[i] * assign;
+}
+
+/*
  * Conditionally assign X = Y, without leaking information
  * about whether the assignment was made or not.
  * (Leaking information about the respective sizes of X and Y is ok however.)
@@ -243,10 +259,9 @@
 
     X->s = X->s * ( 1 - assign ) + Y->s * assign;
 
-    for( i = 0; i < Y->n; i++ )
-        X->p[i] = X->p[i] * ( 1 - assign ) + Y->p[i] * assign;
+    mpi_safe_cond_assign( Y->n, X->p, Y->p, assign );
 
-    for( ; i < X->n; i++ )
+    for( i = Y->n; i < X->n; i++ )
         X->p[i] *= ( 1 - assign );
 
 cleanup:
@@ -1089,10 +1104,24 @@
     return( ret );
 }
 
-/*
- * Helper for mbedtls_mpi subtraction
+/**
+ * Helper for mbedtls_mpi subtraction.
+ *
+ * Calculate d - s where d and s have the same size.
+ * This function operates modulo (2^ciL)^n and returns the carry
+ * (1 if there was a wraparound, i.e. if `d < s`, and 0 otherwise).
+ *
+ * \param n             Number of limbs of \p d and \p s.
+ * \param[in,out] d     On input, the left operand.
+ *                      On output, the result of the subtraction:
+ * \param[in] s         The right operand.
+ *
+ * \return              1 if `d < s`.
+ *                      0 if `d >= s`.
  */
-static void mpi_sub_hlp( size_t n, mbedtls_mpi_uint *s, mbedtls_mpi_uint *d )
+static mbedtls_mpi_uint mpi_sub_hlp( size_t n,
+                                     mbedtls_mpi_uint *d,
+                                     const mbedtls_mpi_uint *s )
 {
     size_t i;
     mbedtls_mpi_uint c, z;
@@ -1103,24 +1132,18 @@
         c = ( *d < *s ) + z; *d -= *s;
     }
 
-    while( c != 0 )
-    {
-        z = ( *d < c ); *d -= c;
-        c = z; i++; d++;
-    }
+    return( c );
 }
 
 /*
- * Unsigned subtraction: X = |A| - |B|  (HAC 14.9)
+ * Unsigned subtraction: X = |A| - |B|  (HAC 14.9, 14.10)
  */
 int mbedtls_mpi_sub_abs( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B )
 {
     mbedtls_mpi TB;
     int ret;
     size_t n;
-
-    if( mbedtls_mpi_cmp_abs( A, B ) < 0 )
-        return( MBEDTLS_ERR_MPI_NEGATIVE_VALUE );
+    mbedtls_mpi_uint carry;
 
     mbedtls_mpi_init( &TB );
 
@@ -1144,7 +1167,18 @@
         if( B->p[n - 1] != 0 )
             break;
 
-    mpi_sub_hlp( n, B->p, X->p );
+    carry = mpi_sub_hlp( n, X->p, B->p );
+    if( carry != 0 )
+    {
+        /* Propagate the carry to the first nonzero limb of X. */
+        for( ; n < X->n && X->p[n] == 0; n++ )
+            --X->p[n];
+        /* If we ran out of space for the carry, it means that the result
+         * is negative. */
+        if( n == X->n )
+            return( MBEDTLS_ERR_MPI_NEGATIVE_VALUE );
+        --X->p[n];
+    }
 
 cleanup:
 
@@ -1696,18 +1730,34 @@
     *mm = ~x + 1;
 }
 
-/*
- * Montgomery multiplication: A = A * B * R^-1 mod N  (HAC 14.36)
+/** Montgomery multiplication: A = A * B * R^-1 mod N  (HAC 14.36)
+ *
+ * \param[in,out]   A   One of the numbers to multiply.
+ *                      It must have at least as many limbs as N
+ *                      (A->n >= N->n), and any limbs beyond n are ignored.
+ *                      On successful completion, A contains the result of
+ *                      the multiplication A * B * R^-1 mod N where
+ *                      R = (2^ciL)^n.
+ * \param[in]       B   One of the numbers to multiply.
+ *                      It must be nonzero and must not have more limbs than N
+ *                      (B->n <= N->n).
+ * \param[in]       N   The modulo. N must be odd.
+ * \param           mm  The value calculated by `mpi_montg_init(&mm, N)`.
+ *                      This is -N^-1 mod 2^ciL.
+ * \param[in,out]   T   A bignum for temporary storage.
+ *                      It must be at least twice the limb size of N plus 2
+ *                      (T->n >= 2 * (N->n + 1)).
+ *                      Its initial content is unused and
+ *                      its final content is indeterminate.
+ *                      Note that unlike the usual convention in the library
+ *                      for `const mbedtls_mpi*`, the content of T can change.
  */
-static int mpi_montmul( mbedtls_mpi *A, const mbedtls_mpi *B, const mbedtls_mpi *N, mbedtls_mpi_uint mm,
+static void mpi_montmul( mbedtls_mpi *A, const mbedtls_mpi *B, const mbedtls_mpi *N, mbedtls_mpi_uint mm,
                          const mbedtls_mpi *T )
 {
     size_t i, n, m;
     mbedtls_mpi_uint u0, u1, *d;
 
-    if( T->n < N->n + 1 || T->p == NULL )
-        return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
-
     memset( T->p, 0, T->n * ciL );
 
     d = T->p;
@@ -1728,21 +1778,33 @@
         *d++ = u0; d[n + 1] = 0;
     }
 
-    memcpy( A->p, d, ( n + 1 ) * ciL );
+    /* At this point, d is either the desired result or the desired result
+     * plus N. We now potentially subtract N, avoiding leaking whether the
+     * subtraction is performed through side channels. */
 
-    if( mbedtls_mpi_cmp_abs( A, N ) >= 0 )
-        mpi_sub_hlp( n, N->p, A->p );
-    else
-        /* prevent timing attacks */
-        mpi_sub_hlp( n, A->p, T->p );
-
-    return( 0 );
+    /* Copy the n least significant limbs of d to A, so that
+     * A = d if d < N (recall that N has n limbs). */
+    memcpy( A->p, d, n * ciL );
+    /* If d >= N then we want to set A to d - N. To prevent timing attacks,
+     * do the calculation without using conditional tests. */
+    /* Set d to d0 + (2^biL)^n - N where d0 is the current value of d. */
+    d[n] += 1;
+    d[n] -= mpi_sub_hlp( n, d, N->p );
+    /* If d0 < N then d < (2^biL)^n
+     * so d[n] == 0 and we want to keep A as it is.
+     * If d0 >= N then d >= (2^biL)^n, and d <= (2^biL)^n + N < 2 * (2^biL)^n
+     * so d[n] == 1 and we want to set A to the result of the subtraction
+     * which is d - (2^biL)^n, i.e. the n least significant limbs of d.
+     * This exactly corresponds to a conditional assignment. */
+    mpi_safe_cond_assign( n, A->p, d, (unsigned char) d[n] );
 }
 
 /*
  * Montgomery reduction: A = A * R^-1 mod N
+ *
+ * See mpi_montmul() regarding constraints and guarantees on the parameters.
  */
-static int mpi_montred( mbedtls_mpi *A, const mbedtls_mpi *N, mbedtls_mpi_uint mm, const mbedtls_mpi *T )
+static void mpi_montred( mbedtls_mpi *A, const mbedtls_mpi *N, mbedtls_mpi_uint mm, const mbedtls_mpi *T )
 {
     mbedtls_mpi_uint z = 1;
     mbedtls_mpi U;
@@ -1750,7 +1812,7 @@
     U.n = U.s = (int) z;
     U.p = &z;
 
-    return( mpi_montmul( A, &U, N, mm, T ) );
+    mpi_montmul( A, &U, N, mm, T );
 }
 
 /*
@@ -1829,13 +1891,13 @@
     else
         MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &W[1], A ) );
 
-    MBEDTLS_MPI_CHK( mpi_montmul( &W[1], &RR, N, mm, &T ) );
+    mpi_montmul( &W[1], &RR, N, mm, &T );
 
     /*
      * X = R^2 * R^-1 mod N = R mod N
      */
     MBEDTLS_MPI_CHK( mbedtls_mpi_copy( X, &RR ) );
-    MBEDTLS_MPI_CHK( mpi_montred( X, N, mm, &T ) );
+    mpi_montred( X, N, mm, &T );
 
     if( wsize > 1 )
     {
@@ -1848,7 +1910,7 @@
         MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &W[j], &W[1]    ) );
 
         for( i = 0; i < wsize - 1; i++ )
-            MBEDTLS_MPI_CHK( mpi_montmul( &W[j], &W[j], N, mm, &T ) );
+            mpi_montmul( &W[j], &W[j], N, mm, &T );
 
         /*
          * W[i] = W[i - 1] * W[1]
@@ -1858,7 +1920,7 @@
             MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &W[i], N->n + 1 ) );
             MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &W[i], &W[i - 1] ) );
 
-            MBEDTLS_MPI_CHK( mpi_montmul( &W[i], &W[1], N, mm, &T ) );
+            mpi_montmul( &W[i], &W[1], N, mm, &T );
         }
     }
 
@@ -1895,7 +1957,7 @@
             /*
              * out of window, square X
              */
-            MBEDTLS_MPI_CHK( mpi_montmul( X, X, N, mm, &T ) );
+            mpi_montmul( X, X, N, mm, &T );
             continue;
         }
 
@@ -1913,12 +1975,12 @@
              * X = X^wsize R^-1 mod N
              */
             for( i = 0; i < wsize; i++ )
-                MBEDTLS_MPI_CHK( mpi_montmul( X, X, N, mm, &T ) );
+                mpi_montmul( X, X, N, mm, &T );
 
             /*
              * X = X * W[wbits] R^-1 mod N
              */
-            MBEDTLS_MPI_CHK( mpi_montmul( X, &W[wbits], N, mm, &T ) );
+            mpi_montmul( X, &W[wbits], N, mm, &T );
 
             state--;
             nbits = 0;
@@ -1931,18 +1993,18 @@
      */
     for( i = 0; i < nbits; i++ )
     {
-        MBEDTLS_MPI_CHK( mpi_montmul( X, X, N, mm, &T ) );
+        mpi_montmul( X, X, N, mm, &T );
 
         wbits <<= 1;
 
         if( ( wbits & ( one << wsize ) ) != 0 )
-            MBEDTLS_MPI_CHK( mpi_montmul( X, &W[1], N, mm, &T ) );
+            mpi_montmul( X, &W[1], N, mm, &T );
     }
 
     /*
      * X = A^E * R * R^-1 mod N = A^E mod N
      */
-    MBEDTLS_MPI_CHK( mpi_montred( X, N, mm, &T ) );
+    mpi_montred( X, N, mm, &T );
 
     if( neg && E->n != 0 && ( E->p[0] & 1 ) != 0 )
     {