Calculate hashes of ssl encryption and decryption keys

Optimize the key switching mechanism to set the key only if 
a different operation is performed with the context.
Signed-off-by: Andrzej Kurek <andrzej.kurek@arm.com>
diff --git a/configs/baremetal.h b/configs/baremetal.h
index b63584f..9fa3918 100644
--- a/configs/baremetal.h
+++ b/configs/baremetal.h
@@ -138,6 +138,7 @@
 
 #define MBEDTLS_OID_C
 #define MBEDTLS_PLATFORM_C
+#define MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY
 #define MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY
 
 /* I/O buffer configuration */
diff --git a/include/mbedtls/check_config.h b/include/mbedtls/check_config.h
index f91f6b4..5e2a661 100644
--- a/include/mbedtls/check_config.h
+++ b/include/mbedtls/check_config.h
@@ -677,6 +677,11 @@
 #error "MBEDTLS_ARC4_C cannot be defined with MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS on"
 #endif
 
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY) && !defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS) && \
+    defined(MBEDTLS_ARC4_C)
+#error "MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY requires MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS to be defined."
+#endif
+
 #if defined(MBEDTLS_SSL_TLS_C) && (defined(MBEDTLS_SSL_PROTO_SSL3) && \
     defined(MBEDTLS_SSL_PROTO_TLS1_1) && !defined(MBEDTLS_SSL_PROTO_TLS1))
 #error "Illegal protocol selection"
diff --git a/include/mbedtls/config.h b/include/mbedtls/config.h
index 1cf868f..06cdde9 100644
--- a/include/mbedtls/config.h
+++ b/include/mbedtls/config.h
@@ -2740,6 +2740,18 @@
 //#define MBEDTLS_CRC_C
 
 /**
+ * \def MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY
+ *
+ * Enable validation of ssl keys by checking their hash
+ * during every encryption/decryption.
+ *
+ * Module:  library/ssl_tls.c
+ *
+ * Requires: MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS
+ */
+//#define MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY
+
+/**
  * \def MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY
  *
  * Enable validation of AES keys by checking their hash
diff --git a/include/mbedtls/platform_util.h b/include/mbedtls/platform_util.h
index b44be23..e0bd08c 100644
--- a/include/mbedtls/platform_util.h
+++ b/include/mbedtls/platform_util.h
@@ -340,6 +340,18 @@
                                       struct tm *tm_buf );
 #endif /* MBEDTLS_HAVE_TIME_DATE */
 
+#if defined(MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY) || defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+/**
+ * \brief      Calculate a hash from the given data.
+ *
+ * \param data            Data from which the hash is calculated.
+ * \param data_len_bytes  Length of the data in bytes.
+ *
+ * \return     A hash calculated from the provided data.
+ */
+uint32_t mbedtls_hash( const void *data, size_t data_len_bytes );
+#endif
+
 /**
  * \brief      Convert a 32-bit number to the big endian format and write it to
  *             the given buffer.
diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h
index 40d246e..a1c5d1d 100644
--- a/include/mbedtls/ssl_internal.h
+++ b/include/mbedtls/ssl_internal.h
@@ -760,7 +760,11 @@
     unsigned char *key_enc;
     unsigned char *key_dec;
     unsigned int key_bitlen;
-    mbedtls_cipher_context_t cipher_ctx;        /*!<  encryption/decryption context */
+    mbedtls_cipher_context_t cipher_ctx;    /*!<  encryption/decryption context */
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+    uint32_t key_enc_hash;                  /*!< hash of the encryption key */
+    uint32_t key_dec_hash;                  /*!< hash of the decryption key */
+#endif
 #else
     mbedtls_cipher_context_t cipher_ctx_enc;    /*!<  encryption context      */
     mbedtls_cipher_context_t cipher_ctx_dec;    /*!<  decryption context      */
diff --git a/library/aes.c b/library/aes.c
index 8005172..04c8208 100644
--- a/library/aes.c
+++ b/library/aes.c
@@ -85,19 +85,6 @@
 }
 #endif
 
-#if defined(MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY)
-static uint32_t mbedtls_hash( const void *data, size_t data_len_bytes )
-{
-    uint32_t result = 0;
-    size_t i;
-    /* data_len_bytes - only multiples of 4 are considered, rest is truncated */
-    for( i = 0; i < data_len_bytes >> 2; i++ )
-    {
-        result ^= ( (uint32_t*) data )[i];
-    }
-    return result;
-}
-#endif
 /*
  * Data structure for AES round data
  */
