Merge pull request #9536 from mpg/rsapub-perf-dev
[dev] Rsapub performance fix
diff --git a/ChangeLog.d/fix-rsa-performance-regression.txt b/ChangeLog.d/fix-rsa-performance-regression.txt
new file mode 100644
index 0000000..603612a
--- /dev/null
+++ b/ChangeLog.d/fix-rsa-performance-regression.txt
@@ -0,0 +1,3 @@
+Bugfix
+ * Fix unintended performance regression when using short RSA public keys.
+ Fixes #9232.
diff --git a/tests/include/test/bignum_codepath_check.h b/tests/include/test/bignum_codepath_check.h
new file mode 100644
index 0000000..3d72be1
--- /dev/null
+++ b/tests/include/test/bignum_codepath_check.h
@@ -0,0 +1,94 @@
+/** Support for path tracking in optionally safe bignum functions
+ *
+ * The functions are called when an optionally safe path is taken and logs it with a single
+ * variable. This variable is at any time in one of three states:
+ * - MBEDTLS_MPI_IS_TEST: No optionally safe path has been taken since the last reset
+ * - MBEDTLS_MPI_IS_SECRET: Only safe paths were teken since the last reset
+ * - MBEDTLS_MPI_IS_PUBLIC: At least one unsafe path has been taken since the last reset
+ *
+ * Use a simple global variable to track execution path. Making it work with multithreading
+ * isn't worth the effort as multithreaded tests add little to no value here.
+ */
+/*
+ * Copyright The Mbed TLS Contributors
+ * SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
+ */
+
+#ifndef BIGNUM_CODEPATH_CHECK_H
+#define BIGNUM_CODEPATH_CHECK_H
+
+#include "bignum_core.h"
+
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+
+extern int mbedtls_codepath_check;
+
+/**
+ * \brief Setup the codepath test hooks used by optionally safe bignum functions to signal
+ * the path taken.
+ */
+void mbedtls_codepath_test_hooks_setup(void);
+
+/**
+ * \brief Teardown the codepath test hooks used by optionally safe bignum functions to
+ * signal the path taken.
+ */
+void mbedtls_codepath_test_hooks_teardown(void);
+
+/**
+ * \brief Reset the state of the codepath to the initial state.
+ */
+static inline void mbedtls_codepath_reset(void)
+{
+ mbedtls_codepath_check = MBEDTLS_MPI_IS_TEST;
+}
+
+/** Check the codepath taken and fail if it doesn't match.
+ *
+ * When a function returns with an error, it can do so before reaching any interesting codepath. The
+ * same can happen if a parameter to the function is zero. In these cases we need to allow
+ * the codepath tracking variable to still have its initial "not set" value.
+ *
+ * This macro expands to an instruction, not an expression.
+ * It may jump to the \c exit label.
+ *
+ * \param path The expected codepath.
+ * This expression may be evaluated multiple times.
+ * \param ret The expected return value.
+ * \param E The MPI parameter that can cause shortcuts.
+ */
+#define ASSERT_BIGNUM_CODEPATH(path, ret, E) \
+ do { \
+ if ((ret) != 0 || (E).n == 0) { \
+ TEST_ASSERT(mbedtls_codepath_check == (path) || \
+ mbedtls_codepath_check == MBEDTLS_MPI_IS_TEST); \
+ } else { \
+ TEST_EQUAL(mbedtls_codepath_check, (path)); \
+ } \
+ } while (0)
+
+/** Check the codepath taken and fail if it doesn't match.
+ *
+ * When a function returns with an error, it can do so before reaching any interesting codepath. In
+ * this case we need to allow the codepath tracking variable to still have its
+ * initial "not set" value.
+ *
+ * This macro expands to an instruction, not an expression.
+ * It may jump to the \c exit label.
+ *
+ * \param path The expected codepath.
+ * This expression may be evaluated multiple times.
+ * \param ret The expected return value.
+ */
+#define ASSERT_RSA_CODEPATH(path, ret) \
+ do { \
+ if ((ret) != 0) { \
+ TEST_ASSERT(mbedtls_codepath_check == (path) || \
+ mbedtls_codepath_check == MBEDTLS_MPI_IS_TEST); \
+ } else { \
+ TEST_EQUAL(mbedtls_codepath_check, (path)); \
+ } \
+ } while (0)
+#endif /* MBEDTLS_TEST_HOOKS && !MBEDTLS_THREADING_C */
+
+#endif /* BIGNUM_CODEPATH_CHECK_H */
diff --git a/tests/src/bignum_codepath_check.c b/tests/src/bignum_codepath_check.c
new file mode 100644
index 0000000..b752d13
--- /dev/null
+++ b/tests/src/bignum_codepath_check.c
@@ -0,0 +1,38 @@
+/** Support for path tracking in optionally safe bignum functions
+ */
+/*
+ * Copyright The Mbed TLS Contributors
+ * SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
+ */
+
+#include "test/bignum_codepath_check.h"
+#include "bignum_core_invasive.h"
+
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+int mbedtls_codepath_check = MBEDTLS_MPI_IS_TEST;
+
+void mbedtls_codepath_take_safe(void)
+{
+ if (mbedtls_codepath_check == MBEDTLS_MPI_IS_TEST) {
+ mbedtls_codepath_check = MBEDTLS_MPI_IS_SECRET;
+ }
+}
+
+void mbedtls_codepath_take_unsafe(void)
+{
+ mbedtls_codepath_check = MBEDTLS_MPI_IS_PUBLIC;
+}
+
+void mbedtls_codepath_test_hooks_setup(void)
+{
+ mbedtls_safe_codepath_hook = mbedtls_codepath_take_safe;
+ mbedtls_unsafe_codepath_hook = mbedtls_codepath_take_unsafe;
+}
+
+void mbedtls_codepath_test_hooks_teardown(void)
+{
+ mbedtls_safe_codepath_hook = NULL;
+ mbedtls_unsafe_codepath_hook = NULL;
+}
+
+#endif /* MBEDTLS_TEST_HOOKS && !MBEDTLS_THREADING_C */
diff --git a/tests/src/helpers.c b/tests/src/helpers.c
index 065d17d..db50296 100644
--- a/tests/src/helpers.c
+++ b/tests/src/helpers.c
@@ -16,6 +16,9 @@
#if defined(MBEDTLS_TEST_HOOKS) && defined(MBEDTLS_PSA_CRYPTO_C)
#include <test/psa_memory_poisoning_wrappers.h>
#endif
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+#include <test/bignum_codepath_check.h>
+#endif
#if defined(MBEDTLS_THREADING_C)
#include "mbedtls/threading.h"
#endif
@@ -342,6 +345,11 @@
mbedtls_mutex_init(&mbedtls_test_info_mutex);
#endif /* MBEDTLS_THREADING_C */
+
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ mbedtls_codepath_test_hooks_setup();
+#endif /* MBEDTLS_TEST_HOOKS && !MBEDTLS_THREADING_C */
+
return ret;
}
@@ -359,6 +367,10 @@
#if defined(MBEDTLS_PLATFORM_C)
mbedtls_platform_teardown(&platform_ctx);
#endif /* MBEDTLS_PLATFORM_C */
+
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ mbedtls_codepath_test_hooks_teardown();
+#endif /* MBEDTLS_TEST_HOOKS && !MBEDTLS_THREADING_C */
}
int mbedtls_test_ascii2uc(const char c, unsigned char *uc)
diff --git a/tf-psa-crypto/drivers/builtin/include/mbedtls/bignum.h b/tf-psa-crypto/drivers/builtin/include/mbedtls/bignum.h
index 71d7b97..8367cd3 100644
--- a/tf-psa-crypto/drivers/builtin/include/mbedtls/bignum.h
+++ b/tf-psa-crypto/drivers/builtin/include/mbedtls/bignum.h
@@ -880,7 +880,7 @@
mbedtls_mpi_sint b);
/**
- * \brief Perform a sliding-window exponentiation: X = A^E mod N
+ * \brief Perform a modular exponentiation: X = A^E mod N
*
* \param X The destination MPI. This must point to an initialized MPI.
* This must not alias E or N.
diff --git a/tf-psa-crypto/drivers/builtin/src/bignum.c b/tf-psa-crypto/drivers/builtin/src/bignum.c
index c45fd5b..4244909 100644
--- a/tf-psa-crypto/drivers/builtin/src/bignum.c
+++ b/tf-psa-crypto/drivers/builtin/src/bignum.c
@@ -27,6 +27,7 @@
#include "mbedtls/bignum.h"
#include "bignum_core.h"
+#include "bignum_internal.h"
#include "bn_mul.h"
#include "mbedtls/platform_util.h"
#include "mbedtls/error.h"
@@ -1610,9 +1611,13 @@
return 0;
}
-int mbedtls_mpi_exp_mod(mbedtls_mpi *X, const mbedtls_mpi *A,
- const mbedtls_mpi *E, const mbedtls_mpi *N,
- mbedtls_mpi *prec_RR)
+/*
+ * Warning! If the parameter E_public has MBEDTLS_MPI_IS_PUBLIC as its value,
+ * this function is not constant time with respect to the exponent (parameter E).
+ */
+static int mbedtls_mpi_exp_mod_optionally_safe(mbedtls_mpi *X, const mbedtls_mpi *A,
+ const mbedtls_mpi *E, int E_public,
+ const mbedtls_mpi *N, mbedtls_mpi *prec_RR)
{
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
@@ -1695,7 +1700,11 @@
{
mbedtls_mpi_uint mm = mbedtls_mpi_core_montmul_init(N->p);
mbedtls_mpi_core_to_mont_rep(X->p, X->p, N->p, N->n, mm, RR.p, T);
- mbedtls_mpi_core_exp_mod(X->p, X->p, N->p, N->n, E->p, E->n, RR.p, T);
+ if (E_public == MBEDTLS_MPI_IS_PUBLIC) {
+ mbedtls_mpi_core_exp_mod_unsafe(X->p, X->p, N->p, N->n, E->p, E->n, RR.p, T);
+ } else {
+ mbedtls_mpi_core_exp_mod(X->p, X->p, N->p, N->n, E->p, E->n, RR.p, T);
+ }
mbedtls_mpi_core_from_mont_rep(X->p, X->p, N->p, N->n, mm, T);
}
@@ -1720,6 +1729,20 @@
return ret;
}
+int mbedtls_mpi_exp_mod(mbedtls_mpi *X, const mbedtls_mpi *A,
+ const mbedtls_mpi *E, const mbedtls_mpi *N,
+ mbedtls_mpi *prec_RR)
+{
+ return mbedtls_mpi_exp_mod_optionally_safe(X, A, E, MBEDTLS_MPI_IS_SECRET, N, prec_RR);
+}
+
+int mbedtls_mpi_exp_mod_unsafe(mbedtls_mpi *X, const mbedtls_mpi *A,
+ const mbedtls_mpi *E, const mbedtls_mpi *N,
+ mbedtls_mpi *prec_RR)
+{
+ return mbedtls_mpi_exp_mod_optionally_safe(X, A, E, MBEDTLS_MPI_IS_PUBLIC, N, prec_RR);
+}
+
/*
* Greatest common divisor: G = gcd(A, B) (HAC 14.54)
*/
diff --git a/tf-psa-crypto/drivers/builtin/src/bignum_core.c b/tf-psa-crypto/drivers/builtin/src/bignum_core.c
index 39ab6e9..60f48f9 100644
--- a/tf-psa-crypto/drivers/builtin/src/bignum_core.c
+++ b/tf-psa-crypto/drivers/builtin/src/bignum_core.c
@@ -747,8 +747,96 @@
}
}
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+void (*mbedtls_safe_codepath_hook)(void) = NULL;
+void (*mbedtls_unsafe_codepath_hook)(void) = NULL;
+#endif
+
+/*
+ * This function calculates the indices of the exponent where the exponentiation algorithm should
+ * start processing.
+ *
+ * Warning! If the parameter E_public has MBEDTLS_MPI_IS_PUBLIC as its value,
+ * this function is not constant time with respect to the exponent (parameter E).
+ */
+static inline void exp_mod_calc_first_bit_optionally_safe(const mbedtls_mpi_uint *E,
+ size_t E_limbs,
+ int E_public,
+ size_t *E_limb_index,
+ size_t *E_bit_index)
+{
+ if (E_public == MBEDTLS_MPI_IS_PUBLIC) {
+ /*
+ * Skip leading zero bits.
+ */
+ size_t E_bits = mbedtls_mpi_core_bitlen(E, E_limbs);
+ if (E_bits == 0) {
+ /*
+ * If E is 0 mbedtls_mpi_core_bitlen() returns 0. Even if that is the case, we will want
+ * to represent it as a single 0 bit and as such the bitlength will be 1.
+ */
+ E_bits = 1;
+ }
+
+ *E_limb_index = E_bits / biL;
+ *E_bit_index = E_bits % biL;
+
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ if (mbedtls_unsafe_codepath_hook != NULL) {
+ mbedtls_unsafe_codepath_hook();
+ }
+#endif
+ } else {
+ /*
+ * Here we need to be constant time with respect to E and can't do anything better than
+ * start at the first allocated bit.
+ */
+ *E_limb_index = E_limbs;
+ *E_bit_index = 0;
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ if (mbedtls_safe_codepath_hook != NULL) {
+ mbedtls_safe_codepath_hook();
+ }
+#endif
+ }
+}
+
+/*
+ * Warning! If the parameter window_public has MBEDTLS_MPI_IS_PUBLIC as its value, this function is
+ * not constant time with respect to the window parameter and consequently the exponent of the
+ * exponentiation (parameter E of mbedtls_mpi_core_exp_mod_optionally_safe).
+ */
+static inline void exp_mod_table_lookup_optionally_safe(mbedtls_mpi_uint *Wselect,
+ mbedtls_mpi_uint *Wtable,
+ size_t AN_limbs, size_t welem,
+ mbedtls_mpi_uint window,
+ int window_public)
+{
+ if (window_public == MBEDTLS_MPI_IS_PUBLIC) {
+ memcpy(Wselect, Wtable + window * AN_limbs, AN_limbs * ciL);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ if (mbedtls_unsafe_codepath_hook != NULL) {
+ mbedtls_unsafe_codepath_hook();
+ }
+#endif
+ } else {
+ /* Select Wtable[window] without leaking window through
+ * memory access patterns. */
+ mbedtls_mpi_core_ct_uint_table_lookup(Wselect, Wtable,
+ AN_limbs, welem, window);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ if (mbedtls_safe_codepath_hook != NULL) {
+ mbedtls_safe_codepath_hook();
+ }
+#endif
+ }
+}
+
/* Exponentiation: X := A^E mod N.
*
+ * Warning! If the parameter E_public has MBEDTLS_MPI_IS_PUBLIC as its value,
+ * this function is not constant time with respect to the exponent (parameter E).
+ *
* A must already be in Montgomery form.
*
* As in other bignum functions, assume that AN_limbs and E_limbs are nonzero.
@@ -759,16 +847,25 @@
* (The difference is that the body in our loop processes a single bit instead
* of a full window.)
*/
-void mbedtls_mpi_core_exp_mod(mbedtls_mpi_uint *X,
- const mbedtls_mpi_uint *A,
- const mbedtls_mpi_uint *N,
- size_t AN_limbs,
- const mbedtls_mpi_uint *E,
- size_t E_limbs,
- const mbedtls_mpi_uint *RR,
- mbedtls_mpi_uint *T)
+static void mbedtls_mpi_core_exp_mod_optionally_safe(mbedtls_mpi_uint *X,
+ const mbedtls_mpi_uint *A,
+ const mbedtls_mpi_uint *N,
+ size_t AN_limbs,
+ const mbedtls_mpi_uint *E,
+ size_t E_limbs,
+ int E_public,
+ const mbedtls_mpi_uint *RR,
+ mbedtls_mpi_uint *T)
{
- const size_t wsize = exp_mod_get_window_size(E_limbs * biL);
+ /* We'll process the bits of E from most significant
+ * (limb_index=E_limbs-1, E_bit_index=biL-1) to least significant
+ * (limb_index=0, E_bit_index=0). */
+ size_t E_limb_index = E_limbs;
+ size_t E_bit_index = 0;
+ exp_mod_calc_first_bit_optionally_safe(E, E_limbs, E_public,
+ &E_limb_index, &E_bit_index);
+
+ const size_t wsize = exp_mod_get_window_size(E_limb_index * biL);
const size_t welem = ((size_t) 1) << wsize;
/* This is how we will use the temporary storage T, which must have space
@@ -787,7 +884,7 @@
const mbedtls_mpi_uint mm = mbedtls_mpi_core_montmul_init(N);
- /* Set Wtable[i] = A^(2^i) (in Montgomery representation) */
+ /* Set Wtable[i] = A^i (in Montgomery representation) */
exp_mod_precompute_window(A, N, AN_limbs,
mm, RR,
welem, Wtable, temp);
@@ -799,11 +896,6 @@
/* X = 1 (in Montgomery presentation) initially */
memcpy(X, Wtable, AN_limbs * ciL);
- /* We'll process the bits of E from most significant
- * (limb_index=E_limbs-1, E_bit_index=biL-1) to least significant
- * (limb_index=0, E_bit_index=0). */
- size_t E_limb_index = E_limbs;
- size_t E_bit_index = 0;
/* At any given time, window contains window_bits bits from E.
* window_bits can go up to wsize. */
size_t window_bits = 0;
@@ -829,10 +921,9 @@
* when we've finished processing the exponent. */
if (window_bits == wsize ||
(E_bit_index == 0 && E_limb_index == 0)) {
- /* Select Wtable[window] without leaking window through
- * memory access patterns. */
- mbedtls_mpi_core_ct_uint_table_lookup(Wselect, Wtable,
- AN_limbs, welem, window);
+
+ exp_mod_table_lookup_optionally_safe(Wselect, Wtable, AN_limbs, welem,
+ window, E_public);
/* Multiply X by the selected element. */
mbedtls_mpi_core_montmul(X, X, Wselect, AN_limbs, N, AN_limbs, mm,
temp);
@@ -842,6 +933,42 @@
} while (!(E_bit_index == 0 && E_limb_index == 0));
}
+void mbedtls_mpi_core_exp_mod(mbedtls_mpi_uint *X,
+ const mbedtls_mpi_uint *A,
+ const mbedtls_mpi_uint *N, size_t AN_limbs,
+ const mbedtls_mpi_uint *E, size_t E_limbs,
+ const mbedtls_mpi_uint *RR,
+ mbedtls_mpi_uint *T)
+{
+ mbedtls_mpi_core_exp_mod_optionally_safe(X,
+ A,
+ N,
+ AN_limbs,
+ E,
+ E_limbs,
+ MBEDTLS_MPI_IS_SECRET,
+ RR,
+ T);
+}
+
+void mbedtls_mpi_core_exp_mod_unsafe(mbedtls_mpi_uint *X,
+ const mbedtls_mpi_uint *A,
+ const mbedtls_mpi_uint *N, size_t AN_limbs,
+ const mbedtls_mpi_uint *E, size_t E_limbs,
+ const mbedtls_mpi_uint *RR,
+ mbedtls_mpi_uint *T)
+{
+ mbedtls_mpi_core_exp_mod_optionally_safe(X,
+ A,
+ N,
+ AN_limbs,
+ E,
+ E_limbs,
+ MBEDTLS_MPI_IS_PUBLIC,
+ RR,
+ T);
+}
+
mbedtls_mpi_uint mbedtls_mpi_core_sub_int(mbedtls_mpi_uint *X,
const mbedtls_mpi_uint *A,
mbedtls_mpi_uint c, /* doubles as carry */
diff --git a/tf-psa-crypto/drivers/builtin/src/bignum_core.h b/tf-psa-crypto/drivers/builtin/src/bignum_core.h
index 51ecca5..484fa8c 100644
--- a/tf-psa-crypto/drivers/builtin/src/bignum_core.h
+++ b/tf-psa-crypto/drivers/builtin/src/bignum_core.h
@@ -70,9 +70,7 @@
#include "common.h"
-#if defined(MBEDTLS_BIGNUM_C)
#include "mbedtls/bignum.h"
-#endif
#include "constant_time_internal.h"
@@ -90,6 +88,34 @@
#define GET_BYTE(X, i) \
(((X)[(i) / ciL] >> (((i) % ciL) * 8)) & 0xff)
+/* Constants to identify whether a value is public or secret. If a parameter is marked as secret by
+ * this constant, the function must be constant time with respect to the parameter.
+ *
+ * This is only needed for functions with the _optionally_safe postfix. All other functions have
+ * fixed behavior that can't be changed at runtime and are constant time with respect to their
+ * parameters as prescribed by their documentation or by conventions in their module's documentation.
+ *
+ * Parameters should be named X_public where X is the name of the
+ * corresponding input parameter.
+ *
+ * Implementation should always check using
+ * if (X_public == MBEDTLS_MPI_IS_PUBLIC) {
+ * // unsafe path
+ * } else {
+ * // safe path
+ * }
+ * not the other way round, in order to prevent misuse. (That is, if a value
+ * other than the two below is passed, default to the safe path.)
+ *
+ * The value of MBEDTLS_MPI_IS_PUBLIC is chosen in a way that is unlikely to happen by accident, but
+ * which can be used as an immediate value in a Thumb2 comparison (for code size). */
+#define MBEDTLS_MPI_IS_PUBLIC 0x2a2a2a2a
+#define MBEDTLS_MPI_IS_SECRET 0
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+// Default value for testing that is neither MBEDTLS_MPI_IS_PUBLIC nor MBEDTLS_MPI_IS_SECRET
+#define MBEDTLS_MPI_IS_TEST 1
+#endif
+
/** Count leading zero bits in a given integer.
*
* \warning The result is undefined if \p a == 0
@@ -615,6 +641,42 @@
size_t mbedtls_mpi_core_exp_mod_working_limbs(size_t AN_limbs, size_t E_limbs);
/**
+ * \brief Perform a modular exponentiation with public or secret exponent:
+ * X = A^E mod N, where \p A is already in Montgomery form.
+ *
+ * \warning This function is not constant time with respect to \p E (the exponent).
+ *
+ * \p X may be aliased to \p A, but not to \p RR or \p E, even if \p E_limbs ==
+ * \p AN_limbs.
+ *
+ * \param[out] X The destination MPI, as a little endian array of length
+ * \p AN_limbs.
+ * \param[in] A The base MPI, as a little endian array of length \p AN_limbs.
+ * Must be in Montgomery form.
+ * \param[in] N The modulus, as a little endian array of length \p AN_limbs.
+ * \param AN_limbs The number of limbs in \p X, \p A, \p N, \p RR.
+ * \param[in] E The exponent, as a little endian array of length \p E_limbs.
+ * \param E_limbs The number of limbs in \p E.
+ * \param[in] RR The precomputed residue of 2^{2*biL} modulo N, as a little
+ * endian array of length \p AN_limbs.
+ * \param[in,out] T Temporary storage of at least the number of limbs returned
+ * by `mbedtls_mpi_core_exp_mod_working_limbs()`.
+ * Its initial content is unused and its final content is
+ * indeterminate.
+ * It must not alias or otherwise overlap any of the other
+ * parameters.
+ * It is up to the caller to zeroize \p T when it is no
+ * longer needed, and before freeing it if it was dynamically
+ * allocated.
+ */
+void mbedtls_mpi_core_exp_mod_unsafe(mbedtls_mpi_uint *X,
+ const mbedtls_mpi_uint *A,
+ const mbedtls_mpi_uint *N, size_t AN_limbs,
+ const mbedtls_mpi_uint *E, size_t E_limbs,
+ const mbedtls_mpi_uint *RR,
+ mbedtls_mpi_uint *T);
+
+/**
* \brief Perform a modular exponentiation with secret exponent:
* X = A^E mod N, where \p A is already in Montgomery form.
*
diff --git a/tf-psa-crypto/drivers/builtin/src/bignum_core_invasive.h b/tf-psa-crypto/drivers/builtin/src/bignum_core_invasive.h
new file mode 100644
index 0000000..167099d
--- /dev/null
+++ b/tf-psa-crypto/drivers/builtin/src/bignum_core_invasive.h
@@ -0,0 +1,23 @@
+/**
+ * \file bignum_core_invasive.h
+ *
+ * \brief Function declarations for invasive functions of bignum core.
+ */
+/**
+ * Copyright The Mbed TLS Contributors
+ * SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
+ */
+
+#ifndef MBEDTLS_BIGNUM_CORE_INVASIVE_H
+#define MBEDTLS_BIGNUM_CORE_INVASIVE_H
+
+#include "bignum_core.h"
+
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+
+extern void (*mbedtls_safe_codepath_hook)(void);
+extern void (*mbedtls_unsafe_codepath_hook)(void);
+
+#endif /* MBEDTLS_TEST_HOOKS && !MBEDTLS_THREADING_C */
+
+#endif /* MBEDTLS_BIGNUM_CORE_INVASIVE_H */
diff --git a/tf-psa-crypto/drivers/builtin/src/bignum_internal.h b/tf-psa-crypto/drivers/builtin/src/bignum_internal.h
new file mode 100644
index 0000000..aceaf55
--- /dev/null
+++ b/tf-psa-crypto/drivers/builtin/src/bignum_internal.h
@@ -0,0 +1,50 @@
+/**
+ * \file bignum_internal.h
+ *
+ * \brief Internal-only bignum public-key cryptosystem API.
+ *
+ * This file declares bignum-related functions that are to be used
+ * only from within the Mbed TLS library itself.
+ *
+ */
+/*
+ * Copyright The Mbed TLS Contributors
+ * SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
+ */
+#ifndef MBEDTLS_BIGNUM_INTERNAL_H
+#define MBEDTLS_BIGNUM_INTERNAL_H
+
+/**
+ * \brief Perform a modular exponentiation: X = A^E mod N
+ *
+ * \warning This function is not constant time with respect to \p E (the exponent).
+ *
+ * \param X The destination MPI. This must point to an initialized MPI.
+ * This must not alias E or N.
+ * \param A The base of the exponentiation.
+ * This must point to an initialized MPI.
+ * \param E The exponent MPI. This must point to an initialized MPI.
+ * \param N The base for the modular reduction. This must point to an
+ * initialized MPI.
+ * \param prec_RR A helper MPI depending solely on \p N which can be used to
+ * speed-up multiple modular exponentiations for the same value
+ * of \p N. This may be \c NULL. If it is not \c NULL, it must
+ * point to an initialized MPI. If it hasn't been used after
+ * the call to mbedtls_mpi_init(), this function will compute
+ * the helper value and store it in \p prec_RR for reuse on
+ * subsequent calls to this function. Otherwise, the function
+ * will assume that \p prec_RR holds the helper value set by a
+ * previous call to mbedtls_mpi_exp_mod(), and reuse it.
+ *
+ * \return \c 0 if successful.
+ * \return #MBEDTLS_ERR_MPI_ALLOC_FAILED if a memory allocation failed.
+ * \return #MBEDTLS_ERR_MPI_BAD_INPUT_DATA if \c N is negative or
+ * even, or if \c E is negative.
+ * \return Another negative error code on different kinds of failures.
+ *
+ */
+int mbedtls_mpi_exp_mod_unsafe(mbedtls_mpi *X, const mbedtls_mpi *A,
+ const mbedtls_mpi *E, const mbedtls_mpi *N,
+ mbedtls_mpi *prec_RR);
+
+#endif /* bignum_internal.h */
diff --git a/tf-psa-crypto/drivers/builtin/src/rsa.c b/tf-psa-crypto/drivers/builtin/src/rsa.c
index 5e7e762..33bb1d3 100644
--- a/tf-psa-crypto/drivers/builtin/src/rsa.c
+++ b/tf-psa-crypto/drivers/builtin/src/rsa.c
@@ -29,6 +29,7 @@
#include "mbedtls/rsa.h"
#include "bignum_core.h"
+#include "bignum_internal.h"
#include "rsa_alt_helpers.h"
#include "rsa_internal.h"
#include "mbedtls/oid.h"
@@ -1257,7 +1258,7 @@
}
olen = ctx->len;
- MBEDTLS_MPI_CHK(mbedtls_mpi_exp_mod(&T, &T, &ctx->E, &ctx->N, &ctx->RN));
+ MBEDTLS_MPI_CHK(mbedtls_mpi_exp_mod_unsafe(&T, &T, &ctx->E, &ctx->N, &ctx->RN));
MBEDTLS_MPI_CHK(mbedtls_mpi_write_binary(&T, output, olen));
cleanup:
diff --git a/tf-psa-crypto/tests/suites/test_suite_bignum.function b/tf-psa-crypto/tests/suites/test_suite_bignum.function
index 1830e5a..3d2b8a1 100644
--- a/tf-psa-crypto/tests/suites/test_suite_bignum.function
+++ b/tf-psa-crypto/tests/suites/test_suite_bignum.function
@@ -3,7 +3,9 @@
#include "mbedtls/entropy.h"
#include "constant_time_internal.h"
#include "bignum_core.h"
+#include "bignum_internal.h"
#include "test/constant_flow.h"
+#include "test/bignum_codepath_check.h"
#if MBEDTLS_MPI_MAX_BITS > 792
#define MPI_MAX_BITS_LARGER_THAN_792
@@ -988,7 +990,13 @@
* against a smaller RR. */
TEST_LE_U(RR.n, N.n - 1);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ mbedtls_codepath_reset();
+#endif
res = mbedtls_mpi_exp_mod(&Z, &A, &E, &N, &RR);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_SECRET, res, E);
+#endif
/* We know that exp_mod internally needs RR to be as large as N.
* Validate that it is the case now, otherwise there was probably
* a buffer overread. */
@@ -1021,7 +1029,26 @@
TEST_ASSERT(mbedtls_test_read_mpi(&N, input_N) == 0);
TEST_ASSERT(mbedtls_test_read_mpi(&X, input_X) == 0);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ mbedtls_codepath_reset();
+#endif
res = mbedtls_mpi_exp_mod(&Z, &A, &E, &N, NULL);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_SECRET, res, E);
+#endif
+ TEST_ASSERT(res == exp_result);
+ if (res == 0) {
+ TEST_ASSERT(sign_is_valid(&Z));
+ TEST_ASSERT(mbedtls_mpi_cmp_mpi(&Z, &X) == 0);
+ }
+
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ mbedtls_codepath_reset();
+#endif
+ res = mbedtls_mpi_exp_mod_unsafe(&Z, &A, &E, &N, NULL);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_PUBLIC, res, E);
+#endif
TEST_ASSERT(res == exp_result);
if (res == 0) {
TEST_ASSERT(sign_is_valid(&Z));
@@ -1029,7 +1056,13 @@
}
/* Now test again with the speed-up parameter supplied as an output. */
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ mbedtls_codepath_reset();
+#endif
res = mbedtls_mpi_exp_mod(&Z, &A, &E, &N, &RR);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_SECRET, res, E);
+#endif
TEST_ASSERT(res == exp_result);
if (res == 0) {
TEST_ASSERT(sign_is_valid(&Z));
@@ -1037,7 +1070,13 @@
}
/* Now test again with the speed-up parameter supplied in calculated form. */
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ mbedtls_codepath_reset();
+#endif
res = mbedtls_mpi_exp_mod(&Z, &A, &E, &N, &RR);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_SECRET, res, E);
+#endif
TEST_ASSERT(res == exp_result);
if (res == 0) {
TEST_ASSERT(sign_is_valid(&Z));
@@ -1077,7 +1116,21 @@
TEST_ASSERT(mbedtls_test_read_mpi(&RR, input_RR) == 0);
}
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ mbedtls_codepath_reset();
+#endif
TEST_ASSERT(mbedtls_mpi_exp_mod(&Z, &A, &E, &N, &RR) == exp_result);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_SECRET, exp_result, E);
+#endif
+
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ mbedtls_codepath_reset();
+#endif
+ TEST_ASSERT(mbedtls_mpi_exp_mod_unsafe(&Z, &A, &E, &N, &RR) == exp_result);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ ASSERT_BIGNUM_CODEPATH(MBEDTLS_MPI_IS_PUBLIC, exp_result, E);
+#endif
exit:
mbedtls_mpi_free(&A); mbedtls_mpi_free(&E); mbedtls_mpi_free(&N);
diff --git a/tf-psa-crypto/tests/suites/test_suite_bignum_core.function b/tf-psa-crypto/tests/suites/test_suite_bignum_core.function
index 6c0bd1e..c755287 100644
--- a/tf-psa-crypto/tests/suites/test_suite_bignum_core.function
+++ b/tf-psa-crypto/tests/suites/test_suite_bignum_core.function
@@ -4,6 +4,7 @@
#include "bignum_core.h"
#include "constant_time_internal.h"
#include "test/constant_flow.h"
+#include "test/bignum_codepath_check.h"
/** Verifies mbedtls_mpi_core_add().
*
@@ -1246,6 +1247,7 @@
char *input_E, char *input_X)
{
mbedtls_mpi_uint *A = NULL;
+ mbedtls_mpi_uint *A_copy = NULL;
mbedtls_mpi_uint *E = NULL;
mbedtls_mpi_uint *N = NULL;
mbedtls_mpi_uint *X = NULL;
@@ -1297,6 +1299,7 @@
TEST_CALLOC(T, working_limbs);
+ /* Test the safe variant */
TEST_CF_SECRET(A, A_limbs * sizeof(mbedtls_mpi_uint));
TEST_CF_SECRET(N, N_limbs * sizeof(mbedtls_mpi_uint));
TEST_CF_SECRET(E, E_limbs * sizeof(mbedtls_mpi_uint));
@@ -1304,22 +1307,48 @@
mbedtls_mpi_core_exp_mod(Y, A, N, N_limbs, E, E_limbs, R2, T);
TEST_CF_PUBLIC(Y, N_limbs * sizeof(mbedtls_mpi_uint));
+ TEST_EQUAL(0, memcmp(X, Y, N_limbs * sizeof(mbedtls_mpi_uint)));
+
+ /* Test the unsafe variant */
+ TEST_CF_PUBLIC(A, A_limbs * sizeof(mbedtls_mpi_uint));
+ TEST_CF_PUBLIC(N, N_limbs * sizeof(mbedtls_mpi_uint));
+ TEST_CF_PUBLIC(E, E_limbs * sizeof(mbedtls_mpi_uint));
+
+ mbedtls_mpi_core_exp_mod_unsafe(Y, A, N, N_limbs, E, E_limbs, R2, T);
TEST_EQUAL(0, memcmp(X, Y, N_limbs * sizeof(mbedtls_mpi_uint)));
+ /*
+ * Check both with output aliased to input
+ */
+
+ TEST_CALLOC(A_copy, A_limbs);
+ memcpy(A_copy, A, sizeof(*A_copy) * A_limbs); // save A
+
+ /* Safe */
TEST_CF_SECRET(A, A_limbs * sizeof(mbedtls_mpi_uint));
TEST_CF_SECRET(N, N_limbs * sizeof(mbedtls_mpi_uint));
TEST_CF_SECRET(E, E_limbs * sizeof(mbedtls_mpi_uint));
- /* Check when output aliased to input */
mbedtls_mpi_core_exp_mod(A, A, N, N_limbs, E, E_limbs, R2, T);
TEST_CF_PUBLIC(A, A_limbs * sizeof(mbedtls_mpi_uint));
TEST_EQUAL(0, memcmp(X, A, N_limbs * sizeof(mbedtls_mpi_uint)));
+ /* Unsafe */
+ memcpy(A, A_copy, sizeof(*A) * A_limbs); // restore A
+ TEST_CF_PUBLIC(A, A_limbs * sizeof(mbedtls_mpi_uint));
+ TEST_CF_PUBLIC(N, N_limbs * sizeof(mbedtls_mpi_uint));
+ TEST_CF_PUBLIC(E, E_limbs * sizeof(mbedtls_mpi_uint));
+
+ mbedtls_mpi_core_exp_mod_unsafe(A, A, N, N_limbs, E, E_limbs, R2, T);
+
+ TEST_EQUAL(0, memcmp(X, A, N_limbs * sizeof(mbedtls_mpi_uint)));
+
exit:
mbedtls_free(T);
mbedtls_free(A);
+ mbedtls_free(A_copy);
mbedtls_free(E);
mbedtls_free(N);
mbedtls_free(X);
diff --git a/tf-psa-crypto/tests/suites/test_suite_rsa.function b/tf-psa-crypto/tests/suites/test_suite_rsa.function
index e824529..98ea9ef 100644
--- a/tf-psa-crypto/tests/suites/test_suite_rsa.function
+++ b/tf-psa-crypto/tests/suites/test_suite_rsa.function
@@ -1,7 +1,9 @@
/* BEGIN_HEADER */
#include "mbedtls/rsa.h"
+#include "bignum_core.h"
#include "rsa_alt_helpers.h"
#include "rsa_internal.h"
+#include "test/bignum_codepath_check.h"
/* END_HEADER */
/* BEGIN_DEPENDENCIES
@@ -489,7 +491,13 @@
TEST_EQUAL(mbedtls_rsa_get_bitlen(&ctx), (size_t) mod);
TEST_ASSERT(mbedtls_rsa_check_pubkey(&ctx) == 0);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ mbedtls_codepath_reset();
+#endif
TEST_ASSERT(mbedtls_rsa_public(&ctx, message_str->x, output) == result);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ ASSERT_RSA_CODEPATH(MBEDTLS_MPI_IS_PUBLIC, result);
+#endif
if (result == 0) {
TEST_ASSERT(mbedtls_test_hexcmp(output, result_str->x,
@@ -554,9 +562,15 @@
/* repeat three times to test updating of blinding values */
for (i = 0; i < 3; i++) {
memset(output, 0x00, sizeof(output));
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ mbedtls_codepath_reset();
+#endif
TEST_ASSERT(mbedtls_rsa_private(&ctx, mbedtls_test_rnd_pseudo_rand,
&rnd_info, message_str->x,
output) == result);
+#if defined(MBEDTLS_TEST_HOOKS) && !defined(MBEDTLS_THREADING_C)
+ ASSERT_RSA_CODEPATH(MBEDTLS_MPI_IS_SECRET, result);
+#endif
if (result == 0) {
TEST_ASSERT(mbedtls_test_hexcmp(output, result_str->x,