Move mbedtls_ct_hmac into ssl_msg.c

Signed-off-by: Dave Rodgman <dave.rodgman@arm.com>
diff --git a/library/constant_time.c b/library/constant_time.c
index 1f6c2ca..a786d38 100644
--- a/library/constant_time.c
+++ b/library/constant_time.c
@@ -431,227 +431,6 @@
     }
 }
 
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-
-#if defined(PSA_WANT_ALG_SHA_384)
-#define MAX_HASH_BLOCK_LENGTH PSA_HASH_BLOCK_LENGTH(PSA_ALG_SHA_384)
-#elif defined(PSA_WANT_ALG_SHA_256)
-#define MAX_HASH_BLOCK_LENGTH PSA_HASH_BLOCK_LENGTH(PSA_ALG_SHA_256)
-#else /* See check_config.h */
-#define MAX_HASH_BLOCK_LENGTH PSA_HASH_BLOCK_LENGTH(PSA_ALG_SHA_1)
-#endif
-
-int mbedtls_ct_hmac(mbedtls_svc_key_id_t key,
-                    psa_algorithm_t mac_alg,
-                    const unsigned char *add_data,
-                    size_t add_data_len,
-                    const unsigned char *data,
-                    size_t data_len_secret,
-                    size_t min_data_len,
-                    size_t max_data_len,
-                    unsigned char *output)
-{
-    /*
-     * This function breaks the HMAC abstraction and uses psa_hash_clone()
-     * extension in order to get constant-flow behaviour.
-     *
-     * HMAC(msg) is defined as HASH(okey + HASH(ikey + msg)) where + means
-     * concatenation, and okey/ikey are the XOR of the key with some fixed bit
-     * patterns (see RFC 2104, sec. 2).
-     *
-     * We'll first compute ikey/okey, then inner_hash = HASH(ikey + msg) by
-     * hashing up to minlen, then cloning the context, and for each byte up
-     * to maxlen finishing up the hash computation, keeping only the
-     * correct result.
-     *
-     * Then we only need to compute HASH(okey + inner_hash) and we're done.
-     */
-    psa_algorithm_t hash_alg = PSA_ALG_HMAC_GET_HASH(mac_alg);
-    const size_t block_size = PSA_HASH_BLOCK_LENGTH(hash_alg);
-    unsigned char key_buf[MAX_HASH_BLOCK_LENGTH];
-    const size_t hash_size = PSA_HASH_LENGTH(hash_alg);
-    psa_hash_operation_t operation = PSA_HASH_OPERATION_INIT;
-    size_t hash_length;
-
-    unsigned char aux_out[PSA_HASH_MAX_SIZE];
-    psa_hash_operation_t aux_operation = PSA_HASH_OPERATION_INIT;
-    size_t offset;
-    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-
-    size_t mac_key_length;
-    size_t i;
-
-#define PSA_CHK(func_call)        \
-    do {                            \
-        status = (func_call);       \
-        if (status != PSA_SUCCESS) \
-        goto cleanup;           \
-    } while (0)
-
-    /* Export MAC key
-     * We assume key length is always exactly the output size
-     * which is never more than the block size, thus we use block_size
-     * as the key buffer size.
-     */
-    PSA_CHK(psa_export_key(key, key_buf, block_size, &mac_key_length));
-
-    /* Calculate ikey */
-    for (i = 0; i < mac_key_length; i++) {
-        key_buf[i] = (unsigned char) (key_buf[i] ^ 0x36);
-    }
-    for (; i < block_size; ++i) {
-        key_buf[i] = 0x36;
-    }
-
-    PSA_CHK(psa_hash_setup(&operation, hash_alg));
-
-    /* Now compute inner_hash = HASH(ikey + msg) */
-    PSA_CHK(psa_hash_update(&operation, key_buf, block_size));
-    PSA_CHK(psa_hash_update(&operation, add_data, add_data_len));
-    PSA_CHK(psa_hash_update(&operation, data, min_data_len));
-
-    /* Fill the hash buffer in advance with something that is
-     * not a valid hash (barring an attack on the hash and
-     * deliberately-crafted input), in case the caller doesn't
-     * check the return status properly. */
-    memset(output, '!', hash_size);
-
-    /* For each possible length, compute the hash up to that point */
-    for (offset = min_data_len; offset <= max_data_len; offset++) {
-        PSA_CHK(psa_hash_clone(&operation, &aux_operation));
-        PSA_CHK(psa_hash_finish(&aux_operation, aux_out,
-                                PSA_HASH_MAX_SIZE, &hash_length));
-        /* Keep only the correct inner_hash in the output buffer */
-        mbedtls_ct_memcpy_if_eq(output, aux_out, hash_size,
-                                offset, data_len_secret);
-
-        if (offset < max_data_len) {
-            PSA_CHK(psa_hash_update(&operation, data + offset, 1));
-        }
-    }
-
-    /* Abort current operation to prepare for final operation */
-    PSA_CHK(psa_hash_abort(&operation));
-
-    /* Calculate okey */
-    for (i = 0; i < mac_key_length; i++) {
-        key_buf[i] = (unsigned char) ((key_buf[i] ^ 0x36) ^ 0x5C);
-    }
-    for (; i < block_size; ++i) {
-        key_buf[i] = 0x5C;
-    }
-
-    /* Now compute HASH(okey + inner_hash) */
-    PSA_CHK(psa_hash_setup(&operation, hash_alg));
-    PSA_CHK(psa_hash_update(&operation, key_buf, block_size));
-    PSA_CHK(psa_hash_update(&operation, output, hash_size));
-    PSA_CHK(psa_hash_finish(&operation, output, hash_size, &hash_length));
-
-#undef PSA_CHK
-
-cleanup:
-    mbedtls_platform_zeroize(key_buf, MAX_HASH_BLOCK_LENGTH);
-    mbedtls_platform_zeroize(aux_out, PSA_HASH_MAX_SIZE);
-
-    psa_hash_abort(&operation);
-    psa_hash_abort(&aux_operation);
-    return PSA_TO_MBEDTLS_ERR(status);
-}
-
-#undef MAX_HASH_BLOCK_LENGTH
-
-#else
-int mbedtls_ct_hmac(mbedtls_md_context_t *ctx,
-                    const unsigned char *add_data,
-                    size_t add_data_len,
-                    const unsigned char *data,
-                    size_t data_len_secret,
-                    size_t min_data_len,
-                    size_t max_data_len,
-                    unsigned char *output)
-{
-    /*
-     * This function breaks the HMAC abstraction and uses the md_clone()
-     * extension to the MD API in order to get constant-flow behaviour.
-     *
-     * HMAC(msg) is defined as HASH(okey + HASH(ikey + msg)) where + means
-     * concatenation, and okey/ikey are the XOR of the key with some fixed bit
-     * patterns (see RFC 2104, sec. 2), which are stored in ctx->hmac_ctx.
-     *
-     * We'll first compute inner_hash = HASH(ikey + msg) by hashing up to
-     * minlen, then cloning the context, and for each byte up to maxlen
-     * finishing up the hash computation, keeping only the correct result.
-     *
-     * Then we only need to compute HASH(okey + inner_hash) and we're done.
-     */
-    const mbedtls_md_type_t md_alg = mbedtls_md_get_type(ctx->md_info);
-    /* TLS 1.2 only supports SHA-384, SHA-256, SHA-1, MD-5,
-     * all of which have the same block size except SHA-384. */
-    const size_t block_size = md_alg == MBEDTLS_MD_SHA384 ? 128 : 64;
-    const unsigned char * const ikey = ctx->hmac_ctx;
-    const unsigned char * const okey = ikey + block_size;
-    const size_t hash_size = mbedtls_md_get_size(ctx->md_info);
-
-    unsigned char aux_out[MBEDTLS_MD_MAX_SIZE];
-    mbedtls_md_context_t aux;
-    size_t offset;
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-
-    mbedtls_md_init(&aux);
-
-#define MD_CHK(func_call) \
-    do {                    \
-        ret = (func_call);  \
-        if (ret != 0)      \
-        goto cleanup;   \
-    } while (0)
-
-    MD_CHK(mbedtls_md_setup(&aux, ctx->md_info, 0));
-
-    /* After hmac_start() of hmac_reset(), ikey has already been hashed,
-     * so we can start directly with the message */
-    MD_CHK(mbedtls_md_update(ctx, add_data, add_data_len));
-    MD_CHK(mbedtls_md_update(ctx, data, min_data_len));
-
-    /* Fill the hash buffer in advance with something that is
-     * not a valid hash (barring an attack on the hash and
-     * deliberately-crafted input), in case the caller doesn't
-     * check the return status properly. */
-    memset(output, '!', hash_size);
-
-    /* For each possible length, compute the hash up to that point */
-    for (offset = min_data_len; offset <= max_data_len; offset++) {
-        MD_CHK(mbedtls_md_clone(&aux, ctx));
-        MD_CHK(mbedtls_md_finish(&aux, aux_out));
-        /* Keep only the correct inner_hash in the output buffer */
-        mbedtls_ct_memcpy_if_eq(output, aux_out, hash_size,
-                                offset, data_len_secret);
-
-        if (offset < max_data_len) {
-            MD_CHK(mbedtls_md_update(ctx, data + offset, 1));
-        }
-    }
-
-    /* The context needs to finish() before it starts() again */
-    MD_CHK(mbedtls_md_finish(ctx, aux_out));
-
-    /* Now compute HASH(okey + inner_hash) */
-    MD_CHK(mbedtls_md_starts(ctx));
-    MD_CHK(mbedtls_md_update(ctx, okey, block_size));
-    MD_CHK(mbedtls_md_update(ctx, output, hash_size));
-    MD_CHK(mbedtls_md_finish(ctx, output));
-
-    /* Done, get ready for next time */
-    MD_CHK(mbedtls_md_hmac_reset(ctx));
-
-#undef MD_CHK
-
-cleanup:
-    mbedtls_md_free(&aux);
-    return ret;
-}
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
-
 #endif /* MBEDTLS_SSL_SOME_SUITES_USE_MAC */
 
 #if defined(MBEDTLS_BIGNUM_C)
