mbedtls_psa_pake_get_implicit_key: move psa_key_derivation_input_bytes call to upper layer

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 8dc1a21..4e0f5f5 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7295,11 +7295,34 @@
     psa_pake_operation_t *operation,
     psa_key_derivation_operation_t *output)
 {
+    psa_status_t status = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    uint8_t shared_key[MBEDTLS_PSA_PAKE_BUFFER_SIZE];
+    size_t shared_key_len = 0;
+
     if (operation->id == 0) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    return psa_driver_wrapper_pake_get_implicit_key(operation, output);
+    status = psa_driver_wrapper_pake_get_implicit_key(operation,
+                                                      shared_key,
+                                                      &shared_key_len);
+
+    if (status != PSA_SUCCESS) {
+        return status;
+    }
+
+    status = psa_key_derivation_input_bytes(output,
+                                            PSA_KEY_DERIVATION_INPUT_SECRET,
+                                            shared_key,
+                                            shared_key_len);
+
+    if (status != PSA_SUCCESS) {
+        psa_key_derivation_abort(output);
+    }
+
+    mbedtls_platform_zeroize(shared_key, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
+
+    return status;
 }
 
 psa_status_t psa_pake_abort(
diff --git a/library/psa_crypto_driver_wrappers.h b/library/psa_crypto_driver_wrappers.h
index a3755d3..78f2f9a 100644
--- a/library/psa_crypto_driver_wrappers.h
+++ b/library/psa_crypto_driver_wrappers.h
@@ -454,7 +454,7 @@
 
 psa_status_t psa_driver_wrapper_pake_get_implicit_key(
     psa_pake_operation_t *operation,
-    psa_key_derivation_operation_t *output);
+    uint8_t *output, size_t *output_size);
 
 psa_status_t psa_driver_wrapper_pake_abort(
     psa_pake_operation_t *operation);
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index 6c4db6f..1e5dca4 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -835,7 +835,7 @@
 
 psa_status_t mbedtls_psa_pake_get_implicit_key(
     mbedtls_psa_pake_operation_t *operation,
-    psa_key_derivation_operation_t *output)
+    uint8_t *output, size_t *output_size)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
@@ -863,16 +863,14 @@
             return mbedtls_ecjpake_to_psa_error(ret);
         }
 
-        status = psa_key_derivation_input_bytes(output,
-                                                PSA_KEY_DERIVATION_INPUT_SECRET,
-                                                operation->buffer,
-                                                operation->buffer_length);
+        memcpy(output, operation->buffer, operation->buffer_length);
+        *output_size = operation->buffer_length;
 
         mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
 
         mbedtls_psa_pake_abort(operation);
 
-        return status;
+        return PSA_SUCCESS;
     } else
 #else
     (void) output;
@@ -880,7 +878,6 @@
     { status = PSA_ERROR_NOT_SUPPORTED; }
 
 error:
-    psa_key_derivation_abort(output);
     mbedtls_psa_pake_abort(operation);
 
     return status;
diff --git a/library/psa_crypto_pake.h b/library/psa_crypto_pake.h
index c7bf270..9256f5a 100644
--- a/library/psa_crypto_pake.h
+++ b/library/psa_crypto_pake.h
@@ -442,7 +442,7 @@
  */
 psa_status_t mbedtls_psa_pake_get_implicit_key(
     mbedtls_psa_pake_operation_t *operation,
-    psa_key_derivation_operation_t *output);
+    uint8_t *output, size_t *output_size);
 
 /** Abort a PAKE operation.
  *
diff --git a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
index b3e40f0..cea7948 100644
--- a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
+++ b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
@@ -3079,13 +3079,13 @@
 
 psa_status_t psa_driver_wrapper_pake_get_implicit_key(
     psa_pake_operation_t *operation,
-    psa_key_derivation_operation_t *output )
+    uint8_t *output, size_t *output_size )
 {
     switch( operation->id )
     {
 #if defined(MBEDTLS_PSA_BUILTIN_PAKE)
         case PSA_CRYPTO_MBED_TLS_DRIVER_ID:
-            return( mbedtls_psa_pake_get_implicit_key( &operation->ctx.mbedtls_ctx, output ) );
+            return( mbedtls_psa_pake_get_implicit_key( &operation->ctx.mbedtls_ctx, output, output_size ) );
 #endif /* MBEDTLS_PSA_BUILTIN_PAKE */
 
 #if defined(PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT)
@@ -3093,15 +3093,16 @@
         case MBEDTLS_TEST_TRANSPARENT_DRIVER_ID:
             return( mbedtls_test_transparent_pake_get_implicit_key(
                         &operation->ctx.transparent_test_driver_ctx,
-                        (psa_key_derivation_operation_t*) output ) );
+                        output, output_size ) );
         case MBEDTLS_TEST_OPAQUE_DRIVER_ID:
             return( mbedtls_test_opaque_pake_get_implicit_key(
                         &operation->ctx.opaque_test_driver_ctx,
-                        (psa_key_derivation_operation_t*) output ) );
+                        output, output_size ) );
 #endif /* PSA_CRYPTO_DRIVER_TEST */
 #endif /* PSA_CRYPTO_ACCELERATOR_DRIVER_PRESENT */
         default:
             (void) output;
+            (void) output_size;
             return( PSA_ERROR_INVALID_ARGUMENT );
     }
 }
diff --git a/tests/include/test/drivers/pake.h b/tests/include/test/drivers/pake.h
index 81e8711..5ee401b 100644
--- a/tests/include/test/drivers/pake.h
+++ b/tests/include/test/drivers/pake.h
@@ -87,7 +87,7 @@
 
 psa_status_t mbedtls_test_transparent_pake_get_implicit_key(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
-    psa_key_derivation_operation_t *output);
+    uint8_t *output, size_t *output_size);
 
 psa_status_t mbedtls_test_transparent_pake_abort(
     mbedtls_transparent_test_driver_pake_operation_t *operation);
@@ -131,7 +131,7 @@
 
 psa_status_t mbedtls_test_opaque_pake_get_implicit_key(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
-    psa_key_derivation_operation_t *output);
+    uint8_t *output, size_t *output_size);
 
 psa_status_t mbedtls_test_opaque_pake_abort(
     mbedtls_opaque_test_driver_pake_operation_t *operation);
diff --git a/tests/src/drivers/test_driver_pake.c b/tests/src/drivers/test_driver_pake.c
index 1ced559..3495705 100644
--- a/tests/src/drivers/test_driver_pake.c
+++ b/tests/src/drivers/test_driver_pake.c
@@ -270,7 +270,7 @@
 
 psa_status_t mbedtls_test_transparent_pake_get_implicit_key(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
-    psa_key_derivation_operation_t *output)
+    uint8_t *output, size_t *output_size)
 {
     mbedtls_test_driver_pake_hooks.hits++;
 
@@ -282,11 +282,11 @@
         defined(LIBTESTDRIVER1_MBEDTLS_PSA_BUILTIN_PAKE)
         mbedtls_test_driver_pake_hooks.driver_status =
             libtestdriver1_mbedtls_psa_pake_get_implicit_key(
-                operation, (libtestdriver1_psa_key_derivation_operation_t *) output);
+                operation,  output, output_size);
 #elif defined(MBEDTLS_PSA_BUILTIN_PAKE)
         mbedtls_test_driver_pake_hooks.driver_status =
             mbedtls_psa_pake_get_implicit_key(
-                operation, output);
+                operation, output, output_size);
 #else
         (void) operation;
         (void) output;
@@ -411,10 +411,11 @@
 
 psa_status_t mbedtls_test_opaque_pake_get_implicit_key(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
-    psa_key_derivation_operation_t *output)
+    uint8_t *output, size_t *output_size)
 {
     (void) operation;
     (void) output;
+    (void) output_size;
     return PSA_ERROR_NOT_SUPPORTED;
 }