Add tests for mbedtls_mpi_mod_raw read/write functions

Signed-off-by: Gabor Mezei <gabor.mezei@arm.com>
diff --git a/tests/suites/test_suite_mpi.function b/tests/suites/test_suite_mpi.function
index ff87ea4..de66fdd 100644
--- a/tests/suites/test_suite_mpi.function
+++ b/tests/suites/test_suite_mpi.function
@@ -2,6 +2,8 @@
 #include "mbedtls/bignum.h"
 #include "mbedtls/entropy.h"
 #include "bignum_core.h"
+#include "bignum_mod.h"
+#include "bignum_mod_raw.h"
 #include "constant_time_internal.h"
 #include "test/constant_flow.h"
 
@@ -351,6 +353,106 @@
 /* END_CASE */
 
 /* BEGIN_CASE */
+void mbedtls_mpi_mod_raw_io( data_t *input, int nb_int, int nx_64_int,
+                             int iendian, int iret, int oret )
+{
+    #define BMAX 1024
+    unsigned char buf[BMAX];
+    #define XMAX BMAX / sizeof( mbedtls_mpi_uint )
+    mbedtls_mpi_uint X[XMAX];
+    mbedtls_mpi_uint init[XMAX];
+    mbedtls_mpi_mod_modulus m;
+    size_t nx, nb;
+    int ret;
+    int endian;
+
+    if( iret != 0 )
+        TEST_ASSERT( oret == 0 );
+
+    TEST_ASSERT( 0 <= nb_int );
+    nb = nb_int;
+    TEST_ASSERT( nb <= BMAX );
+
+    TEST_ASSERT( 0 <= nx_64_int );
+    nx = nx_64_int;
+    /* nx_64_int is the number of 64 bit limbs, if we have 32 bit limbs we need
+     * to double the number of limbs to have the same size. */
+    if( sizeof( mbedtls_mpi_uint ) == 4 )
+        nx *= 2;
+    TEST_ASSERT( nx <= XMAX );
+
+    if( iendian == MBEDTLS_MPI_MOD_EXT_REP_INVALID )
+        endian = MBEDTLS_MPI_MOD_EXT_REP_LE;
+    else
+        endian = iendian;
+
+    mbedtls_mpi_mod_modulus_init( &m );
+    TEST_ASSERT( memset( init, 0xFF, sizeof( init ) ) );
+
+    ret = mbedtls_mpi_mod_modulus_setup( &m, init, nx, endian,
+                                         MBEDTLS_MPI_MOD_REP_MONTGOMERY );
+    TEST_ASSERT( ret == 0 );
+
+    if( iendian == MBEDTLS_MPI_MOD_EXT_REP_INVALID && iret != 0 )
+        m.ext_rep = MBEDTLS_MPI_MOD_EXT_REP_INVALID;
+
+    ret = mbedtls_mpi_mod_raw_read( X, &m, input->x, input->len );
+    TEST_ASSERT( ret == iret );
+
+    if( iret == 0 )
+    {
+        if( iendian == MBEDTLS_MPI_MOD_EXT_REP_INVALID && oret != 0 )
+            m.ext_rep = MBEDTLS_MPI_MOD_EXT_REP_INVALID;
+
+        ret = mbedtls_mpi_mod_raw_write( X, &m, buf, nb );
+        TEST_ASSERT( ret == oret );
+    }
+
+    if( ( iret == 0 ) && ( oret == 0 ) )
+    {
+        if( nb > input->len )
+        {
+            if( endian == MBEDTLS_MPI_MOD_EXT_REP_BE )
+            {
+                size_t leading_zeroes = nb - input->len;
+                TEST_ASSERT( memcmp( buf + nb - input->len, input->x, input->len ) == 0 );
+                for( size_t i = 0; i < leading_zeroes; i++ )
+                    TEST_ASSERT( buf[i] == 0 );
+            }
+            else
+            {
+                TEST_ASSERT( memcmp( buf, input->x, input->len ) == 0 );
+                for( size_t i = input->len; i < nb; i++ )
+                    TEST_ASSERT( buf[i] == 0 );
+            }
+        }
+        else
+        {
+            if( endian == MBEDTLS_MPI_MOD_EXT_REP_BE )
+            {
+                size_t leading_zeroes = input->len - nb;
+                TEST_ASSERT( memcmp( input->x + input->len - nb, buf, nb ) == 0 );
+                for( size_t i = 0; i < leading_zeroes; i++ )
+                    TEST_ASSERT( input->x[i] == 0 );
+            }
+            else
+            {
+                TEST_ASSERT( memcmp( input->x, buf, nb ) == 0 );
+                for( size_t i = nb; i < input->len; i++ )
+                    TEST_ASSERT( input->x[i] == 0 );
+            }
+        }
+    }
+
+exit:
+    mbedtls_mpi_mod_modulus_free( &m );
+
+    #undef BMAX
+    #undef XMAX
+}
+/* END_CASE */
+
+/* BEGIN_CASE */
 void mbedtls_mpi_read_binary_le( data_t * buf, char * input_A )
 {
     mbedtls_mpi X;