Bignum: Implement mbedtls_mpi_mod_raw_inv_prime() and tests

Fixes #6023.

Signed-off-by: Tom Cosgrove <tom.cosgrove@arm.com>
diff --git a/library/bignum_mod_raw.c b/library/bignum_mod_raw.c
index 22e56b7..03924d2 100644
--- a/library/bignum_mod_raw.c
+++ b/library/bignum_mod_raw.c
@@ -124,6 +124,34 @@
 
 /* BEGIN MERGE SLOT 3 */
 
+size_t mbedtls_mpi_mod_raw_inv_prime_working_limbs( size_t AN_limbs )
+{
+    /* mbedtls_mpi_mod_raw_inv_prime() needs a temporary for the exponent,
+     * which will be the same size as the modulus and input (AN_limbs),
+     * and additional space to pass to mbedtls_mpi_core_exp_mod(). */
+    return( AN_limbs +
+            mbedtls_mpi_core_exp_mod_working_limbs( AN_limbs, AN_limbs ) );
+}
+
+void mbedtls_mpi_mod_raw_inv_prime( mbedtls_mpi_uint *X,
+                                    const mbedtls_mpi_uint *A,
+                                    const mbedtls_mpi_uint *N,
+                                    size_t AN_limbs,
+                                    const mbedtls_mpi_uint *RR,
+                                    mbedtls_mpi_uint *T )
+{
+    /* Inversion by power: g^|G| = 1 => g^(-1) = g^(|G|-1), and
+     *                       |G| = N - 1, so we want
+     *                 g^(|G|-1) = g^(N - 2)
+     */
+    mbedtls_mpi_uint *Nminus2 = T;
+    (void) mbedtls_mpi_core_sub_int( Nminus2, N, 2, AN_limbs );
+
+    mbedtls_mpi_core_exp_mod( X,
+                              A, N, AN_limbs, Nminus2, AN_limbs,
+                              RR, T + AN_limbs );
+}
+
 /* END MERGE SLOT 3 */
 
 /* BEGIN MERGE SLOT 4 */
diff --git a/library/bignum_mod_raw.h b/library/bignum_mod_raw.h
index d7b6dd1..698119e 100644
--- a/library/bignum_mod_raw.h
+++ b/library/bignum_mod_raw.h
@@ -174,6 +174,51 @@
 
 /* BEGIN MERGE SLOT 3 */
 
+/**
+ * \brief          Returns the number of limbs of working memory required for
+ *                 a call to `mbedtls_mpi_mod_raw_inv_prime()`.
+ *
+ * \param AN_limbs The number of limbs in the input `A` and the modulus `N`
+ *                 (they must be the same size) that will be given to
+ *                 `mbedtls_mpi_mod_raw_inv_prime()`.
+ *
+ * \return         The number of limbs of working memory required by
+ *                 `mbedtls_mpi_mod_raw_inv_prime()`.
+ */
+size_t mbedtls_mpi_mod_raw_inv_prime_working_limbs( size_t AN_limbs );
+
+/**
+ * \brief Perform fixed-width modular inversion of a Montgomery-form MPI with
+ *        respect to a modulus \p N that must be prime.
+ *
+ * \p X may be aliased to \p A, but not to \p N or \p RR.
+ *
+ * \param[out] X     The modular inverse of \p A with respect to \p N.
+ *                   Will be in Montgomery form.
+ * \param[in] A      The number to calculate the modular inverse of.
+ *                   Must be in Montgomery form. Must not be 0.
+ * \param[in] N      The modulus, as a little-endian array of length \p AN_limbs.
+ *                   Must be prime.
+ * \param AN_limbs   The number of limbs in \p A, \p N and \p RR.
+ * \param[in] RR     The precomputed residue of 2^{2*biL} modulo N, as a little-
+ *                   endian array of length \p AN_limbs.
+ * \param[in,out] T  Temporary storage of at least the number of limbs returned
+ *                   by `mbedtls_mpi_mod_raw_inv_prime_working_limbs()`.
+ *                   Its initial content is unused and its final content is
+ *                   indeterminate.
+ *                   It must not alias or otherwise overlap any of the other
+ *                   parameters.
+ *                   It is up to the caller to zeroize \p T when it is no
+ *                   longer needed, and before freeing it if it was dynamically
+ *                   allocated.
+ */
+void mbedtls_mpi_mod_raw_inv_prime( mbedtls_mpi_uint *X,
+                                    const mbedtls_mpi_uint *A,
+                                    const mbedtls_mpi_uint *N,
+                                    size_t AN_limbs,
+                                    const mbedtls_mpi_uint *RR,
+                                    mbedtls_mpi_uint *T );
+
 /* END MERGE SLOT 3 */
 
 /* BEGIN MERGE SLOT 4 */
