Add mbedtls_mpi_mod_inv()

Signed-off-by: Tom Cosgrove <tom.cosgrove@arm.com>
diff --git a/library/bignum_mod.c b/library/bignum_mod.c
index ac0be99..216b20f 100644
--- a/library/bignum_mod.c
+++ b/library/bignum_mod.c
@@ -191,6 +191,107 @@
 
     return( 0 );
 }
+
+int mbedtls_mpi_mod_inv( mbedtls_mpi_mod_residue *X,
+                         const mbedtls_mpi_mod_residue *A,
+                         const mbedtls_mpi_mod_modulus *N )
+{
+    if( X->limbs != N->limbs || A->limbs != N->limbs )
+        return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
+
+    /* Zero has the same value regardless of Montgomery form or not */
+    if( mbedtls_mpi_core_check_zero_ct( A->p, A->limbs ) == 0 )
+        return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
+
+    /* Will we need to do Montgomery conversion? */
+    int mont_conv_needed;
+    switch( N->int_rep )
+    {
+        case MBEDTLS_MPI_MOD_REP_MONTGOMERY:
+            mont_conv_needed = 0;
+            break;
+        case MBEDTLS_MPI_MOD_REP_OPT_RED:
+            mont_conv_needed = 1;
+            break;
+        default:
+            return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
+    }
+
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+
+    /* If the input is already in Montgomery form, we have little to do but
+     * allocate working memory and call mbedtls_mpi_mod_raw_inv_prime().
+     *
+     * If it's not, we need to
+     * 1. Create a Montgomery version of the modulus;
+     * 2. Convert the input into Mont. form, using X->p to hold it;
+     * 3. (allocate and convert, same as if already in Mont. form);
+     * 4. Convert the inverted output back from Mont. form.
+     *
+     * Since the Montgomery conversion functions are in-place, we'll need to
+     * copy A into X before we start working on it (which could be avoided if
+     * there was a not-in-place function to convert to Montgomery form.
+     */
+
+    /* Montgomery version of modulus (if not already in Mont. form).
+     * We will only call setup if the input is not already in Montgomery form.
+     * We will re-use N->p from input modulus, and make use of the fact that
+     * mbedtls_mpi_mod_raw_to_mont_rep() won't free it. */
+    mbedtls_mpi_mod_modulus Nmont;
+    mbedtls_mpi_mod_modulus_init( &Nmont );
+
+    size_t working_limbs =
+                    mbedtls_mpi_mod_raw_inv_prime_working_limbs( N->limbs );
+
+    mbedtls_mpi_uint *working_memory = mbedtls_calloc( working_limbs,
+                                                     sizeof(mbedtls_mpi_uint) );
+    if( working_memory == NULL )
+    {
+        ret = MBEDTLS_ERR_MPI_ALLOC_FAILED;
+        goto cleanup;
+    }
+
+    const mbedtls_mpi_uint *to_invert;   /* Will alias A->p or X->p */
+    const mbedtls_mpi_mod_modulus *Nuse; /* Which of N and Nmont to use */
+
+    if( mont_conv_needed )
+    {
+        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_modulus_setup( &Nmont, N->p, N->limbs,
+                                             MBEDTLS_MPI_MOD_REP_MONTGOMERY ) );
+
+        mbedtls_mpi_core_to_mont_rep( X->p, A->p, Nmont.p, Nmont.limbs,
+                                      Nmont.rep.mont.mm, Nmont.rep.mont.rr,
+                                      working_memory );
+        to_invert = X->p;
+        Nuse = &Nmont;
+    }
+    else
+    {
+        to_invert = A->p;
+        Nuse = N;
+    }
+
+    mbedtls_mpi_mod_raw_inv_prime( X->p, to_invert,
+                                   Nuse->p, Nuse->limbs,
+                                   Nuse->rep.mont.rr,
+                                   working_memory );
+
+    if( mont_conv_needed )
+        mbedtls_mpi_core_from_mont_rep( X->p, X->p, Nmont.p, Nmont.limbs,
+                                        Nmont.rep.mont.mm, working_memory );
+
+cleanup:
+    mbedtls_mpi_mod_modulus_free( &Nmont );
+
+    if (working_memory != NULL )
+    {
+        mbedtls_platform_zeroize( working_memory,
+                                  working_limbs * sizeof(mbedtls_mpi_uint) );
+        mbedtls_free( working_memory );
+    }
+
+    return( ret );
+}
 /* END MERGE SLOT 3 */
 
 /* BEGIN MERGE SLOT 4 */