Merge pull request #6390 from mpg/fix-ecjpake-psa-format

Fix ecjpake PSA format
diff --git a/include/psa/crypto_extra.h b/include/psa/crypto_extra.h
index 6c2e06e..4f65398 100644
--- a/include/psa/crypto_extra.h
+++ b/include/psa/crypto_extra.h
@@ -1765,9 +1765,9 @@
       primitive == PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC,      \
                                       PSA_ECC_FAMILY_SECP_R1, 256) ?    \
       (                                                                 \
-        output_step == PSA_PAKE_STEP_KEY_SHARE ? 69 :                   \
-        output_step == PSA_PAKE_STEP_ZK_PUBLIC ? 66 :                   \
-        33                                                              \
+        output_step == PSA_PAKE_STEP_KEY_SHARE ? 65 :                   \
+        output_step == PSA_PAKE_STEP_ZK_PUBLIC ? 65 :                   \
+        32                                                              \
       ) :                                                               \
       0 )
 
@@ -1795,9 +1795,9 @@
       primitive == PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC,      \
                                       PSA_ECC_FAMILY_SECP_R1, 256) ?    \
       (                                                                 \
-        input_step == PSA_PAKE_STEP_KEY_SHARE ? 69 :                    \
-        input_step == PSA_PAKE_STEP_ZK_PUBLIC ? 66 :                    \
-        33                                                              \
+        input_step == PSA_PAKE_STEP_KEY_SHARE ? 65 :                    \
+        input_step == PSA_PAKE_STEP_ZK_PUBLIC ? 65 :                    \
+        32                                                              \
       ) :                                                               \
       0 )
 
@@ -1808,7 +1808,7 @@
  *
  * See also #PSA_PAKE_OUTPUT_SIZE(\p alg, \p primitive, \p step).
  */
-#define PSA_PAKE_OUTPUT_MAX_SIZE 69
+#define PSA_PAKE_OUTPUT_MAX_SIZE 65
 
 /** Input buffer size for psa_pake_input() for any of the supported PAKE
  * algorithm and primitive suites and input step.
@@ -1817,7 +1817,7 @@
  *
  * See also #PSA_PAKE_INPUT_SIZE(\p alg, \p primitive, \p step).
  */
-#define PSA_PAKE_INPUT_MAX_SIZE 69
+#define PSA_PAKE_INPUT_MAX_SIZE 65
 
 /** Returns a suitable initializer for a PAKE cipher suite object of type
  * psa_pake_cipher_suite_t.
@@ -1906,7 +1906,10 @@
 
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
 #include <mbedtls/ecjpake.h>
-#define PSA_PAKE_BUFFER_SIZE ( ( 69 + 66 + 33 ) * 2 )
+/* Note: the format for mbedtls_ecjpake_read/write function has an extra
+ * length byte for each step, plus an extra 3 bytes for ECParameters in the
+ * server's 2nd round. */
+#define MBEDTLS_PSA_PAKE_BUFFER_SIZE ( ( 3 + 1 + 65 + 1 + 65 + 1 + 32 ) * 2 )
 #endif
 
 struct psa_pake_operation_s
@@ -1919,7 +1922,7 @@
     unsigned int MBEDTLS_PRIVATE(output_step);
     mbedtls_svc_key_id_t MBEDTLS_PRIVATE(password);
     psa_pake_role_t MBEDTLS_PRIVATE(role);
-    uint8_t MBEDTLS_PRIVATE(buffer[PSA_PAKE_BUFFER_SIZE]);
+    uint8_t MBEDTLS_PRIVATE(buffer[MBEDTLS_PSA_PAKE_BUFFER_SIZE]);
     size_t MBEDTLS_PRIVATE(buffer_length);
     size_t MBEDTLS_PRIVATE(buffer_offset);
 #endif
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index 10d3e4a..870b5b5 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -230,7 +230,7 @@
         operation->input_step = PSA_PAKE_STEP_X1_X2;
         operation->output_step = PSA_PAKE_STEP_X1_X2;
 