diff --git a/library/constant_time_internal.h b/library/constant_time_internal.h
index dde6a0b..d8a0fc7 100644
--- a/library/constant_time_internal.h
+++ b/library/constant_time_internal.h
@@ -236,63 +236,6 @@
                               size_t offset_max,
                               size_t len);
 
-/** Compute the HMAC of variable-length data with constant flow.
- *
- * This function computes the HMAC of the concatenation of \p add_data and \p
- * data, and does with a code flow and memory access pattern that does not
- * depend on \p data_len_secret, but only on \p min_data_len and \p
- * max_data_len. In particular, this function always reads exactly \p
- * max_data_len bytes from \p data.
- *
- * \param ctx               The HMAC context. It must have keys configured
- *                          with mbedtls_md_hmac_starts() and use one of the
- *                          following hashes: SHA-384, SHA-256, SHA-1 or MD-5.
- *                          It is reset using mbedtls_md_hmac_reset() after
- *                          the computation is complete to prepare for the
- *                          next computation.
- * \param add_data          The first part of the message whose HMAC is being
- *                          calculated. This must point to a readable buffer
- *                          of \p add_data_len bytes.
- * \param add_data_len      The length of \p add_data in bytes.
- * \param data              The buffer containing the second part of the
- *                          message. This must point to a readable buffer
- *                          of \p max_data_len bytes.
- * \param data_len_secret   The length of the data to process in \p data.
- *                          This must be no less than \p min_data_len and no
- *                          greater than \p max_data_len.
- * \param min_data_len      The minimal length of the second part of the
- *                          message, read from \p data.
- * \param max_data_len      The maximal length of the second part of the
- *                          message, read from \p data.
- * \param output            The HMAC will be written here. This must point to
- *                          a writable buffer of sufficient size to hold the
- *                          HMAC value.
- *
- * \retval 0 on success.
- * \retval #MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED
- *         The hardware accelerator failed.
- */
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-int mbedtls_ct_hmac(mbedtls_svc_key_id_t key,
-                    psa_algorithm_t alg,
-                    const unsigned char *add_data,
-                    size_t add_data_len,
-                    const unsigned char *data,
-                    size_t data_len_secret,
-                    size_t min_data_len,
-                    size_t max_data_len,
-                    unsigned char *output);
-#else
-int mbedtls_ct_hmac(mbedtls_md_context_t *ctx,
-                    const unsigned char *add_data,
-                    size_t add_data_len,
-                    const unsigned char *data,
-                    size_t data_len_secret,
-                    size_t min_data_len,
-                    size_t max_data_len,
-                    unsigned char *output);
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
-
 #endif /* MBEDTLS_SSL_SOME_SUITES_USE_MAC */
 
 #if defined(MBEDTLS_PKCS1_V15) && defined(MBEDTLS_RSA_C) && !defined(MBEDTLS_RSA_ALT)
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 17149c5..eb11f7b 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -2788,4 +2788,64 @@
 int mbedtls_ssl_tls13_finalize_client_hello(mbedtls_ssl_context *ssl);
 #endif
 
