Merge pull request #7719 from davidhorstmann-arm/second-jpake-state-machine-rework

Change J-PAKE internal state machine
diff --git a/include/psa/crypto_extra.h b/include/psa/crypto_extra.h
index 5529dd1..94def5c 100644
--- a/include/psa/crypto_extra.h
+++ b/include/psa/crypto_extra.h
@@ -1984,34 +1984,6 @@
     psa_pake_cipher_suite_t MBEDTLS_PRIVATE(cipher_suite);
 };
 
-typedef enum psa_jpake_step {
-    PSA_PAKE_STEP_INVALID       = 0,
-    PSA_PAKE_STEP_X1_X2         = 1,
-    PSA_PAKE_STEP_X2S           = 2,
-    PSA_PAKE_STEP_DERIVE        = 3,
-} psa_jpake_step_t;
-
-typedef enum psa_jpake_state {
-    PSA_PAKE_STATE_INVALID      = 0,
-    PSA_PAKE_STATE_SETUP        = 1,
-    PSA_PAKE_STATE_READY        = 2,
-    PSA_PAKE_OUTPUT_X1_X2       = 3,
-    PSA_PAKE_OUTPUT_X2S         = 4,
-    PSA_PAKE_INPUT_X1_X2        = 5,
-    PSA_PAKE_INPUT_X4S          = 6,
-} psa_jpake_state_t;
-
-typedef enum psa_jpake_sequence {
-    PSA_PAKE_SEQ_INVALID        = 0,
-    PSA_PAKE_X1_STEP_KEY_SHARE  = 1,    /* also X2S & X4S KEY_SHARE */
-    PSA_PAKE_X1_STEP_ZK_PUBLIC  = 2,    /* also X2S & X4S ZK_PUBLIC */
-    PSA_PAKE_X1_STEP_ZK_PROOF   = 3,    /* also X2S & X4S ZK_PROOF */
-    PSA_PAKE_X2_STEP_KEY_SHARE  = 4,
-    PSA_PAKE_X2_STEP_ZK_PUBLIC  = 5,
-    PSA_PAKE_X2_STEP_ZK_PROOF   = 6,
-    PSA_PAKE_SEQ_END            = 7,
-} psa_jpake_sequence_t;
-
 typedef enum psa_crypto_driver_pake_step {
     PSA_JPAKE_STEP_INVALID        = 0,  /* Invalid step */
     PSA_JPAKE_X1_STEP_KEY_SHARE   = 1,  /* Round 1: input/output key share (for ephemeral private key X1).*/
@@ -2028,14 +2000,35 @@
     PSA_JPAKE_X4S_STEP_ZK_PROOF   = 12  /* Round 2: input Schnorr NIZKP proof for the X4S key (from peer) */
 } psa_crypto_driver_pake_step_t;
 
+typedef enum psa_jpake_round {
+    PSA_JPAKE_FIRST = 0,
+    PSA_JPAKE_SECOND = 1,
+    PSA_JPAKE_FINISHED = 2
+} psa_jpake_round_t;
+
+typedef enum psa_jpake_io_mode {
+    PSA_JPAKE_INPUT = 0,
+    PSA_JPAKE_OUTPUT = 1
+} psa_jpake_io_mode_t;
 
 struct psa_jpake_computation_stage_s {
-    psa_jpake_state_t MBEDTLS_PRIVATE(state);
-    psa_jpake_sequence_t MBEDTLS_PRIVATE(sequence);
-    psa_jpake_step_t MBEDTLS_PRIVATE(input_step);
-    psa_jpake_step_t MBEDTLS_PRIVATE(output_step);
+    /* The J-PAKE round we are currently on */
+    psa_jpake_round_t MBEDTLS_PRIVATE(round);
+    /* The 'mode' we are currently in (inputting or outputting) */
+    psa_jpake_io_mode_t MBEDTLS_PRIVATE(io_mode);
+    /* The number of completed inputs so far this round */
+    uint8_t MBEDTLS_PRIVATE(inputs);
+    /* The number of completed outputs so far this round */
+    uint8_t MBEDTLS_PRIVATE(outputs);
+    /* The next expected step (KEY_SHARE, ZK_PUBLIC or ZK_PROOF) */
+    psa_pake_step_t MBEDTLS_PRIVATE(step);
 };
 
