Determine special cases in-place in the common Koblitz function

Remove parameter used by the special cases and check for special cases in-place.

Signed-off-by: Gabor Mezei <gabor.mezei@arm.com>
diff --git a/library/ecp_curves.c b/library/ecp_curves.c
index e7ccd41..3f2b798 100644
--- a/library/ecp_curves.c
+++ b/library/ecp_curves.c
@@ -5519,6 +5519,7 @@
 #if defined(MBEDTLS_ECP_DP_SECP192K1_ENABLED) ||   \
     defined(MBEDTLS_ECP_DP_SECP224K1_ENABLED) ||   \
     defined(MBEDTLS_ECP_DP_SECP256K1_ENABLED)
+
 /*
  * Fast quasi-reduction modulo P = 2^s - R,
  * with R about 33 bits, used by the Koblitz curves.
@@ -5531,50 +5532,61 @@
 static inline int ecp_mod_koblitz(mbedtls_mpi_uint *X,
                                   size_t X_limbs,
                                   mbedtls_mpi_uint *R,
-                                  size_t P_limbs,
-                                  size_t adjust,
-                                  size_t shift,
-                                  mbedtls_mpi_uint mask)
+                                  size_t bits)
 {
     int ret = 0;
 
-    size_t A1_limbs = X_limbs - (P_limbs - adjust);
-    if (A1_limbs > P_limbs + adjust) {
-        A1_limbs = P_limbs + adjust;
-    }
-    mbedtls_mpi_uint *A1 = mbedtls_calloc(A1_limbs, ciL);
+    /* Determine if A1 is aligned to limb bitsize. If not then the used limbs
+     * of P, A0 and A1 must be set accordingly and there is a middle limb
+     * which is shared by A0 and A1 and need to handle accordingly.
+     */
+    size_t shift   = bits % biL;
+    size_t adjust  = (shift + biL - 1) / biL;
+    size_t P_limbs = bits / biL + adjust;
+
+    mbedtls_mpi_uint *A1 = mbedtls_calloc(P_limbs, ciL);
     if (A1 == NULL) {
         return MBEDTLS_ERR_ECP_ALLOC_FAILED;
     }
 
+    /* Create a buffer to store the value of `R * A1` */
     size_t R_limbs = P_KOBLITZ_R;
-    size_t M_limbs = A1_limbs + R_limbs;
+    size_t M_limbs = P_limbs + R_limbs;
     mbedtls_mpi_uint *M = mbedtls_calloc(M_limbs, ciL);
     if (M == NULL) {
         ret = MBEDTLS_ERR_ECP_ALLOC_FAILED;
         goto cleanup;
     }
 
+    mbedtls_mpi_uint mask = 0;
+    if (adjust != 0) {
+        mask  = ((mbedtls_mpi_uint) 1 << shift) - 1;
+    }
+
     for (size_t pass = 0; pass < 2; pass++) {
         /* Copy A1 */
-        memcpy(A1, X + P_limbs - adjust, A1_limbs * ciL);
+        memcpy(A1, X + P_limbs - adjust, P_limbs * ciL);
+
+        /* Shift A1 to be aligned */
         if (shift != 0) {
-            mbedtls_mpi_core_shift_r(A1, A1_limbs, shift);
+            mbedtls_mpi_core_shift_r(A1, P_limbs, shift);
         }
 
-        /* X = A0 */
+        /* Zeroize the A1 part of the shared limb */
         if (mask != 0) {
             X[P_limbs - 1] &= mask;
         }
 
-        /* Zeroize the A1 part of X to keep only the A0 part */
+        /* X = A0
+         * Zeroize the A1 part of X to keep only the A0 part.
+         */
         for (size_t i = P_limbs; i < X_limbs; i++) {
             X[i] = 0;
         }
 
         /* X = A0 + R * A1 */
-        mbedtls_mpi_core_mul(M, A1, A1_limbs, R, R_limbs);
-        (void) mbedtls_mpi_core_add(X, X, M, A1_limbs + R_limbs);
+        mbedtls_mpi_core_mul(M, A1, P_limbs, R, R_limbs);
+        (void) mbedtls_mpi_core_add(X, X, M, P_limbs + R_limbs);
 
         /* Carry can not be generated since R is a 33-bit value and stored in
          * 64 bits. The result value of the multiplication is at most
@@ -5620,8 +5632,7 @@
                                   0x00)
     };
 
-    return ecp_mod_koblitz(N->p, N->n, Rp,
-                           192 / 8 / sizeof(mbedtls_mpi_uint), 0, 0, 0);
+    return ecp_mod_koblitz(N->p, N->n, Rp, 192);
 }
 
 #endif /* MBEDTLS_ECP_DP_SECP192K1_ENABLED */
@@ -5651,12 +5662,7 @@
                                   0x00)
     };
 
-#if defined(MBEDTLS_HAVE_INT64)
-    return ecp_mod_koblitz(N->p, N->n, Rp, 4, 1, 32, 0xFFFFFFFF);
-#else
-    return ecp_mod_koblitz(N->p, N->n, Rp,
-                           224 / 8 / sizeof(mbedtls_mpi_uint), 0, 0, 0);
-#endif
+    return ecp_mod_koblitz(N->p, N->n, Rp, 224);
 }
 
 #endif /* MBEDTLS_ECP_DP_SECP224K1_ENABLED */
@@ -5685,8 +5691,7 @@
         MBEDTLS_BYTES_TO_T_UINT_8(0xD1, 0x03, 0x00, 0x00, 0x01, 0x00, 0x00,
                                   0x00)
     };
-    return ecp_mod_koblitz(N->p, N->n, Rp,
-                           256 / 8 / sizeof(mbedtls_mpi_uint), 0, 0, 0);
+    return ecp_mod_koblitz(N->p, N->n, Rp, 256);
 }
 #endif /* MBEDTLS_ECP_DP_SECP256K1_ENABLED */