Merge pull request #9281 from mpg/rsapub

[3.6] Reduce performance regression in RSA public operations
diff --git a/ChangeLog.d/fix-rsa-performance-regression.txt b/ChangeLog.d/fix-rsa-performance-regression.txt
new file mode 100644
index 0000000..603612a
--- /dev/null
+++ b/ChangeLog.d/fix-rsa-performance-regression.txt
@@ -0,0 +1,3 @@
+Bugfix
+   * Fix unintended performance regression when using short RSA public keys.
+     Fixes #9232.
diff --git a/include/mbedtls/bignum.h b/include/mbedtls/bignum.h
index 71d7b97..8367cd3 100644
--- a/include/mbedtls/bignum.h
+++ b/include/mbedtls/bignum.h
@@ -880,7 +880,7 @@
                         mbedtls_mpi_sint b);
 
 /**
- * \brief          Perform a sliding-window exponentiation: X = A^E mod N
+ * \brief          Perform a modular exponentiation: X = A^E mod N
  *
  * \param X        The destination MPI. This must point to an initialized MPI.
  *                 This must not alias E or N.
diff --git a/library/bignum.c b/library/bignum.c
index c45fd5b..4244909 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -27,6 +27,7 @@
 
 #include "mbedtls/bignum.h"
 #include "bignum_core.h"
+#include "bignum_internal.h"
 #include "bn_mul.h"
 #include "mbedtls/platform_util.h"
 #include "mbedtls/error.h"
@@ -1610,9 +1611,13 @@
     return 0;
 }
 
-int mbedtls_mpi_exp_mod(mbedtls_mpi *X, const mbedtls_mpi *A,
-                        const mbedtls_mpi *E, const mbedtls_mpi *N,
-                        mbedtls_mpi *prec_RR)
+/*
+ * Warning! If the parameter E_public has MBEDTLS_MPI_IS_PUBLIC as its value,
+ * this function is not constant time with respect to the exponent (parameter E).
+ */
+static int mbedtls_mpi_exp_mod_optionally_safe(mbedtls_mpi *X, const mbedtls_mpi *A,
+                                               const mbedtls_mpi *E, int E_public,
+                                               const mbedtls_mpi *N, mbedtls_mpi *prec_RR)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
@@ -1695,7 +1700,11 @@
     {
         mbedtls_mpi_uint mm = mbedtls_mpi_core_montmul_init(N->p);
         mbedtls_mpi_core_to_mont_rep(X->p, X->p, N->p, N->n, mm, RR.p, T);
-        mbedtls_mpi_core_exp_mod(X->p, X->p, N->p, N->n, E->p, E->n, RR.p, T);
+        if (E_public == MBEDTLS_MPI_IS_PUBLIC) {
+            mbedtls_mpi_core_exp_mod_unsafe(X->p, X->p, N->p, N->n, E->p, E->n, RR.p, T);
+        } else {
+            mbedtls_mpi_core_exp_mod(X->p, X->p, N->p, N->n, E->p, E->n, RR.p, T);
+        }
         mbedtls_mpi_core_from_mont_rep(X->p, X->p, N->p, N->n, mm, T);
     }
 
@@ -1720,6 +1729,20 @@
     return ret;
 }
 
+int mbedtls_mpi_exp_mod(mbedtls_mpi *X, const mbedtls_mpi *A,
+                        const mbedtls_mpi *E, const mbedtls_mpi *N,
+                        mbedtls_mpi *prec_RR)
+{
+    return mbedtls_mpi_exp_mod_optionally_safe(X, A, E, MBEDTLS_MPI_IS_SECRET, N, prec_RR);
+}
+
+int mbedtls_mpi_exp_mod_unsafe(mbedtls_mpi *X, const mbedtls_mpi *A,
+                               const mbedtls_mpi *E, const mbedtls_mpi *N,
+                               mbedtls_mpi *prec_RR)
+{
+    return mbedtls_mpi_exp_mod_optionally_safe(X, A, E, MBEDTLS_MPI_IS_PUBLIC, N, prec_RR);
+}
+
 /*
  * Greatest common divisor: G = gcd(A, B)  (HAC 14.54)
  */
