Further pake code optimizations

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 8752bff..1611fc9 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7238,7 +7238,6 @@
     const psa_pake_cipher_suite_t *cipher_suite)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED;
 
     if (operation->stage != PSA_PAKE_OPERATION_STAGE_SETUP) {
         status = PSA_ERROR_BAD_STATE;
@@ -7266,8 +7265,7 @@
         computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
         computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
     } else
-#else
-#endif
+#endif /* PSA_WANT_ALG_JPAKE */
     {
         status = PSA_ERROR_NOT_SUPPORTED;
         goto exit;
@@ -7277,8 +7275,8 @@
 
     return PSA_SUCCESS;
 exit:
-    abort_status = psa_pake_abort(operation);
-    return status == PSA_SUCCESS ? abort_status : status;
+    psa_pake_abort(operation);
+    return status;
 }
 
 psa_status_t psa_pake_set_password_key(
@@ -7287,7 +7285,6 @@
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t unlock_status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_key_slot_t *slot = NULL;
 
     if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
@@ -7323,15 +7320,12 @@
     memcpy(operation->data.inputs.password, slot->key.data, slot->key.bytes);
     operation->data.inputs.password_len = slot->key.bytes;
     operation->data.inputs.attributes = attributes;
-
-    unlock_status = psa_unlock_key_slot(slot);
-
-    return unlock_status;
 exit:
+    if (status != PSA_SUCCESS) {
+        psa_pake_abort(operation);
+    }
     unlock_status = psa_unlock_key_slot(slot);
-    abort_status = psa_pake_abort(operation);
-    status = (status == PSA_SUCCESS) ? unlock_status : status;
-    return (status == PSA_SUCCESS) ? abort_status : status;
+    return (status == PSA_SUCCESS) ? unlock_status : status;
 }
 
 psa_status_t psa_pake_set_user(
@@ -7340,7 +7334,6 @@
     size_t user_id_len)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED;
     (void) user_id;
 
     if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
@@ -7355,8 +7348,8 @@
 
     return PSA_ERROR_NOT_SUPPORTED;
 exit:
-    abort_status = psa_pake_abort(operation);
-    return status == PSA_SUCCESS ? abort_status : status;
+    psa_pake_abort(operation);
+    return status;
 }
 
 psa_status_t psa_pake_set_peer(
@@ -7365,7 +7358,6 @@
     size_t peer_id_len)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED;
     (void) peer_id;
 
     if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
@@ -7380,8 +7372,8 @@
 
     return PSA_ERROR_NOT_SUPPORTED;
 exit:
-    abort_status = psa_pake_abort(operation);
-    return status == PSA_SUCCESS ? abort_status : status;
+    psa_pake_abort(operation);
+    return status;
 }
 
 psa_status_t psa_pake_set_role(
@@ -7389,7 +7381,6 @@
     psa_pake_role_t role)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED;
 
     if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         status =  PSA_ERROR_BAD_STATE;
@@ -7409,8 +7400,8 @@
 
     return PSA_SUCCESS;
 exit:
-    abort_status = psa_pake_abort(operation);
-    return status == PSA_SUCCESS ? abort_status : status;
+    psa_pake_abort(operation);
+    return status;
 }
 
 /* Auxiliary function to convert core computation stage(step, sequence, state) to single driver step. */
@@ -7477,7 +7468,7 @@
     }
     return PSA_JPAKE_STEP_INVALID;
 }
-#endif
+#endif /* PSA_WANT_ALG_JPAKE */
 
 static psa_status_t psa_pake_complete_inputs(
     psa_pake_operation_t *operation)
@@ -7518,7 +7509,7 @@
             computation_stage->input_step = PSA_PAKE_STEP_X1_X2;
             computation_stage->output_step = PSA_PAKE_STEP_X1_X2;
         } else
-#endif
+#endif /* PSA_WANT_ALG_JPAKE */
         {
             status = PSA_ERROR_NOT_SUPPORTED;
         }
@@ -7598,9 +7589,7 @@
 
     return PSA_SUCCESS;
 }
-#endif
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
 static psa_status_t psa_jpake_output_epilogue(
     psa_pake_operation_t *operation)
 {
@@ -7620,7 +7609,7 @@
 
     return PSA_SUCCESS;
 }
-#endif
+#endif /* PSA_WANT_ALG_JPAKE */
 
 psa_status_t psa_pake_output(
     psa_pake_operation_t *operation,
@@ -7630,7 +7619,6 @@
     size_t *output_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED;
     *output_length = 0;
 
     if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
@@ -7658,7 +7646,7 @@
                 goto exit;
             }
             break;
