bignum: use CT modinv when A is odd (any range)

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/library/bignum.c b/library/bignum.c
index c742fc9..137afb0 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -2046,6 +2046,42 @@
 }
 
 /*
+ * Compute X = A^-1 mod N with N even and A odd (but in any range).
+ *
+ * Return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE if the inverse doesn't exist.
+ */
+static int mbedtls_mpi_inv_mod_even(mbedtls_mpi *X,
+                                    mbedtls_mpi const *A,
+                                    mbedtls_mpi const *N)
+{
+    int ret;
+    mbedtls_mpi AA;
+
+    mbedtls_mpi_init(&AA);
+
+    /* Bring A in the range [0, N). */
+    MBEDTLS_MPI_CHK(mbedtls_mpi_mod_mpi(&AA, A, N));
+
+    /* We know A >= 0 but the next functions wants A > 1 */
+    int cmp = mbedtls_mpi_cmp_int(&AA, 1);
+    if (cmp < 0) { // AA == 0
+        ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
+        goto cleanup;
+    }
+    if (cmp == 0) { // AA = 1
+        MBEDTLS_MPI_CHK(mbedtls_mpi_lset(X, 1));
+        goto cleanup;
+    }
+
+    /* Now we know 1 < A < N, N is even and AA is still odd */
+    MBEDTLS_MPI_CHK(mbedtls_mpi_inv_mod_even_in_range(X, &AA, N));
+
+cleanup:
+    mbedtls_mpi_free(&AA);
+    return ret;
+}
+
+/*
  * Modular inverse: X = A^-1 mod N  (HAC 14.61 / 14.64)
  */
 int mbedtls_mpi_inv_mod(mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *N)
@@ -2061,10 +2097,8 @@
         return mbedtls_mpi_inv_mod_odd(X, A, N);
     }
 
-    if (mbedtls_mpi_get_bit(A, 0) == 1 &&
-        mbedtls_mpi_cmp_int(A, 1) > 0 &&
-        mbedtls_mpi_cmp_mpi(A, N) < 0) {
-        return mbedtls_mpi_inv_mod_even_in_range(X, A, N);
+    if (mbedtls_mpi_get_bit(A, 0) == 1) {
+        return mbedtls_mpi_inv_mod_even(X, A, N);
     }
 
     mbedtls_mpi_init(&TA); mbedtls_mpi_init(&TU); mbedtls_mpi_init(&U1); mbedtls_mpi_init(&U2);