Merge code for RSA and ECDSA in SSL
diff --git a/include/polarssl/pk.h b/include/polarssl/pk.h
index 6135bb0..7c674ab 100644
--- a/include/polarssl/pk.h
+++ b/include/polarssl/pk.h
@@ -245,6 +245,17 @@
size_t pk_get_size( const pk_context *ctx );
/**
+ * \brief Get the length in bytes of the underlying key
+ * \param ctx Context to use
+ *
+ * \return Key lenght in bytes, or 0 on error
+ */
+static size_t pk_get_len( const pk_context *ctx )
+{
+ return( ( pk_get_size( ctx ) + 7 ) / 8 );
+}
+
+/**
* \brief Tell if a context can do the operation given by type
*
* \param ctx Context to test
diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h
index 442dba2..dfd6490 100644
--- a/include/polarssl/ssl.h
+++ b/include/polarssl/ssl.h
@@ -202,6 +202,7 @@
#define SSL_HASH_SHA384 5
#define SSL_HASH_SHA512 6
+#define SSL_SIG_ANON 0
#define SSL_SIG_RSA 1
#define SSL_SIG_ECDSA 3
@@ -580,13 +581,6 @@
*/
pk_context *pk_key; /*!< own private key */
int pk_key_own_alloc; /*!< did we allocate pk_key? */
-#if defined(POLARSSL_RSA_C)
- int rsa_use_alt; /*<! flag for alt (temporary) */
- void *rsa_key; /*!< own RSA private key */
- rsa_decrypt_func rsa_decrypt; /*!< function for RSA decrypt*/
- rsa_sign_func rsa_sign; /*!< function for RSA sign */
- rsa_key_len_func rsa_key_len; /*!< function for RSA key len*/
-#endif /* POLARSSL_RSA_C */
#if defined(POLARSSL_X509_PARSE_C)
x509_cert *own_cert; /*!< own X.509 certificate */
@@ -909,7 +903,7 @@
* \param pk_key own private key
*/
void ssl_set_own_cert( ssl_context *ssl, x509_cert *own_cert,
- pk_context *rsa_key );
+ pk_context *pk_key );
#if defined(POLARSSL_RSA_C)
/**
@@ -922,9 +916,11 @@
* \param ssl SSL context
* \param own_cert own public certificate chain
* \param rsa_key own private RSA key
+ *
+ * \return 0 on success, or a specific error code.
*/
-void ssl_set_own_cert_rsa( ssl_context *ssl, x509_cert *own_cert,
- rsa_context *rsa_key );
+int ssl_set_own_cert_rsa( ssl_context *ssl, x509_cert *own_cert,
+ rsa_context *rsa_key );
#endif /* POLARSSL_RSA_C */
/**
@@ -1388,6 +1384,8 @@
void ssl_optimize_checksum( ssl_context *ssl, const ssl_ciphersuite_t *ciphersuite_info );
+unsigned char ssl_sig_from_pk( pk_context *pk );
+
#ifdef __cplusplus
}
#endif
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index babd60a..ba2c68c 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -2023,7 +2023,7 @@
md_alg = POLARSSL_MD_SHA256;
ssl->out_msg[4] = SSL_HASH_SHA256;
}
- /* SIG added later */
+ ssl->out_msg[5] = ssl_sig_from_pk( ssl->pk_key );
if( ( md_info = md_info_from_type( md_alg ) ) == NULL )
{
@@ -2036,40 +2036,13 @@
offset = 2;
}
-#if defined(POLARSSL_RSA_C)
- if( ssl->rsa_key != NULL )
+ if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen,
+ ssl->out_msg + 6 + offset, &n,
+ ssl->f_rng, ssl->p_rng ) ) != 0 )
{
- if( ssl->minor_ver == SSL_MINOR_VERSION_3 )
- ssl->out_msg[5] = SSL_SIG_RSA;
-
- if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen,
- ssl->out_msg + 6 + offset, &n,
- ssl->f_rng, ssl->p_rng ) ) != 0 )
- {
- SSL_DEBUG_RET( 1, "pk_sign", ret );
- return( ret );
- }
+ SSL_DEBUG_RET( 1, "pk_sign", ret );
+ return( ret );
}
- else
-#endif /* POLARSSL_RSA_C */
-#if defined(POLARSSL_ECDSA_C)
- if( pk_can_do( ssl->pk_key, POLARSSL_PK_ECDSA ) )
- {
- if( ssl->minor_ver == SSL_MINOR_VERSION_3 )
- ssl->out_msg[5] = SSL_SIG_ECDSA;
-
- if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen,
- ssl->out_msg + 6 + offset, &n,
- ssl->f_rng, ssl->p_rng ) ) != 0 )
- {
- SSL_DEBUG_RET( 1, "pk_sign", ret );
- return( ret );
- }
- }
- else
-#endif /* POLARSSL_ECDSA_C */
- /* should never happen */
- return( POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE );
ssl->out_msg[4 + offset] = (unsigned char)( n >> 8 );
ssl->out_msg[5 + offset] = (unsigned char)( n );
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index 0fa4f66..6c4bdf0 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -2069,50 +2069,21 @@
return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
}
-#if defined(POLARSSL_RSA_C)
- if( ssl->rsa_key != NULL )
+ if( ssl->minor_ver == SSL_MINOR_VERSION_3 )
{
- if( ssl->minor_ver == SSL_MINOR_VERSION_3 )
- {
- *(p++) = ssl->handshake->sig_alg;
- *(p++) = SSL_SIG_RSA;
+ *(p++) = ssl->handshake->sig_alg;
+ *(p++) = ssl_sig_from_pk( ssl->pk_key );
- n += 2;
- }
-
- if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen,
- p + 2 , &signature_len,
- ssl->f_rng, ssl->p_rng ) ) != 0 )
- {
- SSL_DEBUG_RET( 1, "pk_sign", ret );
- return( ret );
- }
+ n += 2;
}
- else
-#endif /* POLARSSL_RSA_C */
-#if defined(POLARSSL_ECDSA_C)
- if( pk_can_do( ssl->pk_key, POLARSSL_PK_ECDSA ) )
+
+ if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen,
+ p + 2 , &signature_len,
+ ssl->f_rng, ssl->p_rng ) ) != 0 )
{
- if( ssl->minor_ver == SSL_MINOR_VERSION_3 )
- {
- *(p++) = ssl->handshake->sig_alg;
- *(p++) = SSL_SIG_ECDSA;
-
- n += 2;
- }
-
- if( ( ret = pk_sign( ssl->pk_key, md_alg, hash, hashlen,
- p + 2 , &signature_len,
- ssl->f_rng, ssl->p_rng ) ) != 0 )
- {
- SSL_DEBUG_RET( 1, "pk_sign", ret );
- return( ret );
- }
+ SSL_DEBUG_RET( 1, "pk_sign", ret );
+ return( ret );
}
- else
-#endif /* POLARSSL_ECDSA_C */
- /* should never happen */
- return( POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE );
*(p++) = (unsigned char)( signature_len >> 8 );
*(p++) = (unsigned char)( signature_len );
@@ -2254,7 +2225,7 @@
* Decrypt the premaster using own private RSA key
*/
i = 4;
- n = ssl->rsa_key_len( ssl->rsa_key );
+ n = pk_get_len( ssl->pk_key );
ssl->handshake->pmslen = 48;
if( ssl->minor_ver != SSL_MINOR_VERSION_0 )
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 9e446f6..527b333 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -131,30 +131,6 @@
int (*ssl_hw_record_finish)(ssl_context *ssl) = NULL;
#endif
-#if defined(POLARSSL_RSA_C)
-static int ssl_rsa_decrypt( void *ctx, int mode, size_t *olen,
- const unsigned char *input, unsigned char *output,
- size_t output_max_len )
-{
- return rsa_pkcs1_decrypt( (rsa_context *) ctx, mode, olen, input, output,
- output_max_len );
-}
-
-static int ssl_rsa_sign( void *ctx,
- int (*f_rng)(void *, unsigned char *, size_t), void *p_rng,
- int mode, int hash_id, unsigned int hashlen,
- const unsigned char *hash, unsigned char *sig )
-{
- return rsa_pkcs1_sign( (rsa_context *) ctx, f_rng, p_rng, mode, hash_id,
- hashlen, hash, sig );
-}
-
-static size_t ssl_rsa_key_len( void *ctx )
-{
- return ( (rsa_context *) ctx )->len;
-}
-#endif /* POLARSSL_RSA_C */
-
/*
* Key material generation
*/
@@ -2858,12 +2834,6 @@
/*
* Sane defaults
*/
-#if defined(POLARSSL_RSA_C)
- ssl->rsa_decrypt = ssl_rsa_decrypt;
- ssl->rsa_sign = ssl_rsa_sign;
- ssl->rsa_key_len = ssl_rsa_key_len;
-#endif
-
ssl->min_major_ver = SSL_MAJOR_VERSION_3;
ssl->min_minor_ver = SSL_MINOR_VERSION_0;
ssl->max_major_ver = SSL_MAJOR_VERSION_3;
@@ -3147,18 +3117,31 @@
{
ssl->own_cert = own_cert;
ssl->pk_key = pk_key;
-
- /* Temporary, until everything is moved to PK */
- if( pk_key->pk_info->type == POLARSSL_PK_RSA )
- ssl->rsa_key = pk_key->pk_ctx;
}
#if defined(POLARSSL_RSA_C)
-void ssl_set_own_cert_rsa( ssl_context *ssl, x509_cert *own_cert,
+int ssl_set_own_cert_rsa( ssl_context *ssl, x509_cert *own_cert,
rsa_context *rsa_key )
{
+ int ret;
+
ssl->own_cert = own_cert;
- ssl->rsa_key = rsa_key;
+
+ if( ( ssl->pk_key = polarssl_malloc( sizeof( pk_context ) ) ) == NULL )
+ return( POLARSSL_ERR_SSL_MALLOC_FAILED );
+
+ ssl->pk_key_own_alloc = 1;
+
+ pk_init( ssl->pk_key );
+
+ ret = pk_init_ctx( ssl->pk_key, pk_info_from_type( POLARSSL_PK_RSA ) );
+ if( ret != 0 )
+ return( ret );
+
+ if( ( ret = rsa_copy( ssl->pk_key->pk_ctx, rsa_key ) ) != 0 )
+ return( ret );
+
+ return( 0 );
}
#endif /* POLARSSL_RSA_C */
@@ -3168,14 +3151,7 @@
rsa_sign_func rsa_sign,
rsa_key_len_func rsa_key_len )
{
- int ret;
-
ssl->own_cert = own_cert;
- ssl->rsa_use_alt = 1;
- ssl->rsa_key = rsa_key;
- ssl->rsa_decrypt = rsa_decrypt;
- ssl->rsa_sign = rsa_sign;
- ssl->rsa_key_len = rsa_key_len;
if( ( ssl->pk_key = polarssl_malloc( sizeof( pk_context ) ) ) == NULL )
return( POLARSSL_ERR_SSL_MALLOC_FAILED );
@@ -3812,4 +3788,20 @@
memset( ssl, 0, sizeof( ssl_context ) );
}
+/*
+ * Get the SSL_SIG_* constant corresponding to a public key
+ */
+unsigned char ssl_sig_from_pk( pk_context *pk )
+{
+#if defined(POLARSSL_RSA_C)
+ if( pk_can_do( pk, POLARSSL_PK_RSA ) )
+ return( SSL_SIG_RSA );
+#endif
+#if defined(POLARSSL_ECDSA_C)
+ if( pk_can_do( pk, POLARSSL_PK_ECDSA ) )
+ return( SSL_SIG_ECDSA );
+#endif
+ return( SSL_SIG_ANON );
+}
+
#endif