+#if defined(MBEDTLS_TEST_HOOKS) && defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
+
+/** Compute the HMAC of variable-length data with constant flow.
+ *
+ * This function computes the HMAC of the concatenation of \p add_data and \p
+ * data, and does with a code flow and memory access pattern that does not
+ * depend on \p data_len_secret, but only on \p min_data_len and \p
+ * max_data_len. In particular, this function always reads exactly \p
+ * max_data_len bytes from \p data.
+ *
+ * \param ctx               The HMAC context. It must have keys configured
+ *                          with mbedtls_md_hmac_starts() and use one of the
+ *                          following hashes: SHA-384, SHA-256, SHA-1 or MD-5.
+ *                          It is reset using mbedtls_md_hmac_reset() after
+ *                          the computation is complete to prepare for the
+ *                          next computation.
+ * \param add_data          The first part of the message whose HMAC is being
+ *                          calculated. This must point to a readable buffer
+ *                          of \p add_data_len bytes.
+ * \param add_data_len      The length of \p add_data in bytes.
+ * \param data              The buffer containing the second part of the
+ *                          message. This must point to a readable buffer
+ *                          of \p max_data_len bytes.
+ * \param data_len_secret   The length of the data to process in \p data.
+ *                          This must be no less than \p min_data_len and no
+ *                          greater than \p max_data_len.
+ * \param min_data_len      The minimal length of the second part of the
+ *                          message, read from \p data.
+ * \param max_data_len      The maximal length of the second part of the
+ *                          message, read from \p data.
+ * \param output            The HMAC will be written here. This must point to
+ *                          a writable buffer of sufficient size to hold the
+ *                          HMAC value.
+ *
+ * \retval 0 on success.
+ * \retval #MBEDTLS_ERR_PLATFORM_HW_ACCEL_FAILED
+ *         The hardware accelerator failed.
+ */
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+int mbedtls_ct_hmac(mbedtls_svc_key_id_t key,
+                    psa_algorithm_t mac_alg,
+                    const unsigned char *add_data,
+                    size_t add_data_len,
+                    const unsigned char *data,
+                    size_t data_len_secret,
+                    size_t min_data_len,
+                    size_t max_data_len,
+                    unsigned char *output);
+#else
+int mbedtls_ct_hmac(mbedtls_md_context_t *ctx,
+                    const unsigned char *add_data,
+                    size_t add_data_len,
+                    const unsigned char *data,
+                    size_t data_len_secret,
+                    size_t min_data_len,
+                    size_t max_data_len,
+                    unsigned char *output);
+#endif /* defined(MBEDTLS_USE_PSA_CRYPTO) */
+#endif /* MBEDTLS_TEST_HOOKS && defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC) */
+
 #endif /* ssl_misc.h */