-#endif
+#endif /* PSA_WANT_ALG_JPAKE */
         default:
             (void) step;
             status = PSA_ERROR_NOT_SUPPORTED;
@@ -7675,7 +7663,7 @@
 #else
     (void) output;
     status = PSA_ERROR_NOT_SUPPORTED;
-#endif
+#endif /* PSA_WANT_ALG_JPAKE */
 
     if (status != PSA_SUCCESS) {
         goto exit;
@@ -7689,7 +7677,7 @@
                 goto exit;
             }
             break;
-#endif
+#endif /* PSA_WANT_ALG_JPAKE */
         default:
             status = PSA_ERROR_NOT_SUPPORTED;
             goto exit;
@@ -7697,8 +7685,8 @@
 
     return PSA_SUCCESS;
 exit:
-    abort_status = psa_pake_abort(operation);
-    return status == PSA_SUCCESS ? abort_status : status;
+    psa_pake_abort(operation);
+    return status;
 }
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
@@ -7780,9 +7768,7 @@
 
     return PSA_SUCCESS;
 }
-#endif
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
 static psa_status_t psa_jpake_input_epilogue(
     psa_pake_operation_t *operation)
 {
@@ -7802,7 +7788,7 @@
 
     return PSA_SUCCESS;
 }
-#endif
+#endif /* PSA_WANT_ALG_JPAKE */
 
 psa_status_t psa_pake_input(
     psa_pake_operation_t *operation,
@@ -7811,7 +7797,6 @@
     size_t input_length)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED;
 
     if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         status = psa_pake_complete_inputs(operation);
@@ -7838,10 +7823,11 @@
                 goto exit;
             }
             break;
-#endif
+#endif /* PSA_WANT_ALG_JPAKE */
         default:
             (void) step;
-            return PSA_ERROR_NOT_SUPPORTED;
+            status = PSA_ERROR_NOT_SUPPORTED;
+            goto exit;
     }
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
@@ -7853,7 +7839,7 @@
 #else
     (void) input;
     status = PSA_ERROR_NOT_SUPPORTED;
-#endif
+#endif /* PSA_WANT_ALG_JPAKE */
 
     if (status != PSA_SUCCESS) {
         goto exit;
@@ -7867,7 +7853,7 @@
                 goto exit;
             }
             break;
-#endif
+#endif /* PSA_WANT_ALG_JPAKE */
         default:
             status = PSA_ERROR_NOT_SUPPORTED;
             goto exit;
@@ -7875,8 +7861,8 @@
 
     return PSA_SUCCESS;
 exit:
-    abort_status = psa_pake_abort(operation);
-    return status == PSA_SUCCESS ? abort_status : status;
+    psa_pake_abort(operation);
+    return status;
 }
 
 psa_status_t psa_pake_get_implicit_key(
@@ -7903,9 +7889,7 @@
             goto exit;
         }
     } else
-#else
-
-#endif
+#endif /* PSA_WANT_ALG_JPAKE */
     {
         status = PSA_ERROR_NOT_SUPPORTED;
         goto exit;
@@ -7925,7 +7909,7 @@
                                             shared_key,
                                             shared_key_len);
 
-    mbedtls_platform_zeroize(shared_key, MBEDTLS_PSA_JPAKE_BUFFER_SIZE);
+    mbedtls_platform_zeroize(shared_key, sizeof(shared_key));
 exit:
     abort_status = psa_pake_abort(operation);
     return status == PSA_SUCCESS ? abort_status : status;
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index 63d0830..c6f9e89 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -171,9 +171,9 @@
     mbedtls_ecjpake_role role = (operation->role == PSA_PAKE_ROLE_CLIENT) ?
                                 MBEDTLS_ECJPAKE_CLIENT : MBEDTLS_ECJPAKE_SERVER;
 
-    mbedtls_ecjpake_init(&operation->ctx.pake);
+    mbedtls_ecjpake_init(&operation->ctx.jpake);
 
