Move derivation of ECC private key to helper function and refactor code

Signed-off-by: Przemyslaw Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 9731932..07c0bbd 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -4824,100 +4824,143 @@
 }
 #endif /* MBEDTLS_PSA_BUILTIN_KEY_TYPE_DES */
 
+/*
+* ECC key types require the generation of a private key which is an integer
+* in the range [1, N - 1], where N is the boundary of the private key domain:
+* N is the prime p for Diffie-Hellman, or the order of the
+* curve’s base point for ECC.
+*
+* Let m be the bit size of N, such that 2^m > N >= 2^(m-1).
+* This function generates the private key using the following process:
+*
+* 1. Draw a byte string of length ceiling(m/8) bytes.
+* 2. If m is not a multiple of 8, set the most significant
+*    (8 * ceiling(m/8) - m) bits of the first byte in the string to zero.
+* 3. Convert the string to integer k by decoding it as a big-endian byte string.
+* 4. If k > N - 2, discard the result and return to step 1.
+* 5. Output k + 1 as the private key.
+*
+* This method allows compliance to NIST standards, specifically the methods titled
+* Key-Pair Generation by Testing Candidates in the following publications:
+* - NIST Special Publication 800-56A: Recommendation for Pair-Wise Key-Establishment
+*   Schemes Using Discrete Logarithm Cryptography [SP800-56A] §5.6.1.1.4 for
+*   Diffie-Hellman keys.
+*
+* - [SP800-56A] §5.6.1.2.2 or FIPS Publication 186-4: Digital Signature
+*   Standard (DSS) [FIPS186-4] §B.4.2 for elliptic curve keys.
+*/
+static psa_status_t psa_generate_derived_ecc_key_helper(
+    psa_key_slot_t *slot,
+    size_t bits,
+    psa_key_derivation_operation_t *operation,
+    uint8_t **data,
+    unsigned *error)
+{
+    mbedtls_mpi N;
+    mbedtls_mpi k;
+    mbedtls_mpi diff_N_2;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    psa_status_t status;
+
+    mbedtls_mpi_init( &k );
+    mbedtls_mpi_init( &N );
+    mbedtls_mpi_init( &diff_N_2 );
+    
+    psa_ecc_family_t curve = PSA_KEY_TYPE_ECC_GET_FAMILY(
+                                slot->attr.type );
+    mbedtls_ecp_group_id grp_id =
+        mbedtls_ecc_group_of_psa( curve, bits, 0 );
+
+    mbedtls_ecp_group ecp_group;
+
+    if( ( status = mbedtls_ecp_group_load( &ecp_group, grp_id ) ) != 0 )
+    {
+        ret = status;
+        goto cleanup;
+    }
+
+    /* N is the boundary of the private key domain. */
+    N = ecp_group.N;
+    /* Let m be the bit size of N. */
+    size_t m = ecp_group.nbits;
+
+    size_t m_bytes = PSA_BITS_TO_BYTES( m );
+    if (*data != NULL)
+        *data = mbedtls_calloc( 1, m_bytes );
+    if( *data == NULL )
+    {
+        ret = PSA_ERROR_INSUFFICIENT_MEMORY;
+        goto cleanup;
+    }
+    /* 1. Draw a byte string of length ceiling(m/8) bytes. */
+    if ( ( status = psa_key_derivation_output_bytes( operation, *data, m_bytes ) ) != 0 )
+    {
+        ret = status;
+        goto cleanup;
+    }
+
+    /* 2. If m is not a multiple of 8 */
+    if (m % 8)
+    {
+        /* Set the most significant
+         * (8 * ceiling(m/8) - m) bits of the first byte in
+         * the string to zero.
+         */
+        uint8_t clear_bit_count = ( 8 * m_bytes - m );
+        uint8_t clear_bit_mask = ( ( 1 << clear_bit_count ) - 1 );
+        clear_bit_mask = ~( clear_bit_mask << ( 8 - clear_bit_count ) );
+        *data[0] = ( *data[0] & clear_bit_mask );
+    }
+
+    /* 3. Convert the string to integer k by decoding it as a
+     *    big-endian byte string.
+     */
+    MBEDTLS_MPI_CHK(mbedtls_mpi_read_binary( &k, *data, m_bytes));
+
+    /* 4. If k > N - 2, discard the result and return to step 1.
+     *    Result of comparison is returned. When it indicates error
+     *    then this fuction is called again.
+     */
+    MBEDTLS_MPI_CHK( mbedtls_mpi_sub_int( &diff_N_2, &N, 2) );
+    MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &k, diff_N_2.n ) );
+    MBEDTLS_MPI_CHK( mbedtls_mpi_lt_mpi_ct( &diff_N_2, &k, error ) );
+
+    /* 5. Output k + 1 as the private key. */
+    MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( &k, &k, 1));
+    MBEDTLS_MPI_CHK( mbedtls_mpi_write_binary( &k, *data, m_bytes) );
+
+    ret = 0;
+cleanup:
+    if (ret) {
+        mbedtls_free( *data );
+        *data = NULL;
+    }
+    mbedtls_mpi_free( &k );
+    mbedtls_mpi_free( &N );
+    mbedtls_mpi_free( &diff_N_2 );
+    return( ret );
+}
+
 static psa_status_t psa_generate_derived_key_internal(
     psa_key_slot_t *slot,
     size_t bits,
     psa_key_derivation_operation_t *operation )
 {
     uint8_t *data = NULL;
+    unsigned key_err = 0;
     size_t bytes = PSA_BITS_TO_BYTES( bits );
     size_t storage_size = bytes;
     psa_status_t status;
 
-    /*
-    * ECC key types require the generation of a private key which is an integer
-    * in the range [1, N - 1], where N is the boundary of the private key domain:
-    * N is the prime p for Diffie-Hellman, or the order of the
-    * curve’s base point for ECC.
-    *
-    * Let m be the bit size of N, such that 2^m > N >= 2^(m-1).
-    * This function generates the private key using the following process:
-    *
-    * 1. Draw a byte string of length ceiling(m/8) bytes.
-    * 2. If m is not a multiple of 8, set the most significant
-    *    (8 * ceiling(m/8) - m) bits of the first byte in the string to zero.
-    * 3. Convert the string to integer k by decoding it as a big-endian byte string.
-    * 4. If k > N - 2, discard the result and return to step 1.
-    * 5. Output k + 1 as the private key.
-    *
-    * This method allows compliance to NIST standards
-    */
     if ( PSA_KEY_TYPE_IS_ECC( slot->attr.type ) )
     {
-        int cmp_result;
-        do {
-            int ret;
-            psa_ecc_family_t curve = PSA_KEY_TYPE_ECC_GET_FAMILY(
-                                        slot->attr.type );
-            mbedtls_ecp_group_id grp_id =
-                mbedtls_ecc_group_of_psa( curve, bits, 0 );
-
-            mbedtls_ecp_keypair ecp;
-            mbedtls_ecp_keypair_init( &ecp );
-
-            if( ( ret = mbedtls_ecp_group_load( &ecp.grp, grp_id ) ) != 0 )
-                return( ret );
-
-            /* N is the boundary of the private key domain */
-            mbedtls_mpi N = ecp.grp.N;
-            /* Let m be the bit size of N */
-            size_t m = ecp.grp.nbits;
-
-            size_t m_bytes = PSA_BITS_TO_BYTES( m );
-
-            /* Alloc buffer once */
-            if ( data == NULL )
-                data = mbedtls_calloc( 1, m_bytes );
-            if( data == NULL )
-                return( PSA_ERROR_INSUFFICIENT_MEMORY );
-
-            /* 1. Draw a byte string of length ceiling(m/8) bytes. */
-            status = psa_key_derivation_output_bytes( operation, data, m_bytes );
-            if( status != PSA_SUCCESS )
-                goto exit;
-
-            /* 2. If m is not a multiple of 8 */
-            if (m % 8)
-            {
-                /* set the most significant
-                 * (8 * ceiling(m/8) - m) bits of the first byte in
-                 * the string to zero.
-                 */
-                uint8_t clear_bit_count = ( 8 * m_bytes - m );
-                uint8_t clear_bit_mask = ( ( 1 << clear_bit_count ) - 1 );
-                clear_bit_mask = ~( clear_bit_mask << ( 8 - clear_bit_count ) );
-                data[0] = ( data[0] & clear_bit_mask );
-            }
-
-            /* 3. Convert the string to integer k by decoding it as a
-             *    big-endian byte string.
-             */
-            mbedtls_mpi k;
-            mbedtls_mpi_init( &k );
-            mbedtls_mpi_read_binary( &k, data, m_bytes);
-
-            /* 4. If k > N - 2, discard the result and return to step 1. */
-            mbedtls_mpi diff_N_2;
-            mbedtls_mpi_init( &diff_N_2 );
-            mbedtls_mpi_sub_int( &diff_N_2, &N, 2);
-            cmp_result = mbedtls_mpi_cmp_mpi( &k, &diff_N_2 );
-
-            /* 5. Output k + 1 as the private key. */
-            mbedtls_mpi sum_k_1;
-            mbedtls_mpi_init( &sum_k_1 );
-            mbedtls_mpi_add_int( &sum_k_1, &k, 1);
-            mbedtls_mpi_write_binary( &sum_k_1, data, m_bytes);
-        } while ( cmp_result == 1 );
+gen_ecc_key:
+        status = psa_generate_derived_ecc_key_helper(slot, bits, operation, &data, &key_err);
+        if( status != PSA_SUCCESS )
+            goto exit;
+        /* Key has been created, but it doesn't meet criteria. */
+        if (key_err)
+            goto gen_ecc_key;
     } else {
         if( ! key_type_is_raw_bytes( slot->attr.type ) )
             return( PSA_ERROR_INVALID_ARGUMENT );