mod_p521: document reduction algorithm

Signed-off-by: Janos Follath <janos.follath@arm.com>
Signed-off-by: Gabor Mezei <gabor.mezei@arm.com>
diff --git a/library/ecp_curves.c b/library/ecp_curves.c
index 186dabe..7439266 100644
--- a/library/ecp_curves.c
+++ b/library/ecp_curves.c
@@ -5222,6 +5222,7 @@
         return 0;
     }
 
+    /* Step 1: Reduction to P521_WIDTH limbs */
     if (X_limbs > P521_WIDTH) {
         /* Helper references for bottom part of X */
         mbedtls_mpi_uint *X0 = X;
@@ -5230,20 +5231,43 @@
         mbedtls_mpi_uint *X1 = X + X0_limbs;
         size_t X1_limbs = X_limbs - X0_limbs;
 
-        /* Split X as X0 + 2^(512 + biL) X1 and compute X0 + 2^(biL - 9) * X1.
-         * This can be done in place. */
+        /* Split X as X0 + 2^P521_WIDTH X1 and compute X0 + 2^(biL - 9) X1.
+         * (We are using that 2^P521_WIDTH = 2^(512 + biL) and that
+         * 2^(512 + biL) X1 = 2^(biL - 9) X1 mod P521.)
+         * The high order limb of the result will be held in carry and the rest
+         * in X0 (that is the result will be represented as
+         * 2^P521_WIDTH carry + X0).
+         *
+         * Also, note that the resulting carry is either 0 or 1:
+         * X0 < 2^P521_WIDTH = 2^(512 + biL) and X1 < 2^(P521_WIDTH-biL) = 2^512
+         * therefore
+         * X0 + 2^(biL - 9) X1 < 2^(512 + biL) + 2^(512 + biL - 9)
+         * which in turn is less than 2 * 2^(512 + biL).
+         */
         mbedtls_mpi_uint shift = ((mbedtls_mpi_uint) 1u) << (biL - 9);
         carry = mbedtls_mpi_core_mla(X0, X0_limbs, X1, X1_limbs, shift);
 
-        /* Clear top part */
+        /* Set X to X0 (by clearing the top part). */
         memset(X1, 0, X1_limbs * sizeof(mbedtls_mpi_uint));
     }
 
-    mbedtls_mpi_uint addend[P521_WIDTH] = { 0 };
-    addend[0] = carry << (biL - 9);
-    addend[0] += (X[P521_WIDTH - 1] >> 9);
+    /* Step 2: Reduction modulo P521
+     *
+     * At this point X is reduced to P521_WIDTH limbs. What remains is to add
+     * the carry (that is 2^P521_WIDTH carry) and to reduce mod P521. */
+
+    /* 2^P521_WIDTH carry = 2^(512 + biL) carry = 2^(biL - 9) carry mod P521.
+     * Also, recall that carry is either 0 or 1. */
+    mbedtls_mpi_uint addend = carry << (biL - 9);
+    /* Keep the top 9 bits and reduce the rest, using 2^521 = 1 mod P521. */
+    addend += (X[P521_WIDTH - 1] >> 9);
     X[P521_WIDTH - 1] &= P521_MASK;
-    (void) mbedtls_mpi_core_add(X, X, addend, P521_WIDTH);
+    /* Declare a helper array for carrying out the addition. */
+    mbedtls_mpi_uint addend_arr[P521_WIDTH] = { 0 };
+    addend_arr[0] = addend;
+    (void) mbedtls_mpi_core_add(X, X, addend_arr, P521_WIDTH);
+    /* Both addends were less than P521 therefore X < 2 P521. (This also means
+     * that the result fit in P521_WIDTH limbs and there won't be any carry.) */
 
     return 0;
 }