Add more wrappers for internal ECP coordinate operations

Signed-off-by: Hanno Becker <hanno.becker@arm.com>
diff --git a/library/ecp.c b/library/ecp.c
index c86d55d..501e5cb 100644
--- a/library/ecp.c
+++ b/library/ecp.c
@@ -1185,30 +1185,60 @@
 }
 #endif /* All functions referencing mbedtls_mpi_shift_l_mod() are alt-implemented without fallback */
 
-#define MPI_ECP_ADD( X, A, B )                                      \
+/*
+ * Macro wrappers around ECP modular arithmetic
+ *
+ * Currently, these wrappers are defined via the bignum module.
+ */
+
+#define MPI_ECP_ADD( X, A, B )                                                  \
     MBEDTLS_MPI_CHK( mbedtls_mpi_add_mod( grp, X, A, B ) )
 
-#define MPI_ECP_SUB( X, A, B )                                      \
+#define MPI_ECP_SUB( X, A, B )                                                  \
     MBEDTLS_MPI_CHK( mbedtls_mpi_sub_mod( grp, X, A, B ) )
 
-#define MPI_ECP_MUL( X, A, B )                                      \
+#define MPI_ECP_MUL( X, A, B )                                                  \
     MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mod( grp, X, A, B ) )
 
-#define MPI_ECP_SQR( X, A )                                         \
+#define MPI_ECP_SQR( X, A )                                                     \
     MBEDTLS_MPI_CHK( mbedtls_mpi_mul_mod( grp, X, A, A ) )
 
-#define MPI_ECP_MUL_INT( X, A, c )                                  \
+#define MPI_ECP_MUL_INT( X, A, c )                                              \
     MBEDTLS_MPI_CHK( mbedtls_mpi_mul_int_mod( grp, X, A, c ) )
 
-#define MPI_ECP_INV( dst, src )                                     \
+#define MPI_ECP_INV( dst, src )                                                 \
     MBEDTLS_MPI_CHK( mbedtls_mpi_inv_mod( (dst), (src), &grp->P ) )
 
-#define MPI_ECP_MOV( X, A )                                         \
+#define MPI_ECP_MOV( X, A )                                                     \
     MBEDTLS_MPI_CHK( mbedtls_mpi_copy( X, A ) )
 
-#define MPI_ECP_SHIFT_L( X, count )                                 \
+#define MPI_ECP_SHIFT_L( X, count )                                             \
     MBEDTLS_MPI_CHK( mbedtls_mpi_shift_l_mod( grp, X, count ) )
 
+#define MPI_ECP_LSET( X, c )                                                    \
+    MBEDTLS_MPI_CHK( mbedtls_mpi_lset( X, c ) )
+
+#define MPI_ECP_CMP_INT( X, c )                                                 \
+    mbedtls_mpi_cmp_int( X, c )
+
+#define MPI_ECP_CMP( X, Y )                                                     \
+    mbedtls_mpi_cmp_mpi( X, Y )
+
+/* Needs f_rng, p_rng to be defined. */
+#define MPI_ECP_RAND( X )                                                       \
+    MBEDTLS_MPI_CHK( mbedtls_mpi_random( (X), 2, &grp->P, f_rng, p_rng ) )
+
+/* Conditional negation
+ * Needs grp and a temporary MPI tmp to be defined. */
+#define MPI_ECP_COND_NEG( X, cond )                                        \
+    do                                                                     \
+    {                                                                      \
+        unsigned char nonzero = mbedtls_mpi_cmp_int( (X), 0 ) != 0;        \
+        MBEDTLS_MPI_CHK( mbedtls_mpi_sub_mpi( &tmp, &grp->P, (X) ) );      \
+        MBEDTLS_MPI_CHK( mbedtls_mpi_safe_cond_assign( (X), &tmp,          \
+                                                       nonzero & cond ) ); \
+    } while( 0 )
+
 #if defined(MBEDTLS_ECP_SHORT_WEIERSTRASS_ENABLED)
 /*
  * For curves in short Weierstrass form, we do all the internal operations in
@@ -1224,7 +1254,7 @@
  */
 static int ecp_normalize_jac( const mbedtls_ecp_group *grp, mbedtls_ecp_point *pt )
 {
-    if( mbedtls_mpi_cmp_int( &pt->Z, 0 ) == 0 )
+    if( MPI_ECP_CMP_INT( &pt->Z, 0 ) == 0 )
         return( 0 );
 
 #if defined(MBEDTLS_ECP_NORMALIZE_JAC_ALT)
@@ -1245,7 +1275,7 @@
     MPI_ECP_MUL( &pt->X,   &pt->X,     &T );  /* X   <- X  * T = X / Z^2 */
     MPI_ECP_MUL( &pt->Y,   &pt->Y,     &T );  /* Y'' <- Y' * T = Y / Z^3 */
 
-    MBEDTLS_MPI_CHK( mbedtls_mpi_lset( &pt->Z, 1 ) );
+    MPI_ECP_LSET( &pt->Z, 1 );
 
 cleanup:
 
@@ -1371,19 +1401,13 @@
                             unsigned char inv )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    unsigned char nonzero;
-    mbedtls_mpi mQY;
+    mbedtls_mpi tmp;
+    mbedtls_mpi_init( &tmp );
 
