Change J-PAKE internal state machine
Keep track of the J-PAKE internal state in a more intuitive way.
Specifically, replace the current state with a struct of 5 fields:
* The round of J-PAKE we are currently in, FIRST or SECOND
* The 'mode' we are currently working in, INPUT or OUTPUT
* The number of inputs so far this round
* The number of outputs so far this round
* The PAKE step we are expecting, KEY_SHARE, ZK_PUBLIC or ZK_PROOF
This should improve the readability of the state-transformation code.
Signed-off-by: David Horstmann <david.horstmann@arm.com>
diff --git a/include/psa/crypto_extra.h b/include/psa/crypto_extra.h
index 5529dd1..a3351a6 100644
--- a/include/psa/crypto_extra.h
+++ b/include/psa/crypto_extra.h
@@ -2028,14 +2028,33 @@
PSA_JPAKE_X4S_STEP_ZK_PROOF = 12 /* Round 2: input Schnorr NIZKP proof for the X4S key (from peer) */
} psa_crypto_driver_pake_step_t;
+typedef enum psa_jpake_round {
+ FIRST = 0,
+ SECOND = 1,
+ FINISHED = 2
+} psa_jpake_round_t;
+
+typedef enum psa_jpake_io_mode {
+ INPUT = 0,
+ OUTPUT = 1
+} psa_jpake_io_mode_t;
struct psa_jpake_computation_stage_s {
- psa_jpake_state_t MBEDTLS_PRIVATE(state);
- psa_jpake_sequence_t MBEDTLS_PRIVATE(sequence);
- psa_jpake_step_t MBEDTLS_PRIVATE(input_step);
- psa_jpake_step_t MBEDTLS_PRIVATE(output_step);
+ /* The J-PAKE round we are currently on */
+ psa_jpake_round_t MBEDTLS_PRIVATE(round);
+ /* The 'mode' we are currently in (inputting or outputting) */
+ psa_jpake_io_mode_t MBEDTLS_PRIVATE(mode);
+ /* The number of inputs so far this round */
+ uint8_t MBEDTLS_PRIVATE(inputs);
+ /* The number of outputs so far this round */
+ uint8_t MBEDTLS_PRIVATE(outputs);
+ /* The next expected step (KEY_SHARE, ZK_PUBLIC or ZK_PROOF) */
+ psa_pake_step_t MBEDTLS_PRIVATE(step);
};
+#define PSA_JPAKE_EXPECTED_INPUTS(round) (((round) == FIRST) ? 2 : 1)
+#define PSA_JPAKE_EXPECTED_OUTPUTS(round) (((round) == FIRST) ? 2 : 1)
+
struct psa_pake_operation_s {
/** Unique ID indicating which driver got assigned to do the
* operation. Since driver contexts are driver-specific, swapping
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 2173483..f86ea3e 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7767,10 +7767,11 @@
psa_jpake_computation_stage_t *computation_stage =
&operation->computation_stage.jpake;
- computation_stage->state = PSA_PAKE_STATE_SETUP;
- computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
- computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
- computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+ computation_stage->round = FIRST;
+ computation_stage->mode = INPUT;
+ computation_stage->inputs = 0;
+ computation_stage->outputs = 0;
+ computation_stage->step = PSA_PAKE_STEP_KEY_SHARE;
} else
#endif /* PSA_WANT_ALG_JPAKE */
{
@@ -7939,57 +7940,66 @@
return status;
}
-/* Auxiliary function to convert core computation stage(step, sequence, state) to single driver step. */
+/* Auxiliary function to convert core computation stage to single driver step. */
#if defined(PSA_WANT_ALG_JPAKE)
static psa_crypto_driver_pake_step_t convert_jpake_computation_stage_to_driver_step(
psa_jpake_computation_stage_t *stage)
{
- switch (stage->state) {
- case PSA_PAKE_OUTPUT_X1_X2:
- case PSA_PAKE_INPUT_X1_X2:
- switch (stage->sequence) {
- case PSA_PAKE_X1_STEP_KEY_SHARE:
+ if (stage->round == FIRST) {
+ int is_x1;
+ if (stage->mode == OUTPUT) {
+ is_x1 = (stage->outputs < 1);
+ } else {
+ is_x1 = (stage->inputs < 1);
+ }
+
+ if (is_x1) {
+ switch (stage->step) {
+ case PSA_PAKE_STEP_KEY_SHARE:
return PSA_JPAKE_X1_STEP_KEY_SHARE;
- case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+ case PSA_PAKE_STEP_ZK_PUBLIC:
return PSA_JPAKE_X1_STEP_ZK_PUBLIC;
- case PSA_PAKE_X1_STEP_ZK_PROOF:
+ case PSA_PAKE_STEP_ZK_PROOF:
return PSA_JPAKE_X1_STEP_ZK_PROOF;
- case PSA_PAKE_X2_STEP_KEY_SHARE:
+ default:
+ return PSA_JPAKE_STEP_INVALID;
+ }
+ } else {
+ switch (stage->step) {
+ case PSA_PAKE_STEP_KEY_SHARE:
return PSA_JPAKE_X2_STEP_KEY_SHARE;
- case PSA_PAKE_X2_STEP_ZK_PUBLIC:
+ case PSA_PAKE_STEP_ZK_PUBLIC:
return PSA_JPAKE_X2_STEP_ZK_PUBLIC;
- case PSA_PAKE_X2_STEP_ZK_PROOF:
+ case PSA_PAKE_STEP_ZK_PROOF:
return PSA_JPAKE_X2_STEP_ZK_PROOF;
default:
return PSA_JPAKE_STEP_INVALID;
}
- break;
- case PSA_PAKE_OUTPUT_X2S:
- switch (stage->sequence) {
- case PSA_PAKE_X1_STEP_KEY_SHARE:
+ }
+ } else if (stage->round == SECOND) {
+ if (stage->mode == OUTPUT) {
+ switch (stage->step) {
+ case PSA_PAKE_STEP_KEY_SHARE:
return PSA_JPAKE_X2S_STEP_KEY_SHARE;
- case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+ case PSA_PAKE_STEP_ZK_PUBLIC:
return PSA_JPAKE_X2S_STEP_ZK_PUBLIC;
- case PSA_PAKE_X1_STEP_ZK_PROOF:
+ case PSA_PAKE_STEP_ZK_PROOF:
return PSA_JPAKE_X2S_STEP_ZK_PROOF;
default:
return PSA_JPAKE_STEP_INVALID;
}
- break;
- case PSA_PAKE_INPUT_X4S:
- switch (stage->sequence) {
- case PSA_PAKE_X1_STEP_KEY_SHARE:
+ } else {
+ switch (stage->step) {
+ case PSA_PAKE_STEP_KEY_SHARE:
return PSA_JPAKE_X4S_STEP_KEY_SHARE;
- case PSA_PAKE_X1_STEP_ZK_PUBLIC:
+ case PSA_PAKE_STEP_ZK_PUBLIC:
return PSA_JPAKE_X4S_STEP_ZK_PUBLIC;
- case PSA_PAKE_X1_STEP_ZK_PROOF:
+ case PSA_PAKE_STEP_ZK_PROOF:
return PSA_JPAKE_X4S_STEP_ZK_PROOF;
default:
return PSA_JPAKE_STEP_INVALID;
}
- break;
- default:
- return PSA_JPAKE_STEP_INVALID;
+ }
}
return PSA_JPAKE_STEP_INVALID;
}
@@ -8032,10 +8042,11 @@
operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
psa_jpake_computation_stage_t *computation_stage =
&operation->computation_stage.jpake;
- computation_stage->state = PSA_PAKE_STATE_READY;
- computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
- computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
- computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
+ computation_stage->round = FIRST;
+ computation_stage->mode = INPUT;
+ computation_stage->inputs = 0;
+ computation_stage->outputs = 0;
+ computation_stage->step = PSA_PAKE_STEP_KEY_SHARE;
} else
#endif /* PSA_WANT_ALG_JPAKE */
{
@@ -8046,9 +8057,10 @@
}
#if defined(PSA_WANT_ALG_JPAKE)
-static psa_status_t psa_jpake_output_prologue(
+static psa_status_t psa_jpake_prologue(
psa_pake_operation_t *operation,
- psa_pake_step_t step)
+ psa_pake_step_t step,
+ psa_jpake_io_mode_t function_mode)
{
if (step != PSA_PAKE_STEP_KEY_SHARE &&
step != PSA_PAKE_STEP_ZK_PUBLIC &&
@@ -8059,84 +8071,79 @@
psa_jpake_computation_stage_t *computation_stage =
&operation->computation_stage.jpake;
- if (computation_stage->state == PSA_PAKE_STATE_INVALID) {
+ if (computation_stage->round != FIRST &&
+ computation_stage->round != SECOND) {
return PSA_ERROR_BAD_STATE;
}
- if (computation_stage->state != PSA_PAKE_STATE_READY &&
- computation_stage->state != PSA_PAKE_OUTPUT_X1_X2 &&
- computation_stage->state != PSA_PAKE_OUTPUT_X2S) {
+ /* Check that the step we are given is the one we were expecting */
+ if (step != computation_stage->step) {
return PSA_ERROR_BAD_STATE;
}
- if (computation_stage->state == PSA_PAKE_STATE_READY) {
- if (step != PSA_PAKE_STEP_KEY_SHARE) {
- return PSA_ERROR_BAD_STATE;
- }
-
- switch (computation_stage->output_step) {
- case PSA_PAKE_STEP_X1_X2:
- computation_stage->state = PSA_PAKE_OUTPUT_X1_X2;
- break;
- case PSA_PAKE_STEP_X2S:
- computation_stage->state = PSA_PAKE_OUTPUT_X2S;
- break;
- default:
- return PSA_ERROR_BAD_STATE;
- }
-
- computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
+ if (step == PSA_PAKE_STEP_KEY_SHARE &&
+ computation_stage->inputs == 0 &&
+ computation_stage->outputs == 0) {
+ /* Start of the round, so function decides whether we are inputting
+ * or outputting */
+ computation_stage->mode = function_mode;
+ } else if (computation_stage->mode != function_mode) {
+ /* Middle of the round so the mode we are in must match the function
+ * called by the user */
+ return PSA_ERROR_BAD_STATE;
}
- /* Check if step matches current sequence */
- switch (computation_stage->sequence) {
- case PSA_PAKE_X1_STEP_KEY_SHARE:
- case PSA_PAKE_X2_STEP_KEY_SHARE:
- if (step != PSA_PAKE_STEP_KEY_SHARE) {
- return PSA_ERROR_BAD_STATE;
- }
- break;
-
- case PSA_PAKE_X1_STEP_ZK_PUBLIC:
- case PSA_PAKE_X2_STEP_ZK_PUBLIC:
- if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
- return PSA_ERROR_BAD_STATE;
- }
- break;
-
- case PSA_PAKE_X1_STEP_ZK_PROOF:
- case PSA_PAKE_X2_STEP_ZK_PROOF:
- if (step != PSA_PAKE_STEP_ZK_PROOF) {
- return PSA_ERROR_BAD_STATE;
- }
- break;
-
- default:
+ /* Check that we do not already have enough inputs/outputs
+ * this round */
+ if (function_mode == INPUT) {
+ if (computation_stage->inputs >=
+ PSA_JPAKE_EXPECTED_INPUTS(computation_stage->round)) {
return PSA_ERROR_BAD_STATE;
+ }
+ } else {
+ if (computation_stage->outputs >=
+ PSA_JPAKE_EXPECTED_OUTPUTS(computation_stage->round)) {
+ return PSA_ERROR_BAD_STATE;
+ }
}
-
return PSA_SUCCESS;
}
-static psa_status_t psa_jpake_output_epilogue(
- psa_pake_operation_t *operation)
+static psa_status_t psa_jpake_epilogue(
+ psa_pake_operation_t *operation,
+ psa_jpake_io_mode_t function_mode)
{
- psa_jpake_computation_stage_t *computation_stage =
+ psa_jpake_computation_stage_t *stage =
&operation->computation_stage.jpake;
- if ((computation_stage->state == PSA_PAKE_OUTPUT_X1_X2 &&
- computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
- (computation_stage->state == PSA_PAKE_OUTPUT_X2S &&
- computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
- computation_stage->state = PSA_PAKE_STATE_READY;
- computation_stage->output_step++;
- computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
+ if (stage->step == PSA_PAKE_STEP_ZK_PROOF) {
+ /* End of an input/output */
+ if (function_mode == INPUT) {
+ stage->inputs++;
+ if (stage->inputs >= PSA_JPAKE_EXPECTED_INPUTS(stage->round)) {
+ stage->mode = OUTPUT;
+ }
+ }
+ if (function_mode == OUTPUT) {
+ stage->outputs++;
+ if (stage->outputs >= PSA_JPAKE_EXPECTED_OUTPUTS(stage->round)) {
+ stage->mode = INPUT;
+ }
+ }
+ if (stage->inputs >= PSA_JPAKE_EXPECTED_INPUTS(stage->round) &&
+ stage->outputs >= PSA_JPAKE_EXPECTED_OUTPUTS(stage->round)) {
+ /* End of a round, move to the next round */
+ stage->inputs = 0;
+ stage->outputs = 0;
+ stage->round++;
+ }
+ stage->step = PSA_PAKE_STEP_KEY_SHARE;
} else {
- computation_stage->sequence++;
+ stage->step++;
}
-
return PSA_SUCCESS;
}
+
#endif /* PSA_WANT_ALG_JPAKE */
psa_status_t psa_pake_output(
@@ -8170,7 +8177,7 @@
switch (operation->alg) {
#if defined(PSA_WANT_ALG_JPAKE)
case PSA_ALG_JPAKE:
- status = psa_jpake_output_prologue(operation, step);
+ status = psa_jpake_prologue(operation, step, OUTPUT);
if (status != PSA_SUCCESS) {
goto exit;
}
@@ -8194,7 +8201,7 @@
switch (operation->alg) {
#if defined(PSA_WANT_ALG_JPAKE)
case PSA_ALG_JPAKE:
- status = psa_jpake_output_epilogue(operation);
+ status = psa_jpake_epilogue(operation, OUTPUT);
if (status != PSA_SUCCESS) {
goto exit;
}
@@ -8211,100 +8218,6 @@
return status;
}
-#if defined(PSA_WANT_ALG_JPAKE)
-static psa_status_t psa_jpake_input_prologue(
- psa_pake_operation_t *operation,
- psa_pake_step_t step)
-{
- if (step != PSA_PAKE_STEP_KEY_SHARE &&
- step != PSA_PAKE_STEP_ZK_PUBLIC &&
- step != PSA_PAKE_STEP_ZK_PROOF) {
- return PSA_ERROR_INVALID_ARGUMENT;
- }
-
- psa_jpake_computation_stage_t *computation_stage =
- &operation->computation_stage.jpake;
-
- if (computation_stage->state == PSA_PAKE_STATE_INVALID) {
- return PSA_ERROR_BAD_STATE;
- }
-
- if (computation_stage->state != PSA_PAKE_STATE_READY &&
- computation_stage->state != PSA_PAKE_INPUT_X1_X2 &&
- computation_stage->state != PSA_PAKE_INPUT_X4S) {
- return PSA_ERROR_BAD_STATE;
- }
-
- if (computation_stage->state == PSA_PAKE_STATE_READY) {
- if (step != PSA_PAKE_STEP_KEY_SHARE) {
- return PSA_ERROR_BAD_STATE;
- }
-
- switch (computation_stage->input_step) {
- case PSA_PAKE_STEP_X1_X2:
- computation_stage->state = PSA_PAKE_INPUT_X1_X2;
- break;
- case PSA_PAKE_STEP_X2S:
- computation_stage->state = PSA_PAKE_INPUT_X4S;
- break;
- default:
- return PSA_ERROR_BAD_STATE;
- }
-
- computation_stage->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
- }
-
- /* Check if step matches current sequence */
- switch (computation_stage->sequence) {
- case PSA_PAKE_X1_STEP_KEY_SHARE:
- case PSA_PAKE_X2_STEP_KEY_SHARE:
- if (step != PSA_PAKE_STEP_KEY_SHARE) {
- return PSA_ERROR_BAD_STATE;
- }
- break;
-
- case PSA_PAKE_X1_STEP_ZK_PUBLIC:
- case PSA_PAKE_X2_STEP_ZK_PUBLIC:
- if (step != PSA_PAKE_STEP_ZK_PUBLIC) {
- return PSA_ERROR_BAD_STATE;
- }
- break;
-
- case PSA_PAKE_X1_STEP_ZK_PROOF:
- case PSA_PAKE_X2_STEP_ZK_PROOF:
- if (step != PSA_PAKE_STEP_ZK_PROOF) {
- return PSA_ERROR_BAD_STATE;
- }
- break;
-
- default:
- return PSA_ERROR_BAD_STATE;
- }
-
- return PSA_SUCCESS;
-}
-
-static psa_status_t psa_jpake_input_epilogue(
- psa_pake_operation_t *operation)
-{
- psa_jpake_computation_stage_t *computation_stage =
- &operation->computation_stage.jpake;
-
- if ((computation_stage->state == PSA_PAKE_INPUT_X1_X2 &&
- computation_stage->sequence == PSA_PAKE_X2_STEP_ZK_PROOF) ||
- (computation_stage->state == PSA_PAKE_INPUT_X4S &&
- computation_stage->sequence == PSA_PAKE_X1_STEP_ZK_PROOF)) {
- computation_stage->state = PSA_PAKE_STATE_READY;
- computation_stage->input_step++;
- computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
- } else {
- computation_stage->sequence++;
- }
-
- return PSA_SUCCESS;
-}
-#endif /* PSA_WANT_ALG_JPAKE */
-
psa_status_t psa_pake_input(
psa_pake_operation_t *operation,
psa_pake_step_t step,
@@ -8337,7 +8250,7 @@
switch (operation->alg) {
#if defined(PSA_WANT_ALG_JPAKE)
case PSA_ALG_JPAKE:
- status = psa_jpake_input_prologue(operation, step);
+ status = psa_jpake_prologue(operation, step, INPUT);
if (status != PSA_SUCCESS) {
goto exit;
}
@@ -8361,7 +8274,7 @@
switch (operation->alg) {
#if defined(PSA_WANT_ALG_JPAKE)
case PSA_ALG_JPAKE:
- status = psa_jpake_input_epilogue(operation);
+ status = psa_jpake_epilogue(operation, INPUT);
if (status != PSA_SUCCESS) {
goto exit;
}
@@ -8396,8 +8309,7 @@
if (operation->alg == PSA_ALG_JPAKE) {
psa_jpake_computation_stage_t *computation_stage =
&operation->computation_stage.jpake;
- if (computation_stage->input_step != PSA_PAKE_STEP_DERIVE ||
- computation_stage->output_step != PSA_PAKE_STEP_DERIVE) {
+ if (computation_stage->round != FINISHED) {
status = PSA_ERROR_BAD_STATE;
goto exit;
}
diff --git a/tests/suites/test_suite_psa_crypto_driver_wrappers.function b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
index b971f81..87f7b37 100644
--- a/tests/suites/test_suite_psa_crypto_driver_wrappers.function
+++ b/tests/suites/test_suite_psa_crypto_driver_wrappers.function
@@ -3127,8 +3127,10 @@
PSA_SUCCESS);
/* Simulate that we are ready to get implicit key. */
- operation.computation_stage.jpake.input_step = PSA_PAKE_STEP_DERIVE;
- operation.computation_stage.jpake.output_step = PSA_PAKE_STEP_DERIVE;
+ operation.computation_stage.jpake.round = PSA_JPAKE_FINISHED;
+ operation.computation_stage.jpake.inputs = 0;
+ operation.computation_stage.jpake.outputs = 0;
+ operation.computation_stage.jpake.step = PSA_PAKE_STEP_KEY_SHARE;
/* --- psa_pake_get_implicit_key --- */
mbedtls_test_driver_pake_hooks.forced_status = forced_status;