Merge pull request #7331 from mprse/ec-jpake-fix2

PSA PAKE: Check input_length against PSA_PAKE_INPUT_SIZE() in psa_pake_input
diff --git a/docs/proposed/psa-driver-interface.md b/docs/proposed/psa-driver-interface.md
index 0027ec7..cd1b9fc 100644
--- a/docs/proposed/psa-driver-interface.md
+++ b/docs/proposed/psa-driver-interface.md
@@ -474,7 +474,8 @@
 * `PSA_JPAKE_X4S_STEP_ZK_PUBLIC`    Round 2: input Schnorr NIZKP public key for the X4S key
 * `PSA_JPAKE_X4S_STEP_ZK_PROOF`     Round 2: input Schnorr NIZKP proof for the X4S key
 
-The core checks that input_length is smaller than PSA_PAKE_INPUT_MAX_SIZE.
+The core checks that `input_length` is not greater than `PSA_PAKE_INPUT_SIZE(alg, prim, step)` and
+the driver can rely on that.
 
 ### PAKE driver get implicit key
 
diff --git a/include/psa/crypto_extra.h b/include/psa/crypto_extra.h
index e8cecf7..b858180 100644
--- a/include/psa/crypto_extra.h
+++ b/include/psa/crypto_extra.h
@@ -1937,6 +1937,9 @@
  *
  * This macro must expand to a compile-time constant integer.
  *
+ * The value of this macro must be at least as large as the largest value
+ * returned by PSA_PAKE_OUTPUT_SIZE()
+ *
  * See also #PSA_PAKE_OUTPUT_SIZE(\p alg, \p primitive, \p step).
  */
 #define PSA_PAKE_OUTPUT_MAX_SIZE 65
@@ -1946,6 +1949,9 @@
  *
  * This macro must expand to a compile-time constant integer.
  *
+ * The value of this macro must be at least as large as the largest value
+ * returned by PSA_PAKE_INPUT_SIZE()
+ *
  * See also #PSA_PAKE_INPUT_SIZE(\p alg, \p primitive, \p step).
  */
 #define PSA_PAKE_INPUT_MAX_SIZE 65
@@ -1958,7 +1964,7 @@
 /** Returns a suitable initializer for a PAKE operation object of type
  * psa_pake_operation_t.
  */
-#define PSA_PAKE_OPERATION_INIT { 0, PSA_ALG_NONE, PSA_PAKE_OPERATION_STAGE_SETUP, \
+#define PSA_PAKE_OPERATION_INIT { 0, PSA_ALG_NONE, 0, PSA_PAKE_OPERATION_STAGE_SETUP, \
                                   { 0 }, { { 0 } } }
 
 struct psa_pake_cipher_suite_s {
@@ -2104,6 +2110,8 @@
     unsigned int MBEDTLS_PRIVATE(id);
     /* Algorithm of the PAKE operation */
     psa_algorithm_t MBEDTLS_PRIVATE(alg);
+    /* A primitive of type compatible with algorithm */
+    psa_pake_primitive_t MBEDTLS_PRIVATE(primitive);
     /* Stage of the PAKE operation: waiting for the setup, collecting inputs
      * or computing. */
     uint8_t MBEDTLS_PRIVATE(stage);
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index a89b5ff..20918bc 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7316,6 +7316,8 @@
     memset(&operation->data.inputs, 0, sizeof(operation->data.inputs));
 
     operation->alg = cipher_suite->algorithm;
+    operation->primitive = PSA_PAKE_PRIMITIVE(cipher_suite->type,
+                                              cipher_suite->family, cipher_suite->bits);
     operation->data.inputs.cipher_suite = *cipher_suite;
 
 #if defined(PSA_WANT_ALG_JPAKE)
@@ -7900,6 +7902,9 @@
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_crypto_driver_pake_step_t driver_step = PSA_JPAKE_STEP_INVALID;
+    const size_t max_input_length = (size_t) PSA_PAKE_INPUT_SIZE(operation->alg,
+                                                                 operation->primitive,
+                                                                 step);
 
     if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         status = psa_pake_complete_inputs(operation);
@@ -7913,7 +7918,7 @@
         goto exit;
     }
 
