Use the new PK RSA-alt interface
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index cd77eb8..babd60a 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -2042,27 +2042,12 @@
if( ssl->minor_ver == SSL_MINOR_VERSION_3 )
ssl->out_msg[5] = SSL_SIG_RSA;
- if( ssl->rsa_use_alt )
+ if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen,
+ ssl->out_msg + 6 + offset, &n,
+ ssl->f_rng, ssl->p_rng ) ) != 0 )
{
- if( ( ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng,
- RSA_PRIVATE, md_alg,
- hashlen, hash, ssl->out_msg + 6 + offset ) ) != 0 )
- {
- SSL_DEBUG_RET( 1, "rsa_sign", ret );
- return( ret );
- }
-
- n = ssl->rsa_key_len ( ssl->rsa_key );
- }
- else
- {
- if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen,
- ssl->out_msg + 6 + offset, &n,
- ssl->f_rng, ssl->p_rng ) ) != 0 )
- {
- SSL_DEBUG_RET( 1, "pk_sign", ret );
- return( ret );
- }
+ SSL_DEBUG_RET( 1, "pk_sign", ret );
+ return( ret );
}
}
else
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index 6fb16ec..0fa4f66 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -2080,27 +2080,12 @@
n += 2;
}
- if( ssl->rsa_use_alt )
+ if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen,
+ p + 2 , &signature_len,
+ ssl->f_rng, ssl->p_rng ) ) != 0 )
{
- if( ( ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng,
- ssl->p_rng, RSA_PRIVATE, md_alg, hashlen,
- hash, p + 2 ) ) != 0 )
- {
- SSL_DEBUG_RET( 1, "rsa_sign", ret );
- return( ret );
- }
-
- signature_len = ssl->rsa_key_len( ssl->rsa_key );
- }
- else
- {
- if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen,
- p + 2 , &signature_len,
- ssl->f_rng, ssl->p_rng ) ) != 0 )
- {
- SSL_DEBUG_RET( 1, "pk_sign", ret );
- return( ret );
- }
+ SSL_DEBUG_RET( 1, "pk_sign", ret );
+ return( ret );
}
}
else
@@ -2289,21 +2274,11 @@
return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_KEY_EXCHANGE );
}
- if( ssl->rsa_use_alt ) {
- ret = ssl->rsa_decrypt( ssl->rsa_key, RSA_PRIVATE,
- &ssl->handshake->pmslen,
- ssl->in_msg + i,
- ssl->handshake->premaster,
- sizeof(ssl->handshake->premaster) );
- }
- else
- {
- ret = pk_decrypt( ssl->pk_key,
- ssl->in_msg + i, n,
- ssl->handshake->premaster, &ssl->handshake->pmslen,
- sizeof(ssl->handshake->premaster),
- ssl->f_rng, ssl->p_rng );
- }
+ ret = pk_decrypt( ssl->pk_key,
+ ssl->in_msg + i, n,
+ ssl->handshake->premaster, &ssl->handshake->pmslen,
+ sizeof(ssl->handshake->premaster),
+ ssl->f_rng, ssl->p_rng );
if( ret != 0 || ssl->handshake->pmslen != 48 ||
ssl->handshake->premaster[0] != ssl->handshake->max_major_ver ||
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index d4723d7..9e446f6 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -3162,18 +3162,30 @@
}
#endif /* POLARSSL_RSA_C */
-void ssl_set_own_cert_alt_rsa( ssl_context *ssl, x509_cert *own_cert,
+int ssl_set_own_cert_alt_rsa( ssl_context *ssl, x509_cert *own_cert,
void *rsa_key,
rsa_decrypt_func rsa_decrypt,
rsa_sign_func rsa_sign,
rsa_key_len_func rsa_key_len )
{
+ int ret;
+
ssl->own_cert = own_cert;
ssl->rsa_use_alt = 1;
ssl->rsa_key = rsa_key;
ssl->rsa_decrypt = rsa_decrypt;
ssl->rsa_sign = rsa_sign;
ssl->rsa_key_len = rsa_key_len;
+
+ if( ( ssl->pk_key = polarssl_malloc( sizeof( pk_context ) ) ) == NULL )
+ return( POLARSSL_ERR_SSL_MALLOC_FAILED );
+
+ ssl->pk_key_own_alloc = 1;
+
+ pk_init( ssl->pk_key );
+
+ return( pk_init_ctx_rsa_alt( ssl->pk_key, rsa_key,
+ rsa_decrypt, rsa_sign, rsa_key_len ) );
}
#endif /* POLARSSL_X509_PARSE_C */
@@ -3780,6 +3792,12 @@
ssl->hostname_len = 0;
}
+ if( ssl->pk_key_own_alloc )
+ {
+ pk_free( ssl->pk_key );
+ polarssl_free( ssl->pk_key );
+ }
+
#if defined(POLARSSL_SSL_HW_RECORD_ACCEL)
if( ssl_hw_record_finish != NULL )
{