Add user pointer and data size duplication to ssl context.

Signed-off-by: Shelly Liberman <shelly.liberman@arm.com>
diff --git a/configs/baremetal.h b/configs/baremetal.h
index 5294351..6ed4b84 100644
--- a/configs/baremetal.h
+++ b/configs/baremetal.h
@@ -146,6 +146,9 @@
 
 #define MBEDTLS_DEPRECATED_REMOVED
 
+/* Fault Injection Countermesures */
+#define MBEDTLS_FI_COUNTERMEASURES
+
 #if defined(MBEDTLS_USER_CONFIG_FILE)
 #include MBEDTLS_USER_CONFIG_FILE
 #endif
diff --git a/include/mbedtls/config.h b/include/mbedtls/config.h
index 4ee5920..4ac141e 100644
--- a/include/mbedtls/config.h
+++ b/include/mbedtls/config.h
@@ -655,6 +655,16 @@
 //#define MBEDTLS_AES_SCA_COUNTERMEASURES
 
 /**
+ * \def MBEDTLS_FI_COUNTERMEASURES
+ *
+ * Add countermeasures against possible  FI attack.
+ *
+ * Uncommenting this macro inrease sode size and slow performence,
+ * it peforms double calls and double result checks of some crypto functions
+ */
+//#define MBEDTLS_FI_COUNTERMEASURES
+
+/**
  * \def MBEDTLS_CAMELLIA_SMALL_MEMORY
  *
  * Use less ROM for the Camellia implementation (saves about 768 bytes).
diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h
index e14f58f..ee231a5 100644
--- a/include/mbedtls/ssl.h
+++ b/include/mbedtls/ssl.h
@@ -1460,6 +1460,10 @@
      *  after an initial handshake. */
     unsigned char own_cid[ MBEDTLS_SSL_CID_IN_LEN_MAX ];
 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
+#if defined(MBEDTLS_FI_COUNTERMEASURES)
+    unsigned char *out_msg_dup;     /*!< out msg ptr duplication  */
+    size_t out_msglen_dup;          /*!< out msg size duplication */
+#endif
 };
 
 #if defined(MBEDTLS_SSL_HW_RECORD_ACCEL)
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index bbe94cb..a450819 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -2562,7 +2562,6 @@
     /* Not using more secure mbedtls_platform_memcpy as cid is public */
     memcpy( rec->cid, transform->out_cid, transform->out_cid_len );
     MBEDTLS_SSL_DEBUG_BUF( 3, "CID", rec->cid, rec->cid_len );
-
     if( rec->cid_len != 0 )
     {
         int ret = MBEDTLS_ERR_PLATFORM_FAULT_DETECTED;
@@ -11221,8 +11220,6 @@
 {
     int ret = mbedtls_ssl_get_max_out_record_payload( ssl );
     const size_t max_len = (size_t) ret;
-    volatile const unsigned char *buf_dup = buf;
-    volatile size_t len_dup = len;
 
     if( ret < 0 )
     {
@@ -11245,7 +11242,6 @@
 #if defined(MBEDTLS_SSL_PROTO_TLS)
         {
             len = max_len;
-            len_dup = len;
         }
 #endif
     }
@@ -11271,22 +11267,40 @@
          * copy the data into the internal buffers and setup the data structure
          * to keep track of partial writes
          */
-        ssl->out_msglen  = len;
+        ssl->out_msglen = len;
         ssl->out_msgtype = MBEDTLS_SSL_MSG_APPLICATION_DATA;
-        mbedtls_platform_memcpy( ssl->out_msg, buf, len );
+        mbedtls_platform_memcpy(ssl->out_msg, buf, len);
 
-        if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 )
-        {
-            MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret );
-            return( ret );
+#if defined(MBEDTLS_FI_COUNTERMEASURES) && !defined(MBEDTLS_SSL_CBC_RECORD_SPLITTING)
+        /* Secure against buffer substitution */
+        if (buf == ssl->out_msg_dup &&
+            ssl->out_msglen == ssl->out_msglen_dup &&
+            ssl->out_msg_dup[0] == ssl->out_msg[0])
+        {/*write record only if data was copied from correct user pointer */
+#endif
+            if ((ret = mbedtls_ssl_write_record(ssl, SSL_FORCE_FLUSH)) != 0)
+            {
+                MBEDTLS_SSL_DEBUG_RET(1, "mbedtls_ssl_write_record", ret);
+                return(ret);
+            }
+
+#if defined(MBEDTLS_FI_COUNTERMEASURES) && !defined(MBEDTLS_SSL_CBC_RECORD_SPLITTING)
         }
+        else
+        {
+            return(MBEDTLS_ERR_PLATFORM_FAULT_DETECTED);
+        }
+#endif
     }
