Import mbedtls-2.16.0

Imports mbedTLS 2.16.0 from https://github.com/ARMmbed/mbedtls.git
commit fb1972db23da ("Merge pull request #544 from ARMmbed/version-2.16")
(tag mbedtls-2.16.0).

Certain files will never be needed and are thus removed (reducing number
of lines to almost 50%):
rm -f circle.yml CMakeLists.txt DartConfiguration.tcl Makefile
rm -f .gitignore .travis.yml .pylint
rm -f include/.gitignore include/CMakeLists.txt library/.gitignore
rm -f library/CMakeLists.txt library/Makefile
rm -rf .git .github doxygen configs programs scripts tests visualc yotta

This is a complete overwrite of previous code so earlier changes in the
branch import/mbedtls-2.6.1 will be added on top of this commit to bring
the changes forward.

Acked-by: Jerome Forissier <jerome.forissier@linaro.org>
Signed-off-by: Jens Wiklander <jens.wiklander@linaro.org>
diff --git a/lib/libmbedtls/mbedtls/library/bignum.c b/lib/libmbedtls/mbedtls/library/bignum.c
index 405cf52..f968a0a 100644
--- a/lib/libmbedtls/mbedtls/library/bignum.c
+++ b/lib/libmbedtls/mbedtls/library/bignum.c
@@ -1,8 +1,8 @@
-// SPDX-License-Identifier: Apache-2.0
 /*
  *  Multi-precision integer library
  *
  *  Copyright (C) 2006-2015, ARM Limited, All Rights Reserved
+ *  SPDX-License-Identifier: Apache-2.0
  *
  *  Licensed under the Apache License, Version 2.0 (the "License"); you may
  *  not use this file except in compliance with the License.
@@ -45,6 +45,7 @@
 
 #include "mbedtls/bignum.h"
 #include "mbedtls/bn_mul.h"
+#include "mbedtls/platform_util.h"
 
 #include <string.h>
 
@@ -58,12 +59,10 @@
 #define mbedtls_free       free
 #endif
 
-#include <mempool.h>
-
-/* Implementation that should never be optimized out by the compiler */
-static void mbedtls_mpi_zeroize( mbedtls_mpi_uint *v, size_t n ) {
-    volatile mbedtls_mpi_uint *p = v; while( n-- ) *p++ = 0;
-}
+#define MPI_VALIDATE_RET( cond )                                       \
+    MBEDTLS_INTERNAL_VALIDATE_RET( cond, MBEDTLS_ERR_MPI_BAD_INPUT_DATA )
+#define MPI_VALIDATE( cond )                                           \
+    MBEDTLS_INTERNAL_VALIDATE( cond )
 
 #define ciL    (sizeof(mbedtls_mpi_uint))         /* chars in limb  */
 #define biL    (ciL << 3)               /* bits  in limb  */
@@ -78,42 +77,22 @@
 #define BITS_TO_LIMBS(i)  ( (i) / biL + ( (i) % biL != 0 ) )
 #define CHARS_TO_LIMBS(i) ( (i) / ciL + ( (i) % ciL != 0 ) )
 
-void *mbedtls_mpi_mempool;
+/* Implementation that should never be optimized out by the compiler */
+static void mbedtls_mpi_zeroize( mbedtls_mpi_uint *v, size_t n )
+{
+    mbedtls_platform_zeroize( v, ciL * n );
+}
 
 /*
  * Initialize one MPI
  */
-static void mpi_init( mbedtls_mpi *X, enum mbedtls_mpi_alloc_type alloc_type,
-		      mbedtls_mpi_uint *p, int sign, size_t alloc_size,
-                      size_t nblimbs)
-{
-    if( X == NULL )
-        return;
-
-    X->s = sign;
-    X->alloc_type = alloc_type;
-    X->alloc_size = alloc_size;
-    X->n = nblimbs;
-    X->p = p;
-}
-
 void mbedtls_mpi_init( mbedtls_mpi *X )
 {
-    mpi_init(X, MBEDTLS_MPI_ALLOC_TYPE_MALLOC, NULL, 1, 0, 0);
-}
+    MPI_VALIDATE( X != NULL );
 
-void mbedtls_mpi_init_mempool( mbedtls_mpi *X )
-{
-    if( mbedtls_mpi_mempool )
-        mpi_init(X, MBEDTLS_MPI_ALLOC_TYPE_MEMPOOL, NULL, 1, 0, 0);
-    else
-        mbedtls_mpi_init( X );
-}
-
-void mbedtls_mpi_init_static( mbedtls_mpi *X , mbedtls_mpi_uint *p,
-                              int sign, size_t alloc_size, size_t nblimbs)
-{
-    mpi_init(X, MBEDTLS_MPI_ALLOC_TYPE_STATIC, p, sign, alloc_size, nblimbs);
+    X->s = 1;
+    X->n = 0;
+    X->p = NULL;
 }
 
 /*
@@ -127,22 +106,12 @@
     if( X->p != NULL )
     {
         mbedtls_mpi_zeroize( X->p, X->n );
-        switch (X->alloc_type) {
-        case MBEDTLS_MPI_ALLOC_TYPE_MALLOC:
-            mbedtls_free( X->p );
-            X->p = NULL;
-            break;
-        case MBEDTLS_MPI_ALLOC_TYPE_MEMPOOL:
-            mempool_free( mbedtls_mpi_mempool, X->p );
-            X->p = NULL;
-            break;
-        default:
-            break;
-        }
+        mbedtls_free( X->p );
     }
 
     X->s = 1;
     X->n = 0;
+    X->p = NULL;
 }
 
 /*
@@ -151,47 +120,27 @@
 int mbedtls_mpi_grow( mbedtls_mpi *X, size_t nblimbs )
 {
     mbedtls_mpi_uint *p;
+    MPI_VALIDATE_RET( X != NULL );
 
     if( nblimbs > MBEDTLS_MPI_MAX_LIMBS )
         return( MBEDTLS_ERR_MPI_ALLOC_FAILED );
 
-    if( X->n >= nblimbs )
-        return( 0 );
-
-    switch( X->alloc_type ) {
-    case MBEDTLS_MPI_ALLOC_TYPE_MALLOC:
-        p = (mbedtls_mpi_uint*)mbedtls_calloc( nblimbs, ciL );
-        break;
-    case MBEDTLS_MPI_ALLOC_TYPE_MEMPOOL:
-        p = mempool_calloc( mbedtls_mpi_mempool, nblimbs, ciL );
-        break;
-    case MBEDTLS_MPI_ALLOC_TYPE_STATIC:
-        if( nblimbs > X->alloc_size )
+    if( X->n < nblimbs )
+    {
+        if( ( p = (mbedtls_mpi_uint*)mbedtls_calloc( nblimbs, ciL ) ) == NULL )
             return( MBEDTLS_ERR_MPI_ALLOC_FAILED );
-        memset( X->p + X->n, 0, (nblimbs - X->n) * ciL );
-	goto out;
-    default:
-        return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
+
+        if( X->p != NULL )
+        {
+            memcpy( p, X->p, X->n * ciL );
+            mbedtls_mpi_zeroize( X->p, X->n );
+            mbedtls_free( X->p );
+        }
+
+        X->n = nblimbs;
+        X->p = p;
     }
 
-    if( p == NULL )
-        return( MBEDTLS_ERR_MPI_ALLOC_FAILED );
-
-    if( X->p != NULL ) {
-        memcpy( p, X->p, X->n * ciL );
-        mbedtls_mpi_zeroize( X->p, X->n );
-    }
-
-
-    if( X->alloc_type == MBEDTLS_MPI_ALLOC_TYPE_MALLOC)
-        mbedtls_free( X->p );
-    else
-        mempool_free( mbedtls_mpi_mempool, X->p );
-
-    X->p = p;
-out:
-    X->n = nblimbs;
-
     return( 0 );
 }
 
@@ -203,6 +152,10 @@
 {
     mbedtls_mpi_uint *p;
     size_t i;
+    MPI_VALIDATE_RET( X != NULL );
+
+    if( nblimbs > MBEDTLS_MPI_MAX_LIMBS )
+        return( MBEDTLS_ERR_MPI_ALLOC_FAILED );
 
     /* Actually resize up in this case */
     if( X->n <= nblimbs )
