Fix psa_pake_get_implicit_key() state & add corresponding tests in ecjpake_rounds()

Signed-off-by: Neil Armstrong <narmstrong@baylibre.com>
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index f7fb384..8ceacd9 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -660,8 +660,8 @@
 
     if( operation->alg == 0 ||
         operation->state != PSA_PAKE_STATE_READY ||
-        ( operation->input_step != PSA_PAKE_STEP_DERIVE  &&
-          operation->output_step != PSA_PAKE_STEP_DERIVE ) )
+        operation->input_step != PSA_PAKE_STEP_DERIVE ||
+        operation->output_step != PSA_PAKE_STEP_DERIVE )
         return( PSA_ERROR_BAD_STATE );
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
diff --git a/tests/suites/test_suite_psa_crypto.function b/tests/suites/test_suite_psa_crypto.function
index 727784f..6d4f2a8 100644
--- a/tests/suites/test_suite_psa_crypto.function
+++ b/tests/suites/test_suite_psa_crypto.function
@@ -8316,6 +8316,21 @@
     psa_pake_cs_set_primitive( &cipher_suite, primitive_arg );
     psa_pake_cs_set_hash( &cipher_suite, hash_alg );
 
+    /* Get shared key */
+    PSA_ASSERT( psa_key_derivation_setup( &server_derive, derive_alg ) );
+    PSA_ASSERT( psa_key_derivation_setup( &client_derive, derive_alg ) );
+
+    if( PSA_ALG_IS_TLS12_PRF( derive_alg ) ||
+        PSA_ALG_IS_TLS12_PSK_TO_MS( derive_alg ) )
+    {
+        PSA_ASSERT( psa_key_derivation_input_bytes( &server_derive,
+                                                PSA_KEY_DERIVATION_INPUT_SEED,
+                                                (const uint8_t*) "", 0) );
+        PSA_ASSERT( psa_key_derivation_input_bytes( &client_derive,
+                                                PSA_KEY_DERIVATION_INPUT_SEED,
+                                                (const uint8_t*) "", 0) );
+    }
+
     PSA_ASSERT( psa_pake_setup( &server, &cipher_suite ) );
     PSA_ASSERT( psa_pake_setup( &client, &cipher_suite ) );
 
@@ -8325,6 +8340,11 @@
     PSA_ASSERT( psa_pake_set_password_key( &server, key ) );
     PSA_ASSERT( psa_pake_set_password_key( &client, key ) );
 
+    TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
+                PSA_ERROR_BAD_STATE );
+    TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ),
+                PSA_ERROR_BAD_STATE );
+
     /* Server first round Output */
     PSA_ASSERT( psa_pake_output( &server, PSA_PAKE_STEP_KEY_SHARE,
                                  buffer0 + buffer0_off,
@@ -8389,6 +8409,11 @@
     c_x2_pr_off = buffer1_off;
     buffer1_off += c_x2_pr_len;
 
+    TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
+                PSA_ERROR_BAD_STATE );
+    TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ),
+                PSA_ERROR_BAD_STATE );
+
     /* Client first round Input */
     PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_KEY_SHARE,
                                 buffer0 + s_g1_off, s_g1_len ) );
@@ -8417,6 +8442,11 @@
     PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_ZK_PROOF,
                                 buffer1 + c_x2_pr_off, c_x2_pr_len ) );
 
+    TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
+                PSA_ERROR_BAD_STATE );
+    TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ),
+                PSA_ERROR_BAD_STATE );
+
     /* Server second round Output */
     buffer0_off = 0;
 
@@ -8455,6 +8485,11 @@
     c_x2s_pr_off = buffer1_off;
     buffer1_off += c_x2s_pr_len;
 
+    TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
+                PSA_ERROR_BAD_STATE );
+    TEST_EQUAL( psa_pake_get_implicit_key( &client, &client_derive ),
+                PSA_ERROR_BAD_STATE );
+
     /* Client second round Input */
     PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_KEY_SHARE,
                                 buffer0 + s_a_off, s_a_len ) );
@@ -8463,6 +8498,9 @@
     PSA_ASSERT( psa_pake_input( &client, PSA_PAKE_STEP_ZK_PROOF,
                                 buffer0 + s_x2s_pr_off, s_x2s_pr_len ) );
 
+    TEST_EQUAL( psa_pake_get_implicit_key( &server, &server_derive ),
+                PSA_ERROR_BAD_STATE );
+
     /* Server second round Input */
     PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_KEY_SHARE,
                                 buffer1 + c_a_off, c_a_len ) );
@@ -8471,22 +8509,6 @@
     PSA_ASSERT( psa_pake_input( &server, PSA_PAKE_STEP_ZK_PROOF,
                                 buffer1 + c_x2s_pr_off, c_x2s_pr_len ) );
 
-
-    /* Get shared key */
-    PSA_ASSERT( psa_key_derivation_setup( &server_derive, derive_alg ) );
-    PSA_ASSERT( psa_key_derivation_setup( &client_derive, derive_alg ) );
-
-    if( PSA_ALG_IS_TLS12_PRF( derive_alg ) ||
-        PSA_ALG_IS_TLS12_PSK_TO_MS( derive_alg ) )
-    {
-        PSA_ASSERT( psa_key_derivation_input_bytes( &server_derive,
-                                                PSA_KEY_DERIVATION_INPUT_SEED,
-                                                (const uint8_t*) "", 0) );
-        PSA_ASSERT( psa_key_derivation_input_bytes( &client_derive,
-                                                PSA_KEY_DERIVATION_INPUT_SEED,
-                                                (const uint8_t*) "", 0) );
-    }
-
     PSA_ASSERT( psa_pake_get_implicit_key( &server, &server_derive ) );
     PSA_ASSERT( psa_pake_get_implicit_key( &client, &client_derive ) );