Merge pull request #6607 from gilles-peskine-arm/negative-zero-from-add-development

Fix negative zero from bignum add/subtract
diff --git a/ChangeLog.d/negative-zero-from-add.txt b/ChangeLog.d/negative-zero-from-add.txt
new file mode 100644
index 0000000..107d858
--- /dev/null
+++ b/ChangeLog.d/negative-zero-from-add.txt
@@ -0,0 +1,6 @@
+Bugfix
+   * In the bignum module, operations of the form (-A) - (+A) or (-A) - (-A)
+     with A > 0 created an unintended representation of the value 0 which was
+     not processed correctly by some bignum operations. Fix this. This had no
+     consequence on cryptography code, but might affect applications that call
+     bignum directly and use negative numbers.
diff --git a/include/mbedtls/bignum.h b/include/mbedtls/bignum.h
index 9d15955..3bd1ca0 100644
--- a/include/mbedtls/bignum.h
+++ b/include/mbedtls/bignum.h
@@ -188,9 +188,27 @@
  */
 typedef struct mbedtls_mpi
 {
-    int MBEDTLS_PRIVATE(s);              /*!<  Sign: -1 if the mpi is negative, 1 otherwise */
-    size_t MBEDTLS_PRIVATE(n);           /*!<  total # of limbs  */
-    mbedtls_mpi_uint *MBEDTLS_PRIVATE(p);          /*!<  pointer to limbs  */
+    /** Sign: -1 if the mpi is negative, 1 otherwise.
+     *
+     * The number 0 must be represented with `s = +1`. Although many library
+     * functions treat all-limbs-zero as equivalent to a valid representation
+     * of 0 regardless of the sign bit, there are exceptions, so bignum
+     * functions and external callers must always set \c s to +1 for the
+     * number zero.
+     *
+     * Note that this implies that calloc() or `... = {0}` does not create
+     * a valid MPI representation. You must call mbedtls_mpi_init().
+     */
+    int MBEDTLS_PRIVATE(s);
+
+    /** Total number of limbs in \c p.  */
+    size_t MBEDTLS_PRIVATE(n);
+
+    /** Pointer to limbs.
+     *
+     * This may be \c NULL if \c n is 0.
+     */
+    mbedtls_mpi_uint *MBEDTLS_PRIVATE(p);
 }
 mbedtls_mpi;
 
diff --git a/library/bignum.c b/library/bignum.c
index 521787d..42be815 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -972,10 +972,12 @@
     return( ret );
 }
 
-/*
- * Signed addition: X = A + B
+/* Common function for signed addition and subtraction.
+ * Calculate A + B * flip_B where flip_B is 1 or -1.
  */
