Add more validation to modulus life cycle

Signed-off-by: Janos Follath <janos.follath@arm.com>
diff --git a/library/bignum_new.c b/library/bignum_new.c
index 5f10ead..1c5cb8c 100644
--- a/library/bignum_new.c
+++ b/library/bignum_new.c
@@ -30,6 +30,16 @@
 #include "bignum_mod.h"
 #include "bignum_mod_raw.h"
 
+#if defined(MBEDTLS_PLATFORM_C)
+#include "mbedtls/platform.h"
+#else
+#include <stdio.h>
+#include <stdlib.h>
+#define mbedtls_printf     printf
+#define mbedtls_calloc    calloc
+#define mbedtls_free       free
+#endif
+
 #define MPI_VALIDATE_RET( cond )                                       \
     MBEDTLS_INTERNAL_VALIDATE_RET( cond, MBEDTLS_ERR_MPI_BAD_INPUT_DATA )
 #define MPI_VALIDATE( cond )                                           \
@@ -126,6 +136,16 @@
     if ( m == NULL )
         return;
 
+    switch( m->int_rep )
+    {
+        case MBEDTLS_MPI_MOD_REP_MONTGOMERY:
+            mbedtls_free( m->rep.mont ); break;
+        case MBEDTLS_MPI_MOD_REP_OPT_RED:
+            mbedtls_free( m->rep.mont ); break;
+        default:
+            break;
+    }
+
     m->p = NULL;
     m->n = 0;
     m->plen = 0;
@@ -139,16 +159,46 @@
                                    int ext_rep,
                                    int int_rep )
 {
+    int ret = 0;
+
     if ( X == NULL || m == NULL )
         return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
 
     m->p = X;
     m->n = nx;
-    m->ext_rep = ext_rep;
-    m->int_rep = int_rep;
     m->plen = mpi_bitlen( X, nx );
 
-    return( 0 );
+    switch( ext_rep )
+    {
+        case MBEDTLS_MPI_MOD_EXT_REP_LE:
+        case MBEDTLS_MPI_MOD_EXT_REP_BE:
+            m->ext_rep = ext_rep; break;
+        default:
+            ret = MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
+            goto exit;
+    }
+
+    switch( int_rep )
+    {
+        case MBEDTLS_MPI_MOD_REP_MONTGOMERY:
+            m->int_rep = int_rep;
+            m->rep.mont = NULL; break;
+        case MBEDTLS_MPI_MOD_REP_OPT_RED:
+            m->int_rep = int_rep;
+            m->rep.ored = NULL; break;
+        default:
+            ret = MBEDTLS_ERR_MPI_BAD_INPUT_DATA;
+            goto exit;
+    }
+
+exit:
+
+    if( ret != 0 )
+    {
+        mbedtls_mpi_mod_modulus_free( m );
+    }
+
+    return( ret );
 }
 
 /* Check X to have at least n limbs and set it to 0. */