Further pake code optimizations

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 115e994..917a9fa 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7609,6 +7609,7 @@
     size_t *output_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_crypto_driver_pake_step_t driver_step = PSA_JPAKE_STEP_INVALID;
     *output_length = 0;
 
     if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
@@ -7635,6 +7636,8 @@
             if (status != PSA_SUCCESS) {
                 goto exit;
             }
+            driver_step = convert_jpake_computation_stage_to_driver_step(
+                &operation->computation_stage.jpake);
             break;
 #endif /* PSA_WANT_ALG_JPAKE */
         default:
@@ -7643,17 +7646,8 @@
             goto exit;
     }
 
-#if defined(PSA_WANT_ALG_JPAKE)
-    status = psa_driver_wrapper_pake_output(operation,
-                                            convert_jpake_computation_stage_to_driver_step(
-                                                &operation->computation_stage.jpake),
-                                            output,
-                                            output_size,
-                                            output_length);
-#else
-    (void) output;
-    status = PSA_ERROR_NOT_SUPPORTED;
-#endif /* PSA_WANT_ALG_JPAKE */
+    status = psa_driver_wrapper_pake_output(operation, driver_step,
+                                            output, output_size, output_length);
 
     if (status != PSA_SUCCESS) {
         goto exit;
@@ -7682,8 +7676,7 @@
 #if defined(PSA_WANT_ALG_JPAKE)
 static psa_status_t psa_jpake_input_prologue(
     psa_pake_operation_t *operation,
-    psa_pake_step_t step,
-    size_t input_length)
+    psa_pake_step_t step)
 {
     if (step != PSA_PAKE_STEP_KEY_SHARE &&
         step != PSA_PAKE_STEP_ZK_PUBLIC &&
@@ -7698,12 +7691,6 @@
         return PSA_ERROR_BAD_STATE;
     }
 
-    const psa_pake_primitive_t prim = PSA_PAKE_PRIMITIVE(
-        PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256);
-    if (input_length > (size_t) PSA_PAKE_INPUT_SIZE(PSA_ALG_JPAKE, prim, step)) {
-        return PSA_ERROR_INVALID_ARGUMENT;
-    }
-
     if (computation_stage->state != PSA_PAKE_STATE_READY &&
         computation_stage->state != PSA_PAKE_INPUT_X1_X2 &&
         computation_stage->state != PSA_PAKE_INPUT_X4S) {
@@ -7787,6 +7774,7 @@
     size_t input_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_crypto_driver_pake_step_t driver_step = PSA_JPAKE_STEP_INVALID;
 
     if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         status = psa_pake_complete_inputs(operation);
@@ -7800,7 +7788,7 @@
         goto exit;
     }
 
-    if (input_length == 0) {
+    if (input_length == 0 || input_length > PSA_PAKE_INPUT_MAX_SIZE) {
         status = PSA_ERROR_INVALID_ARGUMENT;
         goto exit;
     }
@@ -7808,10 +7796,12 @@
     switch (operation->alg) {
 #if defined(PSA_WANT_ALG_JPAKE)
         case PSA_ALG_JPAKE:
-            status = psa_jpake_input_prologue(operation, step, input_length);
+            status = psa_jpake_input_prologue(operation, step);
             if (status != PSA_SUCCESS) {
                 goto exit;
             }
+            driver_step = convert_jpake_computation_stage_to_driver_step(
+                &operation->computation_stage.jpake);
             break;
 #endif /* PSA_WANT_ALG_JPAKE */
         default:
@@ -7820,16 +7810,8 @@
             goto exit;
     }
 
-#if defined(PSA_WANT_ALG_JPAKE)
-    status = psa_driver_wrapper_pake_input(operation,
-                                           convert_jpake_computation_stage_to_driver_step(
-                                               &operation->computation_stage.jpake),
-                                           input,
-                                           input_length);
-#else
-    (void) input;
-    status = PSA_ERROR_NOT_SUPPORTED;
-#endif /* PSA_WANT_ALG_JPAKE */
+    status = psa_driver_wrapper_pake_input(operation, driver_step,
+                                           input, input_length);
 
     if (status != PSA_SUCCESS) {
         goto exit;
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index 538df87..a537184 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -431,7 +431,8 @@
                 0, 23 /* secp256r1 */
             };
 
-            if (operation->buffer_length + sizeof(ecparameters) > sizeof(operation->buffer)) {
+            if (operation->buffer_length + sizeof(ecparameters) >
+                sizeof(operation->buffer)) {
                 return PSA_ERROR_BUFFER_TOO_SMALL;
             }
 
@@ -441,10 +442,9 @@
         }
 
         /*
-         * The core has checked that input_length is smaller than
-         * PSA_PAKE_INPUT_SIZE(PSA_ALG_JPAKE, primitive, step)
-         * where primitive is the JPAKE algorithm primitive and step
-         * the PSA API level input step. Thus no risk of integer overflow here.
+         * The core checks that input_length is smaller than
+         * PSA_PAKE_INPUT_MAX_SIZE.
+         * Thus no risk of integer overflow here.
          */
         if (operation->buffer_length + input_length + 1 > sizeof(operation->buffer)) {
             return PSA_ERROR_BUFFER_TOO_SMALL;
diff --git a/library/psa_crypto_pake.h b/library/psa_crypto_pake.h
index eb30881..001c987 100644
--- a/library/psa_crypto_pake.h
+++ b/library/psa_crypto_pake.h
@@ -96,11 +96,7 @@
  *       entry point as defined in the PSA driver interface specification for
  *       transparent drivers.
  *
- * \note The core has checked that input_length is smaller than
-         PSA_PAKE_INPUT_SIZE(PSA_ALG_JPAKE, primitive, step)
-         where primitive is the JPAKE algorithm primitive and step
-         the PSA API level input step. Thus no risk of integer overflow while
-         checking operation buffer overflow.
+ * \note The core checks that input_length is smaller than PSA_PAKE_INPUT_MAX_SIZE.
  *
  * \param[in,out] operation    Active PAKE operation.
  * \param step                 The driver step for which the input is provided.