Get PSK length & check for buffer size before writting in ECHDE-PSK PSA version of ssl_parse_client_key_exchange()

Signed-off-by: Neil Armstrong <narmstrong@baylibre.com>
diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c
index d9a29dc..7b6efb1 100644
--- a/library/ssl_tls12_server.c
+++ b/library/ssl_tls12_server.c
@@ -4115,10 +4115,6 @@
         MBEDTLS_PUT_UINT16_BE( zlen, psm, 0 );
         psm += zlen_size + zlen;
 
-        /* opaque psk<0..2^16-1>; */
-        if( psm_end - psm < 2 )
-            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
-
         const unsigned char *psk = NULL;
         size_t psk_len = 0;
 
@@ -4130,13 +4126,14 @@
              */
             return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
 
+        /* opaque psk<0..2^16-1>; */
+        if( (size_t)( psm_end - psm ) < ( 2 + psk_len ) )
+            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
         /* Write the PSK length as uint16 */
         MBEDTLS_PUT_UINT16_BE( psk_len, psm, 0 );
         psm += 2;
 
-        if( psm_end < psm || (size_t)( psm_end - psm ) < psk_len )
-            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
-
         /* Write the PSK itself */
         memcpy( psm, psk, psk_len );
         psm += psk_len;