Move JPAKE state machine logic from driver to core

- Add `alg` and `computation_stage` to `psa_pake_operation_s`.
  Now when logic is moved to core information about `alg` is required.
  `computation_stage` is a structure that provides a union of computation stages for pake algorithms.
- Move the jpake operation logic from driver to core. This requires changing driver entry points for `psa_pake_output`/`psa_pake_input` functions and adding a `computation_stage` parameter. I'm not sure if this solution is correct. Now the driver can check the current computation stage and perform some action. For jpake drivers `step` parameter is now not used, but I think it needs to stay as it might be needed for other pake algorithms.
- Removed test that seems to be redundant as we can't be sure that operation is aborted after failure.

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/include/psa/crypto_builtin_composites.h b/include/psa/crypto_builtin_composites.h
index 295452c..3221a64 100644
--- a/include/psa/crypto_builtin_composites.h
+++ b/include/psa/crypto_builtin_composites.h
@@ -195,11 +195,8 @@
 
 typedef struct {
     psa_algorithm_t MBEDTLS_PRIVATE(alg);
-    unsigned int MBEDTLS_PRIVATE(state);
-    unsigned int MBEDTLS_PRIVATE(sequence);
+
 #if defined(MBEDTLS_PSA_BUILTIN_PAKE)
-    unsigned int MBEDTLS_PRIVATE(input_step);
-    unsigned int MBEDTLS_PRIVATE(output_step);
     uint8_t *MBEDTLS_PRIVATE(password);
     size_t MBEDTLS_PRIVATE(password_len);
     uint8_t MBEDTLS_PRIVATE(role);
diff --git a/include/psa/crypto_extra.h b/include/psa/crypto_extra.h
index 4fa273d..1678228 100644
--- a/include/psa/crypto_extra.h
+++ b/include/psa/crypto_extra.h
@@ -1292,6 +1292,12 @@
 /** The type of input values for PAKE operations. */
 typedef struct psa_crypto_driver_pake_inputs_s psa_crypto_driver_pake_inputs_t;
 
+/** The type of compuatation stage for PAKE operations. */
+typedef struct psa_pake_computation_stage_s psa_pake_computation_stage_t;
+
+/** The type of compuatation stage for J-PAKE operations. */
+typedef struct psa_jpake_computation_stage_s psa_jpake_computation_stage_t;
+
 /** Return an initial value for a PAKE operation object.
  */
 static psa_pake_operation_t psa_pake_operation_init(void);
@@ -1832,7 +1838,8 @@
 /** Returns a suitable initializer for a PAKE operation object of type
  * psa_pake_operation_t.
  */
-#define PSA_PAKE_OPERATION_INIT { 0, PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS, { 0 } }
+#define PSA_PAKE_OPERATION_INIT { 0, PSA_ALG_NONE, PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS, \
+                                  { { 0 } }, { 0 } }
 
 struct psa_pake_cipher_suite_s {
     psa_algorithm_t algorithm;
@@ -1904,7 +1911,6 @@
 }
 
 struct psa_crypto_driver_pake_inputs_s {
-    psa_algorithm_t MBEDTLS_PRIVATE(alg);
     uint8_t *MBEDTLS_PRIVATE(password);
     size_t MBEDTLS_PRIVATE(password_len);
     psa_pake_role_t MBEDTLS_PRIVATE(role);
@@ -1912,6 +1918,48 @@
     psa_pake_cipher_suite_t MBEDTLS_PRIVATE(cipher_suite);
 };
 
