Change mbedtls_rsa_set_padding() signature

mbedtls_rsa_set_padding() now returns the error
code MBEDTLS_ERR_RSA_INVALID_PADDING when
padding parameters are invalid.

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
diff --git a/include/mbedtls/rsa.h b/include/mbedtls/rsa.h
index ba00bff..eeb846e 100644
--- a/include/mbedtls/rsa.h
+++ b/include/mbedtls/rsa.h
@@ -399,9 +399,15 @@
  * \param padding  The padding mode to use. This must be either
  *                 #MBEDTLS_RSA_PKCS_V15 or #MBEDTLS_RSA_PKCS_V21.
  * \param hash_id  The #MBEDTLS_RSA_PKCS_V21 hash identifier.
+ *                 #MBEDTLS_MD_NONE is accepted by this function but may be
+ *                 not suitable for some operations.
+ *
+ * \return         \c 0 on success.
+ * \return         #MBEDTLS_ERR_RSA_INVALID_PADDING failure:
+ *                 \p padding or \p hash_id is invalid.
  */
-void mbedtls_rsa_set_padding( mbedtls_rsa_context *ctx, int padding,
-                              int hash_id );
+int mbedtls_rsa_set_padding( mbedtls_rsa_context *ctx, int padding,
+                             mbedtls_md_type_t hash_id );
 
 /**
  * \brief          This function retrieves the length of RSA modulus in Bytes.
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 214c405..7921eb2 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -2838,13 +2838,14 @@
 }
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
-static void psa_rsa_oaep_set_padding_mode( psa_algorithm_t alg,
-                                           mbedtls_rsa_context *rsa )
+static int psa_rsa_oaep_set_padding_mode( psa_algorithm_t alg,
+                                          mbedtls_rsa_context *rsa )
 {
     psa_algorithm_t hash_alg = PSA_ALG_RSA_OAEP_GET_HASH( alg );
     const mbedtls_md_info_t *md_info = mbedtls_md_info_from_psa( hash_alg );
     mbedtls_md_type_t md_alg = mbedtls_md_get_type( md_info );
-    mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg );
+
+    return( mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg ) );
 }
 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) */
 
@@ -2917,7 +2918,11 @@
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
         if( PSA_ALG_IS_RSA_OAEP( alg ) )
         {
-            psa_rsa_oaep_set_padding_mode( alg, rsa );
+            status = mbedtls_to_psa_error(
+                         psa_rsa_oaep_set_padding_mode( alg, rsa ) );
+            if( status != PSA_SUCCESS )
+                goto rsa_exit;
+
             status = mbedtls_to_psa_error(
                 mbedtls_rsa_rsaes_oaep_encrypt( rsa,
                                                 mbedtls_psa_get_random,
@@ -3023,7 +3028,11 @@
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP)
         if( PSA_ALG_IS_RSA_OAEP( alg ) )
         {
-            psa_rsa_oaep_set_padding_mode( alg, rsa );
+            status = mbedtls_to_psa_error(
+                         psa_rsa_oaep_set_padding_mode( alg, rsa ) );
+            if( status != PSA_SUCCESS )
+                goto rsa_exit;
+
             status = mbedtls_to_psa_error(
                 mbedtls_rsa_rsaes_oaep_decrypt( rsa,
                                                 mbedtls_psa_get_random,
diff --git a/library/psa_crypto_rsa.c b/library/psa_crypto_rsa.c
index b5aec20..33e22e7 100644
--- a/library/psa_crypto_rsa.c
+++ b/library/psa_crypto_rsa.c
@@ -416,29 +416,36 @@
 #if defined(BUILTIN_ALG_RSA_PKCS1V15_SIGN)
     if( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) )
     {
-        mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V15,
-                                 MBEDTLS_MD_NONE );
-        ret = mbedtls_rsa_pkcs1_sign( rsa,
-                                      mbedtls_psa_get_random,
-                                      MBEDTLS_PSA_RANDOM_STATE,
-                                      md_alg,
-                                      (unsigned int) hash_length,
-                                      hash,
-                                      signature );
+        ret = mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V15,
+                                       MBEDTLS_MD_NONE );
+        if( ret == 0 )
+        {
+            ret = mbedtls_rsa_pkcs1_sign( rsa,
+                                          mbedtls_psa_get_random,
+                                          MBEDTLS_PSA_RANDOM_STATE,
+                                          md_alg,
+                                          (unsigned int) hash_length,
+                                          hash,
+                                          signature );
+        }
     }
     else
 #endif /* BUILTIN_ALG_RSA_PKCS1V15_SIGN */
 #if defined(BUILTIN_ALG_RSA_PSS)
     if( PSA_ALG_IS_RSA_PSS( alg ) )
     {
-        mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg );
-        ret = mbedtls_rsa_rsassa_pss_sign( rsa,
-                                           mbedtls_psa_get_random,
-                                           MBEDTLS_PSA_RANDOM_STATE,
-                                           MBEDTLS_MD_NONE,
-                                           (unsigned int) hash_length,
-                                           hash,
-                                           signature );
+        ret = mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg );
+
+        if( ret == 0 )
+        {
+            ret = mbedtls_rsa_rsassa_pss_sign( rsa,
+                                               mbedtls_psa_get_random,
+                                               MBEDTLS_PSA_RANDOM_STATE,
+                                               MBEDTLS_MD_NONE,
+                                               (unsigned int) hash_length,
+                                               hash,
+                                               signature );
+        }
     }
     else
 #endif /* BUILTIN_ALG_RSA_PSS */
