tls12: psa_pake: use common code for parsing/writing round one and round two data

Share a common parsing code for both server and client for parsing
round one and two.

Signed-off-by: Valerio Setti <vsetti@baylibre.com>
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 8b96243..d4ce35c 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -2364,6 +2364,218 @@
 }
 #endif /* MBEDTLS_USE_PSA_CRYPTO || MBEDTLS_SSL_PROTO_TLS1_3 */
 
+#if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED) && \
+    defined(MBEDTLS_USE_PSA_CRYPTO)
+/**
+ * \brief       Parse the provided input buffer for getting the first round
+ *              of key exchange. This code is common between server and client
+ *
+ * \param  pake_ctx [in] the PAKE's operation/context structure
+ * \param  buf      [in] input buffer to parse
+ * \param  len      [in] length of the input buffer
+ *
+ * \return               0 on success or a negative error code in case of failure
+ */
+static inline int psa_tls12_parse_ecjpake_round_one( 
+                                    psa_pake_operation_t *pake_ctx,
+                                    const unsigned char *buf,
+                                    size_t len )
+{
+    psa_status_t status;
+    size_t input_offset = 0;
+
+    /* Repeat the KEY_SHARE, ZK_PUBLIC & ZF_PROOF twice */
+    for( unsigned int x = 1; x <= 2; ++x )
+    {
+        for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE;
+             step <= PSA_PAKE_STEP_ZK_PROOF;
+             ++step )
+        {
+            /* Length is stored at the first byte */
+            size_t length = buf[input_offset];
+            input_offset += 1;
+
+            if( input_offset + length > len )
+            {
+                return MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
+            }
+
+            status = psa_pake_input( pake_ctx, step,
+                                     buf + input_offset, length );
+            if( status != PSA_SUCCESS)
+            {
+                return psa_ssl_status_to_mbedtls( status );
+            }
+
+            input_offset += length;
+        }
+    }
+
+    return( 0 );
+}
+
+/**
+ * \brief       Parse the provided input buffer for getting the second round
+ *              of key exchange. This code is common between server and client
+ *
+ * \param  pake_ctx [in] the PAKE's operation/context structure
+ * \param  buf      [in] input buffer to parse
+ * \param  len      [in] length of the input buffer
+ *
+ * \return               0 on success or a negative error code in case of failure
+ */
+static inline int psa_tls12_parse_ecjpake_round_two(
+                                    psa_pake_operation_t *pake_ctx,
+                                    const unsigned char *buf,
+                                    size_t len, int role )
+{
+    psa_status_t status;
+    size_t input_offset = 0;
+
+    for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ;
+            step <= PSA_PAKE_STEP_ZK_PROOF ;
+            ++step )
+    {
+        size_t length;
+
+        /*
+         * On its 2nd round, the server sends 3 extra bytes which identify the
+         * curve. Therefore we should skip them only on the client side
+         */
+        if( ( step == PSA_PAKE_STEP_KEY_SHARE ) && 
+            ( role == MBEDTLS_SSL_IS_CLIENT ) )
+        {
+            /* Length is stored after the 3 bytes for the curve */
+            length = buf[input_offset + 3];
+            input_offset += 3 + 1;
+        }
+        else
+        {
+            /* Length is stored at the first byte */
+            length = buf[input_offset];
+            input_offset += 1;
+        }
+
+        if( input_offset + length > len )
+        {
+            return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
+        }
+
+        status = psa_pake_input( pake_ctx, step,
+                                    buf + input_offset, length );
+        if( status != PSA_SUCCESS)
+        {
+            return psa_ssl_status_to_mbedtls( status );
+        }
+
+        input_offset += length;
+    }
+
+    return( 0 );
+}
+
+/**
+ * \brief       Write the first round of key exchange into the provided output
+ *              buffer. This code is common between server and client
+ *
+ * \param  pake_ctx [in] the PAKE's operation/context structure
+ * \param  buf      [out] the output buffer in which data will be written to
+ * \param  len      [in] length of the output buffer
+ * \param  olen     [out] the length of the data really written on the buffer
+ *
+ * \return               0 on success or a negative error code in case of failure
+ */
+static inline int psa_tls12_write_ecjpake_round_one(
+                                    psa_pake_operation_t *pake_ctx,
+                                    unsigned char *buf,
+                                    size_t len, size_t *olen )
+{
+    psa_status_t status;
+    size_t output_offset = 0;
+    size_t output_len;
+
+    /* Repeat the KEY_SHARE, ZK_PUBLIC & ZF_PROOF twice */
+    for( unsigned int x = 1 ; x <= 2 ; ++x )
+    {
+        for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ;
+            step <= PSA_PAKE_STEP_ZK_PROOF ;
+            ++step )
+        {
+            /* For each step, prepend 1 byte with the length of the data */
+            if (step != PSA_PAKE_STEP_ZK_PROOF) {
+                *(buf + output_offset) = 65;
+            } else {
+                *(buf + output_offset) = 32;
+            }
+            output_offset += 1;
+
+            status = psa_pake_output( pake_ctx, step,
+                                        buf + output_offset,
+                                        len - output_offset,
+                                        &output_len );
+            if( status != PSA_SUCCESS )
+            {
+                return( psa_ssl_status_to_mbedtls( status ) );
+            }
+
+            output_offset += output_len;
+        }
+    }
+
+    *olen = output_offset;
+
+    return( 0 );
+}
+
+/**
+ * \brief       Write the second round of key exchange into the provided output
+ *              buffer. This code is common between server and client
+ *
+ * \param  pake_ctx [in] the PAKE's operation/context structure
+ * \param  buf      [out] the output buffer in which data will be written to
+ * \param  len      [in] length of the output buffer
+ * \param  olen     [out] the length of the data really written on the buffer
+ *
+ * \return               0 on success or a negative error code in case of failure
+ */
+static inline int psa_tls12_write_ecjpake_round_two(
+                                    psa_pake_operation_t *pake_ctx,
+                                    unsigned char *buf,
+                                    size_t len, size_t *olen )
+{
+    psa_status_t status;
+    size_t output_offset = 0;
+    size_t output_len;
+
+    for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ;
+            step <= PSA_PAKE_STEP_ZK_PROOF ;
+            ++step )
+    {
+        /* For each step, prepend 1 byte with the length of the data */
+        if (step != PSA_PAKE_STEP_ZK_PROOF) {
+            *(buf + output_offset) = 65;
+        } else {
+            *(buf + output_offset) = 32;
+        }
+        output_offset += 1;
+        status = psa_pake_output( pake_ctx,
+                                    step, buf + output_offset,
+                                    len - output_offset,
+                                    &output_len );
+        if( status != PSA_SUCCESS )
+        {
+            return( psa_ssl_status_to_mbedtls( status ) );
+        }
+
+        output_offset += output_len;
+    }
+
+    *olen = output_offset;
+
+    return( 0 );
+}
+#endif //MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED && MBEDTLS_USE_PSA_CRYPTO
+
 /**
  * \brief       TLS record protection modes
  */
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index ebada7a..8771c59 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -1616,23 +1616,19 @@
 /*
  * Set EC J-PAKE password for current handshake
  */
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
 int mbedtls_ssl_set_hs_ecjpake_password( mbedtls_ssl_context *ssl,
                                          const unsigned char *pw,
                                          size_t pw_len )
 {
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
     psa_pake_cipher_suite_t cipher_suite = psa_pake_cipher_suite_init();
     psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
     psa_pake_role_t psa_role;
     psa_status_t status;
-#else
-    mbedtls_ecjpake_role role;
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
     if( ssl->handshake == NULL || ssl->conf == NULL )
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
 
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER )
         psa_role = PSA_PAKE_ROLE_SERVER;
     else
