Change J-PAKE internal state machine

Keep track of the J-PAKE internal state in a more intuitive way.
Specifically, replace the current state with a struct of 5 fields:

* The round of J-PAKE we are currently in, FIRST or SECOND
* The 'mode' we are currently working in, INPUT or OUTPUT
* The number of inputs so far this round
* The number of outputs so far this round
* The PAKE step we are expecting, KEY_SHARE, ZK_PUBLIC or ZK_PROOF

This should improve the readability of the state-transformation code.

Signed-off-by: David Horstmann <david.horstmann@arm.com>
diff --git a/include/psa/crypto_extra.h b/include/psa/crypto_extra.h
index 5529dd1..a3351a6 100644
--- a/include/psa/crypto_extra.h
+++ b/include/psa/crypto_extra.h
@@ -2028,14 +2028,33 @@
     PSA_JPAKE_X4S_STEP_ZK_PROOF   = 12  /* Round 2: input Schnorr NIZKP proof for the X4S key (from peer) */
 } psa_crypto_driver_pake_step_t;
 
+typedef enum psa_jpake_round {
+    FIRST = 0,
+    SECOND = 1,
+    FINISHED = 2
+} psa_jpake_round_t;
+
+typedef enum psa_jpake_io_mode {
+    INPUT = 0,
+    OUTPUT = 1
+} psa_jpake_io_mode_t;
 
 struct psa_jpake_computation_stage_s {
-    psa_jpake_state_t MBEDTLS_PRIVATE(state);
-    psa_jpake_sequence_t MBEDTLS_PRIVATE(sequence);
-    psa_jpake_step_t MBEDTLS_PRIVATE(input_step);
-    psa_jpake_step_t MBEDTLS_PRIVATE(output_step);
+    /* The J-PAKE round we are currently on */
+    psa_jpake_round_t MBEDTLS_PRIVATE(round);
+    /* The 'mode' we are currently in (inputting or outputting) */
+    psa_jpake_io_mode_t MBEDTLS_PRIVATE(mode);
+    /* The number of inputs so far this round */
+    uint8_t MBEDTLS_PRIVATE(inputs);
+    /* The number of outputs so far this round */
+    uint8_t MBEDTLS_PRIVATE(outputs);
+    /* The next expected step (KEY_SHARE, ZK_PUBLIC or ZK_PROOF) */
+    psa_pake_step_t MBEDTLS_PRIVATE(step);
 };
 
