Prepare codepath tests for early termination

Signed-off-by: Janos Follath <janos.follath@arm.com>
diff --git a/tests/include/test/bignum_codepath_check.h b/tests/include/test/bignum_codepath_check.h
index 6ab68bb..34dfc56 100644
--- a/tests/include/test/bignum_codepath_check.h
+++ b/tests/include/test/bignum_codepath_check.h
@@ -43,6 +43,49 @@
     mbedtls_codepath_check = MBEDTLS_MPI_IS_TEST;
 }
 
+/** Check the codepath taken and fail if it doesn't match.
+ *
+ * When a function returns with an error, it can do so before reaching any interesting codepath. The
+ * same can happen if a parameter to the function is zero. In these cases we need to allow
+ * uninitialised value for the codepath tracking variable.
+ *
+ * This macro expands to an instruction, not an expression.
+ * It may jump to the \c exit label.
+ *
+ * \param path      The expected codepath.
+ *                  This expression may be evaluated multiple times.
+ * \param ret       The expected return value.
+ * \param E         The MPI parameter that can cause shortcuts.
+ */
+#define ASSERT_BIGNUM_CODEPATH(path, ret, E)                            \
+    do {                                                                \
+        if((ret)!=0 || (E).n == 0)                                      \
+            TEST_ASSERT(mbedtls_codepath_check == (path) ||             \
+                        mbedtls_codepath_check == MBEDTLS_MPI_IS_TEST); \
+        else                                                            \
+            TEST_EQUAL(mbedtls_codepath_check, (path));                 \
+    } while (0)
+
+/** Check the codepath taken and fail if it doesn't match.
+ *
+ * When a function returns with an error, it can do so before reaching any interesting codepath. In
+ * this case we need to allow uninitialised value for the codepath tracking variable.
+ *
+ * This macro expands to an instruction, not an expression.
+ * It may jump to the \c exit label.
+ *
+ * \param path      The expected codepath.
+ *                  This expression may be evaluated multiple times.
+ * \param ret       The expected return value.
+ */
+#define ASSERT_RSA_CODEPATH(path, ret)                                  \
+    do {                                                                \
+        if((ret)!=0)                                                    \
+            TEST_ASSERT(mbedtls_codepath_check == (path) ||             \
+                        mbedtls_codepath_check == MBEDTLS_MPI_IS_TEST); \
+        else                                                            \
+            TEST_EQUAL(mbedtls_codepath_check, (path));                 \
+    } while (0)
 #endif /* MBEDTLS_TEST_HOOKS && !MBEDTLS_THREADING_C */
 
 #endif /* BIGNUM_CODEPATH_CHECK_H */
diff --git a/tests/suites/test_suite_bignum.function b/tests/suites/test_suite_bignum.function
index 1102e18..3d2b8a1 100644
--- a/tests/suites/test_suite_bignum.function
+++ b/tests/suites/test_suite_bignum.function
@@ -995,7 +995,7 @@
 #endif
     res = mbedtls_mpi_exp_mod(&Z, &A, &E, &N, &RR);
 #if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
-    TEST_EQUAL(mbedtls_codepath_check, MBEDTLS_MPI_IS_SECRET);
+    ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_SECRET, res, E);
 #endif
     /* We know that exp_mod internally needs RR to be as large as N.
      * Validate that it is the case now, otherwise there was probably
@@ -1034,7 +1034,7 @@
 #endif
     res = mbedtls_mpi_exp_mod(&Z, &A, &E, &N, NULL);
 #if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
-    TEST_EQUAL(mbedtls_codepath_check, MBEDTLS_MPI_IS_SECRET);
+    ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_SECRET, res, E);
 #endif
     TEST_ASSERT(res == exp_result);
     if (res == 0) {
@@ -1047,7 +1047,7 @@
 #endif
     res = mbedtls_mpi_exp_mod_unsafe(&Z, &A, &E, &N, NULL);
 #if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
-    TEST_EQUAL(mbedtls_codepath_check, MBEDTLS_MPI_IS_PUBLIC);
+    ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_PUBLIC, res, E);
 #endif
     TEST_ASSERT(res == exp_result);
     if (res == 0) {
@@ -1061,7 +1061,7 @@
 #endif
     res = mbedtls_mpi_exp_mod(&Z, &A, &E, &N, &RR);
 #if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
-    TEST_EQUAL(mbedtls_codepath_check, MBEDTLS_MPI_IS_SECRET);
+    ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_SECRET, res, E);
 #endif
     TEST_ASSERT(res == exp_result);
     if (res == 0) {
@@ -1075,7 +1075,7 @@
 #endif
     res = mbedtls_mpi_exp_mod(&Z, &A, &E, &N, &RR);
 #if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
-    TEST_EQUAL(mbedtls_codepath_check, MBEDTLS_MPI_IS_SECRET);
+    ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_SECRET, res, E);
 #endif
     TEST_ASSERT(res == exp_result);
     if (res == 0) {
@@ -1121,7 +1121,7 @@
 #endif
     TEST_ASSERT(mbedtls_mpi_exp_mod(&Z, &A, &E, &N, &RR) == exp_result);
 #if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
-    TEST_EQUAL(mbedtls_codepath_check, MBEDTLS_MPI_IS_SECRET);
+    ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_SECRET, exp_result, E);
 #endif
 
 #if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
@@ -1129,7 +1129,7 @@
 #endif
     TEST_ASSERT(mbedtls_mpi_exp_mod_unsafe(&Z, &A, &E, &N, &RR) == exp_result);
 #if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
-    TEST_EQUAL(mbedtls_codepath_check, MBEDTLS_MPI_IS_PUBLIC);
+    ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_PUBLIC, exp_result, E);
 #endif
 
 exit:
diff --git a/tests/suites/test_suite_rsa.function b/tests/suites/test_suite_rsa.function
index 75f3f42..98ea9ef 100644
--- a/tests/suites/test_suite_rsa.function
+++ b/tests/suites/test_suite_rsa.function
@@ -496,7 +496,7 @@
 #endif
     TEST_ASSERT(mbedtls_rsa_public(&ctx, message_str->x, output) == result);
 #if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
-    TEST_EQUAL(mbedtls_codepath_check, MBEDTLS_MPI_IS_PUBLIC);
+    ASSERT_RSA_CODEPATH(MBEDTLS_MPI_IS_PUBLIC, result);
 #endif
     if (result == 0) {
 
@@ -569,7 +569,7 @@
                                         &rnd_info, message_str->x,
                                         output) == result);
 #if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
-        TEST_EQUAL(mbedtls_codepath_check, MBEDTLS_MPI_IS_SECRET);
+        ASSERT_RSA_CODEPATH(MBEDTLS_MPI_IS_SECRET, result);
 #endif
         if (result == 0) {