diff --git a/library/platform_util.c b/library/platform_util.c
index 458dfc9..15309aa 100644
--- a/library/platform_util.c
+++ b/library/platform_util.c
@@ -442,6 +442,20 @@
 }
 #endif /* MBEDTLS_HAVE_TIME_DATE && MBEDTLS_PLATFORM_GMTIME_R_ALT */
 
+#if defined(MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY) || defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+uint32_t mbedtls_hash( const void *data, size_t data_len_bytes )
+{
+    uint32_t result = 0;
+    size_t i;
+    /* data_len_bytes - only multiples of 4 are considered, rest is truncated */
+    for( i = 0; i < data_len_bytes >> 2; i++ )
+    {
+        result ^= ( (uint32_t*) data )[i];
+    }
+    return result;
+}
+#endif
+
 unsigned char* mbedtls_platform_put_uint32_be( unsigned char *buf,
                                                size_t num )
 {
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 2c363fd..0a763f4 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -58,6 +58,54 @@
 
 #define PROPER_HS_FRAGMENT 0x75555555
 
+#if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
+static int mbedtls_ssl_switch_key( mbedtls_ssl_transform *transform,
+                                   const mbedtls_operation_t operation )
+{
+    unsigned char * key;
+    int ret;
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+    uint32_t hash;
+#endif
+    if( operation == MBEDTLS_ENCRYPT )
+    {
+        key = transform->key_enc;
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+        hash = transform->key_enc_hash;
+#endif
+    }
+    else if ( operation == MBEDTLS_DECRYPT )
+    {
+        key = transform->key_dec;
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+        hash = transform->key_dec_hash;
+#endif
+    }
+    else
+    {
+        return ( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+    }
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+    /* Check hash */
+    if( hash != mbedtls_hash( key, transform->key_bitlen >> 3 ) )
+    {
+        return( MBEDTLS_ERR_PLATFORM_FAULT_DETECTED );
+    }
+#endif
+    if( operation != transform->cipher_ctx.operation )
+    {
+        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
+                                           key,
+                                           transform->key_bitlen,
+                                           operation ) ) != 0 )
+        {
+            return( ret );
+        }
+    }
+    return( 0 );
+}
+#endif
+
 #if defined(MBEDTLS_USE_TINYCRYPT)
 static int uecc_rng_wrapper( uint8_t *dest, unsigned int size )
 {
@@ -1577,6 +1625,11 @@
     memcpy( transform->key_dec, key2, cipher_info->key_bitlen >> 3 );
 
     transform->key_bitlen = cipher_info->key_bitlen;
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+    transform->key_enc_hash = mbedtls_hash( transform->key_enc, transform->key_bitlen >> 3 );
+    transform->key_dec_hash = mbedtls_hash( transform->key_dec, transform->key_bitlen >> 3 );
+#endif
+
 #else
     if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx_enc,
                                  cipher_info ) ) != 0 )
@@ -2697,12 +2750,10 @@
                                     "including %d bytes of padding",
                                     rec->data_len, 0 ) );
 #if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
-        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
-                                           transform->key_enc,
-                                           transform->key_bitlen,
-                                           MBEDTLS_ENCRYPT ) ) != 0 )
+        if( ( ret = mbedtls_ssl_switch_key( transform, MBEDTLS_ENCRYPT ) )
+                != 0 )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_switch_key", ret );
             return( ret );
         }
 
@@ -2798,12 +2849,10 @@
          * Encrypt and authenticate
          */
 #if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
-        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
-                                           transform->key_enc,
-                                           transform->key_bitlen,
-                                           MBEDTLS_ENCRYPT ) ) != 0 )
+        if( ( ret = mbedtls_ssl_switch_key( transform, MBEDTLS_ENCRYPT ) )
+                != 0 )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_switch_key", ret );
             return( ret );
         }
 
@@ -2905,12 +2954,10 @@
                             rec->data_len, transform->ivlen,
                             padlen + 1 ) );
 #if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
-        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
-                                           transform->key_enc,
-                                           transform->key_bitlen,
-                                           MBEDTLS_ENCRYPT ) ) != 0 )
+        if( ( ret = mbedtls_ssl_switch_key( transform, MBEDTLS_ENCRYPT ) )
+             != 0 )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_switch_key", ret );
             return( ret );
         }
 
@@ -3076,12 +3123,10 @@
     {
         padlen = 0;
 #if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
-        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
-                                           transform->key_dec,
-                                           transform->key_bitlen,
-                                           MBEDTLS_DECRYPT ) ) != 0 )
+        if( ( ret = mbedtls_ssl_switch_key( transform, MBEDTLS_DECRYPT ) )
+                != 0 )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_switch_key", ret );
             return( ret );
         }
         if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx,
