Change pk_set_type to pk_init_ctx for consistency
diff --git a/library/pk.c b/library/pk.c
index 61544eb..4c16de8 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -67,7 +67,7 @@
/*
* Get pk_info structure from type
*/
-static const pk_info_t * pk_info_from_type( pk_type_t pk_type )
+const pk_info_t * pk_info_from_type( pk_type_t pk_type )
{
switch( pk_type ) {
#if defined(POLARSSL_RSA_C)
@@ -90,21 +90,11 @@
}
/*
- * Set a pk_context to a given type
+ * Initialise context
*/
-int pk_set_type( pk_context *ctx, pk_type_t type )
+int pk_init_ctx( pk_context *ctx, const pk_info_t *info )
{
- const pk_info_t *info;
-
- if( ctx->pk_info != NULL )
- {
- if( ctx->pk_info->type == type )
- return 0;
-
- return( POLARSSL_ERR_PK_TYPE_MISMATCH );
- }
-
- if( ( info = pk_info_from_type( type ) ) == NULL )
+ if( ctx == NULL || info == NULL || ctx->pk_info != NULL )
return( POLARSSL_ERR_PK_BAD_INPUT_DATA );
if( ( ctx->pk_ctx = info->ctx_alloc_func() ) == NULL )
diff --git a/library/x509parse.c b/library/x509parse.c
index e080174..4da4e75 100644
--- a/library/x509parse.c
+++ b/library/x509parse.c
@@ -570,6 +570,7 @@
size_t len;
x509_buf alg_params;
pk_type_t pk_alg = POLARSSL_PK_NONE;
+ const pk_info_t *pk_info;
if( ( ret = asn1_get_tag( p, end, &len,
ASN1_CONSTRUCTED | ASN1_SEQUENCE ) ) != 0 )
@@ -589,7 +590,10 @@
return( POLARSSL_ERR_X509_CERT_INVALID_PUBKEY +
POLARSSL_ERR_ASN1_LENGTH_MISMATCH );
- if( ( ret = pk_set_type( pk, pk_alg ) ) != 0 )
+ if( ( pk_info = pk_info_from_type( pk_alg ) ) == NULL )
+ return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG );
+
+ if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 )
return( ret );
#if defined(POLARSSL_RSA_C)
@@ -2142,10 +2146,12 @@
pk_context pk;
pk_init( &pk );
- pk_set_type( &pk, POLARSSL_PK_RSA );
ret = x509parse_keyfile( &pk, path, pwd );
+ if( ret == 0 && ! pk_can_do( &pk, POLARSSL_PK_RSA ) )
+ ret = POLARSSL_ERR_PK_TYPE_MISMATCH;
+
if( ret == 0 )
rsa_copy( rsa, pk_rsa( pk ) );
else
@@ -2165,10 +2171,12 @@
pk_context pk;
pk_init( &pk );
- pk_set_type( &pk, POLARSSL_PK_RSA );
ret = x509parse_public_keyfile( &pk, path );
+ if( ret == 0 && ! pk_can_do( &pk, POLARSSL_PK_RSA ) )
+ ret = POLARSSL_ERR_PK_TYPE_MISMATCH;
+
if( ret == 0 )
rsa_copy( rsa, pk_rsa( pk ) );
else
@@ -2380,6 +2388,7 @@
unsigned char *p = (unsigned char *) key;
unsigned char *end = p + keylen;
pk_type_t pk_alg = POLARSSL_PK_NONE;
+ const pk_info_t *pk_info;
/*
* This function parses the PrivatKeyInfo object (PKCS#8 v1.2 = RFC 5208)
@@ -2421,7 +2430,10 @@
return( POLARSSL_ERR_X509_KEY_INVALID_FORMAT +
POLARSSL_ERR_ASN1_OUT_OF_DATA );
- if( ( ret = pk_set_type( pk, pk_alg ) ) != 0 )
+ if( ( pk_info = pk_info_from_type( pk_alg ) ) == NULL )
+ return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG );
+
+ if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 )
return( ret );
#if defined(POLARSSL_RSA_C)
@@ -2568,6 +2580,7 @@
const unsigned char *pwd, size_t pwdlen )
{
int ret;
+ const pk_info_t *pk_info;
#if defined(POLARSSL_PEM_C)
size_t len;
@@ -2582,7 +2595,10 @@
key, pwd, pwdlen, &len );
if( ret == 0 )
{
- if( ( ret = pk_set_type( pk, POLARSSL_PK_RSA ) ) != 0 ||
+ if( ( pk_info = pk_info_from_type( POLARSSL_PK_RSA ) ) == NULL )
+ return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG );
+
+ if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 ||
( ret = x509parse_key_pkcs1_der( pk_rsa( *pk ),
pem.buf, pem.buflen ) ) != 0 )
{
@@ -2607,7 +2623,10 @@
key, pwd, pwdlen, &len );
if( ret == 0 )
{
- if( ( ret = pk_set_type( pk, POLARSSL_PK_ECKEY ) ) != 0 ||
+ if( ( pk_info = pk_info_from_type( POLARSSL_PK_ECKEY ) ) == NULL )
+ return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG );
+
+ if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 ||
( ret = x509parse_key_sec1_der( pk_ec( *pk ),
pem.buf, pem.buflen ) ) != 0 )
{
@@ -2692,7 +2711,10 @@
pk_free( pk );
#if defined(POLARSSL_RSA_C)
- if( ( ret = pk_set_type( pk, POLARSSL_PK_RSA ) ) == 0 &&
+ if( ( pk_info = pk_info_from_type( POLARSSL_PK_RSA ) ) == NULL )
+ return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG );
+
+ if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 ||
( ret = x509parse_key_pkcs1_der( pk_rsa( *pk ), key, keylen ) ) == 0 )
{
return( 0 );
@@ -2702,7 +2724,10 @@
#endif /* POLARSSL_RSA_C */
#if defined(POLARSSL_ECP_C)
- if( ( ret = pk_set_type( pk, POLARSSL_PK_ECKEY ) ) == 0 &&
+ if( ( pk_info = pk_info_from_type( POLARSSL_PK_ECKEY ) ) == NULL )
+ return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG );
+
+ if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 ||
( ret = x509parse_key_sec1_der( pk_ec( *pk ), key, keylen ) ) == 0 )
{
return( 0 );
@@ -2769,10 +2794,12 @@
pk_context pk;
pk_init( &pk );
- pk_set_type( &pk, POLARSSL_PK_RSA );
ret = x509parse_key( &pk, key, keylen, pwd, pwdlen );
+ if( ret == 0 && ! pk_can_do( &pk, POLARSSL_PK_RSA ) )
+ ret = POLARSSL_ERR_PK_TYPE_MISMATCH;
+
if( ret == 0 )
rsa_copy( rsa, pk_rsa( pk ) );
else
@@ -2793,10 +2820,12 @@
pk_context pk;
pk_init( &pk );
- pk_set_type( &pk, POLARSSL_PK_RSA );
ret = x509parse_public_key( &pk, key, keylen );
+ if( ret == 0 && ! pk_can_do( &pk, POLARSSL_PK_RSA ) )
+ ret = POLARSSL_ERR_PK_TYPE_MISMATCH;
+
if( ret == 0 )
rsa_copy( rsa, pk_rsa( pk ) );
else