- Generalized external private key implementation handling (like PKCS#11) in SSL/TLS

diff --git a/ChangeLog b/ChangeLog
index d662bcf..bfe3421 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -40,6 +40,8 @@
      POLARSSL_MODE_CFB, to also handle different block size CFB modes.
    * Removed handling for SSLv2 Client Hello (as per RFC 5246 recommendation)
    * Revamped session resumption handling
+   * Generalized external private key implementation handling (like PKCS#11)
+     in SSL/TLS
 
 Bugfix
    * Fixed handling error in mpi_cmp_mpi() on longer B values (found by
diff --git a/include/polarssl/config.h b/include/polarssl/config.h
index 538ef81..543b96c 100644
--- a/include/polarssl/config.h
+++ b/include/polarssl/config.h
@@ -612,7 +612,7 @@
 /**
  * \def POLARSSL_PKCS11_C
  *
- * Enable support for PKCS#11 smartcard support.
+ * Enable wrapper for PKCS#11 smartcard support.
  *
  * Module:  library/ssl_srv.c
  * Caller:  library/ssl_cli.c
@@ -620,7 +620,7 @@
  *
  * Requires: POLARSSL_SSL_TLS_C
  *
- * This module is required for SSL/TLS PKCS #11 smartcard support.
+ * This module enables SSL/TLS PKCS #11 smartcard support.
  * Requires the presence of the PKCS#11 helper library (libpkcs11-helper)
 #define POLARSSL_PKCS11_C
  */
diff --git a/include/polarssl/pkcs11.h b/include/polarssl/pkcs11.h
index a65a72e..ddfae30 100644
--- a/include/polarssl/pkcs11.h
+++ b/include/polarssl/pkcs11.h
@@ -37,6 +37,14 @@
 
 #include <pkcs11-helper-1.0/pkcs11h-certificate.h>
 
+#if defined(_MSC_VER) && !defined(inline)
+#define inline _inline
+#else
+#if defined(__ARMCC_VERSION) && !defined(inline)
+#define inline __inline
+#endif /* __ARMCC_VERSION */
+#endif /*_MSC_VER */
+
 /**
  * Context for PKCS #11 private keys.
  */
@@ -121,6 +129,33 @@
                     const unsigned char *hash,
                     unsigned char *sig );
 
+/**
+ * SSL/TLS wrappers for PKCS#11 functions
+ */
+static inline int ssl_pkcs11_decrypt( void *ctx, int mode, size_t *olen,
+                        const unsigned char *input, unsigned char *output,
+                        unsigned int output_max_len )
+{
+    return pkcs11_decrypt( (pkcs11_context *) ctx, mode, olen, input, output,
+                           output_max_len );
+}
+
+static inline int ssl_pkcs11_sign( void *ctx, 
+                     int (*f_rng)(void *, unsigned char *, size_t), void *p_rng,
+                     int mode, int hash_id, unsigned int hashlen,
+                     const unsigned char *hash, unsigned char *sig )
+{
+    ((void) f_rng);
+    ((void) p_rng);
+    return pkcs11_sign( (pkcs11_context *) ctx, mode, hash_id,
+                        hashlen, hash, sig );
+}
+
+static inline size_t ssl_pkcs11_key_len( void *ctx )
+{
+    return ( (pkcs11_context *) ctx )->len;
+}
+
 #endif /* POLARSSL_PKCS11_C */
 
 #endif /* POLARSSL_PKCS11_H */
diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h
index fcf8a8f..62ffba2 100644
--- a/include/polarssl/ssl.h
+++ b/include/polarssl/ssl.h
@@ -42,10 +42,6 @@
 #include "dhm.h"
 #endif
 
-#if defined(POLARSSL_PKCS11_C)
-#include "pkcs11.h"
-#endif
-
 #if defined(POLARSSL_ZLIB_SUPPORT)
 #include "zlib.h"
 #endif
@@ -253,6 +249,20 @@
 
 #define TLS_EXT_RENEGOTIATION_INFO 0xFF01
 
+
+/*
+ * Generic function pointers for allowing external RSA private key
+ * implementations.
+ */
+typedef int (*rsa_decrypt_func)( void *ctx, int mode, size_t *olen,
+                        const unsigned char *input, unsigned char *output,
+                        size_t output_max_len ); 
+typedef int (*rsa_sign_func)( void *ctx,
+                     int (*f_rng)(void *, unsigned char *, size_t), void *p_rng,
+                     int mode, int hash_id, unsigned int hashlen,
+                     const unsigned char *hash, unsigned char *sig );
+typedef size_t (*rsa_key_len_func)( void *ctx );
+
 /*
  * SSL state machine
  */
@@ -446,10 +456,11 @@
     /*
      * PKI layer
      */
-    rsa_context *rsa_key;               /*!<  own RSA private key     */
-#if defined(POLARSSL_PKCS11_C)
-    pkcs11_context *pkcs11_key;         /*!<  own PKCS#11 RSA private key */
-#endif
+    void *rsa_key;                      /*!<  own RSA private key     */
+    rsa_decrypt_func rsa_decrypt;       /*!<  function for RSA decrypt*/
+    rsa_sign_func rsa_sign;             /*!<  function for RSA sign   */
+    rsa_key_len_func rsa_key_len;       /*!<  function for RSA key len*/
+
     x509_cert *own_cert;                /*!<  own X.509 certificate   */
     x509_cert *ca_chain;                /*!<  own trusted CA chain    */
     x509_crl *ca_crl;                   /*!<  trusted CA CRLs         */
@@ -722,17 +733,26 @@
 void ssl_set_own_cert( ssl_context *ssl, x509_cert *own_cert,
                        rsa_context *rsa_key );
 
-#if defined(POLARSSL_PKCS11_C)
 /**
- * \brief          Set own certificate and PKCS#11 private key
+ * \brief          Set own certificate and alternate non-PolarSSL private
+ *                 key and handling callbacks, such as the PKCS#11 wrappers
+ *                 or any other external private key handler.
+ *                 (see the respective RSA functions in rsa.h for documentation
+ *                 of the callback parameters, with the only change being
+ *                 that the rsa_context * is a void * in the callbacks)
  *
  * \param ssl      SSL context
  * \param own_cert own public certificate
- * \param pkcs11_key    own PKCS#11 RSA key
+ * \param rsa_key  alternate implementation private RSA key
+ * \param rsa_decrypt_func  alternate implementation of \c rsa_pkcs1_decrypt()
+ * \param rsa_sign_func     alternate implementation of \c rsa_pkcs1_sign()
+ * \param rsa_key_len_func  function returning length of RSA key in bytes
  */
-void ssl_set_own_cert_pkcs11( ssl_context *ssl, x509_cert *own_cert,
-                       pkcs11_context *pkcs11_key );
-#endif
+void ssl_set_own_cert_alt( ssl_context *ssl, x509_cert *own_cert,
+                           void *rsa_key,
+                           rsa_decrypt_func rsa_decrypt,
+                           rsa_sign_func rsa_sign,
+                           rsa_key_len_func rsa_key_len );
 
 #if defined(POLARSSL_DHM_C)
 /**
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index b44af2b..3e1b056 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -30,10 +30,6 @@
 #include "polarssl/debug.h"
 #include "polarssl/ssl.h"
 
-#if defined(POLARSSL_PKCS11_C)
-#include "polarssl/pkcs11.h"
-#endif /* defined(POLARSSL_PKCS11_C) */
-
 #include <stdlib.h>
 #include <stdio.h>
 #include <time.h>
@@ -1115,15 +1111,8 @@
 
     if( ssl->rsa_key == NULL )
     {
-#if defined(POLARSSL_PKCS11_C)
-        if( ssl->pkcs11_key == NULL )
-        {
-#endif /* defined(POLARSSL_PKCS11_C) */
-            SSL_DEBUG_MSG( 1, ( "got no private key" ) );
-            return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
-#if defined(POLARSSL_PKCS11_C)
-        }
-#endif /* defined(POLARSSL_PKCS11_C) */
+        SSL_DEBUG_MSG( 1, ( "got no private key" ) );
+        return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
     }
 
     /*
@@ -1132,11 +1121,7 @@
     ssl->handshake->calc_verify( ssl, hash );
 
     if ( ssl->rsa_key )
-        n = ssl->rsa_key->len;
-#if defined(POLARSSL_PKCS11_C)
-    else
-        n = ssl->pkcs11_key->len;
-#endif  /* defined(POLARSSL_PKCS11_C) */
+        n = ssl->rsa_key_len ( ssl->rsa_key );
 
     if( ssl->minor_ver == SSL_MINOR_VERSION_3 )
     {
@@ -1164,14 +1149,9 @@
 
     if( ssl->rsa_key )
     {
-        ret = rsa_pkcs1_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng,
-                              RSA_PRIVATE, hash_id,
-                              hashlen, hash, ssl->out_msg + 6 + offset );
-    } else {
-#if defined(POLARSSL_PKCS11_C)
-        ret = pkcs11_sign( ssl->pkcs11_key, RSA_PRIVATE, hash_id,
-                           hashlen, hash, ssl->out_msg + 6 + offset );
-#endif  /* defined(POLARSSL_PKCS11_C) */
+        ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng,
+                             RSA_PRIVATE, hash_id,
+                             hashlen, hash, ssl->out_msg + 6 + offset );
     }
 
     if (ret != 0)
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index 64b0d2d..e311458 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -30,10 +30,6 @@
 #include "polarssl/debug.h"
 #include "polarssl/ssl.h"
 
-#if defined(POLARSSL_PKCS11_C)
-#include "polarssl/pkcs11.h"
-#endif /* defined(POLARSSL_PKCS11_C) */
-
 #include <stdlib.h>
 #include <stdio.h>
 #include <time.h>
@@ -644,15 +640,8 @@
 
     if( ssl->rsa_key == NULL )
     {
-#if defined(POLARSSL_PKCS11_C)
-        if( ssl->pkcs11_key == NULL )
-        {
-#endif /* defined(POLARSSL_PKCS11_C) */
-            SSL_DEBUG_MSG( 1, ( "got no private key" ) );
-            return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
-#if defined(POLARSSL_PKCS11_C)
-        }
-#endif /* defined(POLARSSL_PKCS11_C) */
+        SSL_DEBUG_MSG( 1, ( "got no private key" ) );
+        return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
     }
 
     /*
@@ -738,11 +727,7 @@
     SSL_DEBUG_BUF( 3, "parameters hash", hash, hashlen );
 
     if ( ssl->rsa_key )
-        rsa_key_len = ssl->rsa_key->len;
-#if defined(POLARSSL_PKCS11_C)
-    else
-        rsa_key_len = ssl->pkcs11_key->len;
-#endif /* defined(POLARSSL_PKCS11_C) */
+        rsa_key_len = ssl->rsa_key_len( ssl->rsa_key );
 
     if( ssl->minor_ver == SSL_MINOR_VERSION_3 )
     {
@@ -758,16 +743,11 @@
 
     if ( ssl->rsa_key )
     {
-        ret = rsa_pkcs1_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng,
-                              RSA_PRIVATE,
-                              hash_id, hashlen, hash, ssl->out_msg + 6 + n );
+        ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng,
+                             RSA_PRIVATE,
+                             hash_id, hashlen, hash,
+                             ssl->out_msg + 6 + n );
     }