@@ -489,25 +496,31 @@
 #if defined(BUILTIN_ALG_RSA_PKCS1V15_SIGN)
     if( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) )
     {
-        mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V15,
-                                 MBEDTLS_MD_NONE );
-        ret = mbedtls_rsa_pkcs1_verify( rsa,
-                                        md_alg,
-                                        (unsigned int) hash_length,
-                                        hash,
-                                        signature );
+        ret = mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V15,
+                                       MBEDTLS_MD_NONE );
+        if( ret == 0 )
+        {
+            ret = mbedtls_rsa_pkcs1_verify( rsa,
+                                            md_alg,
+                                            (unsigned int) hash_length,
+                                            hash,
+                                            signature );
+        }
     }
     else
 #endif /* BUILTIN_ALG_RSA_PKCS1V15_SIGN */
 #if defined(BUILTIN_ALG_RSA_PSS)
     if( PSA_ALG_IS_RSA_PSS( alg ) )
     {
-        mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg );
-        ret = mbedtls_rsa_rsassa_pss_verify( rsa,
-                                             MBEDTLS_MD_NONE,
-                                             (unsigned int) hash_length,
-                                             hash,
-                                             signature );
+        ret = mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg );
+        if( ret == 0 )
+        {
+            ret = mbedtls_rsa_rsassa_pss_verify( rsa,
+                                                 MBEDTLS_MD_NONE,
+                                                 (unsigned int) hash_length,
+                                                 hash,
+                                                 signature );
+        }
     }
     else
 #endif /* BUILTIN_ALG_RSA_PSS */
diff --git a/library/rsa.c b/library/rsa.c
index 36424bd..5a1ae79 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -500,15 +500,27 @@
 /*
  * Set padding for an existing RSA context
  */