@@ -3192,12 +3237,10 @@
          * Decrypt and authenticate
          */
 #if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
-        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
-                                           transform->key_dec,
-                                           transform->key_bitlen,
-                                           MBEDTLS_DECRYPT ) ) != 0 )
+        if( ( ret = mbedtls_ssl_switch_key( transform, MBEDTLS_DECRYPT ) )
+                != 0 )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_switch_key", ret );
             return( ret );
         }
         if( ( ret = mbedtls_cipher_auth_decrypt( &transform->cipher_ctx,
@@ -3376,14 +3419,13 @@
 
         /* We still have data_len % ivlen == 0 and data_len >= ivlen here. */
 #if defined(MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS)
-        if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx,
-                                           transform->key_dec,
-                                           transform->key_bitlen,
-                                           MBEDTLS_DECRYPT ) ) != 0 )
+        if( ( ret = mbedtls_ssl_switch_key( transform, MBEDTLS_DECRYPT ) )
+                != 0 )
         {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_switch_key", ret );
             return( ret );
         }
+
         if( ( ret = mbedtls_cipher_crypt( &transform->cipher_ctx,
                                    transform->iv_dec, transform->ivlen,
                                    data, rec->data_len, data, &olen ) ) != 0 )
diff --git a/library/version_features.c b/library/version_features.c
index ec4a692..2ef9d12 100644
--- a/library/version_features.c
+++ b/library/version_features.c
@@ -687,6 +687,9 @@
 #if defined(MBEDTLS_CRC_C)
     "MBEDTLS_CRC_C",
 #endif /* MBEDTLS_CRC_C */
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+    "MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY",
+#endif /* MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY */
 #if defined(MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY)
     "MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY",
 #endif /* MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY */
diff --git a/programs/ssl/query_config.c b/programs/ssl/query_config.c
index e8fd634..ac0ef2e 100644
--- a/programs/ssl/query_config.c
+++ b/programs/ssl/query_config.c
@@ -1874,6 +1874,14 @@
     }
 #endif /* MBEDTLS_CRC_C */
 
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+    if( strcmp( "MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY", config ) == 0 )
+    {
+        MACRO_EXPANSION_TO_STR( MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY );
+        return( 0 );
+    }
+#endif /* MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY */
+
 #if defined(MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY)
     if( strcmp( "MBEDTLS_VALIDATE_AES_KEYS_INTEGRITY", config ) == 0 )
     {
diff --git a/scripts/config.pl b/scripts/config.pl
index 5d2b28e..1c3422e 100755
--- a/scripts/config.pl
+++ b/scripts/config.pl
@@ -58,6 +58,7 @@
 #   MBEDTLS_AES_SCA_COUNTERMEASURES
 #   MBEDTLS_CTR_DRBG_USE_128_BIT_KEY
 #   MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS
+#   MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY
 #   and any symbol beginning _ALT
 #
 # The baremetal configuration excludes options that require a library or
@@ -142,6 +143,7 @@
 MBEDTLS_AES_SCA_COUNTERMEASURES
 MBEDTLS_CTR_DRBG_USE_128_BIT_KEY
 MBEDTLS_SSL_TRANSFORM_OPTIMIZE_CIPHERS
+MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY
 _ALT\s*$
 );
 
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index a689d45..68d8442 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -83,6 +83,14 @@
     memcpy( t_out->key_dec, key0, keylen);
     t_out->key_bitlen = cipher_info->key_bitlen;
 
+#if defined(MBEDTLS_VALIDATE_SSL_KEYS_INTEGRITY)
+    t_in->key_enc_hash = mbedtls_hash( t_in->key_enc, t_in->key_bitlen >> 3 );
+    t_in->key_dec_hash = mbedtls_hash( t_in->key_dec, t_in->key_bitlen >> 3 );
+
+    t_out->key_enc_hash = mbedtls_hash( t_out->key_enc, t_out->key_bitlen >> 3 );
+    t_out->key_dec_hash = mbedtls_hash( t_out->key_dec, t_out->key_bitlen >> 3 );
+#endif
+
     /* Setup cipher contexts */
     CHK( mbedtls_cipher_setup( &t_in->cipher_ctx,  cipher_info ) == 0 );
     CHK( mbedtls_cipher_setup( &t_out->cipher_ctx, cipher_info ) == 0 );