-    if (input_length == 0 || input_length > PSA_PAKE_INPUT_MAX_SIZE) {
+    if (input_length == 0 || input_length > max_input_length) {
         status = PSA_ERROR_INVALID_ARGUMENT;
         goto exit;
     }
diff --git a/tests/suites/test_suite_psa_crypto_pake.data b/tests/suites/test_suite_psa_crypto_pake.data
index 6215703..c467d01 100644
--- a/tests/suites/test_suite_psa_crypto_pake.data
+++ b/tests/suites/test_suite_psa_crypto_pake.data
@@ -82,10 +82,14 @@
 depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
 ecjpake_setup:PSA_ALG_JPAKE:PSA_KEY_TYPE_PASSWORD:PSA_KEY_USAGE_DERIVE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:"client":"server":1:ERR_INJECT_INVALID_FIRST_STEP:PSA_ERROR_BAD_STATE
 
-PSA PAKE: input buffer too large
+PSA PAKE: input buffer too large #1
 depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
 ecjpake_setup:PSA_ALG_JPAKE:PSA_KEY_TYPE_PASSWORD:PSA_KEY_USAGE_DERIVE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:"client":"server":1:ERR_INJECT_WRONG_BUFFER_SIZE:PSA_ERROR_INVALID_ARGUMENT
 
+PSA PAKE: input buffer too large #2
+depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
+ecjpake_setup:PSA_ALG_JPAKE:PSA_KEY_TYPE_PASSWORD:PSA_KEY_USAGE_DERIVE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:"client":"server":1:ERR_INJECT_WRONG_BUFFER_SIZE_2:PSA_ERROR_INVALID_ARGUMENT
+
 PSA PAKE: invalid output
 depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
 ecjpake_setup:PSA_ALG_JPAKE:PSA_KEY_TYPE_PASSWORD:PSA_KEY_USAGE_DERIVE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:"client":"server":0:ERR_INJECT_EMPTY_IO_BUFFER:PSA_ERROR_INVALID_ARGUMENT
diff --git a/tests/suites/test_suite_psa_crypto_pake.function b/tests/suites/test_suite_psa_crypto_pake.function
index 88f24dd..ecbd363 100644
--- a/tests/suites/test_suite_psa_crypto_pake.function
+++ b/tests/suites/test_suite_psa_crypto_pake.function
@@ -17,6 +17,7 @@
     ERR_INJECT_UNKNOWN_STEP,
     ERR_INJECT_INVALID_FIRST_STEP,
     ERR_INJECT_WRONG_BUFFER_SIZE,
+    ERR_INJECT_WRONG_BUFFER_SIZE_2,
     ERR_INJECT_VALID_OPERATION_AFTER_FAILURE,
     ERR_INJECT_ANTICIPATE_KEY_DERIVATION_1,
     ERR_INJECT_ANTICIPATE_KEY_DERIVATION_2,
@@ -670,6 +671,11 @@
                                                     output_buffer, size_zk_public + 1),
                                      ERR_INJECT_WRONG_BUFFER_SIZE);
 
+        SETUP_CONDITIONAL_CHECK_STEP(psa_pake_input(&operation,
+                                                    PSA_PAKE_STEP_ZK_PROOF,
+                                                    output_buffer, size_zk_proof + 1),
+                                     ERR_INJECT_WRONG_BUFFER_SIZE_2);
+
         SETUP_CONDITIONAL_CHECK_STEP(
             (psa_pake_input(&operation, PSA_PAKE_STEP_ZK_PUBLIC,
                             output_buffer, size_zk_public + 1),