@@ -1688,7 +1684,17 @@
     ssl->handshake->psa_pake_ctx_is_ok = 1;
 
     return( 0 );
-#else
+}
+#else /* MBEDTLS_USE_PSA_CRYPTO */
+int mbedtls_ssl_set_hs_ecjpake_password( mbedtls_ssl_context *ssl,
+                                         const unsigned char *pw,
+                                         size_t pw_len )
+{
+    mbedtls_ecjpake_role role;
+
+    if( ssl->handshake == NULL || ssl->conf == NULL )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
     if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER )
         role = MBEDTLS_ECJPAKE_SERVER;
     else
@@ -1699,8 +1705,8 @@
                                    MBEDTLS_MD_SHA256,
                                    MBEDTLS_ECP_DP_SECP256R1,
                                    pw, pw_len ) );
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
 }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 #endif /* MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED */
 
 #if defined(MBEDTLS_SSL_HANDSHAKE_WITH_PSK_ENABLED)
@@ -3734,6 +3740,7 @@
 #if !defined(MBEDTLS_USE_PSA_CRYPTO) && defined(MBEDTLS_ECDH_C)
     mbedtls_ecdh_free( &handshake->ecdh_ctx );
 #endif
+
 #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED)
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     psa_pake_abort( &handshake->psa_pake_ctx );
@@ -6042,7 +6049,6 @@
         return( ret );
     }
 