@@ -216,37 +169,18 @@
     if( i < nblimbs )
         i = nblimbs;
 
-    switch (X->alloc_type) {
-    case MBEDTLS_MPI_ALLOC_TYPE_MALLOC:
-            p = (mbedtls_mpi_uint*)mbedtls_calloc( nblimbs, ciL );
-            break;
-    case MBEDTLS_MPI_ALLOC_TYPE_MEMPOOL:
-            p = mempool_calloc(mbedtls_mpi_mempool, nblimbs, ciL);
-            break;
-    case MBEDTLS_MPI_ALLOC_TYPE_STATIC:
-        if (nblimbs > X->alloc_size)
-            return( MBEDTLS_ERR_MPI_ALLOC_FAILED );
-        mbedtls_mpi_zeroize(X->p + i, X->n - i);
-        goto out;
-
-    default:
-        return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
-    }
-
+    if( ( p = (mbedtls_mpi_uint*)mbedtls_calloc( i, ciL ) ) == NULL )
+        return( MBEDTLS_ERR_MPI_ALLOC_FAILED );
 
     if( X->p != NULL )
     {
         memcpy( p, X->p, i * ciL );
         mbedtls_mpi_zeroize( X->p, X->n );
-        if (X->alloc_type == MBEDTLS_MPI_ALLOC_TYPE_MALLOC)
-            mbedtls_free( X->p );
-        else
-            mempool_free( mbedtls_mpi_mempool, X->p );
+        mbedtls_free( X->p );
     }
 
-    X->p = p;
-out:
     X->n = i;
+    X->p = p;
 
     return( 0 );
 }
@@ -256,8 +190,10 @@
  */
 int mbedtls_mpi_copy( mbedtls_mpi *X, const mbedtls_mpi *Y )
 {
-    int ret;
+    int ret = 0;
     size_t i;
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( Y != NULL );
 
     if( X == Y )
         return( 0 );
@@ -275,9 +211,15 @@
 
     X->s = Y->s;
 
-    MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, i ) );
+    if( X->n < i )
+    {
+        MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, i ) );
+    }
+    else
+    {
+        memset( X->p + i, 0, ( X->n - i ) * ciL );
+    }
 
-    memset( X->p, 0, X->n * ciL );
     memcpy( X->p, Y->p, i * ciL );
 
 cleanup:
@@ -291,6 +233,8 @@
 void mbedtls_mpi_swap( mbedtls_mpi *X, mbedtls_mpi *Y )
 {
     mbedtls_mpi T;
+    MPI_VALIDATE( X != NULL );
+    MPI_VALIDATE( Y != NULL );
 
     memcpy( &T,  X, sizeof( mbedtls_mpi ) );
     memcpy(  X,  Y, sizeof( mbedtls_mpi ) );
@@ -306,6 +250,8 @@
 {
     int ret = 0;
     size_t i;
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( Y != NULL );
 
     /* make sure assign is 0 or 1 in a time-constant manner */
     assign = (assign | (unsigned char)-assign) >> 7;
@@ -335,6 +281,8 @@
     int ret, s;
     size_t i;
     mbedtls_mpi_uint tmp;
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( Y != NULL );
 
     if( X == Y )
         return( 0 );
@@ -367,6 +315,7 @@
 int mbedtls_mpi_lset( mbedtls_mpi *X, mbedtls_mpi_sint z )
 {
     int ret;
+    MPI_VALIDATE_RET( X != NULL );
 
     MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, 1 ) );
     memset( X->p, 0, X->n * ciL );
@@ -384,12 +333,18 @@
  */
 int mbedtls_mpi_get_bit( const mbedtls_mpi *X, size_t pos )
 {
+    MPI_VALIDATE_RET( X != NULL );
+
     if( X->n * biL <= pos )
         return( 0 );
 
     return( ( X->p[pos / biL] >> ( pos % biL ) ) & 0x01 );
 }
 
+/* Get a specific byte, without range checks. */
+#define GET_BYTE( X, i )                                \
+    ( ( ( X )->p[( i ) / ciL] >> ( ( ( i ) % ciL ) * 8 ) ) & 0xff )
+
 /*
  * Set a bit to a specific value of 0 or 1
  */
@@ -398,6 +353,7 @@
     int ret = 0;
     size_t off = pos / biL;
     size_t idx = pos % biL;
+    MPI_VALIDATE_RET( X != NULL );
 
     if( val != 0 && val != 1 )
         return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
@@ -424,6 +380,7 @@
 size_t mbedtls_mpi_lsb( const mbedtls_mpi *X )
 {
     size_t i, j, count = 0;
+    MBEDTLS_INTERNAL_VALIDATE_RET( X != NULL, 0 );
 
     for( i = 0; i < X->n; i++ )
         for( j = 0; j < biL; j++, count++ )
@@ -504,11 +461,13 @@
     size_t i, j, slen, n;
     mbedtls_mpi_uint d;
     mbedtls_mpi T;
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( s != NULL );
 
     if( radix < 2 || radix > 16 )
         return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
 
-    mbedtls_mpi_init_mempool( &T );
+    mbedtls_mpi_init( &T );
 
     slen = strlen( s );
 
@@ -604,6 +563,9 @@
     size_t n;
     char *p;
     mbedtls_mpi T;
+    MPI_VALIDATE_RET( X    != NULL );
+    MPI_VALIDATE_RET( olen != NULL );
+    MPI_VALIDATE_RET( buflen == 0 || buf != NULL );
 
     if( radix < 2 || radix > 16 )
         return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
@@ -625,7 +587,7 @@
     }
 
     p = buf;
-    mbedtls_mpi_init_mempool( &T );
+    mbedtls_mpi_init( &T );
 
     if( X->s == -1 )
         *p++ = '-';
@@ -685,6 +647,12 @@
      */
     char s[ MBEDTLS_MPI_RW_BUFFER_SIZE ];
 
+    MPI_VALIDATE_RET( X   != NULL );
+    MPI_VALIDATE_RET( fin != NULL );
+
+    if( radix < 2 || radix > 16 )
+        return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
+
     memset( s, 0, sizeof( s ) );
     if( fgets( s, sizeof( s ) - 1, fin ) == NULL )
         return( MBEDTLS_ERR_MPI_FILE_IO_ERROR );
@@ -716,6 +684,10 @@
      * newline characters and '\0'
      */
     char s[ MBEDTLS_MPI_RW_BUFFER_SIZE ];
+    MPI_VALIDATE_RET( X != NULL );
+
+    if( radix < 2 || radix > 16 )
+        return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
 
     memset( s, 0, sizeof( s ) );
 
@@ -749,16 +721,23 @@
 int mbedtls_mpi_read_binary( mbedtls_mpi *X, const unsigned char *buf, size_t buflen )
 {
     int ret;
-    size_t i, j, n;
+    size_t i, j;
+    size_t const limbs = CHARS_TO_LIMBS( buflen );
 
-    for( n = 0; n < buflen; n++ )
-        if( buf[n] != 0 )
-            break;
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( buflen == 0 || buf != NULL );
 
-    MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, CHARS_TO_LIMBS( buflen - n ) ) );
+    /* Ensure that target MPI has exactly the necessary number of limbs */
+    if( X->n != limbs )
+    {
+        mbedtls_mpi_free( X );
+        mbedtls_mpi_init( X );
+        MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, limbs ) );
+    }
+
     MBEDTLS_MPI_CHK( mbedtls_mpi_lset( X, 0 ) );
 
-    for( i = buflen, j = 0; i > n; i--, j++ )
+    for( i = buflen, j = 0; i > 0; i--, j++ )
         X->p[j / ciL] |= ((mbedtls_mpi_uint) buf[i - 1]) << ((j % ciL) << 3);
 
 cleanup:
@@ -769,19 +748,45 @@
 /*
  * Export X into unsigned binary data, big endian
  */
-int mbedtls_mpi_write_binary( const mbedtls_mpi *X, unsigned char *buf, size_t buflen )
+int mbedtls_mpi_write_binary( const mbedtls_mpi *X,
+                              unsigned char *buf, size_t buflen )
 {
-    size_t i, j, n;
+    size_t stored_bytes;
+    size_t bytes_to_copy;
+    unsigned char *p;
+    size_t i;
 
-    n = mbedtls_mpi_size( X );
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( buflen == 0 || buf != NULL );
 
-    if( buflen < n )
-        return( MBEDTLS_ERR_MPI_BUFFER_TOO_SMALL );
+    stored_bytes = X->n * ciL;
 
-    memset( buf, 0, buflen );
+    if( stored_bytes < buflen )
+    {
+        /* There is enough space in the output buffer. Write initial
+         * null bytes and record the position at which to start
+         * writing the significant bytes. In this case, the execution
+         * trace of this function does not depend on the value of the
+         * number. */
+        bytes_to_copy = stored_bytes;
+        p = buf + buflen - stored_bytes;
+        memset( buf, 0, buflen - stored_bytes );
+    }
+    else
+    {
+        /* The output buffer is smaller than the allocated size of X.
+         * However X may fit if its leading bytes are zero. */
+        bytes_to_copy = buflen;
+        p = buf;
+        for( i = bytes_to_copy; i < stored_bytes; i++ )
+        {
+            if( GET_BYTE( X, i ) != 0 )
+                return( MBEDTLS_ERR_MPI_BUFFER_TOO_SMALL );
+        }
+    }
 
-    for( i = buflen - 1, j = 0; n > 0; i--, j++, n-- )
-        buf[i] = (unsigned char)( X->p[j / ciL] >> ((j % ciL) << 3) );
+    for( i = 0; i < bytes_to_copy; i++ )
+        p[bytes_to_copy - i - 1] = GET_BYTE( X, i );
 
     return( 0 );
 }
@@ -794,6 +799,7 @@
     int ret;
     size_t i, v0, t1;
     mbedtls_mpi_uint r0 = 0, r1;
+    MPI_VALIDATE_RET( X != NULL );
 
     v0 = count / (biL    );
     t1 = count & (biL - 1);
