Add low-level Montgomery conversion functions to bignum_core

Signed-off-by: Tom Cosgrove <tom.cosgrove@arm.com>
diff --git a/library/bignum_core.c b/library/bignum_core.c
index 75cce05..74efb38 100644
--- a/library/bignum_core.c
+++ b/library/bignum_core.c
@@ -753,6 +753,29 @@
     return( bits );
 }
 
+void mbedtls_mpi_core_to_mont_rep( mbedtls_mpi_uint *X,
+                                   const mbedtls_mpi_uint *A,
+                                   const mbedtls_mpi_uint *N,
+                                   size_t AN_limbs,
+                                   mbedtls_mpi_uint mm,
+                                   const mbedtls_mpi_uint *rr,
+                                   mbedtls_mpi_uint *T )
+{
+    mbedtls_mpi_core_montmul( X, A, rr, AN_limbs, N, AN_limbs, mm, T );
+}
+
+void mbedtls_mpi_core_from_mont_rep( mbedtls_mpi_uint *X,
+                                     const mbedtls_mpi_uint *A,
+                                     const mbedtls_mpi_uint *N,
+                                     size_t AN_limbs,
+                                     mbedtls_mpi_uint mm,
+                                     mbedtls_mpi_uint *T )
+{
+    const mbedtls_mpi_uint Rinv = 1;    /* 1/R in Mont. rep => 1 */
+
+    mbedtls_mpi_core_montmul( X, A, &Rinv, 1, N, AN_limbs, mm, T );
+}
+
 /* END MERGE SLOT 3 */
 
 /* BEGIN MERGE SLOT 4 */
diff --git a/library/bignum_core.h b/library/bignum_core.h
index 7b5787c..b898527 100644
--- a/library/bignum_core.h
+++ b/library/bignum_core.h
@@ -606,6 +606,81 @@
     return( 2 * AN_limbs + 1 );
 }
 
+/** Convert an MPI into Montgomery form.
+ *
+ * \p X may be aliased to \p A, but may not otherwise overlap it.
+ *
+ * \p X may not alias \p N (it is in canonical form, so must be stricly less
+ * than \p N). Nor may it alias or overlap \p rr (this is unlikely to be
+ * required in practice.)
+ *
+ * This function is a thin wrapper around `mbedtls_mpi_core_montmul()` that is
+ * an alternative to calling `mbedtls_mpi_mod_raw_to_mont_rep()` when we
+ * don't want to allocate memory.
+ *
+ * \param[out]    X         The result of the conversion.
+ *                          Must have the same number of limbs as \p A.
+ * \param[in]     A         The MPI to convert into Montgomery form.
+ *                          Must have the same number of limbs as the modulus.
+ * \param[in]     N         The address of the modulus, which gives the size of
+ *                          the base `R` = 2^(biL*m->limbs).
+ * \param[in]     AN_limbs  The number of limbs in \p X, \p A, \p N and \p rr.
+ * \param         mm        The Montgomery constant for \p N: -N^-1 mod 2^biL.
+ *                          This can be determined  by calling
+ *                          `mbedtls_mpi_core_montmul_init()`.
+ * \param[in]     rr        The residue for `2^{2*n*biL} mod N`.
+ * \param[in,out] T         Temporary storage of size at least
+ *                          `mbedtls_mpi_core_montmul_working_limbs(AN_limbs)`
+ *                          limbs.
+ *                          Its initial content is unused and
+ *                          its final content is indeterminate.
+ *                          It must not alias or otherwise overlap any of the
+ *                          other parameters.
+ */
+void mbedtls_mpi_core_to_mont_rep( mbedtls_mpi_uint *X,
+                                   const mbedtls_mpi_uint *A,
+                                   const mbedtls_mpi_uint *N,
+                                   size_t AN_limbs,
+                                   mbedtls_mpi_uint mm,
+                                   const mbedtls_mpi_uint *rr,
+                                   mbedtls_mpi_uint *T );
+
+/** Convert an MPI from Montgomery form.
+ *
+ * \p X may be aliased to \p A, but may not otherwise overlap it.
+ *
+ * \p X may not alias \p N (it is in canonical form, so must be stricly less
+ * than \p N).
+ *
+ * This function is a thin wrapper around `mbedtls_mpi_core_montmul()` that is
+ * an alternative to calling `mbedtls_mpi_mod_raw_from_mont_rep()` when we
+ * don't want to allocate memory.
+ *
+ * \param[out]    X         The result of the conversion.
+ *                          Must have the same number of limbs as \p A.
+ * \param[in]     A         The MPI to convert from Montgomery form.
+ *                          Must have the same number of limbs as the modulus.
+ * \param[in]     N         The address of the modulus, which gives the size of
+ *                          the base `R` = 2^(biL*m->limbs).
+ * \param[in]     AN_limbs  The number of limbs in \p X, \p A and \p N.
+ * \param         mm        The Montgomery constant for \p N: -N^-1 mod 2^biL.
+ *                          This can be determined  by calling
+ *                          `mbedtls_mpi_core_montmul_init()`.
+ * \param[in,out] T         Temporary storage of size at least
+ *                          `mbedtls_mpi_core_montmul_working_limbs(AN_limbs)`
+ *                          limbs.
+ *                          Its initial content is unused and
+ *                          its final content is indeterminate.
+ *                          It must not alias or otherwise overlap any of the
+ *                          other parameters.
+ */
+void mbedtls_mpi_core_from_mont_rep( mbedtls_mpi_uint *X,
+                                     const mbedtls_mpi_uint *A,
+                                     const mbedtls_mpi_uint *N,
+                                     size_t AN_limbs,
+                                     mbedtls_mpi_uint mm,
+                                     mbedtls_mpi_uint *T );
+
 /* END MERGE SLOT 3 */
 
 /* BEGIN MERGE SLOT 4 */