-    mbedtls_mpi_init( &mQY );
-
-    /* Use the fact that -Q.Y mod P = P - Q.Y unless Q.Y == 0 */
-    MBEDTLS_MPI_CHK( mbedtls_mpi_sub_mpi( &mQY, &grp->P, &Q->Y ) );
-    nonzero = mbedtls_mpi_cmp_int( &Q->Y, 0 ) != 0;
-    MBEDTLS_MPI_CHK( mbedtls_mpi_safe_cond_assign( &Q->Y, &mQY, inv & nonzero ) );
+    MPI_ECP_COND_NEG( &Q->Y, inv );
 
 cleanup:
-    mbedtls_mpi_free( &mQY );
-
+    mbedtls_mpi_free( &tmp );
     return( ret );
 }
 
@@ -1436,7 +1460,7 @@
         MPI_ECP_MUL_INT( &tmp[0],  &tmp[1],  3 );
 
         /* Optimize away for "koblitz" curves with A = 0 */
-        if( mbedtls_mpi_cmp_int( &grp->A, 0 ) != 0 )
+        if( MPI_ECP_CMP_INT( &grp->A, 0 ) != 0 )
         {
             /* M += A.Z^4 */
             MPI_ECP_SQR( &tmp[1],  &P->Z                );
@@ -1470,9 +1494,9 @@
     MPI_ECP_MUL(     &tmp[3],  &P->Y,  &P->Z   );
     MPI_ECP_SHIFT_L( &tmp[3],  1               );
 
-    MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &R->X, &tmp[2] ) );
-    MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &R->Y, &tmp[1] ) );
-    MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &R->Z, &tmp[3] ) );
+    MPI_ECP_MOV( &R->X, &tmp[2] );
+    MPI_ECP_MOV( &R->Y, &tmp[1] );
+    MPI_ECP_MOV( &R->Z, &tmp[3] );
 
 cleanup:
 
@@ -1546,9 +1570,9 @@
     MPI_ECP_SUB( &tmp[1], &tmp[1], &P->Y );
 
     /* Special cases (2) and (3) */
-    if( mbedtls_mpi_cmp_int( &tmp[0], 0 ) == 0 )
+    if( MPI_ECP_CMP_INT( &tmp[0], 0 ) == 0 )
     {
-        if( mbedtls_mpi_cmp_int( &tmp[1], 0 ) == 0 )
+        if( MPI_ECP_CMP_INT( &tmp[1], 0 ) == 0 )
         {
             ret = ecp_double_jac( grp, R, P, tmp );
             goto cleanup;
@@ -1609,7 +1633,7 @@
     mbedtls_mpi_init( &l );
 
     /* Generate l such that 1 < l < p */
-    MBEDTLS_MPI_CHK( mbedtls_mpi_random( &l, 2, &grp->P, f_rng, p_rng ) );
+    MPI_ECP_RAND( &l );
 
     /* Z = l * Z */
     MPI_ECP_MUL( &pt->Z,   &pt->Z,     &l );
@@ -1927,7 +1951,7 @@
     /* Safely invert result if i is "negative" */
     MBEDTLS_MPI_CHK( ecp_safe_invert_jac( grp, R, i >> 7 ) );
 
-    MBEDTLS_MPI_CHK( mbedtls_mpi_lset( &R->Z, 1 ) );
+    MPI_ECP_LSET( &R->Z, 1 );
 
 cleanup:
     return( ret );
@@ -2338,7 +2362,7 @@
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     MPI_ECP_INV( &P->Z, &P->Z );
     MPI_ECP_MUL( &P->X, &P->X, &P->Z );
-    MBEDTLS_MPI_CHK( mbedtls_mpi_lset( &P->Z, 1 ) );
+    MPI_ECP_LSET( &P->Z, 1 );
 
 cleanup:
     return( ret );
@@ -2369,7 +2393,7 @@
     mbedtls_mpi_init( &l );
 
     /* Generate l such that 1 < l < p */
-    MBEDTLS_MPI_CHK( mbedtls_mpi_random( &l, 2, &grp->P, f_rng, p_rng ) );
+    MPI_ECP_RAND( &l );
 
     MPI_ECP_MUL( &P->X, &P->X, &l );
     MPI_ECP_MUL( &P->Z, &P->Z, &l );
@@ -2465,12 +2489,12 @@
         return( MBEDTLS_ERR_ECP_BAD_INPUT_DATA );
 
     /* Save PX and read from P before writing to R, in case P == R */
-    MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &PX, &P->X ) );
+    MPI_ECP_MOV( &PX, &P->X );
     MBEDTLS_MPI_CHK( mbedtls_ecp_copy( &RP, P ) );
 
     /* Set R to zero in modified x/z coordinates */
-    MBEDTLS_MPI_CHK( mbedtls_mpi_lset( &R->X, 1 ) );
-    MBEDTLS_MPI_CHK( mbedtls_mpi_lset( &R->Z, 0 ) );
+    MPI_ECP_LSET( &R->X, 1 );
+    MPI_ECP_LSET( &R->Z, 0 );
     mbedtls_mpi_free( &R->Y );
 
     /* RP.X might be sligtly larger than P, so reduce it */
@@ -2664,7 +2688,7 @@
     MPI_ECP_MUL( &RHS, &RHS, &pt->X  );
     MPI_ECP_ADD( &RHS, &RHS, &grp->B );
 
-    if( mbedtls_mpi_cmp_mpi( &YY, &RHS ) != 0 )
+    if( MPI_ECP_CMP( &YY, &RHS ) != 0 )
         ret = MBEDTLS_ERR_ECP_INVALID_KEY;
 
 cleanup: