Test extra inputs and outputs at the end of J-PAKE

Add tests for supplying inputs or requesting outputs when a J-PAKE
computation has already completed

Signed-off-by: David Horstmann <david.horstmann@arm.com>
diff --git a/tests/suites/test_suite_psa_crypto_pake.data b/tests/suites/test_suite_psa_crypto_pake.data
index 89f1562..da54ad1 100644
--- a/tests/suites/test_suite_psa_crypto_pake.data
+++ b/tests/suites/test_suite_psa_crypto_pake.data
@@ -218,6 +218,14 @@
 depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
 ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:1:"abcdef":ERR_INJECT_EXTRA_INPUT:PSA_ERROR_BAD_STATE
 
+PSA PAKE: inject ERR_INJECT_EXTRA_OUTPUT_AT_END
+depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
+ecjpake_rounds_inject_second:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:1:"abcdef":ERR_INJECT_EXTRA_OUTPUT_AT_END:PSA_ERROR_BAD_STATE
+
+PSA PAKE: inject ERR_INJECT_EXTRA_INPUT_AT_END
+depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
+ecjpake_rounds_inject_second:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:0:"abcdef":ERR_INJECT_EXTRA_INPUT_AT_END:PSA_ERROR_BAD_STATE
+
 PSA PAKE: ecjpake size macros
 depends_on:MBEDTLS_PSA_WANT_KEY_TYPE_ECC_KEY_PAIR_LEGACY:PSA_WANT_ECC_SECP_R1_256
 ecjpake_size_macros:
diff --git a/tests/suites/test_suite_psa_crypto_pake.function b/tests/suites/test_suite_psa_crypto_pake.function
index 87c40f5..49ca361 100644
--- a/tests/suites/test_suite_psa_crypto_pake.function
+++ b/tests/suites/test_suite_psa_crypto_pake.function
@@ -42,6 +42,8 @@
     ERR_INJECT_ROUND2_SERVER_ZK_PROOF,
     ERR_INJECT_EXTRA_OUTPUT,
     ERR_INJECT_EXTRA_INPUT,
+    ERR_INJECT_EXTRA_OUTPUT_AT_END,
+    ERR_INJECT_EXTRA_INPUT_AT_END,
     /* erros issued from the .data file */
     ERR_IN_SETUP,
     ERR_IN_SET_USER,
@@ -466,6 +468,16 @@
                 buffer1 + buffer1_off);
             DO_ROUND_UPDATE_OFFSETS(buffer1_off, c_x2s_pr_off, c_x2s_pr_len);
 
+            if (client_input_first == 1) {
+                size_t extra_output_at_end_len;
+                DO_ROUND_CONDITIONAL_CHECK_FAILURE(
+                    ERR_INJECT_EXTRA_OUTPUT_AT_END,
+                    psa_pake_output(client, PSA_PAKE_STEP_KEY_SHARE,
+                                    buffer1 + c_a_off, 512 - c_a_off,
+                                    &extra_output_at_end_len));
+                (void) extra_output_at_end_len;
+            }
+
             if (client_input_first == 0) {
                 /* Client second round Input */
                 status = psa_pake_input(client, PSA_PAKE_STEP_KEY_SHARE,
@@ -503,6 +515,12 @@
                                     buffer1 + c_x2s_pr_off, c_x2s_pr_len);
             DO_ROUND_CHECK_FAILURE();
 
+            DO_ROUND_CONDITIONAL_CHECK_FAILURE(
+                ERR_INJECT_EXTRA_INPUT_AT_END,
+                psa_pake_input(server, PSA_PAKE_STEP_KEY_SHARE,
+                               buffer1 + c_a_off, c_a_len));
+
+
             /* Error didn't trigger, make test fail */
             if ((err_stage >= ERR_INJECT_ROUND2_CLIENT_KEY_SHARE) &&
                 (err_stage <= ERR_INJECT_ROUND2_CLIENT_ZK_PROOF)) {
@@ -811,6 +829,63 @@
 /* END_CASE */
 
 /* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
+/* Inject errors during the second round of J-PAKE */
+void ecjpake_rounds_inject_second(int alg_arg, int primitive_arg, int hash_arg,
+                                  int client_input_first,
+                                  data_t *pw_data,
+                                  int err_stage_arg,
+                                  int expected_error_arg)
+{
+    psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
+    psa_pake_operation_t server = psa_pake_operation_init();
+    psa_pake_operation_t client = psa_pake_operation_init();
+    psa_algorithm_t alg = alg_arg;
+    psa_algorithm_t hash_alg = hash_arg;
+    mbedtls_svc_key_id_t key = MBEDTLS_SVC_KEY_ID_INIT;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    ecjpake_error_stage_t err_stage = err_stage_arg;
+
+    PSA_INIT();
+
+    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_DERIVE);
+    psa_set_key_algorithm(&attributes, alg);
+    psa_set_key_type(&attributes, PSA_KEY_TYPE_PASSWORD);
+
+    PSA_ASSERT(psa_import_key(&attributes, pw_data->x, pw_data->len,
+                              &key));
+
+    psa_pake_cs_set_algorithm(&cipher_suite, alg);
+    psa_pake_cs_set_primitive(&cipher_suite, primitive_arg);
+    psa_pake_cs_set_hash(&cipher_suite, hash_alg);
+
+    PSA_ASSERT(psa_pake_setup(&server, &cipher_suite));
+    PSA_ASSERT(psa_pake_setup(&client, &cipher_suite));
+
+    PSA_ASSERT(psa_pake_set_user(&server, jpake_server_id, sizeof(jpake_server_id)));
+    PSA_ASSERT(psa_pake_set_peer(&server, jpake_client_id, sizeof(jpake_client_id)));
+    PSA_ASSERT(psa_pake_set_user(&client, jpake_client_id, sizeof(jpake_client_id)));
+    PSA_ASSERT(psa_pake_set_peer(&client, jpake_server_id, sizeof(jpake_server_id)));
+
+    PSA_ASSERT(psa_pake_set_password_key(&server, key));
+    PSA_ASSERT(psa_pake_set_password_key(&client, key));
+
+    ecjpake_do_round(alg, primitive_arg, &server, &client,
+                     client_input_first, PAKE_ROUND_ONE,
+                     ERR_NONE, expected_error_arg);
+
+    ecjpake_do_round(alg, primitive_arg, &server, &client,
+                     client_input_first, PAKE_ROUND_TWO,
+                     err_stage, expected_error_arg);
+
+exit:
+    psa_destroy_key(key);
+    psa_pake_abort(&server);
+    psa_pake_abort(&client);
+    PSA_DONE();
+}
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:PSA_WANT_ALG_JPAKE */
 void ecjpake_rounds(int alg_arg, int primitive_arg, int hash_arg,
                     int derive_alg_arg, data_t *pw_data,
                     int client_input_first, int destroy_key,