Restrict the EC J-PAKE to PMS input type to secret

Signed-off-by: Andrzej Kurek <andrzej.kurek@arm.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 5c05f79..cbdc912 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -5148,7 +5148,7 @@
         return( PSA_ERROR_NOT_SUPPORTED );
 
     /* All currently supported key derivation algorithms (apart from
-     * ecjpake to pms are based on a hash algorithm. */
+     * ecjpake to pms) are based on a hash algorithm. */
     psa_algorithm_t hash_alg = PSA_ALG_HKDF_GET_HASH( kdf_alg );
     size_t hash_size = PSA_HASH_LENGTH( hash_alg );
     if( !PSA_ALG_IS_TLS12_ECJPAKE_TO_PMS( kdf_alg ) )
@@ -5570,10 +5570,12 @@
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_TLS12_ECJPAKE_TO_PMS)
 static psa_status_t psa_tls12_ecjpake_to_pms_input(
     psa_tls12_ecjpake_to_pms_t *ecjpake,
+    psa_key_derivation_step_t step,
     const uint8_t *data,
     size_t data_length )
 {
-    if( data_length != PSA_TLS12_ECJPAKE_TO_PMS_INPUT_SIZE )
+    if( data_length != PSA_TLS12_ECJPAKE_TO_PMS_INPUT_SIZE ||
+        step != PSA_KEY_DERIVATION_INPUT_SECRET )
         return( PSA_ERROR_INVALID_ARGUMENT );
 
     /* Check if the passed point is in an uncompressed form */
@@ -5668,7 +5670,7 @@
     if( PSA_ALG_IS_TLS12_ECJPAKE_TO_PMS( kdf_alg ) )
     {
         status = psa_tls12_ecjpake_to_pms_input(
-            &operation->ctx.tls12_ecjpake_to_pms, data, data_length );
+            &operation->ctx.tls12_ecjpake_to_pms, step, data, data_length );
     }
     else
 #endif /* MBEDTLS_PSA_BUILTIN_ALG_TLS12_ECJPAKE_TO_PMS */