diff --git a/library/ssl_msg.c b/library/ssl_msg.c
index 18c19f9..69706cf 100644
--- a/library/ssl_msg.c
+++ b/library/ssl_msg.c
@@ -54,6 +54,234 @@
                                                            psa_generic_status_to_mbedtls)
 #endif
 
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+
+#if defined(PSA_WANT_ALG_SHA_384)
+#define MAX_HASH_BLOCK_LENGTH PSA_HASH_BLOCK_LENGTH(PSA_ALG_SHA_384)
+#elif defined(PSA_WANT_ALG_SHA_256)
+#define MAX_HASH_BLOCK_LENGTH PSA_HASH_BLOCK_LENGTH(PSA_ALG_SHA_256)
+#else /* See check_config.h */
+#define MAX_HASH_BLOCK_LENGTH PSA_HASH_BLOCK_LENGTH(PSA_ALG_SHA_1)
+#endif
+
+MBEDTLS_STATIC_TESTABLE
+int mbedtls_ct_hmac(mbedtls_svc_key_id_t key,
+                    psa_algorithm_t mac_alg,
+                    const unsigned char *add_data,
+                    size_t add_data_len,
+                    const unsigned char *data,
+                    size_t data_len_secret,
+                    size_t min_data_len,
+                    size_t max_data_len,
+                    unsigned char *output)
+{
+    /*
+     * This function breaks the HMAC abstraction and uses psa_hash_clone()
+     * extension in order to get constant-flow behaviour.
+     *
+     * HMAC(msg) is defined as HASH(okey + HASH(ikey + msg)) where + means
+     * concatenation, and okey/ikey are the XOR of the key with some fixed bit
+     * patterns (see RFC 2104, sec. 2).
+     *
+     * We'll first compute ikey/okey, then inner_hash = HASH(ikey + msg) by
+     * hashing up to minlen, then cloning the context, and for each byte up
+     * to maxlen finishing up the hash computation, keeping only the
+     * correct result.
+     *
+     * Then we only need to compute HASH(okey + inner_hash) and we're done.
+     */
+    psa_algorithm_t hash_alg = PSA_ALG_HMAC_GET_HASH(mac_alg);
+    const size_t block_size = PSA_HASH_BLOCK_LENGTH(hash_alg);
+    unsigned char key_buf[MAX_HASH_BLOCK_LENGTH];
+    const size_t hash_size = PSA_HASH_LENGTH(hash_alg);
+    psa_hash_operation_t operation = PSA_HASH_OPERATION_INIT;
+    size_t hash_length;
+
+    unsigned char aux_out[PSA_HASH_MAX_SIZE];
+    psa_hash_operation_t aux_operation = PSA_HASH_OPERATION_INIT;
+    size_t offset;
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
+    size_t mac_key_length;
+    size_t i;
+
+#define PSA_CHK(func_call)        \
+    do {                            \
+        status = (func_call);       \
+        if (status != PSA_SUCCESS) \
+        goto cleanup;           \
+    } while (0)
+
+    /* Export MAC key
+     * We assume key length is always exactly the output size
+     * which is never more than the block size, thus we use block_size
+     * as the key buffer size.
+     */
+    PSA_CHK(psa_export_key(key, key_buf, block_size, &mac_key_length));
+
+    /* Calculate ikey */
+    for (i = 0; i < mac_key_length; i++) {
+        key_buf[i] = (unsigned char) (key_buf[i] ^ 0x36);
+    }
+    for (; i < block_size; ++i) {
+        key_buf[i] = 0x36;
+    }
+
+    PSA_CHK(psa_hash_setup(&operation, hash_alg));
+
+    /* Now compute inner_hash = HASH(ikey + msg) */
+    PSA_CHK(psa_hash_update(&operation, key_buf, block_size));
+    PSA_CHK(psa_hash_update(&operation, add_data, add_data_len));
+    PSA_CHK(psa_hash_update(&operation, data, min_data_len));
+
+    /* Fill the hash buffer in advance with something that is
+     * not a valid hash (barring an attack on the hash and
+     * deliberately-crafted input), in case the caller doesn't
+     * check the return status properly. */
+    memset(output, '!', hash_size);
+
+    /* For each possible length, compute the hash up to that point */
+    for (offset = min_data_len; offset <= max_data_len; offset++) {
+        PSA_CHK(psa_hash_clone(&operation, &aux_operation));
+        PSA_CHK(psa_hash_finish(&aux_operation, aux_out,
+                                PSA_HASH_MAX_SIZE, &hash_length));
+        /* Keep only the correct inner_hash in the output buffer */
+        mbedtls_ct_memcpy_if_eq(output, aux_out, hash_size,
+                                offset, data_len_secret);
+
+        if (offset < max_data_len) {
+            PSA_CHK(psa_hash_update(&operation, data + offset, 1));
+        }
+    }
+
+    /* Abort current operation to prepare for final operation */
+    PSA_CHK(psa_hash_abort(&operation));
+
+    /* Calculate okey */
+    for (i = 0; i < mac_key_length; i++) {
+        key_buf[i] = (unsigned char) ((key_buf[i] ^ 0x36) ^ 0x5C);
+    }
+    for (; i < block_size; ++i) {
+        key_buf[i] = 0x5C;
+    }
+
+    /* Now compute HASH(okey + inner_hash) */
+    PSA_CHK(psa_hash_setup(&operation, hash_alg));
+    PSA_CHK(psa_hash_update(&operation, key_buf, block_size));
+    PSA_CHK(psa_hash_update(&operation, output, hash_size));
+    PSA_CHK(psa_hash_finish(&operation, output, hash_size, &hash_length));
+
+#undef PSA_CHK
+
+cleanup:
+    mbedtls_platform_zeroize(key_buf, MAX_HASH_BLOCK_LENGTH);
+    mbedtls_platform_zeroize(aux_out, PSA_HASH_MAX_SIZE);
+
+    psa_hash_abort(&operation);
+    psa_hash_abort(&aux_operation);
+    return PSA_TO_MBEDTLS_ERR(status);
+}
+
+#undef MAX_HASH_BLOCK_LENGTH
+
+#else
+MBEDTLS_STATIC_TESTABLE
+int mbedtls_ct_hmac(mbedtls_md_context_t *ctx,
+                    const unsigned char *add_data,
+                    size_t add_data_len,
+                    const unsigned char *data,
+                    size_t data_len_secret,
+                    size_t min_data_len,
+                    size_t max_data_len,
+                    unsigned char *output)
+{
+    /*
+     * This function breaks the HMAC abstraction and uses the md_clone()
+     * extension to the MD API in order to get constant-flow behaviour.
+     *
+     * HMAC(msg) is defined as HASH(okey + HASH(ikey + msg)) where + means
+     * concatenation, and okey/ikey are the XOR of the key with some fixed bit
+     * patterns (see RFC 2104, sec. 2), which are stored in ctx->hmac_ctx.
+     *
+     * We'll first compute inner_hash = HASH(ikey + msg) by hashing up to
+     * minlen, then cloning the context, and for each byte up to maxlen
+     * finishing up the hash computation, keeping only the correct result.
+     *
+     * Then we only need to compute HASH(okey + inner_hash) and we're done.
+     */
+    const mbedtls_md_type_t md_alg = mbedtls_md_get_type(ctx->md_info);
+    /* TLS 1.2 only supports SHA-384, SHA-256, SHA-1, MD-5,
+     * all of which have the same block size except SHA-384. */
+    const size_t block_size = md_alg == MBEDTLS_MD_SHA384 ? 128 : 64;
+    const unsigned char * const ikey = ctx->hmac_ctx;
+    const unsigned char * const okey = ikey + block_size;
+    const size_t hash_size = mbedtls_md_get_size(ctx->md_info);
+
+    unsigned char aux_out[MBEDTLS_MD_MAX_SIZE];
+    mbedtls_md_context_t aux;
+    size_t offset;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+
+    mbedtls_md_init(&aux);
+
+#define MD_CHK(func_call) \
+    do {                    \
+        ret = (func_call);  \
+        if (ret != 0)      \
+        goto cleanup;   \
+    } while (0)
+
+    MD_CHK(mbedtls_md_setup(&aux, ctx->md_info, 0));
+
+    /* After hmac_start() of hmac_reset(), ikey has already been hashed,
+     * so we can start directly with the message */
+    MD_CHK(mbedtls_md_update(ctx, add_data, add_data_len));
+    MD_CHK(mbedtls_md_update(ctx, data, min_data_len));
+
+    /* Fill the hash buffer in advance with something that is
+     * not a valid hash (barring an attack on the hash and
+     * deliberately-crafted input), in case the caller doesn't
+     * check the return status properly. */
+    memset(output, '!', hash_size);
+
+    /* For each possible length, compute the hash up to that point */
+    for (offset = min_data_len; offset <= max_data_len; offset++) {
+        MD_CHK(mbedtls_md_clone(&aux, ctx));
+        MD_CHK(mbedtls_md_finish(&aux, aux_out));
+        /* Keep only the correct inner_hash in the output buffer */
+        mbedtls_ct_memcpy_if_eq(output, aux_out, hash_size,
+                                offset, data_len_secret);
+
+        if (offset < max_data_len) {
+            MD_CHK(mbedtls_md_update(ctx, data + offset, 1));
+        }
+    }
+
+    /* The context needs to finish() before it starts() again */
+    MD_CHK(mbedtls_md_finish(ctx, aux_out));
+
+    /* Now compute HASH(okey + inner_hash) */
+    MD_CHK(mbedtls_md_starts(ctx));
+    MD_CHK(mbedtls_md_update(ctx, okey, block_size));
+    MD_CHK(mbedtls_md_update(ctx, output, hash_size));
+    MD_CHK(mbedtls_md_finish(ctx, output));
+
+    /* Done, get ready for next time */
+    MD_CHK(mbedtls_md_hmac_reset(ctx));
+
+#undef MD_CHK
+
+cleanup:
+    mbedtls_md_free(&aux);
+    return ret;
+}
+
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
+#endif /* MBEDTLS_SSL_SOME_SUITES_USE_MAC */
+
 static uint32_t ssl_get_hs_total_len(mbedtls_ssl_context const *ssl);
 
 /*