-
     /* Compute master secret if needed */
     ret = ssl_compute_master( ssl->handshake,
                               ssl->session_negotiate->master,
diff --git a/library/ssl_tls12_client.c b/library/ssl_tls12_client.c
index 3d25e40..c90ed2e 100644
--- a/library/ssl_tls12_client.c
+++ b/library/ssl_tls12_client.c
@@ -130,13 +130,9 @@
                                        const unsigned char *end,
                                        size_t *olen )
 {
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    psa_status_t status;
-#else
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
     unsigned char *p = buf;
-    size_t kkpp_len;
+    size_t kkpp_len = 0;
 
     *olen = 0;
 
@@ -168,41 +164,15 @@
         MBEDTLS_SSL_DEBUG_MSG( 3, ( "generating new ecjpake parameters" ) );
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-        size_t output_offset = 0;
-        size_t output_len;
-
-        /* Repeat the KEY_SHARE, ZK_PUBLIC & ZF_PROOF twice */
-        for( unsigned int x = 1 ; x <= 2 ; ++x )
+        ret = psa_tls12_write_ecjpake_round_one(&ssl->handshake->psa_pake_ctx,
+                                                p + 2, end - p - 2, &kkpp_len );
+        if ( ret != 0 )
         {
-            for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ;
-                 step <= PSA_PAKE_STEP_ZK_PROOF ;
-                 ++step )
-            {
-                /* For each step, prepend 1 byte with the length of the data */
-                if (step != PSA_PAKE_STEP_ZK_PROOF) {
-                    *(p + 2 + output_offset) = 65;
-                } else {
-                    *(p + 2 + output_offset) = 32;
-                }
-                output_offset += 1;
-
-                status = psa_pake_output( &ssl->handshake->psa_pake_ctx,
-                                          step, p + 2 + output_offset,
-                                          end - p - output_offset - 2,
-                                          &output_len );
-                if( status != PSA_SUCCESS )
-                {
-                    psa_destroy_key( ssl->handshake->psa_pake_password );
-                    psa_pake_abort( &ssl->handshake->psa_pake_ctx );
-                    MBEDTLS_SSL_DEBUG_RET( 1 , "psa_pake_output", status );
-                    return( psa_ssl_status_to_mbedtls( status ) );
-                }
-
-                output_offset += output_len;
-            }
+            psa_destroy_key( ssl->handshake->psa_pake_password );
+            psa_pake_abort( &ssl->handshake->psa_pake_ctx );
+            MBEDTLS_SSL_DEBUG_RET( 1 , "psa_pake_output", ret );
+            return( ret );
         }
-
-        kkpp_len = output_offset;
 #else
         ret = mbedtls_ecjpake_write_round_one( &ssl->handshake->ecjpake_ctx,
                                                p + 2, end - p - 2, &kkpp_len,
@@ -924,9 +894,6 @@
                                    size_t len )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    psa_status_t status;
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
     if( ssl->handshake->ciphersuite_info->key_exchange !=
         MBEDTLS_KEY_EXCHANGE_ECJPAKE )