-    ret = mbedtls_ecjpake_setup(&operation->ctx.pake,
+    ret = mbedtls_ecjpake_setup(&operation->ctx.jpake,
                                 role,
                                 MBEDTLS_MD_SHA256,
                                 MBEDTLS_ECP_DP_SECP256R1,
@@ -295,9 +295,9 @@
     if (operation->alg == PSA_ALG_JPAKE) {
         /* Initialize & write round on KEY_SHARE sequences */
         if (step == PSA_JPAKE_X1_STEP_KEY_SHARE) {
-            ret = mbedtls_ecjpake_write_round_one(&operation->ctx.pake,
+            ret = mbedtls_ecjpake_write_round_one(&operation->ctx.jpake,
                                                   operation->buffer,
-                                                  MBEDTLS_PSA_JPAKE_BUFFER_SIZE,
+                                                  sizeof(operation->buffer),
                                                   &operation->buffer_length,
                                                   mbedtls_psa_get_random,
                                                   MBEDTLS_PSA_RANDOM_STATE);
@@ -307,9 +307,9 @@
 
             operation->buffer_offset = 0;
         } else if (step == PSA_JPAKE_X2S_STEP_KEY_SHARE) {
-            ret = mbedtls_ecjpake_write_round_two(&operation->ctx.pake,
+            ret = mbedtls_ecjpake_write_round_two(&operation->ctx.jpake,
                                                   operation->buffer,
-                                                  MBEDTLS_PSA_JPAKE_BUFFER_SIZE,
+                                                  sizeof(operation->buffer),
                                                   &operation->buffer_length,
                                                   mbedtls_psa_get_random,
                                                   MBEDTLS_PSA_RANDOM_STATE);
@@ -359,7 +359,7 @@
         /* Reset buffer after ZK_PROOF sequence */
         if ((step == PSA_JPAKE_X2_STEP_ZK_PROOF) ||
             (step == PSA_JPAKE_X2S_STEP_ZK_PROOF)) {
-            mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_JPAKE_BUFFER_SIZE);
+            mbedtls_platform_zeroize(operation->buffer, sizeof(operation->buffer));
             operation->buffer_length = 0;
             operation->buffer_offset = 0;
         }
@@ -446,22 +446,22 @@
 
         /* Load buffer at each last round ZK_PROOF */
         if (step == PSA_JPAKE_X2_STEP_ZK_PROOF) {
-            ret = mbedtls_ecjpake_read_round_one(&operation->ctx.pake,
+            ret = mbedtls_ecjpake_read_round_one(&operation->ctx.jpake,
                                                  operation->buffer,
                                                  operation->buffer_length);
 
-            mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_JPAKE_BUFFER_SIZE);
+            mbedtls_platform_zeroize(operation->buffer, sizeof(operation->buffer));
             operation->buffer_length = 0;
 
             if (ret != 0) {
                 return mbedtls_ecjpake_to_psa_error(ret);
             }
         } else if (step == PSA_JPAKE_X4S_STEP_ZK_PROOF) {
-            ret = mbedtls_ecjpake_read_round_two(&operation->ctx.pake,
+            ret = mbedtls_ecjpake_read_round_two(&operation->ctx.jpake,
                                                  operation->buffer,
                                                  operation->buffer_length);
 
-            mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_JPAKE_BUFFER_SIZE);
+            mbedtls_platform_zeroize(operation->buffer, sizeof(operation->buffer));
             operation->buffer_length = 0;
 
             if (ret != 0) {
@@ -499,19 +499,16 @@
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
     if (operation->alg == PSA_ALG_JPAKE) {
-        ret = mbedtls_ecjpake_write_shared_key(&operation->ctx.pake,
-                                               operation->buffer,
+        ret = mbedtls_ecjpake_write_shared_key(&operation->ctx.jpake,
+                                               output,
                                                output_size,
-                                               &operation->buffer_length,
+                                               output_length,
                                                mbedtls_psa_get_random,
                                                MBEDTLS_PSA_RANDOM_STATE);
         if (ret != 0) {
             return mbedtls_ecjpake_to_psa_error(ret);
         }
 
-        memcpy(output, operation->buffer, operation->buffer_length);
-        *output_length = operation->buffer_length;
-
         return PSA_SUCCESS;
     } else
 #else
@@ -530,10 +527,10 @@
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
     if (operation->alg == PSA_ALG_JPAKE) {
         operation->role = PSA_PAKE_ROLE_NONE;
-        mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_JPAKE_BUFFER_SIZE);
+        mbedtls_platform_zeroize(operation->buffer, sizeof(operation->buffer));
         operation->buffer_length = 0;
         operation->buffer_offset = 0;
-        mbedtls_ecjpake_free(&operation->ctx.pake);
+        mbedtls_ecjpake_free(&operation->ctx.jpake);
     }
 #endif