Test mbedtls_mpi_safe_cond_{assign,swap} with the basic functions

Test mbedtls_mpi_safe_cond_assign() and mbedtls_mpi_safe_cond_swap()
with their "unsafe" counterparts mbedtls_mpi_copy() and
mbedtls_mpi_swap(). This way we don't need to repeat the coverage of
test cases.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/tests/suites/test_suite_mpi.function b/tests/suites/test_suite_mpi.function
index 17c485e..83df698 100644
--- a/tests/suites/test_suite_mpi.function
+++ b/tests/suites/test_suite_mpi.function
@@ -621,21 +621,38 @@
 /* BEGIN_CASE */
 void mbedtls_mpi_copy( char *src_hex, char *dst_hex )
 {
-    mbedtls_mpi src, dst;
+    mbedtls_mpi src, dst, ref;
     mbedtls_mpi_init( &src );
     mbedtls_mpi_init( &dst );
+    mbedtls_mpi_init( &ref );
 
     TEST_ASSERT( mbedtls_test_read_mpi( &src, 16, src_hex ) == 0 );
+    TEST_ASSERT( mbedtls_test_read_mpi( &ref, 16, dst_hex ) == 0 );
+
+    /* mbedtls_mpi_copy() */
     TEST_ASSERT( mbedtls_test_read_mpi( &dst, 16, dst_hex ) == 0 );
-
     TEST_ASSERT( mbedtls_mpi_copy( &dst, &src ) == 0 );
-
     TEST_ASSERT( sign_is_valid( &dst ) );
     TEST_ASSERT( mbedtls_mpi_cmp_mpi( &dst, &src ) == 0 );
 
+    /* mbedtls_mpi_safe_cond_assign(), assignment done */
+    mbedtls_mpi_free( &dst );
+    TEST_ASSERT( mbedtls_test_read_mpi( &dst, 16, dst_hex ) == 0 );
+    TEST_ASSERT( mbedtls_mpi_safe_cond_assign( &dst, &src, 1 ) == 0 );
+    TEST_ASSERT( sign_is_valid( &dst ) );
+    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &dst, &src ) == 0 );
+
+    /* mbedtls_mpi_safe_cond_assign(), assignment not done */
+    mbedtls_mpi_free( &dst );
+    TEST_ASSERT( mbedtls_test_read_mpi( &dst, 16, dst_hex ) == 0 );
+    TEST_ASSERT( mbedtls_mpi_safe_cond_assign( &dst, &src, 0 ) == 0 );
+    TEST_ASSERT( sign_is_valid( &dst ) );
+    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &dst, &ref ) == 0 );
+
 exit:
     mbedtls_mpi_free( &src );
     mbedtls_mpi_free( &dst );
+    mbedtls_mpi_free( &ref );
 }
 /* END_CASE */
 
@@ -666,17 +683,40 @@
     mbedtls_mpi_init( &X ); mbedtls_mpi_init( &Y );
     mbedtls_mpi_init( &X0 ); mbedtls_mpi_init( &Y0 );
 
-    TEST_ASSERT( mbedtls_test_read_mpi( &X,  16, X_hex ) == 0 );
-    TEST_ASSERT( mbedtls_test_read_mpi( &Y,  16, Y_hex ) == 0 );
     TEST_ASSERT( mbedtls_test_read_mpi( &X0, 16, X_hex ) == 0 );
     TEST_ASSERT( mbedtls_test_read_mpi( &Y0, 16, Y_hex ) == 0 );
 
+    /* mbedtls_mpi_swap() */
+    TEST_ASSERT( mbedtls_test_read_mpi( &X,  16, X_hex ) == 0 );
+    TEST_ASSERT( mbedtls_test_read_mpi( &Y,  16, Y_hex ) == 0 );
     mbedtls_mpi_swap( &X, &Y );
     TEST_ASSERT( sign_is_valid( &X ) );
     TEST_ASSERT( sign_is_valid( &Y ) );
     TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &Y0 ) == 0 );
     TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Y, &X0 ) == 0 );
 