-#if defined(POLARSSL_PKCS11_C)
-    else {
-        ret = pkcs11_sign( ssl->pkcs11_key, RSA_PRIVATE,
-                              hash_id, hashlen, hash, ssl->out_msg + 6 + n );
-    }
-#endif  /* defined(POLARSSL_PKCS11_C) */
 
     if( ret != 0 )
     {
@@ -898,15 +878,8 @@
     {
         if( ssl->rsa_key == NULL )
         {
-#if defined(POLARSSL_PKCS11_C)
-                if( ssl->pkcs11_key == NULL )
-                {
-#endif
-                    SSL_DEBUG_MSG( 1, ( "got no private key" ) );
-                    return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
-#if defined(POLARSSL_PKCS11_C)
-                }
-#endif
+            SSL_DEBUG_MSG( 1, ( "got no private key" ) );
+            return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
         }
 
         /*
@@ -914,11 +887,7 @@
          */
         i = 4;
         if( ssl->rsa_key )
-            n = ssl->rsa_key->len;
-#if defined(POLARSSL_PKCS11_C)
-        else
-            n = ssl->pkcs11_key->len;
-#endif
+            n = ssl->rsa_key_len( ssl->rsa_key );
         ssl->handshake->pmslen = 48;
 
         if( ssl->minor_ver != SSL_MINOR_VERSION_0 )
@@ -939,21 +908,12 @@
         }
 
         if( ssl->rsa_key ) {
-            ret = rsa_pkcs1_decrypt(  ssl->rsa_key, RSA_PRIVATE,
-                                     &ssl->handshake->pmslen,
-                                      ssl->in_msg + i,
-                                      ssl->handshake->premaster,
-                                      sizeof(ssl->handshake->premaster) );
+            ret = ssl->rsa_decrypt( ssl->rsa_key, RSA_PRIVATE,
+                                   &ssl->handshake->pmslen,
+                                    ssl->in_msg + i,
+                                    ssl->handshake->premaster,
+                                    sizeof(ssl->handshake->premaster) );
         }
-#if defined(POLARSSL_PKCS11_C)
-        else {
-            ret = pkcs11_decrypt(  ssl->pkcs11_key, RSA_PRIVATE,
-                                  &ssl->handshake->pmslen,
-                                   ssl->in_msg + i,
-                                   ssl->handshake->premaster,
-                                   sizeof(ssl->handshake->premaster) );
-        }
-#endif  /* defined(POLARSSL_PKCS11_C) */
 
         if( ret != 0 || ssl->handshake->pmslen != 48 ||
             ssl->handshake->premaster[0] != ssl->max_major_ver ||
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 6192004..cc0f65c 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -65,6 +65,28 @@
 int (*ssl_hw_record_finish)(ssl_context *ssl) = NULL;
 #endif
 
+static int ssl_rsa_decrypt( void *ctx, int mode, size_t *olen,
+                        const unsigned char *input, unsigned char *output,
+                        size_t output_max_len )
+{
+    return rsa_pkcs1_decrypt( (rsa_context *) ctx, mode, olen, input, output,
+                              output_max_len );
+}
+
+static int ssl_rsa_sign( void *ctx,
+                    int (*f_rng)(void *, unsigned char *, size_t), void *p_rng,
+                    int mode, int hash_id, unsigned int hashlen,
+                    const unsigned char *hash, unsigned char *sig )
+{
+    return rsa_pkcs1_sign( (rsa_context *) ctx, f_rng, p_rng, mode, hash_id,
+                           hashlen, hash, sig );
+}
+
+static size_t ssl_rsa_key_len( void *ctx )
+{
+    return ( (rsa_context *) ctx )->len;
+}
+
 /*
  * Key material generation
  */
@@ -2826,6 +2848,10 @@
 
     memset( ssl, 0, sizeof( ssl_context ) );
 
+    ssl->rsa_decrypt = ssl_rsa_decrypt;
+    ssl->rsa_sign = ssl_rsa_sign;
+    ssl->rsa_key_len = ssl_rsa_key_len;
+
     ssl->in_ctr = (unsigned char *) malloc( len );
     ssl->in_hdr = ssl->in_ctr +  8;
     ssl->in_msg = ssl->in_ctr + 13;
@@ -3002,14 +3028,19 @@
     ssl->rsa_key    = rsa_key;
 }
 
-#if defined(POLARSSL_PKCS11_C)
-void ssl_set_own_cert_pkcs11( ssl_context *ssl, x509_cert *own_cert,
-                       pkcs11_context *pkcs11_key )
+void ssl_set_own_cert_alt( ssl_context *ssl, x509_cert *own_cert,
+                           void *rsa_key,
+                           rsa_decrypt_func rsa_decrypt,
+                           rsa_sign_func rsa_sign,
+                           rsa_key_len_func rsa_key_len )
 {
     ssl->own_cert   = own_cert;
-    ssl->pkcs11_key = pkcs11_key;
+    ssl->rsa_key    = rsa_key;
+    ssl->rsa_decrypt = rsa_decrypt;
+    ssl->rsa_sign = rsa_sign;
+    ssl->rsa_key_len = rsa_key_len;
 }
-#endif
+
 
 #if defined(POLARSSL_DHM_C)
 int ssl_set_dh_param( ssl_context *ssl, const char *dhm_P, const char *dhm_G )