Optimization of pake core functions

Adapt pake test (passing NULL buffers is not allowed).
Passing the null buffer to psa_pake_output results in a hard fault.

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index b4fad33..4f3d774 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7241,8 +7241,7 @@
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (cipher_suite == NULL ||
-        PSA_ALG_IS_PAKE(cipher_suite->algorithm) == 0 ||
+    if (PSA_ALG_IS_PAKE(cipher_suite->algorithm) == 0 ||
         PSA_ALG_IS_HASH(cipher_suite->hash) == 0) {
         return PSA_ERROR_INVALID_ARGUMENT;
     }
@@ -7436,17 +7435,12 @@
 static psa_status_t psa_pake_complete_inputs(
     psa_pake_operation_t *operation)
 {
-    psa_jpake_computation_stage_t *computation_stage =
-        &operation->computation_stage.jpake;
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    uint8_t *password = operation->data.inputs.password;
-    size_t password_len = operation->data.inputs.password_len;
     /* Create copy of the inputs on stack as inputs share memory
        with the driver context which will be setup by the driver. */
     psa_crypto_driver_pake_inputs_t inputs = operation->data.inputs;
 
-    if (operation->alg == PSA_ALG_NONE ||
-        operation->data.inputs.password_len == 0 ||
+    if (operation->data.inputs.password_len == 0 ||
         operation->data.inputs.role == PSA_PAKE_ROLE_NONE) {
         return PSA_ERROR_BAD_STATE;
     }
@@ -7457,12 +7451,14 @@
     status = psa_driver_wrapper_pake_setup(operation, &inputs);
 
     /* Driver is responsible for creating its own copy of the password. */
-    mbedtls_platform_zeroize(password, password_len);
-    mbedtls_free(password);
+    mbedtls_platform_zeroize(inputs.password, inputs.password_len);
+    mbedtls_free(inputs.password);
 
     if (status == PSA_SUCCESS) {
         operation->stage = PSA_PAKE_OPERATION_STAGE_COMPUTATION;
         if (operation->alg == PSA_ALG_JPAKE) {
+            psa_jpake_computation_stage_t *computation_stage =
+                &operation->computation_stage.jpake;
             computation_stage->state = PSA_PAKE_STATE_READY;
             computation_stage->sequence = PSA_PAKE_SEQ_INVALID;
             computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
@@ -7576,6 +7572,7 @@
     size_t *output_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    *output_length = 0;
 
     if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         status = psa_pake_complete_inputs(operation);
@@ -7588,11 +7585,7 @@
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->id == 0) {
-        return PSA_ERROR_BAD_STATE;
-    }
-
-    if (output == NULL || output_size == 0) {
+    if (output_size == 0) {
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
@@ -7750,11 +7743,7 @@
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->id == 0) {
-        return PSA_ERROR_BAD_STATE;
-    }
-
-    if (input == NULL || input_length == 0) {
+    if (input_length == 0) {
         return PSA_ERROR_INVALID_ARGUMENT;
     }
 
@@ -7797,13 +7786,13 @@
     psa_pake_operation_t *operation,
     psa_key_derivation_operation_t *output)
 {
-    psa_status_t status = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     uint8_t shared_key[MBEDTLS_PSA_PAKE_BUFFER_SIZE];
     size_t shared_key_len = 0;
     psa_jpake_computation_stage_t *computation_stage =
         &operation->computation_stage.jpake;
 
-    if (operation->id == 0) {
+    if (operation->stage != PSA_PAKE_OPERATION_STAGE_COMPUTATION) {
         return PSA_ERROR_BAD_STATE;
     }
 
diff --git a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
index d52ed59..cf08794 100644
--- a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
+++ b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
@@ -2816,7 +2816,7 @@
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
     psa_key_location_t location =
-            PSA_KEY_LIFETIME_GET_LOCATION( inputs->attributes.core.lifetime );
+            PSA_KEY_LIFETIME_GET_LOCATION( psa_get_key_lifetime( &inputs->attributes ) );
 
     switch( location )
     {
diff --git a/tests/suites/test_suite_psa_crypto_pake.function b/tests/suites/test_suite_psa_crypto_pake.function
index 5af41f7..d77dfdc 100644
--- a/tests/suites/test_suite_psa_crypto_pake.function
+++ b/tests/suites/test_suite_psa_crypto_pake.function
@@ -590,10 +590,10 @@
         TEST_EQUAL(psa_pake_set_role(&operation, role),
                    expected_error);
         TEST_EQUAL(psa_pake_output(&operation, PSA_PAKE_STEP_KEY_SHARE,
-                                   NULL, 0, NULL),
+                                   output_buffer, 0, &output_len),
                    expected_error);
         TEST_EQUAL(psa_pake_input(&operation, PSA_PAKE_STEP_KEY_SHARE,
-                                  NULL, 0),
+                                  output_buffer, 0),
                    expected_error);
         TEST_EQUAL(psa_pake_get_implicit_key(&operation, &key_derivation),
                    expected_error);
@@ -633,7 +633,8 @@
 
     if (test_input) {
         SETUP_CONDITIONAL_CHECK_STEP(psa_pake_input(&operation,
-                                                    PSA_PAKE_STEP_ZK_PROOF,  NULL, 0),
+                                                    PSA_PAKE_STEP_ZK_PROOF,
+                                                    output_buffer, 0),
                                      ERR_INJECT_EMPTY_IO_BUFFER);
 
         SETUP_CONDITIONAL_CHECK_STEP(psa_pake_input(&operation,
@@ -665,7 +666,8 @@
     } else {
         SETUP_CONDITIONAL_CHECK_STEP(psa_pake_output(&operation,
                                                      PSA_PAKE_STEP_ZK_PROOF,
-                                                     NULL, 0, NULL),
+                                                     output_buffer, 0,
+                                                     &output_len),
                                      ERR_INJECT_EMPTY_IO_BUFFER);
 
         SETUP_CONDITIONAL_CHECK_STEP(psa_pake_output(&operation,