-        mbedtls_platform_zeroize( operation->buffer, PSA_PAKE_BUFFER_SIZE );
+        mbedtls_platform_zeroize( operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE );
         operation->buffer_length = 0;
         operation->buffer_offset = 0;
 
@@ -385,7 +385,8 @@
 }
 #endif
 
-psa_status_t psa_pake_output( psa_pake_operation_t *operation,
+static psa_status_t psa_pake_output_internal(
+                              psa_pake_operation_t *operation,
                               psa_pake_step_t step,
                               uint8_t *output,
                               size_t output_size,
@@ -427,10 +428,7 @@
         if( operation->state == PSA_PAKE_STATE_SETUP ) {
             status = psa_pake_ecjpake_setup( operation );
             if( status != PSA_SUCCESS )
-            {
-                psa_pake_abort( operation );
                 return( status );
-            }
         }
 
         if( operation->state != PSA_PAKE_STATE_READY &&
@@ -491,15 +489,12 @@
         {
             ret = mbedtls_ecjpake_write_round_one( &operation->ctx.ecjpake,
                                                    operation->buffer,
-                                                   PSA_PAKE_BUFFER_SIZE,
+                                                   MBEDTLS_PSA_PAKE_BUFFER_SIZE,
                                                    &operation->buffer_length,
                                                    mbedtls_psa_get_random,
                                                    MBEDTLS_PSA_RANDOM_STATE );
             if( ret != 0 )
-            {
-                psa_pake_abort( operation );
                 return( mbedtls_ecjpake_to_psa_error( ret ) );
-            }
 
             operation->buffer_offset = 0;
         }
@@ -508,68 +503,47 @@
         {
             ret = mbedtls_ecjpake_write_round_two( &operation->ctx.ecjpake,
                                                    operation->buffer,
-                                                   PSA_PAKE_BUFFER_SIZE,
+                                                   MBEDTLS_PSA_PAKE_BUFFER_SIZE,
                                                    &operation->buffer_length,
                                                    mbedtls_psa_get_random,
                                                    MBEDTLS_PSA_RANDOM_STATE );
             if( ret != 0 )
-            {
-                psa_pake_abort( operation );
                 return( mbedtls_ecjpake_to_psa_error( ret ) );
-            }
 
             operation->buffer_offset = 0;
         }
 
         /*
-         * Steps sequences are stored as:
-         * struct {
-         *     opaque point <1..2^8-1>;
-         * } ECPoint;
+         * mbedtls_ecjpake_write_round_xxx() outputs thing in the format
+         * defined by draft-cragie-tls-ecjpake-01 section 7. The summary is
+         * that the data for each step is prepended with a length byte, and
+         * then they're concatenated. Additionally, the server's second round
+         * output is prepended with a 3-bytes ECParameters structure.
          *
-         * Where byte 0 stores the ECPoint curve point length.
-         *
-         * The sequence length is equal to:
-         * - data length extracted from byte 0
-         * - byte 0 size (1)
+         * In PSA, we output each step separately, and don't prepend the
+         * output with a length byte, even less a curve identifier, as that
+         * information is already available.
          */
         if( operation->state == PSA_PAKE_OUTPUT_X2S &&
-            operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE )
+            operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
+            operation->role == PSA_PAKE_ROLE_SERVER )
         {
-            if( operation->role == PSA_PAKE_ROLE_SERVER )
-                /*
-                 * The X2S KEY SHARE Server steps sequence is stored as:
-                 * struct {
-                 *     ECPoint X;
-                 *    opaque r <1..2^8-1>;
-                 * } ECSchnorrZKP;
-                 *
-                 * And MbedTLS uses a 3 bytes Ephemeral public key ECPoint,
-                 * so byte 3 stores the r Schnorr signature length.
-                 *
-                 * The sequence length is equal to:
-                 * - curve storage size (3)
-                 * - data length extracted from byte 3
-                 * - byte 3 size (1)
-                 */
-                length = 3 + operation->buffer[3] + 1;
-            else
-                length = operation->buffer[0] + 1;
+            /* Skip ECParameters, with is 3 bytes (RFC 8422) */
+            operation->buffer_offset += 3;
         }
-        else
-            length = operation->buffer[operation->buffer_offset] + 1;
 
-        if( length > operation->buffer_length )
+        /* Read the length byte then move past it to the data */
+        length = operation->buffer[operation->buffer_offset];
+        operation->buffer_offset += 1;
+
+        if( operation->buffer_offset + length > operation->buffer_length )
             return( PSA_ERROR_DATA_CORRUPT );
 
         if( output_size < length )
-        {
-            psa_pake_abort( operation );
             return( PSA_ERROR_BUFFER_TOO_SMALL );
-        }
 
         memcpy( output,
-                operation->buffer +  operation->buffer_offset,
+                operation->buffer + operation->buffer_offset,
                 length );
         *output_length = length;
 
@@ -581,7 +555,7 @@
             ( operation->state == PSA_PAKE_OUTPUT_X2S &&
               operation->sequence == PSA_PAKE_X1_STEP_ZK_PROOF ) )
         {
-            mbedtls_platform_zeroize( operation->buffer, PSA_PAKE_BUFFER_SIZE );
+            mbedtls_platform_zeroize( operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE );
             operation->buffer_length = 0;
             operation->buffer_offset = 0;
 
@@ -599,14 +573,29 @@
     return( PSA_ERROR_NOT_SUPPORTED );
 }
 
-psa_status_t psa_pake_input( psa_pake_operation_t *operation,
+psa_status_t psa_pake_output( psa_pake_operation_t *operation,
+                              psa_pake_step_t step,
+                              uint8_t *output,
+                              size_t output_size,
+                              size_t *output_length )
+{
+    psa_status_t status = psa_pake_output_internal(
+            operation, step, output, output_size, output_length );
+
+    if( status != PSA_SUCCESS )
+        psa_pake_abort( operation );
+
+    return( status );
+}
+
+static psa_status_t psa_pake_input_internal(
+                             psa_pake_operation_t *operation,
                              psa_pake_step_t step,
                              const uint8_t *input,
                              size_t input_length )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    size_t buffer_remain;
 
     if( operation->alg == PSA_ALG_NONE ||
         operation->state == PSA_PAKE_STATE_INVALID )
@@ -638,14 +627,16 @@
             step != PSA_PAKE_STEP_ZK_PROOF )
             return( PSA_ERROR_INVALID_ARGUMENT );
 
+        const psa_pake_primitive_t prim = PSA_PAKE_PRIMITIVE(
+                PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256 );
+        if( input_length > (size_t) PSA_PAKE_INPUT_SIZE( PSA_ALG_JPAKE, prim, step ) )
+            return( PSA_ERROR_INVALID_ARGUMENT );
+
         if( operation->state == PSA_PAKE_STATE_SETUP )
         {
             status = psa_pake_ecjpake_setup( operation );
             if( status != PSA_SUCCESS )
-            {
-                psa_pake_abort( operation );
                 return( status );
-            }
         }
 
         if( operation->state != PSA_PAKE_STATE_READY &&
@@ -675,15 +666,6 @@
             operation->sequence = PSA_PAKE_X1_STEP_KEY_SHARE;
         }
 
-        buffer_remain = PSA_PAKE_BUFFER_SIZE - operation->buffer_length;
-
-        if( input_length == 0 ||
-            input_length > buffer_remain )
-        {
-            psa_pake_abort( operation );
-            return( PSA_ERROR_INSUFFICIENT_MEMORY );
-        }
-
         /* Check if step matches current sequence */
         switch( operation->sequence )
         {
@@ -709,7 +691,35 @@
                 return( PSA_ERROR_BAD_STATE );
         }
 
-        /* Copy input to local buffer */
+        /*
+         * Copy input to local buffer and format it as the Mbed TLS API
+         * expects, i.e. as defined by draft-cragie-tls-ecjpake-01 section 7.
+         * The summary is that the data for each step is prepended with a
+         * length byte, and then they're concatenated. Additionally, the
+         * server's second round output is prepended with a 3-bytes
+         * ECParameters structure - which means we have to prepend that when
+         * we're a client.
+         */
+        if( operation->state == PSA_PAKE_INPUT_X4S &&
+            operation->sequence == PSA_PAKE_X1_STEP_KEY_SHARE &&
+            operation->role == PSA_PAKE_ROLE_CLIENT )
+        {
+            /* We only support secp256r1. */
+            /* This is the ECParameters structure defined by RFC 8422. */
+            unsigned char ecparameters[3] = {
+                3, /* named_curve */
+                0, 23 /* secp256r1 */
+            };
+            memcpy( operation->buffer + operation->buffer_length,
+                    ecparameters, sizeof( ecparameters ) );
+            operation->buffer_length += sizeof( ecparameters );
+        }
+
+        /* Write the length byte */
+        operation->buffer[operation->buffer_length] = (uint8_t) input_length;
+        operation->buffer_length += 1;
+
+        /* Finally copy the data */
         memcpy( operation->buffer + operation->buffer_length,
                 input, input_length );
         operation->buffer_length += input_length;
@@ -722,14 +732,11 @@
                                                   operation->buffer,
                                                   operation->buffer_length );
 
