Add and use pk_encrypt(), pk_decrypt()
diff --git a/library/pk.c b/library/pk.c
index 6f68c73..6e60574 100644
--- a/library/pk.c
+++ b/library/pk.c
@@ -153,6 +153,42 @@
 }
 
 /*
+ * Decrypt message
+ */
+int pk_decrypt( pk_context *ctx,
+                const unsigned char *input, size_t ilen,
+                unsigned char *output, size_t *olen, size_t osize,
+                int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
+{
+    if( ctx == NULL || ctx->pk_info == NULL )
+        return( POLARSSL_ERR_PK_BAD_INPUT_DATA );
+
+    if( ctx->pk_info->decrypt_func == NULL )
+        return( POLARSSL_ERR_PK_TYPE_MISMATCH );
+
+    return( ctx->pk_info->decrypt_func( ctx->pk_ctx, input, ilen,
+                output, olen, osize, f_rng, p_rng ) );
+}
+
+/*
+ * Encrypt message
+ */
+int pk_encrypt( pk_context *ctx,
+                const unsigned char *input, size_t ilen,
+                unsigned char *output, size_t *olen, size_t osize,
+                int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
+{
+    if( ctx == NULL || ctx->pk_info == NULL )
+        return( POLARSSL_ERR_PK_BAD_INPUT_DATA );
+
+    if( ctx->pk_info->encrypt_func == NULL )
+        return( POLARSSL_ERR_PK_TYPE_MISMATCH );
+
+    return( ctx->pk_info->encrypt_func( ctx->pk_ctx, input, ilen,
+                output, olen, osize, f_rng, p_rng ) );
+}
+
+/*
  * Get key size in bits
  */
 size_t pk_get_size( const pk_context *ctx )
diff --git a/library/pk_wrap.c b/library/pk_wrap.c
index eb91d89..2c55ce0 100644
--- a/library/pk_wrap.c
+++ b/library/pk_wrap.c
@@ -80,6 +80,34 @@
                 md_alg, hash_len, hash, sig ) );
 }
 
+static int rsa_decrypt_wrap( void *ctx,
+                    const unsigned char *input, size_t ilen,
+                    unsigned char *output, size_t *olen, size_t osize,
+                    int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
+{
+    ((void) f_rng);
+    ((void) p_rng);
+
+    if( ilen != ((rsa_context *) ctx)->len )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    return( rsa_pkcs1_decrypt( (rsa_context *) ctx,
+                RSA_PRIVATE, olen, input, output, osize ) );
+}
+
+static int rsa_encrypt_wrap( void *ctx,
+                    const unsigned char *input, size_t ilen,
+                    unsigned char *output, size_t *olen, size_t osize,
+                    int (*f_rng)(void *, unsigned char *, size_t), void *p_rng )
+{
+    ((void) osize);
+
+    *olen = ((rsa_context *) ctx)->len;
+
+    return( rsa_pkcs1_encrypt( (rsa_context *) ctx,
+                f_rng, p_rng, RSA_PUBLIC, ilen, input, output ) );
+}
+
 static void *rsa_alloc_wrap( void )
 {
     void *ctx = polarssl_malloc( sizeof( rsa_context ) );
@@ -116,6 +144,8 @@
     rsa_can_do,
     rsa_verify_wrap,
     rsa_sign_wrap,
+    rsa_decrypt_wrap,
+    rsa_encrypt_wrap,
     rsa_alloc_wrap,
     rsa_free_wrap,
     rsa_debug,
@@ -222,6 +252,8 @@
     NULL,
     NULL,
 #endif
+    NULL,
+    NULL,
     eckey_alloc_wrap,
     eckey_free_wrap,
     eckey_debug,
@@ -243,6 +275,8 @@
     eckeydh_can_do,
     NULL,
     NULL,
+    NULL,
+    NULL,
     eckey_alloc_wrap,       /* Same underlying key structure */
     eckey_free_wrap,        /* Same underlying key structure */
     eckey_debug,            /* Same underlying key structure */
@@ -299,6 +333,8 @@
     ecdsa_can_do,
     ecdsa_verify_wrap,
     ecdsa_sign_wrap,
+    NULL,
+    NULL,
     ecdsa_alloc_wrap,
     ecdsa_free_wrap,
     eckey_debug,        /* Compatible key structures */
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index 829e46b..cd77eb8 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -1870,26 +1870,24 @@
             return( POLARSSL_ERR_SSL_PK_TYPE_MISMATCH );
         }
 
-        i = 4;
-        n = pk_get_size( &ssl->session_negotiate->peer_cert->pk ) / 8;
+        i = ssl->minor_ver == SSL_MINOR_VERSION_0 ? 4 : 6;
 
-        if( ssl->minor_ver != SSL_MINOR_VERSION_0 )
-        {
-            i += 2;
-            ssl->out_msg[4] = (unsigned char)( n >> 8 );
-            ssl->out_msg[5] = (unsigned char)( n      );
-        }
-
-        ret = rsa_pkcs1_encrypt(
-                pk_rsa( ssl->session_negotiate->peer_cert->pk ),
-                ssl->f_rng, ssl->p_rng, RSA_PUBLIC,
-                ssl->handshake->pmslen, ssl->handshake->premaster,
-                ssl->out_msg + i );
+        ret = pk_encrypt( &ssl->session_negotiate->peer_cert->pk,
+                ssl->handshake->premaster, ssl->handshake->pmslen,
+                ssl->out_msg + i, &n, SSL_BUFFER_LEN,
+                ssl->f_rng, ssl->p_rng );
         if( ret != 0 )
         {
             SSL_DEBUG_RET( 1, "rsa_pkcs1_encrypt", ret );
             return( ret );
         }
+
+        if( ssl->minor_ver != SSL_MINOR_VERSION_0 )
+        {
+            ssl->out_msg[4] = (unsigned char)( n >> 8 );
+            ssl->out_msg[5] = (unsigned char)( n      );
+        }
+
     }
     else
 #endif /* POLARSSL_KEY_EXCHANGE_RSA_ENABLED */
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index ffd754e..6fb16ec 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -2259,9 +2259,9 @@
     int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE;
     size_t i, n = 0;
 
-    if( ssl->rsa_key == NULL )
+    if( ! pk_can_do( ssl->pk_key, POLARSSL_PK_RSA ) )
     {
-        SSL_DEBUG_MSG( 1, ( "got no private key" ) );
+        SSL_DEBUG_MSG( 1, ( "got no RSA private key" ) );
         return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
     }
 
@@ -2269,8 +2269,7 @@
      * Decrypt the premaster using own private RSA key
      */
     i = 4;
-    if( ssl->rsa_key )
-        n = ssl->rsa_key_len( ssl->rsa_key );
+    n = ssl->rsa_key_len( ssl->rsa_key );
     ssl->handshake->pmslen = 48;
 
     if( ssl->minor_ver != SSL_MINOR_VERSION_0 )
@@ -2290,13 +2289,21 @@
         return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_KEY_EXCHANGE );
     }
 
-    if( ssl->rsa_key ) {
+    if( ssl->rsa_use_alt ) {
         ret = ssl->rsa_decrypt( ssl->rsa_key, RSA_PRIVATE,
                                &ssl->handshake->pmslen,
                                 ssl->in_msg + i,
                                 ssl->handshake->premaster,
                                 sizeof(ssl->handshake->premaster) );
     }
+    else
+    {
+        ret = pk_decrypt( ssl->pk_key,
+                          ssl->in_msg + i, n,
+                          ssl->handshake->premaster, &ssl->handshake->pmslen,
+                          sizeof(ssl->handshake->premaster),
+                          ssl->f_rng, ssl->p_rng );
+    }
 
     if( ret != 0 || ssl->handshake->pmslen != 48 ||
         ssl->handshake->premaster[0] != ssl->handshake->max_major_ver ||