+#define PSA_JPAKE_EXPECTED_INPUTS(round) (((round) == FIRST) ? 2 : 1)
+#define PSA_JPAKE_EXPECTED_OUTPUTS(round) (((round) == FIRST) ? 2 : 1)
+
 struct psa_pake_operation_s {
     /** Unique ID indicating which driver got assigned to do the
      * operation. Since driver contexts are driver-specific, swapping
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 2173483..f86ea3e 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7767,10 +7767,11 @@
         psa_jpake_computation_stage_t *computation_stage =
             &operation->computation_stage.jpake;
 
-        computation_stage->state = PSA_PAKE_STATE_SETUP;
-        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
-        computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
-        computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+        computation_stage->round = FIRST;
+        computation_stage->mode = INPUT;
+        computation_stage->inputs = 0;
+        computation_stage->outputs = 0;
+        computation_stage->step = PSA_PAKE_STEP_KEY_SHARE;
     } else
 #endif /* PSA_WANT_ALG_JPAKE */
     {
@@ -7939,57 +7940,66 @@
     return status;
 }
 
-/* Auxiliary function to convert core computation stage(step, sequence, state) to single driver step. */
+/* Auxiliary function to convert core computation stage to single driver step. */
 #if defined(PSA_WANT_ALG_JPAKE)
 static psa_crypto_driver_pake_step_t convert_jpake_computation_stage_to_driver_step(
     psa_jpake_computation_stage_t *stage)
 {
-    switch (stage->state) {
-        case PSA_PAKE_OUTPUT_X1_X2:
-        case PSA_PAKE_INPUT_X1_X2:
-            switch (stage->sequence) {
-                case PSA_PAKE_X1_STEP_KEY_SHARE:
+    if (stage->round == FIRST) {
+        int is_x1;
+        if (stage->mode == OUTPUT) {
+            is_x1 = (stage->outputs < 1);
+        } else {
+            is_x1 = (stage->inputs < 1);
+        }
+
+        if (is_x1) {
+            switch (stage->step) {
+                case PSA_PAKE_STEP_KEY_SHARE:
                     return PSA_JPAKE_X1_STEP_KEY_SHARE;
-                case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+                case PSA_PAKE_STEP_ZK_PUBLIC:
                     return PSA_JPAKE_X1_STEP_ZK_PUBLIC;
-                case PSA_PAKE_X1_STEP_ZK_PROOF:
+                case PSA_PAKE_STEP_ZK_PROOF:
                     return PSA_JPAKE_X1_STEP_ZK_PROOF;
-                case PSA_PAKE_X2_STEP_KEY_SHARE:
+                default:
+                    return PSA_JPAKE_STEP_INVALID;
+            }
+        } else {
+            switch (stage->step) {
+                case PSA_PAKE_STEP_KEY_SHARE:
                     return PSA_JPAKE_X2_STEP_KEY_SHARE;
-                case PSA_PAKE_X2_STEP_ZK_PUBLIC:
+                case PSA_PAKE_STEP_ZK_PUBLIC:
                     return PSA_JPAKE_X2_STEP_ZK_PUBLIC;
-                case PSA_PAKE_X2_STEP_ZK_PROOF:
+                case PSA_PAKE_STEP_ZK_PROOF:
                     return PSA_JPAKE_X2_STEP_ZK_PROOF;
                 default:
                     return PSA_JPAKE_STEP_INVALID;
             }
-            break;
-        case PSA_PAKE_OUTPUT_X2S:
-            switch (stage->sequence) {
-                case PSA_PAKE_X1_STEP_KEY_SHARE:
+        }
+    } else if (stage->round == SECOND) {
+        if (stage->mode == OUTPUT) {
+            switch (stage->step) {
+                case PSA_PAKE_STEP_KEY_SHARE:
                     return PSA_JPAKE_X2S_STEP_KEY_SHARE;
-                case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+                case PSA_PAKE_STEP_ZK_PUBLIC:
                     return PSA_JPAKE_X2S_STEP_ZK_PUBLIC;
-                case PSA_PAKE_X1_STEP_ZK_PROOF:
+                case PSA_PAKE_STEP_ZK_PROOF:
                     return PSA_JPAKE_X2S_STEP_ZK_PROOF;
                 default:
                     return PSA_JPAKE_STEP_INVALID;
             }
-            break;
-        case PSA_PAKE_INPUT_X4S:
-            switch (stage->sequence) {
-                case PSA_PAKE_X1_STEP_KEY_SHARE:
+        } else {
+            switch (stage->step) {
+                case PSA_PAKE_STEP_KEY_SHARE:
                     return PSA_JPAKE_X4S_STEP_KEY_SHARE;
-                case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+                case PSA_PAKE_STEP_ZK_PUBLIC:
                     return PSA_JPAKE_X4S_STEP_ZK_PUBLIC;
-                case PSA_PAKE_X1_STEP_ZK_PROOF:
+                case PSA_PAKE_STEP_ZK_PROOF:
                     return PSA_JPAKE_X4S_STEP_ZK_PROOF;
                 default:
                     return PSA_JPAKE_STEP_INVALID;
             }
-            break;
-        default:
-            return PSA_JPAKE_STEP_INVALID;
+        }
     }
     return PSA_JPAKE_STEP_INVALID;
 }
@@ -8032,10 +8042,11 @@
             operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
             psa_jpake_computation_stage_t *computation_stage =
                 &operation->computation_stage.jpake;
-            computation_stage->state = PSA_PAKE_STATE_READY;
-            computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
-            computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
-            computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+            computation_stage->round = FIRST;
+            computation_stage->mode = INPUT;
+            computation_stage->inputs = 0;
+            computation_stage->outputs = 0;
+            computation_stage->step = PSA_PAKE_STEP_KEY_SHARE;
         } else
 #endif /* PSA_WANT_ALG_JPAKE */
         {
@@ -8046,9 +8057,10 @@
 }
 
 #if defined(PSA_WANT_ALG_JPAKE)
-static psa_status_t psa_jpake_output_prologue(
+static psa_status_t psa_jpake_prologue(
     psa_pake_operation_t *operation,
-    psa_pake_step_t step)
+    psa_pake_step_t step,
+    psa_jpake_io_mode_t function_mode)
 {
     if (step != PSA_PAKE_STEP_KEY_SHARE &&
         step != PSA_PAKE_STEP_ZK_PUBLIC &&
@@ -8059,84 +8071,79 @@
     psa_jpake_computation_stage_t *computation_stage =
         &operation->computation_stage.jpake;
 
-    if (computation_stage->state == PSA_PAKE_STATE_INVALID) {
+    if (computation_stage->round != FIRST &&
+        computation_stage->round != SECOND) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (computation_stage->state != PSA_PAKE_STATE_READY &&
-        computation_stage->state != PSA_PAKE_OUTPUT_X1_X2 &&
-        computation_stage->state != PSA_PAKE_OUTPUT_X2S) {
+    /* Check that the step we are given is the one we were expecting */
+    if (step != computation_stage->step) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (computation_stage->state == PSA_PAKE_STATE_READY) {
-        if (step != PSA_PAKE_STEP_KEY_SHARE) {
-            return PSA_ERROR_BAD_STATE;
-        }
-
-        switch (computation_stage->output_step) {
-            case PSA_PAKE_STEP_X1_X2:
-                computation_stage->state = PSA_PAKE_OUTPUT_X1_X2;
-                break;
-            case PSA_PAKE_STEP_X2S:
-                computation_stage->state = PSA_PAKE_OUTPUT_X2S;
-                break;
-            default:
-                return PSA_ERROR_BAD_STATE;
-        }
-
-        computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
+    if (step == PSA_PAKE_STEP_KEY_SHARE &&
+        computation_stage->inputs == 0 &&
+        computation_stage->outputs == 0) {
+        /* Start of the round, so function decides whether we are inputting
+         * or outputting */
+        computation_stage->mode = function_mode;
+    } else if (computation_stage->mode != function_mode) {
+        /* Middle of the round so the mode we are in must match the function
+         * called by the user */
+        return PSA_ERROR_BAD_STATE;
     }
 
-    /* Check if step matches current sequence */
-    switch (computation_stage->sequence) {
-        case PSA_PAKE_X1_STEP_KEY_SHARE:
-        case PSA_PAKE_X2_STEP_KEY_SHARE:
-            if (step != PSA_PAKE_STEP_KEY_SHARE) {
-                return PSA_ERROR_BAD_STATE;
-            }
-            break;
-
-        case PSA_PAKE_X1_STEP_ZK_PUBLIC:
-        case PSA_PAKE_X2_STEP_ZK_PUBLIC:
-            if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
-                return PSA_ERROR_BAD_STATE;
-            }
-            break;
-
-        case PSA_PAKE_X1_STEP_ZK_PROOF:
-        case PSA_PAKE_X2_STEP_ZK_PROOF:
-            if (step != PSA_PAKE_STEP_ZK_PROOF) {
-                return PSA_ERROR_BAD_STATE;
-            }
-            break;
-
-        default:
+    /* Check that we do not already have enough inputs/outputs
+     * this round */
+    if (function_mode == INPUT) {
+        if (computation_stage->inputs >=
+            PSA_JPAKE_EXPECTED_INPUTS(computation_stage->round)) {
             return PSA_ERROR_BAD_STATE;
+        }
+    } else {
+        if (computation_stage->outputs >=
+            PSA_JPAKE_EXPECTED_OUTPUTS(computation_stage->round)) {
+            return PSA_ERROR_BAD_STATE;
+        }
     }
-
     return PSA_SUCCESS;
 }
 
-static psa_status_t psa_jpake_output_epilogue(
-    psa_pake_operation_t *operation)
+static psa_status_t psa_jpake_epilogue(
+    psa_pake_operation_t *operation,
+    psa_jpake_io_mode_t function_mode)
 {
-    psa_jpake_computation_stage_t *computation_stage =
+    psa_jpake_computation_stage_t *stage =
         &operation->computation_stage.jpake;
 
-    if ((computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
-         computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
-        (computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
-         computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
-        computation_stage->state = PSA_PAKE_STATE_READY;
-        computation_stage->output_step++;
-        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+    if (stage->step == PSA_PAKE_STEP_ZK_PROOF) {
+        /* End of an input/output */
+        if (function_mode == INPUT) {
+            stage->inputs++;
+            if (stage->inputs >= PSA_JPAKE_EXPECTED_INPUTS(stage->round)) {
+                stage->mode = OUTPUT;
+            }
+        }
+        if (function_mode == OUTPUT) {
+            stage->outputs++;
+            if (stage->outputs >= PSA_JPAKE_EXPECTED_OUTPUTS(stage->round)) {
+                stage->mode = INPUT;
+            }
+        }
+        if (stage->inputs >= PSA_JPAKE_EXPECTED_INPUTS(stage->round) &&
+            stage->outputs >= PSA_JPAKE_EXPECTED_OUTPUTS(stage->round)) {
+            /* End of a round, move to the next round */
+            stage->inputs = 0;
+            stage->outputs = 0;
+            stage->round++;
+        }
+        stage->step = PSA_PAKE_STEP_KEY_SHARE;
     } else {
-        computation_stage->sequence++;
+        stage->step++;
     }
-
     return PSA_SUCCESS;
 }
+
 #endif /* PSA_WANT_ALG_JPAKE */
 
 psa_status_t psa_pake_output(
@@ -8170,7 +8177,7 @@
     switch (operation->alg) {
 #if defined(PSA_WANT_ALG_JPAKE)
         case PSA_ALG_JPAKE:
-            status = psa_jpake_output_prologue(operation, step);
+            status = psa_jpake_prologue(operation, step, OUTPUT);
             if (status != PSA_SUCCESS) {
                 goto exit;
             }
@@ -8194,7 +8201,7 @@
     switch (operation->alg) {
 #if defined(PSA_WANT_ALG_JPAKE)
         case PSA_ALG_JPAKE:
-            status = psa_jpake_output_epilogue(operation);
+            status = psa_jpake_epilogue(operation, OUTPUT);
             if (status != PSA_SUCCESS) {
                 goto exit;
             }
@@ -8211,100 +8218,6 @@
     return status;
 }
 
-#if defined(PSA_WANT_ALG_JPAKE)
-static psa_status_t psa_jpake_input_prologue(
-    psa_pake_operation_t *operation,
-    psa_pake_step_t step)
-{
-    if (step != PSA_PAKE_STEP_KEY_SHARE &&
-        step != PSA_PAKE_STEP_ZK_PUBLIC &&
-        step != PSA_PAKE_STEP_ZK_PROOF) {
-        return PSA_ERROR_INVALID_ARGUMENT;
-    }
-
-    psa_jpake_computation_stage_t *computation_stage =
-        &operation->computation_stage.jpake;
-
-    if (computation_stage->state == PSA_PAKE_STATE_INVALID) {
-        return PSA_ERROR_BAD_STATE;
-    }
-
-    if (computation_stage->state != PSA_PAKE_STATE_READY &&
-        computation_stage->state != PSA_PAKE_INPUT_X1_X2 &&
-        computation_stage->state != PSA_PAKE_INPUT_X4S) {
-        return PSA_ERROR_BAD_STATE;
-    }
-
-    if (computation_stage->state == PSA_PAKE_STATE_READY) {
-        if (step != PSA_PAKE_STEP_KEY_SHARE) {
-            return PSA_ERROR_BAD_STATE;
-        }
-
-        switch (computation_stage->input_step) {
-            case PSA_PAKE_STEP_X1_X2:
-                computation_stage->state = PSA_PAKE_INPUT_X1_X2;
-                break;
-            case PSA_PAKE_STEP_X2S:
-                computation_stage->state = PSA_PAKE_INPUT_X4S;
-                break;
-            default:
-                return PSA_ERROR_BAD_STATE;
-        }
-
-        computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
-    }
-
-    /* Check if step matches current sequence */
-    switch (computation_stage->sequence) {
-        case PSA_PAKE_X1_STEP_KEY_SHARE:
-        case PSA_PAKE_X2_STEP_KEY_SHARE:
-            if (step != PSA_PAKE_STEP_KEY_SHARE) {
-                return PSA_ERROR_BAD_STATE;
-            }
-            break;
-
-        case PSA_PAKE_X1_STEP_ZK_PUBLIC:
-        case PSA_PAKE_X2_STEP_ZK_PUBLIC:
-            if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
-                return PSA_ERROR_BAD_STATE;
-            }
-            break;
-
-        case PSA_PAKE_X1_STEP_ZK_PROOF:
-        case PSA_PAKE_X2_STEP_ZK_PROOF:
-            if (step != PSA_PAKE_STEP_ZK_PROOF) {
-                return PSA_ERROR_BAD_STATE;
-            }
-            break;
-
-        default:
-            return PSA_ERROR_BAD_STATE;
-    }
-
-    return PSA_SUCCESS;
-}
-
-static psa_status_t psa_jpake_input_epilogue(
-    psa_pake_operation_t *operation)
-{
-    psa_jpake_computation_stage_t *computation_stage =
-        &operation->computation_stage.jpake;
-
-    if ((computation_stage->state == PSA_PAKE_INPUT_X1_X2 &&
-         computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
-        (computation_stage->state == PSA_PAKE_INPUT_X4S &&
-         computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
-        computation_stage->state = PSA_PAKE_STATE_READY;
-        computation_stage->input_step++;
-        computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
-    } else {
-        computation_stage->sequence++;
-    }
-
-    return PSA_SUCCESS;
-}
-#endif /* PSA_WANT_ALG_JPAKE */
-
 psa_status_t psa_pake_input(
     psa_pake_operation_t *operation,
     psa_pake_step_t step,
@@ -8337,7 +8250,7 @@
     switch (operation->alg) {
 #if defined(PSA_WANT_ALG_JPAKE)
         case PSA_ALG_JPAKE:
-            status = psa_jpake_input_prologue(operation, step);
+            status = psa_jpake_prologue(operation, step, INPUT);
             if (status != PSA_SUCCESS) {
                 goto exit;
             }
@@ -8361,7 +8274,7 @@
     switch (operation->alg) {
 #if defined(PSA_WANT_ALG_JPAKE)
         case PSA_ALG_JPAKE:
-            status = psa_jpake_input_epilogue(operation);
+            status = psa_jpake_epilogue(operation, INPUT);
             if (status != PSA_SUCCESS) {
                 goto exit;
             }
@@ -8396,8 +8309,7 @@
     if (operation->alg == PSA_ALG_JPAKE) {
         psa_jpake_computation_stage_t *computation_stage =
             &operation->computation_stage.jpake;
-        if (computation_stage->input_step != PSA_PAKE_STEP_DERIVE ||
-            computation_stage->output_step != PSA_PAKE_STEP_DERIVE) {
+        if (computation_stage->round != FINISHED) {
             status = PSA_ERROR_BAD_STATE;
             goto exit;
         }
diff --git a/tests/suites/test_suite_psa_crypto_driver_wrappers.function b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
index b971f81..87f7b37 100644
--- a/tests/suites/test_suite_psa_crypto_driver_wrappers.function
+++ b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
@@ -3127,8 +3127,10 @@
                        PSA_SUCCESS);
 
             /* Simulate that we are ready to get implicit key. */
-            operation.computation_stage.jpake.input_step = PSA_PAKE_STEP_DERIVE;
-            operation.computation_stage.jpake.output_step = PSA_PAKE_STEP_DERIVE;
+            operation.computation_stage.jpake.round = PSA_JPAKE_FINISHED;
+            operation.computation_stage.jpake.inputs = 0;
+            operation.computation_stage.jpake.outputs = 0;
+            operation.computation_stage.jpake.step = PSA_PAKE_STEP_KEY_SHARE;
 
             /* --- psa_pake_get_implicit_key --- */
             mbedtls_test_driver_pake_hooks.forced_status = forced_status;