Extract common code for computing X^3 + AX + B

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/library/ecp.c b/library/ecp.c
index 5f814f3..70b2283 100644
--- a/library/ecp.c
+++ b/library/ecp.c
@@ -1224,6 +1224,39 @@
     MBEDTLS_MPI_CHK( mbedtls_mpi_safe_cond_swap( (X), (Y), (cond) ) )
 
 #if defined(MBEDTLS_ECP_SHORT_WEIERSTRASS_ENABLED)
+/*
+ * Computes the right-hand side of the Short Weierstrass equation
+ * RHS = X^3 + A X + B
+ */
+static int ecp_sw_rhs( const mbedtls_ecp_group *grp,
+                       mbedtls_mpi *rhs,
+                       const mbedtls_mpi *X )
+{
+    int ret;
+
+    /* Compute X^3 + A X + B as X (X^2 + A) + B */
+    MPI_ECP_SQR( rhs, X );
+
+    /* Special case for A = -3 */
+    if( grp->A.p == NULL )
+    {
+        MPI_ECP_SUB_INT( rhs, rhs, 3 );
+    }
+    else
+    {
+        MPI_ECP_ADD( rhs, rhs, &grp->A );
+    }
+
+    MPI_ECP_MUL( rhs, rhs, X  );
+    MPI_ECP_ADD( rhs, rhs, &grp->B );
+
+cleanup:
+    return( ret );
+}
+
+/*
+ * Derive Y from X and a parity bit
+ */
 static int mbedtls_ecp_sw_derive_y( const mbedtls_ecp_group *grp,
                                     const mbedtls_mpi *X,
                                     mbedtls_mpi *Y,
@@ -1246,18 +1279,8 @@
     mbedtls_mpi exp;
     mbedtls_mpi_init( &exp );
 
-    /* use Y to store intermediate results */
-    /* y^2 = x^3 + ax + b = (x^2 + a)x + b */
-    /* x^2 */
-    MPI_ECP_MUL( Y, X, X );
-    /* x^2 + a */
-    if( !grp->A.p ) /* special case for A = -3; temporarily set exp = -3 */
-        MPI_ECP_LSET( &exp, -3 );
-    MPI_ECP_ADD( Y, Y, grp->A.p ? &grp->A : &exp );
-    /* (x^2 + a)x */
-    MPI_ECP_MUL( Y, Y, X );
-    /* (x^2 + a)x + b */
-    MPI_ECP_ADD( Y, Y, &grp->B );
+    /* use Y to store intermediate result, actually w above */
+    MBEDTLS_MPI_CHK( ecp_sw_rhs( grp, Y, X ) );
 
     /* w = y^2 */ /* Y contains y^2 intermediate result */
     /* exp = ((p+1)/4) */
@@ -2698,23 +2721,10 @@
 
     /*
      * YY = Y^2
-     * RHS = X (X^2 + A) + B = X^3 + A X + B
+     * RHS = X^3 + A X + B
      */
     MPI_ECP_SQR( &YY,  &pt->Y );
-    MPI_ECP_SQR( &RHS, &pt->X );
-
-    /* Special case for A = -3 */
-    if( grp->A.p == NULL )
-    {
-        MPI_ECP_SUB_INT( &RHS, &RHS, 3 );
-    }
-    else
-    {
-        MPI_ECP_ADD( &RHS, &RHS, &grp->A );
-    }
-
-    MPI_ECP_MUL( &RHS, &RHS, &pt->X  );
-    MPI_ECP_ADD( &RHS, &RHS, &grp->B );
+    MBEDTLS_MPI_CHK( ecp_sw_rhs( grp, &RHS, &pt->X ) );
 
     if( MPI_ECP_CMP( &YY, &RHS ) != 0 )
         ret = MBEDTLS_ERR_ECP_INVALID_KEY;