SSL asynchronous private key operation callbacks: test server

New options in ssl_server2 to use the asynchronous private key
operation feature.

Features: resume delay to call resume more than once; error injection
at each stage; renegotiation support.
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index 1285abc..d75338f 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -108,6 +108,9 @@
 #define DFL_KEY_FILE            ""
 #define DFL_CRT_FILE2           ""
 #define DFL_KEY_FILE2           ""
+#define DFL_ASYNC_PRIVATE_DELAY1 ( -1 )
+#define DFL_ASYNC_PRIVATE_DELAY2 ( -1 )
+#define DFL_ASYNC_PRIVATE_ERROR  ( -1 )
 #define DFL_PSK                 ""
 #define DFL_PSK_IDENTITY        "Client_identity"
 #define DFL_ECJPAKE_PW          NULL
@@ -195,6 +198,16 @@
 #define USAGE_IO ""
 #endif /* MBEDTLS_X509_CRT_PARSE_C */
 
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+#define USAGE_SSL_ASYNC \
+    "    async_private_delay1=%%d  Asynchronous delay for key_file or preloaded key\n" \
+    "    async_private_delay2=%%d  Asynchronous delay for key_file2\n" \
+    "                              default: -1 (not asynchronous)\n" \
+    "    async_private_error=%%d   Async callback error injection (default=0=none, 1=start, 2=cancel, 3=resume, 4=pk)"
+#else
+#define USAGE_SSL_ASYNC ""
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+
 #if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
 #define USAGE_PSK                                                   \
     "    psk=%%s              default: \"\" (in hex, without 0x)\n" \
@@ -343,6 +356,7 @@
     "    cert_req_ca_list=%%d default: 1 (send ca list)\n"  \
     "                        options: 1 (send ca list), 0 (don't send)\n" \
     USAGE_IO                                                \
+    USAGE_SSL_ASYNC                                         \
     USAGE_SNI                                               \
     "\n"                                                    \
     USAGE_PSK                                               \
@@ -406,6 +420,9 @@
     const char *key_file;       /* the file with the server key             */
     const char *crt_file2;      /* the file with the 2nd server certificate */
     const char *key_file2;      /* the file with the 2nd server key         */
+    int async_private_delay1;   /* number of times f_async_resume needs to be called for key 1, or -1 for no async */
+    int async_private_delay2;   /* number of times f_async_resume needs to be called for key 2, or -1 for no async */
+    int async_private_error;    /* inject error in async private callback */
     const char *psk;            /* the pre-shared key                       */
     const char *psk_identity;   /* the pre-shared key identity              */
     char *psk_list;             /* list of PSK id/key pairs for callback    */