-void mbedtls_rsa_set_padding( mbedtls_rsa_context *ctx, int padding,
-                              int hash_id )
+int mbedtls_rsa_set_padding( mbedtls_rsa_context *ctx, int padding,
+                             mbedtls_md_type_t hash_id )
 {
-    RSA_VALIDATE( ctx != NULL );
-    RSA_VALIDATE( padding == MBEDTLS_RSA_PKCS_V15 ||
-                  padding == MBEDTLS_RSA_PKCS_V21 );
+    if( ( padding != MBEDTLS_RSA_PKCS_V15 ) &&
+        ( padding != MBEDTLS_RSA_PKCS_V21 ) )
+        return( MBEDTLS_ERR_RSA_INVALID_PADDING );
+
+    if( ( padding == MBEDTLS_RSA_PKCS_V21 ) &&
+        ( hash_id != MBEDTLS_MD_NONE ) )
+    {
+        const mbedtls_md_info_t *md_info;
+
+        md_info = mbedtls_md_info_from_type( hash_id );
+        if( md_info == NULL )
+            return( MBEDTLS_ERR_RSA_INVALID_PADDING );
+    }
 
     ctx->padding = padding;
     ctx->hash_id = hash_id;
+
+    return( 0 );
 }
 
 /*
diff --git a/programs/pkey/rsa_sign_pss.c b/programs/pkey/rsa_sign_pss.c
index 9d5053a..e7fcf51 100644
--- a/programs/pkey/rsa_sign_pss.c
+++ b/programs/pkey/rsa_sign_pss.c
@@ -115,7 +115,13 @@
         goto exit;
     }
 
-    mbedtls_rsa_set_padding( mbedtls_pk_rsa( pk ), MBEDTLS_RSA_PKCS_V21, MBEDTLS_MD_SHA256 );
+    if( ( ret = mbedtls_rsa_set_padding( mbedtls_pk_rsa( pk ),
+                                         MBEDTLS_RSA_PKCS_V21,
+                                         MBEDTLS_MD_SHA256 ) ) != 0 )
+    {
+        mbedtls_printf( " failed\n  ! Invalid padding\n" );
+        goto exit;
+    }
 
     /*
      * Compute the SHA-256 hash of the input file,
diff --git a/programs/pkey/rsa_verify_pss.c b/programs/pkey/rsa_verify_pss.c
index 81b0fd6..527d799 100644
--- a/programs/pkey/rsa_verify_pss.c
+++ b/programs/pkey/rsa_verify_pss.c
@@ -98,7 +98,13 @@
         goto exit;
     }
 
-    mbedtls_rsa_set_padding( mbedtls_pk_rsa( pk ), MBEDTLS_RSA_PKCS_V21, MBEDTLS_MD_SHA256 );
+    if( ( ret = mbedtls_rsa_set_padding( mbedtls_pk_rsa( pk ),
+                                         MBEDTLS_RSA_PKCS_V21,
+                                         MBEDTLS_MD_SHA256 ) ) != 0 )
+    {
+        mbedtls_printf( " failed\n  ! Invalid padding\n" );
+        goto exit;
+    }
 
     /*
      * Extract the RSA signature from the file
diff --git a/tests/suites/test_suite_rsa.data b/tests/suites/test_suite_rsa.data
index 2512ef2..cc5a047 100644
--- a/tests/suites/test_suite_rsa.data
+++ b/tests/suites/test_suite_rsa.data
@@ -1,3 +1,6 @@
+RSA parameter validation
+rsa_invalid_param:
+
 RSA init-free-free
 rsa_init_free:0
 
diff --git a/tests/suites/test_suite_rsa.function b/tests/suites/test_suite_rsa.function
index 9cf2fcf..e057dfb 100644
--- a/tests/suites/test_suite_rsa.function
+++ b/tests/suites/test_suite_rsa.function
@@ -18,6 +18,30 @@
  */
 
 /* BEGIN_CASE */
+void rsa_invalid_param( )
+{
+    mbedtls_rsa_context ctx;
+    const int invalid_padding = 42;
+    const int invalid_hash_id = 0xff;
+
+    mbedtls_rsa_init( &ctx, MBEDTLS_RSA_PKCS_V15, MBEDTLS_MD_NONE );
+
+    TEST_EQUAL( mbedtls_rsa_set_padding( &ctx,
+                                         invalid_padding,
+                                         MBEDTLS_MD_NONE ),
+                MBEDTLS_ERR_RSA_INVALID_PADDING );
+
+    TEST_EQUAL( mbedtls_rsa_set_padding( &ctx,
+                                         MBEDTLS_RSA_PKCS_V21,
+                                         invalid_hash_id ),
+                MBEDTLS_ERR_RSA_INVALID_PADDING );
+
+exit:
+    mbedtls_rsa_free( &ctx );
+}
+/* END_CASE */
+
+/* BEGIN_CASE */
 void rsa_init_free( int reinit )
 {
     mbedtls_rsa_context ctx;