+    /* mbedtls_mpi_safe_cond_swap(), swap done */
+    mbedtls_mpi_free( &X );
+    mbedtls_mpi_free( &Y );
+    TEST_ASSERT( mbedtls_test_read_mpi( &X,  16, X_hex ) == 0 );
+    TEST_ASSERT( mbedtls_test_read_mpi( &Y,  16, Y_hex ) == 0 );
+    TEST_ASSERT( mbedtls_mpi_safe_cond_swap( &X, &Y, 1 ) == 0 );
+    TEST_ASSERT( sign_is_valid( &X ) );
+    TEST_ASSERT( sign_is_valid( &Y ) );
+    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &Y0 ) == 0 );
+    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Y, &X0 ) == 0 );
+
+    /* mbedtls_mpi_safe_cond_swap(), swap not done */
+    mbedtls_mpi_free( &X );
+    mbedtls_mpi_free( &Y );
+    TEST_ASSERT( mbedtls_test_read_mpi( &X,  16, X_hex ) == 0 );
+    TEST_ASSERT( mbedtls_test_read_mpi( &Y,  16, Y_hex ) == 0 );
+    TEST_ASSERT( mbedtls_mpi_safe_cond_swap( &X, &Y, 0 ) == 0 );
+    TEST_ASSERT( sign_is_valid( &X ) );
+    TEST_ASSERT( sign_is_valid( &Y ) );
+    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &X0 ) == 0 );
+    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Y, &Y0 ) == 0 );
+
 exit:
     mbedtls_mpi_free( &X ); mbedtls_mpi_free( &Y );
     mbedtls_mpi_free( &X0 ); mbedtls_mpi_free( &Y0 );
@@ -719,67 +759,6 @@
 /* END_CASE */
 
 /* BEGIN_CASE */
-void mbedtls_mpi_safe_cond_assign( int x_sign, char * x_str, int y_sign,
-                                   char * y_str )
-{
-    mbedtls_mpi X, Y, XX;
-    mbedtls_mpi_init( &X ); mbedtls_mpi_init( &Y ); mbedtls_mpi_init( &XX );
-
-    TEST_ASSERT( mbedtls_test_read_mpi( &X, 16, x_str ) == 0 );
-    X.s = x_sign;
-    TEST_ASSERT( mbedtls_test_read_mpi( &Y, 16, y_str ) == 0 );
-    Y.s = y_sign;
-    TEST_ASSERT( mbedtls_mpi_copy( &XX, &X ) == 0 );
-
-    TEST_ASSERT( mbedtls_mpi_safe_cond_assign( &X, &Y, 0 ) == 0 );
-    TEST_ASSERT( sign_is_valid( &X ) );
-    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &XX ) == 0 );
-
-    TEST_ASSERT( mbedtls_mpi_safe_cond_assign( &X, &Y, 1 ) == 0 );
-    TEST_ASSERT( sign_is_valid( &X ) );
-    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &Y ) == 0 );
-
-exit:
-    mbedtls_mpi_free( &X ); mbedtls_mpi_free( &Y ); mbedtls_mpi_free( &XX );
-}
-/* END_CASE */
-
-/* BEGIN_CASE */
-void mbedtls_mpi_safe_cond_swap( int x_sign, char * x_str, int y_sign,
-                                 char * y_str )
-{
-    mbedtls_mpi X, Y, XX, YY;
-
-    mbedtls_mpi_init( &X ); mbedtls_mpi_init( &Y );
-    mbedtls_mpi_init( &XX ); mbedtls_mpi_init( &YY );
-
-    TEST_ASSERT( mbedtls_test_read_mpi( &X, 16, x_str ) == 0 );
-    X.s = x_sign;
-    TEST_ASSERT( mbedtls_test_read_mpi( &Y, 16, y_str ) == 0 );
-    Y.s = y_sign;
-
-    TEST_ASSERT( mbedtls_mpi_copy( &XX, &X ) == 0 );
-    TEST_ASSERT( mbedtls_mpi_copy( &YY, &Y ) == 0 );
-
-    TEST_ASSERT( mbedtls_mpi_safe_cond_swap( &X, &Y, 0 ) == 0 );
-    TEST_ASSERT( sign_is_valid( &X ) );
-    TEST_ASSERT( sign_is_valid( &Y ) );
-    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &XX ) == 0 );
-    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Y, &YY ) == 0 );
-
-    TEST_ASSERT( mbedtls_mpi_safe_cond_swap( &X, &Y, 1 ) == 0 );
-    TEST_ASSERT( sign_is_valid( &X ) );
-    TEST_ASSERT( sign_is_valid( &Y ) );
-    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &Y, &XX ) == 0 );
-    TEST_ASSERT( mbedtls_mpi_cmp_mpi( &X, &YY ) == 0 );
-
-exit:
-    mbedtls_mpi_free( &X ); mbedtls_mpi_free( &Y );
-    mbedtls_mpi_free( &XX ); mbedtls_mpi_free( &YY );
-}
-/* END_CASE */
-
-/* BEGIN_CASE */
 void mbedtls_mpi_add_mpi( int radix_X, char * input_X, int radix_Y,
                           char * input_Y, int radix_A, char * input_A )
 {