@@ -837,6 +854,150 @@
 };
 #endif /* MBEDTLS_X509_CRT_PARSE_C */
 
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+typedef struct
+{
+    mbedtls_x509_crt *cert;
+    mbedtls_pk_context *pk;
+    unsigned delay;
+} ssl_async_key_slot_t;
+
+typedef enum {
+    SSL_ASYNC_INJECT_ERROR_NONE = 0,
+    SSL_ASYNC_INJECT_ERROR_START,
+    SSL_ASYNC_INJECT_ERROR_CANCEL,
+    SSL_ASYNC_INJECT_ERROR_RESUME,
+    SSL_ASYNC_INJECT_ERROR_PK
+#define SSL_ASYNC_INJECT_ERROR_MAX SSL_ASYNC_INJECT_ERROR_PK
+} ssl_async_inject_error_t;
+
+typedef struct
+{
+    ssl_async_key_slot_t slots[2];
+    size_t slots_used;
+    ssl_async_inject_error_t inject_error;
+    int (*f_rng)(void *, unsigned char *, size_t);
+    void *p_rng;
+} ssl_async_key_context_t;
+
+void ssl_async_set_key( ssl_async_key_context_t *ctx,
+                        mbedtls_x509_crt *cert,
+                        mbedtls_pk_context *pk,
+                        unsigned delay )
+{
+    ctx->slots[ctx->slots_used].cert = cert;
+    ctx->slots[ctx->slots_used].pk = pk;
+    ctx->slots[ctx->slots_used].delay = delay;
+    ++ctx->slots_used;
+}
+
+typedef struct
+{
+    size_t slot;
+    mbedtls_md_type_t md_alg;
+    unsigned char hash[MBEDTLS_MD_MAX_SIZE];
+    size_t hash_len;
+    unsigned delay;
+} ssl_async_operation_context_t;
+
+int ssl_async_sign( void *connection_ctx_arg,
+                    void **p_operation_ctx,
+                    mbedtls_x509_crt *cert,
+                    mbedtls_md_type_t md_alg,
+                    const unsigned char *hash,
+                    size_t hash_len )
+{
+    ssl_async_key_context_t *key_ctx = connection_ctx_arg;
+    size_t slot;
+    ssl_async_operation_context_t *ctx = NULL;
+    {
+        char dn[100];
+        mbedtls_x509_dn_gets( dn, sizeof( dn ), &cert->subject );
+        mbedtls_printf( "Async sign callback: looking for DN=%s\n", dn );
+    }
+    for( slot = 0; slot < key_ctx->slots_used; slot++ )
+    {
+        if( key_ctx->slots[slot].cert == cert )
+            break;
+    }
+    if( slot == key_ctx->slots_used )
+    {
+        mbedtls_printf( "Async sign callback: no key matches this certificate.\n" );
+        return( MBEDTLS_ERR_SSL_HW_ACCEL_FALLTHROUGH );
+    }
+    mbedtls_printf( "Async sign callback: using key slot %zd, delay=%u.\n",
+                    slot, key_ctx->slots[slot].delay );
+    if( key_ctx->inject_error == SSL_ASYNC_INJECT_ERROR_START )
+    {
+        mbedtls_printf( "Async sign callback: injected error\n" );
+        return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE );
+    }
+    if( hash_len > MBEDTLS_MD_MAX_SIZE )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+    ctx = mbedtls_calloc( 1, sizeof( *ctx ) );
+    if( ctx == NULL )
+        return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
+    ctx->slot = slot;
+    ctx->md_alg = md_alg;
+    memcpy( ctx->hash, hash, hash_len );
+    ctx->hash_len = hash_len;
+    ctx->delay = key_ctx->slots[slot].delay;
+    *p_operation_ctx = ctx;
+    if( ctx->delay == 0 )
+        return( 0 );
+    else
+        return( MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS );
+}
+
+int ssl_async_resume( void *connection_ctx_arg,
+                      void *operation_ctx_arg,
+                      unsigned char *output,
+                      size_t *output_len,
+                      size_t output_size )
+{
+    ssl_async_operation_context_t *ctx = operation_ctx_arg;
+    ssl_async_key_context_t *connection_ctx = connection_ctx_arg;
+    ssl_async_key_slot_t *key_slot = &connection_ctx->slots[ctx->slot];
+    int ret;
+    if( connection_ctx->inject_error == SSL_ASYNC_INJECT_ERROR_RESUME )
+    {
+        mbedtls_printf( "Async resume callback: injected error\n" );
+        return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE );
+    }
+    if( ctx->delay > 0 )
+    {
+        --ctx->delay;
+        mbedtls_printf( "Async resume (slot %zd): call %u more times.\n",
+                        ctx->slot, ctx->delay );
+        return( MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS );
+    }
+    (void) output_size; /* mbedtls_pk_size lacks this parameter */
+    ret = mbedtls_pk_sign( key_slot->pk,
+                           ctx->md_alg,
+                           ctx->hash, ctx->hash_len,
+                           output, output_len,
+                           connection_ctx->f_rng, connection_ctx->p_rng );
+    if( connection_ctx->inject_error == SSL_ASYNC_INJECT_ERROR_PK )
+    {
+        mbedtls_printf( "Async resume callback: done but injected error\n" );
+        return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE );
+    }
+    mbedtls_printf( "Async resume (slot %zd): done, status=%d.\n",
+                    ctx->slot, ret );
+    mbedtls_free( ctx );
+    return( ret );
+}
+
+void ssl_async_cancel( void *connection_ctx_arg,
+                       void *operation_ctx_arg )
+{
+    ssl_async_operation_context_t *ctx = operation_ctx_arg;
+    (void) connection_ctx_arg;
+    mbedtls_printf( "Async cancel callback.\n" );
+    mbedtls_free( ctx );
+}
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+
 int main( int argc, char *argv[] )
 {
     int ret = 0, len, written, frags, exchanges_left;
@@ -875,7 +1036,10 @@
     mbedtls_x509_crt srvcert2;
     mbedtls_pk_context pkey2;
     int key_cert_init = 0, key_cert_init2 = 0;
-#endif
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+    ssl_async_key_context_t ssl_async_keys;
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+#endif /* MBEDTLS_X509_CRT_PARSE_C */
 #if defined(MBEDTLS_DHM_C) && defined(MBEDTLS_FS_IO)
     mbedtls_dhm_context dhm;
 #endif
@@ -977,6 +1141,9 @@
     opt.key_file            = DFL_KEY_FILE;
     opt.crt_file2           = DFL_CRT_FILE2;
     opt.key_file2           = DFL_KEY_FILE2;
+    opt.async_private_delay1 = DFL_ASYNC_PRIVATE_DELAY1;
+    opt.async_private_delay2 = DFL_ASYNC_PRIVATE_DELAY2;
+    opt.async_private_error = DFL_ASYNC_PRIVATE_ERROR;
     opt.psk                 = DFL_PSK;
     opt.psk_identity        = DFL_PSK_IDENTITY;
     opt.psk_list            = DFL_PSK_LIST;
@@ -1063,6 +1230,22 @@
             opt.key_file2 = q;
         else if( strcmp( p, "dhm_file" ) == 0 )
             opt.dhm_file = q;
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+        else if( strcmp( p, "async_private_delay1" ) == 0 )
+            opt.async_private_delay1 = atoi( q );
+        else if( strcmp( p, "async_private_delay2" ) == 0 )
+            opt.async_private_delay2 = atoi( q );
+        else if( strcmp( p, "async_private_error" ) == 0 )
+        {
+            int n = atoi( q );
+            if( n < 0 || n > SSL_ASYNC_INJECT_ERROR_MAX )
+            {
+                ret = 2;
+                goto usage;
+            }
+            opt.async_private_error = n;
+        }
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
         else if( strcmp( p, "psk" ) == 0 )
             opt.psk = q;
         else if( strcmp( p, "psk_identity" ) == 0 )
@@ -1932,18 +2115,55 @@
         mbedtls_ssl_conf_ca_chain( &conf, &cacert, NULL );
     }
     if( key_cert_init )
