Merge pull request #7301 from gilles-peskine-arm/msan-explicit_bzero

Fix Msan failure with explicit_bzero
diff --git a/ChangeLog.d/add-cache-remove-api.txt b/ChangeLog.d/add-cache-remove-api.txt
new file mode 100644
index 0000000..950ff97
--- /dev/null
+++ b/ChangeLog.d/add-cache-remove-api.txt
@@ -0,0 +1,5 @@
+Features
+   * Add new API mbedtls_ssl_cache_remove for cache entry removal by
+     its session id.
+Security
+   * Zeroize SSL cache entries when they are freed.
diff --git a/ChangeLog.d/fix-oid-to-string-bugs.txt b/ChangeLog.d/fix-oid-to-string-bugs.txt
index 799f444..3cf02c3 100644
--- a/ChangeLog.d/fix-oid-to-string-bugs.txt
+++ b/ChangeLog.d/fix-oid-to-string-bugs.txt
@@ -3,4 +3,8 @@
      mbedtls_oid_get_numeric_string(). OIDs such as 2.40.0.25 are now printed
      correctly.
    * Reject OIDs with overlong-encoded subidentifiers when converting
-     OID-to-string.
+     them to a string.
+   * Reject OIDs with subidentifier values exceeding UINT_MAX.  Such
+     subidentifiers can be valid, but Mbed TLS cannot currently handle them.
+   * Reject OIDs that have unterminated subidentifiers, or (equivalently)
+     have the most-significant bit set in their last byte.
diff --git a/include/mbedtls/ssl_cache.h b/include/mbedtls/ssl_cache.h
index 5cd1cd3..55dcf77 100644
--- a/include/mbedtls/ssl_cache.h
+++ b/include/mbedtls/ssl_cache.h
@@ -123,6 +123,23 @@
                           size_t session_id_len,
                           const mbedtls_ssl_session *session);
 
+/**
+ * \brief          Remove the cache entry by the session ID
+ *                 (Thread-safe if MBEDTLS_THREADING_C is enabled)
+ *
+ * \param data            The SSL cache context to use.
+ * \param session_id      The pointer to the buffer holding the session ID
+ *                        associated to \p session.
+ * \param session_id_len  The length of \p session_id in bytes.
+ *
+ * \return                0: The cache entry for session with provided ID
+ *                           is removed or does not exist.
+ *                        Otherwise: fail.
+ */
+int mbedtls_ssl_cache_remove(void *data,
+                             unsigned char const *session_id,
+                             size_t session_id_len);
+
 #if defined(MBEDTLS_HAVE_TIME)
 /**
  * \brief          Set the cache timeout
diff --git a/library/oid.c b/library/oid.c
index 86214b2..63b3df3 100644
--- a/library/oid.c
+++ b/library/oid.c
@@ -813,65 +813,26 @@
                  cipher_alg)
 #endif /* MBEDTLS_PKCS12_C */
 
-#define OID_SAFE_SNPRINTF                               \
-    do {                                                \
-        if (ret < 0 || (size_t) ret >= n)              \
-        return MBEDTLS_ERR_OID_BUF_TOO_SMALL;    \
-                                                      \
-        n -= (size_t) ret;                              \
-        p += (size_t) ret;                              \
-    } while (0)
-
 /* Return the x.y.z.... style numeric string for the given OID */
 int mbedtls_oid_get_numeric_string(char *buf, size_t size,
                                    const mbedtls_asn1_buf *oid)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    size_t i, n;
-    unsigned int value;
-    char *p;
+    char *p = buf;
+    size_t n = size;
+    unsigned int value = 0;
 
