Fix output width of mbedtls_ecp_mod_p448() to 448 bits

Signed-off-by: Paul Elliott <paul.elliott@arm.com>
diff --git a/library/ecp_curves.c b/library/ecp_curves.c
index 85c889f..782a66a 100644
--- a/library/ecp_curves.c
+++ b/library/ecp_curves.c
@@ -28,6 +28,8 @@
 
 #include "mbedtls/platform.h"
 
+#include "constant_time_internal.h"
+
 #include "bn_mul.h"
 #include "bignum_core.h"
 #include "ecp_invasive.h"
@@ -5502,13 +5504,18 @@
     /* Extra limb for carry below. */
     M_limbs++;
 
-    mbedtls_mpi_uint *M = mbedtls_calloc(M_limbs, ciL);
+    mbedtls_mpi_uint *M = NULL;
+    mbedtls_mpi_uint *Q = NULL;
+    const mbedtls_mpi_uint *P = (mbedtls_mpi_uint *) curve448_p;
+    const size_t P_limbs = CHARS_TO_LIMBS(sizeof(curve448_p));
+
+    M = mbedtls_calloc(M_limbs, ciL);
 
     if (M == NULL) {
         return MBEDTLS_ERR_ECP_ALLOC_FAILED;
     }
 
-    mbedtls_mpi_uint *Q = mbedtls_calloc(Q_limbs, ciL);
+    Q = mbedtls_calloc(Q_limbs, ciL);
 
     if (Q == NULL) {
         ret =  MBEDTLS_ERR_ECP_ALLOC_FAILED;
@@ -5527,9 +5534,15 @@
         X[i] = 0;
     }
 
-    /* X += A1 - Carry here dealt with by oversize M and X. */
+    /* X += A1 - Carry here fits in oversize X. Oversize M means it will get
+     * added in, not returned as carry. */
     (void) mbedtls_mpi_core_add(X, X, M, M_limbs);
 
+    /* Deal with carry bit from add by subtracting P if necessary. */
+    if (X[P448_WIDTH] != 0) {
+        mbedtls_mpi_core_sub(X, X, P, P_limbs);
+    }
+
     /* Q = B1, X += B1 */
     memcpy(Q, M, (Q_limbs * ciL));
 
@@ -5548,10 +5561,22 @@
 
     (void) mbedtls_mpi_core_add(M, M, Q, Q_limbs);
 
-    /* Shifted carry bit from the addition is dealt with by oversize M */
+    /* Shifted carry bit from the addition fits in oversize M */
     mbedtls_mpi_core_shift_l(M, M_limbs, 224);
     (void) mbedtls_mpi_core_add(X, X, M, M_limbs);
 
+    /* Deal with carry bit by subtracting P if necessary. */
+    if (X[P448_WIDTH] != 0) {
+        mbedtls_mpi_core_sub(X, X, P, P_limbs);
+    }
+
+    /* Returned result should be 0 < X < P. Although we have controlled bit
+     * width, we may still have a result which is greater than P. Subtract P
+     * if this is the case. */
+    if (mbedtls_mpi_core_lt_ct(P, X, P_limbs)) {
+        mbedtls_mpi_core_sub(X, X, P, P_limbs);
+    }
+
     ret = 0;
 
 cleanup:
diff --git a/tests/suites/test_suite_ecp.function b/tests/suites/test_suite_ecp.function
index 53da2fc..0b9ce6b 100644
--- a/tests/suites/test_suite_ecp.function
+++ b/tests/suites/test_suite_ecp.function
@@ -1404,7 +1404,6 @@
     TEST_EQUAL(res.n, limbs);
 
     TEST_EQUAL(mbedtls_ecp_mod_p448(X.p, X.n), 0);
-    TEST_EQUAL(mbedtls_mpi_mod_mpi(&X, &X, &N), 0);
     TEST_LE_U(mbedtls_mpi_core_bitlen(X.p, X.n), 448);
     ASSERT_COMPARE(X.p, bytes, res.p, bytes);