RSA: refactor: avoid code duplication
Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/library/rsa.c b/library/rsa.c
index b7df690..08267db 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -1047,7 +1047,7 @@
unsigned int nbits, int exponent)
{
int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
- mbedtls_mpi H, G, L;
+ mbedtls_mpi H;
int prime_quality = 0;
/*
@@ -1060,8 +1060,6 @@
}
mbedtls_mpi_init(&H);
- mbedtls_mpi_init(&G);
- mbedtls_mpi_init(&L);
if (exponent < 3 || nbits % 2 != 0) {
ret = MBEDTLS_ERR_RSA_BAD_INPUT_DATA;
@@ -1099,35 +1097,28 @@
mbedtls_mpi_swap(&ctx->P, &ctx->Q);
}
- /* Temporarily replace P,Q by P-1, Q-1 */
- MBEDTLS_MPI_CHK(mbedtls_mpi_sub_int(&ctx->P, &ctx->P, 1));
- MBEDTLS_MPI_CHK(mbedtls_mpi_sub_int(&ctx->Q, &ctx->Q, 1));
- MBEDTLS_MPI_CHK(mbedtls_mpi_mul_mpi(&H, &ctx->P, &ctx->Q));
-
- /* check GCD( E, (P-1)*(Q-1) ) == 1 (FIPS 186-4 §B.3.1 criterion 2(a)) */
- MBEDTLS_MPI_CHK(mbedtls_mpi_gcd(&G, &ctx->E, &H));
- if (mbedtls_mpi_cmp_int(&G, 1) != 0) {
+ /* Compute D = E^-1 mod LCM(P-1, Q-1) (FIPS 186-4 §B.3.1 criterion 3(b))
+ * if it exists (FIPS 186-4 §B.3.1 criterion 2(a)) */
+ ret = mbedtls_rsa_deduce_private_exponent(&ctx->P, &ctx->Q, &ctx->E, &ctx->D);
+ if (ret == MBEDTLS_ERR_MPI_NOT_ACCEPTABLE) {
+ mbedtls_mpi_lset(&ctx->D, 0); /* needed for the next call */
continue;
}
+ if (ret != 0) {
+ goto cleanup;
+ }
- /* compute smallest possible D = E^-1 mod LCM(P-1, Q-1) (FIPS 186-4 §B.3.1 criterion 3(b)) */
- MBEDTLS_MPI_CHK(mbedtls_mpi_gcd(&G, &ctx->P, &ctx->Q));
- MBEDTLS_MPI_CHK(mbedtls_mpi_div_mpi(&L, NULL, &H, &G));
- MBEDTLS_MPI_CHK(mbedtls_mpi_inv_mod(&ctx->D, &ctx->E, &L));
-
- if (mbedtls_mpi_bitlen(&ctx->D) <= ((nbits + 1) / 2)) { // (FIPS 186-4 §B.3.1 criterion 3(a))
+ /* (FIPS 186-4 §B.3.1 criterion 3(a)) */
+ if (mbedtls_mpi_bitlen(&ctx->D) <= ((nbits + 1) / 2)) {
continue;
}
break;
} while (1);
- /* Restore P,Q */
- MBEDTLS_MPI_CHK(mbedtls_mpi_add_int(&ctx->P, &ctx->P, 1));
- MBEDTLS_MPI_CHK(mbedtls_mpi_add_int(&ctx->Q, &ctx->Q, 1));
+ /* N = P * Q */
MBEDTLS_MPI_CHK(mbedtls_mpi_mul_mpi(&ctx->N, &ctx->P, &ctx->Q));
-
ctx->len = mbedtls_mpi_size(&ctx->N);
#if !defined(MBEDTLS_RSA_NO_CRT)
@@ -1146,8 +1137,6 @@
cleanup:
mbedtls_mpi_free(&H);
- mbedtls_mpi_free(&G);
- mbedtls_mpi_free(&L);
if (ret != 0) {
mbedtls_rsa_free(ctx);