-    p = buf;
-    n = size;
-
-    /* First subidentifier contains first two OID components */
-    i = 0;
-    value = 0;
-    if ((oid->p[0]) == 0x80) {
-        /* Overlong encoding is not allowed */
-        return MBEDTLS_ERR_ASN1_INVALID_DATA;
+    if (size > INT_MAX) {
+        /* Avoid overflow computing return value */
+        return MBEDTLS_ERR_ASN1_INVALID_LENGTH;
     }
 
-    while (i < oid->len && ((oid->p[i] & 0x80) != 0)) {
-        /* Prevent overflow in value. */
-        if (value > (UINT_MAX >> 7)) {
-            return MBEDTLS_ERR_ASN1_INVALID_DATA;
-        }
-
-        value |= oid->p[i] & 0x7F;
-        value <<= 7;
-        i++;
-    }
-    if (i >= oid->len) {
+    if (oid->len <= 0) {
+        /* OID must not be empty */
         return MBEDTLS_ERR_ASN1_OUT_OF_DATA;
     }
-    /* Last byte of first subidentifier */
-    value |= oid->p[i] & 0x7F;
-    i++;
 
-    unsigned int component1 = value / 40;
-    if (component1 > 2) {
-        /* The first component can only be 0, 1 or 2.
-         * If oid->p[0] / 40 is greater than 2, the leftover belongs to
-         * the second component. */
-        component1 = 2;
-    }
-    unsigned int component2 = value - (40 * component1);
-    ret = mbedtls_snprintf(p, n, "%u.%u", component1, component2);
-    OID_SAFE_SNPRINTF;
-
-    value = 0;
-    for (; i < oid->len; i++) {
+    for (size_t i = 0; i < oid->len; i++) {
         /* Prevent overflow in value. */
         if (value > (UINT_MAX >> 7)) {
             return MBEDTLS_ERR_ASN1_INVALID_DATA;
@@ -886,12 +847,38 @@
 
         if (!(oid->p[i] & 0x80)) {
             /* Last byte */
-            ret = mbedtls_snprintf(p, n, ".%u", value);
-            OID_SAFE_SNPRINTF;
+            if (n == size) {
+                int component1;
+                unsigned int component2;
+                /* First subidentifier contains first two OID components */
+                if (value >= 80) {
+                    component1 = '2';
+                    component2 = value - 80;
+                } else if (value >= 40) {
+                    component1 = '1';
+                    component2 = value - 40;
+                } else {
+                    component1 = '0';
+                    component2 = value;
+                }
+                ret = mbedtls_snprintf(p, n, "%c.%u", component1, component2);
+            } else {
+                ret = mbedtls_snprintf(p, n, ".%u", value);
+            }
+            if (ret < 2 || (size_t) ret >= n) {
+                return MBEDTLS_ERR_OID_BUF_TOO_SMALL;
+            }
+            n -= (size_t) ret;
+            p += ret;
             value = 0;
         }
     }
 
+    if (value != 0) {
+        /* Unterminated subidentifier */
+        return MBEDTLS_ERR_ASN1_OUT_OF_DATA;
+    }
+
     return (int) (size - n);
 }
 
diff --git a/library/ssl_cache.c b/library/ssl_cache.c
index 7c16e10..048c21d 100644
--- a/library/ssl_cache.c
+++ b/library/ssl_cache.c
@@ -92,8 +92,8 @@
     mbedtls_ssl_cache_entry *entry;
 
 #if defined(MBEDTLS_THREADING_C)
-    if (mbedtls_mutex_lock(&cache->mutex) != 0) {
-        return 1;
+    if ((ret = mbedtls_mutex_lock(&cache->mutex)) != 0) {
+        return ret;
     }
 #endif
 
@@ -114,13 +114,30 @@
 exit:
 #if defined(MBEDTLS_THREADING_C)
     if (mbedtls_mutex_unlock(&cache->mutex) != 0) {
-        ret = 1;
+        ret = MBEDTLS_ERR_THREADING_MUTEX_ERROR;
     }
 #endif
 
     return ret;
 }
 
+/* zeroize a cache entry */
+static void ssl_cache_entry_zeroize(mbedtls_ssl_cache_entry *entry)
+{
+    if (entry == NULL) {
+        return;
+    }
+
+    /* zeroize and free session structure */
+    if (entry->session != NULL) {
+        mbedtls_platform_zeroize(entry->session, entry->session_len);
+        mbedtls_free(entry->session);
+    }
+
+    /* zeroize the whole entry structure */
+    mbedtls_platform_zeroize(entry, sizeof(mbedtls_ssl_cache_entry));
+}
+
 MBEDTLS_CHECK_RETURN_CRITICAL
 static int ssl_cache_pick_writing_slot(mbedtls_ssl_cache_context *cache,
                                        unsigned char const *session_id,
@@ -220,19 +237,19 @@
 
 found:
 
+    /* If we're reusing an entry, free it first. */
+    if (cur->session != NULL) {
+        /* `ssl_cache_entry_zeroize` would break the chain,
+         * so we reuse `old` to record `next` temporarily. */
+        old = cur->next;
+        ssl_cache_entry_zeroize(cur);
+        cur->next = old;
+    }
+
 #if defined(MBEDTLS_HAVE_TIME)
     cur->timestamp = t;
 #endif
 
-    /* If we're reusing an entry, free it first. */
-    if (cur->session != NULL) {
-        mbedtls_free(cur->session);
-        cur->session = NULL;
-        cur->session_len = 0;
-        memset(cur->session_id, 0, sizeof(cur->session_id));
-        cur->session_id_len = 0;
-    }
-
     *dst = cur;
     return 0;
 }
@@ -301,7 +318,7 @@
 exit:
 #if defined(MBEDTLS_THREADING_C)
     if (mbedtls_mutex_unlock(&cache->mutex) != 0) {
-        ret = 1;
+        ret = MBEDTLS_ERR_THREADING_MUTEX_ERROR;
     }
 #endif
 
@@ -314,6 +331,55 @@
     return ret;
 }
 
+int mbedtls_ssl_cache_remove(void *data,
+                             unsigned char const *session_id,
+                             size_t session_id_len)
+{
+    int ret = 1;
+    mbedtls_ssl_cache_context *cache = (mbedtls_ssl_cache_context *) data;
+    mbedtls_ssl_cache_entry *entry;
+    mbedtls_ssl_cache_entry *prev;
+
+#if defined(MBEDTLS_THREADING_C)
+    if ((ret = mbedtls_mutex_lock(&cache->mutex)) != 0) {
+        return ret;
+    }
+#endif
+
+    ret = ssl_cache_find_entry(cache, session_id, session_id_len, &entry);
+    /* No valid entry found, exit with success */
+    if (ret != 0) {
+        ret = 0;
+        goto exit;
+    }
+
+    /* Now we remove the entry from the chain */
+    if (entry == cache->chain) {
+        cache->chain = entry->next;
+        goto free;
+    }
+    for (prev = cache->chain; prev->next != NULL; prev = prev->next) {
+        if (prev->next == entry) {
+            prev->next = entry->next;
+            break;
+        }
+    }
+
+free:
+    ssl_cache_entry_zeroize(entry);
+    mbedtls_free(entry);
+    ret = 0;
+
+exit:
+#if defined(MBEDTLS_THREADING_C)
+    if (mbedtls_mutex_unlock(&cache->mutex) != 0) {
+        ret = MBEDTLS_ERR_THREADING_MUTEX_ERROR;
+    }
+#endif
+
+    return ret;
+}
+
 #if defined(MBEDTLS_HAVE_TIME)
 void mbedtls_ssl_cache_set_timeout(mbedtls_ssl_cache_context *cache, int timeout)
 {
@@ -344,7 +410,7 @@
         prv = cur;
         cur = cur->next;
 
-        mbedtls_free(prv->session);
+        ssl_cache_entry_zeroize(prv);
         mbedtls_free(prv);
     }
 
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index 88c2192..1c34861 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -129,6 +129,7 @@
 #define DFL_TICKET_AEAD         MBEDTLS_CIPHER_AES_256_GCM
 #define DFL_CACHE_MAX           -1
 #define DFL_CACHE_TIMEOUT       -1
+#define DFL_CACHE_REMOVE        0
 #define DFL_SNI                 NULL
 #define DFL_ALPN_STRING         NULL
 #define DFL_CURVES              NULL
@@ -321,7 +322,8 @@
 
 #if defined(MBEDTLS_SSL_CACHE_C)
 #define USAGE_CACHE                                             \
-    "    cache_max=%%d        default: cache default (50)\n"
+    "    cache_max=%%d        default: cache default (50)\n"    \
+    "    cache_remove=%%d     default: 0 (don't remove)\n"
 #if defined(MBEDTLS_HAVE_TIME)
 #define USAGE_CACHE_TIME \
     "    cache_timeout=%%d    default: cache default (1d)\n"
@@ -669,6 +671,7 @@
 #if defined(MBEDTLS_HAVE_TIME)
     int cache_timeout;          /* expiration delay of session cache entries*/
 #endif
+    int cache_remove;           /* enable / disable cache removement        */
     char *sni;                  /* string describing sni information        */
     const char *curves;         /* list of supported elliptic curves        */
     const char *sig_algs;       /* supported TLS 1.3 signature algorithms   */
@@ -1731,6 +1734,7 @@
 #if defined(MBEDTLS_HAVE_TIME)
     opt.cache_timeout       = DFL_CACHE_TIMEOUT;
 #endif
+    opt.cache_remove        = DFL_CACHE_REMOVE;
     opt.sni                 = DFL_SNI;
     opt.alpn_string         = DFL_ALPN_STRING;
     opt.curves              = DFL_CURVES;
@@ -2144,7 +2148,12 @@
             }
         }
 #endif
