tls12: psa_pake: use common code for parsing/writing round one and round two data

Share a common parsing code for both server and client for parsing
round one and two.

Signed-off-by: Valerio Setti <vsetti@baylibre.com>
diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c
index 68b4d09..806efd2 100644
--- a/library/ssl_tls12_server.c
+++ b/library/ssl_tls12_server.c
@@ -290,12 +290,9 @@
 MBEDTLS_CHECK_RETURN_CRITICAL
 static int ssl_parse_ecjpake_kkpp( mbedtls_ssl_context *ssl,
                                    const unsigned char *buf,
-                                   size_t len )
+                                   size_t len)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    psa_status_t status;
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( ssl->handshake->psa_pake_ctx_is_ok != 1 )
@@ -308,35 +305,19 @@
     }
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-    size_t input_offset = 0;
-
-    /* Repeat the KEY_SHARE, ZK_PUBLIC & ZF_PROOF twice */
-    for( unsigned int x = 1 ; x <= 2 ; ++x )
+    if ( ( ret = psa_tls12_parse_ecjpake_round_one(
+                        &ssl->handshake->psa_pake_ctx, buf, len ) ) != 0 )
     {
-        for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ;
-             step <= PSA_PAKE_STEP_ZK_PROOF ;
-             ++step )
-        {
-            /* Length is stored at the first byte */
-            size_t length = buf[input_offset];
-            input_offset += 1;
+        psa_destroy_key( ssl->handshake->psa_pake_password );
+        psa_pake_abort( &ssl->handshake->psa_pake_ctx );
 
-            if( input_offset + length > len )
-            {
-                ret = MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
-                goto psa_pake_error;
-            }
+        MBEDTLS_SSL_DEBUG_RET( 1, "psa_pake_input round one", ret );
+        mbedtls_ssl_send_alert_message(
+                ssl,
+                MBEDTLS_SSL_ALERT_LEVEL_FATAL,
+                MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE );
 
-            status = psa_pake_input( &ssl->handshake->psa_pake_ctx, step,
-                                     buf + input_offset, length );
-            if( status != PSA_SUCCESS)
-            {
-                ret = psa_ssl_status_to_mbedtls( status );
-                goto psa_pake_error;
-            }
-
-            input_offset += length;
-        }
+        return( ret );
     }
 #else
     if( ( ret = mbedtls_ecjpake_read_round_one( &ssl->handshake->ecjpake_ctx,
@@ -353,20 +334,6 @@
     ssl->handshake->cli_exts |= MBEDTLS_TLS_EXT_ECJPAKE_KKPP_OK;
 
     return( 0 );
-
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-psa_pake_error:
-    psa_destroy_key( ssl->handshake->psa_pake_password );
-    psa_pake_abort( &ssl->handshake->psa_pake_ctx );
-
-    MBEDTLS_SSL_DEBUG_RET( 1, "psa_pake_input round one", ret );
-    mbedtls_ssl_send_alert_message(
-            ssl,
-            MBEDTLS_SSL_ALERT_LEVEL_FATAL,
-            MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE );
-
-    return( ret );
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
 }
 #endif /* MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED */
 
@@ -2903,13 +2870,13 @@
 #if defined(MBEDTLS_KEY_EXCHANGE_ECJPAKE_ENABLED)
     if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECJPAKE )
     {
+        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
         unsigned char *out_p = ssl->out_msg + ssl->out_msglen;
         unsigned char *end_p = ssl->out_msg + MBEDTLS_SSL_OUT_CONTENT_LEN -
                                ssl->out_msglen;
-        psa_status_t status;
         size_t output_offset = 0;
-        size_t output_len;
+        size_t output_len = 0;
         size_t ec_len;
 
 #if !defined(MBEDTLS_ECJPAKE_ALT)
@@ -2931,34 +2898,20 @@
 #endif //MBEDTLS_PSA_BUILTIN_ALG_JPAKE
         output_offset += ec_len;
 
-        for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ;
-             step <= PSA_PAKE_STEP_ZK_PROOF ;
-             ++step )
+        ret = psa_tls12_write_ecjpake_round_two( &ssl->handshake->psa_pake_ctx,
+                                    out_p + output_offset,
+                                    end_p - out_p - output_offset, &output_len );
+        if( ret != 0 )
         {
-            if (step != PSA_PAKE_STEP_ZK_PROOF) {
-                *(out_p + output_offset) = 65;
-            } else {
-                *(out_p + output_offset) = 32;
-            }
-            output_offset += 1;
-            status = psa_pake_output( &ssl->handshake->psa_pake_ctx,
-                                      step, out_p + output_offset,
-                                      end_p - out_p - output_offset,
-                                      &output_len );
-            if( status != PSA_SUCCESS )
-            {
-                psa_destroy_key( ssl->handshake->psa_pake_password );
-                psa_pake_abort( &ssl->handshake->psa_pake_ctx );
-                MBEDTLS_SSL_DEBUG_RET( 1 , "psa_pake_output", status );
-                return( psa_ssl_status_to_mbedtls( status ) );
-            }
-
-            output_offset += output_len;
+            psa_destroy_key( ssl->handshake->psa_pake_password );
+            psa_pake_abort( &ssl->handshake->psa_pake_ctx );
+            MBEDTLS_SSL_DEBUG_RET( 1 , "psa_pake_output", ret );
+            return( ret );
         }
 
+        output_offset += output_len;
         ssl->out_msglen += output_offset;
 #else
-        int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
         size_t len = 0;
 
         ret = mbedtls_ecjpake_write_round_two(
@@ -4192,37 +4145,9 @@
     if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_ECJPAKE )
     {
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-        size_t len = end - p;
-        psa_status_t status;
-        size_t input_offset = 0;
-
-        for( psa_pake_step_t step = PSA_PAKE_STEP_KEY_SHARE ;
-             step <= PSA_PAKE_STEP_ZK_PROOF ;
-             ++step )
-        {
-            /* Length is stored at the first byte */
-            size_t length = p[input_offset];
-            input_offset += 1;
-
-            if( input_offset + length > len )
-            {
-                ret = MBEDTLS_ERR_SSL_HANDSHAKE_FAILURE;
-                goto psa_pake_out;
-            }
-
-            status = psa_pake_input( &ssl->handshake->psa_pake_ctx, step,
-                                     p + input_offset, length );
-            if( status != PSA_SUCCESS)
-            {
-                ret = psa_ssl_status_to_mbedtls( status );
-                goto psa_pake_out;
-            }
-
-            input_offset += length;
-        }
-
-psa_pake_out:
-        if( ret != 0 )
+        if( ( ret = psa_tls12_parse_ecjpake_round_two(
+                        &ssl->handshake->psa_pake_ctx, p, end - p,
+                        ssl->conf->endpoint ) ) != 0 )
         {
             psa_destroy_key( ssl->handshake->psa_pake_password );
             psa_pake_abort( &ssl->handshake->psa_pake_ctx );