rsa: Add support for RSA-PSS

The PKCS#1 standards, which define RSA signatures, are currently at
version 2.2.  Starting in v2.1, the standard defines a new signature
method RSA-PSS, which has a stronger security proof than the signature
method used in earlier versions.  The standard recommends that RSA-PSS
be used in new designs, instead of the older algorithm.

This patch implements RSA-PSS verification for a specific set of
parameters:

    - RSA-2048
    - SHA256 for both the message digest and the internal hash
    - 32-byte salt
    - 2047 bit message

Although RSA-PSS supports other parameters, due to size constraints,
this verificatino code only supports these specific parameters, and
signatures with other parameters will be considered invalid.

To encourage the use of the more secure algorithm, the default build
configuration is RSA-PSS.  BOOTUTIL_RSA_PKCS1_15 needs to be defined in
order to support the older signature algorithm.
diff --git a/boot/bootutil/include/bootutil/image.h b/boot/bootutil/include/bootutil/image.h
index 70ce7fb..b2b636a 100644
--- a/boot/bootutil/include/bootutil/image.h
+++ b/boot/bootutil/include/bootutil/image.h
@@ -42,6 +42,7 @@
 #define IMAGE_F_ECDSA224_SHA256       0x00000008 /* ECDSA224 over SHA256 */
 #define IMAGE_F_NON_BOOTABLE          0x00000010 /* Split image app. */
 #define IMAGE_F_ECDSA256_SHA256       0x00000020 /* ECDSA256 over SHA256 */