diff --git a/library/bignum_core.c b/library/bignum_core.c
index 1a3e0b9..4231554 100644
--- a/library/bignum_core.c
+++ b/library/bignum_core.c
@@ -746,8 +746,94 @@
     }
 }
 
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+// Set to a default that is neither MBEDTLS_MPI_IS_PUBLIC nor MBEDTLS_MPI_IS_SECRET
+int mbedtls_mpi_optionally_safe_codepath = MBEDTLS_MPI_IS_PUBLIC + MBEDTLS_MPI_IS_SECRET + 1;
+#endif
+
+/*
+ * This function calculates the indices of the exponent where the exponentiation algorithm should
+ * start processing.
+ *
+ * Warning! If the parameter E_public has MBEDTLS_MPI_IS_PUBLIC as its value,
+ * this function is not constant time with respect to the exponent (parameter E).
+ */
+static inline void exp_mod_calc_first_bit_optionally_safe(const mbedtls_mpi_uint *E,
+                                                          size_t E_limbs,
+                                                          int E_public,
+                                                          size_t *E_limb_index,
+                                                          size_t *E_bit_index)
+{
+    if (E_public == MBEDTLS_MPI_IS_PUBLIC) {
+        /*
+         * Skip leading zero bits.
+         */
+        size_t E_bits = mbedtls_mpi_core_bitlen(E, E_limbs);
+        if (E_bits == 0) {
+            /*
+             * If E is 0 mbedtls_mpi_core_bitlen() returns 0. Even if that is the case, we will want
+             * to represent it as a single 0 bit and as such the bitlength will be 1.
+             */
+            E_bits = 1;
+        }
+
+        *E_limb_index = E_bits / biL;
+        *E_bit_index = E_bits % biL;
+
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+        mbedtls_mpi_optionally_safe_codepath = MBEDTLS_MPI_IS_PUBLIC;
+#endif
+    } else {
+        /*
+         * Here we need to be constant time with respect to E and can't do anything better than
+         * start at the first allocated bit.
+         */
+        *E_limb_index = E_limbs;
+        *E_bit_index = 0;
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+        // Only mark the codepath safe if there wasn't an unsafe codepath before
+        if (mbedtls_mpi_optionally_safe_codepath != MBEDTLS_MPI_IS_PUBLIC) {
+            mbedtls_mpi_optionally_safe_codepath = MBEDTLS_MPI_IS_SECRET;
+        }
+#endif
+    }
+}
+
+/*
+ * Warning! If the parameter window_public has MBEDTLS_MPI_IS_PUBLIC as its value, this function is
+ * not constant time with respect to the window parameter and consequently the exponent of the
+ * exponentiation (parameter E of mbedtls_mpi_core_exp_mod_optionally_safe).
+ */
+static inline void exp_mod_table_lookup_optionally_safe(mbedtls_mpi_uint *Wselect,
+                                                        mbedtls_mpi_uint *Wtable,
+                                                        size_t AN_limbs, size_t welem,
+                                                        mbedtls_mpi_uint window,
+                                                        int window_public)
+{
+    if (window_public == MBEDTLS_MPI_IS_PUBLIC) {
+        memcpy(Wselect, Wtable + window * AN_limbs, AN_limbs * ciL);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+        mbedtls_mpi_optionally_safe_codepath = MBEDTLS_MPI_IS_PUBLIC;
+#endif
+    } else {
+        /* Select Wtable[window] without leaking window through
+         * memory access patterns. */
+        mbedtls_mpi_core_ct_uint_table_lookup(Wselect, Wtable,
+                                              AN_limbs, welem, window);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+        // Only mark the codepath safe if there wasn't an unsafe codepath before
+        if (mbedtls_mpi_optionally_safe_codepath != MBEDTLS_MPI_IS_PUBLIC) {
+            mbedtls_mpi_optionally_safe_codepath = MBEDTLS_MPI_IS_SECRET;
+        }
+#endif
+    }
+}
+
 /* Exponentiation: X := A^E mod N.
  *
+ * Warning! If the parameter E_public has MBEDTLS_MPI_IS_PUBLIC as its value,
+ * this function is not constant time with respect to the exponent (parameter E).
+ *
  * A must already be in Montgomery form.
  *
  * As in other bignum functions, assume that AN_limbs and E_limbs are nonzero.
@@ -758,16 +844,25 @@
  * (The difference is that the body in our loop processes a single bit instead
  * of a full window.)
  */
