Merged renegotiation refactoring
diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h
index e5ca9d5..e51e507 100644
--- a/include/polarssl/ssl.h
+++ b/include/polarssl/ssl.h
@@ -1570,6 +1570,20 @@
}
#endif /* POLARSSL_X509_CRT_PARSE_C */
+/* constant-time buffer comparison */
+static inline int safer_memcmp( const void *a, const void *b, size_t n )
+{
+ size_t i;
+ const unsigned char *A = (const unsigned char *) a;
+ const unsigned char *B = (const unsigned char *) b;
+ unsigned char diff = 0;
+
+ for( i = 0; i < n; i++ )
+ diff |= A[i] ^ B[i];
+
+ return( diff );
+}
+
#ifdef __cplusplus
}
#endif
diff --git a/library/bignum.c b/library/bignum.c
index e2da5a8..2a97a59 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -1493,7 +1493,7 @@
for( i = 0; i < wsize - 1; i++ )
mpi_montmul( &W[j], &W[j], N, mm, &T );
-
+
/*
* W[i] = W[i - 1] * W[1]
*/
@@ -1516,9 +1516,11 @@
{
if( bufsize == 0 )
{
- if( nblimbs-- == 0 )
+ if( nblimbs == 0 )
break;
+ nblimbs--;
+
bufsize = sizeof( t_uint ) << 3;
}
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index ad6583b..0eaa531 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -628,11 +628,13 @@
}
else
{
+ /* Check verify-data in constant-time. The length OTOH is no secret */
if( len != 1 + ssl->verify_data_len * 2 ||
buf[0] != ssl->verify_data_len * 2 ||
- memcmp( buf + 1, ssl->own_verify_data, ssl->verify_data_len ) != 0 ||
- memcmp( buf + 1 + ssl->verify_data_len,
- ssl->peer_verify_data, ssl->verify_data_len ) != 0 )
+ safer_memcmp( buf + 1,
+ ssl->own_verify_data, ssl->verify_data_len ) != 0 ||
+ safer_memcmp( buf + 1 + ssl->verify_data_len,
+ ssl->peer_verify_data, ssl->verify_data_len ) != 0 )
{
SSL_DEBUG_MSG( 1, ( "non-matching renegotiated connection field" ) );
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index 7d81fc9..e44bf72 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -254,7 +254,7 @@
unsigned char *mac;
unsigned char computed_mac[32];
size_t enc_len, clear_len, i;
- unsigned char pad_len;
+ unsigned char pad_len, diff;
SSL_DEBUG_BUF( 3, "session ticket structure", buf, len );
@@ -267,19 +267,23 @@
if( len != enc_len + 66 )
return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
- /* Check name */
- if( memcmp( key_name, ssl->ticket_keys->key_name, 16 ) != 0 )
- return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
+ /* Check name, in constant time though it's not a big secret */
+ diff = 0;
+ for( i = 0; i < 16; i++ )
+ diff |= key_name[i] ^ ssl->ticket_keys->key_name[i];
+ /* don't return yet, check the MAC anyway */
- /* Check mac */
+ /* Check mac, with constant-time buffer comparison */
sha256_hmac( ssl->ticket_keys->mac_key, 16, buf, len - 32,
computed_mac, 0 );
- ret = 0;
+
for( i = 0; i < 32; i++ )
- if( mac[i] != computed_mac[i] )
- ret = POLARSSL_ERR_SSL_INVALID_MAC;
- if( ret != 0 )
- return( ret );
+ diff |= mac[i] ^ computed_mac[i];
+
+ /* Now return if ticket is not authentic, since we want to avoid
+ * decrypting arbitrary attacker-chosen data */
+ if( diff != 0 )
+ return( POLARSSL_ERR_SSL_INVALID_MAC );
/* Decrypt */
if( ( ret = aes_crypt_cbc( &ssl->ticket_keys->dec, AES_DECRYPT,
@@ -428,9 +432,11 @@
}
else
{
+ /* Check verify-data in constant-time. The length OTOH is no secret */
if( len != 1 + ssl->verify_data_len ||
buf[0] != ssl->verify_data_len ||
- memcmp( buf + 1, ssl->peer_verify_data, ssl->verify_data_len ) != 0 )
+ safer_memcmp( buf + 1, ssl->peer_verify_data,
+ ssl->verify_data_len ) != 0 )
{
SSL_DEBUG_MSG( 1, ( "non-matching renegotiated connection field" ) );
@@ -2408,8 +2414,10 @@
if( ret == 0 )
{
+ /* Identity is not a big secret since clients send it in the clear,
+ * but treat it carefully anyway, just in case */
if( n != ssl->psk_identity_len ||
- memcmp( ssl->psk_identity, *p, n ) != 0 )
+ safer_memcmp( ssl->psk_identity, *p, n ) != 0 )
{
ret = POLARSSL_ERR_SSL_UNKNOWN_IDENTITY;
}
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 4654ea6..055798f 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -1711,7 +1711,7 @@
SSL_DEBUG_BUF( 4, "computed mac", ssl->in_msg + ssl->in_msglen,
ssl->transform_in->maclen );
- if( memcmp( tmp, ssl->in_msg + ssl->in_msglen,
+ if( safer_memcmp( tmp, ssl->in_msg + ssl->in_msglen,
ssl->transform_in->maclen ) != 0 )
{
#if defined(POLARSSL_SSL_DEBUG_ALL)
@@ -3196,7 +3196,7 @@
return( POLARSSL_ERR_SSL_BAD_HS_FINISHED );
}
- if( memcmp( ssl->in_msg + 4, buf, hash_len ) != 0 )
+ if( safer_memcmp( ssl->in_msg + 4, buf, hash_len ) != 0 )
{
SSL_DEBUG_MSG( 1, ( "bad finished message" ) );
return( POLARSSL_ERR_SSL_BAD_HS_FINISHED );
diff --git a/programs/aes/aescrypt2.c b/programs/aes/aescrypt2.c
index 4c1f8ea..1239ca2 100644
--- a/programs/aes/aescrypt2.c
+++ b/programs/aes/aescrypt2.c
@@ -75,6 +75,7 @@
unsigned char key[512];
unsigned char digest[32];
unsigned char buffer[1024];
+ unsigned char diff;
aes_context aes_ctx;
sha256_context sha_ctx;
@@ -397,7 +398,12 @@
goto exit;
}
- if( memcmp( digest, buffer, 32 ) != 0 )
+ /* Use constant-time buffer comparison */
+ diff = 0;
+ for( i = 0; i < 32; i++ )
+ diff |= digest[i] ^ buffer[i];
+
+ if( diff != 0 )
{
fprintf( stderr, "HMAC check failed: wrong key, "
"or file corrupted.\n" );
diff --git a/programs/aes/crypt_and_hash.c b/programs/aes/crypt_and_hash.c
index d2845de..50218e1 100644
--- a/programs/aes/crypt_and_hash.c
+++ b/programs/aes/crypt_and_hash.c
@@ -76,6 +76,7 @@
unsigned char digest[POLARSSL_MD_MAX_SIZE];
unsigned char buffer[1024];
unsigned char output[1024];
+ unsigned char diff;
const cipher_info_t *cipher_info;
const md_info_t *md_info;
@@ -476,7 +477,12 @@
goto exit;
}
- if( memcmp( digest, buffer, md_get_size( md_info ) ) != 0 )
+ /* Use constant-time buffer comparison */
+ diff = 0;
+ for( i = 0; i < md_get_size( md_info ); i++ )
+ diff |= digest[i] ^ buffer[i];
+
+ if( diff != 0 )
{
fprintf( stderr, "HMAC check failed: wrong key, "
"or file corrupted.\n" );
diff --git a/programs/hash/generic_sum.c b/programs/hash/generic_sum.c
index 8ca4d92..3f29058 100644
--- a/programs/hash/generic_sum.c
+++ b/programs/hash/generic_sum.c
@@ -77,6 +77,7 @@
int nb_tot1, nb_tot2;
unsigned char sum[POLARSSL_MD_MAX_SIZE];
char buf[POLARSSL_MD_MAX_SIZE * 2 + 1], line[1024];
+ char diff;
if( ( f = fopen( filename, "rb" ) ) == NULL )
{
@@ -123,7 +124,12 @@
for( i = 0; i < md_info->size; i++ )
sprintf( buf + i * 2, "%02x", sum[i] );
- if( memcmp( line, buf, 2 * md_info->size ) != 0 )
+ /* Use constant-time buffer comparison */
+ diff = 0;
+ for( i = 0; i < 2 * md_info->size; i++ )
+ diff |= line[i] ^ buf[i];
+
+ if( diff != 0 )
{
nb_err2++;
fprintf( stderr, "wrong checksum: %s\n", line + 66 );
diff --git a/programs/hash/md5sum.c b/programs/hash/md5sum.c
index 6ddc673..d614aa1 100644
--- a/programs/hash/md5sum.c
+++ b/programs/hash/md5sum.c
@@ -77,6 +77,7 @@
int nb_tot1, nb_tot2;
unsigned char sum[16];
char buf[33], line[1024];
+ char diff;
if( ( f = fopen( filename, "rb" ) ) == NULL )
{
@@ -117,7 +118,12 @@
for( i = 0; i < 16; i++ )
sprintf( buf + i * 2, "%02x", sum[i] );
- if( memcmp( line, buf, 32 ) != 0 )
+ /* Use constant-time buffer comparison */
+ diff = 0;
+ for( i = 0; i < 32; i++ )
+ diff |= line[i] ^ buf[i];
+
+ if( diff != 0 )
{
nb_err2++;
fprintf( stderr, "wrong checksum: %s\n", line + 34 );
diff --git a/programs/hash/sha1sum.c b/programs/hash/sha1sum.c
index adde916..ff0514a 100644
--- a/programs/hash/sha1sum.c
+++ b/programs/hash/sha1sum.c
@@ -77,6 +77,7 @@
int nb_tot1, nb_tot2;
unsigned char sum[20];
char buf[41], line[1024];
+ char diff;
if( ( f = fopen( filename, "rb" ) ) == NULL )
{
@@ -117,7 +118,12 @@
for( i = 0; i < 20; i++ )
sprintf( buf + i * 2, "%02x", sum[i] );
- if( memcmp( line, buf, 40 ) != 0 )
+ /* Use constant-time buffer comparison */
+ diff = 0;
+ for( i = 0; i < 40; i++ )
+ diff |= line[i] ^ buf[i];
+
+ if( diff != 0 )
{
nb_err2++;
fprintf( stderr, "wrong checksum: %s\n", line + 42 );
diff --git a/programs/hash/sha2sum.c b/programs/hash/sha2sum.c
index 2f3acf8..c3f1a0d 100644
--- a/programs/hash/sha2sum.c
+++ b/programs/hash/sha2sum.c
@@ -77,6 +77,7 @@
int nb_tot1, nb_tot2;
unsigned char sum[32];
char buf[65], line[1024];
+ char diff;
if( ( f = fopen( filename, "rb" ) ) == NULL )
{
@@ -117,7 +118,12 @@
for( i = 0; i < 32; i++ )
sprintf( buf + i * 2, "%02x", sum[i] );
- if( memcmp( line, buf, 64 ) != 0 )
+ /* Use constant-time buffer comparison */
+ diff = 0;
+ for( i = 0; i < 64; i++ )
+ diff |= line[i] ^ buf[i];
+
+ if( diff != 0 )
{
nb_err2++;
fprintf( stderr, "wrong checksum: %s\n", line + 66 );