Merge pull request #7637 from paul-elliott-arm/fixed_ecp_mod_p448

[Bignum] Fixed width for ecp mod p448
diff --git a/library/ecp_curves.c b/library/ecp_curves.c
index 57ce39a..0d5dc50 100644
--- a/library/ecp_curves.c
+++ b/library/ecp_curves.c
@@ -28,6 +28,8 @@
 
 #include "mbedtls/platform.h"
 
+#include "constant_time_internal.h"
+
 #include "bn_mul.h"
 #include "bignum_core.h"
 #include "ecp_invasive.h"
@@ -5482,8 +5484,9 @@
 
 /* Number of limbs fully occupied by 2^224 (max), and limbs used by it (min) */
 #define DIV_ROUND_UP(X, Y) (((X) + (Y) -1) / (Y))
-#define P224_WIDTH_MIN   (28 / sizeof(mbedtls_mpi_uint))
-#define P224_WIDTH_MAX   DIV_ROUND_UP(28, sizeof(mbedtls_mpi_uint))
+#define P224_SIZE        (224 / 8)
+#define P224_WIDTH_MIN   (P224_SIZE / sizeof(mbedtls_mpi_uint))
+#define P224_WIDTH_MAX   DIV_ROUND_UP(P224_SIZE, sizeof(mbedtls_mpi_uint))
 #define P224_UNUSED_BITS ((P224_WIDTH_MAX * sizeof(mbedtls_mpi_uint) * 8) - 224)
 
 static int ecp_mod_p448(mbedtls_mpi *N)