+#define PSA_JPAKE_EXPECTED_INPUTS(round) ((round) == PSA_JPAKE_FINISHED ? 0 : \
+                                          ((round) == PSA_JPAKE_FIRST ? 2 : 1))
+#define PSA_JPAKE_EXPECTED_OUTPUTS(round) ((round) == PSA_JPAKE_FINISHED ? 0 : \
+                                           ((round) == PSA_JPAKE_FIRST ? 2 : 1))
+
 struct psa_pake_operation_s {
     /** Unique ID indicating which driver got assigned to do the
      * operation. Since driver contexts are driver-specific, swapping
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 08a2b19..ac6bd5b 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7765,10 +7765,8 @@
         psa_jpake_computation_stage_t *computation_stage =
             &operation->computation_stage.jpake;
 
-        computation_stage->state = PSA_PAKE_STATE_SETUP;
-        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
-        computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
-        computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+        memset(computation_stage, 0, sizeof(*computation_stage));
+        computation_stage->step = PSA_PAKE_STEP_KEY_SHARE;
     } else
 #endif /* PSA_WANT_ALG_JPAKE */
     {
@@ -7937,59 +7935,32 @@
     return status;
 }
 
-/* Auxiliary function to convert core computation stage(step, sequence, state) to single driver step. */
+/* Auxiliary function to convert core computation stage to single driver step. */
 #if defined(PSA_WANT_ALG_JPAKE)
 static psa_crypto_driver_pake_step_t convert_jpake_computation_stage_to_driver_step(
     psa_jpake_computation_stage_t *stage)
 {
-    switch (stage->state) {
-        case PSA_PAKE_OUTPUT_X1_X2:
-        case PSA_PAKE_INPUT_X1_X2:
-            switch (stage->sequence) {
-                case PSA_PAKE_X1_STEP_KEY_SHARE:
-                    return PSA_JPAKE_X1_STEP_KEY_SHARE;
-                case PSA_PAKE_X1_STEP_ZK_PUBLIC:
-                    return PSA_JPAKE_X1_STEP_ZK_PUBLIC;
-                case PSA_PAKE_X1_STEP_ZK_PROOF:
-                    return PSA_JPAKE_X1_STEP_ZK_PROOF;
-                case PSA_PAKE_X2_STEP_KEY_SHARE:
-                    return PSA_JPAKE_X2_STEP_KEY_SHARE;
-                case PSA_PAKE_X2_STEP_ZK_PUBLIC:
-                    return PSA_JPAKE_X2_STEP_ZK_PUBLIC;
-                case PSA_PAKE_X2_STEP_ZK_PROOF:
-                    return PSA_JPAKE_X2_STEP_ZK_PROOF;
-                default:
-                    return PSA_JPAKE_STEP_INVALID;
-            }
-            break;
-        case PSA_PAKE_OUTPUT_X2S:
-            switch (stage->sequence) {
-                case PSA_PAKE_X1_STEP_KEY_SHARE:
-                    return PSA_JPAKE_X2S_STEP_KEY_SHARE;
-                case PSA_PAKE_X1_STEP_ZK_PUBLIC:
-                    return PSA_JPAKE_X2S_STEP_ZK_PUBLIC;
-                case PSA_PAKE_X1_STEP_ZK_PROOF:
-                    return PSA_JPAKE_X2S_STEP_ZK_PROOF;
-                default:
-                    return PSA_JPAKE_STEP_INVALID;
-            }
-            break;
-        case PSA_PAKE_INPUT_X4S:
-            switch (stage->sequence) {
-                case PSA_PAKE_X1_STEP_KEY_SHARE:
-                    return PSA_JPAKE_X4S_STEP_KEY_SHARE;
-                case PSA_PAKE_X1_STEP_ZK_PUBLIC:
-                    return PSA_JPAKE_X4S_STEP_ZK_PUBLIC;
-                case PSA_PAKE_X1_STEP_ZK_PROOF:
-                    return PSA_JPAKE_X4S_STEP_ZK_PROOF;
-                default:
-                    return PSA_JPAKE_STEP_INVALID;
-            }
-            break;
-        default:
-            return PSA_JPAKE_STEP_INVALID;
+    psa_crypto_driver_pake_step_t key_share_step;
+    if (stage->round == PSA_JPAKE_FIRST) {
+        int is_x1;
+
+        if (stage->io_mode == PSA_JPAKE_OUTPUT) {
+            is_x1 = (stage->outputs < 1);
+        } else {
+            is_x1 = (stage->inputs < 1);
+        }
+
+        key_share_step = is_x1 ?
+                         PSA_JPAKE_X1_STEP_KEY_SHARE :
+                         PSA_JPAKE_X2_STEP_KEY_SHARE;
+    } else if (stage->round == PSA_JPAKE_SECOND) {
+        key_share_step = (stage->io_mode == PSA_JPAKE_OUTPUT) ?
+                         PSA_JPAKE_X2S_STEP_KEY_SHARE :
+                         PSA_JPAKE_X4S_STEP_KEY_SHARE;
+    } else {
+        return PSA_JPAKE_STEP_INVALID;
     }
-    return PSA_JPAKE_STEP_INVALID;
+    return key_share_step + stage->step - PSA_PAKE_STEP_KEY_SHARE;
 }
 #endif /* PSA_WANT_ALG_JPAKE */
 
@@ -8028,12 +7999,6 @@
 #if defined(PSA_WANT_ALG_JPAKE)
         if (operation->alg == PSA_ALG_JPAKE) {
             operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
-            psa_jpake_computation_stage_t *computation_stage =
-                &operation->computation_stage.jpake;
-            computation_stage->state = PSA_PAKE_STATE_READY;
-            computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
-            computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
-            computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
         } else
 #endif /* PSA_WANT_ALG_JPAKE */
         {
@@ -8044,9 +8009,10 @@
 }
 
 #if defined(PSA_WANT_ALG_JPAKE)
-static psa_status_t psa_jpake_output_prologue(
+static psa_status_t psa_jpake_prologue(
     psa_pake_operation_t *operation,
-    psa_pake_step_t step)
+    psa_pake_step_t step,
+    psa_jpake_io_mode_t io_mode)
 {
     if (step != PSA_PAKE_STEP_KEY_SHARE &&
         step != PSA_PAKE_STEP_ZK_PUBLIC &&
@@ -8057,84 +8023,66 @@
     psa_jpake_computation_stage_t *computation_stage =
         &operation->computation_stage.jpake;
 
-    if (computation_stage->state == PSA_PAKE_STATE_INVALID) {
+    if (computation_stage->round != PSA_JPAKE_FIRST &&
+        computation_stage->round != PSA_JPAKE_SECOND) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (computation_stage->state != PSA_PAKE_STATE_READY &&
-        computation_stage->state != PSA_PAKE_OUTPUT_X1_X2 &&
-        computation_stage->state != PSA_PAKE_OUTPUT_X2S) {
+    /* Check that the step we are given is the one we were expecting */
+    if (step != computation_stage->step) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (computation_stage->state == PSA_PAKE_STATE_READY) {
-        if (step != PSA_PAKE_STEP_KEY_SHARE) {
-            return PSA_ERROR_BAD_STATE;
-        }
-
-        switch (computation_stage->output_step) {
-            case PSA_PAKE_STEP_X1_X2:
-                computation_stage->state = PSA_PAKE_OUTPUT_X1_X2;
-                break;
-            case PSA_PAKE_STEP_X2S:
-                computation_stage->state = PSA_PAKE_OUTPUT_X2S;
-                break;
-            default:
-                return PSA_ERROR_BAD_STATE;
-        }
-
-        computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
-    }
-
-    /* Check if step matches current sequence */
-    switch (computation_stage->sequence) {
-        case PSA_PAKE_X1_STEP_KEY_SHARE:
-        case PSA_PAKE_X2_STEP_KEY_SHARE:
-            if (step != PSA_PAKE_STEP_KEY_SHARE) {
-                return PSA_ERROR_BAD_STATE;
-            }
-            break;
-
-        case PSA_PAKE_X1_STEP_ZK_PUBLIC:
-        case PSA_PAKE_X2_STEP_ZK_PUBLIC:
-            if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
-                return PSA_ERROR_BAD_STATE;
-            }
-            break;
-
-        case PSA_PAKE_X1_STEP_ZK_PROOF:
-        case PSA_PAKE_X2_STEP_ZK_PROOF:
-            if (step != PSA_PAKE_STEP_ZK_PROOF) {
-                return PSA_ERROR_BAD_STATE;
-            }
-            break;
-
-        default:
-            return PSA_ERROR_BAD_STATE;
+    if (step == PSA_PAKE_STEP_KEY_SHARE &&
+        computation_stage->inputs == 0 &&
+        computation_stage->outputs == 0) {
+        /* Start of the round, so function decides whether we are inputting
+         * or outputting */
+        computation_stage->io_mode = io_mode;
+    } else if (computation_stage->io_mode != io_mode) {
+        /* Middle of the round so the mode we are in must match the function
+         * called by the user */
+        return PSA_ERROR_BAD_STATE;
     }
 
     return PSA_SUCCESS;
 }
 
-static psa_status_t psa_jpake_output_epilogue(
-    psa_pake_operation_t *operation)
+static psa_status_t psa_jpake_epilogue(
+    psa_pake_operation_t *operation,
+    psa_jpake_io_mode_t io_mode)
 {
-    psa_jpake_computation_stage_t *computation_stage =
+    psa_jpake_computation_stage_t *stage =
         &operation->computation_stage.jpake;
 
-    if ((computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
-         computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
-        (computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
-         computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
-        computation_stage->state = PSA_PAKE_STATE_READY;
-        computation_stage->output_step++;
-        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+    if (stage->step == PSA_PAKE_STEP_ZK_PROOF) {
+        /* End of an input/output */
+        if (io_mode == PSA_JPAKE_INPUT) {
+            stage->inputs++;
+            if (stage->inputs == PSA_JPAKE_EXPECTED_INPUTS(stage->round)) {
+                stage->io_mode = PSA_JPAKE_OUTPUT;
+            }
+        }
+        if (io_mode == PSA_JPAKE_OUTPUT) {
+            stage->outputs++;
+            if (stage->outputs == PSA_JPAKE_EXPECTED_OUTPUTS(stage->round)) {
+                stage->io_mode = PSA_JPAKE_INPUT;
+            }
+        }
+        if (stage->inputs == PSA_JPAKE_EXPECTED_INPUTS(stage->round) &&
+            stage->outputs == PSA_JPAKE_EXPECTED_OUTPUTS(stage->round)) {
+            /* End of a round, move to the next round */
+            stage->inputs = 0;
+            stage->outputs = 0;
+            stage->round++;
+        }
+        stage->step = PSA_PAKE_STEP_KEY_SHARE;
     } else {
-        computation_stage->sequence++;
+        stage->step++;
     }
-
     return PSA_SUCCESS;
 }
+
 #endif /* PSA_WANT_ALG_JPAKE */
 
 psa_status_t psa_pake_output(
@@ -8168,7 +8116,7 @@
     switch (operation->alg) {
 #if defined(PSA_WANT_ALG_JPAKE)
         case PSA_ALG_JPAKE:
-            status = psa_jpake_output_prologue(operation, step);
+            status = psa_jpake_prologue(operation, step, PSA_JPAKE_OUTPUT);
             if (status != PSA_SUCCESS) {
                 goto exit;
             }
@@ -8192,7 +8140,7 @@
     switch (operation->alg) {
 #if defined(PSA_WANT_ALG_JPAKE)
         case PSA_ALG_JPAKE:
-            status = psa_jpake_output_epilogue(operation);
+            status = psa_jpake_epilogue(operation, PSA_JPAKE_OUTPUT);
             if (status != PSA_SUCCESS) {
                 goto exit;
             }
@@ -8209,100 +8157,6 @@
     return status;
 }
 
-#if defined(PSA_WANT_ALG_JPAKE)
-static psa_status_t psa_jpake_input_prologue(
-    psa_pake_operation_t *operation,
-    psa_pake_step_t step)
-{
-    if (step != PSA_PAKE_STEP_KEY_SHARE &&
-        step != PSA_PAKE_STEP_ZK_PUBLIC &&
-        step != PSA_PAKE_STEP_ZK_PROOF) {
-        return PSA_ERROR_INVALID_ARGUMENT;
-    }
-
-    psa_jpake_computation_stage_t *computation_stage =
-        &operation->computation_stage.jpake;
-
-    if (computation_stage->state == PSA_PAKE_STATE_INVALID) {
-        return PSA_ERROR_BAD_STATE;
-    }
-
-    if (computation_stage->state != PSA_PAKE_STATE_READY &&
-        computation_stage->state != PSA_PAKE_INPUT_X1_X2 &&
-        computation_stage->state != PSA_PAKE_INPUT_X4S) {
-        return PSA_ERROR_BAD_STATE;
-    }
-
-    if (computation_stage->state == PSA_PAKE_STATE_READY) {
-        if (step != PSA_PAKE_STEP_KEY_SHARE) {
-            return PSA_ERROR_BAD_STATE;
-        }
-
-        switch (computation_stage->input_step) {
-            case PSA_PAKE_STEP_X1_X2:
-                computation_stage->state = PSA_PAKE_INPUT_X1_X2;
-                break;
-            case PSA_PAKE_STEP_X2S:
-                computation_stage->state = PSA_PAKE_INPUT_X4S;
-                break;
-            default:
-                return PSA_ERROR_BAD_STATE;
-        }
-
-        computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
-    }
-
-    /* Check if step matches current sequence */
-    switch (computation_stage->sequence) {
-        case PSA_PAKE_X1_STEP_KEY_SHARE:
-        case PSA_PAKE_X2_STEP_KEY_SHARE:
-            if (step != PSA_PAKE_STEP_KEY_SHARE) {
-                return PSA_ERROR_BAD_STATE;
-            }
-            break;
-
-        case PSA_PAKE_X1_STEP_ZK_PUBLIC:
-        case PSA_PAKE_X2_STEP_ZK_PUBLIC:
-            if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
-                return PSA_ERROR_BAD_STATE;
-            }
-            break;
-
-        case PSA_PAKE_X1_STEP_ZK_PROOF:
-        case PSA_PAKE_X2_STEP_ZK_PROOF:
-            if (step != PSA_PAKE_STEP_ZK_PROOF) {
-                return PSA_ERROR_BAD_STATE;
-            }
-            break;
-
-        default:
-            return PSA_ERROR_BAD_STATE;
-    }
-
-    return PSA_SUCCESS;
-}
-
-static psa_status_t psa_jpake_input_epilogue(
-    psa_pake_operation_t *operation)
-{
-    psa_jpake_computation_stage_t *computation_stage =
-        &operation->computation_stage.jpake;
-
-    if ((computation_stage->state == PSA_PAKE_INPUT_X1_X2 &&
-         computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
-        (computation_stage->state == PSA_PAKE_INPUT_X4S &&
-         computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
-        computation_stage->state = PSA_PAKE_STATE_READY;
-        computation_stage->input_step++;
-        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
-    } else {
-        computation_stage->sequence++;
-    }
-
-    return PSA_SUCCESS;
-}
-#endif /* PSA_WANT_ALG_JPAKE */
-
 psa_status_t psa_pake_input(
     psa_pake_operation_t *operation,
     psa_pake_step_t step,
@@ -8335,7 +8189,7 @@
     switch (operation->alg) {
 #if defined(PSA_WANT_ALG_JPAKE)
         case PSA_ALG_JPAKE:
-            status = psa_jpake_input_prologue(operation, step);
+            status = psa_jpake_prologue(operation, step, PSA_JPAKE_INPUT);
             if (status != PSA_SUCCESS) {
                 goto exit;
             }
@@ -8359,7 +8213,7 @@
     switch (operation->alg) {
 #if defined(PSA_WANT_ALG_JPAKE)
         case PSA_ALG_JPAKE:
-            status = psa_jpake_input_epilogue(operation);
+            status = psa_jpake_epilogue(operation, PSA_JPAKE_INPUT);
             if (status != PSA_SUCCESS) {
                 goto exit;
             }
@@ -8394,8 +8248,7 @@
     if (operation->alg == PSA_ALG_JPAKE) {
         psa_jpake_computation_stage_t *computation_stage =
             &operation->computation_stage.jpake;
-        if (computation_stage->input_step != PSA_PAKE_STEP_DERIVE ||
-            computation_stage->output_step != PSA_PAKE_STEP_DERIVE) {
+        if (computation_stage->round != PSA_JPAKE_FINISHED) {
             status = PSA_ERROR_BAD_STATE;
             goto exit;
         }
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index 4136614..e22bcf8 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -80,65 +80,37 @@
  */
 
 /*
- * The first PAKE step shares the same sequences of the second PAKE step
- * but with a second set of KEY_SHARE/ZK_PUBLIC/ZK_PROOF outputs/inputs.
- * It's simpler to share the same sequences numbers of the first
- * set of KEY_SHARE/ZK_PUBLIC/ZK_PROOF outputs/inputs in both PAKE steps.
+ * Possible sequence of calls to implementation:
  *
- * State sequence with step, state & sequence enums:
- *   => Input & Output Step = PSA_PAKE_STEP_INVALID
- *   => state = PSA_PAKE_STATE_INVALID
- *   psa_pake_setup()
- *   => Input & Output Step = PSA_PAKE_STEP_X1_X2
- *   => state = PSA_PAKE_STATE_SETUP
- *   => sequence = PSA_PAKE_SEQ_INVALID
- *   |
- *   |--- In any order: (First round input before or after first round output)
- *   |   | First call of psa_pake_output() or psa_pake_input() sets
- *   |   | state = PSA_PAKE_STATE_READY
- *   |   |
- *   |   |------ In Order: => state = PSA_PAKE_OUTPUT_X1_X2
- *   |   |       | psa_pake_output() => sequence = PSA_PAKE_X1_STEP_KEY_SHARE
- *   |   |       | psa_pake_output() => sequence = PSA_PAKE_X1_STEP_ZK_PUBLIC
- *   |   |       | psa_pake_output() => sequence = PSA_PAKE_X1_STEP_ZK_PROOF
- *   |   |       | psa_pake_output() => sequence = PSA_PAKE_X2_STEP_KEY_SHARE
- *   |   |       | psa_pake_output() => sequence = PSA_PAKE_X2_STEP_ZK_PUBLIC
- *   |   |       | psa_pake_output() => sequence = PSA_PAKE_X2_STEP_ZK_PROOF
- *   |   |       | => state = PSA_PAKE_STATE_READY
- *   |   |       | => sequence = PSA_PAKE_SEQ_INVALID
- *   |   |       | => Output Step = PSA_PAKE_STEP_X2S
- *   |   |
- *   |   |------ In Order: => state = PSA_PAKE_INPUT_X1_X2
- *   |   |       | psa_pake_input() => sequence = PSA_PAKE_X1_STEP_KEY_SHARE
- *   |   |       | psa_pake_input() => sequence = PSA_PAKE_X1_STEP_ZK_PUBLIC
- *   |   |       | psa_pake_input() => sequence = PSA_PAKE_X1_STEP_ZK_PROOF
- *   |   |       | psa_pake_input() => sequence = PSA_PAKE_X2_STEP_KEY_SHARE
- *   |   |       | psa_pake_input() => sequence = PSA_PAKE_X2_STEP_ZK_PUBLIC
- *   |   |       | psa_pake_input() => sequence = PSA_PAKE_X2_STEP_ZK_PROOF
- *   |   |       | => state = PSA_PAKE_STATE_READY
- *   |   |       | => sequence = PSA_PAKE_SEQ_INVALID
- *   |   |       | => Output Step = PSA_PAKE_INPUT_X4S
- *   |
- *   |--- In any order: (Second round input before or after second round output)
- *   |   |
- *   |   |------ In Order: => state = PSA_PAKE_OUTPUT_X2S
- *   |   |       | psa_pake_output() => sequence = PSA_PAKE_X1_STEP_KEY_SHARE
- *   |   |       | psa_pake_output() => sequence = PSA_PAKE_X1_STEP_ZK_PUBLIC
- *   |   |       | psa_pake_output() => sequence = PSA_PAKE_X1_STEP_ZK_PROOF
- *   |   |       | => state = PSA_PAKE_STATE_READY
- *   |   |       | => sequence = PSA_PAKE_SEQ_INVALID
- *   |   |       | => Output Step = PSA_PAKE_STEP_DERIVE
- *   |   |
- *   |   |------ In Order: => state = PSA_PAKE_INPUT_X4S
- *   |   |       | psa_pake_input() => sequence = PSA_PAKE_X1_STEP_KEY_SHARE
- *   |   |       | psa_pake_input() => sequence = PSA_PAKE_X1_STEP_ZK_PUBLIC
- *   |   |       | psa_pake_input() => sequence = PSA_PAKE_X1_STEP_ZK_PROOF
- *   |   |       | => state = PSA_PAKE_STATE_READY
- *   |   |       | => sequence = PSA_PAKE_SEQ_INVALID
- *   |   |       | => Output Step = PSA_PAKE_STEP_DERIVE
- *   |
- *   psa_pake_get_implicit_key()
- *   => Input & Output Step = PSA_PAKE_STEP_INVALID
+ * |--- In any order:
+ * |   |
+ * |   |------ In Order
+ * |   |       | mbedtls_psa_pake_output(PSA_JPAKE_X1_STEP_KEY_SHARE)
+ * |   |       | mbedtls_psa_pake_output(PSA_JPAKE_X1_STEP_ZK_PUBLIC)
+ * |   |       | mbedtls_psa_pake_output(PSA_JPAKE_X1_STEP_ZK_PROOF)
+ * |   |       | mbedtls_psa_pake_output(PSA_JPAKE_X2_STEP_KEY_SHARE)
+ * |   |       | mbedtls_psa_pake_output(PSA_JPAKE_X2_STEP_ZK_PUBLIC)
+ * |   |       | mbedtls_psa_pake_output(PSA_JPAKE_X2_STEP_ZK_PROOF)
+ * |   |
+ * |   |------ In Order:
+ * |           | mbedtls_psa_pake_input(PSA_JPAKE_X1_STEP_KEY_SHARE)
+ * |           | mbedtls_psa_pake_input(PSA_JPAKE_X1_STEP_ZK_PUBLIC)
+ * |           | mbedtls_psa_pake_input(PSA_JPAKE_X1_STEP_ZK_PROOF)
+ * |           | mbedtls_psa_pake_input(PSA_JPAKE_X2_STEP_KEY_SHARE)
+ * |           | mbedtls_psa_pake_input(PSA_JPAKE_X2_STEP_ZK_PUBLIC)
+ * |           | mbedtls_psa_pake_input(PSA_JPAKE_X2_STEP_ZK_PROOF)
+ * |
+ * |--- In any order:
+ * |   |
+ * |   |------ In Order
+ * |   |       | mbedtls_psa_pake_output(PSA_JPAKE_X2S_STEP_KEY_SHARE)
+ * |   |       | mbedtls_psa_pake_output(PSA_JPAKE_X2S_STEP_ZK_PUBLIC)
+ * |   |       | mbedtls_psa_pake_output(PSA_JPAKE_X2S_STEP_ZK_PROOF)
+ * |   |
+ * |   |------ In Order:
+ * |           | mbedtls_psa_pake_input(PSA_JPAKE_X4S_STEP_KEY_SHARE)
+ * |           | mbedtls_psa_pake_input(PSA_JPAKE_X4S_STEP_ZK_PUBLIC)
+ * |           | mbedtls_psa_pake_input(PSA_JPAKE_X4S_STEP_ZK_PROOF)
  */
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
diff --git a/tests/suites/test_suite_psa_crypto_driver_wrappers.function b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
index b971f81..87f7b37 100644
--- a/tests/suites/test_suite_psa_crypto_driver_wrappers.function
+++ b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
@@ -3127,8 +3127,10 @@
                        PSA_SUCCESS);
 
             /* Simulate that we are ready to get implicit key. */
-            operation.computation_stage.jpake.input_step = PSA_PAKE_STEP_DERIVE;
-            operation.computation_stage.jpake.output_step = PSA_PAKE_STEP_DERIVE;
+            operation.computation_stage.jpake.round = PSA_JPAKE_FINISHED;
+            operation.computation_stage.jpake.inputs = 0;
+            operation.computation_stage.jpake.outputs = 0;
+            operation.computation_stage.jpake.step = PSA_PAKE_STEP_KEY_SHARE;
 
             /* --- psa_pake_get_implicit_key --- */
             mbedtls_test_driver_pake_hooks.forced_status = forced_status;
diff --git a/tests/suites/test_suite_psa_crypto_pake.data b/tests/suites/test_suite_psa_crypto_pake.data
index 9e1cc63..ea39ea4 100644
--- a/tests/suites/test_suite_psa_crypto_pake.data
+++ b/tests/suites/test_suite_psa_crypto_pake.data
@@ -132,83 +132,99 @@
 
 PSA PAKE: no injected errors
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_NONE:PSA_SUCCESS
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_NONE:PSA_SUCCESS:0
 
 PSA PAKE: no injected errors, client input first
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:1:"abcdef":ERR_NONE:PSA_SUCCESS
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:1:"abcdef":ERR_NONE:PSA_SUCCESS:0
 
 PSA PAKE: inject ERR_INJECT_ROUND1_CLIENT_KEY_SHARE_PART1
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_CLIENT_KEY_SHARE_PART1:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_CLIENT_KEY_SHARE_PART1:PSA_ERROR_DATA_INVALID:0
 
 PSA PAKE: inject ERR_INJECT_ROUND1_CLIENT_ZK_PUBLIC_PART1
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_CLIENT_ZK_PUBLIC_PART1:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_CLIENT_ZK_PUBLIC_PART1:PSA_ERROR_DATA_INVALID:0
 
 PSA PAKE: inject ERR_INJECT_ROUND1_CLIENT_ZK_PROOF_PART1
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_CLIENT_ZK_PROOF_PART1:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_CLIENT_ZK_PROOF_PART1:PSA_ERROR_DATA_INVALID:0
 
 PSA PAKE: inject ERR_INJECT_ROUND1_CLIENT_KEY_SHARE_PART2
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_CLIENT_KEY_SHARE_PART2:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_CLIENT_KEY_SHARE_PART2:PSA_ERROR_DATA_INVALID:0
 
 PSA PAKE: inject ERR_INJECT_ROUND1_CLIENT_ZK_PUBLIC_PART2
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_CLIENT_ZK_PUBLIC_PART2:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_CLIENT_ZK_PUBLIC_PART2:PSA_ERROR_DATA_INVALID:0
 
 PSA PAKE: inject ERR_INJECT_ROUND1_CLIENT_ZK_PROOF_PART2
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_CLIENT_ZK_PROOF_PART2:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_CLIENT_ZK_PROOF_PART2:PSA_ERROR_DATA_INVALID:0
 
 PSA PAKE: inject ERR_INJECT_ROUND1_SERVER_KEY_SHARE_PART1
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_SERVER_KEY_SHARE_PART1:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_SERVER_KEY_SHARE_PART1:PSA_ERROR_DATA_INVALID:0
 
 PSA PAKE: inject ERR_INJECT_ROUND1_SERVER_ZK_PUBLIC_PART1
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_SERVER_ZK_PUBLIC_PART1:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_SERVER_ZK_PUBLIC_PART1:PSA_ERROR_DATA_INVALID:0
 
 PSA PAKE: inject ERR_INJECT_ROUND1_SERVER_ZK_PROOF_PART1
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_SERVER_ZK_PROOF_PART1:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_SERVER_ZK_PROOF_PART1:PSA_ERROR_DATA_INVALID:0
 
 PSA PAKE: inject ERR_INJECT_ROUND1_SERVER_KEY_SHARE_PART2
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_SERVER_KEY_SHARE_PART2:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_SERVER_KEY_SHARE_PART2:PSA_ERROR_DATA_INVALID:0
 
 PSA PAKE: inject ERR_INJECT_ROUND1_SERVER_ZK_PUBLIC_PART2
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_SERVER_ZK_PUBLIC_PART2:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_SERVER_ZK_PUBLIC_PART2:PSA_ERROR_DATA_INVALID:0
 
 PSA PAKE: inject ERR_INJECT_ROUND1_SERVER_ZK_PROOF_PART2
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_SERVER_ZK_PROOF_PART2:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND1_SERVER_ZK_PROOF_PART2:PSA_ERROR_DATA_INVALID:0
 
 PSA PAKE: inject ERR_INJECT_ROUND2_CLIENT_KEY_SHARE
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND2_CLIENT_KEY_SHARE:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND2_CLIENT_KEY_SHARE:PSA_ERROR_DATA_INVALID:1
 
 PSA PAKE: inject ERR_INJECT_ROUND2_CLIENT_ZK_PUBLIC
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND2_CLIENT_ZK_PUBLIC:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND2_CLIENT_ZK_PUBLIC:PSA_ERROR_DATA_INVALID:1
 
 PSA PAKE: inject ERR_INJECT_ROUND2_CLIENT_ZK_PROOF
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND2_CLIENT_ZK_PROOF:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND2_CLIENT_ZK_PROOF:PSA_ERROR_DATA_INVALID:1
 
 PSA PAKE: inject ERR_INJECT_ROUND2_SERVER_KEY_SHARE
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND2_SERVER_KEY_SHARE:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND2_SERVER_KEY_SHARE:PSA_ERROR_DATA_INVALID:1
 
 PSA PAKE: inject ERR_INJECT_ROUND2_SERVER_ZK_PUBLIC
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND2_SERVER_ZK_PUBLIC:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND2_SERVER_ZK_PUBLIC:PSA_ERROR_DATA_INVALID:1
 
 PSA PAKE: inject ERR_INJECT_ROUND2_SERVER_ZK_PROOF
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
-ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND2_SERVER_ZK_PROOF:PSA_ERROR_DATA_INVALID
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_ROUND2_SERVER_ZK_PROOF:PSA_ERROR_DATA_INVALID:1
+
+PSA PAKE: inject ERR_INJECT_EXTRA_OUTPUT
+depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_EXTRA_OUTPUT:PSA_ERROR_BAD_STATE:0
+
+PSA PAKE: inject ERR_INJECT_EXTRA_INPUT
+depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:1:"abcdef":ERR_INJECT_EXTRA_INPUT:PSA_ERROR_BAD_STATE:0
+
+PSA PAKE: inject ERR_INJECT_EXTRA_OUTPUT_AT_END
+depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:1:"abcdef":ERR_INJECT_EXTRA_OUTPUT_AT_END:PSA_ERROR_BAD_STATE:1
+
+PSA PAKE: inject ERR_INJECT_EXTRA_INPUT_AT_END
+depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
+ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_EXTRA_INPUT_AT_END:PSA_ERROR_BAD_STATE:1
 
 PSA PAKE: ecjpake size macros
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256
diff --git a/tests/suites/test_suite_psa_crypto_pake.function b/tests/suites/test_suite_psa_crypto_pake.function
index 52380de..f04d56f 100644
--- a/tests/suites/test_suite_psa_crypto_pake.function
+++ b/tests/suites/test_suite_psa_crypto_pake.function
@@ -2,6 +2,7 @@
 #include <stdint.h>
 
 #include "psa/crypto.h"
+#include "psa/crypto_extra.h"
 
 typedef enum {
     ERR_NONE = 0,
@@ -39,6 +40,10 @@
     ERR_INJECT_ROUND2_SERVER_KEY_SHARE,
     ERR_INJECT_ROUND2_SERVER_ZK_PUBLIC,
     ERR_INJECT_ROUND2_SERVER_ZK_PROOF,
+    ERR_INJECT_EXTRA_OUTPUT,
+    ERR_INJECT_EXTRA_INPUT,
+    ERR_INJECT_EXTRA_OUTPUT_AT_END,
+    ERR_INJECT_EXTRA_INPUT_AT_END,
     /* erros issued from the .data file */
     ERR_IN_SETUP,
     ERR_IN_SET_USER,
@@ -69,6 +74,13 @@
         *(buf + 7) ^= 1;                           \
     }
 
+#define DO_ROUND_CONDITIONAL_CHECK_FAILURE(this_stage, function) \
+    if (this_stage == err_stage)                                 \
+    {                                                            \
+        TEST_EQUAL(function, expected_error_arg);                \
+        break;                                                   \
+    }
+
 #define DO_ROUND_UPDATE_OFFSETS(main_buf_offset, step_offset, step_size) \
     {                                       \
         step_offset = main_buf_offset;      \
@@ -185,6 +197,12 @@
                 buffer0 + buffer0_off);
             DO_ROUND_UPDATE_OFFSETS(buffer0_off, s_x2_pr_off, s_x2_pr_len);
 
+            size_t extra_output_len;
+            DO_ROUND_CONDITIONAL_CHECK_FAILURE(
+                ERR_INJECT_EXTRA_OUTPUT,
+                psa_pake_output(server, PSA_PAKE_STEP_KEY_SHARE,
+                                buffer0 + s_g2_off, 512 - s_g2_off, &extra_output_len));
+            (void) extra_output_len;
             /*
              * When injecting errors in inputs, the implementation is
              * free to detect it right away of with a delay.
@@ -223,6 +241,12 @@
                                         s_x2_pr_len);
                 DO_ROUND_CHECK_FAILURE();
 
+                /* Note: Must have client_input_first == 1 to inject extra input */
+                DO_ROUND_CONDITIONAL_CHECK_FAILURE(
+                    ERR_INJECT_EXTRA_INPUT,
+                    psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
+                                   buffer0 + s_g2_off, s_g2_len));
+
                 /* Error didn't trigger, make test fail */
                 if ((err_stage >= ERR_INJECT_ROUND1_SERVER_KEY_SHARE_PART1) &&
                     (err_stage <= ERR_INJECT_ROUND1_SERVER_ZK_PROOF_PART2)) {
@@ -444,6 +468,16 @@
                 buffer1 + buffer1_off);
             DO_ROUND_UPDATE_OFFSETS(buffer1_off, c_x2s_pr_off, c_x2s_pr_len);
 
+            if (client_input_first == 1) {
+                size_t extra_output_at_end_len;
+                DO_ROUND_CONDITIONAL_CHECK_FAILURE(
+                    ERR_INJECT_EXTRA_OUTPUT_AT_END,
+                    psa_pake_output(client, PSA_PAKE_STEP_KEY_SHARE,
+                                    buffer1 + c_a_off, 512 - c_a_off,
+                                    &extra_output_at_end_len));
+                (void) extra_output_at_end_len;
+            }
+
             if (client_input_first == 0) {
                 /* Client second round Input */
                 status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
@@ -481,6 +515,12 @@
                                     buffer1 + c_x2s_pr_off, c_x2s_pr_len);
             DO_ROUND_CHECK_FAILURE();
 
+            DO_ROUND_CONDITIONAL_CHECK_FAILURE(
+                ERR_INJECT_EXTRA_INPUT_AT_END,
+                psa_pake_input(server, PSA_PAKE_STEP_KEY_SHARE,
+                               buffer1 + c_a_off, c_a_len));
+
+
             /* Error didn't trigger, make test fail */
             if ((err_stage >= ERR_INJECT_ROUND2_CLIENT_KEY_SHARE) &&
                 (err_stage <= ERR_INJECT_ROUND2_CLIENT_ZK_PROOF)) {
@@ -733,7 +773,8 @@
                            int client_input_first,
                            data_t *pw_data,
                            int err_stage_arg,
-                           int expected_error_arg)
+                           int expected_error_arg,
+                           int inject_in_second_round)
 {
     psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
     psa_pake_operation_t server = psa_pake_operation_init();
@@ -770,9 +811,10 @@
 
     ecjpake_do_round(alg, primitive_arg, &server, &client,
                      client_input_first, PAKE_ROUND_ONE,
-                     err_stage, expected_error_arg);
+                     inject_in_second_round ? ERR_NONE : err_stage,
+                     expected_error_arg);
 
-    if (err_stage != ERR_NONE) {
+    if (!inject_in_second_round && err_stage != ERR_NONE) {
         goto exit;
     }