SSL asynchronous private key operation callbacks: test server
New options in ssl_server2 to use the asynchronous private key
operation feature.
Features: resume delay to call resume more than once; error injection
at each stage; renegotiation support.
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index 1285abc..d75338f 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -108,6 +108,9 @@
#define DFL_KEY_FILE ""
#define DFL_CRT_FILE2 ""
#define DFL_KEY_FILE2 ""
+#define DFL_ASYNC_PRIVATE_DELAY1 ( -1 )
+#define DFL_ASYNC_PRIVATE_DELAY2 ( -1 )
+#define DFL_ASYNC_PRIVATE_ERROR ( -1 )
#define DFL_PSK ""
#define DFL_PSK_IDENTITY "Client_identity"
#define DFL_ECJPAKE_PW NULL
@@ -195,6 +198,16 @@
#define USAGE_IO ""
#endif /* MBEDTLS_X509_CRT_PARSE_C */
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+#define USAGE_SSL_ASYNC \
+ " async_private_delay1=%%d Asynchronous delay for key_file or preloaded key\n" \
+ " async_private_delay2=%%d Asynchronous delay for key_file2\n" \
+ " default: -1 (not asynchronous)\n" \
+ " async_private_error=%%d Async callback error injection (default=0=none, 1=start, 2=cancel, 3=resume, 4=pk)"
+#else
+#define USAGE_SSL_ASYNC ""
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+
#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
#define USAGE_PSK \
" psk=%%s default: \"\" (in hex, without 0x)\n" \
@@ -343,6 +356,7 @@
" cert_req_ca_list=%%d default: 1 (send ca list)\n" \
" options: 1 (send ca list), 0 (don't send)\n" \
USAGE_IO \
+ USAGE_SSL_ASYNC \
USAGE_SNI \
"\n" \
USAGE_PSK \
@@ -406,6 +420,9 @@
const char *key_file; /* the file with the server key */
const char *crt_file2; /* the file with the 2nd server certificate */
const char *key_file2; /* the file with the 2nd server key */
+ int async_private_delay1; /* number of times f_async_resume needs to be called for key 1, or -1 for no async */
+ int async_private_delay2; /* number of times f_async_resume needs to be called for key 2, or -1 for no async */
+ int async_private_error; /* inject error in async private callback */
const char *psk; /* the pre-shared key */
const char *psk_identity; /* the pre-shared key identity */
char *psk_list; /* list of PSK id/key pairs for callback */
@@ -837,6 +854,150 @@
};
#endif /* MBEDTLS_X509_CRT_PARSE_C */
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+typedef struct
+{
+ mbedtls_x509_crt *cert;
+ mbedtls_pk_context *pk;
+ unsigned delay;
+} ssl_async_key_slot_t;
+
+typedef enum {
+ SSL_ASYNC_INJECT_ERROR_NONE = 0,
+ SSL_ASYNC_INJECT_ERROR_START,
+ SSL_ASYNC_INJECT_ERROR_CANCEL,
+ SSL_ASYNC_INJECT_ERROR_RESUME,
+ SSL_ASYNC_INJECT_ERROR_PK
+#define SSL_ASYNC_INJECT_ERROR_MAX SSL_ASYNC_INJECT_ERROR_PK
+} ssl_async_inject_error_t;
+
+typedef struct
+{
+ ssl_async_key_slot_t slots[2];
+ size_t slots_used;
+ ssl_async_inject_error_t inject_error;
+ int (*f_rng)(void *, unsigned char *, size_t);
+ void *p_rng;
+} ssl_async_key_context_t;
+
+void ssl_async_set_key( ssl_async_key_context_t *ctx,
+ mbedtls_x509_crt *cert,
+ mbedtls_pk_context *pk,
+ unsigned delay )
+{
+ ctx->slots[ctx->slots_used].cert = cert;
+ ctx->slots[ctx->slots_used].pk = pk;
+ ctx->slots[ctx->slots_used].delay = delay;
+ ++ctx->slots_used;
+}
+
+typedef struct
+{
+ size_t slot;
+ mbedtls_md_type_t md_alg;
+ unsigned char hash[MBEDTLS_MD_MAX_SIZE];
+ size_t hash_len;
+ unsigned delay;
+} ssl_async_operation_context_t;
+
+int ssl_async_sign( void *connection_ctx_arg,
+ void **p_operation_ctx,
+ mbedtls_x509_crt *cert,
+ mbedtls_md_type_t md_alg,
+ const unsigned char *hash,
+ size_t hash_len )
+{
+ ssl_async_key_context_t *key_ctx = connection_ctx_arg;
+ size_t slot;
+ ssl_async_operation_context_t *ctx = NULL;
+ {
+ char dn[100];
+ mbedtls_x509_dn_gets( dn, sizeof( dn ), &cert->subject );
+ mbedtls_printf( "Async sign callback: looking for DN=%s\n", dn );
+ }
+ for( slot = 0; slot < key_ctx->slots_used; slot++ )
+ {
+ if( key_ctx->slots[slot].cert == cert )
+ break;
+ }
+ if( slot == key_ctx->slots_used )
+ {
+ mbedtls_printf( "Async sign callback: no key matches this certificate.\n" );
+ return( MBEDTLS_ERR_SSL_HW_ACCEL_FALLTHROUGH );
+ }
+ mbedtls_printf( "Async sign callback: using key slot %zd, delay=%u.\n",
+ slot, key_ctx->slots[slot].delay );
+ if( key_ctx->inject_error == SSL_ASYNC_INJECT_ERROR_START )
+ {
+ mbedtls_printf( "Async sign callback: injected error\n" );
+ return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE );
+ }
+ if( hash_len > MBEDTLS_MD_MAX_SIZE )
+ return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+ ctx = mbedtls_calloc( 1, sizeof( *ctx ) );
+ if( ctx == NULL )
+ return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
+ ctx->slot = slot;
+ ctx->md_alg = md_alg;
+ memcpy( ctx->hash, hash, hash_len );
+ ctx->hash_len = hash_len;
+ ctx->delay = key_ctx->slots[slot].delay;
+ *p_operation_ctx = ctx;
+ if( ctx->delay == 0 )
+ return( 0 );
+ else
+ return( MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS );
+}
+
+int ssl_async_resume( void *connection_ctx_arg,
+ void *operation_ctx_arg,
+ unsigned char *output,
+ size_t *output_len,
+ size_t output_size )
+{
+ ssl_async_operation_context_t *ctx = operation_ctx_arg;
+ ssl_async_key_context_t *connection_ctx = connection_ctx_arg;
+ ssl_async_key_slot_t *key_slot = &connection_ctx->slots[ctx->slot];
+ int ret;
+ if( connection_ctx->inject_error == SSL_ASYNC_INJECT_ERROR_RESUME )
+ {
+ mbedtls_printf( "Async resume callback: injected error\n" );
+ return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE );
+ }
+ if( ctx->delay > 0 )
+ {
+ --ctx->delay;
+ mbedtls_printf( "Async resume (slot %zd): call %u more times.\n",
+ ctx->slot, ctx->delay );
+ return( MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS );
+ }
+ (void) output_size; /* mbedtls_pk_size lacks this parameter */
+ ret = mbedtls_pk_sign( key_slot->pk,
+ ctx->md_alg,
+ ctx->hash, ctx->hash_len,
+ output, output_len,
+ connection_ctx->f_rng, connection_ctx->p_rng );
+ if( connection_ctx->inject_error == SSL_ASYNC_INJECT_ERROR_PK )
+ {
+ mbedtls_printf( "Async resume callback: done but injected error\n" );
+ return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE );
+ }
+ mbedtls_printf( "Async resume (slot %zd): done, status=%d.\n",
+ ctx->slot, ret );
+ mbedtls_free( ctx );
+ return( ret );
+}
+
+void ssl_async_cancel( void *connection_ctx_arg,
+ void *operation_ctx_arg )
+{
+ ssl_async_operation_context_t *ctx = operation_ctx_arg;
+ (void) connection_ctx_arg;
+ mbedtls_printf( "Async cancel callback.\n" );
+ mbedtls_free( ctx );
+}
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+
int main( int argc, char *argv[] )
{
int ret = 0, len, written, frags, exchanges_left;
@@ -875,7 +1036,10 @@
mbedtls_x509_crt srvcert2;
mbedtls_pk_context pkey2;
int key_cert_init = 0, key_cert_init2 = 0;
-#endif
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+ ssl_async_key_context_t ssl_async_keys;
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+#endif /* MBEDTLS_X509_CRT_PARSE_C */
#if defined(MBEDTLS_DHM_C) && defined(MBEDTLS_FS_IO)
mbedtls_dhm_context dhm;
#endif
@@ -977,6 +1141,9 @@
opt.key_file = DFL_KEY_FILE;
opt.crt_file2 = DFL_CRT_FILE2;
opt.key_file2 = DFL_KEY_FILE2;
+ opt.async_private_delay1 = DFL_ASYNC_PRIVATE_DELAY1;
+ opt.async_private_delay2 = DFL_ASYNC_PRIVATE_DELAY2;
+ opt.async_private_error = DFL_ASYNC_PRIVATE_ERROR;
opt.psk = DFL_PSK;
opt.psk_identity = DFL_PSK_IDENTITY;
opt.psk_list = DFL_PSK_LIST;
@@ -1063,6 +1230,22 @@
opt.key_file2 = q;
else if( strcmp( p, "dhm_file" ) == 0 )
opt.dhm_file = q;
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+ else if( strcmp( p, "async_private_delay1" ) == 0 )
+ opt.async_private_delay1 = atoi( q );
+ else if( strcmp( p, "async_private_delay2" ) == 0 )
+ opt.async_private_delay2 = atoi( q );
+ else if( strcmp( p, "async_private_error" ) == 0 )
+ {
+ int n = atoi( q );
+ if( n < 0 || n > SSL_ASYNC_INJECT_ERROR_MAX )
+ {
+ ret = 2;
+ goto usage;
+ }
+ opt.async_private_error = n;
+ }
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
else if( strcmp( p, "psk" ) == 0 )
opt.psk = q;
else if( strcmp( p, "psk_identity" ) == 0 )
@@ -1932,18 +2115,55 @@
mbedtls_ssl_conf_ca_chain( &conf, &cacert, NULL );
}
if( key_cert_init )
- if( ( ret = mbedtls_ssl_conf_own_cert( &conf, &srvcert, &pkey ) ) != 0 )
+ {
+ mbedtls_pk_context *pk = &pkey;
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+ if( opt.async_private_delay1 >= 0 )
+ {
+ ssl_async_set_key( &ssl_async_keys, &srvcert, pk,
+ opt.async_private_delay1 );
+ pk = NULL;
+ }
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+ if( ( ret = mbedtls_ssl_conf_own_cert( &conf, &srvcert, pk ) ) != 0 )
{
mbedtls_printf( " failed\n ! mbedtls_ssl_conf_own_cert returned %d\n\n", ret );
goto exit;
}
+ }
if( key_cert_init2 )
- if( ( ret = mbedtls_ssl_conf_own_cert( &conf, &srvcert2, &pkey2 ) ) != 0 )
+ {
+ mbedtls_pk_context *pk = &pkey2;
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+ if( opt.async_private_delay2 >= 0 )
+ {
+ ssl_async_set_key( &ssl_async_keys, &srvcert2, pk,
+ opt.async_private_delay2 );
+ pk = NULL;
+ }
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+ if( ( ret = mbedtls_ssl_conf_own_cert( &conf, &srvcert2, pk ) ) != 0 )
{
mbedtls_printf( " failed\n ! mbedtls_ssl_conf_own_cert returned %d\n\n", ret );
goto exit;
}
-#endif
+ }
+
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+ if( opt.async_private_delay1 >= 0 || opt.async_private_delay2 >= 0 )
+ {
+ ssl_async_keys.inject_error = opt.async_private_error;
+ ssl_async_keys.f_rng = mbedtls_ctr_drbg_random;
+ ssl_async_keys.p_rng = &ctr_drbg;
+ mbedtls_ssl_conf_async_private_cb( &conf,
+ ssl_async_sign,
+ NULL,
+ ssl_async_resume,
+ ssl_async_cancel,
+ &ssl_async_keys );
+ }
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+#endif /* MBEDTLS_X509_CRT_PARSE_C */
#if defined(SNI_OPTION)
if( opt.sni != NULL )
@@ -2113,9 +2333,21 @@
mbedtls_printf( " . Performing the SSL/TLS handshake..." );
fflush( stdout );
- do ret = mbedtls_ssl_handshake( &ssl );
+ do
+ {
+ ret = mbedtls_ssl_handshake( &ssl );
+#if defined(MBEDTLS_SSL_ASYNC_PRIVATE_C)
+ if( ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS &&
+ opt.async_private_error == SSL_ASYNC_INJECT_ERROR_CANCEL )
+ {
+ mbedtls_printf( " cancelling on injected error\n" );
+ goto reset;
+ }
+#endif /* MBEDTLS_SSL_ASYNC_PRIVATE_C */
+ }
while( ret == MBEDTLS_ERR_SSL_WANT_READ ||
- ret == MBEDTLS_ERR_SSL_WANT_WRITE );
+ ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
+ ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS );
if( ret == MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED )
{
@@ -2220,7 +2452,8 @@
ret = mbedtls_ssl_read( &ssl, buf, len );
if( ret == MBEDTLS_ERR_SSL_WANT_READ ||
- ret == MBEDTLS_ERR_SSL_WANT_WRITE )
+ ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
+ ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS )
continue;
if( ret <= 0 )
@@ -2311,7 +2544,8 @@
do ret = mbedtls_ssl_read( &ssl, buf, len );
while( ret == MBEDTLS_ERR_SSL_WANT_READ ||
- ret == MBEDTLS_ERR_SSL_WANT_WRITE );
+ ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
+ ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS );
if( ret <= 0 )
{
@@ -2347,7 +2581,8 @@
while( ( ret = mbedtls_ssl_renegotiate( &ssl ) ) != 0 )
{
if( ret != MBEDTLS_ERR_SSL_WANT_READ &&
- ret != MBEDTLS_ERR_SSL_WANT_WRITE )
+ ret != MBEDTLS_ERR_SSL_WANT_WRITE &&
+ ret != MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS )
{
mbedtls_printf( " failed\n ! mbedtls_ssl_renegotiate returned %d\n\n", ret );
goto reset;
@@ -2381,7 +2616,8 @@
}
if( ret != MBEDTLS_ERR_SSL_WANT_READ &&
- ret != MBEDTLS_ERR_SSL_WANT_WRITE )
+ ret != MBEDTLS_ERR_SSL_WANT_WRITE &&
+ ret != MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS )
{
mbedtls_printf( " failed\n ! mbedtls_ssl_write returned %d\n\n", ret );
goto reset;
@@ -2393,7 +2629,8 @@
{
do ret = mbedtls_ssl_write( &ssl, buf, len );
while( ret == MBEDTLS_ERR_SSL_WANT_READ ||
- ret == MBEDTLS_ERR_SSL_WANT_WRITE );
+ ret == MBEDTLS_ERR_SSL_WANT_WRITE ||
+ ret == MBEDTLS_ERR_SSL_ASYNC_IN_PROGRESS );
if( ret < 0 )
{