bignum: use CT gcd for mbedtls_mpi_gcd()

The overall function is still not constant-time, but it just got a lot
less leaky.

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/library/bignum.c b/library/bignum.c
index 432ecb9..7d5103e 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -1819,103 +1819,47 @@
 }
 
 /*
- * Greatest common divisor: G = gcd(A, B)  (HAC 14.54)
+ * Greatest common divisor: G = gcd(A, B)
+ * Wrapper around mbedtls_mpi_gcd_modinv() that removes its restrictions.
  */
 int mbedtls_mpi_gcd(mbedtls_mpi *G, const mbedtls_mpi *A, const mbedtls_mpi *B)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    size_t lz, lzt;
     mbedtls_mpi TA, TB;
 
     mbedtls_mpi_init(&TA); mbedtls_mpi_init(&TB);
 
+    /* Make copies and take absolute values */
     MBEDTLS_MPI_CHK(mbedtls_mpi_copy(&TA, A));
     MBEDTLS_MPI_CHK(mbedtls_mpi_copy(&TB, B));
-
-    lz = mbedtls_mpi_lsb(&TA);
-    lzt = mbedtls_mpi_lsb(&TB);
-
-    /* The loop below gives the correct result when A==0 but not when B==0.
-     * So have a special case for B==0. Leverage the fact that we just
-     * calculated the lsb and lsb(B)==0 iff B is odd or 0 to make the test
-     * slightly more efficient than cmp_int(). */
-    if (lzt == 0 && mbedtls_mpi_get_bit(&TB, 0) == 0) {
-        ret = mbedtls_mpi_copy(G, A);
-        goto cleanup;
-    }
-
-    if (lzt < lz) {
-        lz = lzt;
-    }
-
     TA.s = TB.s = 1;
 
-    /* We mostly follow the procedure described in HAC 14.54, but with some
-     * minor differences:
-     * - Sequences of multiplications or divisions by 2 are grouped into a
-     *   single shift operation.
-     * - The procedure in HAC assumes that 0 < TB <= TA.
-     *     - The condition TB <= TA is not actually necessary for correctness.
-     *       TA and TB have symmetric roles except for the loop termination
-     *       condition, and the shifts at the beginning of the loop body
-     *       remove any significance from the ordering of TA vs TB before
-     *       the shifts.
-     *     - If TA = 0, the loop goes through 0 iterations and the result is
-     *       correctly TB.
-     *     - The case TB = 0 was short-circuited above.
-     *
-     * For the correctness proof below, decompose the original values of
-     * A and B as
-     *   A = sa * 2^a * A' with A'=0 or A' odd, and sa = +-1
-     *   B = sb * 2^b * B' with B'=0 or B' odd, and sb = +-1
-     * Then gcd(A, B) = 2^{min(a,b)} * gcd(A',B'),
-     * and gcd(A',B') is odd or 0.
-     *
-     * At the beginning, we have TA = |A| and TB = |B| so gcd(A,B) = gcd(TA,TB).
-     * The code maintains the following invariant:
-     *     gcd(A,B) = 2^k * gcd(TA,TB) for some k   (I)
-     */
-
-    /* Proof that the loop terminates:
-     * At each iteration, either the right-shift by 1 is made on a nonzero
-     * value and the nonnegative integer bitlen(TA) + bitlen(TB) decreases
-     * by at least 1, or the right-shift by 1 is made on zero and then
-     * TA becomes 0 which ends the loop (TB cannot be 0 if it is right-shifted
-     * since in that case TB is calculated from TB-TA with the condition TB>TA).
-     */
-    while (mbedtls_mpi_cmp_int(&TA, 0) != 0) {
-        /* Divisions by 2 preserve the invariant (I). */
-        MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&TA, mbedtls_mpi_lsb(&TA)));
-        MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&TB, mbedtls_mpi_lsb(&TB)));
-
-        /* Set either TA or TB to |TA-TB|/2. Since TA and TB are both odd,
-         * TA-TB is even so the division by 2 has an integer result.
-         * Invariant (I) is preserved since any odd divisor of both TA and TB
-         * also divides |TA-TB|/2, and any odd divisor of both TA and |TA-TB|/2
-         * also divides TB, and any odd divisor of both TB and |TA-TB|/2 also
-         * divides TA.
-         */
-        if (mbedtls_mpi_cmp_mpi(&TA, &TB) >= 0) {
-            MBEDTLS_MPI_CHK(mbedtls_mpi_sub_abs(&TA, &TA, &TB));
-            MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&TA, 1));
-        } else {
-            MBEDTLS_MPI_CHK(mbedtls_mpi_sub_abs(&TB, &TB, &TA));
-            MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&TB, 1));
-        }
-        /* Note that one of TA or TB is still odd. */
+    /* Handle special cases (that don't happen in crypto usage) */
+    if (mbedtls_mpi_core_check_zero_ct(A.p, A.n) == MBEDTLS_CT_FALSE) {
+        return mbedtls_mpi_copy(G, TB); // GCD(0, B) = abs(B)
+    }
+    if (mbedtls_mpi_core_check_zero_ct(B.p, B.n) == MBEDTLS_CT_FALSE) {
+        return mbedtls_mpi_copy(G, A); // GCD(A, 0) = A (for now)
     }
 
-    /* By invariant (I), gcd(A,B) = 2^k * gcd(TA,TB) for some k.
-     * At the loop exit, TA = 0, so gcd(TA,TB) = TB.
-     * - If there was at least one loop iteration, then one of TA or TB is odd,
-     *   and TA = 0, so TB is odd and gcd(TA,TB) = gcd(A',B'). In this case,
-     *   lz = min(a,b) so gcd(A,B) = 2^lz * TB.
-     * - If there was no loop iteration, then A was 0, and gcd(A,B) = B.
-     *   In this case, lz = 0 and B = TB so gcd(A,B) = B = 2^lz * TB as well.
-     */
+    /* Make the two values the same (non-zero) number of limbs */
+    MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&TA, TB.n != 0 ? TB.n : 1));
+    MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&TB, TA.n)); // non-zero from above
 
-    MBEDTLS_MPI_CHK(mbedtls_mpi_shift_l(&TB, lz));
-    MBEDTLS_MPI_CHK(mbedtls_mpi_copy(G, &TB));
+    const size_t za = mbedtls_mpi_lsb(&TA);
+    const size_t zb = mbedtls_mpi_lsb(&TB);
+
+    MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&TA, za));
+    MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&TB, zb));
+
+    /* Ensure A <= B: if B < A, swap them */
+    mbedtls_ct_condition_t swap = mbedtls_mpi_core_lt_ct(TB.p, TA.p, TA.n);
+    mbedtls_mpi_core_cond_swap(TA.p, TB.p, TA.n, swap);
+
+    MBEDTLS_MPI_CHK(mbedtls_mpi_gcd_modinv_odd(G, NULL, &TA, &TB));
+
+    size_t zg = za > zb ? zb : za; // zg = min(za, zb)
+    MBEDTLS_MPI_CHK(mbedtls_mpi_shift_l(G, zg));
 
 cleanup: