Merge pull request #7217 from lpy4105/issue/6840/add-cache-entry-removal-api

ssl_cache: Add cache entry removal api
diff --git a/ChangeLog.d/add-cache-remove-api.txt b/ChangeLog.d/add-cache-remove-api.txt
new file mode 100644
index 0000000..950ff97
--- /dev/null
+++ b/ChangeLog.d/add-cache-remove-api.txt
@@ -0,0 +1,5 @@
+Features
+   * Add new API mbedtls_ssl_cache_remove for cache entry removal by
+     its session id.
+Security
+   * Zeroize SSL cache entries when they are freed.
diff --git a/include/mbedtls/ssl_cache.h b/include/mbedtls/ssl_cache.h
index 5cd1cd3..55dcf77 100644
--- a/include/mbedtls/ssl_cache.h
+++ b/include/mbedtls/ssl_cache.h
@@ -123,6 +123,23 @@
                           size_t session_id_len,
                           const mbedtls_ssl_session *session);
 
+/**
+ * \brief          Remove the cache entry by the session ID
+ *                 (Thread-safe if MBEDTLS_THREADING_C is enabled)
+ *
+ * \param data            The SSL cache context to use.
+ * \param session_id      The pointer to the buffer holding the session ID
+ *                        associated to \p session.
+ * \param session_id_len  The length of \p session_id in bytes.
+ *
+ * \return                0: The cache entry for session with provided ID
+ *                           is removed or does not exist.
+ *                        Otherwise: fail.
+ */
+int mbedtls_ssl_cache_remove(void *data,
+                             unsigned char const *session_id,
+                             size_t session_id_len);
+
 #if defined(MBEDTLS_HAVE_TIME)
 /**
  * \brief          Set the cache timeout
diff --git a/library/ssl_cache.c b/library/ssl_cache.c
index 7c16e10..048c21d 100644
--- a/library/ssl_cache.c
+++ b/library/ssl_cache.c
@@ -92,8 +92,8 @@
     mbedtls_ssl_cache_entry *entry;
 
 #if defined(MBEDTLS_THREADING_C)
-    if (mbedtls_mutex_lock(&cache->mutex) != 0) {
-        return 1;
+    if ((ret = mbedtls_mutex_lock(&cache->mutex)) != 0) {
+        return ret;
     }
 #endif
 
@@ -114,13 +114,30 @@
 exit:
 #if defined(MBEDTLS_THREADING_C)
     if (mbedtls_mutex_unlock(&cache->mutex) != 0) {
-        ret = 1;
+        ret = MBEDTLS_ERR_THREADING_MUTEX_ERROR;
     }
 #endif
 
     return ret;
 }
 
+/* zeroize a cache entry */
+static void ssl_cache_entry_zeroize(mbedtls_ssl_cache_entry *entry)
+{
+    if (entry == NULL) {
+        return;
+    }
+
+    /* zeroize and free session structure */
+    if (entry->session != NULL) {
+        mbedtls_platform_zeroize(entry->session, entry->session_len);
+        mbedtls_free(entry->session);
+    }
+
+    /* zeroize the whole entry structure */
+    mbedtls_platform_zeroize(entry, sizeof(mbedtls_ssl_cache_entry));
+}
+
 MBEDTLS_CHECK_RETURN_CRITICAL
 static int ssl_cache_pick_writing_slot(mbedtls_ssl_cache_context *cache,
                                        unsigned char const *session_id,
@@ -220,19 +237,19 @@
 
 found:
 
+    /* If we're reusing an entry, free it first. */
+    if (cur->session != NULL) {
+        /* `ssl_cache_entry_zeroize` would break the chain,
+         * so we reuse `old` to record `next` temporarily. */
+        old = cur->next;
+        ssl_cache_entry_zeroize(cur);
+        cur->next = old;
+    }
+
 #if defined(MBEDTLS_HAVE_TIME)
     cur->timestamp = t;
 #endif
 
-    /* If we're reusing an entry, free it first. */
-    if (cur->session != NULL) {
-        mbedtls_free(cur->session);
-        cur->session = NULL;
-        cur->session_len = 0;
-        memset(cur->session_id, 0, sizeof(cur->session_id));
-        cur->session_id_len = 0;
-    }
-
     *dst = cur;
     return 0;
 }
@@ -301,7 +318,7 @@
 exit:
 #if defined(MBEDTLS_THREADING_C)
     if (mbedtls_mutex_unlock(&cache->mutex) != 0) {
-        ret = 1;
+        ret = MBEDTLS_ERR_THREADING_MUTEX_ERROR;
     }
 #endif
 
@@ -314,6 +331,55 @@
     return ret;
 }
 
+int mbedtls_ssl_cache_remove(void *data,
+                             unsigned char const *session_id,
+                             size_t session_id_len)
+{
+    int ret = 1;
+    mbedtls_ssl_cache_context *cache = (mbedtls_ssl_cache_context *) data;
+    mbedtls_ssl_cache_entry *entry;
+    mbedtls_ssl_cache_entry *prev;
+
+#if defined(MBEDTLS_THREADING_C)
+    if ((ret = mbedtls_mutex_lock(&cache->mutex)) != 0) {
+        return ret;
+    }
+#endif
+
+    ret = ssl_cache_find_entry(cache, session_id, session_id_len, &entry);
+    /* No valid entry found, exit with success */
+    if (ret != 0) {
+        ret = 0;
+        goto exit;
+    }
+
+    /* Now we remove the entry from the chain */
+    if (entry == cache->chain) {
+        cache->chain = entry->next;
+        goto free;
+    }
+    for (prev = cache->chain; prev->next != NULL; prev = prev->next) {
+        if (prev->next == entry) {
+            prev->next = entry->next;
+            break;
+        }
+    }
+
+free:
+    ssl_cache_entry_zeroize(entry);
+    mbedtls_free(entry);
+    ret = 0;
+
+exit:
+#if defined(MBEDTLS_THREADING_C)
+    if (mbedtls_mutex_unlock(&cache->mutex) != 0) {
+        ret = MBEDTLS_ERR_THREADING_MUTEX_ERROR;
+    }
+#endif
+
+    return ret;
+}
+
 #if defined(MBEDTLS_HAVE_TIME)
 void mbedtls_ssl_cache_set_timeout(mbedtls_ssl_cache_context *cache, int timeout)
 {
@@ -344,7 +410,7 @@
         prv = cur;
         cur = cur->next;
 
-        mbedtls_free(prv->session);
+        ssl_cache_entry_zeroize(prv);
         mbedtls_free(prv);
     }
 
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index 88c2192..1c34861 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -129,6 +129,7 @@
 #define DFL_TICKET_AEAD         MBEDTLS_CIPHER_AES_256_GCM
 #define DFL_CACHE_MAX           -1
 #define DFL_CACHE_TIMEOUT       -1
+#define DFL_CACHE_REMOVE        0
 #define DFL_SNI                 NULL
 #define DFL_ALPN_STRING         NULL
 #define DFL_CURVES              NULL
@@ -321,7 +322,8 @@
 
 #if defined(MBEDTLS_SSL_CACHE_C)
 #define USAGE_CACHE                                             \
-    "    cache_max=%%d        default: cache default (50)\n"
+    "    cache_max=%%d        default: cache default (50)\n"    \
+    "    cache_remove=%%d     default: 0 (don't remove)\n"
 #if defined(MBEDTLS_HAVE_TIME)
 #define USAGE_CACHE_TIME \
     "    cache_timeout=%%d    default: cache default (1d)\n"
@@ -669,6 +671,7 @@
 #if defined(MBEDTLS_HAVE_TIME)
     int cache_timeout;          /* expiration delay of session cache entries*/
 #endif
+    int cache_remove;           /* enable / disable cache removement        */
     char *sni;                  /* string describing sni information        */
     const char *curves;         /* list of supported elliptic curves        */
     const char *sig_algs;       /* supported TLS 1.3 signature algorithms   */
@@ -1731,6 +1734,7 @@
 #if defined(MBEDTLS_HAVE_TIME)
     opt.cache_timeout       = DFL_CACHE_TIMEOUT;
 #endif
+    opt.cache_remove        = DFL_CACHE_REMOVE;
     opt.sni                 = DFL_SNI;
     opt.alpn_string         = DFL_ALPN_STRING;
     opt.curves              = DFL_CURVES;
@@ -2144,7 +2148,12 @@
             }
         }
 #endif
-        else if (strcmp(p, "cookies") == 0) {
+        else if (strcmp(p, "cache_remove") == 0) {
+            opt.cache_remove = atoi(q);
+            if (opt.cache_remove < 0 || opt.cache_remove > 1) {
+                goto usage;
+            }
+        } else if (strcmp(p, "cookies") == 0) {
             opt.cookies = atoi(q);
             if (opt.cookies < -1 || opt.cookies > 1) {
                 goto usage;
@@ -4127,6 +4136,12 @@
 
     mbedtls_printf(" done\n");
 
+#if defined(MBEDTLS_SSL_CACHE_C)
+    if (opt.cache_remove > 0) {
+        mbedtls_ssl_cache_remove(&cache, ssl.session->id, ssl.session->id_len);
+    }
+#endif
+
     goto reset;
 
     /*
diff --git a/tests/ssl-opt.sh b/tests/ssl-opt.sh
index e2b1e04..663f194 100755
--- a/tests/ssl-opt.sh
+++ b/tests/ssl-opt.sh
@@ -4142,6 +4142,22 @@
 
 requires_config_enabled MBEDTLS_SSL_PROTO_TLS1_2
 requires_config_enabled MBEDTLS_SSL_CACHE_C
+run_test    "Session resume using cache: cache removed" \
+            "$P_SRV debug_level=3 tickets=0 cache_remove=1" \
+            "$P_CLI debug_level=3 tickets=0 reconnect=1" \
+            0 \
+            -C "client hello, adding session ticket extension" \
+            -S "found session ticket extension" \
+            -S "server hello, adding session ticket extension" \
+            -C "found session_ticket extension" \
+            -C "parse new session ticket" \
+            -S "session successfully restored from cache" \
+            -S "session successfully restored from ticket" \
+            -S "a session has been resumed" \
+            -C "a session has been resumed"
+
+requires_config_enabled MBEDTLS_SSL_PROTO_TLS1_2
+requires_config_enabled MBEDTLS_SSL_CACHE_C
 run_test    "Session resume using cache: timeout > delay" \
             "$P_SRV debug_level=3 tickets=0" \
             "$P_CLI debug_level=3 tickets=0 reconnect=1 reco_delay=0" \