Relax checks on RSA mode for public key operations
diff --git a/include/polarssl/rsa.h b/include/polarssl/rsa.h
index 1c697fb..c57ff97 100644
--- a/include/polarssl/rsa.h
+++ b/include/polarssl/rsa.h
@@ -126,6 +126,17 @@
  *
  * \note           The hash_id parameter is actually ignored
  *                 when using RSA_PKCS_V15 padding.
+ *
+ * \note           Choice of padding mode is strictly enforced for private key
+ *                 operations, since there might be security concerns in
+ *                 mixing padding modes. For public key operations it's merely
+ *                 a default value, which can be overriden by calling specific
+ *                 rsa_rsaes_xxx or rsa_rsassa_xxx functions.
+ *
+ * \note           The chosen hash is always used for OEAP encryption.
+ *                 For PSS signatures, it's always used for making signatures,
+ *                 but can be overriden (and always is, if set to
+ *                 POLARSSL_MD_NONE) for verifying them.
  */
 void rsa_init( rsa_context *ctx,
                int padding,
@@ -133,16 +144,11 @@
 
 /**
  * \brief          Set padding for an already initialized RSA context
- *
- *                 Note: Set padding to RSA_PKCS_V21 for the RSAES-OAEP
- *                 encryption scheme and the RSASSA-PSS signature scheme.
+ *                 See \c rsa_init() for details.
  *
  * \param ctx      RSA context to be set
  * \param padding  RSA_PKCS_V15 or RSA_PKCS_V21
  * \param hash_id  RSA_PKCS_V21 hash identifier
- *
- * \note           The hash_id parameter is actually ignored
- *                 when using RSA_PKCS_V15 padding.
  */
 void rsa_set_padding( rsa_context *ctx, int padding, int hash_id);
 
@@ -405,11 +411,8 @@
  * \note           The "sig" buffer must be as large as the size
  *                 of ctx->N (eg. 128 bytes if RSA-1024 is used).
  *
- * \note           In case of PKCS#1 v2.1 encoding keep in mind that
- *                 the hash_id in the RSA context is the one used for the
- *                 encoding. hash_id in the function call is the type of hash
- *                 that is encoded. According to RFC 3447 it is advised to
- *                 keep both hashes the same.
+ * \note           In case of PKCS#1 v2.1 encoding, see comments on
+ * \note           \c rsa_rsassa_pss_sign() for details on md_alg and hash_id.
  */
 int rsa_pkcs1_sign( rsa_context *ctx,
                     int (*f_rng)(void *, unsigned char *, size_t),
@@ -466,9 +469,8 @@
  * \note           The "sig" buffer must be as large as the size
  *                 of ctx->N (eg. 128 bytes if RSA-1024 is used).
  *
- * \note           In case of PKCS#1 v2.1 encoding keep in mind that
- *                 the hash_id in the RSA context is the one used for the
- *                 encoding. hash_id in the function call is the type of hash
+ * \note           The hash_id in the RSA context is the one used for the
+ *                 encoding. md_alg in the function call is the type of hash
  *                 that is encoded. According to RFC 3447 it is advised to
  *                 keep both hashes the same.
  */
@@ -501,11 +503,8 @@
  * \note           The "sig" buffer must be as large as the size
  *                 of ctx->N (eg. 128 bytes if RSA-1024 is used).
  *
- * \note           In case of PKCS#1 v2.1 encoding keep in mind that
- *                 the hash_id in the RSA context is the one used for the
- *                 verification. hash_id in the function call is the type of
- *                 hash that is verified. According to RFC 3447 it is advised to
- *                 keep both hashes the same.
+ * \note           In case of PKCS#1 v2.1 encoding, see comments on
+ *                 \c rsa_rsassa_pss_verify() about md_alg and hash_id.
  */
 int rsa_pkcs1_verify( rsa_context *ctx,
                       int (*f_rng)(void *, unsigned char *, size_t),
@@ -561,11 +560,11 @@
  * \note           The "sig" buffer must be as large as the size
  *                 of ctx->N (eg. 128 bytes if RSA-1024 is used).
  *
- * \note           In case of PKCS#1 v2.1 encoding keep in mind that
- *                 the hash_id in the RSA context is the one used for the
- *                 verification. hash_id in the function call is the type of
+ * \note           The hash_id in the RSA context is the one used for the
+ *                 verification. md_alg in the function call is the type of
  *                 hash that is verified. According to RFC 3447 it is advised to
- *                 keep both hashes the same.
+ *                 keep both hashes the same. If hash_id in the RSA context is
+ *                 unset, the md_alg from the function call is used.
  */
 int rsa_rsassa_pss_verify( rsa_context *ctx,
                            int (*f_rng)(void *, unsigned char *, size_t),
diff --git a/library/rsa.c b/library/rsa.c
index e3cac12..1e84d9f 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -505,7 +505,10 @@
     const md_info_t *md_info;
     md_context_t md_ctx;
 
-    if( ctx->padding != RSA_PKCS_V21 || f_rng == NULL )
+    if( mode == RSA_PRIVATE && ctx->padding != RSA_PKCS_V21 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    if( f_rng == NULL )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     md_info = md_info_from_type( ctx->hash_id );
@@ -515,7 +518,7 @@
     olen = ctx->len;
     hlen = md_get_size( md_info );
 
-    if( olen < ilen + 2 * hlen + 2 || f_rng == NULL )
+    if( olen < ilen + 2 * hlen + 2 )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     memset( output, 0, olen );
@@ -572,7 +575,10 @@
     int ret;
     unsigned char *p = output;
 
-    if( ctx->padding != RSA_PKCS_V15 || f_rng == NULL )
+    if( mode == RSA_PRIVATE && ctx->padding != RSA_PKCS_V15 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    if( f_rng == NULL )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     olen = ctx->len;
@@ -675,7 +681,7 @@
     /*
      * Parameters sanity checks
      */
-    if( ctx->padding != RSA_PKCS_V21 )
+    if( mode == RSA_PRIVATE && ctx->padding != RSA_PKCS_V21 )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     ilen = ctx->len;
@@ -780,7 +786,7 @@
     unsigned char *p, bad, pad_done = 0;
     unsigned char buf[POLARSSL_MPI_MAX_SIZE];
 
-    if( ctx->padding != RSA_PKCS_V15 )
+    if( mode == RSA_PRIVATE && ctx->padding != RSA_PKCS_V15 )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     ilen = ctx->len;
@@ -901,7 +907,10 @@
     const md_info_t *md_info;
     md_context_t md_ctx;
 
-    if( ctx->padding != RSA_PKCS_V21 || f_rng == NULL )
+    if( mode == RSA_PRIVATE && ctx->padding != RSA_PKCS_V21 )
+        return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
+
+    if( f_rng == NULL )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     olen = ctx->len;
@@ -995,7 +1004,7 @@
     unsigned char *p = sig;
     const char *oid;
 
-    if( ctx->padding != RSA_PKCS_V15 )
+    if( mode == RSA_PRIVATE && ctx->padding != RSA_PKCS_V15 )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     olen = ctx->len;
@@ -1117,7 +1126,7 @@
     const md_info_t *md_info;
     md_context_t md_ctx;
 
-    if( ctx->padding != RSA_PKCS_V21 )
+    if( mode == RSA_PRIVATE && ctx->padding != RSA_PKCS_V21 )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     siglen = ctx->len;
@@ -1148,7 +1157,8 @@
         hashlen = md_get_size( md_info );
     }
 
-    md_info = md_info_from_type( ctx->hash_id );
+    md_info = md_info_from_type( ctx->hash_id != POLARSSL_MD_NONE ?
+                                 ctx->hash_id : md_alg );
     if( md_info == NULL )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
@@ -1227,7 +1237,7 @@
     const md_info_t *md_info;
     asn1_buf oid;
 
-    if( ctx->padding != RSA_PKCS_V15 )
+    if( mode == RSA_PRIVATE && ctx->padding != RSA_PKCS_V15 )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     siglen = ctx->len;