psa_tls12_prf_psk_to_ms_set_key(): add support for other secret input
Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index f3a2258..e1f1e7b 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -5299,31 +5299,58 @@
size_t data_length )
{
psa_status_t status;
- uint8_t pms[ 4 + 2 * PSA_TLS12_PSK_TO_MS_PSK_MAX_SIZE ];
- uint8_t *cur = pms;
+ const size_t pms_len = ( prf->state == PSA_TLS12_PRF_STATE_OTHER_KEY_SET ?
+ 4 + data_length + prf->other_secret_length :
+ 4 + 2 * data_length );
if( data_length > PSA_TLS12_PSK_TO_MS_PSK_MAX_SIZE )
return( PSA_ERROR_INVALID_ARGUMENT );
- /* Quoting RFC 4279, Section 2:
+ uint8_t *pms = mbedtls_calloc( 1, pms_len );
+ uint8_t *cur = pms;
+
+ /* pure-PSK:
+ * Quoting RFC 4279, Section 2:
*
* The premaster secret is formed as follows: if the PSK is N octets
* long, concatenate a uint16 with the value N, N zero octets, a second
* uint16 with the value N, and the PSK itself.
+ *
+ * mixed-PSK:
+ * In a DHE-PSK, RSA-PSK, ECDHE-PSK the premaster secret is formed as
+ * follows: concatenate a uint16 with the length of the other secret,
+ * the other secret itself, uint16 with the length of PSK, and the
+ * PSK itself.
+ * For details please check:
+ * - RFC 4279, Section 4 for the definition of RSA-PSK,
+ * - RFC 4279, Section 3 for the definition of DHE-PSK,
+ * - RFC 5489 for the definition of ECDHE-PSK.
*/
+ if ( prf->state == PSA_TLS12_PRF_STATE_OTHER_KEY_SET )
+ {
+ *cur++ = MBEDTLS_BYTE_1( prf->other_secret_length );
+ *cur++ = MBEDTLS_BYTE_0( prf->other_secret_length );
+ memcpy( cur, prf->other_secret, prf->other_secret_length );
+ cur += prf->other_secret_length;
+ }
+ else
+ {
+ *cur++ = MBEDTLS_BYTE_1( data_length );
+ *cur++ = MBEDTLS_BYTE_0( data_length );
+ memset( cur, 0, data_length );
+ cur += data_length;
+ }
+
*cur++ = MBEDTLS_BYTE_1( data_length );
*cur++ = MBEDTLS_BYTE_0( data_length );
- memset( cur, 0, data_length );
- cur += data_length;
- *cur++ = pms[0];
- *cur++ = pms[1];
memcpy( cur, data, data_length );
cur += data_length;
status = psa_tls12_prf_set_key( prf, pms, cur - pms );
- mbedtls_platform_zeroize( pms, sizeof( pms ) );
+ mbedtls_platform_zeroize( pms, pms_len );
+ mbedtls_free( pms );
return( status );
}