-            mbedtls_platform_zeroize( operation->buffer, PSA_PAKE_BUFFER_SIZE );
+            mbedtls_platform_zeroize( operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE );
             operation->buffer_length = 0;
 
             if( ret != 0 )
-            {
-                psa_pake_abort( operation );
                 return( mbedtls_ecjpake_to_psa_error( ret ) );
-            }
         }
         else if( operation->state == PSA_PAKE_INPUT_X4S &&
                  operation->sequence == PSA_PAKE_X1_STEP_ZK_PROOF )
@@ -738,14 +745,11 @@
                                                   operation->buffer,
                                                   operation->buffer_length );
 
-            mbedtls_platform_zeroize( operation->buffer, PSA_PAKE_BUFFER_SIZE );
+            mbedtls_platform_zeroize( operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE );
             operation->buffer_length = 0;
 
             if( ret != 0 )
-            {
-                psa_pake_abort( operation );
                 return( mbedtls_ecjpake_to_psa_error( ret ) );
-            }
         }
 
         if( ( operation->state == PSA_PAKE_INPUT_X1_X2 &&
@@ -767,6 +771,20 @@
     return( PSA_ERROR_NOT_SUPPORTED );
 }
 
+psa_status_t psa_pake_input( psa_pake_operation_t *operation,
+                             psa_pake_step_t step,
+                             const uint8_t *input,
+                             size_t input_length )
+{
+    psa_status_t status = psa_pake_input_internal(
+            operation, step, input, input_length );
+
+    if( status != PSA_SUCCESS )
+        psa_pake_abort( operation );
+
+    return( status );
+}
+
 psa_status_t psa_pake_get_implicit_key(psa_pake_operation_t *operation,
                                        psa_key_derivation_operation_t *output)
 {
@@ -784,7 +802,7 @@
     {
         ret = mbedtls_ecjpake_write_shared_key( &operation->ctx.ecjpake,
                                                 operation->buffer,
-                                                PSA_PAKE_BUFFER_SIZE,
+                                                MBEDTLS_PSA_PAKE_BUFFER_SIZE,
                                                 &operation->buffer_length,
                                                 mbedtls_psa_get_random,
                                                 MBEDTLS_PSA_RANDOM_STATE );
@@ -799,7 +817,7 @@
                                                  operation->buffer,
                                                  operation->buffer_length );
 
-        mbedtls_platform_zeroize( operation->buffer, PSA_PAKE_BUFFER_SIZE );
+        mbedtls_platform_zeroize( operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE );
 
         psa_pake_abort( operation );
 
@@ -824,7 +842,7 @@
         operation->output_step = PSA_PAKE_STEP_INVALID;
         operation->password = MBEDTLS_SVC_KEY_ID_INIT;
         operation->role = PSA_PAKE_ROLE_NONE;
-        mbedtls_platform_zeroize( operation->buffer, PSA_PAKE_BUFFER_SIZE );
+        mbedtls_platform_zeroize( operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE );
         operation->buffer_length = 0;
         operation->buffer_offset = 0;
         mbedtls_ecjpake_free( &operation->ctx.ecjpake );
diff --git a/tests/suites/test_suite_psa_crypto.data b/tests/suites/test_suite_psa_crypto.data
index 4448bc4..cce3fd0 100644
--- a/tests/suites/test_suite_psa_crypto.data
+++ b/tests/suites/test_suite_psa_crypto.data
@@ -6594,3 +6594,7 @@
 PSA PAKE: ecjpake inject input errors, second round server, client input first
 depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256:PSA_WANT_ALG_SHA_256
 ecjpake_rounds_inject:PSA_ALG_JPAKE:PSA_PAKE_PRIMITIVE(PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, 256):PSA_ALG_SHA_256:1:4:"abcdef"
+
+PSA PAKE: ecjpake size macros
+depends_on:PSA_WANT_KEY_TYPE_ECC_KEY_PAIR:PSA_WANT_ECC_SECP_R1_256
+ecjpake_size_macros:
diff --git a/tests/suites/test_suite_psa_crypto.function b/tests/suites/test_suite_psa_crypto.function
index 6c95c2a..36a8efa 100644
--- a/tests/suites/test_suite_psa_crypto.function
+++ b/tests/suites/test_suite_psa_crypto.function
@@ -718,6 +718,15 @@
         PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_KEY_SHARE) +
         PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PUBLIC) +
         PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PROOF)) * 2;