-        if( ( ret = mbedtls_ssl_conf_own_cert( &conf, &srvcert, &pkey ) ) != 0 )
+    {
+        mbedtls_pk_context *pk = &pkey;
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+        if( opt.async_private_delay1 >= 0 )
+        {
+            ssl_async_set_key( &ssl_async_keys, &srvcert, pk,
+                               opt.async_private_delay1 );
+            pk = NULL;
+        }
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+        if( ( ret = mbedtls_ssl_conf_own_cert( &conf, &srvcert, pk ) ) != 0 )
         {
             mbedtls_printf( " failed\n  ! mbedtls_ssl_conf_own_cert returned %d\n\n", ret );
             goto exit;
         }
+    }
     if( key_cert_init2 )
-        if( ( ret = mbedtls_ssl_conf_own_cert( &conf, &srvcert2, &pkey2 ) ) != 0 )
+    {
+        mbedtls_pk_context *pk = &pkey2;
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+        if( opt.async_private_delay2 >= 0 )
+        {
+            ssl_async_set_key( &ssl_async_keys, &srvcert2, pk,
+                               opt.async_private_delay2 );
+            pk = NULL;
+        }
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+        if( ( ret = mbedtls_ssl_conf_own_cert( &conf, &srvcert2, pk ) ) != 0 )
         {
             mbedtls_printf( " failed\n  ! mbedtls_ssl_conf_own_cert returned %d\n\n", ret );
             goto exit;
         }
-#endif
+    }
+
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+    if( opt.async_private_delay1 >= 0 || opt.async_private_delay2 >= 0 )
+    {
+        ssl_async_keys.inject_error = opt.async_private_error;
+        ssl_async_keys.f_rng = mbedtls_ctr_drbg_random;
+        ssl_async_keys.p_rng = &ctr_drbg;
+        mbedtls_ssl_conf_async_private_cb( &conf,
+                                           ssl_async_sign,
+                                           NULL,
+                                           ssl_async_resume,
+                                           ssl_async_cancel,
+                                           &ssl_async_keys );
+    }
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+#endif /* MBEDTLS_X509_CRT_PARSE_C */
 
 #if defined(SNI_OPTION)
     if( opt.sni != NULL )
@@ -2113,9 +2333,21 @@
     mbedtls_printf( "  . Performing the SSL/TLS handshake..." );
     fflush( stdout );
 
-    do ret = mbedtls_ssl_handshake( &ssl );
+    do
+    {
+        ret = mbedtls_ssl_handshake( &ssl );
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+        if( ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS &&
+            opt.async_private_error == SSL_ASYNC_INJECT_ERROR_CANCEL )
+        {
+            mbedtls_printf( " cancelling on injected error\n" );
+            goto reset;
+        }
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+    }
     while( ret == MBEDTLS_ERR_SSL_WANT_READ ||
-           ret == MBEDTLS_ERR_SSL_WANT_WRITE );
+           ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
+           ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS );
 
     if( ret == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED )
     {
@@ -2220,7 +2452,8 @@
             ret = mbedtls_ssl_read( &ssl, buf, len );
 
             if( ret == MBEDTLS_ERR_SSL_WANT_READ ||
-                ret == MBEDTLS_ERR_SSL_WANT_WRITE )
+                ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
+                ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS )
                 continue;
 
             if( ret <= 0 )
@@ -2311,7 +2544,8 @@
 
         do ret = mbedtls_ssl_read( &ssl, buf, len );
         while( ret == MBEDTLS_ERR_SSL_WANT_READ ||
-               ret == MBEDTLS_ERR_SSL_WANT_WRITE );
+               ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
+               ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS );
 
         if( ret <= 0 )
         {
@@ -2347,7 +2581,8 @@
         while( ( ret = mbedtls_ssl_renegotiate( &ssl ) ) != 0 )
         {
             if( ret != MBEDTLS_ERR_SSL_WANT_READ &&
-                ret != MBEDTLS_ERR_SSL_WANT_WRITE )
+                ret != MBEDTLS_ERR_SSL_WANT_WRITE &&
+                ret != MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS )
             {
                 mbedtls_printf( " failed\n  ! mbedtls_ssl_renegotiate returned %d\n\n", ret );
                 goto reset;
@@ -2381,7 +2616,8 @@
                 }
 
                 if( ret != MBEDTLS_ERR_SSL_WANT_READ &&
-                    ret != MBEDTLS_ERR_SSL_WANT_WRITE )
+                    ret != MBEDTLS_ERR_SSL_WANT_WRITE &&
+                    ret != MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS )
                 {
                     mbedtls_printf( " failed\n  ! mbedtls_ssl_write returned %d\n\n", ret );
                     goto reset;
@@ -2393,7 +2629,8 @@
     {
         do ret = mbedtls_ssl_write( &ssl, buf, len );
         while( ret == MBEDTLS_ERR_SSL_WANT_READ ||
-               ret == MBEDTLS_ERR_SSL_WANT_WRITE );
+               ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
+               ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS );
 
         if( ret < 0 )
         {