Pass bits_in_limb parameter to duplicated methods

Signed-off-by: Werner Lewis <werner.lewis@arm.com>
diff --git a/scripts/mbedtls_dev/bignum_common.py b/scripts/mbedtls_dev/bignum_common.py
index 88fa4df..c81770e 100644
--- a/scripts/mbedtls_dev/bignum_common.py
+++ b/scripts/mbedtls_dev/bignum_common.py
@@ -43,31 +43,18 @@
 def quote_str(val) -> str:
     return "\"{}\"".format(val)
 
-def bound_mpi8(val: int) -> int:
-    """First number exceeding 8-byte limbs needed for given input value."""
-    return bound_mpi8_limbs(limbs_mpi8(val))
+def bound_mpi(val: int, bits_in_limb: int) -> int:
+    """First number exceeding number of limbs needed for given input value."""
+    return bound_mpi_limbs(limbs_mpi(val, bits_in_limb), bits_in_limb)
 
-def bound_mpi4(val: int) -> int:
-    """First number exceeding 4-byte limbs needed for given input value."""
-    return bound_mpi4_limbs(limbs_mpi4(val))
-
-def bound_mpi8_limbs(limbs: int) -> int:
-    """First number exceeding maximum of given 8-byte limbs."""
-    bits = 64 * limbs
+def bound_mpi_limbs(limbs: int, bits_in_limb: int) -> int:
+    """First number exceeding maximum of given number of limbs."""
+    bits = bits_in_limb * limbs
     return 1 << bits
 
-def bound_mpi4_limbs(limbs: int) -> int:
-    """First number exceeding maximum of given 4-byte limbs."""
-    bits = 32 * limbs
-    return 1 << bits
-
-def limbs_mpi8(val: int) -> int:
-    """Return the number of 8-byte limbs required to store value."""
-    return (val.bit_length() + 63) // 64
-
-def limbs_mpi4(val: int) -> int:
-    """Return the number of 4-byte limbs required to store value."""
-    return (val.bit_length() + 31) // 32
+def limbs_mpi(val: int, bits_in_limb: int) -> int:
+    """Return the number of limbs required to store value."""
+    return (val.bit_length() + bits_in_limb - 1) // bits_in_limb
 
 def combination_pairs(values: List[T]) -> List[Tuple[T, T]]:
     """Return all pair combinations from input values.
diff --git a/scripts/mbedtls_dev/bignum_core.py b/scripts/mbedtls_dev/bignum_core.py
index 1bd2482..3652ac2 100644
--- a/scripts/mbedtls_dev/bignum_core.py
+++ b/scripts/mbedtls_dev/bignum_core.py
@@ -83,8 +83,8 @@
     def result(self) -> List[str]:
         tmp = self.int_a + self.int_b
         bound_val = max(self.int_a, self.int_b)
-        bound_4 = bignum_common.bound_mpi4(bound_val)
-        bound_8 = bignum_common.bound_mpi8(bound_val)
+        bound_4 = bignum_common.bound_mpi(bound_val, 32)
+        bound_8 = bignum_common.bound_mpi(bound_val, 64)
         carry_4, remainder_4 = divmod(tmp, bound_4)
         carry_8, remainder_8 = divmod(tmp, bound_8)
         return [
@@ -109,9 +109,9 @@
             carry = 0
         else:
             bound_val = max(self.int_a, self.int_b)
-            bound_4 = bignum_common.bound_mpi4(bound_val)
+            bound_4 = bignum_common.bound_mpi(bound_val, 32)
             result_4 = bound_4 + self.int_a - self.int_b
-            bound_8 = bignum_common.bound_mpi8(bound_val)
+            bound_8 = bignum_common.bound_mpi(bound_val, 64)
             result_8 = bound_8 + self.int_a - self.int_b
             carry = 1
         return [
@@ -153,7 +153,7 @@
         super().__init__(val_a, val_b)
         self.arg_scalar = val_s
         self.int_scalar = bignum_common.hex_to_int(val_s)
-        if bignum_common.limbs_mpi4(self.int_scalar) > 1:
+        if bignum_common.limbs_mpi(self.int_scalar, 32) > 1:
             self.dependencies = ["MBEDTLS_HAVE_INT64"]
 
     def arguments(self) -> List[str]:
@@ -174,8 +174,8 @@
     def result(self) -> List[str]:
         result = self.int_a + (self.int_b * self.int_scalar)
         bound_val = max(self.int_a, self.int_b)
-        bound_4 = bignum_common.bound_mpi4(bound_val)
-        bound_8 = bignum_common.bound_mpi8(bound_val)
+        bound_4 = bignum_common.bound_mpi(bound_val, 32)
+        bound_8 = bignum_common.bound_mpi(bound_val, 64)
         carry_4, remainder_4 = divmod(result, bound_4)
         carry_8, remainder_8 = divmod(result, bound_8)
         return [
@@ -548,12 +548,12 @@
         self.arg_n = val_n
         self.int_n = bignum_common.hex_to_int(val_n)
 
-        limbs_a4 = bignum_common.limbs_mpi4(self.int_a)
-        limbs_a8 = bignum_common.limbs_mpi8(self.int_a)
-        self.limbs_b4 = bignum_common.limbs_mpi4(self.int_b)
-        self.limbs_b8 = bignum_common.limbs_mpi8(self.int_b)
-        self.limbs_an4 = bignum_common.limbs_mpi4(self.int_n)
-        self.limbs_an8 = bignum_common.limbs_mpi8(self.int_n)
+        limbs_a4 = bignum_common.limbs_mpi(self.int_a, 32)
+        limbs_a8 = bignum_common.limbs_mpi(self.int_a, 64)
+        self.limbs_b4 = bignum_common.limbs_mpi(self.int_b, 32)
+        self.limbs_b8 = bignum_common.limbs_mpi(self.int_b, 64)
+        self.limbs_an4 = bignum_common.limbs_mpi(self.int_n, 32)
+        self.limbs_an8 = bignum_common.limbs_mpi(self.int_n, 64)
 
         if limbs_a4 > self.limbs_an4 or limbs_a8 > self.limbs_an8:
             raise Exception("Limbs of input A ({}) exceeds N ({})".format(
@@ -584,12 +584,12 @@
 
     def result(self) -> List[str]:
         """Get the result of the operation."""
-        r4 = bignum_common.bound_mpi4_limbs(self.limbs_an4)
+        r4 = bignum_common.bound_mpi_limbs(self.limbs_an4, 32)
         i4 = bignum_common.invmod(r4, self.int_n)
         x4 = self.int_a * self.int_b * i4
         x4 = x4 % self.int_n
 
-        r8 = bignum_common.bound_mpi8_limbs(self.limbs_an8)
+        r8 = bignum_common.bound_mpi_limbs(self.limbs_an8, 64)
         i8 = bignum_common.invmod(r8, self.int_n)
         x8 = self.int_a * self.int_b * i8
         x8 = x8 % self.int_n