+    /* The output should be exactly this size according to the spec */
+    const size_t expected_size_key_share =
+        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_KEY_SHARE);
+    /* The output should be exactly this size according to the spec */
+    const size_t expected_size_zk_public =
+        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PUBLIC);
+    /* The output can be smaller: the spec allows stripping leading zeroes */
+    const size_t max_expected_size_zk_proof =
+        PSA_PAKE_OUTPUT_SIZE(alg, primitive, PSA_PAKE_STEP_ZK_PROOF);
     size_t buffer0_off = 0;
     size_t buffer1_off = 0;
     size_t s_g1_len, s_g2_len, s_a_len;
@@ -745,31 +754,37 @@
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_KEY_SHARE,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_g1_len ) );
+            TEST_EQUAL( s_g1_len, expected_size_key_share );
             s_g1_off = buffer0_off;
             buffer0_off += s_g1_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PUBLIC,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_x1_pk_len ) );
+            TEST_EQUAL( s_x1_pk_len, expected_size_zk_public );
             s_x1_pk_off = buffer0_off;
             buffer0_off += s_x1_pk_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PROOF,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_x1_pr_len ) );
+            TEST_LE_U( s_x1_pr_len, max_expected_size_zk_proof );
             s_x1_pr_off = buffer0_off;
             buffer0_off += s_x1_pr_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_KEY_SHARE,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_g2_len ) );