-void mbedtls_mpi_core_exp_mod(mbedtls_mpi_uint *X,
-                              const mbedtls_mpi_uint *A,
-                              const mbedtls_mpi_uint *N,
-                              size_t AN_limbs,
-                              const mbedtls_mpi_uint *E,
-                              size_t E_limbs,
-                              const mbedtls_mpi_uint *RR,
-                              mbedtls_mpi_uint *T)
+static void mbedtls_mpi_core_exp_mod_optionally_safe(mbedtls_mpi_uint *X,
+                                                     const mbedtls_mpi_uint *A,
+                                                     const mbedtls_mpi_uint *N,
+                                                     size_t AN_limbs,
+                                                     const mbedtls_mpi_uint *E,
+                                                     size_t E_limbs,
+                                                     int E_public,
+                                                     const mbedtls_mpi_uint *RR,
+                                                     mbedtls_mpi_uint *T)
 {
-    const size_t wsize = exp_mod_get_window_size(E_limbs * biL);
+    /* We'll process the bits of E from most significant
+     * (limb_index=E_limbs-1, E_bit_index=biL-1) to least significant
+     * (limb_index=0, E_bit_index=0). */
+    size_t E_limb_index;
+    size_t E_bit_index;
+    exp_mod_calc_first_bit_optionally_safe(E, E_limbs, E_public,
+                                           &E_limb_index, &E_bit_index);
+
+    const size_t wsize = exp_mod_get_window_size(E_limb_index * biL);
     const size_t welem = ((size_t) 1) << wsize;
 
     /* This is how we will use the temporary storage T, which must have space
@@ -786,7 +881,7 @@
 
     const mbedtls_mpi_uint mm = mbedtls_mpi_core_montmul_init(N);
 
-    /* Set Wtable[i] = A^(2^i) (in Montgomery representation) */
+    /* Set Wtable[i] = A^i (in Montgomery representation) */
     exp_mod_precompute_window(A, N, AN_limbs,
                               mm, RR,
                               welem, Wtable, temp);
@@ -798,11 +893,6 @@
     /* X = 1 (in Montgomery presentation) initially */
     memcpy(X, Wtable, AN_limbs * ciL);
 
-    /* We'll process the bits of E from most significant
-     * (limb_index=E_limbs-1, E_bit_index=biL-1) to least significant
-     * (limb_index=0, E_bit_index=0). */
-    size_t E_limb_index = E_limbs;
-    size_t E_bit_index = 0;
     /* At any given time, window contains window_bits bits from E.
      * window_bits can go up to wsize. */
     size_t window_bits = 0;
@@ -828,10 +918,9 @@
          * when we've finished processing the exponent. */
         if (window_bits == wsize ||
             (E_bit_index == 0 && E_limb_index == 0)) {
-            /* Select Wtable[window] without leaking window through
-             * memory access patterns. */
-            mbedtls_mpi_core_ct_uint_table_lookup(Wselect, Wtable,
-                                                  AN_limbs, welem, window);
+
+            exp_mod_table_lookup_optionally_safe(Wselect, Wtable, AN_limbs, welem,
+                                                 window, E_public);
             /* Multiply X by the selected element. */
             mbedtls_mpi_core_montmul(X, X, Wselect, AN_limbs, N, AN_limbs, mm,
                                      temp);
@@ -841,6 +930,42 @@
     } while (!(E_bit_index == 0 && E_limb_index == 0));
 }
 
