Calculate hashes of ssl encryption and decryption keys

Optimize the key switching mechanism to set the key only if 
a different operation is performed with the context.
Signed-off-by: Andrzej Kurek <andrzej.kurek@arm.com>
diff --git a/library/aes.c b/library/aes.c
index 8005172..04c8208 100644
--- a/library/aes.c
+++ b/library/aes.c
@@ -85,19 +85,6 @@
 }
 #endif
 
-#if defined(MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY)
-static uint32_t mbedtls_hash( const void *data, size_t data_len_bytes )
-{
-    uint32_t result = 0;
-    size_t i;
-    /* data_len_bytes - only multiples of 4 are considered, rest is truncated */
-    for( i = 0; i < data_len_bytes >> 2; i++ )
-    {
-        result ^= ( (uint32_t*) data )[i];
-    }
-    return result;
-}
-#endif
 /*
  * Data structure for AES round data
  */
diff --git a/library/platform_util.c b/library/platform_util.c
index 458dfc9..15309aa 100644
--- a/library/platform_util.c
+++ b/library/platform_util.c
@@ -442,6 +442,20 @@
 }
 #endif /* MBEDTLS_HAVE_TIME_DATE && MBEDTLS_PLATFORM_GMTIME_R_ALT */
 
+#if defined(MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY) || defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+uint32_t mbedtls_hash( const void *data, size_t data_len_bytes )
+{
+    uint32_t result = 0;
+    size_t i;
+    /* data_len_bytes - only multiples of 4 are considered, rest is truncated */
+    for( i = 0; i < data_len_bytes >> 2; i++ )
+    {
+        result ^= ( (uint32_t*) data )[i];
+    }
+    return result;
+}
+#endif
+
 unsigned char* mbedtls_platform_put_uint32_be( unsigned char *buf,
                                                size_t num )
 {
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 2c363fd..0a763f4 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -58,6 +58,54 @@
 
 #define PROPER_HS_FRAGMENT 0x75555555
 
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+static int mbedtls_ssl_switch_key( mbedtls_ssl_transform *transform,
+                                   const mbedtls_operation_t operation )
+{
+    unsigned char * key;
+    int ret;
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+    uint32_t hash;
+#endif
+    if( operation == MBEDTLS_ENCRYPT )
+    {
+        key = transform->key_enc;
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+        hash = transform->key_enc_hash;
+#endif
+    }
+    else if ( operation == MBEDTLS_DECRYPT )
+    {
+        key = transform->key_dec;
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+        hash = transform->key_dec_hash;
+#endif
+    }
+    else
+    {
+        return ( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+    }
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+    /* Check hash */
+    if( hash != mbedtls_hash( key, transform->key_bitlen >> 3 ) )
+    {
+        return( MBEDTLS_ERR_PLATFORM_FAULT_DETECTED );
+    }
+#endif
+    if( operation != transform->cipher_ctx.operation )
+    {
+        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
+                                           key,
+                                           transform->key_bitlen,
+                                           operation ) ) != 0 )
+        {
+            return( ret );
+        }
+    }
+    return( 0 );
+}
+#endif
+
 #if defined(MBEDTLS_USE_TINYCRYPT)
 static int uecc_rng_wrapper( uint8_t *dest, unsigned int size )
 {
@@ -1577,6 +1625,11 @@
     memcpy( transform->key_dec, key2, cipher_info->key_bitlen >> 3 );
 
     transform->key_bitlen = cipher_info->key_bitlen;
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+    transform->key_enc_hash = mbedtls_hash( transform->key_enc, transform->key_bitlen >> 3 );
+    transform->key_dec_hash = mbedtls_hash( transform->key_dec, transform->key_bitlen >> 3 );
+#endif
+
 #else
     if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx_enc,
                                  cipher_info ) ) != 0 )
@@ -2697,12 +2750,10 @@
                                     "including %d bytes of padding",
                                     rec->data_len, 0 ) );
 #if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
-        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
-                                           transform->key_enc,
-                                           transform->key_bitlen,
-                                           MBEDTLS_ENCRYPT ) ) != 0 )
+        if( ( ret = mbedtls_ssl_switch_key( transform, MBEDTLS_ENCRYPT ) )
+                != 0 )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_switch_key", ret );
             return( ret );
         }
 
@@ -2798,12 +2849,10 @@
          * Encrypt and authenticate
          */
 #if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
-        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
-                                           transform->key_enc,
-                                           transform->key_bitlen,
-                                           MBEDTLS_ENCRYPT ) ) != 0 )
+        if( ( ret = mbedtls_ssl_switch_key( transform, MBEDTLS_ENCRYPT ) )
+                != 0 )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_switch_key", ret );
             return( ret );
         }
 
@@ -2905,12 +2954,10 @@
                             rec->data_len, transform->ivlen,
                             padlen + 1 ) );
 #if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
-        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
-                                           transform->key_enc,
-                                           transform->key_bitlen,
-                                           MBEDTLS_ENCRYPT ) ) != 0 )
+        if( ( ret = mbedtls_ssl_switch_key( transform, MBEDTLS_ENCRYPT ) )
+             != 0 )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_switch_key", ret );
             return( ret );
         }
 
@@ -3076,12 +3123,10 @@
     {
         padlen = 0;
 #if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
-        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
-                                           transform->key_dec,
-                                           transform->key_bitlen,
-                                           MBEDTLS_DECRYPT ) ) != 0 )
+        if( ( ret = mbedtls_ssl_switch_key( transform, MBEDTLS_DECRYPT ) )
+                != 0 )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_switch_key", ret );
             return( ret );
         }
         if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx,
@@ -3192,12 +3237,10 @@
          * Decrypt and authenticate
          */
 #if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
-        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
-                                           transform->key_dec,
-                                           transform->key_bitlen,
-                                           MBEDTLS_DECRYPT ) ) != 0 )
+        if( ( ret = mbedtls_ssl_switch_key( transform, MBEDTLS_DECRYPT ) )
+                != 0 )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_switch_key", ret );
             return( ret );
         }
         if( ( ret = mbedtls_cipher_auth_decrypt( &transform->cipher_ctx,
@@ -3376,14 +3419,13 @@
 
         /* We still have data_len % ivlen == 0 and data_len >= ivlen here. */
 #if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
-        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
-                                           transform->key_dec,
-                                           transform->key_bitlen,
-                                           MBEDTLS_DECRYPT ) ) != 0 )
+        if( ( ret = mbedtls_ssl_switch_key( transform, MBEDTLS_DECRYPT ) )
+                != 0 )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_switch_key", ret );
             return( ret );
         }
+
         if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx,
                                    transform->iv_dec, transform->ivlen,
                                    data, rec->data_len, data, &olen ) ) != 0 )
diff --git a/library/version_features.c b/library/version_features.c
index ec4a692..2ef9d12 100644
--- a/library/version_features.c
+++ b/library/version_features.c
@@ -687,6 +687,9 @@
 #if defined(MBEDTLS_CRC_C)
     "MBEDTLS_CRC_C",
 #endif /* MBEDTLS_CRC_C */
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+    "MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY",
+#endif /* MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY */
 #if defined(MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY)
     "MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY",
 #endif /* MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY */