Test that overly large Diffie-Hellman keys are rejected

Adds test cases to ensure that `mbedtls_mpi_exp_mod` will return an error with
an exponent or modulus that is greater than `MBEDTLS_MPI_MAX_SIZE` in size.

Adds test cases to ensure that Diffie-Hellman will fail to make a key pair
(using `mbedtls_dhm_make_public`) when the prime modulus is greater than
`MBEDTLS_MPI_MAX_SIZE` in size.

Signed-off-by: Chris Jones <christopher.jones@arm.com>
diff --git a/tests/suites/test_suite_mpi.function b/tests/suites/test_suite_mpi.function
index f2702f1..c61d8b3 100644
--- a/tests/suites/test_suite_mpi.function
+++ b/tests/suites/test_suite_mpi.function
@@ -1111,6 +1111,37 @@
 /* END_CASE */
 
 /* BEGIN_CASE */
+void mbedtls_mpi_exp_mod_size( int A_bytes, int E_bytes, int N_bytes,
+                               int radix_RR, char * input_RR, int div_result )
+{
+    mbedtls_mpi A, E, N, RR, Z;
+    mbedtls_mpi_init( &A  ); mbedtls_mpi_init( &E ); mbedtls_mpi_init( &N );
+    mbedtls_mpi_init( &RR ); mbedtls_mpi_init( &Z );
+
+    TEST_ASSERT( mbedtls_mpi_lset( &A, 1 ) == 0 );
+    TEST_ASSERT( mbedtls_mpi_lset( &E, 1 ) == 0 );
+    TEST_ASSERT( mbedtls_mpi_lset( &N, 1 ) == 0 );
+
+    TEST_ASSERT( mbedtls_mpi_shift_l( &A, ( A_bytes * 8 ) - 1 ) == 0 );
+    TEST_ASSERT( mbedtls_mpi_shift_l( &E, ( E_bytes * 8 ) - 1 ) == 0 );
+    TEST_ASSERT( mbedtls_mpi_shift_l( &N, ( N_bytes * 8 ) - 1 ) == 0 );
+
+    TEST_ASSERT( mbedtls_mpi_set_bit( &A, 0, 1 ) == 0 );
+    TEST_ASSERT( mbedtls_mpi_set_bit( &E, 0, 1 ) == 0 );
+    TEST_ASSERT( mbedtls_mpi_set_bit( &N, 0, 1 ) == 0 );
+
+    if( strlen( input_RR ) )
+        TEST_ASSERT( mbedtls_mpi_read_string( &RR, radix_RR, input_RR ) == 0 );
+
+    TEST_ASSERT( mbedtls_mpi_exp_mod( &Z, &A, &E, &N, &RR ) == div_result );
+
+exit:
+    mbedtls_mpi_free( &A  ); mbedtls_mpi_free( &E ); mbedtls_mpi_free( &N );
+    mbedtls_mpi_free( &RR ); mbedtls_mpi_free( &Z );
+}
+/* END_CASE */
+
+/* BEGIN_CASE */
 void mbedtls_mpi_inv_mod( int radix_X, char * input_X, int radix_Y,
                           char * input_Y, int radix_A, char * input_A,
                           int div_result )