diff --git a/library/bignum_mod_raw.c b/library/bignum_mod_raw.c
index be8fc86..d2d93d3 100644
--- a/library/bignum_mod_raw.c
+++ b/library/bignum_mod_raw.c
@@ -188,8 +188,8 @@
     if( ( T = (mbedtls_mpi_uint *) mbedtls_calloc( t_limbs, ciL ) ) == NULL )
         return( MBEDTLS_ERR_MPI_ALLOC_FAILED );
 
-    mbedtls_mpi_core_montmul( X, X, m->rep.mont.rr, m->limbs, m->p, m->limbs,
-                              m->rep.mont.mm, T );
+    mbedtls_mpi_core_to_mont_rep( X, X, m->p, m->limbs,
+                                  m->rep.mont.mm, m->rep.mont.rr, T );
 
     mbedtls_platform_zeroize( T, t_limbs * ciL );
     mbedtls_free( T );
@@ -199,15 +199,13 @@
 int mbedtls_mpi_mod_raw_from_mont_rep( mbedtls_mpi_uint *X,
                                        const mbedtls_mpi_mod_modulus *m )
 {
-    const mbedtls_mpi_uint one = 1;
     const size_t t_limbs = mbedtls_mpi_core_montmul_working_limbs( m->limbs );
     mbedtls_mpi_uint *T;
 
     if( ( T = (mbedtls_mpi_uint *) mbedtls_calloc( t_limbs, ciL ) ) == NULL )
         return( MBEDTLS_ERR_MPI_ALLOC_FAILED );
 
-    mbedtls_mpi_core_montmul( X, X, &one, 1, m->p, m->limbs,
-                              m->rep.mont.mm, T );
+    mbedtls_mpi_core_from_mont_rep( X, X, m->p, m->limbs, m->rep.mont.mm, T );
 
     mbedtls_platform_zeroize( T, t_limbs * ciL );
     mbedtls_free( T );
diff --git a/tests/suites/test_suite_bignum_mod_raw.function b/tests/suites/test_suite_bignum_mod_raw.function
index ef0f712..50fdac3 100644
--- a/tests/suites/test_suite_bignum_mod_raw.function
+++ b/tests/suites/test_suite_bignum_mod_raw.function
@@ -533,8 +533,10 @@
 {
     mbedtls_mpi_uint *N = NULL;
     mbedtls_mpi_uint *A = NULL;
+    mbedtls_mpi_uint *R = NULL; /* for result of low-level conversion */
     mbedtls_mpi_uint *X = NULL;
-    size_t n_limbs, a_limbs, x_limbs, x_bytes;
+    mbedtls_mpi_uint *T = NULL;
+    size_t n_limbs, a_limbs, x_limbs;
 
     mbedtls_mpi_mod_modulus m;
     mbedtls_mpi_mod_modulus_init( &m );
@@ -543,23 +545,50 @@
     TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &N, &n_limbs, input_N ) );
     TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &A, &a_limbs, input_A ) );
     TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &X, &x_limbs, input_X ) );
-    x_bytes = x_limbs * sizeof(mbedtls_mpi_uint);
 
-    /* Test that input does not require more limbs than modulo */
-    TEST_LE_U(a_limbs, n_limbs);
+    /* Number to convert must have same number of limbs as modulus */
+    TEST_EQUAL(a_limbs, n_limbs);
+
+    /* Higher-level conversion is in-place, so expected result must have the
+     * same number of limbs too */
+    TEST_EQUAL(x_limbs, n_limbs);
+
+    size_t limbs = n_limbs;
+    size_t bytes = limbs * sizeof(mbedtls_mpi_uint);
 
     TEST_EQUAL( 0, mbedtls_mpi_mod_modulus_setup( &m, N, n_limbs,
-                MBEDTLS_MPI_MOD_REP_MONTGOMERY ) );
+                                          MBEDTLS_MPI_MOD_REP_MONTGOMERY ) );
 