+#define IMAGE_F_PKCS1_PSS_RSA2048_SHA256 0x0000040 /* PKCS1 PSS */
 
 /*
  * ECSDA224 is with NIST P-224
diff --git a/boot/bootutil/src/image_rsa.c b/boot/bootutil/src/image_rsa.c
index 408e7e6..fd2d913 100644
--- a/boot/bootutil/src/image_rsa.c
+++ b/boot/bootutil/src/image_rsa.c
@@ -25,17 +25,52 @@
 
 #ifdef MCUBOOT_SIGN_RSA
 #include "bootutil/sign_key.h"
+#include "bootutil/sha256.h"
 
 #include "mbedtls/rsa.h"
 #include "mbedtls/asn1.h"
 
 #include "bootutil_priv.h"
 
+#if MCUBOOT_RSA_PKCS1_15
 static const uint8_t sha256_oid[] = {
     0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86,
     0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05,
     0x00, 0x04, 0x20
 };
+#endif
+
+#ifndef MCUBOOT_RSA_PKCS1_15
+/*
+ * Constants for this particular constrained implementation of
+ * RSA-PSS.  In particular, we support RSA 2048, with a SHA256 hash,
+ * and a 32-byte salt.  A signature with different parameters will be
+ * rejected as invalid.
+ */
+
+/* The size, in octets, of the message. */
+#define PSS_EMLEN 256
+
+/* The size of the hash function.  For SHA256, this is 32 bytes. */
+#define PSS_HLEN 32
+
+/* Size of the salt, should be fixed. */
+#define PSS_SLEN 32
+
+/* The length of the mask: emLen - hLen - 1. */
+#define PSS_MASK_LEN (256 - PSS_HLEN - 1)
+
+#define PSS_HASH_OFFSET PSS_MASK_LEN
+
+/* For the mask itself, how many bytes should be all zeros. */
+#define PSS_MASK_ZERO_COUNT (PSS_MASK_LEN - PSS_SLEN - 1)
+#define PSS_MASK_ONE_POS   PSS_MASK_ZERO_COUNT
+
+/* Where the salt starts. */
+#define PSS_MASK_SALT_POS   (PSS_MASK_ONE_POS + 1)
+
+static const uint8_t pss_zeros[8] = {0};
+#endif
 
 /*
  * Parse the public key used for signing. Simple RSA format.
@@ -73,6 +108,162 @@
     return 0;
 }
 
+#ifndef MCUBOOT_RSA_PKCS1_15
+/*
+ * Compute the RSA-PSS mask-generation function, MGF1.  Assumptions
+ * are that the mask length will be less than 256 * PSS_HLEN, and
+ * therefore we never need to increment anything other than the low
+ * byte of the counter.
+ *
+ * This is described in PKCS#1, B.2.1.
+ */
+static void
+pss_mgf1(uint8_t *mask, const uint8_t *hash)
+{
+    bootutil_sha256_context ctx;
+    uint8_t counter[4] = { 0, 0, 0, 0 };
+    uint8_t htmp[PSS_HLEN];
+    int count = PSS_MASK_LEN;
+    int bytes;
+
+    while (count > 0) {
+        bootutil_sha256_init(&ctx);
+        bootutil_sha256_update(&ctx, hash, PSS_HLEN);
+        bootutil_sha256_update(&ctx, counter, 4);
+        bootutil_sha256_finish(&ctx, htmp);
+
+        counter[3]++;
+
+        bytes = PSS_HLEN;
+        if (bytes > count)
+            bytes = count;
+
+        memcpy(mask, htmp, bytes);
+        mask += bytes;
+        count -= bytes;
+    }
+}
+
+/*
+ * Validate an RSA signature, using RSA-PSS, as described in PKCS #1
+ * v2.2, section 9.1.2, with many parameters required to have fixed
+ * values.
+ */
+static int
+bootutil_cmp_rsasig(mbedtls_rsa_context *ctx, uint8_t *hash, uint32_t hlen,
+  uint8_t *sig)
+{
+    bootutil_sha256_context shactx;
+    uint8_t em[MBEDTLS_MPI_MAX_SIZE];
+    uint8_t db_mask[PSS_MASK_LEN];
+    uint8_t h2[PSS_HLEN];
+    int i;
+
+    if (ctx->len != PSS_EMLEN || PSS_EMLEN != MBEDTLS_MPI_MAX_SIZE) {
+        return -1;
+    }
+
+    if (hlen != PSS_HLEN) {
+        return -1;
+    }
+
+    if (mbedtls_rsa_public(ctx, sig, em)) {
+        return -1;
+    }
+
+    /*
+     * PKCS #1 v2.2, 9.1.2 EMSA-PSS-Verify
+     *
+     * emBits is 2048
+     * emLen = ceil(emBits/8) = 256
+     *
+     * The salt length is not known at the beginning.
+     */
+
+    /* Step 1.  The message is constrained by the address space of a
+     * 32-bit processor, which is far less than the 2^61-1 limit of
+     * SHA-256.
+     */
+
+    /* Step 2.  mHash is passed in as 'hash', with hLen the hlen
+     * argument. */
+
+    /* Step 3.  if emLen < hLen + sLen + 2, inconsistent and stop.
+     * The salt length is not known at this point.
+     */
+
+    /* Step 4.  If the rightmost octect of EM does have the value
+     * 0xbc, output inconsistent and stop.
+     */
+    if (em[PSS_EMLEN - 1] != 0xbc) {
+        return -1;
+    }
+
+    /* Step 5.  Let maskedDB be the leftmost emLen - hLen - 1 octets
+     * of EM, and H be the next hLen octets.
+     *
+     * maskedDB is then the first 256 - 32 - 1 = 0-222
+     * H is 32 bytes 223-254
+     */
+
+    /* Step 6.  If the leftmost 8emLen - emBits bits of the leftmost
+     * octet in maskedDB are not all equal to zero, output
+     * inconsistent and stop.
+     *
+     * 8emLen - emBits is zero, so there is nothing to test here.
+     */
+
+    /* Step 7.  let dbMask = MGF(H, emLen - hLen - 1). */
+    pss_mgf1(db_mask, &em[PSS_HASH_OFFSET]);
+
+    /* Step 8.  let DB = maskedDB xor dbMask.
+     * To avoid needing an additional buffer, store the 'db' in the
+     * same buffer as db_mask.  From now, to the end of this function,
+     * db_mask refers to the unmasked 'db'. */
+    for (i = 0; i < PSS_MASK_LEN; i++) {
+        db_mask[i] ^= em[i];
+    }
+
+    /* Step 9.  Set the leftmost 8emLen - emBits bits of the leftmost
+     * octet in DB to zero.
+     * pycrypto seems to always make the emBits 2047, so we need to
+     * clear the top bit. */
+    db_mask[0] &= 0x7F;
+
+    /* Step 10.  If the emLen - hLen - sLen - 2 leftmost octets of DB
+     * are not zero or if the octet at position emLen - hLen - sLen -
+     * 1 (the leftmost position is "position 1") does not have
+     * hexadecimal value 0x01, output "inconsistent" and stop. */
+    for (i = 0; i < PSS_MASK_ZERO_COUNT; i++) {
+        if (db_mask[i] != 0) {
+            return -1;
+        }
+    }
+
+    if (db_mask[PSS_MASK_ONE_POS] != 1) {
+        return -1;
+    }
+
+    /* Step 11. Let salt be the last sLen octets of DB */
+
+    /* Step 12.  Let M' = 0x00 00 00 00 00 00 00 00 || mHash || salt; */
+
+    /* Step 13.  Let H' = Hash(M') */
+    bootutil_sha256_init(&shactx);
+    bootutil_sha256_update(&shactx, pss_zeros, 8);
+    bootutil_sha256_update(&shactx, hash, PSS_HLEN);
+    bootutil_sha256_update(&shactx, &db_mask[PSS_MASK_SALT_POS], PSS_SLEN);
+    bootutil_sha256_finish(&shactx, h2);
+
+    /* Step 14.  If H = H', output "consistent".  Otherwise, output
+     * "inconsistent". */
+    if (memcmp(h2, &em[PSS_HASH_OFFSET], PSS_HLEN) != 0) {
+        return -1;
+    }
+
+    return 0;
+}
+#else /* BOOTUTIL_RSA_PKCS1_15 */
 /*
  * PKCS1.5 using RSA2048 computed over SHA256.
  */
@@ -120,6 +311,7 @@
 
     return 0;
 }
+#endif
 
 int
 bootutil_verify_sig(uint8_t *hash, uint32_t hlen, uint8_t *sig, int slen,
diff --git a/boot/bootutil/src/image_validate.c b/boot/bootutil/src/image_validate.c
index 00d43c6..2741c37 100644
--- a/boot/bootutil/src/image_validate.c
+++ b/boot/bootutil/src/image_validate.c
@@ -110,9 +110,15 @@
     int rc;
 
 #ifdef MCUBOOT_SIGN_RSA
+#ifdef MCUBOOT_RSA_PKCS1_15
     if ((hdr->ih_flags & IMAGE_F_PKCS15_RSA2048_SHA256) == 0) {
         return -1;
     }
+#else
+    if ((hdr->ih_flags & IMAGE_F_PKCS1_PSS_RSA2048_SHA256) == 0) {
+        return -1;
+    }
+#endif
 #endif
 #ifdef MCUBOOT_SIGN_EC
     if ((hdr->ih_flags & IMAGE_F_ECDSA224_SHA256) == 0) {