Move mbedtls_ct_hmac into ssl_msg.c

Signed-off-by: Dave Rodgman <dave.rodgman@arm.com>
diff --git a/library/ssl_msg.c b/library/ssl_msg.c
index 18c19f9..69706cf 100644
--- a/library/ssl_msg.c
+++ b/library/ssl_msg.c
@@ -54,6 +54,234 @@
                                                            psa_generic_status_to_mbedtls)
 #endif
 
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+
+#if defined(PSA_WANT_ALG_SHA_384)
+#define MAX_HASH_BLOCK_LENGTH PSA_HASH_BLOCK_LENGTH(PSA_ALG_SHA_384)
+#elif defined(PSA_WANT_ALG_SHA_256)
+#define MAX_HASH_BLOCK_LENGTH PSA_HASH_BLOCK_LENGTH(PSA_ALG_SHA_256)
+#else /* See check_config.h */
+#define MAX_HASH_BLOCK_LENGTH PSA_HASH_BLOCK_LENGTH(PSA_ALG_SHA_1)
+#endif
+
+MBEDTLS_STATIC_TESTABLE
+int mbedtls_ct_hmac(mbedtls_svc_key_id_t key,
+                    psa_algorithm_t mac_alg,
+                    const unsigned char *add_data,
+                    size_t add_data_len,
+                    const unsigned char *data,
+                    size_t data_len_secret,
+                    size_t min_data_len,
+                    size_t max_data_len,
+                    unsigned char *output)
+{
+    /*
+     * This function breaks the HMAC abstraction and uses psa_hash_clone()
+     * extension in order to get constant-flow behaviour.
+     *
+     * HMAC(msg) is defined as HASH(okey + HASH(ikey + msg)) where + means
+     * concatenation, and okey/ikey are the XOR of the key with some fixed bit
+     * patterns (see RFC 2104, sec. 2).
+     *
+     * We'll first compute ikey/okey, then inner_hash = HASH(ikey + msg) by
+     * hashing up to minlen, then cloning the context, and for each byte up
+     * to maxlen finishing up the hash computation, keeping only the
+     * correct result.
+     *
+     * Then we only need to compute HASH(okey + inner_hash) and we're done.
+     */
+    psa_algorithm_t hash_alg = PSA_ALG_HMAC_GET_HASH(mac_alg);
+    const size_t block_size = PSA_HASH_BLOCK_LENGTH(hash_alg);
+    unsigned char key_buf[MAX_HASH_BLOCK_LENGTH];
+    const size_t hash_size = PSA_HASH_LENGTH(hash_alg);
+    psa_hash_operation_t operation = PSA_HASH_OPERATION_INIT;
+    size_t hash_length;
+
+    unsigned char aux_out[PSA_HASH_MAX_SIZE];
+    psa_hash_operation_t aux_operation = PSA_HASH_OPERATION_INIT;
+    size_t offset;
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+
+    size_t mac_key_length;
+    size_t i;
+
+#define PSA_CHK(func_call)        \
+    do {                            \
+        status = (func_call);       \
+        if (status != PSA_SUCCESS) \
+        goto cleanup;           \
+    } while (0)
+
+    /* Export MAC key
+     * We assume key length is always exactly the output size
+     * which is never more than the block size, thus we use block_size
+     * as the key buffer size.
+     */
+    PSA_CHK(psa_export_key(key, key_buf, block_size, &mac_key_length));
+
+    /* Calculate ikey */
+    for (i = 0; i < mac_key_length; i++) {
+        key_buf[i] = (unsigned char) (key_buf[i] ^ 0x36);
+    }
+    for (; i < block_size; ++i) {
+        key_buf[i] = 0x36;
+    }
+
+    PSA_CHK(psa_hash_setup(&operation, hash_alg));
+
+    /* Now compute inner_hash = HASH(ikey + msg) */
+    PSA_CHK(psa_hash_update(&operation, key_buf, block_size));
+    PSA_CHK(psa_hash_update(&operation, add_data, add_data_len));
+    PSA_CHK(psa_hash_update(&operation, data, min_data_len));
+
+    /* Fill the hash buffer in advance with something that is
+     * not a valid hash (barring an attack on the hash and
+     * deliberately-crafted input), in case the caller doesn't
+     * check the return status properly. */
+    memset(output, '!', hash_size);
+
+    /* For each possible length, compute the hash up to that point */
+    for (offset = min_data_len; offset <= max_data_len; offset++) {
+        PSA_CHK(psa_hash_clone(&operation, &aux_operation));
+        PSA_CHK(psa_hash_finish(&aux_operation, aux_out,
+                                PSA_HASH_MAX_SIZE, &hash_length));
+        /* Keep only the correct inner_hash in the output buffer */
+        mbedtls_ct_memcpy_if_eq(output, aux_out, hash_size,
+                                offset, data_len_secret);
+
+        if (offset < max_data_len) {
+            PSA_CHK(psa_hash_update(&operation, data + offset, 1));
+        }
+    }
+
+    /* Abort current operation to prepare for final operation */
+    PSA_CHK(psa_hash_abort(&operation));
+
+    /* Calculate okey */
+    for (i = 0; i < mac_key_length; i++) {
+        key_buf[i] = (unsigned char) ((key_buf[i] ^ 0x36) ^ 0x5C);
+    }
+    for (; i < block_size; ++i) {
+        key_buf[i] = 0x5C;
+    }
+
+    /* Now compute HASH(okey + inner_hash) */
+    PSA_CHK(psa_hash_setup(&operation, hash_alg));
+    PSA_CHK(psa_hash_update(&operation, key_buf, block_size));
+    PSA_CHK(psa_hash_update(&operation, output, hash_size));
+    PSA_CHK(psa_hash_finish(&operation, output, hash_size, &hash_length));
+
+#undef PSA_CHK
+
+cleanup:
+    mbedtls_platform_zeroize(key_buf, MAX_HASH_BLOCK_LENGTH);
+    mbedtls_platform_zeroize(aux_out, PSA_HASH_MAX_SIZE);
+
+    psa_hash_abort(&operation);
+    psa_hash_abort(&aux_operation);
+    return PSA_TO_MBEDTLS_ERR(status);
+}
+
+#undef MAX_HASH_BLOCK_LENGTH
+
+#else
+MBEDTLS_STATIC_TESTABLE
+int mbedtls_ct_hmac(mbedtls_md_context_t *ctx,
+                    const unsigned char *add_data,
+                    size_t add_data_len,
+                    const unsigned char *data,
+                    size_t data_len_secret,
+                    size_t min_data_len,
+                    size_t max_data_len,
+                    unsigned char *output)
+{
+    /*
+     * This function breaks the HMAC abstraction and uses the md_clone()
+     * extension to the MD API in order to get constant-flow behaviour.
+     *
+     * HMAC(msg) is defined as HASH(okey + HASH(ikey + msg)) where + means
+     * concatenation, and okey/ikey are the XOR of the key with some fixed bit
+     * patterns (see RFC 2104, sec. 2), which are stored in ctx->hmac_ctx.
+     *
+     * We'll first compute inner_hash = HASH(ikey + msg) by hashing up to
+     * minlen, then cloning the context, and for each byte up to maxlen
+     * finishing up the hash computation, keeping only the correct result.
+     *
+     * Then we only need to compute HASH(okey + inner_hash) and we're done.
+     */
+    const mbedtls_md_type_t md_alg = mbedtls_md_get_type(ctx->md_info);
+    /* TLS 1.2 only supports SHA-384, SHA-256, SHA-1, MD-5,
+     * all of which have the same block size except SHA-384. */
+    const size_t block_size = md_alg == MBEDTLS_MD_SHA384 ? 128 : 64;
+    const unsigned char * const ikey = ctx->hmac_ctx;
+    const unsigned char * const okey = ikey + block_size;
+    const size_t hash_size = mbedtls_md_get_size(ctx->md_info);
+
+    unsigned char aux_out[MBEDTLS_MD_MAX_SIZE];
+    mbedtls_md_context_t aux;
+    size_t offset;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+
+    mbedtls_md_init(&aux);
+
+#define MD_CHK(func_call) \
+    do {                    \
+        ret = (func_call);  \
+        if (ret != 0)      \
+        goto cleanup;   \
+    } while (0)
+
+    MD_CHK(mbedtls_md_setup(&aux, ctx->md_info, 0));
+
+    /* After hmac_start() of hmac_reset(), ikey has already been hashed,
+     * so we can start directly with the message */
+    MD_CHK(mbedtls_md_update(ctx, add_data, add_data_len));
+    MD_CHK(mbedtls_md_update(ctx, data, min_data_len));
+
+    /* Fill the hash buffer in advance with something that is
+     * not a valid hash (barring an attack on the hash and
+     * deliberately-crafted input), in case the caller doesn't
+     * check the return status properly. */
+    memset(output, '!', hash_size);
+
+    /* For each possible length, compute the hash up to that point */
+    for (offset = min_data_len; offset <= max_data_len; offset++) {
+        MD_CHK(mbedtls_md_clone(&aux, ctx));
+        MD_CHK(mbedtls_md_finish(&aux, aux_out));
+        /* Keep only the correct inner_hash in the output buffer */
+        mbedtls_ct_memcpy_if_eq(output, aux_out, hash_size,
+                                offset, data_len_secret);
+
+        if (offset < max_data_len) {
+            MD_CHK(mbedtls_md_update(ctx, data + offset, 1));
+        }
+    }
+
+    /* The context needs to finish() before it starts() again */
+    MD_CHK(mbedtls_md_finish(ctx, aux_out));
+
+    /* Now compute HASH(okey + inner_hash) */
+    MD_CHK(mbedtls_md_starts(ctx));
+    MD_CHK(mbedtls_md_update(ctx, okey, block_size));
+    MD_CHK(mbedtls_md_update(ctx, output, hash_size));
+    MD_CHK(mbedtls_md_finish(ctx, output));
+
+    /* Done, get ready for next time */
+    MD_CHK(mbedtls_md_hmac_reset(ctx));
+
+#undef MD_CHK
+
+cleanup:
+    mbedtls_md_free(&aux);
+    return ret;
+}
+
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
+#endif /* MBEDTLS_SSL_SOME_SUITES_USE_MAC */
+
 static uint32_t ssl_get_hs_total_len(mbedtls_ssl_context const *ssl);
 
 /*