-    /* Convert from cannonical into Montgomery representation */
+    /* 1. Test low-level function first */
+
+    /* It has separate output, and requires temporary working storage */
+    size_t temp_limbs = mbedtls_mpi_core_montmul_working_limbs( limbs );
+    ASSERT_ALLOC( T, temp_limbs );
+    ASSERT_ALLOC( R, limbs );
+    mbedtls_mpi_core_to_mont_rep( R, A, N, n_limbs,
+                                  m.rep.mont.mm, m.rep.mont.rr, T );
+    /* Test that the low-level function gives the required value */
+    ASSERT_COMPARE( R, bytes, X, bytes );
+
+    /* Test when output is aliased to input */
+    memcpy( R, A, bytes );
+    mbedtls_mpi_core_to_mont_rep( R, R, N, n_limbs,
+                                  m.rep.mont.mm, m.rep.mont.rr, T );
+    ASSERT_COMPARE( R, bytes, X, bytes );
+
+    /* 2. Test higher-level cannonical to Montgomery conversion */
+
     TEST_EQUAL(0, mbedtls_mpi_mod_raw_to_mont_rep( A, &m ) );
 
     /* The result matches expected value */
-    ASSERT_COMPARE( A, x_bytes, X, x_bytes );
+    ASSERT_COMPARE( A, bytes, X, bytes );
+
 exit:
     mbedtls_mpi_mod_modulus_free( &m );
+    mbedtls_free( T );
     mbedtls_free( N );
     mbedtls_free( A );
+    mbedtls_free( R );
     mbedtls_free( X );
 }
 /* END_CASE */
@@ -569,8 +598,10 @@
 {
     mbedtls_mpi_uint *N = NULL;
     mbedtls_mpi_uint *A = NULL;
+    mbedtls_mpi_uint *R = NULL; /* for result of low-level conversion */
     mbedtls_mpi_uint *X = NULL;
-    size_t n_limbs, a_limbs, x_limbs, x_bytes;
+    mbedtls_mpi_uint *T = NULL;
+    size_t n_limbs, a_limbs, x_limbs;
 
     mbedtls_mpi_mod_modulus m;
     mbedtls_mpi_mod_modulus_init( &m );
@@ -579,23 +610,50 @@
     TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &N, &n_limbs, input_N ) );
     TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &A, &a_limbs, input_A ) );
     TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &X, &x_limbs, input_X ) );
-    x_bytes = x_limbs * sizeof(mbedtls_mpi_uint);
 
-    /* Test that input does not require more limbs than modulo */
-    TEST_LE_U(a_limbs, n_limbs);
+    /* Number to convert must have same number of limbs as modulus */
+    TEST_EQUAL(a_limbs, n_limbs);
+
+    /* Higher-level conversion is in-place, so expected result must have the
+     * same number of limbs too */
+    TEST_EQUAL(x_limbs, n_limbs);
+
+    size_t limbs = n_limbs;
+    size_t bytes = limbs * sizeof(mbedtls_mpi_uint);
 
     TEST_EQUAL( 0, mbedtls_mpi_mod_modulus_setup( &m, N, n_limbs,
-                MBEDTLS_MPI_MOD_REP_MONTGOMERY ) );
+                                          MBEDTLS_MPI_MOD_REP_MONTGOMERY ) );
 
-    /* Convert from Montgomery into cannonical representation */
+    /* 1. Test low-level function first */
+
+    /* It has separate output, and requires temporary working storage */
+    size_t temp_limbs = mbedtls_mpi_core_montmul_working_limbs( limbs );
+    ASSERT_ALLOC( T, temp_limbs );
+    ASSERT_ALLOC( R, limbs );
+    mbedtls_mpi_core_from_mont_rep( R, A, N, n_limbs,
+                                    m.rep.mont.mm, T );
+    /* Test that the low-level function gives the required value */
+    ASSERT_COMPARE( R, bytes, X, bytes );
+
+    /* Test when output is aliased to input */
+    memcpy( R, A, bytes );
+    mbedtls_mpi_core_from_mont_rep( R, R, N, n_limbs,
+                                    m.rep.mont.mm, T );
+    ASSERT_COMPARE( R, bytes, X, bytes );
+
+    /* 2. Test higher-level Montgomery to cannonical conversion */
+
     TEST_EQUAL(0, mbedtls_mpi_mod_raw_from_mont_rep( A, &m ) );
 
     /* The result matches expected value */
-    ASSERT_COMPARE( A, x_bytes, X, x_bytes );
+    ASSERT_COMPARE( A, bytes, X, bytes );
+
 exit:
     mbedtls_mpi_mod_modulus_free( &m );
+    mbedtls_free( T );
     mbedtls_free( N );
     mbedtls_free( A );
+    mbedtls_free( R );
     mbedtls_free( X );
 }
 /* END_CASE */