-int mbedtls_mpi_add_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B )
+static int add_sub_mpi( mbedtls_mpi *X,
+                        const mbedtls_mpi *A, const mbedtls_mpi *B,
+                        int flip_B )
 {
     int ret, s;
     MPI_VALIDATE_RET( X != NULL );
@@ -983,16 +985,21 @@
     MPI_VALIDATE_RET( B != NULL );
 
     s = A->s;
-    if( A->s * B->s < 0 )
+    if( A->s * B->s * flip_B < 0 )
     {
-        if( mbedtls_mpi_cmp_abs( A, B ) >= 0 )
+        int cmp = mbedtls_mpi_cmp_abs( A, B );
+        if( cmp >= 0 )
         {
             MBEDTLS_MPI_CHK( mbedtls_mpi_sub_abs( X, A, B ) );
-            X->s =  s;
+            /* If |A| = |B|, the result is 0 and we must set the sign bit
+             * to +1 regardless of which of A or B was negative. Otherwise,
+             * since |A| > |B|, the sign is the sign of A. */
+            X->s = cmp == 0 ? 1 : s;
         }
         else
         {
             MBEDTLS_MPI_CHK( mbedtls_mpi_sub_abs( X, B, A ) );
+            /* Since |A| < |B|, the sign is the opposite of A. */
             X->s = -s;
         }
     }
@@ -1008,38 +1015,19 @@
 }
 
 /*
+ * Signed addition: X = A + B
+ */
+int mbedtls_mpi_add_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B )
+{
+    return( add_sub_mpi( X, A, B, 1 ) );
+}
+
+/*
  * Signed subtraction: X = A - B
  */
 int mbedtls_mpi_sub_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B )
 {
-    int ret, s;
-    MPI_VALIDATE_RET( X != NULL );
-    MPI_VALIDATE_RET( A != NULL );
-    MPI_VALIDATE_RET( B != NULL );
-
-    s = A->s;
-    if( A->s * B->s > 0 )
-    {
-        if( mbedtls_mpi_cmp_abs( A, B ) >= 0 )
-        {
-            MBEDTLS_MPI_CHK( mbedtls_mpi_sub_abs( X, A, B ) );
-            X->s =  s;
-        }
-        else
-        {
-            MBEDTLS_MPI_CHK( mbedtls_mpi_sub_abs( X, B, A ) );
-            X->s = -s;
-        }
-    }
-    else
-    {
-        MBEDTLS_MPI_CHK( mbedtls_mpi_add_abs( X, A, B ) );
-        X->s = s;
-    }
-
-cleanup:
-
-    return( ret );
+    return( add_sub_mpi( X, A, B, -1 ) );
 }
 
 /*
diff --git a/scripts/mbedtls_dev/bignum_common.py b/scripts/mbedtls_dev/bignum_common.py
index 279668f..8b11bc2 100644
--- a/scripts/mbedtls_dev/bignum_common.py
+++ b/scripts/mbedtls_dev/bignum_common.py
@@ -14,9 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import itertools
-import typing
-
 from abc import abstractmethod
 from typing import Iterator, List, Tuple, TypeVar
 
@@ -38,7 +35,13 @@
     raise ValueError("Not invertible")
 
 def hex_to_int(val: str) -> int:
-    return int(val, 16) if val else 0
+    """Implement the syntax accepted by mbedtls_test_read_mpi().
+
+    This is a superset of what is accepted by mbedtls_test_read_mpi_core().
+    """
+    if val in ['', '-']:
+        return 0
+    return int(val, 16)
 
 def quote_str(val) -> str:
     return "\"{}\"".format(val)
@@ -57,15 +60,8 @@
     return (val.bit_length() + bits_in_limb - 1) // bits_in_limb
 
 def combination_pairs(values: List[T]) -> List[Tuple[T, T]]:
-    """Return all pair combinations from input values.
-
-    The return value is cast, as older versions of mypy are unable to derive
-    the specific type returned by itertools.combinations_with_replacement.
-    """
-    return typing.cast(
-        List[Tuple[T, T]],
-        list(itertools.combinations_with_replacement(values, 2))
-    )
+    """Return all pair combinations from input values."""
+    return [(x, y) for x in values for y in values]
 
 
 class OperationCommon:
diff --git a/tests/include/test/helpers.h b/tests/include/test/helpers.h
index e0e6fd2..5f9bde6 100644
--- a/tests/include/test/helpers.h
+++ b/tests/include/test/helpers.h
@@ -295,13 +295,19 @@
 
 /** Read an MPI from a hexadecimal string.
  *
- * Like mbedtls_mpi_read_string(), but size the resulting bignum based
- * on the number of digits in the string. In particular, construct a
- * bignum with 0 limbs for an empty string, and a bignum with leading 0
- * limbs if the string has sufficiently many leading 0 digits.
+ * Like mbedtls_mpi_read_string(), but with tighter guarantees around
+ * edge cases.
  *
- * This is important so that the "0 (null)" and "0 (1 limb)" and
- * "leading zeros" test cases do what they claim.
+ * - This function guarantees that if \p s begins with '-' then the sign
+ *   bit of the result will be negative, even if the value is 0.
+ *   When this function encounters such a "negative 0", it
+ *   increments #mbedtls_test_case_uses_negative_0.
+ * - The size of the result is exactly the minimum number of limbs needed
+ *   to fit the digits in the input. In particular, this function constructs
+ *   a bignum with 0 limbs for an empty string, and a bignum with leading 0
+ *   limbs if the string has sufficiently many leading 0 digits.
+ *   This is important so that the "0 (null)" and "0 (1 limb)" and
+ *   "leading zeros" test cases do what they claim.
  *
  * \param[out] X        The MPI object to populate. It must be initialized.
  * \param[in] s         The null-terminated hexadecimal string to read from.
@@ -309,6 +315,14 @@
  * \return \c 0 on success, an \c MBEDTLS_ERR_MPI_xxx error code otherwise.
  */
 int mbedtls_test_read_mpi( mbedtls_mpi *X, const char *s );