@@ -941,50 +908,21 @@
     ssl->handshake->ecjpake_cache_len = 0;
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-    size_t input_offset = 0;
-
-    /* Repeat the KEY_SHARE, ZK_PUBLIC & ZF_PROOF twice */
-    for( unsigned int x = 1 ; x <= 2 ; ++x )
+    if( ( ret = psa_tls12_parse_ecjpake_round_one(
+                            &ssl->handshake->psa_pake_ctx, buf, len ) ) != 0 )
     {
-        for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ;
-             step <= PSA_PAKE_STEP_ZK_PROOF ;
-             ++step )
-        {
-            /* Length is stored at the first byte */
-            size_t length = buf[input_offset];
-            input_offset += 1;
+        psa_destroy_key( ssl->handshake->psa_pake_password );
+        psa_pake_abort( &ssl->handshake->psa_pake_ctx );
 
-            if( input_offset + length > len )
-            {
-                ret = MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
-                goto psa_pake_error;
-            }
-
-            status = psa_pake_input( &ssl->handshake->psa_pake_ctx, step,
-                                     buf + input_offset, length );
-            if( status != PSA_SUCCESS)
-            {
-                ret = psa_ssl_status_to_mbedtls( status );
-                goto psa_pake_error;
-            }
-
-            input_offset += length;
-        }
+        MBEDTLS_SSL_DEBUG_RET( 1, "psa_pake_input round one", ret );
+        mbedtls_ssl_send_alert_message(
+                ssl,
+                MBEDTLS_SSL_ALERT_LEVEL_FATAL,
+                MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE );
+        return( ret );
     }
 
     return( 0 );
-
-psa_pake_error:
-    psa_destroy_key( ssl->handshake->psa_pake_password );
-    psa_pake_abort( &ssl->handshake->psa_pake_ctx );
-
-    MBEDTLS_SSL_DEBUG_RET( 1, "psa_pake_input round one", ret );
-    mbedtls_ssl_send_alert_message(
-            ssl,
-            MBEDTLS_SSL_ALERT_LEVEL_FATAL,
-            MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE );
-
-    return( ret );
 #else
     if( ( ret = mbedtls_ecjpake_read_round_one( &ssl->handshake->ecjpake_ctx,
                                                 buf, len ) ) != 0 )
@@ -2395,48 +2333,9 @@
     if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECJPAKE )
     {
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-        psa_status_t status;
-        size_t len = end - p;
-        size_t input_offset = 0;
-
-        for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ;
-             step <= PSA_PAKE_STEP_ZK_PROOF ;
-             ++step )
-        {
-            size_t length;
-
-            if( step == PSA_PAKE_STEP_KEY_SHARE )
-            {
-                /* Length is stored after 3bytes curve */
-                length = p[input_offset + 3];
-                input_offset += 3 + 1;
-            }
-            else
-            {
-                /* Length is stored at the first byte */
-                length = p[input_offset];
-                input_offset += 1;
-            }
-
-            if( input_offset + length > len )
-            {
-                ret = MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
-                goto psa_pake_out;
-            }
-
-            status = psa_pake_input( &ssl->handshake->psa_pake_ctx, step,
-                                     p + input_offset, length );
-            if( status != PSA_SUCCESS)
-            {
-                ret = psa_ssl_status_to_mbedtls( status );
-                goto psa_pake_out;
-            }
-
-            input_offset += length;
-        }
-
-psa_pake_out:
-        if( ret != 0 )
+        if( ( ret = psa_tls12_parse_ecjpake_round_two(
+                        &ssl->handshake->psa_pake_ctx, p, end - p,
+                        ssl->conf->endpoint ) ) != 0 )
         {
             psa_destroy_key( ssl->handshake->psa_pake_password );
             psa_pake_abort( &ssl->handshake->psa_pake_ctx );
@@ -3393,37 +3292,15 @@
         unsigned char *out_p = ssl->out_msg + header_len;
         unsigned char *end_p = ssl->out_msg + MBEDTLS_SSL_OUT_CONTENT_LEN -
                                header_len;
-        psa_status_t status;
-        size_t output_offset = 0;
-        size_t output_len;
-
-        for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ;
-             step <= PSA_PAKE_STEP_ZK_PROOF ;
-             ++step )
+        ret = psa_tls12_write_ecjpake_round_two( &ssl->handshake->psa_pake_ctx,
+                                    out_p, end_p - out_p, &content_len );
+        if ( ret != 0 )
         {
-            /* For each step, prepend 1 byte with the length of the data */
-            if (step != PSA_PAKE_STEP_ZK_PROOF) {
-                *(out_p + output_offset) = 65;
-            } else {
-                *(out_p + output_offset) = 32;
-            }
-            output_offset += 1;
-            status = psa_pake_output( &ssl->handshake->psa_pake_ctx,
-                                      step, out_p + output_offset,
-                                      end_p - out_p - output_offset,
-                                      &output_len );
-            if( status != PSA_SUCCESS )
-            {
-                psa_destroy_key( ssl->handshake->psa_pake_password );
-                psa_pake_abort( &ssl->handshake->psa_pake_ctx );
-                MBEDTLS_SSL_DEBUG_RET( 1 , "psa_pake_output", status );
-                return( psa_ssl_status_to_mbedtls( status ) );
-            }
-
-            output_offset += output_len;
+            psa_destroy_key( ssl->handshake->psa_pake_password );
+            psa_pake_abort( &ssl->handshake->psa_pake_ctx );
+            MBEDTLS_SSL_DEBUG_RET( 1 , "psa_pake_output", ret );
+            return( ret );
         }
