Combine core pake computation stage(step,sequence,state) into single driver step

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/include/psa/crypto_extra.h b/include/psa/crypto_extra.h
index fa6ef4e..83c7e04 100644
--- a/include/psa/crypto_extra.h
+++ b/include/psa/crypto_extra.h
@@ -1292,12 +1292,15 @@
 /** The type of input values for PAKE operations. */
 typedef struct psa_crypto_driver_pake_inputs_s psa_crypto_driver_pake_inputs_t;
 
-/** The type of compuatation stage for PAKE operations. */
+/** The type of computation stage for PAKE operations. */
 typedef struct psa_pake_computation_stage_s psa_pake_computation_stage_t;
 
-/** The type of compuatation stage for J-PAKE operations. */
+/** The type of computation stage for J-PAKE operations. */
 typedef struct psa_jpake_computation_stage_s psa_jpake_computation_stage_t;
 
+/** The type of driver step for PAKE operation. */
+typedef enum psa_pake_driver_step psa_pake_driver_step_t;
+
 /** Return an initial value for a PAKE operation object.
  */
 static psa_pake_operation_t psa_pake_operation_init(void);
@@ -1946,21 +1949,23 @@
     PSA_PAKE_SEQ_END            = 7,
 };
 
-enum psa_jpake_computation_state {
-    PSA_PAKE_X1_STEP_KEY_SHARE   = 1,  /* Round 1: input/output key share (for ephemeral private key X1).*/
-    PSA_PAKE_X1_STEP_ZK_PUBLIC   = 2,  /* Round 1: input/output Schnorr NIZKP public key for the X1 key */
-    PSA_PAKE_X1_STEP_ZK_PROOF    = 3,  /* Round 1: input/output Schnorr NIZKP proof for the X1 key */
-    PSA_PAKE_X2_STEP_KEY_SHARE   = 4,  /* Round 1: input/output key share (for ephemeral private key X2).*/
-    PSA_PAKE_X2_STEP_ZK_PUBLIC   = 5,  /* Round 1: input/output Schnorr NIZKP public key for the X2 key */
-    PSA_PAKE_X2_STEP_ZK_PROOF    = 6,  /* Round 1: input/output Schnorr NIZKP proof for the X2 key */
-    PSA_PAKE_X2S_STEP_KEY_SHARE  = 7,  /* Round 2: output X2S key (our key) */
-    PSA_PAKE_X2S_STEP_ZK_PUBLIC  = 8,  /* Round 2: output Schnorr NIZKP public key for the X2S key (our key) */
-    PSA_PAKE_X2S_STEP_ZK_PROOF   = 9,  /* Round 2: output Schnorr NIZKP proof for the X2S key (our key) */
-    PSA_PAKE_X4S_STEP_KEY_SHARE  = 10, /* Round 2: input X4S key (from peer) */
-    PSA_PAKE_X4S_STEP_ZK_PUBLIC  = 11, /* Round 2: input Schnorr NIZKP public key for the X4S key (from peer) */
-    PSA_PAKE_X4S_STEP_ZK_PROOF   = 12  /* Round 2: input Schnorr NIZKP proof for the X4S key (from peer) */
+enum psa_pake_driver_step {
+    PSA_JPAKE_STEP_INVALID        = 0,  /* Invalid step */
+    PSA_JPAKE_X1_STEP_KEY_SHARE   = 1,  /* Round 1: input/output key share (for ephemeral private key X1).*/
+    PSA_JPAKE_X1_STEP_ZK_PUBLIC   = 2,  /* Round 1: input/output Schnorr NIZKP public key for the X1 key */
+    PSA_JPAKE_X1_STEP_ZK_PROOF    = 3,  /* Round 1: input/output Schnorr NIZKP proof for the X1 key */
+    PSA_JPAKE_X2_STEP_KEY_SHARE   = 4,  /* Round 1: input/output key share (for ephemeral private key X2).*/
+    PSA_JPAKE_X2_STEP_ZK_PUBLIC   = 5,  /* Round 1: input/output Schnorr NIZKP public key for the X2 key */
+    PSA_JPAKE_X2_STEP_ZK_PROOF    = 6,  /* Round 1: input/output Schnorr NIZKP proof for the X2 key */
+    PSA_JPAKE_X2S_STEP_KEY_SHARE  = 7,  /* Round 2: output X2S key (our key) */
+    PSA_JPAKE_X2S_STEP_ZK_PUBLIC  = 8,  /* Round 2: output Schnorr NIZKP public key for the X2S key (our key) */
+    PSA_JPAKE_X2S_STEP_ZK_PROOF   = 9,  /* Round 2: output Schnorr NIZKP proof for the X2S key (our key) */
+    PSA_JPAKE_X4S_STEP_KEY_SHARE  = 10, /* Round 2: input X4S key (from peer) */
+    PSA_JPAKE_X4S_STEP_ZK_PUBLIC  = 11, /* Round 2: input Schnorr NIZKP public key for the X4S key (from peer) */
+    PSA_JPAKE_X4S_STEP_ZK_PROOF   = 12  /* Round 2: input Schnorr NIZKP proof for the X4S key (from peer) */
 };
 
+
 struct psa_jpake_computation_stage_s {
     unsigned int MBEDTLS_PRIVATE(state);
     unsigned int MBEDTLS_PRIVATE(sequence);
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index f7b0270..09d46ed 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7332,6 +7332,70 @@
     return PSA_SUCCESS;
 }
 
+/* Auxiliary function to convert core computation stage(step, sequence, state) to single driver step. */
+static psa_pake_driver_step_t convert_jpake_computation_stage_to_driver_step(
+    psa_pake_computation_stage_t *stage)
+{
+    switch (stage->data.jpake_computation_stage.state) {
+        case PSA_PAKE_OUTPUT_X1_X2:
+        case PSA_PAKE_INPUT_X1_X2:
+            switch (stage->data.jpake_computation_stage.sequence) {
+                case PSA_PAKE_X1_STEP_KEY_SHARE:
+                    return PSA_JPAKE_X1_STEP_KEY_SHARE;
+                    break;
+                case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+                    return PSA_JPAKE_X1_STEP_ZK_PUBLIC;
+                    break;
+                case PSA_PAKE_X1_STEP_ZK_PROOF:
+                    return PSA_JPAKE_X1_STEP_ZK_PROOF;
+                    break;
+                case PSA_PAKE_X2_STEP_KEY_SHARE:
+                    return PSA_JPAKE_X2_STEP_KEY_SHARE;
+                    break;
+                case PSA_PAKE_X2_STEP_ZK_PUBLIC:
+                    return PSA_JPAKE_X2_STEP_ZK_PUBLIC;
+                    break;
+                case PSA_PAKE_X2_STEP_ZK_PROOF:
+                    return PSA_JPAKE_X2_STEP_ZK_PROOF;
+                    break;
+                default:
+                    return PSA_JPAKE_STEP_INVALID;
+            }
+            break;
+        case PSA_PAKE_OUTPUT_X2S:
+            switch (stage->data.jpake_computation_stage.sequence) {
+                case PSA_PAKE_X1_STEP_KEY_SHARE:
+                    return PSA_JPAKE_X2S_STEP_KEY_SHARE;
+                    break;
+                case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+                    return PSA_JPAKE_X2S_STEP_ZK_PUBLIC;
+                    break;
+                case PSA_PAKE_X1_STEP_ZK_PROOF:
+                    return PSA_JPAKE_X2S_STEP_ZK_PROOF;
+                    break;
+                    return PSA_JPAKE_STEP_INVALID;
+            }
+            break;
+        case PSA_PAKE_INPUT_X4S:
+            switch (stage->data.jpake_computation_stage.sequence) {
+                case PSA_PAKE_X1_STEP_KEY_SHARE:
+                    return PSA_JPAKE_X4S_STEP_KEY_SHARE;
+                    break;
+                case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+                    return PSA_JPAKE_X4S_STEP_ZK_PUBLIC;
+                    break;
+                case PSA_PAKE_X1_STEP_ZK_PROOF:
+                    return PSA_JPAKE_X4S_STEP_ZK_PROOF;
+                    break;
+                    return PSA_JPAKE_STEP_INVALID;
+            }
+            break;
+        default:
+            return PSA_JPAKE_STEP_INVALID;
+    }
+    return PSA_JPAKE_STEP_INVALID;
+}
+
 static psa_status_t psa_pake_complete_inputs(
     psa_pake_operation_t *operation)
 {
@@ -7501,9 +7565,14 @@
             return PSA_ERROR_NOT_SUPPORTED;
     }
 
-    status = psa_driver_wrapper_pake_output(operation, step,
-                                            &operation->computation_stage,
-                                            output, output_size, output_length);
+    status = psa_driver_wrapper_pake_output(operation,
+                                            convert_jpake_computation_stage_to_driver_step(&
+                                                                                           operation
+                                                                                           ->
+                                                                                           computation_stage),
+                                            output,
+                                            output_size,
+                                            output_length);
 
     if (status != PSA_SUCCESS) {
         return status;
@@ -7660,9 +7729,12 @@
             return PSA_ERROR_NOT_SUPPORTED;
     }
 
-    status = psa_driver_wrapper_pake_input(operation, step,
-                                           &operation->computation_stage,
-                                           input, input_length);
+    status = psa_driver_wrapper_pake_input(operation,
+                                           convert_jpake_computation_stage_to_driver_step(&operation
+                                                                                          ->
+                                                                                          computation_stage),
+                                           input,
+                                           input_length);
 
     if (status != PSA_SUCCESS) {
         return status;
diff --git a/library/psa_crypto_driver_wrappers.h b/library/psa_crypto_driver_wrappers.h
index ac17be4..11a95e3 100644
--- a/library/psa_crypto_driver_wrappers.h
+++ b/library/psa_crypto_driver_wrappers.h
@@ -421,16 +421,14 @@
 
 psa_status_t psa_driver_wrapper_pake_output(
     psa_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length);
 
 psa_status_t psa_driver_wrapper_pake_input(
     psa_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     const uint8_t *input,
     size_t input_length);
 
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index a238147..da10cdd 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -266,8 +266,7 @@
 
 static psa_status_t mbedtls_psa_pake_output_internal(
     mbedtls_psa_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length)
@@ -292,12 +291,8 @@
      * to return the right parts on each step.
      */
     if (operation->alg == PSA_ALG_JPAKE) {
-        const psa_jpake_computation_stage_t *jpake_computation_stage =
-            &computation_stage->data.jpake_computation_stage;
-
         /* Initialize & write round on KEY_SHARE sequences */
-        if (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
-            jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE) {
+        if (step == PSA_JPAKE_X1_STEP_KEY_SHARE) {
             ret = mbedtls_ecjpake_write_round_one(&operation->ctx.pake,
                                                   operation->buffer,
                                                   MBEDTLS_PSA_PAKE_BUFFER_SIZE,
@@ -309,8 +304,7 @@
             }
 
             operation->buffer_offset = 0;
-        } else if (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
-                   jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE) {
+        } else if (step == PSA_JPAKE_X2S_STEP_KEY_SHARE) {
             ret = mbedtls_ecjpake_write_round_two(&operation->ctx.pake,
                                                   operation->buffer,
                                                   MBEDTLS_PSA_PAKE_BUFFER_SIZE,
@@ -335,8 +329,7 @@
          * output with a length byte, even less a curve identifier, as that
          * information is already available.
          */
-        if (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
-            jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
+        if (step == PSA_JPAKE_X2S_STEP_KEY_SHARE &&
             operation->role == PSA_PAKE_ROLE_SERVER) {
             /* Skip ECParameters, with is 3 bytes (RFC 8422) */
             operation->buffer_offset += 3;
@@ -362,10 +355,8 @@
         operation->buffer_offset += length;
 
         /* Reset buffer after ZK_PROOF sequence */
-        if ((jpake_computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
-             jpake_computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
-            (jpake_computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
-             jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
+        if ((step == PSA_JPAKE_X2_STEP_ZK_PROOF) ||
+            (step == PSA_JPAKE_X2S_STEP_ZK_PROOF)) {
             mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
             operation->buffer_length = 0;
             operation->buffer_offset = 0;
@@ -375,7 +366,6 @@
     } else
 #else
     (void) step;
-    (void) computation_stage;
     (void) output;
     (void) output_size;
     (void) output_length;
@@ -384,14 +374,13 @@
 }
 
 psa_status_t mbedtls_psa_pake_output(mbedtls_psa_pake_operation_t *operation,
-                                     psa_pake_step_t step,
-                                     const psa_pake_computation_stage_t *computation_stage,
+                                     psa_pake_driver_step_t step,
                                      uint8_t *output,
                                      size_t output_size,
                                      size_t *output_length)
 {
     psa_status_t status = mbedtls_psa_pake_output_internal(
-        operation, step, computation_stage, output, output_size, output_length);
+        operation, step, output, output_size, output_length);
 
     if (status != PSA_SUCCESS) {
         mbedtls_psa_pake_abort(operation);
@@ -402,8 +391,7 @@
 
 static psa_status_t mbedtls_psa_pake_input_internal(
     mbedtls_psa_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     const uint8_t *input,
     size_t input_length)
 {
@@ -427,8 +415,6 @@
      * This causes any input error to be only detected on the last step.
      */
     if (operation->alg == PSA_ALG_JPAKE) {
-        const psa_jpake_computation_stage_t *jpake_computation_stage =
-            &computation_stage->data.jpake_computation_stage;
         /*
          * Copy input to local buffer and format it as the Mbed TLS API
          * expects, i.e. as defined by draft-cragie-tls-ecjpake-01 section 7.
@@ -438,8 +424,7 @@
          * ECParameters structure - which means we have to prepend that when
          * we're a client.
          */
-        if (jpake_computation_stage->state == PSA_PAKE_INPUT_X4S &&
-            jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
+        if (step == PSA_JPAKE_X4S_STEP_KEY_SHARE &&
             operation->role == PSA_PAKE_ROLE_CLIENT) {
             /* We only support secp256r1. */
             /* This is the ECParameters structure defined by RFC 8422. */
@@ -462,8 +447,7 @@
         operation->buffer_length += input_length;
 
         /* Load buffer at each last round ZK_PROOF */
-        if (jpake_computation_stage->state == PSA_PAKE_INPUT_X1_X2 &&
-            jpake_computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) {
+        if (step == PSA_JPAKE_X2_STEP_ZK_PROOF) {
             ret = mbedtls_ecjpake_read_round_one(&operation->ctx.pake,
                                                  operation->buffer,
                                                  operation->buffer_length);
@@ -474,8 +458,7 @@
             if (ret != 0) {
                 return mbedtls_ecjpake_to_psa_error(ret);
             }
-        } else if (jpake_computation_stage->state == PSA_PAKE_INPUT_X4S &&
-                   jpake_computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF) {
+        } else if (step == PSA_JPAKE_X4S_STEP_ZK_PROOF) {
             ret = mbedtls_ecjpake_read_round_two(&operation->ctx.pake,
                                                  operation->buffer,
                                                  operation->buffer_length);
@@ -492,7 +475,6 @@
     } else
 #else
     (void) step;
-    (void) computation_stage;
     (void) input;
     (void) input_length;
 #endif
@@ -500,13 +482,12 @@
 }
 
 psa_status_t mbedtls_psa_pake_input(mbedtls_psa_pake_operation_t *operation,
-                                    psa_pake_step_t step,
-                                    const psa_pake_computation_stage_t *computation_stage,
+                                    psa_pake_driver_step_t step,
                                     const uint8_t *input,
                                     size_t input_length)
 {
     psa_status_t status = mbedtls_psa_pake_input_internal(
-        operation, step, computation_stage, input, input_length);
+        operation, step, input, input_length);
 
     if (status != PSA_SUCCESS) {
         mbedtls_psa_pake_abort(operation);
diff --git a/library/psa_crypto_pake.h b/library/psa_crypto_pake.h
index 485c93a..dc6ad7b 100644
--- a/library/psa_crypto_pake.h
+++ b/library/psa_crypto_pake.h
@@ -58,7 +58,6 @@
  * \param[in,out] operation    Active PAKE operation.
  * \param step                 The step of the algorithm for which the output is
  *                             requested.
- * \param computation_stage    The structure that holds PAKE computation stage.
  * \param[out] output          Buffer where the output is to be written in the
  *                             format appropriate for this \p step. Refer to
  *                             the documentation of the individual
@@ -97,8 +96,7 @@
  *         results in this error code.
  */
 psa_status_t mbedtls_psa_pake_output(mbedtls_psa_pake_operation_t *operation,
-                                     psa_pake_step_t step,
-                                     const psa_pake_computation_stage_t *computation_stage,
+                                     psa_pake_driver_step_t step,
                                      uint8_t *output,
                                      size_t output_size,
                                      size_t *output_length);
@@ -112,7 +110,6 @@
  *
  * \param[in,out] operation    Active PAKE operation.
  * \param step                 The step for which the input is provided.
- * \param computation_stage    The structure that holds PAKE computation stage.
  * \param[in] input            Buffer containing the input in the format
  *                             appropriate for this \p step. Refer to the
  *                             documentation of the individual
@@ -146,8 +143,7 @@
  *         results in this error code.
  */
 psa_status_t mbedtls_psa_pake_input(mbedtls_psa_pake_operation_t *operation,
-                                    psa_pake_step_t step,
-                                    const psa_pake_computation_stage_t *computation_stage,
+                                    psa_pake_driver_step_t step,
                                     const uint8_t *input,
                                     size_t input_length);
 
diff --git a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
index e1a4c9c..d7dabed 100644
--- a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
+++ b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
@@ -2865,8 +2865,7 @@
 }
 psa_status_t psa_driver_wrapper_pake_output(
     psa_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length )
@@ -2876,8 +2875,7 @@
 #if defined(MBEDTLS_PSA_BUILTIN_PAKE)
         case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
             return( mbedtls_psa_pake_output( &operation->data.ctx.mbedtls_ctx, step,
-                                             computation_stage, output,
-                                             output_size, output_length ) );
+                                             output, output_size, output_length ) );
 #endif /* MBEDTLS_PSA_BUILTIN_PAKE */
 
 #if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
@@ -2885,16 +2883,15 @@
         case MBEDTLS_TEST_TRANSPARENT_DRIVER_ID:
             return( mbedtls_test_transparent_pake_output(
                         &operation->data.ctx.transparent_test_driver_ctx,
-                        step, computation_stage, output, output_size, output_length ) );
+                        step, output, output_size, output_length ) );
         case MBEDTLS_TEST_OPAQUE_DRIVER_ID:
             return( mbedtls_test_opaque_pake_output(
                         &operation->data.ctx.opaque_test_driver_ctx,
-                        step, computation_stage, output, output_size, output_length ) );
+                        step, output, output_size, output_length ) );
 #endif /* PSA_CRYPTO_DRIVER_TEST */
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
         default:
             (void) step;
-            (void) computation_stage;
             (void) output;
             (void) output_size;
             (void) output_length;
@@ -2904,8 +2901,7 @@
 
 psa_status_t psa_driver_wrapper_pake_input(
     psa_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     const uint8_t *input,
     size_t input_length )
 {
@@ -2914,7 +2910,7 @@
 #if defined(MBEDTLS_PSA_BUILTIN_PAKE)
         case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
             return( mbedtls_psa_pake_input( &operation->data.ctx.mbedtls_ctx,
-                                            step, computation_stage, input,
+                                            step, input,
                                             input_length ) );
 #endif /* MBEDTLS_PSA_BUILTIN_PAKE */
 
@@ -2923,18 +2919,17 @@
         case MBEDTLS_TEST_TRANSPARENT_DRIVER_ID:
             return( mbedtls_test_transparent_pake_input(
                         &operation->data.ctx.transparent_test_driver_ctx,
-                        step, computation_stage,
+                        step,
                         input, input_length ) );
         case MBEDTLS_TEST_OPAQUE_DRIVER_ID:
             return( mbedtls_test_opaque_pake_input(
                         &operation->data.ctx.opaque_test_driver_ctx,
-                        step, computation_stage,
+                        step,
                         input, input_length ) );
 #endif /* PSA_CRYPTO_DRIVER_TEST */
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
         default:
             (void) step;
-            (void) computation_stage;
             (void) input;
             (void) input_length;
             return( PSA_ERROR_INVALID_ARGUMENT );
diff --git a/tests/include/test/drivers/pake.h b/tests/include/test/drivers/pake.h
index 1f53008..23cb98a 100644
--- a/tests/include/test/drivers/pake.h
+++ b/tests/include/test/drivers/pake.h
@@ -57,16 +57,14 @@
 
 psa_status_t mbedtls_test_transparent_pake_output(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length);
 
 psa_status_t mbedtls_test_transparent_pake_input(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     const uint8_t *input,
     size_t input_length);
 
@@ -103,16 +101,14 @@
 
 psa_status_t mbedtls_test_opaque_pake_output(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length);
 
 psa_status_t mbedtls_test_opaque_pake_input(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     const uint8_t *input,
     size_t input_length);
 
diff --git a/tests/src/drivers/test_driver_pake.c b/tests/src/drivers/test_driver_pake.c
index 21719e6..e0be17d 100644
--- a/tests/src/drivers/test_driver_pake.c
+++ b/tests/src/drivers/test_driver_pake.c
@@ -64,8 +64,7 @@
 
 psa_status_t mbedtls_test_transparent_pake_output(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length)
@@ -93,20 +92,14 @@
         defined(LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_PAKE)
         mbedtls_test_driver_pake_hooks.driver_status =
             libtestdriver1_mbedtls_psa_pake_output(
-                operation,
-                step,
-                (libtestdriver1_psa_pake_computation_stage_t *) computation_stage,
-                output,
-                output_size,
-                output_length);
+                operation, step, output, output_size, output_length);
 #elif defined(MBEDTLS_PSA_BUILTIN_PAKE)
         mbedtls_test_driver_pake_hooks.driver_status =
             mbedtls_psa_pake_output(
-                operation, step, computation_stage, output, output_size, output_length);
+                operation, step, output, output_size, output_length);
 #else
         (void) operation;
         (void) step;
-        (void) computation_stage;
         (void) output;
         (void) output_size;
         (void) output_length;
@@ -119,8 +112,7 @@
 
 psa_status_t mbedtls_test_transparent_pake_input(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     const uint8_t *input,
     size_t input_length)
 {
@@ -134,19 +126,14 @@
         defined(LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_PAKE)
         mbedtls_test_driver_pake_hooks.driver_status =
             libtestdriver1_mbedtls_psa_pake_input(
-                operation,
-                step,
-                (libtestdriver1_psa_pake_computation_stage_t *) computation_stage,
-                input,
-                input_length);
+                operation, step, input, input_length);
 #elif defined(MBEDTLS_PSA_BUILTIN_PAKE)
         mbedtls_test_driver_pake_hooks.driver_status =
             mbedtls_psa_pake_input(
-                operation, step, computation_stage, input, input_length);
+                operation, step, input, input_length);
 #else
         (void) operation;
         (void) step;
-        (void) computation_stage;
         (void) input;
         (void) input_length;
         mbedtls_test_driver_pake_hooks.driver_status = PSA_ERROR_NOT_SUPPORTED;
@@ -270,15 +257,13 @@
 
 psa_status_t mbedtls_test_opaque_pake_output(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length)
 {
     (void) operation;
     (void) step;
-    (void) computation_stage;
     (void) output;
     (void) output_size;
     (void) output_length;
@@ -288,14 +273,12 @@
 
 psa_status_t mbedtls_test_opaque_pake_input(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
-    psa_pake_step_t step,
-    const psa_pake_computation_stage_t *computation_stage,
+    psa_pake_driver_step_t step,
     const uint8_t *input,
     size_t input_length)
 {
     (void) operation;
     (void) step;
-    (void) computation_stage;
     (void) input;
     (void) input_length;
     return PSA_ERROR_NOT_SUPPORTED;