Adapt x509write_csr prototypes for PK
diff --git a/library/x509write.c b/library/x509write.c
index 7c4ca33..e1f68dc 100644
--- a/library/x509write.c
+++ b/library/x509write.c
@@ -156,16 +156,9 @@
     ctx->md_alg = md_alg;
 }
 
-// TODO: take a pk_context
-// TODO: return int
-void x509write_csr_set_rsa_key( x509write_csr *ctx, rsa_context *rsa )
+void x509write_csr_set_key( x509write_csr *ctx, pk_context *key )
 {
-    // temporary
-    ctx->key = polarssl_malloc( sizeof( pk_context ) );
-
-    // TODO: check errors
-    pk_init_ctx( ctx->key, pk_info_from_type( POLARSSL_PK_RSA ) );
-    rsa_copy( pk_rsa( *ctx->key ), rsa );
+    ctx->key = key;
 }
 
 int x509write_csr_set_subject_name( x509write_csr *ctx, char *subject_name )
@@ -698,7 +691,9 @@
     return( len );
 }
 
-int x509write_csr_der( x509write_csr *ctx, unsigned char *buf, size_t size )
+int x509write_csr_der( x509write_csr *ctx, unsigned char *buf, size_t size,
+                       int (*f_rng)(void *, unsigned char *, size_t),
+                       void *p_rng )
 {
     int ret;
     const char *sig_oid;
@@ -761,7 +756,7 @@
     md( md_info_from_type( ctx->md_alg ), c, len, hash );
 
     if( ( ret = pk_sign( ctx->key, ctx->md_alg, hash, 0, sig, &sig_len,
-                         NULL, NULL ) ) != 0 ||
+                         f_rng, p_rng ) ) != 0 ||
         ( ret = oid_get_oid_by_sig_alg( pk_get_type( ctx->key ), ctx->md_alg,
                                         &sig_oid, &sig_oid_len ) ) != 0 )
     {
@@ -1006,13 +1001,15 @@
     return( 0 );
 }
 
-int x509write_csr_pem( x509write_csr *ctx, unsigned char *buf, size_t size )
+int x509write_csr_pem( x509write_csr *ctx, unsigned char *buf, size_t size,
+                       int (*f_rng)(void *, unsigned char *, size_t),
+                       void *p_rng )
 {
     int ret;
     unsigned char output_buf[4096];
 
-    if( ( ret = x509write_csr_der( ctx, output_buf,
-                                      sizeof(output_buf) ) ) < 0 )
+    if( ( ret = x509write_csr_der( ctx, output_buf, sizeof(output_buf),
+                                   f_rng, p_rng ) ) < 0 )
     {
         return( ret );
     }