+void mbedtls_mpi_core_exp_mod(mbedtls_mpi_uint *X,
+                              const mbedtls_mpi_uint *A,
+                              const mbedtls_mpi_uint *N, size_t AN_limbs,
+                              const mbedtls_mpi_uint *E, size_t E_limbs,
+                              const mbedtls_mpi_uint *RR,
+                              mbedtls_mpi_uint *T)
+{
+    mbedtls_mpi_core_exp_mod_optionally_safe(X,
+                                             A,
+                                             N,
+                                             AN_limbs,
+                                             E,
+                                             E_limbs,
+                                             MBEDTLS_MPI_IS_SECRET,
+                                             RR,
+                                             T);
+}
+
+void mbedtls_mpi_core_exp_mod_unsafe(mbedtls_mpi_uint *X,
+                                     const mbedtls_mpi_uint *A,
+                                     const mbedtls_mpi_uint *N, size_t AN_limbs,
+                                     const mbedtls_mpi_uint *E, size_t E_limbs,
+                                     const mbedtls_mpi_uint *RR,
+                                     mbedtls_mpi_uint *T)
+{
+    mbedtls_mpi_core_exp_mod_optionally_safe(X,
+                                             A,
+                                             N,
+                                             AN_limbs,
+                                             E,
+                                             E_limbs,
+                                             MBEDTLS_MPI_IS_PUBLIC,
+                                             RR,
+                                             T);
+}
+
 mbedtls_mpi_uint mbedtls_mpi_core_sub_int(mbedtls_mpi_uint *X,
                                           const mbedtls_mpi_uint *A,
                                           mbedtls_mpi_uint c,  /* doubles as carry */
diff --git a/library/bignum_core.h b/library/bignum_core.h
index 92c8d47..cf6485a 100644
--- a/library/bignum_core.h
+++ b/library/bignum_core.h
@@ -90,6 +90,27 @@
 #define GET_BYTE(X, i)                                \
     (((X)[(i) / ciL] >> (((i) % ciL) * 8)) & 0xff)
 
+/* Constants to identify whether a value is public or secret. If a parameter is marked as secret by
+ * this constant, the function must be constant time with respect to the parameter.
+ *
+ * This is only needed for functions with the _optionally_safe postfix. All other functions have
+ * fixed behavior that can't be changed at runtime and are constant time with respect to their
+ * parameters as prescribed by their documentation or by conventions in their module's documentation.
+ *
+ * Parameters should be named X_public where X is the name of the
+ * corresponding input parameter.
+ *
+ * Implementation should always check using
+ *  if (X_public == MBEDTLS_MPI_IS_PUBLIC) {
+ *      // unsafe path
+ *  } else {
+ *      // safe path
+ *  }
+ * not the other way round, in order to prevent misuse. (This is, if a value
+ * other than the two below is passed, default to the safe path.) */
+#define MBEDTLS_MPI_IS_PUBLIC  0x2a2a2a2a
+#define MBEDTLS_MPI_IS_SECRET  0
+
 /** Count leading zero bits in a given integer.
  *
  * \warning     The result is undefined if \p a == 0
@@ -605,6 +626,42 @@
 size_t mbedtls_mpi_core_exp_mod_working_limbs(size_t AN_limbs, size_t E_limbs);
 
 /**
+ * \brief            Perform a modular exponentiation with public or secret exponent:
+ *                   X = A^E mod N, where \p A is already in Montgomery form.
+ *
+ * \warning          This function is not constant time with respect to \p E (the exponent).
+ *
+ * \p X may be aliased to \p A, but not to \p RR or \p E, even if \p E_limbs ==
+ * \p AN_limbs.
+ *
+ * \param[out] X     The destination MPI, as a little endian array of length
+ *                   \p AN_limbs.
+ * \param[in] A      The base MPI, as a little endian array of length \p AN_limbs.
+ *                   Must be in Montgomery form.
+ * \param[in] N      The modulus, as a little endian array of length \p AN_limbs.
+ * \param AN_limbs   The number of limbs in \p X, \p A, \p N, \p RR.
+ * \param[in] E      The exponent, as a little endian array of length \p E_limbs.
+ * \param E_limbs    The number of limbs in \p E.
+ * \param[in] RR     The precomputed residue of 2^{2*biL} modulo N, as a little
+ *                   endian array of length \p AN_limbs.
+ * \param[in,out] T  Temporary storage of at least the number of limbs returned
+ *                   by `mbedtls_mpi_core_exp_mod_working_limbs()`.
+ *                   Its initial content is unused and its final content is
+ *                   indeterminate.
+ *                   It must not alias or otherwise overlap any of the other
+ *                   parameters.
+ *                   It is up to the caller to zeroize \p T when it is no
+ *                   longer needed, and before freeing it if it was dynamically
+ *                   allocated.
+ */
+void mbedtls_mpi_core_exp_mod_unsafe(mbedtls_mpi_uint *X,
+                                     const mbedtls_mpi_uint *A,
+                                     const mbedtls_mpi_uint *N, size_t AN_limbs,
+                                     const mbedtls_mpi_uint *E, size_t E_limbs,
+                                     const mbedtls_mpi_uint *RR,
+                                     mbedtls_mpi_uint *T);
+
+/**
  * \brief            Perform a modular exponentiation with secret exponent:
  *                   X = A^E mod N, where \p A is already in Montgomery form.
  *
@@ -760,4 +817,17 @@
                                     mbedtls_mpi_uint mm,
                                     mbedtls_mpi_uint *T);
 
+/*
+ * Can't define thread local variables with our abstraction layer: do nothing if threading is on.
+ */
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+extern int mbedtls_mpi_optionally_safe_codepath;
+
+static inline void mbedtls_mpi_optionally_safe_codepath_reset(void)
+{
+    // Set to a default that is neither MBEDTLS_MPI_IS_PUBLIC nor MBEDTLS_MPI_IS_SECRET
+    mbedtls_mpi_optionally_safe_codepath = MBEDTLS_MPI_IS_PUBLIC + MBEDTLS_MPI_IS_SECRET + 1;
+}
+#endif
+
 #endif /* MBEDTLS_BIGNUM_CORE_H */
diff --git a/library/bignum_internal.h b/library/bignum_internal.h
new file mode 100644
index 0000000..aceaf55
--- /dev/null
+++ b/library/bignum_internal.h
@@ -0,0 +1,50 @@
+/**
+ * \file bignum_internal.h
+ *
+ * \brief Internal-only bignum public-key cryptosystem API.
+ *
+ * This file declares bignum-related functions that are to be used
+ * only from within the Mbed TLS library itself.
+ *
+ */
+/*
+ *  Copyright The Mbed TLS Contributors
+ *  SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
+ */
+#ifndef MBEDTLS_BIGNUM_INTERNAL_H
+#define MBEDTLS_BIGNUM_INTERNAL_H
+
+/**
+ * \brief          Perform a modular exponentiation: X = A^E mod N
+ *
+ * \warning        This function is not constant time with respect to \p E (the exponent).
+ *
+ * \param X        The destination MPI. This must point to an initialized MPI.
+ *                 This must not alias E or N.
+ * \param A        The base of the exponentiation.
+ *                 This must point to an initialized MPI.
+ * \param E        The exponent MPI. This must point to an initialized MPI.
+ * \param N        The base for the modular reduction. This must point to an
+ *                 initialized MPI.
+ * \param prec_RR  A helper MPI depending solely on \p N which can be used to
+ *                 speed-up multiple modular exponentiations for the same value
+ *                 of \p N. This may be \c NULL. If it is not \c NULL, it must
+ *                 point to an initialized MPI. If it hasn't been used after
+ *                 the call to mbedtls_mpi_init(), this function will compute
+ *                 the helper value and store it in \p prec_RR for reuse on
+ *                 subsequent calls to this function. Otherwise, the function
+ *                 will assume that \p prec_RR holds the helper value set by a
+ *                 previous call to mbedtls_mpi_exp_mod(), and reuse it.
+ *
+ * \return         \c 0 if successful.
+ * \return         #MBEDTLS_ERR_MPI_ALLOC_FAILED if a memory allocation failed.
+ * \return         #MBEDTLS_ERR_MPI_BAD_INPUT_DATA if \c N is negative or
+ *                 even, or if \c E is negative.
+ * \return         Another negative error code on different kinds of failures.
+ *
+ */
+int mbedtls_mpi_exp_mod_unsafe(mbedtls_mpi *X, const mbedtls_mpi *A,
+                               const mbedtls_mpi *E, const mbedtls_mpi *N,
+                               mbedtls_mpi *prec_RR);
+
+#endif /* bignum_internal.h */
diff --git a/library/rsa.c b/library/rsa.c
index 7eb4a25..557faaf 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -29,6 +29,7 @@
 
 #include "mbedtls/rsa.h"
 #include "bignum_core.h"
+#include "bignum_internal.h"
 #include "rsa_alt_helpers.h"
 #include "rsa_internal.h"
 #include "mbedtls/oid.h"
@@ -1259,7 +1260,7 @@
     }
 
     olen = ctx->len;