-    /* Secure against buffer substitution */
-    if( buf_dup == buf && len_dup == len )
+    if (ret == 0)
     {
-        return( (int) len );
+        return((int)len);
     }
-    return( MBEDTLS_ERR_PLATFORM_FAULT_DETECTED );
+    else
+    {
+        return(MBEDTLS_ERR_PLATFORM_FAULT_DETECTED);
+    }
+
 }
 
 /*
@@ -11334,10 +11348,11 @@
  */
 int mbedtls_ssl_write( mbedtls_ssl_context *ssl, const unsigned char *buf, size_t len )
 {
-    int ret;
+    int ret = MBEDTLS_ERR_PLATFORM_FAULT_DETECTED;
+#if defined(MBEDTLS_FI_COUNTERMEASURES) && !defined(MBEDTLS_SSL_CBC_RECORD_SPLITTING)
     volatile const unsigned char *buf_dup = buf;
     volatile size_t len_dup = len;
-
+#endif
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write" ) );
 
     if( ssl == NULL || ssl->conf == NULL )
@@ -11363,17 +11378,19 @@
 #if defined(MBEDTLS_SSL_CBC_RECORD_SPLITTING)
     ret = ssl_write_split( ssl, buf, len );
 #else
+#if defined(MBEDTLS_FI_COUNTERMEASURES)
+    /*Add const user pointers to context. We will be able to check its validity before copy to context*/
+    ssl->out_msg_dup = (unsigned char*)buf_dup;
+    ssl->out_msglen_dup = len_dup;
+#endif //MBEDTLS_FI_COUNTERMEASURES
     ret = ssl_write_real( ssl, buf, len );
 #endif
 
     MBEDTLS_SSL_DEBUG_MSG( 2, ( "<= write" ) );
 
-    /* Secure against buffer substitution */
-    if( buf_dup == buf && len_dup == len )
-    {
-        return( ret );
-    }
-    return( MBEDTLS_ERR_PLATFORM_FAULT_DETECTED );
+
+    return( ret );
+
 }
 
 /*
diff --git a/library/version_features.c b/library/version_features.c
index d60758c..38a7cee 100644
--- a/library/version_features.c
+++ b/library/version_features.c
@@ -273,6 +273,9 @@
 #if defined(MBEDTLS_AES_SCA_COUNTERMEASURES)
     "MBEDTLS_AES_SCA_COUNTERMEASURES",
 #endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
+#if defined(MBEDTLS_FI_COUNTERMEASURES)
+    "MBEDTLS_FI_COUNTERMEASURES",
+#endif /* MBEDTLS_FI_COUNTERMEASURES */
 #if defined(MBEDTLS_CAMELLIA_SMALL_MEMORY)
     "MBEDTLS_CAMELLIA_SMALL_MEMORY",
 #endif /* MBEDTLS_CAMELLIA_SMALL_MEMORY */
diff --git a/programs/ssl/query_config.c b/programs/ssl/query_config.c
index 8093c0d..8db6d22 100644
--- a/programs/ssl/query_config.c
+++ b/programs/ssl/query_config.c
@@ -770,6 +770,14 @@
     }
 #endif /* MBEDTLS_AES_SCA_COUNTERMEASURES */
 
+#if defined(MBEDTLS_FI_COUNTERMEASURES)
+    if( strcmp( "MBEDTLS_FI_COUNTERMEASURES", config ) == 0 )
+    {
+        MACRO_EXPANSION_TO_STR( MBEDTLS_FI_COUNTERMEASURES );
+        return( 0 );
+    }
+#endif /* MBEDTLS_FI_COUNTERMEASURES */
+
 #if defined(MBEDTLS_CAMELLIA_SMALL_MEMORY)
     if( strcmp( "MBEDTLS_CAMELLIA_SMALL_MEMORY", config ) == 0 )
     {