-        else if (strcmp(p, "cookies") == 0) {
+        else if (strcmp(p, "cache_remove") == 0) {
+            opt.cache_remove = atoi(q);
+            if (opt.cache_remove < 0 || opt.cache_remove > 1) {
+                goto usage;
+            }
+        } else if (strcmp(p, "cookies") == 0) {
             opt.cookies = atoi(q);
             if (opt.cookies < -1 || opt.cookies > 1) {
                 goto usage;
@@ -4127,6 +4136,12 @@
 
     mbedtls_printf(" done\n");
 
+#if defined(MBEDTLS_SSL_CACHE_C)
+    if (opt.cache_remove > 0) {
+        mbedtls_ssl_cache_remove(&cache, ssl.session->id, ssl.session->id_len);
+    }
+#endif
+
     goto reset;
 
     /*
diff --git a/tests/ssl-opt.sh b/tests/ssl-opt.sh
index e2b1e04..663f194 100755
--- a/tests/ssl-opt.sh
+++ b/tests/ssl-opt.sh
@@ -4142,6 +4142,22 @@
 
 requires_config_enabled MBEDTLS_SSL_PROTO_TLS1_2
 requires_config_enabled MBEDTLS_SSL_CACHE_C
+run_test    "Session resume using cache: cache removed" \
+            "$P_SRV debug_level=3 tickets=0 cache_remove=1" \
+            "$P_CLI debug_level=3 tickets=0 reconnect=1" \
+            0 \
+            -C "client hello, adding session ticket extension" \
+            -S "found session ticket extension" \
+            -S "server hello, adding session ticket extension" \
+            -C "found session_ticket extension" \
+            -C "parse new session ticket" \
+            -S "session successfully restored from cache" \
+            -S "session successfully restored from ticket" \
+            -S "a session has been resumed" \
+            -C "a session has been resumed"
+
+requires_config_enabled MBEDTLS_SSL_PROTO_TLS1_2
+requires_config_enabled MBEDTLS_SSL_CACHE_C
 run_test    "Session resume using cache: timeout > delay" \
             "$P_SRV debug_level=3 tickets=0" \
             "$P_CLI debug_level=3 tickets=0 reconnect=1 reco_delay=0" \
diff --git a/tests/suites/test_suite_oid.data b/tests/suites/test_suite_oid.data
index b9fa654..75213e9 100644
--- a/tests/suites/test_suite_oid.data
+++ b/tests/suites/test_suite_oid.data
@@ -101,12 +101,30 @@
 OID get numeric string - multi-byte first subidentifier
 oid_get_numeric_string:"8837":0:"2.999"
 
+OID get numeric string - second subidentifier not terminated
+oid_get_numeric_string:"0081":MBEDTLS_ERR_ASN1_OUT_OF_DATA:""
+
 OID get numeric string - empty oid buffer
 oid_get_numeric_string:"":MBEDTLS_ERR_ASN1_OUT_OF_DATA:""
 
 OID get numeric string - no final / all bytes have top bit set
 oid_get_numeric_string:"818181":MBEDTLS_ERR_ASN1_OUT_OF_DATA:""
 
+OID get numeric string - 0.39
+oid_get_numeric_string:"27":0:"0.39"
+
+OID get numeric string - 1.0
+oid_get_numeric_string:"28":0:"1.0"
+
+OID get numeric string - 1.39
+oid_get_numeric_string:"4f":0:"1.39"
+
+OID get numeric string - 2.0
+oid_get_numeric_string:"50":0:"2.0"
+
+OID get numeric string - 1 byte first subidentifier beyond 2.39
+oid_get_numeric_string:"7f":0:"2.47"
+
 # Encodes the number 0x0400000000 as a subidentifier which overflows 32-bits
 OID get numeric string - 32-bit overflow
 oid_get_numeric_string:"C080808000":MBEDTLS_ERR_ASN1_INVALID_DATA:""
diff --git a/tests/suites/test_suite_oid.function b/tests/suites/test_suite_oid.function
index 3004b65..5fbc9b5 100644
--- a/tests/suites/test_suite_oid.function
+++ b/tests/suites/test_suite_oid.function
@@ -105,13 +105,16 @@
     int ret;
 
     input_oid.tag = MBEDTLS_ASN1_OID;
-    input_oid.p = oid->x;
+    /* Test that an empty OID is not dereferenced */
+    input_oid.p = oid->len ? oid->x : (void *) 1;
     input_oid.len = oid->len;
 
     ret = mbedtls_oid_get_numeric_string(buf, sizeof(buf), &input_oid);
 
     if (error_ret == 0) {
-        TEST_ASSERT(strcmp(buf, result_str) == 0);
+        TEST_EQUAL(ret, strlen(result_str));
+        TEST_ASSERT(ret >= 3);
+        TEST_EQUAL(strcmp(buf, result_str), 0);
     } else {
         TEST_EQUAL(ret, error_ret);
     }