+            TEST_EQUAL( s_g2_len, expected_size_key_share );
             s_g2_off = buffer0_off;
             buffer0_off += s_g2_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PUBLIC,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_x2_pk_len ) );
+            TEST_EQUAL( s_x2_pk_len, expected_size_zk_public );
             s_x2_pk_off = buffer0_off;
             buffer0_off += s_x2_pk_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PROOF,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_x2_pr_len ) );
+            TEST_LE_U( s_x2_pr_len, max_expected_size_zk_proof );
             s_x2_pr_off = buffer0_off;
             buffer0_off += s_x2_pr_len;
 
@@ -877,31 +892,37 @@
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_KEY_SHARE,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_g1_len ) );
+            TEST_EQUAL( c_g1_len, expected_size_key_share );
             c_g1_off = buffer1_off;
             buffer1_off += c_g1_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PUBLIC,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_x1_pk_len ) );
+            TEST_EQUAL( c_x1_pk_len, expected_size_zk_public );
             c_x1_pk_off = buffer1_off;
             buffer1_off += c_x1_pk_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PROOF,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_x1_pr_len ) );
+            TEST_LE_U( c_x1_pr_len, max_expected_size_zk_proof );
             c_x1_pr_off = buffer1_off;
             buffer1_off += c_x1_pr_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_KEY_SHARE,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_g2_len ) );