-    MBEDTLS_MPI_CHK(mbedtls_mpi_exp_mod(&T, &T, &ctx->E, &ctx->N, &ctx->RN));
+    MBEDTLS_MPI_CHK(mbedtls_mpi_exp_mod_unsafe(&T, &T, &ctx->E, &ctx->N, &ctx->RN));
     MBEDTLS_MPI_CHK(mbedtls_mpi_write_binary(&T, output, olen));
 
 cleanup:
diff --git a/tests/suites/test_suite_bignum.function b/tests/suites/test_suite_bignum.function
index 1830e5a..3ac4e10 100644
--- a/tests/suites/test_suite_bignum.function
+++ b/tests/suites/test_suite_bignum.function
@@ -3,6 +3,7 @@
 #include "mbedtls/entropy.h"
 #include "constant_time_internal.h"
 #include "bignum_core.h"
+#include "bignum_internal.h"
 #include "test/constant_flow.h"
 
 #if MBEDTLS_MPI_MAX_BITS > 792
diff --git a/tests/suites/test_suite_bignum_core.function b/tests/suites/test_suite_bignum_core.function
index db84d62..08dac2e 100644
--- a/tests/suites/test_suite_bignum_core.function
+++ b/tests/suites/test_suite_bignum_core.function
@@ -1178,6 +1178,7 @@
                       char *input_E, char *input_X)
 {
     mbedtls_mpi_uint *A = NULL;
+    mbedtls_mpi_uint *A_copy = NULL;
     mbedtls_mpi_uint *E = NULL;
     mbedtls_mpi_uint *N = NULL;
     mbedtls_mpi_uint *X = NULL;
@@ -1229,19 +1230,56 @@
 
     TEST_CALLOC(T, working_limbs);
 
-    mbedtls_mpi_core_exp_mod(Y, A, N, N_limbs, E, E_limbs, R2, T);
+    /* Test the safe variant */
 
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+    mbedtls_mpi_optionally_safe_codepath_reset();
+#endif
+    mbedtls_mpi_core_exp_mod(Y, A, N, N_limbs, E, E_limbs, R2, T);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+    TEST_EQUAL(mbedtls_mpi_optionally_safe_codepath, MBEDTLS_MPI_IS_SECRET);
+#endif
     TEST_EQUAL(0, memcmp(X, Y, N_limbs * sizeof(mbedtls_mpi_uint)));
 
-    /* Check when output aliased to input */
+    /* Test the unsafe variant */
 
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+    mbedtls_mpi_optionally_safe_codepath_reset();
+#endif
+    mbedtls_mpi_core_exp_mod_unsafe(Y, A, N, N_limbs, E, E_limbs, R2, T);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+    TEST_EQUAL(mbedtls_mpi_optionally_safe_codepath, MBEDTLS_MPI_IS_PUBLIC);
+#endif
+    TEST_EQUAL(0, memcmp(X, Y, N_limbs * sizeof(mbedtls_mpi_uint)));
+
+    /* Check both with output aliased to input */
+
+    TEST_CALLOC(A_copy, A_limbs);
+    memcpy(A_copy, A, sizeof(*A_copy) * A_limbs);
+
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+    mbedtls_mpi_optionally_safe_codepath_reset();
+#endif
     mbedtls_mpi_core_exp_mod(A, A, N, N_limbs, E, E_limbs, R2, T);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+    TEST_EQUAL(mbedtls_mpi_optionally_safe_codepath, MBEDTLS_MPI_IS_SECRET);
+#endif
+    TEST_EQUAL(0, memcmp(X, A, N_limbs * sizeof(mbedtls_mpi_uint)));
 
+    memcpy(A, A_copy, sizeof(*A) * A_limbs);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+    mbedtls_mpi_optionally_safe_codepath_reset();
+#endif
+    mbedtls_mpi_core_exp_mod_unsafe(A, A, N, N_limbs, E, E_limbs, R2, T);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+    TEST_EQUAL(mbedtls_mpi_optionally_safe_codepath, MBEDTLS_MPI_IS_PUBLIC);
+#endif
     TEST_EQUAL(0, memcmp(X, A, N_limbs * sizeof(mbedtls_mpi_uint)));
 
 exit:
     mbedtls_free(T);
     mbedtls_free(A);
+    mbedtls_free(A_copy);
     mbedtls_free(E);
     mbedtls_free(N);
     mbedtls_free(X);