@@ -5516,7 +5519,7 @@
 MBEDTLS_STATIC_TESTABLE
 int mbedtls_ecp_mod_p448(mbedtls_mpi_uint *X, size_t X_limbs)
 {
-    size_t i;
+    size_t round;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
     if (X_limbs <= P448_WIDTH) {
@@ -5524,23 +5527,26 @@
     }
 
     size_t M_limbs = X_limbs - (P448_WIDTH);
-    const size_t Q_limbs = M_limbs;
 
     if (M_limbs > P448_WIDTH) {
         /* Shouldn't be called with X larger than 2^896! */
         return MBEDTLS_ERR_ECP_BAD_INPUT_DATA;
     }
 
-    /* Extra limb for carry below. */
+    /* Both M and Q require an extra limb to catch carries. */
     M_limbs++;
 
-    mbedtls_mpi_uint *M = mbedtls_calloc(M_limbs, ciL);
+    const size_t Q_limbs = M_limbs;
+    mbedtls_mpi_uint *M = NULL;
+    mbedtls_mpi_uint *Q = NULL;
+
+    M = mbedtls_calloc(M_limbs, ciL);
 
     if (M == NULL) {
         return MBEDTLS_ERR_ECP_ALLOC_FAILED;
     }
 
-    mbedtls_mpi_uint *Q = mbedtls_calloc(Q_limbs, ciL);
+    Q = mbedtls_calloc(Q_limbs, ciL);
 
     if (Q == NULL) {
         ret =  MBEDTLS_ERR_ECP_ALLOC_FAILED;
@@ -5549,41 +5555,72 @@
 
     /* M = A1 */
     memset(M, 0, (M_limbs * ciL));
-
     /* Do not copy into the overflow limb, as this would read past the end of
      * X. */
     memcpy(M, X + P448_WIDTH, ((M_limbs - 1) * ciL));
 
     /* X = A0 */
-    for (i = P448_WIDTH; i < X_limbs; i++) {
-        X[i] = 0;
-    }
+    memset(X + P448_WIDTH, 0, ((M_limbs - 1) * ciL));
 
-    /* X += A1 - Carry here dealt with by oversize M and X. */
+    /* X = X + M = A0 + A1 */
+    /* Carry here fits in oversize X. Oversize M means it will get
+     * added in, not returned as carry. */
     (void) mbedtls_mpi_core_add(X, X, M, M_limbs);
 
-    /* Q = B1, X += B1 */
-    memcpy(Q, M, (Q_limbs * ciL));
+    /* Q = B1 = M >> 224 */
+    memcpy(Q, (char *) M + P224_SIZE, P224_SIZE);
+    memset((char *) Q + P224_SIZE, 0, P224_SIZE);
 
-    mbedtls_mpi_core_shift_r(Q, Q_limbs, 224);
-
-    /* No carry here - only max 224 bits */
+    /* X = X + Q = (A0 + A1) + B1
+     * Oversize Q catches potential carry here when X is already max 448 bits.
+     */
     (void) mbedtls_mpi_core_add(X, X, Q, Q_limbs);
 
-    /* M = (B0 + B1) * 2^224, X += M */
-    if (sizeof(mbedtls_mpi_uint) > 4) {
+    /* M = B0 */
+    if (ciL > 4) {
         M[P224_WIDTH_MIN] &= ((mbedtls_mpi_uint)-1) >> (P224_UNUSED_BITS);
     }
-    for (i = P224_WIDTH_MAX; i < M_limbs; ++i) {
-        M[i] = 0;
-    }
+    memset(M + P224_WIDTH_MAX, 0, ((M_limbs - P224_WIDTH_MAX) * ciL));
 
+    /* M = M + Q = B0 + B1 */
     (void) mbedtls_mpi_core_add(M, M, Q, Q_limbs);
 
-    /* Shifted carry bit from the addition is dealt with by oversize M */
-    mbedtls_mpi_core_shift_l(M, M_limbs, 224);
+    /* M = (B0 + B1) * 2^224 */
+    /* Shifted carry bit from the addition fits in oversize M. */
+    memmove((char *) M + P224_SIZE, M, P224_SIZE + ciL);
+    memset(M, 0, P224_SIZE);
+
+    /* X = X + M = (A0 + A1 + B1) + (B0 + B1) * 2^224 */
     (void) mbedtls_mpi_core_add(X, X, M, M_limbs);
 
+    /* In the second and third rounds A1 and B0 have at most 1 non-zero limb and
+     * B1=0.
+     * Using this we need to calculate:
+     * A0 + A1 + B1 + (B0 + B1) * 2^224 = A0 + A1 + B0 * 2^224. */
+    for (round = 0; round < 2; ++round) {
+
+        /* M = A1 */
+        memset(M, 0, (M_limbs * ciL));
+        memcpy(M, X + P448_WIDTH, ((M_limbs - 1) * ciL));
+
+        /* X = A0 */
+        memset(X + P448_WIDTH, 0, ((M_limbs - 1) * ciL));
+
+        /* M = A1 + B0 * 2^224
+         * We know that only one limb of A1 will be non-zero and that it will be
+         * limb 0. We also know that B0 is the bottom 224 bits of A1 (which is
+         * then shifted up 224 bits), so, given M is currently A1 this turns
+         * into:
+         * M = M + (M << 224)
+         * As the single non-zero limb in B0 will be A1 limb 0 shifted up by 224
+         * bits, we can just move that into the right place, shifted up
+         * accordingly.*/
+        M[P224_WIDTH_MIN] = M[0] << (224 & (biL - 1));
+
+        /* X = A0 + (A1 + B0 * 2^224) */
+        (void) mbedtls_mpi_core_add(X, X, M, M_limbs);
+    }
+
     ret = 0;
 
 cleanup:
diff --git a/library/ecp_invasive.h b/library/ecp_invasive.h
index 587b173..75714f9 100644
--- a/library/ecp_invasive.h
+++ b/library/ecp_invasive.h
@@ -264,6 +264,25 @@
 
 #if defined(MBEDTLS_ECP_DP_CURVE448_ENABLED)
 
+/** Fast quasi-reduction modulo p448 = 2^448 - 2^224 - 1
+ * Write X as A0 + 2^448 A1 and A1 as B0 + 2^224 B1, and return A0 + A1 + B1 +
+ * (B0 + B1) * 2^224.
+ *
+ * \param[in,out]   X       The address of the MPI to be converted.
+ *                          Must have exact limb size that stores a 896-bit MPI
+ *                          (double the bitlength of the modulus). Upon return
+ *                          holds the reduced value which is in range `0 <= X <
+ *                          N` (where N is the modulus). The bitlength of the
+ *                          reduced value is the same as that of the modulus
+ *                          (448 bits).
+ * \param[in]       X_limbs The length of \p X in limbs.
+ *
+ * \return          \c 0 on Success.
+ * \return          #MBEDTLS_ERR_ECP_BAD_INPUT_DATA if \p X does not have
+ *                  twice as many limbs as the modulus.
+ * \return          #MBEDTLS_ERR_ECP_ALLOC_FAILED if memory allocation
+ *                  failed.
+ */
 MBEDTLS_STATIC_TESTABLE
 int mbedtls_ecp_mod_p448(mbedtls_mpi_uint *X, size_t X_limbs);
 
diff --git a/scripts/mbedtls_dev/ecp.py b/scripts/mbedtls_dev/ecp.py
index 02db438..bed4d56 100644
--- a/scripts/mbedtls_dev/ecp.py
+++ b/scripts/mbedtls_dev/ecp.py
@@ -848,6 +848,12 @@
          "167b75dfb948f82a8317cba01c75f67e290535d868a24b7f627f2855"
          "09167d4126af8090013c3273c02c6b9586b4625b475b51096c4ad652"),
 
+        # Corner case which causes maximum overflow
+        ("f4ae65e920a63ac1f2b64df6dff07870c9d531ae72a47403063238da1"
+         "a1fe3f9d6a179fa50f96cd4aff9261aa92c0e6f17ec940639bc2ccd0B"
+         "519A16DF59C53E0D49B209200F878F362ACE518D5B8BFCF9CDC725E5E"
+         "01C06295E8605AF06932B5006D9E556D3F190E8136BF9C643D332"),
+
         # Next 2 number generated by random.getrandbits(448)
         ("8f54f8ceacaab39e83844b40ffa9b9f15c14bc4a829e07b0829a48d4"
          "22fe99a22c70501e533c91352d3d854e061b90303b08c6e33c729578"),