+            TEST_EQUAL( c_g2_len, expected_size_key_share );
             c_g2_off = buffer1_off;
             buffer1_off += c_g2_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PUBLIC,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_x2_pk_len ) );
+            TEST_EQUAL( c_x2_pk_len, expected_size_zk_public );
             c_x2_pk_off = buffer1_off;
             buffer1_off += c_x2_pk_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PROOF,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_x2_pr_len ) );
+            TEST_LE_U( c_x2_pr_len, max_expected_size_zk_proof );
             c_x2_pr_off = buffer1_off;
             buffer1_off += c_x2_pr_len;
 
@@ -1083,16 +1104,19 @@
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_KEY_SHARE,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_a_len ) );
+            TEST_EQUAL( s_a_len, expected_size_key_share );
             s_a_off = buffer0_off;
             buffer0_off += s_a_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PUBLIC,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_x2s_pk_len ) );
+            TEST_EQUAL( s_x2s_pk_len, expected_size_zk_public );
             s_x2s_pk_off = buffer0_off;
             buffer0_off += s_x2s_pk_len;
             PSA_ASSERT( psa_pake_output( server, PSA_PAKE_STEP_ZK_PROOF,
                                          buffer0 + buffer0_off,
                                          512 - buffer0_off, &s_x2s_pr_len ) );
+            TEST_LE_U( s_x2s_pr_len, max_expected_size_zk_proof );
             s_x2s_pr_off = buffer0_off;
             buffer0_off += s_x2s_pr_len;
 
@@ -1154,16 +1178,19 @@
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_KEY_SHARE,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_a_len ) );
+            TEST_EQUAL( c_a_len, expected_size_key_share );
             c_a_off = buffer1_off;
             buffer1_off += c_a_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PUBLIC,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_x2s_pk_len ) );
+            TEST_EQUAL( c_x2s_pk_len, expected_size_zk_public );
             c_x2s_pk_off = buffer1_off;
             buffer1_off += c_x2s_pk_len;
             PSA_ASSERT( psa_pake_output( client, PSA_PAKE_STEP_ZK_PROOF,
                                          buffer1 + buffer1_off,
                                          512 - buffer1_off, &c_x2s_pr_len ) );
+            TEST_LE_U( c_x2s_pr_len, max_expected_size_zk_proof );
             c_x2s_pr_off = buffer1_off;
             buffer1_off += c_x2s_pr_len;
 