+
+/** Nonzero if the current test case had an input parsed with
+ * mbedtls_test_read_mpi() that is a negative 0 (`"-"`, `"-0"`, `"-00"`, etc.,
+ * constructing a result with the sign bit set to -1 and the value being
+ * all-limbs-0, which is not a valid representation in #mbedtls_mpi but is
+ * tested for robustness).
+ */
+extern unsigned mbedtls_test_case_uses_negative_0;
 #endif /* MBEDTLS_BIGNUM_C */
 
 #endif /* TEST_HELPERS_H */
diff --git a/tests/scripts/generate_bignum_tests.py b/tests/scripts/generate_bignum_tests.py
index a105203..eee2f65 100755
--- a/tests/scripts/generate_bignum_tests.py
+++ b/tests/scripts/generate_bignum_tests.py
@@ -78,11 +78,17 @@
     #pylint: disable=abstract-method
     """Common features for bignum operations in legacy tests."""
     input_values = [
-        "", "0", "7b", "-7b",
+        "", "0", "-", "-0",
+        "7b", "-7b",
         "0000000000000000123", "-0000000000000000123",
         "1230000000000000000", "-1230000000000000000"
     ]
 
+    def description_suffix(self) -> str:
+        #pylint: disable=no-self-use # derived classes need self
+        """Text to add at the end of the test case description."""
+        return ""
+
     def description(self) -> str:
         """Generate a description for the test case.
 
@@ -96,6 +102,9 @@
                 self.symbol,
                 self.value_description(self.arg_b)
             )
+            description_suffix = self.description_suffix()
+            if description_suffix:
+                self.case_description += " " + description_suffix
         return super().description()
 
     @staticmethod
@@ -107,6 +116,8 @@
         """
         if val == "":
             return "0 (null)"
+        if val == "-":
+            return "negative 0 (null)"
         if val == "0":
             return "0 (1 limb)"
 
@@ -171,9 +182,21 @@
         ]
     )
 
-    def result(self) -> List[str]:
-        return [bignum_common.quote_str("{:x}").format(self.int_a + self.int_b)]
+    def __init__(self, val_a: str, val_b: str) -> None:
+        super().__init__(val_a, val_b)
+        self._result = self.int_a + self.int_b
 
+    def description_suffix(self) -> str:
+        if (self.int_a >= 0 and self.int_b >= 0):
+            return "" # obviously positive result or 0
+        if (self.int_a <= 0 and self.int_b <= 0):
+            return "" # obviously negative result or 0
+        # The sign of the result is not obvious, so indicate it
+        return ", result{}0".format('>' if self._result > 0 else
+                                    '<' if self._result < 0 else '=')
+
+    def result(self) -> List[str]:
+        return [bignum_common.quote_str("{:x}".format(self._result))]
 
 if __name__ == '__main__':
     # Use the section of the docstring relevant to the CLI as description
diff --git a/tests/src/helpers.c b/tests/src/helpers.c
index cc23fd7..7c83714 100644
--- a/tests/src/helpers.c
+++ b/tests/src/helpers.c
@@ -89,6 +89,10 @@
     mbedtls_test_info.step = step;
 }
 
+#if defined(MBEDTLS_BIGNUM_C)
+unsigned mbedtls_test_case_uses_negative_0 = 0;
+#endif
+
 void mbedtls_test_info_reset( void )
 {
     mbedtls_test_info.result = MBEDTLS_TEST_RESULT_SUCCESS;
@@ -98,6 +102,9 @@
     mbedtls_test_info.filename = 0;
     memset( mbedtls_test_info.line1, 0, sizeof( mbedtls_test_info.line1 ) );
     memset( mbedtls_test_info.line2, 0, sizeof( mbedtls_test_info.line2 ) );
+#if defined(MBEDTLS_BIGNUM_C)
+    mbedtls_test_case_uses_negative_0 = 0;
+#endif
 }
 
 int mbedtls_test_equal( const char *test, int line_no, const char* filename,
@@ -396,6 +403,15 @@
 
 int mbedtls_test_read_mpi( mbedtls_mpi *X, const char *s )
 {
+    int negative = 0;
+    /* Always set the sign bit to -1 if the input has a minus sign, even for 0.
+     * This creates an invalid representation, which mbedtls_mpi_read_string()
+     * avoids but we want to be able to create that in test data. */
+    if( s[0] == '-' )
+    {
+        ++s;
+        negative = 1;
+    }
     /* mbedtls_mpi_read_string() currently retains leading zeros.
      * It always allocates at least one limb for the value 0. */
     if( s[0] == 0 )
@@ -403,7 +419,15 @@
         mbedtls_mpi_free( X );
         return( 0 );
     }
-    else
-        return( mbedtls_mpi_read_string( X, 16, s ) );
+    int ret = mbedtls_mpi_read_string( X, 16, s );
+    if( ret != 0 )
+        return( ret );
+    if( negative )
+    {
+        if( mbedtls_mpi_cmp_int( X, 0 ) == 0 )
+            ++mbedtls_test_case_uses_negative_0;
+        X->s = -1;
+    }
+    return( 0 );
 }
 #endif
diff --git a/tests/suites/test_suite_bignum.function b/tests/suites/test_suite_bignum.function
index 5c3d776..b75f534 100644
--- a/tests/suites/test_suite_bignum.function
+++ b/tests/suites/test_suite_bignum.function
@@ -13,10 +13,21 @@
  * constructing the value. */
 static int sign_is_valid( const mbedtls_mpi *X )
 {
+    /* Only +1 and -1 are valid sign bits, not e.g. 0 */
     if( X->s != 1 && X->s != -1 )
-        return( 0 ); // invalid sign bit, e.g. 0
-    if( mbedtls_mpi_bitlen( X ) == 0 && X->s != 1 )
-        return( 0 ); // negative zero
+        return( 0 );
+
+    /* The value 0 must be represented with the sign +1. A "negative zero"
+     * with s=-1 is an invalid representation. Forbid that. As an exception,
+     * we sometimes test the robustness of library functions when given
+     * a negative zero input. If a test case has a negative zero as input,
+     * we don't mind if the function has a negative zero output. */
+    if( ! mbedtls_test_case_uses_negative_0 &&
+        mbedtls_mpi_bitlen( X ) == 0 && X->s != 1 )
+    {
+        return( 0 );
+    }
+
     return( 1 );
 }
 
diff --git a/tests/suites/test_suite_bignum.misc.data b/tests/suites/test_suite_bignum.misc.data
index 0b8aa33..818f361 100644
--- a/tests/suites/test_suite_bignum.misc.data
+++ b/tests/suites/test_suite_bignum.misc.data
@@ -1144,6 +1144,18 @@
 Test mbedtls_mpi_div_mpi: 0 (null) / -1
 mpi_div_mpi:"":"-1":"":"":0
 
+Test mbedtls_mpi_div_mpi: -0 (null) / 1
+mpi_div_mpi:"-":"1":"":"":0
+
+Test mbedtls_mpi_div_mpi: -0 (null) / -1
+mpi_div_mpi:"-":"-1":"":"":0
+
+Test mbedtls_mpi_div_mpi: -0 (null) / 42
+mpi_div_mpi:"-":"2a":"":"":0
+
+Test mbedtls_mpi_div_mpi: -0 (null) / -42
+mpi_div_mpi:"-":"-2a":"":"":0
+
 Test mbedtls_mpi_div_mpi #1
 mpi_div_mpi:"9e22d6da18a33d1ef28d2a82242b3f6e9c9742f63e5d440f58a190bfaf23a7866e67589adb80":"22":"4a6abf75b13dc268ea9cc8b5b6aaf0ac85ecd437a4e0987fb13cf8d2acc57c0306c738c1583":"1a":0
 
@@ -1204,6 +1216,18 @@
 Test mbedtls_mpi_mod_mpi: 0 (null) % -1
 mpi_mod_mpi:"":"-1":"":MBEDTLS_ERR_MPI_NEGATIVE_VALUE
 
+Test mbedtls_mpi_mod_mpi: -0 (null) % 1
+mpi_mod_mpi:"-":"1":"":0
+
+Test mbedtls_mpi_mod_mpi: -0 (null) % -1
+mpi_mod_mpi:"-":"-1":"":MBEDTLS_ERR_MPI_NEGATIVE_VALUE
+
+Test mbedtls_mpi_mod_mpi: -0 (null) % 42
+mpi_mod_mpi:"-":"2a":"":0
+
+Test mbedtls_mpi_mod_mpi: -0 (null) % -42
+mpi_mod_mpi:"-":"-2a":"":MBEDTLS_ERR_MPI_NEGATIVE_VALUE
+
 Base test mbedtls_mpi_mod_int #1
 mpi_mod_int:"3e8":"d":"c":0