-
-        content_len = output_offset;
 #else
         ret = mbedtls_ecjpake_write_round_two( &ssl->handshake->ecjpake_ctx,
                 ssl->out_msg + header_len,
diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c
index 68b4d09..806efd2 100644
--- a/library/ssl_tls12_server.c
+++ b/library/ssl_tls12_server.c
@@ -290,12 +290,9 @@
 MBEDTLS_CHECK_RETURN_CRITICAL
 static int ssl_parse_ecjpake_kkpp( mbedtls_ssl_context *ssl,
                                    const unsigned char *buf,
-                                   size_t len )
+                                   size_t len)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    psa_status_t status;
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( ssl->handshake->psa_pake_ctx_is_ok != 1 )
@@ -308,35 +305,19 @@
     }
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-    size_t input_offset = 0;
-
-    /* Repeat the KEY_SHARE, ZK_PUBLIC & ZF_PROOF twice */
-    for( unsigned int x = 1 ; x <= 2 ; ++x )
+    if ( ( ret = psa_tls12_parse_ecjpake_round_one(
+                        &ssl->handshake->psa_pake_ctx, buf, len ) ) != 0 )
     {
-        for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ;
-             step <= PSA_PAKE_STEP_ZK_PROOF ;
-             ++step )
-        {
-            /* Length is stored at the first byte */
-            size_t length = buf[input_offset];
-            input_offset += 1;
+        psa_destroy_key( ssl->handshake->psa_pake_password );
+        psa_pake_abort( &ssl->handshake->psa_pake_ctx );
 
-            if( input_offset + length > len )
-            {
-                ret = MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
-                goto psa_pake_error;
-            }
+        MBEDTLS_SSL_DEBUG_RET( 1, "psa_pake_input round one", ret );
+        mbedtls_ssl_send_alert_message(
+                ssl,
+                MBEDTLS_SSL_ALERT_LEVEL_FATAL,
+                MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE );
 
-            status = psa_pake_input( &ssl->handshake->psa_pake_ctx, step,
-                                     buf + input_offset, length );
-            if( status != PSA_SUCCESS)
-            {
-                ret = psa_ssl_status_to_mbedtls( status );
-                goto psa_pake_error;
-            }
-
-            input_offset += length;
-        }
+        return( ret );
     }
 #else
     if( ( ret = mbedtls_ecjpake_read_round_one( &ssl->handshake->ecjpake_ctx,
@@ -353,20 +334,6 @@
     ssl->handshake->cli_exts |= MBEDTLS_TLS_EXT_ECJPAKE_KKPP_OK;
 
     return( 0 );
-
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-psa_pake_error:
-    psa_destroy_key( ssl->handshake->psa_pake_password );
-    psa_pake_abort( &ssl->handshake->psa_pake_ctx );
-
-    MBEDTLS_SSL_DEBUG_RET( 1, "psa_pake_input round one", ret );
-    mbedtls_ssl_send_alert_message(
-            ssl,
-            MBEDTLS_SSL_ALERT_LEVEL_FATAL,
-            MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE );
-
-    return( ret );
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
 }
 #endif /* MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED */
 
@@ -2903,13 +2870,13 @@
 #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED)
     if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECJPAKE )
     {
+        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
         unsigned char *out_p = ssl->out_msg + ssl->out_msglen;
         unsigned char *end_p = ssl->out_msg + MBEDTLS_SSL_OUT_CONTENT_LEN -
                                ssl->out_msglen;
-        psa_status_t status;
         size_t output_offset = 0;
-        size_t output_len;
+        size_t output_len = 0;
         size_t ec_len;
 
 #if !defined(MBEDTLS_ECJPAKE_ALT)
@@ -2931,34 +2898,20 @@
 #endif //MBEDTLS_PSA_BUILTIN_ALG_JPAKE
         output_offset += ec_len;
 
-        for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ;
-             step <= PSA_PAKE_STEP_ZK_PROOF ;
-             ++step )
+        ret = psa_tls12_write_ecjpake_round_two( &ssl->handshake->psa_pake_ctx,
+                                    out_p + output_offset,
+                                    end_p - out_p - output_offset, &output_len );
+        if( ret != 0 )
         {
-            if (step != PSA_PAKE_STEP_ZK_PROOF) {
-                *(out_p + output_offset) = 65;
-            } else {
-                *(out_p + output_offset) = 32;
-            }
-            output_offset += 1;
-            status = psa_pake_output( &ssl->handshake->psa_pake_ctx,
-                                      step, out_p + output_offset,
-                                      end_p - out_p - output_offset,
-                                      &output_len );
-            if( status != PSA_SUCCESS )
-            {
-                psa_destroy_key( ssl->handshake->psa_pake_password );
-                psa_pake_abort( &ssl->handshake->psa_pake_ctx );
-                MBEDTLS_SSL_DEBUG_RET( 1 , "psa_pake_output", status );
-                return( psa_ssl_status_to_mbedtls( status ) );
-            }
-
-            output_offset += output_len;
+            psa_destroy_key( ssl->handshake->psa_pake_password );
+            psa_pake_abort( &ssl->handshake->psa_pake_ctx );
+            MBEDTLS_SSL_DEBUG_RET( 1 , "psa_pake_output", ret );
+            return( ret );
         }
 
+        output_offset += output_len;
         ssl->out_msglen += output_offset;
 #else
-        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
         size_t len = 0;
 
         ret = mbedtls_ecjpake_write_round_two(
@@ -4192,37 +4145,9 @@
     if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECJPAKE )
     {
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-        size_t len = end - p;
-        psa_status_t status;
-        size_t input_offset = 0;
-
-        for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ;
-             step <= PSA_PAKE_STEP_ZK_PROOF ;
-             ++step )
-        {
-            /* Length is stored at the first byte */
-            size_t length = p[input_offset];
-            input_offset += 1;
-
-            if( input_offset + length > len )
-            {
-                ret = MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
-                goto psa_pake_out;
-            }
-
-            status = psa_pake_input( &ssl->handshake->psa_pake_ctx, step,
-                                     p + input_offset, length );
-            if( status != PSA_SUCCESS)
-            {
-                ret = psa_ssl_status_to_mbedtls( status );
-                goto psa_pake_out;
-            }
-
-            input_offset += length;
-        }
-
-psa_pake_out:
-        if( ret != 0 )
+        if( ( ret = psa_tls12_parse_ecjpake_round_two(
+                        &ssl->handshake->psa_pake_ctx, p, end - p,
+                        ssl->conf->endpoint ) ) != 0 )
         {
             psa_destroy_key( ssl->handshake->psa_pake_password );
             psa_pake_abort( &ssl->handshake->psa_pake_ctx );