@@ -8713,7 +8740,9 @@
 {
     psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
     psa_pake_operation_t operation = psa_pake_operation_init();
+    psa_pake_operation_t op_copy = psa_pake_operation_init();
     psa_algorithm_t alg = alg_arg;
+    psa_pake_primitive_t primitive = primitive_arg;
     psa_key_type_t key_type_pw = key_type_pw_arg;
     psa_key_usage_t key_usage_pw = key_usage_pw_arg;
     psa_algorithm_t hash_alg = hash_arg;
@@ -8731,9 +8760,9 @@
 
     PSA_INIT( );
 
-    ASSERT_ALLOC( output_buffer,
-                  PSA_PAKE_OUTPUT_SIZE(alg, primitive_arg,
-                                       PSA_PAKE_STEP_KEY_SHARE) );
+    size_t buf_size = PSA_PAKE_OUTPUT_SIZE(alg, primitive_arg,
+                                       PSA_PAKE_STEP_KEY_SHARE);
+    ASSERT_ALLOC( output_buffer, buf_size );
 
     if( pw_data->len > 0 )
     {
@@ -8745,7 +8774,7 @@
     }
 
     psa_pake_cs_set_algorithm( &cipher_suite, alg );
-    psa_pake_cs_set_primitive( &cipher_suite, primitive_arg );
+    psa_pake_cs_set_primitive( &cipher_suite, primitive );
     psa_pake_cs_set_hash( &cipher_suite, hash_alg );
 
     PSA_ASSERT( psa_pake_abort( &operation ) );
@@ -8799,54 +8828,71 @@
     TEST_EQUAL( psa_pake_set_peer( &operation, unsupported_id, 4 ),
                 PSA_ERROR_NOT_SUPPORTED );
 
+    const size_t size_key_share = PSA_PAKE_INPUT_SIZE( alg, primitive,
+                                                PSA_PAKE_STEP_KEY_SHARE );
+    const size_t size_zk_public = PSA_PAKE_INPUT_SIZE( alg, primitive,
+                                                PSA_PAKE_STEP_ZK_PUBLIC );
+    const size_t size_zk_proof = PSA_PAKE_INPUT_SIZE( alg, primitive,
+                                                PSA_PAKE_STEP_ZK_PROOF );
+
     /* First round */
     if( input_first )
     {
-        /* Invalid parameters */
-        TEST_EQUAL( psa_pake_input( &operation, PSA_PAKE_STEP_ZK_PROOF,
+        /* Invalid parameters (input) */
+        op_copy = operation;
+        TEST_EQUAL( psa_pake_input( &op_copy, PSA_PAKE_STEP_ZK_PROOF,
                                     NULL, 0 ),
                     PSA_ERROR_INVALID_ARGUMENT );
-        TEST_EQUAL( psa_pake_input( &operation, PSA_PAKE_STEP_ZK_PROOF + 10,
-                                    output_buffer, 66 ),
+        /* Invalid parameters (step) */
+        op_copy = operation;
+        TEST_EQUAL( psa_pake_input( &op_copy, PSA_PAKE_STEP_ZK_PROOF + 10,
+                                    output_buffer, size_zk_proof ),
                     PSA_ERROR_INVALID_ARGUMENT );
         /* Invalid first step */
-        TEST_EQUAL( psa_pake_input( &operation, PSA_PAKE_STEP_ZK_PROOF,
-                                    output_buffer, 66 ),
+        op_copy = operation;
+        TEST_EQUAL( psa_pake_input( &op_copy, PSA_PAKE_STEP_ZK_PROOF,
+                                    output_buffer, size_zk_proof ),
                     PSA_ERROR_BAD_STATE );
 
+        /* Possibly valid */
         TEST_EQUAL( psa_pake_input( &operation, PSA_PAKE_STEP_KEY_SHARE,
-                                    output_buffer, 66 ),
+                                    output_buffer, size_key_share ),
                     expected_status_input_output);
 
         if( expected_status_input_output == PSA_SUCCESS )
         {
             /* Buffer too large */
             TEST_EQUAL( psa_pake_input( &operation, PSA_PAKE_STEP_ZK_PUBLIC,
-                                    output_buffer, 512 ),
-                        PSA_ERROR_INSUFFICIENT_MEMORY );
+                                    output_buffer, size_zk_public + 1 ),
+                        PSA_ERROR_INVALID_ARGUMENT );
 
-            /* The operation should be aborted at this point */
+            /* The operation's state should be invalidated at this point */
             TEST_EQUAL( psa_pake_input( &operation, PSA_PAKE_STEP_ZK_PUBLIC,
-                                        output_buffer, 66 ),
+                                        output_buffer, size_zk_public ),
                         PSA_ERROR_BAD_STATE );
         }
     }
     else
     {
-        /* Invalid parameters */
-        TEST_EQUAL( psa_pake_output( &operation, PSA_PAKE_STEP_ZK_PROOF,
+        /* Invalid parameters (output) */
+        op_copy = operation;
+        TEST_EQUAL( psa_pake_output( &op_copy, PSA_PAKE_STEP_ZK_PROOF,
                                      NULL, 0, NULL ),
                     PSA_ERROR_INVALID_ARGUMENT );
-        TEST_EQUAL( psa_pake_output( &operation, PSA_PAKE_STEP_ZK_PROOF + 10,
-                                     output_buffer, 512, &output_len ),
+        op_copy = operation;
+        /* Invalid parameters (step) */
+        TEST_EQUAL( psa_pake_output( &op_copy, PSA_PAKE_STEP_ZK_PROOF + 10,
+                                     output_buffer, buf_size, &output_len ),
                     PSA_ERROR_INVALID_ARGUMENT );
         /* Invalid first step */
-        TEST_EQUAL( psa_pake_output( &operation, PSA_PAKE_STEP_ZK_PROOF,
-                                     output_buffer, 512, &output_len ),
+        op_copy = operation;
+        TEST_EQUAL( psa_pake_output( &op_copy, PSA_PAKE_STEP_ZK_PROOF,
+                                     output_buffer, buf_size, &output_len ),
                     PSA_ERROR_BAD_STATE );
 
+        /* Possibly valid */
         TEST_EQUAL( psa_pake_output( &operation, PSA_PAKE_STEP_KEY_SHARE,
-                                     output_buffer, 512, &output_len ),
+                                     output_buffer, buf_size, &output_len ),
                     expected_status_input_output );
 
         if( expected_status_input_output == PSA_SUCCESS )
@@ -8855,12 +8901,12 @@
 
             /* Buffer too small */
             TEST_EQUAL( psa_pake_output( &operation, PSA_PAKE_STEP_ZK_PUBLIC,
-                                         output_buffer, 5, &output_len ),
+                                         output_buffer, size_zk_public - 1, &output_len ),
                         PSA_ERROR_BUFFER_TOO_SMALL );
 
-            /* The operation should be aborted at this point */
+            /* The operation's state should be invalidated at this point */
             TEST_EQUAL( psa_pake_output( &operation, PSA_PAKE_STEP_ZK_PUBLIC,
-                                         output_buffer, 512, &output_len ),
+                                         output_buffer, buf_size, &output_len ),
                         PSA_ERROR_BAD_STATE );
         }
     }
@@ -9009,3 +9055,47 @@
     PSA_DONE( );
 }
 /* END_CASE */