diff --git a/scripts/mbedtls_dev/bignum_common.py b/scripts/mbedtls_dev/bignum_common.py
index 3ff8b2f..0339b1a 100644
--- a/scripts/mbedtls_dev/bignum_common.py
+++ b/scripts/mbedtls_dev/bignum_common.py
@@ -99,6 +99,7 @@
     limb_sizes = [32, 64] # type: List[int]
     arities = [1, 2]
     arity = 2
+    suffix = False   # for arity = 1, symbol can be prefix (default) or suffix
 
     def __init__(self, val_a: str, val_b: str = "0", bits_in_limb: int = 32) -> None:
         self.val_a = val_a
@@ -170,7 +171,8 @@
         """
         if not self.case_description:
             if self.arity == 1:
-                self.case_description = "{} {:x}".format(
+                format_string = "{1:x} {0}" if self.suffix else "{0} {1:x}"
+                self.case_description = format_string.format(
                     self.symbol, self.int_a
                 )
             elif self.arity == 2:
diff --git a/scripts/mbedtls_dev/bignum_data.py b/scripts/mbedtls_dev/bignum_data.py
index e6ed300..9658933 100644
--- a/scripts/mbedtls_dev/bignum_data.py
+++ b/scripts/mbedtls_dev/bignum_data.py
@@ -90,8 +90,8 @@
                               "4708d9893a973000b54a23020fc5b043d6e4a51519d9c9cc"
                               "52d32377e78131c1")
 
-# Adding 192 bit and 1024 bit numbers because these are the shortest required
-# for ECC and RSA respectively.
+# Adding 192 bit and 1024 bit numbers because these are the shortest required
+# for ECC and RSA respectively.
 INPUTS_DEFAULT = [
         "0", "1", # corner cases
         "2", "3", # small primes
@@ -110,13 +110,21 @@
 # supported for now.
 MODULI_DEFAULT = [
         "53", # safe prime
-        "45", # non-prime
+        "45", # non-prime
         SAFE_PRIME_192_BIT_SEED_1,  # safe prime
         RANDOM_192_BIT_SEED_2_NO4,  # not a prime
         SAFE_PRIME_1024_BIT_SEED_3, # safe prime
         RANDOM_1024_BIT_SEED_4_NO5, # not a prime
         ]
 
+# Some functions, e.g. mbedtls_mpi_mod_raw_inv_prime(), only support prime moduli.
+ONLY_PRIME_MODULI = [
+        "53", # safe prime
+        "8ac72304057392b5",     # 9999999997777777333 (longer, not safe, prime)
+        SAFE_PRIME_192_BIT_SEED_1,  # safe prime
+        SAFE_PRIME_1024_BIT_SEED_3, # safe prime
+        ]
+
 def __gen_safe_prime(bits, seed):
     '''
     Generate a safe prime.
diff --git a/scripts/mbedtls_dev/bignum_mod_raw.py b/scripts/mbedtls_dev/bignum_mod_raw.py
index d05479a..1a23a60 100644
--- a/scripts/mbedtls_dev/bignum_mod_raw.py
+++ b/scripts/mbedtls_dev/bignum_mod_raw.py
@@ -18,6 +18,7 @@
 
 from . import test_data_generation
 from . import bignum_common
+from .bignum_data import ONLY_PRIME_MODULI
 
 class BignumModRawTarget(test_data_generation.BaseTarget):
     #pylint: disable=abstract-method, too-few-public-methods
@@ -53,6 +54,36 @@
 
 # BEGIN MERGE SLOT 3
 
+class BignumModRawInvPrime(bignum_common.ModOperationCommon,
+                           BignumModRawTarget):
+    """Test cases for bignum mpi_mod_raw_inv_prime()."""
+    moduli = ONLY_PRIME_MODULI
+    symbol = "^ -1"
+    test_function = "mpi_mod_raw_inv_prime"
+    test_name = "mbedtls_mpi_mod_raw_inv_prime (Montgomery form only)"
+    input_style = "fixed"
+    arity = 1
+    suffix = True
+
+    @property
+    def is_valid(self) -> bool:
+        return self.int_a > 0 and self.int_a < self.int_n
+
+    def arguments(self) -> List[str]:
+        # Input has to be given in Montgomery form
+        mont_a = self.to_montgomery(self.int_a)
+        arg_mont_a = self.format_arg('{:x}'.format(mont_a))
+        return [bignum_common.quote_str(n) for n in [self.arg_n,
+                                                     arg_mont_a]
+               ] + self.result()
+
+    def result(self) -> List[str]:
+        result = bignum_common.invmod(self.int_a, self.int_n)
+        if result < 0:
+            result += self.int_n
+        mont_result = self.to_montgomery(result)
+        return [self.format_result(mont_result)]
+
 # END MERGE SLOT 3
 
 # BEGIN MERGE SLOT 4
diff --git a/tests/suites/test_suite_bignum_mod_raw.function b/tests/suites/test_suite_bignum_mod_raw.function
index c7decf0..5d23707 100644
--- a/tests/suites/test_suite_bignum_mod_raw.function
+++ b/tests/suites/test_suite_bignum_mod_raw.function
@@ -349,6 +349,75 @@
 
 /* BEGIN MERGE SLOT 3 */
 
+/* BEGIN_CASE */
+void mpi_mod_raw_inv_prime( char * input_N, char * input_A, char * input_X )
+{
+    mbedtls_mpi_uint *A = NULL;
+    mbedtls_mpi_uint *N = NULL;
+    mbedtls_mpi_uint *X = NULL;
+    size_t A_limbs, N_limbs, X_limbs;
+    mbedtls_mpi_uint *Y = NULL;
+    mbedtls_mpi_uint *T = NULL;
+    const mbedtls_mpi_uint *R2 = NULL;
+
+    /* Legacy MPIs for computing R2 */
+    mbedtls_mpi N_mpi;  /* gets set up manually, aliasing N, so no need to free */
+    mbedtls_mpi R2_mpi;
+    mbedtls_mpi_init( &R2_mpi );
+
+    TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &A, &A_limbs, input_A ) );
+    TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &N, &N_limbs, input_N ) );
+    TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &X, &X_limbs, input_X ) );
+    ASSERT_ALLOC( Y, N_limbs );
+
+    TEST_EQUAL( A_limbs, N_limbs );
+    TEST_EQUAL( X_limbs, N_limbs );
+
+    N_mpi.s = 1;
+    N_mpi.p = N;
+    N_mpi.n = N_limbs;
+    TEST_EQUAL( 0, mbedtls_mpi_core_get_mont_r2_unsafe( &R2_mpi, &N_mpi ) );
+    TEST_EQUAL( 0, mbedtls_mpi_grow( &R2_mpi, N_limbs ) );
+    R2 = R2_mpi.p;
+
+    size_t working_limbs = mbedtls_mpi_mod_raw_inv_prime_working_limbs( N_limbs );
+
+    /* No point exactly duplicating the code in mbedtls_mpi_mod_raw_inv_prime_working_limbs()
+     * to see if the output is correct, but we can check that it's in a
+     * reasonable range.  The current calculation works out as
+     * `1 + N_limbs * (welem + 4)`, where welem is the number of elements in
+     * the window (1 << 1 up to 1 << 6).
+     */
+    size_t min_expected_working_limbs = 1 + N_limbs * 5;
+    size_t max_expected_working_limbs = 1 + N_limbs * 68;
+
+    TEST_LE_U( min_expected_working_limbs, working_limbs );
+    TEST_LE_U( working_limbs, max_expected_working_limbs );
+
+    ASSERT_ALLOC( T, working_limbs );
+
+    mbedtls_mpi_mod_raw_inv_prime( Y, A, N, N_limbs, R2, T );
+
+    TEST_EQUAL( 0, memcmp( X, Y, N_limbs * sizeof( mbedtls_mpi_uint ) ) );
+
+    /* Check when output aliased to input */
+
+    mbedtls_mpi_mod_raw_inv_prime( A, A, N, N_limbs, R2, T );
+
+    TEST_EQUAL( 0, memcmp( X, A, N_limbs * sizeof( mbedtls_mpi_uint ) ) );
+
+exit:
+    mbedtls_free( T );
+    mbedtls_free( A );
+    mbedtls_free( N );
+    mbedtls_free( X );
+    mbedtls_free( Y );
+    mbedtls_mpi_free( &R2_mpi );
+    // R2 doesn't need to be freed as it is only aliasing R2_mpi
+    // N_mpi doesn't need to be freed as it is only aliasing N
+}
+/* END_CASE */
+
 /* END MERGE SLOT 3 */
 
 /* BEGIN MERGE SLOT 4 */