+enum psa_jpake_step {
+    PSA_PAKE_STEP_INVALID       = 0,
+    PSA_PAKE_STEP_X1_X2         = 1,
+    PSA_PAKE_STEP_X2S           = 2,
+    PSA_PAKE_STEP_DERIVE        = 3,
+};
+
+enum psa_jpake_state {
+    PSA_PAKE_STATE_INVALID      = 0,
+    PSA_PAKE_STATE_SETUP        = 1,
+    PSA_PAKE_STATE_READY        = 2,
+    PSA_PAKE_OUTPUT_X1_X2       = 3,
+    PSA_PAKE_OUTPUT_X2S         = 4,
+    PSA_PAKE_INPUT_X1_X2        = 5,
+    PSA_PAKE_INPUT_X4S          = 6,
+};
+
+enum psa_jpake_sequence {
+    PSA_PAKE_SEQ_INVALID        = 0,
+    PSA_PAKE_X1_STEP_KEY_SHARE  = 1,    /* also X2S & X4S KEY_SHARE */
+    PSA_PAKE_X1_STEP_ZK_PUBLIC  = 2,    /* also X2S & X4S ZK_PUBLIC */
+    PSA_PAKE_X1_STEP_ZK_PROOF   = 3,    /* also X2S & X4S ZK_PROOF */
+    PSA_PAKE_X2_STEP_KEY_SHARE  = 4,
+    PSA_PAKE_X2_STEP_ZK_PUBLIC  = 5,
+    PSA_PAKE_X2_STEP_ZK_PROOF   = 6,
+    PSA_PAKE_SEQ_END            = 7,
+};
+
+struct psa_jpake_computation_stage_s {
+    unsigned int MBEDTLS_PRIVATE(state);
+    unsigned int MBEDTLS_PRIVATE(sequence);
+    unsigned int MBEDTLS_PRIVATE(input_step);
+    unsigned int MBEDTLS_PRIVATE(output_step);
+};
+
+struct psa_pake_computation_stage_s {
+    union {
+        unsigned dummy;
+        psa_jpake_computation_stage_t MBEDTLS_PRIVATE(jpake_computation_stage);
+    } MBEDTLS_PRIVATE(data);
+};
+
 struct psa_pake_operation_s {
     /** Unique ID indicating which driver got assigned to do the
      * operation. Since driver contexts are driver-specific, swapping
@@ -1920,10 +1968,14 @@
      * ID value zero means the context is not valid or not assigned to
      * any driver (i.e. none of the driver contexts are active). */
     unsigned int MBEDTLS_PRIVATE(id);
+    /* Algorithm used for PAKE operation */
+    psa_algorithm_t MBEDTLS_PRIVATE(alg);
     /* Based on stage (collecting inputs/computation) we select active structure of data union.
      * While switching stage (when driver setup is called) collected inputs
        are copied to the corresponding operation context. */
     uint8_t MBEDTLS_PRIVATE(stage);
+    /* Holds computation stage of the PAKE algorithms. */
+    psa_pake_computation_stage_t MBEDTLS_PRIVATE(computation_stage);
     union {
         unsigned dummy;
         psa_crypto_driver_pake_inputs_t MBEDTLS_PRIVATE(inputs);
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 273d248..66ecc06 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7180,11 +7180,14 @@
     psa_pake_operation_t *operation,
     const psa_pake_cipher_suite_t *cipher_suite)
 {
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
+
     if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->data.inputs.alg != PSA_ALG_NONE) {
+    if (operation->alg != PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
@@ -7198,9 +7201,16 @@
 
     memset(&operation->data.inputs, 0, sizeof(operation->data.inputs));
 
-    operation->data.inputs.alg = cipher_suite->algorithm;
+    operation->alg = cipher_suite->algorithm;
     operation->data.inputs.cipher_suite = *cipher_suite;
 
+    if (operation->alg == PSA_ALG_JPAKE) {
+        computation_stage->state = PSA_PAKE_STATE_SETUP;
+        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+        computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
+        computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+    }
+
     return PSA_SUCCESS;
 }
 
@@ -7216,7 +7226,7 @@
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->data.inputs.alg == PSA_ALG_NONE) {
+    if (operation->alg == PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
@@ -7241,7 +7251,8 @@
 
     operation->data.inputs.password = mbedtls_calloc(1, slot->key.bytes);
     if (operation->data.inputs.password == NULL) {
-        return PSA_ERROR_INSUFFICIENT_MEMORY;
+        status = PSA_ERROR_INSUFFICIENT_MEMORY;
+        goto error;
     }
 
     memcpy(operation->data.inputs.password, slot->key.data, slot->key.bytes);
@@ -7264,7 +7275,7 @@
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->data.inputs.alg == PSA_ALG_NONE) {
+    if (operation->alg == PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
@@ -7286,7 +7297,7 @@
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->data.inputs.alg == PSA_ALG_NONE) {
+    if (operation->alg == PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
@@ -7305,7 +7316,7 @@
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->data.inputs.alg == PSA_ALG_NONE) {
+    if (operation->alg == PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
@@ -7322,6 +7333,98 @@
     return PSA_SUCCESS;
 }
 
+static psa_status_t psa_jpake_output_prologue(
+    psa_pake_operation_t *operation,
+    psa_pake_step_t step)
+{
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
+
+    if (computation_stage->state == PSA_PAKE_STATE_INVALID) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (step != PSA_PAKE_STEP_KEY_SHARE &&
+        step != PSA_PAKE_STEP_ZK_PUBLIC &&
+        step != PSA_PAKE_STEP_ZK_PROOF) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
+    if (computation_stage->state != PSA_PAKE_STATE_READY &&
+        computation_stage->state != PSA_PAKE_OUTPUT_X1_X2 &&
+        computation_stage->state != PSA_PAKE_OUTPUT_X2S) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (computation_stage->state == PSA_PAKE_STATE_READY) {
+        if (step != PSA_PAKE_STEP_KEY_SHARE) {
+            return PSA_ERROR_BAD_STATE;
+        }
+
+        switch (computation_stage->output_step) {
+            case PSA_PAKE_STEP_X1_X2:
+                computation_stage->state = PSA_PAKE_OUTPUT_X1_X2;
+                break;
+            case PSA_PAKE_STEP_X2S:
+                computation_stage->state = PSA_PAKE_OUTPUT_X2S;
+                break;
+            default:
+                return PSA_ERROR_BAD_STATE;
+        }
+
+        computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
+    }
+
+    /* Check if step matches current sequence */
+    switch (computation_stage->sequence) {
+        case PSA_PAKE_X1_STEP_KEY_SHARE:
+        case PSA_PAKE_X2_STEP_KEY_SHARE:
+            if (step != PSA_PAKE_STEP_KEY_SHARE) {
+                return PSA_ERROR_BAD_STATE;
+            }
+            break;
+
+        case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+        case PSA_PAKE_X2_STEP_ZK_PUBLIC:
+            if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
+                return PSA_ERROR_BAD_STATE;
+            }
+            break;
+
+        case PSA_PAKE_X1_STEP_ZK_PROOF:
+        case PSA_PAKE_X2_STEP_ZK_PROOF:
+            if (step != PSA_PAKE_STEP_ZK_PROOF) {
+                return PSA_ERROR_BAD_STATE;
+            }
+            break;
+
+        default:
+            return PSA_ERROR_BAD_STATE;
+    }
+
+    return PSA_SUCCESS;
+}
+
+static psa_status_t psa_jpake_output_epilogue(
+    psa_pake_operation_t *operation)
+{
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
+
+    if ((computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
+         computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
+        (computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
+         computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
+        computation_stage->state = PSA_PAKE_STATE_READY;
+        computation_stage->output_step++;
+        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+    } else {
+        computation_stage->sequence++;
+    }
+
+    return PSA_SUCCESS;
+}
+
 psa_status_t psa_pake_output(
     psa_pake_operation_t *operation,
     psa_pake_step_t step,
@@ -7330,9 +7433,11 @@
     size_t *output_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
 
     if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
-        if (operation->data.inputs.alg == PSA_ALG_NONE ||
+        if (operation->alg == PSA_ALG_NONE ||
             operation->data.inputs.password_len == 0 ||
             operation->data.inputs.role == PSA_PAKE_ROLE_NONE) {
             return PSA_ERROR_BAD_STATE;
@@ -7343,6 +7448,12 @@
 
         if (status == PSA_SUCCESS) {
             operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
+            if (operation->alg == PSA_ALG_JPAKE) {
+                computation_stage->state = PSA_PAKE_STATE_READY;
+                computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+                computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
+                computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+            }
         } else {
             return status;
         }
@@ -7360,10 +7471,140 @@
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
-    return psa_driver_wrapper_pake_output(operation, step, output,
-                                          output_size, output_length);
+    switch (operation->alg) {
+        case PSA_ALG_JPAKE:
+            status = psa_jpake_output_prologue(operation, step);
+            if (status != PSA_SUCCESS) {
+                return status;
+            }
+            break;
+        default:
+            return PSA_ERROR_NOT_SUPPORTED;
+    }
+
+    status = psa_driver_wrapper_pake_output(operation, step,
+                                            &operation->computation_stage,
+                                            output, output_size, output_length);
+
+    if (status != PSA_SUCCESS) {
+        return status;
+    }
+
+    switch (operation->alg) {
+        case PSA_ALG_JPAKE:
+            status = psa_jpake_output_epilogue(operation);
+            if (status != PSA_SUCCESS) {
+                return status;
+            }
+            break;
+        default:
+            return PSA_ERROR_NOT_SUPPORTED;
+    }
+
+    return status;
 }
 
+static psa_status_t psa_jpake_input_prologue(
+    psa_pake_operation_t *operation,
+    psa_pake_step_t step,
+    size_t input_length)
+{
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
+
+    if (computation_stage->state == PSA_PAKE_STATE_INVALID) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (step != PSA_PAKE_STEP_KEY_SHARE &&
+        step != PSA_PAKE_STEP_ZK_PUBLIC &&
+        step != PSA_PAKE_STEP_ZK_PROOF) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
+    const psa_pake_primitive_t prim = PSA_PAKE_PRIMITIVE(
+        PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256);
+    if (input_length > (size_t) PSA_PAKE_INPUT_SIZE(PSA_ALG_JPAKE, prim, step)) {
+        return PSA_ERROR_INVALID_ARGUMENT;
+    }
+
+    if (computation_stage->state != PSA_PAKE_STATE_READY &&
+        computation_stage->state != PSA_PAKE_INPUT_X1_X2 &&
+        computation_stage->state != PSA_PAKE_INPUT_X4S) {
+        return PSA_ERROR_BAD_STATE;
+    }
+
+    if (computation_stage->state == PSA_PAKE_STATE_READY) {
+        if (step != PSA_PAKE_STEP_KEY_SHARE) {
+            return PSA_ERROR_BAD_STATE;
+        }
+
+        switch (computation_stage->input_step) {
+            case PSA_PAKE_STEP_X1_X2:
+                computation_stage->state = PSA_PAKE_INPUT_X1_X2;
+                break;
+            case PSA_PAKE_STEP_X2S:
+                computation_stage->state = PSA_PAKE_INPUT_X4S;
+                break;
+            default:
+                return PSA_ERROR_BAD_STATE;
+        }
+
+        computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
+    }
+
+    /* Check if step matches current sequence */
+    switch (computation_stage->sequence) {
+        case PSA_PAKE_X1_STEP_KEY_SHARE:
+        case PSA_PAKE_X2_STEP_KEY_SHARE:
+            if (step != PSA_PAKE_STEP_KEY_SHARE) {
+                return PSA_ERROR_BAD_STATE;
+            }
+            break;
+
+        case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+        case PSA_PAKE_X2_STEP_ZK_PUBLIC:
+            if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
+                return PSA_ERROR_BAD_STATE;
+            }
+            break;
+
+        case PSA_PAKE_X1_STEP_ZK_PROOF:
+        case PSA_PAKE_X2_STEP_ZK_PROOF:
+            if (step != PSA_PAKE_STEP_ZK_PROOF) {
+                return PSA_ERROR_BAD_STATE;
+            }
+            break;
+
+        default:
+            return PSA_ERROR_BAD_STATE;
+    }
+
+    return PSA_SUCCESS;
+}
+
+
+static psa_status_t psa_jpake_input_epilogue(
+    psa_pake_operation_t *operation)
+{
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
+
+    if ((computation_stage->state == PSA_PAKE_INPUT_X1_X2 &&
+         computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
+        (computation_stage->state == PSA_PAKE_INPUT_X4S &&
+         computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
+        computation_stage->state = PSA_PAKE_STATE_READY;
+        computation_stage->input_step++;
+        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+    } else {
+        computation_stage->sequence++;
+    }
+
+    return PSA_SUCCESS;
+}
+
+
 psa_status_t psa_pake_input(
     psa_pake_operation_t *operation,
     psa_pake_step_t step,
@@ -7371,9 +7612,11 @@
     size_t input_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
 
     if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
-        if (operation->data.inputs.alg == PSA_ALG_NONE ||
+        if (operation->alg == PSA_ALG_NONE ||
             operation->data.inputs.password_len == 0 ||
             operation->data.inputs.role == PSA_PAKE_ROLE_NONE) {
             return PSA_ERROR_BAD_STATE;
@@ -7384,6 +7627,12 @@
 
         if (status == PSA_SUCCESS) {
             operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
+            if (operation->alg == PSA_ALG_JPAKE) {
+                computation_stage->state = PSA_PAKE_STATE_READY;
+                computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+                computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
+                computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+            }
         } else {
             return status;
         }
@@ -7401,8 +7650,37 @@
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
-    return psa_driver_wrapper_pake_input(operation, step, input,
-                                         input_length);
+    switch (operation->alg) {
+        case PSA_ALG_JPAKE:
+            status = psa_jpake_input_prologue(operation, step, input_length);
+            if (status != PSA_SUCCESS) {
+                return status;
+            }
+            break;
+        default:
+            return PSA_ERROR_NOT_SUPPORTED;
+    }
+
+    status = psa_driver_wrapper_pake_input(operation, step,
+                                           &operation->computation_stage,
+                                           input, input_length);
+
+    if (status != PSA_SUCCESS) {
+        return status;
+    }
+
+    switch (operation->alg) {
+        case PSA_ALG_JPAKE:
+            status = psa_jpake_input_epilogue(operation);
+            if (status != PSA_SUCCESS) {
+                return status;
+            }
+            break;
+        default:
+            return PSA_ERROR_NOT_SUPPORTED;
+    }
+
+    return status;
 }
 
 psa_status_t psa_pake_get_implicit_key(
@@ -7412,11 +7690,20 @@
     psa_status_t status = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     uint8_t shared_key[MBEDTLS_PSA_PAKE_BUFFER_SIZE];
     size_t shared_key_len = 0;
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
 
     if (operation->id == 0) {
         return PSA_ERROR_BAD_STATE;
     }
 
+    if (operation->alg == PSA_ALG_JPAKE) {
+        if (computation_stage->input_step != PSA_PAKE_STEP_DERIVE ||
+            computation_stage->output_step != PSA_PAKE_STEP_DERIVE) {
+            return PSA_ERROR_BAD_STATE;
+        }
+    }
+
     status = psa_driver_wrapper_pake_get_implicit_key(operation,
                                                       shared_key,
                                                       &shared_key_len);
@@ -7436,18 +7723,29 @@
 
     mbedtls_platform_zeroize(shared_key, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
 
+    psa_pake_abort(operation);
+
     return status;
 }
 
 psa_status_t psa_pake_abort(
     psa_pake_operation_t *operation)
 {
+    psa_jpake_computation_stage_t *computation_stage =
+        &operation->computation_stage.data.jpake_computation_stage;
+
     /* If we are in collecting inputs stage clear inputs. */
     if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         mbedtls_free(operation->data.inputs.password);
         memset(&operation->data.inputs, 0, sizeof(psa_crypto_driver_pake_inputs_t));
         return PSA_SUCCESS;
     }
+    if (operation->alg == PSA_ALG_JPAKE) {
+        computation_stage->input_step = PSA_PAKE_STEP_INVALID;
+        computation_stage->output_step = PSA_PAKE_STEP_INVALID;
+        computation_stage->state = PSA_PAKE_STATE_INVALID;
+        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+    }
 
     return psa_driver_wrapper_pake_abort(operation);
 }
diff --git a/library/psa_crypto_driver_wrappers.h b/library/psa_crypto_driver_wrappers.h
index abaabb5..ac17be4 100644
--- a/library/psa_crypto_driver_wrappers.h
+++ b/library/psa_crypto_driver_wrappers.h
@@ -422,6 +422,7 @@
 psa_status_t psa_driver_wrapper_pake_output(
     psa_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     uint8_t *output,
     size_t output_size,
     size_t *output_length);
@@ -429,6 +430,7 @@
 psa_status_t psa_driver_wrapper_pake_input(
     psa_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     const uint8_t *input,
     size_t input_length);
 
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index 3a710dc..3d5b57d 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -79,23 +79,6 @@
  *   psa_pake_abort()
  */
 
-enum psa_pake_step {
-    PSA_PAKE_STEP_INVALID       = 0,
-    PSA_PAKE_STEP_X1_X2         = 1,
-    PSA_PAKE_STEP_X2S           = 2,
-    PSA_PAKE_STEP_DERIVE        = 3,
-};
-
-enum psa_pake_state {
-    PSA_PAKE_STATE_INVALID      = 0,
-    PSA_PAKE_STATE_SETUP        = 1,
-    PSA_PAKE_STATE_READY        = 2,
-    PSA_PAKE_OUTPUT_X1_X2       = 3,
-    PSA_PAKE_OUTPUT_X2S         = 4,
-    PSA_PAKE_INPUT_X1_X2        = 5,
-    PSA_PAKE_INPUT_X4S          = 6,
-};
-
 /*
  * The first PAKE step shares the same sequences of the second PAKE step
  * but with a second set of KEY_SHARE/ZK_PUBLIC/ZK_PROOF outputs/inputs.
@@ -157,16 +140,6 @@
  *   psa_pake_get_implicit_key()
  *   => Input & Output Step = PSA_PAKE_STEP_INVALID
  */
-enum psa_pake_sequence {
-    PSA_PAKE_SEQ_INVALID        = 0,
-    PSA_PAKE_X1_STEP_KEY_SHARE  = 1,    /* also X2S & X4S KEY_SHARE */
-    PSA_PAKE_X1_STEP_ZK_PUBLIC  = 2,    /* also X2S & X4S ZK_PUBLIC */
-    PSA_PAKE_X1_STEP_ZK_PROOF   = 3,    /* also X2S & X4S ZK_PROOF */
-    PSA_PAKE_X2_STEP_KEY_SHARE  = 4,
-    PSA_PAKE_X2_STEP_ZK_PUBLIC  = 5,
-    PSA_PAKE_X2_STEP_ZK_PROOF   = 6,
-    PSA_PAKE_SEQ_END            = 7,
-};
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
 static psa_status_t mbedtls_ecjpake_to_psa_error(int ret)
@@ -190,65 +163,6 @@
 }
 #endif
 
-#if defined(MBEDTLS_PSA_BUILTIN_PAKE)
-psa_status_t mbedtls_psa_pake_setup(mbedtls_psa_pake_operation_t *operation,
-                                    const psa_crypto_driver_pake_inputs_t *inputs)
-{
-    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-
-    uint8_t *password = inputs->password;
-    size_t password_len = inputs->password_len;
-    psa_pake_role_t role = inputs->role;
-    psa_pake_cipher_suite_t cipher_suite = inputs->cipher_suite;
-
-    memset(operation, 0, sizeof(mbedtls_psa_pake_operation_t));
-
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
-    if (cipher_suite.algorithm == PSA_ALG_JPAKE) {
-        if (cipher_suite.type != PSA_PAKE_PRIMITIVE_TYPE_ECC ||
-            cipher_suite.family != PSA_ECC_FAMILY_SECP_R1 ||
-            cipher_suite.bits != 256 ||
-            cipher_suite.hash != PSA_ALG_SHA_256) {
-            status = PSA_ERROR_NOT_SUPPORTED;
-            goto error;
-        }
-
-        if (role != PSA_PAKE_ROLE_CLIENT &&
-            role != PSA_PAKE_ROLE_SERVER) {
-            status = PSA_ERROR_NOT_SUPPORTED;
-            goto error;
-        }
-
-        mbedtls_ecjpake_init(&operation->ctx.pake);
-
-        operation->state = PSA_PAKE_STATE_SETUP;
-        operation->sequence = PSA_PAKE_SEQ_INVALID;
-        operation->input_step = PSA_PAKE_STEP_X1_X2;
-        operation->output_step = PSA_PAKE_STEP_X1_X2;
-        operation->password_len = password_len;
-        operation->password = password;
-        operation->role = role;
-        operation->alg = cipher_suite.algorithm;
-
-        mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
-        operation->buffer_length = 0;
-        operation->buffer_offset = 0;
-
-        return PSA_SUCCESS;
-    } else
-#else
-    (void) operation;
-    (void) inputs;
-#endif
-    { status = PSA_ERROR_NOT_SUPPORTED; }
-
-error:
-    mbedtls_free(password);
-    mbedtls_psa_pake_abort(operation);
-    return status;
-}
-
-
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
 static psa_status_t psa_pake_ecjpake_setup(mbedtls_psa_pake_operation_t *operation)
 {
@@ -283,31 +197,84 @@
         return mbedtls_ecjpake_to_psa_error(ret);
     }
 
-    operation->state = PSA_PAKE_STATE_READY;
-
     return PSA_SUCCESS;
 }
+
+psa_status_t mbedtls_psa_pake_setup(mbedtls_psa_pake_operation_t *operation,
+                                    const psa_crypto_driver_pake_inputs_t *inputs)
+{
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
+    uint8_t *password = inputs->password;
+    size_t password_len = inputs->password_len;
+    psa_pake_role_t role = inputs->role;
+    psa_pake_cipher_suite_t cipher_suite = inputs->cipher_suite;
+
+    memset(operation, 0, sizeof(mbedtls_psa_pake_operation_t));
+
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
+    if (cipher_suite.algorithm == PSA_ALG_JPAKE) {
+        if (cipher_suite.type != PSA_PAKE_PRIMITIVE_TYPE_ECC ||
+            cipher_suite.family != PSA_ECC_FAMILY_SECP_R1 ||
+            cipher_suite.bits != 256 ||
+            cipher_suite.hash != PSA_ALG_SHA_256) {
+            status = PSA_ERROR_NOT_SUPPORTED;
+            goto error;
+        }
+
+        if (role != PSA_PAKE_ROLE_CLIENT &&
+            role != PSA_PAKE_ROLE_SERVER) {
+            status = PSA_ERROR_NOT_SUPPORTED;
+            goto error;
+        }
+
+        mbedtls_ecjpake_init(&operation->ctx.pake);
+
+        operation->password_len = password_len;
+        operation->password = password;
+        operation->role = role;
+        operation->alg = cipher_suite.algorithm;
+
+        mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
+        operation->buffer_length = 0;
+        operation->buffer_offset = 0;
+
+        status = psa_pake_ecjpake_setup(operation);
+
+        if (status != PSA_SUCCESS) {
+            goto error;
+        }
+
+        return PSA_SUCCESS;
+    } else
+#else
+    (void) operation;
+    (void) inputs;
 #endif
+    { status = PSA_ERROR_NOT_SUPPORTED; }
+
+error:
+    mbedtls_free(password);
+    mbedtls_psa_pake_abort(operation);
+    return status;
+}
 
 static psa_status_t mbedtls_psa_pake_output_internal(
     mbedtls_psa_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     uint8_t *output,
     size_t output_size,
     size_t *output_length)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     size_t length;
+    (void) step;
 
     if (operation->alg == PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->state == PSA_PAKE_STATE_INVALID) {
-        return PSA_ERROR_BAD_STATE;
-    }
-
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
     /*
      * The PSA CRYPTO PAKE and MbedTLS JPAKE API have a different
@@ -324,74 +291,12 @@
      * to return the right parts on each step.
      */
     if (operation->alg == PSA_ALG_JPAKE) {
-        if (step != PSA_PAKE_STEP_KEY_SHARE &&
-            step != PSA_PAKE_STEP_ZK_PUBLIC &&
-            step != PSA_PAKE_STEP_ZK_PROOF) {
-            return PSA_ERROR_INVALID_ARGUMENT;
-        }
-
-        if (operation->state == PSA_PAKE_STATE_SETUP) {
-            status = psa_pake_ecjpake_setup(operation);
-            if (status != PSA_SUCCESS) {
-                return status;
-            }
-        }
-
-        if (operation->state != PSA_PAKE_STATE_READY &&
-            operation->state != PSA_PAKE_OUTPUT_X1_X2 &&
-            operation->state != PSA_PAKE_OUTPUT_X2S) {
-            return PSA_ERROR_BAD_STATE;
-        }
-
-        if (operation->state == PSA_PAKE_STATE_READY) {
-            if (step != PSA_PAKE_STEP_KEY_SHARE) {
-                return PSA_ERROR_BAD_STATE;
-            }
-
-            switch (operation->output_step) {
-                case PSA_PAKE_STEP_X1_X2:
-                    operation->state = PSA_PAKE_OUTPUT_X1_X2;
-                    break;
-                case PSA_PAKE_STEP_X2S:
-                    operation->state = PSA_PAKE_OUTPUT_X2S;
-                    break;
-                default:
-                    return PSA_ERROR_BAD_STATE;
-            }
-
-            operation->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
-        }
-
-        /* Check if step matches current sequence */
-        switch (operation->sequence) {
-            case PSA_PAKE_X1_STEP_KEY_SHARE:
-            case PSA_PAKE_X2_STEP_KEY_SHARE:
-                if (step != PSA_PAKE_STEP_KEY_SHARE) {
-                    return PSA_ERROR_BAD_STATE;
-                }
-                break;
-
-            case PSA_PAKE_X1_STEP_ZK_PUBLIC:
-            case PSA_PAKE_X2_STEP_ZK_PUBLIC:
-                if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
-                    return PSA_ERROR_BAD_STATE;
-                }
-                break;
-
-            case PSA_PAKE_X1_STEP_ZK_PROOF:
-            case PSA_PAKE_X2_STEP_ZK_PROOF:
-                if (step != PSA_PAKE_STEP_ZK_PROOF) {
-                    return PSA_ERROR_BAD_STATE;
-                }
-                break;
-
-            default:
-                return PSA_ERROR_BAD_STATE;
-        }
+        const psa_jpake_computation_stage_t *jpake_computation_stage =
+            &computation_stage->data.jpake_computation_stage;
 
         /* Initialize & write round on KEY_SHARE sequences */
-        if (operation->state == PSA_PAKE_OUTPUT_X1_X2 &&
-            operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE) {
+        if (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
+            jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE) {
             ret = mbedtls_ecjpake_write_round_one(&operation->ctx.pake,
                                                   operation->buffer,
                                                   MBEDTLS_PSA_PAKE_BUFFER_SIZE,
@@ -403,8 +308,8 @@
             }
 
             operation->buffer_offset = 0;
-        } else if (operation->state == PSA_PAKE_OUTPUT_X2S &&
-                   operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE) {
+        } else if (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
+                   jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE) {
             ret = mbedtls_ecjpake_write_round_two(&operation->ctx.pake,
                                                   operation->buffer,
                                                   MBEDTLS_PSA_PAKE_BUFFER_SIZE,
@@ -429,8 +334,8 @@
          * output with a length byte, even less a curve identifier, as that
          * information is already available.
          */
-        if (operation->state == PSA_PAKE_OUTPUT_X2S &&
-            operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
+        if (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
+            jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
             operation->role == PSA_PAKE_ROLE_SERVER) {
             /* Skip ECParameters, with is 3 bytes (RFC 8422) */
             operation->buffer_offset += 3;
@@ -456,25 +361,20 @@
         operation->buffer_offset += length;
 
         /* Reset buffer after ZK_PROOF sequence */
-        if ((operation->state == PSA_PAKE_OUTPUT_X1_X2 &&
-             operation->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
-            (operation->state == PSA_PAKE_OUTPUT_X2S &&
-             operation->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
+        if ((jpake_computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
+             jpake_computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
+            (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
+             jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
             mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
             operation->buffer_length = 0;
             operation->buffer_offset = 0;
-
-            operation->state = PSA_PAKE_STATE_READY;
-            operation->output_step++;
-            operation->sequence = PSA_PAKE_SEQ_INVALID;
-        } else {
-            operation->sequence++;
         }
 
         return PSA_SUCCESS;
     } else
 #else
     (void) step;
+    (void) computation_stage;
     (void) output;
     (void) output_size;
     (void) output_length;
@@ -484,12 +384,13 @@
 
 psa_status_t mbedtls_psa_pake_output(mbedtls_psa_pake_operation_t *operation,
                                      psa_pake_step_t step,
+                                     const psa_pake_computation_stage_t *computation_stage,
                                      uint8_t *output,
                                      size_t output_size,
                                      size_t *output_length)
 {
     psa_status_t status = mbedtls_psa_pake_output_internal(
-        operation, step, output, output_size, output_length);
+        operation, step, computation_stage, output, output_size, output_length);
 
     if (status != PSA_SUCCESS) {
         mbedtls_psa_pake_abort(operation);
@@ -501,20 +402,16 @@
 static psa_status_t mbedtls_psa_pake_input_internal(
     mbedtls_psa_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     const uint8_t *input,
     size_t input_length)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-
+    (void) step;
     if (operation->alg == PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->state == PSA_PAKE_STATE_INVALID) {
-        return PSA_ERROR_BAD_STATE;
-    }
-
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
     /*
      * The PSA CRYPTO PAKE and MbedTLS JPAKE API have a different
@@ -532,77 +429,8 @@
      * This causes any input error to be only detected on the last step.
      */
     if (operation->alg == PSA_ALG_JPAKE) {
-        if (step != PSA_PAKE_STEP_KEY_SHARE &&
-            step != PSA_PAKE_STEP_ZK_PUBLIC &&
-            step != PSA_PAKE_STEP_ZK_PROOF) {
-            return PSA_ERROR_INVALID_ARGUMENT;
-        }
-
-        const psa_pake_primitive_t prim = PSA_PAKE_PRIMITIVE(
-            PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256);
-        if (input_length > (size_t) PSA_PAKE_INPUT_SIZE(PSA_ALG_JPAKE, prim, step)) {
-            return PSA_ERROR_INVALID_ARGUMENT;
-        }
-
-        if (operation->state == PSA_PAKE_STATE_SETUP) {
-            status = psa_pake_ecjpake_setup(operation);
-            if (status != PSA_SUCCESS) {
-                return status;
-            }
-        }
-
-        if (operation->state != PSA_PAKE_STATE_READY &&
-            operation->state != PSA_PAKE_INPUT_X1_X2 &&
-            operation->state != PSA_PAKE_INPUT_X4S) {
-            return PSA_ERROR_BAD_STATE;
-        }
-
-        if (operation->state == PSA_PAKE_STATE_READY) {
-            if (step != PSA_PAKE_STEP_KEY_SHARE) {
-                return PSA_ERROR_BAD_STATE;
-            }
-
-            switch (operation->input_step) {
-                case PSA_PAKE_STEP_X1_X2:
-                    operation->state = PSA_PAKE_INPUT_X1_X2;
-                    break;
-                case PSA_PAKE_STEP_X2S:
-                    operation->state = PSA_PAKE_INPUT_X4S;
-                    break;
-                default:
-                    return PSA_ERROR_BAD_STATE;
-            }
-
-            operation->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
-        }
-
-        /* Check if step matches current sequence */
-        switch (operation->sequence) {
-            case PSA_PAKE_X1_STEP_KEY_SHARE:
-            case PSA_PAKE_X2_STEP_KEY_SHARE:
-                if (step != PSA_PAKE_STEP_KEY_SHARE) {
-                    return PSA_ERROR_BAD_STATE;
-                }
-                break;
-
-            case PSA_PAKE_X1_STEP_ZK_PUBLIC:
-            case PSA_PAKE_X2_STEP_ZK_PUBLIC:
-                if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
-                    return PSA_ERROR_BAD_STATE;
-                }
-                break;
-
-            case PSA_PAKE_X1_STEP_ZK_PROOF:
-            case PSA_PAKE_X2_STEP_ZK_PROOF:
-                if (step != PSA_PAKE_STEP_ZK_PROOF) {
-                    return PSA_ERROR_BAD_STATE;
-                }
-                break;
-
-            default:
-                return PSA_ERROR_BAD_STATE;
-        }
-
+        const psa_jpake_computation_stage_t *jpake_computation_stage =
+            &computation_stage->data.jpake_computation_stage;
         /*
          * Copy input to local buffer and format it as the Mbed TLS API
          * expects, i.e. as defined by draft-cragie-tls-ecjpake-01 section 7.
@@ -612,8 +440,8 @@
          * ECParameters structure - which means we have to prepend that when
          * we're a client.
          */
-        if (operation->state == PSA_PAKE_INPUT_X4S &&
-            operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
+        if (jpake_computation_stage->state == PSA_PAKE_INPUT_X4S &&
+            jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
             operation->role == PSA_PAKE_ROLE_CLIENT) {
             /* We only support secp256r1. */
             /* This is the ECParameters structure defined by RFC 8422. */
@@ -636,8 +464,8 @@
         operation->buffer_length += input_length;
 
         /* Load buffer at each last round ZK_PROOF */
-        if (operation->state == PSA_PAKE_INPUT_X1_X2 &&
-            operation->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) {
+        if (jpake_computation_stage->state == PSA_PAKE_INPUT_X1_X2 &&
+            jpake_computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) {
             ret = mbedtls_ecjpake_read_round_one(&operation->ctx.pake,
                                                  operation->buffer,
                                                  operation->buffer_length);
@@ -648,8 +476,8 @@
             if (ret != 0) {
                 return mbedtls_ecjpake_to_psa_error(ret);
             }
-        } else if (operation->state == PSA_PAKE_INPUT_X4S &&
-                   operation->sequence == PSA_PAKE_X1_STEP_ZK_PROOF) {
+        } else if (jpake_computation_stage->state == PSA_PAKE_INPUT_X4S &&
+                   jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF) {
             ret = mbedtls_ecjpake_read_round_two(&operation->ctx.pake,
                                                  operation->buffer,
                                                  operation->buffer_length);
@@ -662,21 +490,11 @@
             }
         }
 
-        if ((operation->state == PSA_PAKE_INPUT_X1_X2 &&
-             operation->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
-            (operation->state == PSA_PAKE_INPUT_X4S &&
-             operation->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
-            operation->state = PSA_PAKE_STATE_READY;
-            operation->input_step++;
-            operation->sequence = PSA_PAKE_SEQ_INVALID;
-        } else {
-            operation->sequence++;
-        }
-
         return PSA_SUCCESS;
     } else
 #else
     (void) step;
+    (void) computation_stage;
     (void) input;
     (void) input_length;
 #endif
@@ -685,11 +503,12 @@
 
 psa_status_t mbedtls_psa_pake_input(mbedtls_psa_pake_operation_t *operation,
                                     psa_pake_step_t step,
+                                    const psa_pake_computation_stage_t *computation_stage,
                                     const uint8_t *input,
                                     size_t input_length)
 {
     psa_status_t status = mbedtls_psa_pake_input_internal(
-        operation, step, input, input_length);
+        operation, step, computation_stage, input, input_length);
 
     if (status != PSA_SUCCESS) {
         mbedtls_psa_pake_abort(operation);
@@ -703,18 +522,11 @@
     uint8_t *output, size_t *output_size)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
     if (operation->alg == PSA_ALG_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->input_step != PSA_PAKE_STEP_DERIVE ||
-        operation->output_step != PSA_PAKE_STEP_DERIVE) {
-        status = PSA_ERROR_BAD_STATE;
-        goto error;
-    }
-
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
     if (operation->alg == PSA_ALG_JPAKE) {
         ret = mbedtls_ecjpake_write_shared_key(&operation->ctx.pake,
@@ -740,12 +552,7 @@
 #else
     (void) output;
 #endif
-    { status = PSA_ERROR_NOT_SUPPORTED; }
-
-error:
-    mbedtls_psa_pake_abort(operation);
-
-    return status;
+    { return PSA_ERROR_NOT_SUPPORTED; }
 }
 
 psa_status_t mbedtls_psa_pake_abort(mbedtls_psa_pake_operation_t *operation)
@@ -757,8 +564,6 @@
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
 
     if (operation->alg == PSA_ALG_JPAKE) {
-        operation->input_step = PSA_PAKE_STEP_INVALID;
-        operation->output_step = PSA_PAKE_STEP_INVALID;
         if (operation->password_len > 0) {
             mbedtls_platform_zeroize(operation->password, operation->password_len);
         }
@@ -774,8 +579,6 @@
 #endif
 
     operation->alg = PSA_ALG_NONE;
-    operation->state = PSA_PAKE_STATE_INVALID;
-    operation->sequence = PSA_PAKE_SEQ_INVALID;
 
     return PSA_SUCCESS;
 }
diff --git a/library/psa_crypto_pake.h b/library/psa_crypto_pake.h
index 608d76a..485c93a 100644
--- a/library/psa_crypto_pake.h
+++ b/library/psa_crypto_pake.h
@@ -58,6 +58,7 @@
  * \param[in,out] operation    Active PAKE operation.
  * \param step                 The step of the algorithm for which the output is
  *                             requested.
+ * \param computation_stage    The structure that holds PAKE computation stage.
  * \param[out] output          Buffer where the output is to be written in the
  *                             format appropriate for this \p step. Refer to
  *                             the documentation of the individual
@@ -97,6 +98,7 @@
  */
 psa_status_t mbedtls_psa_pake_output(mbedtls_psa_pake_operation_t *operation,
                                      psa_pake_step_t step,
+                                     const psa_pake_computation_stage_t *computation_stage,
                                      uint8_t *output,
                                      size_t output_size,
                                      size_t *output_length);
@@ -110,6 +112,7 @@
  *
  * \param[in,out] operation    Active PAKE operation.
  * \param step                 The step for which the input is provided.
+ * \param computation_stage    The structure that holds PAKE computation stage.
  * \param[in] input            Buffer containing the input in the format
  *                             appropriate for this \p step. Refer to the
  *                             documentation of the individual
@@ -144,6 +147,7 @@
  */
 psa_status_t mbedtls_psa_pake_input(mbedtls_psa_pake_operation_t *operation,
                                     psa_pake_step_t step,
+                                    const psa_pake_computation_stage_t *computation_stage,
                                     const uint8_t *input,
                                     size_t input_length);
 
diff --git a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
index 21a3b5f..e1a4c9c 100644
--- a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
+++ b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
@@ -2866,6 +2866,7 @@
 psa_status_t psa_driver_wrapper_pake_output(
     psa_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     uint8_t *output,
     size_t output_size,
     size_t *output_length )
@@ -2874,7 +2875,8 @@
     {
 #if defined(MBEDTLS_PSA_BUILTIN_PAKE)
         case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
-            return( mbedtls_psa_pake_output( &operation->data.ctx.mbedtls_ctx, step, output,
+            return( mbedtls_psa_pake_output( &operation->data.ctx.mbedtls_ctx, step,
+                                             computation_stage, output,
                                              output_size, output_length ) );
 #endif /* MBEDTLS_PSA_BUILTIN_PAKE */
 
@@ -2883,15 +2885,16 @@
         case MBEDTLS_TEST_TRANSPARENT_DRIVER_ID:
             return( mbedtls_test_transparent_pake_output(
                         &operation->data.ctx.transparent_test_driver_ctx,
-                        step, output, output_size, output_length ) );
+                        step, computation_stage, output, output_size, output_length ) );
         case MBEDTLS_TEST_OPAQUE_DRIVER_ID:
             return( mbedtls_test_opaque_pake_output(
                         &operation->data.ctx.opaque_test_driver_ctx,
-                        step, output, output_size, output_length ) );
+                        step, computation_stage, output, output_size, output_length ) );
 #endif /* PSA_CRYPTO_DRIVER_TEST */
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
         default:
             (void) step;
+            (void) computation_stage;
             (void) output;
             (void) output_size;
             (void) output_length;
@@ -2902,6 +2905,7 @@
 psa_status_t psa_driver_wrapper_pake_input(
     psa_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     const uint8_t *input,
     size_t input_length )
 {
@@ -2910,7 +2914,8 @@
 #if defined(MBEDTLS_PSA_BUILTIN_PAKE)
         case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
             return( mbedtls_psa_pake_input( &operation->data.ctx.mbedtls_ctx,
-                                            step, input, input_length ) );
+                                            step, computation_stage, input,
+                                            input_length ) );
 #endif /* MBEDTLS_PSA_BUILTIN_PAKE */
 
 #if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
@@ -2918,15 +2923,18 @@
         case MBEDTLS_TEST_TRANSPARENT_DRIVER_ID:
             return( mbedtls_test_transparent_pake_input(
                         &operation->data.ctx.transparent_test_driver_ctx,
-                        step, input, input_length ) );
+                        step, computation_stage,
+                        input, input_length ) );
         case MBEDTLS_TEST_OPAQUE_DRIVER_ID:
             return( mbedtls_test_opaque_pake_input(
                         &operation->data.ctx.opaque_test_driver_ctx,
-                        step, input, input_length ) );
+                        step, computation_stage,
+                        input, input_length ) );
 #endif /* PSA_CRYPTO_DRIVER_TEST */
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
         default:
             (void) step;
+            (void) computation_stage;
             (void) input;
             (void) input_length;
             return( PSA_ERROR_INVALID_ARGUMENT );
diff --git a/tests/include/test/drivers/pake.h b/tests/include/test/drivers/pake.h
index 0412296..1f53008 100644
--- a/tests/include/test/drivers/pake.h
+++ b/tests/include/test/drivers/pake.h
@@ -58,6 +58,7 @@
 psa_status_t mbedtls_test_transparent_pake_output(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     uint8_t *output,
     size_t output_size,
     size_t *output_length);
@@ -65,6 +66,7 @@
 psa_status_t mbedtls_test_transparent_pake_input(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     const uint8_t *input,
     size_t input_length);
 
@@ -102,6 +104,7 @@
 psa_status_t mbedtls_test_opaque_pake_output(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     uint8_t *output,
     size_t output_size,
     size_t *output_length);
@@ -109,6 +112,7 @@
 psa_status_t mbedtls_test_opaque_pake_input(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     const uint8_t *input,
     size_t input_length);
 
diff --git a/tests/src/drivers/test_driver_pake.c b/tests/src/drivers/test_driver_pake.c
index 437c499..21719e6 100644
--- a/tests/src/drivers/test_driver_pake.c
+++ b/tests/src/drivers/test_driver_pake.c
@@ -65,6 +65,7 @@
 psa_status_t mbedtls_test_transparent_pake_output(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     uint8_t *output,
     size_t output_size,
     size_t *output_length)
@@ -92,14 +93,20 @@
         defined(LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_PAKE)
         mbedtls_test_driver_pake_hooks.driver_status =
             libtestdriver1_mbedtls_psa_pake_output(
-                operation, step, output, output_size, output_length);
+                operation,
+                step,
+                (libtestdriver1_psa_pake_computation_stage_t *) computation_stage,
+                output,
+                output_size,
+                output_length);
 #elif defined(MBEDTLS_PSA_BUILTIN_PAKE)
         mbedtls_test_driver_pake_hooks.driver_status =
             mbedtls_psa_pake_output(
-                operation, step, output, output_size, output_length);
+                operation, step, computation_stage, output, output_size, output_length);
 #else
         (void) operation;
         (void) step;
+        (void) computation_stage;
         (void) output;
         (void) output_size;
         (void) output_length;
@@ -113,6 +120,7 @@
 psa_status_t mbedtls_test_transparent_pake_input(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     const uint8_t *input,
     size_t input_length)
 {
@@ -126,14 +134,19 @@
         defined(LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_PAKE)
         mbedtls_test_driver_pake_hooks.driver_status =
             libtestdriver1_mbedtls_psa_pake_input(
-                operation, step, input, input_length);
+                operation,
+                step,
+                (libtestdriver1_psa_pake_computation_stage_t *) computation_stage,
+                input,
+                input_length);
 #elif defined(MBEDTLS_PSA_BUILTIN_PAKE)
         mbedtls_test_driver_pake_hooks.driver_status =
             mbedtls_psa_pake_input(
-                operation, step, input, input_length);
+                operation, step, computation_stage, input, input_length);
 #else
         (void) operation;
         (void) step;
+        (void) computation_stage;
         (void) input;
         (void) input_length;
         mbedtls_test_driver_pake_hooks.driver_status = PSA_ERROR_NOT_SUPPORTED;
@@ -258,12 +271,14 @@
 psa_status_t mbedtls_test_opaque_pake_output(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     uint8_t *output,
     size_t output_size,
     size_t *output_length)
 {
     (void) operation;
     (void) step;
+    (void) computation_stage;
     (void) output;
     (void) output_size;
     (void) output_length;
@@ -274,11 +289,13 @@
 psa_status_t mbedtls_test_opaque_pake_input(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
     psa_pake_step_t step,
+    const psa_pake_computation_stage_t *computation_stage,
     const uint8_t *input,
     size_t input_length)
 {
     (void) operation;
     (void) step;
+    (void) computation_stage;
     (void) input;
     (void) input_length;
     return PSA_ERROR_NOT_SUPPORTED;
diff --git a/tests/suites/test_suite_psa_crypto_pake.data b/tests/suites/test_suite_psa_crypto_pake.data
index 0ec16f0..e4bb92b 100644
--- a/tests/suites/test_suite_psa_crypto_pake.data
+++ b/tests/suites/test_suite_psa_crypto_pake.data
@@ -70,10 +70,6 @@
 depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
 ecjpake_setup:PSA_ALG_JPAKE:PSA_KEY_TYPE_PASSWORD:PSA_KEY_USAGE_DERIVE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:PSA_PAKE_ROLE_SERVER:1:ERR_INJECT_WRONG_BUFFER_SIZE:PSA_ERROR_INVALID_ARGUMENT
 
-PSA PAKE: valid input operation after a failure
-depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_setup:PSA_ALG_JPAKE:PSA_KEY_TYPE_PASSWORD:PSA_KEY_USAGE_DERIVE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:PSA_PAKE_ROLE_SERVER:1:ERR_INJECT_VALID_OPERATION_AFTER_FAILURE:PSA_ERROR_BAD_STATE
-
 PSA PAKE: invalid output
 depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
 ecjpake_setup:PSA_ALG_JPAKE:PSA_KEY_TYPE_PASSWORD:PSA_KEY_USAGE_DERIVE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:PSA_PAKE_ROLE_SERVER:0:ERR_INJECT_EMPTY_IO_BUFFER:PSA_ERROR_INVALID_ARGUMENT
@@ -90,10 +86,6 @@
 depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
 ecjpake_setup:PSA_ALG_JPAKE:PSA_KEY_TYPE_PASSWORD:PSA_KEY_USAGE_DERIVE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:PSA_PAKE_ROLE_SERVER:0:ERR_INJECT_WRONG_BUFFER_SIZE:PSA_ERROR_BUFFER_TOO_SMALL
 
-PSA PAKE: valid output operation after a failure
-depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_setup:PSA_ALG_JPAKE:PSA_KEY_TYPE_PASSWORD:PSA_KEY_USAGE_DERIVE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:PSA_PAKE_ROLE_SERVER:0:ERR_INJECT_VALID_OPERATION_AFTER_FAILURE:PSA_ERROR_BAD_STATE
-
 PSA PAKE: check rounds w/o forced errors
 depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256:PSA_WANT_ALG_TLS12_PSK_TO_MS
 ecjpake_rounds:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:PSA_ALG_TLS12_PSK_TO_MS(PSA_ALG_SHA_256):"abcdef":0:0:ERR_NONE