+
+/* BEGIN_CASE */
+void ecjpake_size_macros( )
+{
+    const psa_algorithm_t alg = PSA_ALG_JPAKE;
+    const size_t bits = 256;
+    const psa_pake_primitive_t prim = PSA_PAKE_PRIMITIVE(
+            PSA_PAKE_PRIMITIVE_TYPE_ECC, PSA_ECC_FAMILY_SECP_R1, bits );
+    const psa_key_type_t key_type = PSA_KEY_TYPE_ECC_KEY_PAIR(
+            PSA_ECC_FAMILY_SECP_R1 );
+
+    // https://armmbed.github.io/mbed-crypto/1.1_PAKE_Extension.0-bet.0/html/pake.html#pake-step-types
+    /* The output for KEY_SHARE and ZK_PUBLIC is the same as a public key */
+    TEST_EQUAL( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE),
+                PSA_EXPORT_PUBLIC_KEY_OUTPUT_SIZE( key_type, bits ) );
+    TEST_EQUAL( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC),
+                PSA_EXPORT_PUBLIC_KEY_OUTPUT_SIZE( key_type, bits ) );
+    /* The output for ZK_PROOF is the same bitsize as the curve */
+    TEST_EQUAL( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF),
+                PSA_BITS_TO_BYTES( bits ) );
+
+    /* Input sizes are the same as output sizes */
+    TEST_EQUAL( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE),
+                PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE) );
+    TEST_EQUAL( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC),
+                PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC) );
+    TEST_EQUAL( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF),
+                PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF) );
+
+    /* These inequalities will always hold even when other PAKEs are added */
+    TEST_LE_U( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE),
+               PSA_PAKE_OUTPUT_MAX_SIZE );
+    TEST_LE_U( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC),
+               PSA_PAKE_OUTPUT_MAX_SIZE );
+    TEST_LE_U( PSA_PAKE_OUTPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF),
+               PSA_PAKE_OUTPUT_MAX_SIZE );
+    TEST_LE_U( PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_KEY_SHARE),
+               PSA_PAKE_INPUT_MAX_SIZE );
+    TEST_LE_U( PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PUBLIC),
+               PSA_PAKE_INPUT_MAX_SIZE );
+    TEST_LE_U( PSA_PAKE_INPUT_SIZE(alg, prim, PSA_PAKE_STEP_ZK_PROOF),
+               PSA_PAKE_INPUT_MAX_SIZE );
+}
+/* END_CASE */