Prevent mpi_mod_write from corrupting the input

Allocate a working buffer to store the converted value needed for the
mpi_mod_write function.

Signed-off-by: Gabor Mezei <gabor.mezei@arm.com>
diff --git a/library/bignum_mod.c b/library/bignum_mod.c
index e986865..916d34a 100644
--- a/library/bignum_mod.c
+++ b/library/bignum_mod.c
@@ -383,38 +383,46 @@
                           size_t buflen,
                           mbedtls_mpi_mod_ext_rep ext_rep)
 {
-    int ret = MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
-
     /* Do our best to check if r and m have been set up */
     if (r->limbs == 0 || N->limbs == 0) {
-        goto cleanup;
+        return MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
     }
     if (r->limbs != N->limbs) {
-        goto cleanup;
+        return MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
     }
 
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    mbedtls_mpi_uint *working_memory = r->p;
+    size_t working_memory_len = sizeof(mbedtls_mpi_uint) * r->limbs;
+
     if (N->int_rep == MBEDTLS_MPI_MOD_REP_MONTGOMERY) {
-        ret = mbedtls_mpi_mod_raw_from_mont_rep(r->p, N);
+
+        working_memory = mbedtls_calloc(r->limbs, sizeof(mbedtls_mpi_uint));
+
+        if (working_memory == NULL) {
+            ret = MBEDTLS_ERR_MPI_ALLOC_FAILED;
+            goto cleanup;
+        }
+
+        memcpy(working_memory, r->p, working_memory_len);
+
+        ret = mbedtls_mpi_mod_raw_from_mont_rep(working_memory, N);
         if (ret != 0) {
             goto cleanup;
         }
     }
 
-    ret = mbedtls_mpi_mod_raw_write(r->p, N, buf, buflen, ext_rep);
-
-    if (N->int_rep == MBEDTLS_MPI_MOD_REP_MONTGOMERY) {
-        /* If this fails, the value of r is corrupted and we want to return
-         * this error (as opposed to the error code from the write above) to
-         * let the caller know. If it succeeds, we want to return the error
-         * code from write above. */
-        int conv_ret = mbedtls_mpi_mod_raw_to_mont_rep(r->p, N);
-        if (ret == 0) {
-            ret = conv_ret;
-        }
-    }
+    ret = mbedtls_mpi_mod_raw_write(working_memory, N, buf, buflen, ext_rep);
 
 cleanup:
 
+    if (N->int_rep == MBEDTLS_MPI_MOD_REP_MONTGOMERY &&
+        working_memory != NULL) {
+
+        mbedtls_platform_zeroize(working_memory, working_memory_len);
+        mbedtls_free(working_memory);
+    }
+
     return ret;
 }
 /* END MERGE SLOT 7 */