psa_tls12_prf_psk_to_ms_set_key(): add support for other secret input

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index f3a2258..e1f1e7b 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -5299,31 +5299,58 @@
     size_t data_length )
 {
     psa_status_t status;
-    uint8_t pms[ 4 + 2 * PSA_TLS12_PSK_TO_MS_PSK_MAX_SIZE ];
-    uint8_t *cur = pms;
+    const size_t pms_len = ( prf->state == PSA_TLS12_PRF_STATE_OTHER_KEY_SET ?
+        4 + data_length + prf->other_secret_length :
+        4 + 2 * data_length );
 
     if( data_length > PSA_TLS12_PSK_TO_MS_PSK_MAX_SIZE )
         return( PSA_ERROR_INVALID_ARGUMENT );
 
-    /* Quoting RFC 4279, Section 2:
+    uint8_t *pms = mbedtls_calloc( 1, pms_len );
+    uint8_t *cur = pms;
+
+    /* pure-PSK:
+     * Quoting RFC 4279, Section 2:
      *
      * The premaster secret is formed as follows: if the PSK is N octets
      * long, concatenate a uint16 with the value N, N zero octets, a second
      * uint16 with the value N, and the PSK itself.
+     *
+     * mixed-PSK:
+     * In a DHE-PSK, RSA-PSK, ECDHE-PSK the premaster secret is formed as
+     * follows: concatenate a uint16 with the length of the other secret,
+     * the other secret itself, uint16 with the length of PSK, and the
+     * PSK itself.
+     * For details please check:
+     * - RFC 4279, Section 4 for the definition of RSA-PSK,
+     * - RFC 4279, Section 3 for the definition of DHE-PSK,
+     * - RFC 5489 for the definition of ECDHE-PSK.
      */
 
+    if ( prf->state == PSA_TLS12_PRF_STATE_OTHER_KEY_SET )
+    {
+        *cur++ = MBEDTLS_BYTE_1( prf->other_secret_length );
+        *cur++ = MBEDTLS_BYTE_0( prf->other_secret_length );
+        memcpy( cur, prf->other_secret, prf->other_secret_length );
+        cur += prf->other_secret_length;
+    }
+    else
+    {
+        *cur++ = MBEDTLS_BYTE_1( data_length );
+        *cur++ = MBEDTLS_BYTE_0( data_length );
+        memset( cur, 0, data_length );
+        cur += data_length;
+    }
+
     *cur++ = MBEDTLS_BYTE_1( data_length );
     *cur++ = MBEDTLS_BYTE_0( data_length );
-    memset( cur, 0, data_length );
-    cur += data_length;
-    *cur++ = pms[0];
-    *cur++ = pms[1];
     memcpy( cur, data, data_length );
     cur += data_length;
 
     status = psa_tls12_prf_set_key( prf, pms, cur - pms );
 
-    mbedtls_platform_zeroize( pms, sizeof( pms ) );
+    mbedtls_platform_zeroize( pms, pms_len );
+    mbedtls_free( pms );
     return( status );
 }