psa_generate_key_ext: RSA: support custom public exponent

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 6263be9..98823de 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7501,11 +7501,16 @@
 
 psa_status_t psa_generate_key_internal(
     const psa_key_attributes_t *attributes,
+    const psa_key_generation_method_t *method, size_t method_length,
     uint8_t *key_buffer, size_t key_buffer_size, size_t *key_buffer_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_type_t type = attributes->core.type;
 
+    /* Only used for RSA */
+    (void) method;
+    (void) method_length;
+
     if ((attributes->domain_parameters == NULL) &&
         (attributes->domain_parameters_size != 0)) {
         return PSA_ERROR_INVALID_ARGUMENT;
@@ -7526,7 +7531,17 @@
 
 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR_GENERATE)
     if (type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
-        return mbedtls_psa_rsa_generate_key(attributes,
+        /* Hack: if the method specifies a non-default e, pass it
+         * via the domain parameters. TODO: refactor this code so
+         * that mbedtls_psa_rsa_generate_key() gets e via a new
+         * parameter instead. */
+        psa_key_attributes_t override_attributes = *attributes;
+        if (method_length > sizeof(*method)) {
+            override_attributes.domain_parameters_size =
+                method_length - offsetof(psa_key_generation_method_t, data);
+            override_attributes.domain_parameters = (uint8_t *) &method->data;
+        }
+        return mbedtls_psa_rsa_generate_key(&override_attributes,
                                             key_buffer,
                                             key_buffer_size,
                                             key_buffer_length);
@@ -7584,6 +7599,14 @@
     if (method_length < sizeof(*method)) {
         return PSA_ERROR_INVALID_ARGUMENT;
     }
+
+#if defined(PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_GENERATE)
+    if (attributes->core.type == PSA_KEY_TYPE_RSA_KEY_PAIR) {
+        if (method->flags != 0) {
+            return PSA_ERROR_INVALID_ARGUMENT;
+        }
+    } else
+#endif
     if (!psa_key_generation_method_is_default(method, method_length)) {
         return PSA_ERROR_INVALID_ARGUMENT;
     }
@@ -7625,8 +7648,9 @@
     }
 
     status = psa_driver_wrapper_generate_key(attributes,
-                                             slot->key.data, slot->key.bytes, &slot->key.bytes);
-
+                                             method, method_length,
+                                             slot->key.data, slot->key.bytes,
+                                             &slot->key.bytes);
     if (status != PSA_SUCCESS) {
         psa_remove_key_data_from_memory(slot);
     }