Make RSA unblinding constant flow

Signed-off-by: Janos Follath <janos.follath@arm.com>
diff --git a/library/rsa.c b/library/rsa.c
index 84403c4..937d4aa 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -34,6 +34,7 @@
 #include "mbedtls/error.h"
 #include "constant_time_internal.h"
 #include "mbedtls/constant_time.h"
+#include "bignum_internal.h"
 
 #include <string.h>
 
@@ -805,6 +806,47 @@
 }
 
 /*
+ * Unblind
+ * T = T * Vf mod N
+ */
+static int rsa_unblind(mbedtls_mpi *T, mbedtls_mpi *Vf, const mbedtls_mpi *N)
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    const size_t nlimbs = N->n;
+    const size_t tlimbs = 2 * (nlimbs + 1);
+
+    mbedtls_mpi_uint mm;
+    mbedtls_mpi_montg_init(&mm, N);
+
+    mbedtls_mpi RR, M_T;
+
+    mbedtls_mpi_init(&RR);
+    mbedtls_mpi_init(&M_T);
+
+    MBEDTLS_MPI_CHK(mbedtls_mpi_get_mont_r2_unsafe(&RR, N));
+    MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&M_T, tlimbs));
+
+    MBEDTLS_MPI_CHK(mbedtls_mpi_grow(T, nlimbs));
+    MBEDTLS_MPI_CHK(mbedtls_mpi_grow(Vf, nlimbs));
+
+    /* T = T * Vf mod N
+     * Reminder: montmul(A, B, N) = A * B * R^-1 mod N
+     * Usually both operands are multiplied by R mod N beforehand, yielding a
+     * result that's also * R mod N (aka "in the Montgomery domain"). Here we
+     * only multiply one operand by R mod N, so the result is directly what we
+     * want - no need to call `mpi_montred()` on it. */
+    mbedtls_mpi_montmul(T, &RR, N, mm, &M_T);
+    mbedtls_mpi_montmul(T, Vf, N, mm, &M_T);
+
+cleanup:
+
+    mbedtls_mpi_free(&RR);
+    mbedtls_mpi_free(&M_T);
+
+    return ret;
+}
+
+/*
  * Exponent blinding supposed to prevent side-channel attacks using multiple
  * traces of measurements to recover the RSA key. The more collisions are there,
  * the more bits of the key can be recovered. See [3].
@@ -1000,8 +1042,7 @@
          * Unblind
          * T = T * Vf mod N
          */
-        MBEDTLS_MPI_CHK(mbedtls_mpi_mul_mpi(&T, &T, &ctx->Vf));
-        MBEDTLS_MPI_CHK(mbedtls_mpi_mod_mpi(&T, &T, &ctx->N));
+        MBEDTLS_MPI_CHK(rsa_unblind(&T, &ctx->Vf, &ctx->N));
     }
 
     /* Verify the result to prevent glitching attacks. */