@@ -843,6 +849,7 @@
 {
     size_t i, v0, v1;
     mbedtls_mpi_uint r0 = 0, r1;
+    MPI_VALIDATE_RET( X != NULL );
 
     v0 = count /  biL;
     v1 = count & (biL - 1);
@@ -885,6 +892,8 @@
 int mbedtls_mpi_cmp_abs( const mbedtls_mpi *X, const mbedtls_mpi *Y )
 {
     size_t i, j;
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( Y != NULL );
 
     for( i = X->n; i > 0; i-- )
         if( X->p[i - 1] != 0 )
@@ -915,6 +924,8 @@
 int mbedtls_mpi_cmp_mpi( const mbedtls_mpi *X, const mbedtls_mpi *Y )
 {
     size_t i, j;
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( Y != NULL );
 
     for( i = X->n; i > 0; i-- )
         if( X->p[i - 1] != 0 )
@@ -949,6 +960,7 @@
 {
     mbedtls_mpi Y;
     mbedtls_mpi_uint p[1];
+    MPI_VALIDATE_RET( X != NULL );
 
     *p  = ( z < 0 ) ? -z : z;
     Y.s = ( z < 0 ) ? -1 : 1;
@@ -966,6 +978,9 @@
     int ret;
     size_t i, j;
     mbedtls_mpi_uint *o, *p, c, tmp;
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( A != NULL );
+    MPI_VALIDATE_RET( B != NULL );
 
     if( X == B )
     {
@@ -1031,7 +1046,7 @@
     while( c != 0 )
     {
         z = ( *d < c ); *d -= c;
-        c = z; i++; d++;
+        c = z; d++;
     }
 }
 
@@ -1043,11 +1058,14 @@
     mbedtls_mpi TB;
     int ret;
     size_t n;
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( A != NULL );
+    MPI_VALIDATE_RET( B != NULL );
 
     if( mbedtls_mpi_cmp_abs( A, B ) < 0 )
         return( MBEDTLS_ERR_MPI_NEGATIVE_VALUE );
 
-    mbedtls_mpi_init_mempool( &TB );
+    mbedtls_mpi_init( &TB );
 
     if( X == B )
     {
@@ -1083,8 +1101,12 @@
  */
 int mbedtls_mpi_add_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B )
 {
-    int ret, s = A->s;
+    int ret, s;
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( A != NULL );
+    MPI_VALIDATE_RET( B != NULL );
 
+    s = A->s;
     if( A->s * B->s < 0 )
     {
         if( mbedtls_mpi_cmp_abs( A, B ) >= 0 )
@@ -1114,8 +1136,12 @@
  */
 int mbedtls_mpi_sub_mpi( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *B )
 {
-    int ret, s = A->s;
+    int ret, s;
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( A != NULL );
+    MPI_VALIDATE_RET( B != NULL );
 
+    s = A->s;
     if( A->s * B->s > 0 )
     {
         if( mbedtls_mpi_cmp_abs( A, B ) >= 0 )
@@ -1147,6 +1173,8 @@
 {
     mbedtls_mpi _B;
     mbedtls_mpi_uint p[1];
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( A != NULL );
 
     p[0] = ( b < 0 ) ? -b : b;
     _B.s = ( b < 0 ) ? -1 : 1;
@@ -1163,6 +1191,8 @@
 {
     mbedtls_mpi _B;
     mbedtls_mpi_uint p[1];
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( A != NULL );
 
     p[0] = ( b < 0 ) ? -b : b;
     _B.s = ( b < 0 ) ? -1 : 1;
@@ -1252,8 +1282,11 @@
     int ret;
     size_t i, j;
     mbedtls_mpi TA, TB;
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( A != NULL );
+    MPI_VALIDATE_RET( B != NULL );
 
-    mbedtls_mpi_init_mempool( &TA ); mbedtls_mpi_init_mempool( &TB );
+    mbedtls_mpi_init( &TA ); mbedtls_mpi_init( &TB );
 
     if( X == A ) { MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TA, A ) ); A = &TA; }
     if( X == B ) { MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TB, B ) ); B = &TB; }
@@ -1269,8 +1302,8 @@
     MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, i + j ) );
     MBEDTLS_MPI_CHK( mbedtls_mpi_lset( X, 0 ) );
 
-    for( i++; j > 0; j-- )
-        mpi_mul_hlp( i - 1, A->p, X->p + j - 1, B->p[j - 1] );
+    for( ; j > 0; j-- )
+        mpi_mul_hlp( i, A->p, X->p + j - 1, B->p[j - 1] );
 
     X->s = A->s * B->s;
 
@@ -1288,6 +1321,8 @@
 {
     mbedtls_mpi _B;
     mbedtls_mpi_uint p[1];
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( A != NULL );
 
     _B.s = 1;
     _B.n = 1;
@@ -1396,18 +1431,20 @@
 /*
  * Division by mbedtls_mpi: A = Q * B + R  (HAC 14.20)
  */
-int mbedtls_mpi_div_mpi( mbedtls_mpi *Q, mbedtls_mpi *R, const mbedtls_mpi *A, const mbedtls_mpi *B )
+int mbedtls_mpi_div_mpi( mbedtls_mpi *Q, mbedtls_mpi *R, const mbedtls_mpi *A,
+                         const mbedtls_mpi *B )
 {
     int ret;
     size_t i, n, t, k;
     mbedtls_mpi X, Y, Z, T1, T2;
+    MPI_VALIDATE_RET( A != NULL );
+    MPI_VALIDATE_RET( B != NULL );
 
     if( mbedtls_mpi_cmp_int( B, 0 ) == 0 )
         return( MBEDTLS_ERR_MPI_DIVISION_BY_ZERO );
 
-    mbedtls_mpi_init_mempool( &X ); mbedtls_mpi_init_mempool( &Y );
-    mbedtls_mpi_init_mempool( &Z ); mbedtls_mpi_init_mempool( &T1 );
-    mbedtls_mpi_init_mempool( &T2 );
+    mbedtls_mpi_init( &X ); mbedtls_mpi_init( &Y ); mbedtls_mpi_init( &Z );
+    mbedtls_mpi_init( &T1 ); mbedtls_mpi_init( &T2 );
 
     if( mbedtls_mpi_cmp_abs( A, B ) < 0 )
     {
@@ -1512,10 +1549,13 @@
 /*
  * Division by int: A = Q * b + R
  */
-int mbedtls_mpi_div_int( mbedtls_mpi *Q, mbedtls_mpi *R, const mbedtls_mpi *A, mbedtls_mpi_sint b )
+int mbedtls_mpi_div_int( mbedtls_mpi *Q, mbedtls_mpi *R,
+                         const mbedtls_mpi *A,
+                         mbedtls_mpi_sint b )
 {
     mbedtls_mpi _B;
     mbedtls_mpi_uint p[1];
+    MPI_VALIDATE_RET( A != NULL );
 
     p[0] = ( b < 0 ) ? -b : b;
     _B.s = ( b < 0 ) ? -1 : 1;
@@ -1531,6 +1571,9 @@
 int mbedtls_mpi_mod_mpi( mbedtls_mpi *R, const mbedtls_mpi *A, const mbedtls_mpi *B )
 {
     int ret;
+    MPI_VALIDATE_RET( R != NULL );
+    MPI_VALIDATE_RET( A != NULL );
+    MPI_VALIDATE_RET( B != NULL );
 
     if( mbedtls_mpi_cmp_int( B, 0 ) < 0 )
         return( MBEDTLS_ERR_MPI_NEGATIVE_VALUE );
@@ -1555,6 +1598,8 @@
 {
     size_t i;
     mbedtls_mpi_uint x, y, z;
+    MPI_VALIDATE_RET( r != NULL );
+    MPI_VALIDATE_RET( A != NULL );
 
     if( b == 0 )
         return( MBEDTLS_ERR_MPI_DIVISION_BY_ZERO );
@@ -1608,7 +1653,7 @@
 /*
  * Fast Montgomery initialization (thanks to Tom St Denis)
  */
-void mbedtls_mpi_montg_init( mbedtls_mpi_uint *mm, const mbedtls_mpi *N )
+static void mpi_montg_init( mbedtls_mpi_uint *mm, const mbedtls_mpi *N )
 {
     mbedtls_mpi_uint x, m0 = N->p[0];
     unsigned int i;
@@ -1625,8 +1670,7 @@
 /*
  * Montgomery multiplication: A = A * B * R^-1 mod N  (HAC 14.36)
  */
-int mbedtls_mpi_montmul( mbedtls_mpi *A, const mbedtls_mpi *B,
-			 const mbedtls_mpi *N, mbedtls_mpi_uint mm,
+static int mpi_montmul( mbedtls_mpi *A, const mbedtls_mpi *B, const mbedtls_mpi *N, mbedtls_mpi_uint mm,
                          const mbedtls_mpi *T )
 {
     size_t i, n, m;
@@ -1669,8 +1713,8 @@
 /*
  * Montgomery reduction: A = A * R^-1 mod N
  */
-int mbedtls_mpi_montred( mbedtls_mpi *A, const mbedtls_mpi *N,
-			 mbedtls_mpi_uint mm, const mbedtls_mpi *T )
+static int mpi_montred( mbedtls_mpi *A, const mbedtls_mpi *N,
+                        mbedtls_mpi_uint mm, const mbedtls_mpi *T )
 {
     mbedtls_mpi_uint z = 1;
     mbedtls_mpi U;
@@ -1678,13 +1722,15 @@
     U.n = U.s = (int) z;
     U.p = &z;
 
-    return( mbedtls_mpi_montmul( A, &U, N, mm, T ) );
+    return( mpi_montmul( A, &U, N, mm, T ) );
 }
 
 /*
  * Sliding-window exponentiation: X = A^E mod N  (HAC 14.85)
  */
-int mbedtls_mpi_exp_mod( mbedtls_mpi *X, const mbedtls_mpi *A, const mbedtls_mpi *E, const mbedtls_mpi *N, mbedtls_mpi *_RR )
+int mbedtls_mpi_exp_mod( mbedtls_mpi *X, const mbedtls_mpi *A,
+                         const mbedtls_mpi *E, const mbedtls_mpi *N,
+                         mbedtls_mpi *_RR )
 {
     int ret;
     size_t wbits, wsize, one = 1;
@@ -1694,7 +1740,12 @@
     mbedtls_mpi RR, T, W[ 2 << MBEDTLS_MPI_WINDOW_SIZE ], Apos;
     int neg;
 
-    if( mbedtls_mpi_cmp_int( N, 0 ) < 0 || ( N->p[0] & 1 ) == 0 )
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( A != NULL );
+    MPI_VALIDATE_RET( E != NULL );
+    MPI_VALIDATE_RET( N != NULL );
+
+    if( mbedtls_mpi_cmp_int( N, 0 ) <= 0 || ( N->p[0] & 1 ) == 0 )
         return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
 
     if( mbedtls_mpi_cmp_int( E, 0 ) < 0 )
@@ -1703,9 +1754,9 @@
     /*
      * Init temps and window size
      */
-    mbedtls_mpi_montg_init( &mm, N );
-    mbedtls_mpi_init_mempool( &RR ); mbedtls_mpi_init_mempool( &T );
-    mbedtls_mpi_init_mempool( &Apos );
+    mpi_montg_init( &mm, N );
+    mbedtls_mpi_init( &RR ); mbedtls_mpi_init( &T );
+    mbedtls_mpi_init( &Apos );
     memset( W, 0, sizeof( W ) );
 
     i = mbedtls_mpi_bitlen( E );
@@ -1755,13 +1806,13 @@
     else
         MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &W[1], A ) );
 
-    MBEDTLS_MPI_CHK( mbedtls_mpi_montmul( &W[1], &RR, N, mm, &T ) );
+    MBEDTLS_MPI_CHK( mpi_montmul( &W[1], &RR, N, mm, &T ) );
 
     /*
      * X = R^2 * R^-1 mod N = R mod N
      */
     MBEDTLS_MPI_CHK( mbedtls_mpi_copy( X, &RR ) );
-    MBEDTLS_MPI_CHK( mbedtls_mpi_montred( X, N, mm, &T ) );
+    MBEDTLS_MPI_CHK( mpi_montred( X, N, mm, &T ) );
 
     if( wsize > 1 )
     {
@@ -1774,7 +1825,7 @@
         MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &W[j], &W[1]    ) );
 
         for( i = 0; i < wsize - 1; i++ )
-            MBEDTLS_MPI_CHK( mbedtls_mpi_montmul( &W[j], &W[j], N, mm, &T ) );
+            MBEDTLS_MPI_CHK( mpi_montmul( &W[j], &W[j], N, mm, &T ) );
 
         /*
          * W[i] = W[i - 1] * W[1]
@@ -1784,7 +1835,7 @@
             MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &W[i], N->n + 1 ) );
             MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &W[i], &W[i - 1] ) );
 
-            MBEDTLS_MPI_CHK( mbedtls_mpi_montmul( &W[i], &W[1], N, mm, &T ) );
+            MBEDTLS_MPI_CHK( mpi_montmul( &W[i], &W[1], N, mm, &T ) );
         }
     }
 
@@ -1821,7 +1872,7 @@
             /*
              * out of window, square X
              */
-            MBEDTLS_MPI_CHK( mbedtls_mpi_montmul( X, X, N, mm, &T ) );
+            MBEDTLS_MPI_CHK( mpi_montmul( X, X, N, mm, &T ) );
             continue;
         }
 
@@ -1839,12 +1890,12 @@
              * X = X^wsize R^-1 mod N
              */
             for( i = 0; i < wsize; i++ )
-                MBEDTLS_MPI_CHK( mbedtls_mpi_montmul( X, X, N, mm, &T ) );
+                MBEDTLS_MPI_CHK( mpi_montmul( X, X, N, mm, &T ) );
 
             /*
              * X = X * W[wbits] R^-1 mod N
              */
-            MBEDTLS_MPI_CHK( mbedtls_mpi_montmul( X, &W[wbits], N, mm, &T ) );
+            MBEDTLS_MPI_CHK( mpi_montmul( X, &W[wbits], N, mm, &T ) );
 
             state--;
             nbits = 0;
@@ -1857,18 +1908,18 @@
      */
     for( i = 0; i < nbits; i++ )
     {
-        MBEDTLS_MPI_CHK( mbedtls_mpi_montmul( X, X, N, mm, &T ) );
+        MBEDTLS_MPI_CHK( mpi_montmul( X, X, N, mm, &T ) );
 
         wbits <<= 1;
 
         if( ( wbits & ( one << wsize ) ) != 0 )
-            MBEDTLS_MPI_CHK( mbedtls_mpi_montmul( X, &W[1], N, mm, &T ) );
+            MBEDTLS_MPI_CHK( mpi_montmul( X, &W[1], N, mm, &T ) );
     }
 
     /*
      * X = A^E * R * R^-1 mod N = A^E mod N
      */
-    MBEDTLS_MPI_CHK( mbedtls_mpi_montred( X, N, mm, &T ) );
+    MBEDTLS_MPI_CHK( mpi_montred( X, N, mm, &T ) );
 
     if( neg && E->n != 0 && ( E->p[0] & 1 ) != 0 )
     {
@@ -1898,8 +1949,11 @@
     size_t lz, lzt;
     mbedtls_mpi TG, TA, TB;
 
-    mbedtls_mpi_init_mempool( &TG ); mbedtls_mpi_init_mempool( &TA );
-    mbedtls_mpi_init_mempool( &TB );
+    MPI_VALIDATE_RET( G != NULL );
+    MPI_VALIDATE_RET( A != NULL );
+    MPI_VALIDATE_RET( B != NULL );
+
+    mbedtls_mpi_init( &TG ); mbedtls_mpi_init( &TA ); mbedtls_mpi_init( &TB );
 
     MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TA, A ) );
     MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &TB, B ) );
@@ -1955,6 +2009,8 @@
 {
     int ret;
     unsigned char buf[MBEDTLS_MPI_MAX_SIZE];
+    MPI_VALIDATE_RET( X     != NULL );
+    MPI_VALIDATE_RET( f_rng != NULL );
 
     if( size > MBEDTLS_MPI_MAX_SIZE )
         return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
@@ -1963,6 +2019,7 @@
     MBEDTLS_MPI_CHK( mbedtls_mpi_read_binary( X, buf, size ) );
 
 cleanup:
+    mbedtls_platform_zeroize( buf, sizeof( buf ) );
     return( ret );
 }
 
@@ -1973,15 +2030,16 @@
 {
     int ret;
     mbedtls_mpi G, TA, TU, U1, U2, TB, TV, V1, V2;
+    MPI_VALIDATE_RET( X != NULL );
+    MPI_VALIDATE_RET( A != NULL );
+    MPI_VALIDATE_RET( N != NULL );
 
     if( mbedtls_mpi_cmp_int( N, 1 ) <= 0 )
         return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
 
-    mbedtls_mpi_init_mempool( &TA ); mbedtls_mpi_init_mempool( &TU );
-    mbedtls_mpi_init_mempool( &U1 ); mbedtls_mpi_init_mempool( &U2 );
-    mbedtls_mpi_init_mempool( &G ); mbedtls_mpi_init_mempool( &TB );
-    mbedtls_mpi_init_mempool( &TV ); mbedtls_mpi_init_mempool( &V1 );
-    mbedtls_mpi_init_mempool( &V2 );
+    mbedtls_mpi_init( &TA ); mbedtls_mpi_init( &TU ); mbedtls_mpi_init( &U1 ); mbedtls_mpi_init( &U2 );
+    mbedtls_mpi_init( &G ); mbedtls_mpi_init( &TB ); mbedtls_mpi_init( &TV );
+    mbedtls_mpi_init( &V1 ); mbedtls_mpi_init( &V2 );
 
     MBEDTLS_MPI_CHK( mbedtls_mpi_gcd( &G, A, N ) );
 
@@ -2126,17 +2184,20 @@
 /*
  * Miller-Rabin pseudo-primality test  (HAC 4.24)
  */
-static int mpi_miller_rabin( const mbedtls_mpi *X,
+static int mpi_miller_rabin( const mbedtls_mpi *X, size_t rounds,
                              int (*f_rng)(void *, unsigned char *, size_t),
                              void *p_rng )
 {
     int ret, count;
-    size_t i, j, k, n, s;
+    size_t i, j, k, s;
     mbedtls_mpi W, R, T, A, RR;
 
-    mbedtls_mpi_init_mempool( &W ); mbedtls_mpi_init_mempool( &R );
-    mbedtls_mpi_init_mempool( &T ); mbedtls_mpi_init_mempool( &A );
-    mbedtls_mpi_init_mempool( &RR );
+    MPI_VALIDATE_RET( X     != NULL );
+    MPI_VALIDATE_RET( f_rng != NULL );
+
+    mbedtls_mpi_init( &W ); mbedtls_mpi_init( &R );
+    mbedtls_mpi_init( &T ); mbedtls_mpi_init( &A );
+    mbedtls_mpi_init( &RR );
 
     /*
      * W = |X| - 1
@@ -2148,27 +2209,12 @@
     MBEDTLS_MPI_CHK( mbedtls_mpi_shift_r( &R, s ) );
 
     i = mbedtls_mpi_bitlen( X );
-    /*
-     * HAC, table 4.4
-     */
-    n = ( ( i >= 1300 ) ?  2 : ( i >=  850 ) ?  3 :
-          ( i >=  650 ) ?  4 : ( i >=  350 ) ?  8 :
-          ( i >=  250 ) ? 12 : ( i >=  150 ) ? 18 : 27 );
 
-    for( i = 0; i < n; i++ )
+    for( i = 0; i < rounds; i++ )
     {
         /*
          * pick a random A, 1 < A < |X| - 1
          */
-        MBEDTLS_MPI_CHK( mbedtls_mpi_fill_random( &A, X->n * ciL, f_rng, p_rng ) );
-
-        if( mbedtls_mpi_cmp_mpi( &A, &W ) >= 0 )
-        {
-            j = mbedtls_mpi_bitlen( &A ) - mbedtls_mpi_bitlen( &W );
-            MBEDTLS_MPI_CHK( mbedtls_mpi_shift_r( &A, j + 1 ) );
-        }
-        A.p[0] |= 3;
-
         count = 0;
         do {
             MBEDTLS_MPI_CHK( mbedtls_mpi_fill_random( &A, X->n * ciL, f_rng, p_rng ) );
@@ -2176,12 +2222,11 @@
             j = mbedtls_mpi_bitlen( &A );
             k = mbedtls_mpi_bitlen( &W );
             if (j > k) {
-                MBEDTLS_MPI_CHK( mbedtls_mpi_shift_r( &A, j - k ) );
+                A.p[A.n - 1] &= ( (mbedtls_mpi_uint) 1 << ( k - ( A.n - 1 ) * biL - 1 ) ) - 1;
             }
 
-            if (count++ > 300) {
-                ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
-                goto cleanup;
+            if (count++ > 30) {
+                return MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
             }
 
         } while ( mbedtls_mpi_cmp_mpi( &A, &W ) >= 0 ||
@@ -2223,7 +2268,8 @@
     }
 
 cleanup:
-    mbedtls_mpi_free( &W ); mbedtls_mpi_free( &R ); mbedtls_mpi_free( &T ); mbedtls_mpi_free( &A );
+    mbedtls_mpi_free( &W ); mbedtls_mpi_free( &R );
+    mbedtls_mpi_free( &T ); mbedtls_mpi_free( &A );
     mbedtls_mpi_free( &RR );
 
     return( ret );
@@ -2232,12 +2278,14 @@
 /*
  * Pseudo-primality test: small factors, then Miller-Rabin
  */
-int mbedtls_mpi_is_prime( const mbedtls_mpi *X,
-                  int (*f_rng)(void *, unsigned char *, size_t),
-                  void *p_rng )
+int mbedtls_mpi_is_prime_ext( const mbedtls_mpi *X, int rounds,
+                              int (*f_rng)(void *, unsigned char *, size_t),
+                              void *p_rng )
 {
     int ret;
     mbedtls_mpi XX;
+    MPI_VALIDATE_RET( X     != NULL );
+    MPI_VALIDATE_RET( f_rng != NULL );
 
     XX.s = 1;
     XX.n = X->n;
@@ -2258,91 +2306,146 @@
         return( ret );
     }
 
-    return( mpi_miller_rabin( &XX, f_rng, p_rng ) );
+    return( mpi_miller_rabin( &XX, rounds, f_rng, p_rng ) );
 }
 
+#if !defined(MBEDTLS_DEPRECATED_REMOVED)
+/*
+ * Pseudo-primality test, error probability 2^-80
+ */
+int mbedtls_mpi_is_prime( const mbedtls_mpi *X,
+                  int (*f_rng)(void *, unsigned char *, size_t),
+                  void *p_rng )
+{
+    MPI_VALIDATE_RET( X     != NULL );
+    MPI_VALIDATE_RET( f_rng != NULL );
+
+    /*
+     * In the past our key generation aimed for an error rate of at most
+     * 2^-80. Since this function is deprecated, aim for the same certainty
+     * here as well.
+     */
+    return( mbedtls_mpi_is_prime_ext( X, 40, f_rng, p_rng ) );
+}
+#endif
+
 /*
  * Prime number generation
+ *
+ * To generate an RSA key in a way recommended by FIPS 186-4, both primes must
+ * be either 1024 bits or 1536 bits long, and flags must contain
+ * MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR.
  */
-int mbedtls_mpi_gen_prime( mbedtls_mpi *X, size_t nbits, int dh_flag,
+int mbedtls_mpi_gen_prime( mbedtls_mpi *X, size_t nbits, int flags,
                    int (*f_rng)(void *, unsigned char *, size_t),
                    void *p_rng )
 {
-    int ret;
+#ifdef MBEDTLS_HAVE_INT64
+// ceil(2^63.5)
+#define CEIL_MAXUINT_DIV_SQRT2 0xb504f333f9de6485ULL
+#else
+// ceil(2^31.5)
+#define CEIL_MAXUINT_DIV_SQRT2 0xb504f334U
+#endif
+    int ret = MBEDTLS_ERR_MPI_NOT_ACCEPTABLE;
     size_t k, n;
+    int rounds;
     mbedtls_mpi_uint r;
     mbedtls_mpi Y;
 
+    MPI_VALIDATE_RET( X     != NULL );
+    MPI_VALIDATE_RET( f_rng != NULL );
+
     if( nbits < 3 || nbits > MBEDTLS_MPI_MAX_BITS )
         return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
 
-    mbedtls_mpi_init_mempool( &Y );
+    mbedtls_mpi_init( &Y );
 
     n = BITS_TO_LIMBS( nbits );
 
-    MBEDTLS_MPI_CHK( mbedtls_mpi_fill_random( X, n * ciL, f_rng, p_rng ) );
-
-    k = mbedtls_mpi_bitlen( X );
-    if( k > nbits ) MBEDTLS_MPI_CHK( mbedtls_mpi_shift_r( X, k - nbits + 1 ) );
-
-    mbedtls_mpi_set_bit( X, nbits-1, 1 );
-
-    X->p[0] |= 1;
-
-    if( dh_flag == 0 )
+    if( ( flags & MBEDTLS_MPI_GEN_PRIME_FLAG_LOW_ERR ) == 0 )
     {
-        while( ( ret = mbedtls_mpi_is_prime( X, f_rng, p_rng ) ) != 0 )
-        {
-            if( ret != MBEDTLS_ERR_MPI_NOT_ACCEPTABLE )
-                goto cleanup;
-
-            MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( X, X, 2 ) );
-        }
+        /*
+         * 2^-80 error probability, number of rounds chosen per HAC, table 4.4
+         */
+        rounds = ( ( nbits >= 1300 ) ?  2 : ( nbits >=  850 ) ?  3 :
+                   ( nbits >=  650 ) ?  4 : ( nbits >=  350 ) ?  8 :
+                   ( nbits >=  250 ) ? 12 : ( nbits >=  150 ) ? 18 : 27 );
     }
     else
     {
         /*
-         * An necessary condition for Y and X = 2Y + 1 to be prime
-         * is X = 2 mod 3 (which is equivalent to Y = 2 mod 3).
-         * Make sure it is satisfied, while keeping X = 3 mod 4
+         * 2^-100 error probability, number of rounds computed based on HAC,
+         * fact 4.48
          */
+        rounds = ( ( nbits >= 1450 ) ?  4 : ( nbits >=  1150 ) ?  5 :
+                   ( nbits >= 1000 ) ?  6 : ( nbits >=   850 ) ?  7 :
+                   ( nbits >=  750 ) ?  8 : ( nbits >=   500 ) ? 13 :
+                   ( nbits >=  250 ) ? 28 : ( nbits >=   150 ) ? 40 : 51 );
+    }
 
-        X->p[0] |= 2;
+    while( 1 )
+    {
+        MBEDTLS_MPI_CHK( mbedtls_mpi_fill_random( X, n * ciL, f_rng, p_rng ) );
+        /* make sure generated number is at least (nbits-1)+0.5 bits (FIPS 186-4 §B.3.3 steps 4.4, 5.5) */
+        if( X->p[n-1] < CEIL_MAXUINT_DIV_SQRT2 ) continue;
 
-        MBEDTLS_MPI_CHK( mbedtls_mpi_mod_int( &r, X, 3 ) );
-        if( r == 0 )
-            MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( X, X, 8 ) );
-        else if( r == 1 )
-            MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( X, X, 4 ) );
+        k = n * biL;
+        if( k > nbits ) MBEDTLS_MPI_CHK( mbedtls_mpi_shift_r( X, k - nbits ) );
+        X->p[0] |= 1;
 
-        /* Set Y = (X-1) / 2, which is X / 2 because X is odd */
-        MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &Y, X ) );
-        MBEDTLS_MPI_CHK( mbedtls_mpi_shift_r( &Y, 1 ) );
-
-        while( 1 )
+        if( ( flags & MBEDTLS_MPI_GEN_PRIME_FLAG_DH ) == 0 )
         {
-            /*
-             * First, check small factors for X and Y
-             * before doing Miller-Rabin on any of them
-             */
-            if( ( ret = mpi_check_small_factors(  X         ) ) == 0 &&
-                ( ret = mpi_check_small_factors( &Y         ) ) == 0 &&
-                ( ret = mpi_miller_rabin(  X, f_rng, p_rng  ) ) == 0 &&
-                ( ret = mpi_miller_rabin( &Y, f_rng, p_rng  ) ) == 0 )
-            {
-                break;
-            }
+            ret = mbedtls_mpi_is_prime_ext( X, rounds, f_rng, p_rng );
 
             if( ret != MBEDTLS_ERR_MPI_NOT_ACCEPTABLE )
                 goto cleanup;
-
+        }
+        else
+        {
             /*
-             * Next candidates. We want to preserve Y = (X-1) / 2 and
-             * Y = 1 mod 2 and Y = 2 mod 3 (eq X = 3 mod 4 and X = 2 mod 3)
-             * so up Y by 6 and X by 12.
+             * An necessary condition for Y and X = 2Y + 1 to be prime
+             * is X = 2 mod 3 (which is equivalent to Y = 2 mod 3).
+             * Make sure it is satisfied, while keeping X = 3 mod 4
              */
-            MBEDTLS_MPI_CHK( mbedtls_mpi_add_int(  X,  X, 12 ) );
-            MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( &Y, &Y, 6  ) );
+
+            X->p[0] |= 2;
+
+            MBEDTLS_MPI_CHK( mbedtls_mpi_mod_int( &r, X, 3 ) );
+            if( r == 0 )
+                MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( X, X, 8 ) );
+            else if( r == 1 )
+                MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( X, X, 4 ) );
+
+            /* Set Y = (X-1) / 2, which is X / 2 because X is odd */
+            MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &Y, X ) );
+            MBEDTLS_MPI_CHK( mbedtls_mpi_shift_r( &Y, 1 ) );
+
+            while( 1 )
+            {
+                /*
+                 * First, check small factors for X and Y
+                 * before doing Miller-Rabin on any of them
+                 */
+                if( ( ret = mpi_check_small_factors(  X         ) ) == 0 &&
+                    ( ret = mpi_check_small_factors( &Y         ) ) == 0 &&
+                    ( ret = mpi_miller_rabin(  X, rounds, f_rng, p_rng  ) )
+                                                                    == 0 &&
+                    ( ret = mpi_miller_rabin( &Y, rounds, f_rng, p_rng  ) )
+                                                                    == 0 )
+                    goto cleanup;
+
+                if( ret != MBEDTLS_ERR_MPI_NOT_ACCEPTABLE )
+                    goto cleanup;
+
+                /*
+                 * Next candidates. We want to preserve Y = (X-1) / 2 and
+                 * Y = 1 mod 2 and Y = 2 mod 3 (eq X = 3 mod 4 and X = 2 mod 3)
+                 * so up Y by 6 and X by 12.
+                 */
+                MBEDTLS_MPI_CHK( mbedtls_mpi_add_int(  X,  X, 12 ) );
+                MBEDTLS_MPI_CHK( mbedtls_mpi_add_int( &Y, &Y, 6  ) );
+            }
         }
     }
 
@@ -2374,10 +2477,8 @@
     int ret, i;
     mbedtls_mpi A, E, N, X, Y, U, V;
 
-    mbedtls_mpi_init_mempool( &A ); mbedtls_mpi_init_mempool( &E );
-    mbedtls_mpi_init_mempool( &N ); mbedtls_mpi_init_mempool( &X );
-    mbedtls_mpi_init_mempool( &Y ); mbedtls_mpi_init_mempool( &U );
-    mbedtls_mpi_init_mempool( &V );
+    mbedtls_mpi_init( &A ); mbedtls_mpi_init( &E ); mbedtls_mpi_init( &N ); mbedtls_mpi_init( &X );
+    mbedtls_mpi_init( &Y ); mbedtls_mpi_init( &U ); mbedtls_mpi_init( &V );
 
     MBEDTLS_MPI_CHK( mbedtls_mpi_read_string( &A, 16,
         "EFE021C2645FD1DC586E69184AF4A31E" \