bignum: use CT gcd for mbedtls_mpi_gcd()
The overall function is still not constant-time, but it just got a lot
less leaky.
Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/library/bignum.c b/library/bignum.c
index 432ecb9..7d5103e 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -1819,103 +1819,47 @@
}
/*
- * Greatest common divisor: G = gcd(A, B) (HAC 14.54)
+ * Greatest common divisor: G = gcd(A, B)
+ * Wrapper around mbedtls_mpi_gcd_modinv() that removes its restrictions.
*/
int mbedtls_mpi_gcd(mbedtls_mpi *G, const mbedtls_mpi *A, const mbedtls_mpi *B)
{
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
- size_t lz, lzt;
mbedtls_mpi TA, TB;
mbedtls_mpi_init(&TA); mbedtls_mpi_init(&TB);
+ /* Make copies and take absolute values */
MBEDTLS_MPI_CHK(mbedtls_mpi_copy(&TA, A));
MBEDTLS_MPI_CHK(mbedtls_mpi_copy(&TB, B));
-
- lz = mbedtls_mpi_lsb(&TA);
- lzt = mbedtls_mpi_lsb(&TB);
-
- /* The loop below gives the correct result when A==0 but not when B==0.
- * So have a special case for B==0. Leverage the fact that we just
- * calculated the lsb and lsb(B)==0 iff B is odd or 0 to make the test
- * slightly more efficient than cmp_int(). */
- if (lzt == 0 && mbedtls_mpi_get_bit(&TB, 0) == 0) {
- ret = mbedtls_mpi_copy(G, A);
- goto cleanup;
- }
-
- if (lzt < lz) {
- lz = lzt;
- }
-
TA.s = TB.s = 1;
- /* We mostly follow the procedure described in HAC 14.54, but with some
- * minor differences:
- * - Sequences of multiplications or divisions by 2 are grouped into a
- * single shift operation.
- * - The procedure in HAC assumes that 0 < TB <= TA.
- * - The condition TB <= TA is not actually necessary for correctness.
- * TA and TB have symmetric roles except for the loop termination
- * condition, and the shifts at the beginning of the loop body
- * remove any significance from the ordering of TA vs TB before
- * the shifts.
- * - If TA = 0, the loop goes through 0 iterations and the result is
- * correctly TB.
- * - The case TB = 0 was short-circuited above.
- *
- * For the correctness proof below, decompose the original values of
- * A and B as
- * A = sa * 2^a * A' with A'=0 or A' odd, and sa = +-1
- * B = sb * 2^b * B' with B'=0 or B' odd, and sb = +-1
- * Then gcd(A, B) = 2^{min(a,b)} * gcd(A',B'),
- * and gcd(A',B') is odd or 0.
- *
- * At the beginning, we have TA = |A| and TB = |B| so gcd(A,B) = gcd(TA,TB).
- * The code maintains the following invariant:
- * gcd(A,B) = 2^k * gcd(TA,TB) for some k (I)
- */
-
- /* Proof that the loop terminates:
- * At each iteration, either the right-shift by 1 is made on a nonzero
- * value and the nonnegative integer bitlen(TA) + bitlen(TB) decreases
- * by at least 1, or the right-shift by 1 is made on zero and then
- * TA becomes 0 which ends the loop (TB cannot be 0 if it is right-shifted
- * since in that case TB is calculated from TB-TA with the condition TB>TA).
- */
- while (mbedtls_mpi_cmp_int(&TA, 0) != 0) {
- /* Divisions by 2 preserve the invariant (I). */
- MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&TA, mbedtls_mpi_lsb(&TA)));
- MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&TB, mbedtls_mpi_lsb(&TB)));
-
- /* Set either TA or TB to |TA-TB|/2. Since TA and TB are both odd,
- * TA-TB is even so the division by 2 has an integer result.
- * Invariant (I) is preserved since any odd divisor of both TA and TB
- * also divides |TA-TB|/2, and any odd divisor of both TA and |TA-TB|/2
- * also divides TB, and any odd divisor of both TB and |TA-TB|/2 also
- * divides TA.
- */
- if (mbedtls_mpi_cmp_mpi(&TA, &TB) >= 0) {
- MBEDTLS_MPI_CHK(mbedtls_mpi_sub_abs(&TA, &TA, &TB));
- MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&TA, 1));
- } else {
- MBEDTLS_MPI_CHK(mbedtls_mpi_sub_abs(&TB, &TB, &TA));
- MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&TB, 1));
- }
- /* Note that one of TA or TB is still odd. */
+ /* Handle special cases (that don't happen in crypto usage) */
+ if (mbedtls_mpi_core_check_zero_ct(A.p, A.n) == MBEDTLS_CT_FALSE) {
+ return mbedtls_mpi_copy(G, TB); // GCD(0, B) = abs(B)
+ }
+ if (mbedtls_mpi_core_check_zero_ct(B.p, B.n) == MBEDTLS_CT_FALSE) {
+ return mbedtls_mpi_copy(G, A); // GCD(A, 0) = A (for now)
}
- /* By invariant (I), gcd(A,B) = 2^k * gcd(TA,TB) for some k.
- * At the loop exit, TA = 0, so gcd(TA,TB) = TB.
- * - If there was at least one loop iteration, then one of TA or TB is odd,
- * and TA = 0, so TB is odd and gcd(TA,TB) = gcd(A',B'). In this case,
- * lz = min(a,b) so gcd(A,B) = 2^lz * TB.
- * - If there was no loop iteration, then A was 0, and gcd(A,B) = B.
- * In this case, lz = 0 and B = TB so gcd(A,B) = B = 2^lz * TB as well.
- */
+ /* Make the two values the same (non-zero) number of limbs */
+ MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&TA, TB.n != 0 ? TB.n : 1));
+ MBEDTLS_MPI_CHK(mbedtls_mpi_grow(&TB, TA.n)); // non-zero from above
- MBEDTLS_MPI_CHK(mbedtls_mpi_shift_l(&TB, lz));
- MBEDTLS_MPI_CHK(mbedtls_mpi_copy(G, &TB));
+ const size_t za = mbedtls_mpi_lsb(&TA);
+ const size_t zb = mbedtls_mpi_lsb(&TB);
+
+ MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&TA, za));
+ MBEDTLS_MPI_CHK(mbedtls_mpi_shift_r(&TB, zb));
+
+ /* Ensure A <= B: if B < A, swap them */
+ mbedtls_ct_condition_t swap = mbedtls_mpi_core_lt_ct(TB.p, TA.p, TA.n);
+ mbedtls_mpi_core_cond_swap(TA.p, TB.p, TA.n, swap);
+
+ MBEDTLS_MPI_CHK(mbedtls_mpi_gcd_modinv_odd(G, NULL, &TA, &TB));
+
+ size_t zg = za > zb ? zb : za; // zg = min(za, zb)
+ MBEDTLS_MPI_CHK(mbedtls_mpi_shift_l(G, zg));
cleanup: