Update mbedtls_mpi_safe_cond_(assign|swap) to use new CT interface
Signed-off-by: Dave Rodgman <dave.rodgman@arm.com>
diff --git a/library/bignum.c b/library/bignum.c
index b62f3f2..76910b1 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -141,6 +141,8 @@
     MPI_VALIDATE_RET(X != NULL);
     MPI_VALIDATE_RET(Y != NULL);
 
+    mbedtls_ct_condition_t do_assign = mbedtls_ct_bool(assign);
+
     /* all-bits 1 if assign is 1, all-bits 0 if assign is 0 */
     mbedtls_mpi_uint limb_mask = mbedtls_ct_mpi_uint_mask(assign);
 
@@ -148,7 +150,7 @@
 
     X->s = (int) mbedtls_ct_uint_if(assign, Y->s, X->s);
 
-    mbedtls_mpi_core_cond_assign(X->p, Y->p, Y->n, assign);
+    mbedtls_mpi_core_cond_assign(X->p, Y->p, Y->n, do_assign);
 
     for (size_t i = Y->n; i < X->n; i++) {
         X->p[i] &= ~limb_mask;
@@ -177,6 +179,8 @@
         return 0;
     }
 
+    mbedtls_ct_condition_t do_swap = mbedtls_ct_bool(swap);
+
     MBEDTLS_MPI_CHK(mbedtls_mpi_grow(X, Y->n));
     MBEDTLS_MPI_CHK(mbedtls_mpi_grow(Y, X->n));
 
@@ -184,7 +188,7 @@
     X->s = (int) mbedtls_ct_uint_if(swap, Y->s, X->s);
     Y->s = (int) mbedtls_ct_uint_if(swap, s, Y->s);
 
-    mbedtls_mpi_core_cond_swap(X->p, Y->p, X->n, swap);
+    mbedtls_mpi_core_cond_swap(X->p, Y->p, X->n, do_swap);
 
 cleanup:
     return ret;
diff --git a/library/bignum_core.c b/library/bignum_core.c
index a51b3f4..75806cf 100644
--- a/library/bignum_core.c
+++ b/library/bignum_core.c
@@ -211,31 +211,29 @@
 void mbedtls_mpi_core_cond_assign(mbedtls_mpi_uint *X,
                                   const mbedtls_mpi_uint *A,
                                   size_t limbs,
-                                  unsigned char assign)
+                                  mbedtls_ct_condition_t assign)
 {
     if (X == A) {
         return;
     }
 
-    mbedtls_ct_mpi_uint_cond_assign(limbs, X, A, assign);
+    mbedtls_ct_memcpy_if(assign, (unsigned char *) X, (unsigned char *) A, NULL,
+                         limbs * sizeof(mbedtls_mpi_uint));
 }
 
 void mbedtls_mpi_core_cond_swap(mbedtls_mpi_uint *X,
                                 mbedtls_mpi_uint *Y,
                                 size_t limbs,
-                                unsigned char swap)
+                                mbedtls_ct_condition_t swap)
 {
     if (X == Y) {
         return;
     }
 
-    /* all-bits 1 if swap is 1, all-bits 0 if swap is 0 */
-    mbedtls_mpi_uint limb_mask = mbedtls_ct_mpi_uint_mask(swap);
-
     for (size_t i = 0; i < limbs; i++) {
         mbedtls_mpi_uint tmp = X[i];
-        X[i] = (X[i] & ~limb_mask) | (Y[i] & limb_mask);
-        Y[i] = (Y[i] & ~limb_mask) | (tmp & limb_mask);
+        X[i] = mbedtls_ct_mpi_uint_if(swap, Y[i], X[i]);
+        Y[i] = mbedtls_ct_mpi_uint_if(swap, tmp, Y[i]);
     }
 }
 
@@ -637,7 +635,7 @@
                                            size_t index)
 {
     for (size_t i = 0; i < count; i++, table += limbs) {
-        unsigned char assign = mbedtls_ct_size_bool_eq(i, index);
+        mbedtls_ct_condition_t assign = mbedtls_ct_bool_eq(i, index);
         mbedtls_mpi_core_cond_assign(dest, table, limbs, assign);
     }
 }
diff --git a/library/bignum_core.h b/library/bignum_core.h
index 1fc5375..5432c80 100644
--- a/library/bignum_core.h
+++ b/library/bignum_core.h
@@ -86,6 +86,8 @@
 #include "mbedtls/bignum.h"
 #endif
 
+#include "constant_time_internal.h"
+
 #define ciL    (sizeof(mbedtls_mpi_uint))     /** chars in limb  */
 #define biL    (ciL << 3)                     /** bits  in limb  */
 #define biH    (ciL << 2)                     /** half limb size */
@@ -176,21 +178,15 @@
  * \param[in]  A        The address of the source MPI. This must be initialized.
  * \param      limbs    The number of limbs of \p A.
  * \param      assign   The condition deciding whether to perform the
- *                      assignment or not. Must be either 0 or 1:
- *                      * \c 1: Perform the assignment `X = A`.
- *                      * \c 0: Keep the original value of \p X.
+ *                      assignment or not.
  *
  * \note           This function avoids leaking any information about whether
  *                 the assignment was done or not.
- *
- * \warning        If \p assign is neither 0 nor 1, the result of this function
- *                 is indeterminate, and the resulting value in \p X might be
- *                 neither its original value nor the value in \p A.
  */
 void mbedtls_mpi_core_cond_assign(mbedtls_mpi_uint *X,
                                   const mbedtls_mpi_uint *A,
                                   size_t limbs,
-                                  unsigned char assign);
+                                  mbedtls_ct_condition_t assign);
 
 /**
  * \brief   Perform a safe conditional swap of two MPIs which doesn't reveal
@@ -202,21 +198,15 @@
  *                          This must be initialized.
  * \param         limbs     The number of limbs of \p X and \p Y.
  * \param         swap      The condition deciding whether to perform
- *                          the swap or not. Must be either 0 or 1:
- *                          * \c 1: Swap the values of \p X and \p Y.
- *                          * \c 0: Keep the original values of \p X and \p Y.
+ *                          the swap or not.
  *
  * \note           This function avoids leaking any information about whether
  *                 the swap was done or not.
- *
- * \warning        If \p swap is neither 0 nor 1, the result of this function
- *                 is indeterminate, and both \p X and \p Y might end up with
- *                 values different to either of the original ones.
  */
 void mbedtls_mpi_core_cond_swap(mbedtls_mpi_uint *X,
                                 mbedtls_mpi_uint *Y,
                                 size_t limbs,
-                                unsigned char swap);
+                                mbedtls_ct_condition_t swap);
 
 /** Import X from unsigned binary data, little-endian.
  *
diff --git a/library/bignum_mod_raw.c b/library/bignum_mod_raw.c
index 7919211..ef8c2b3 100644
--- a/library/bignum_mod_raw.c
+++ b/library/bignum_mod_raw.c
@@ -40,7 +40,7 @@
                                      const mbedtls_mpi_mod_modulus *N,
                                      unsigned char assign)
 {
-    mbedtls_mpi_core_cond_assign(X, A, N->limbs, assign);
+    mbedtls_mpi_core_cond_assign(X, A, N->limbs, mbedtls_ct_bool(assign));
 }
 
 void mbedtls_mpi_mod_raw_cond_swap(mbedtls_mpi_uint *X,
@@ -48,7 +48,7 @@
                                    const mbedtls_mpi_mod_modulus *N,
                                    unsigned char swap)
 {
-    mbedtls_mpi_core_cond_swap(X, Y, N->limbs, swap);
+    mbedtls_mpi_core_cond_swap(X, Y, N->limbs, mbedtls_ct_bool(swap));
 }
 
 int mbedtls_mpi_mod_raw_read(mbedtls_mpi_uint *X,