core: mbedtls: use AES crypto accelerated routines

Uses the recently provided accelerated AES crypto routines in mbedtls.

Acked-by: Etienne Carriere <etienne.carriere@linaro.org>
Signed-off-by: Jens Wiklander <jens.wiklander@linaro.org>
diff --git a/lib/libmbedtls/core/aes.c b/lib/libmbedtls/core/aes.c
index 4d317f0..1dfa28c 100644
--- a/lib/libmbedtls/core/aes.c
+++ b/lib/libmbedtls/core/aes.c
@@ -4,15 +4,23 @@
  * Copyright (C) 2019, Linaro Limited
  */
 
+#include <assert.h>
+#include <compiler.h>
+#include <crypto/crypto_accel.h>
 #include <crypto/crypto.h>
 #include <kernel/panic.h>
 #include <mbedtls/aes.h>
+#include <mbedtls/platform_util.h>
 #include <string.h>
 
 TEE_Result crypto_aes_expand_enc_key(const void *key, size_t key_len,
 				     void *enc_key, size_t enc_keylen,
 				     unsigned int *rounds)
 {
+#if defined(MBEDTLS_AES_ALT)
+	return crypto_accel_aes_expand_keys(key, key_len, enc_key, NULL,
+					    enc_keylen, rounds);
+#else
 	mbedtls_aes_context ctx;
 
 	memset(&ctx, 0, sizeof(ctx));
@@ -26,11 +34,15 @@
 	*rounds = ctx.nr;
 	mbedtls_aes_free(&ctx);
 	return TEE_SUCCESS;
+#endif
 }
 
-void crypto_aes_enc_block(const void *enc_key, size_t enc_keylen,
+void crypto_aes_enc_block(const void *enc_key, size_t enc_keylen __maybe_unused,
 			  unsigned int rounds, const void *src, void *dst)
 {
+#if defined(MBEDTLS_AES_ALT)
+	crypto_accel_aes_ecb_enc(dst, src, enc_key, rounds, 1);
+#else
 	mbedtls_aes_context ctx;
 
 	memset(&ctx, 0, sizeof(ctx));
@@ -42,4 +54,51 @@
 	ctx.nr = rounds;
 	mbedtls_aes_encrypt(&ctx, src, dst);
 	mbedtls_aes_free(&ctx);
+#endif
 }
+
+#if defined(MBEDTLS_AES_ALT)
+void mbedtls_aes_init(mbedtls_aes_context *ctx)
+{
+	assert(ctx);
+	memset(ctx, 0, sizeof(*ctx));
+}
+
+void mbedtls_aes_free( mbedtls_aes_context *ctx )
+{
+	if (ctx)
+		mbedtls_platform_zeroize(ctx, sizeof(*ctx));
+}
+
+int mbedtls_aes_setkey_enc(mbedtls_aes_context *ctx, const unsigned char *key,
+			   unsigned int keybits)
+{
+	assert(ctx && key);
+
+	if (keybits != 128 && keybits != 192 && keybits != 256)
+		return MBEDTLS_ERR_AES_INVALID_KEY_LENGTH;
+
+	if (crypto_accel_aes_expand_keys(key, keybits / 8, ctx->key, NULL,
+					 sizeof(ctx->key), &ctx->round_count))
+		return MBEDTLS_ERR_AES_BAD_INPUT_DATA;
+
+	return 0;
+}
+
+int mbedtls_aes_setkey_dec(mbedtls_aes_context *ctx, const unsigned char *key,
+			   unsigned int keybits)
+{
+	uint32_t enc_key[sizeof(ctx->key)] = { 0 };
+
+	assert(ctx && key);
+
+	if (keybits != 128 && keybits != 192 && keybits != 256)
+		return MBEDTLS_ERR_AES_INVALID_KEY_LENGTH;
+
+	if (crypto_accel_aes_expand_keys(key, keybits / 8, enc_key, ctx->key,
+					 sizeof(ctx->key), &ctx->round_count))
+		return MBEDTLS_ERR_AES_BAD_INPUT_DATA;
+
+	return 0;
+}
+#endif /*MBEDTLS_AES_ALT*/
diff --git a/lib/libmbedtls/core/aes_cbc.c b/lib/libmbedtls/core/aes_cbc.c
index aa97969..b568f99 100644
--- a/lib/libmbedtls/core/aes_cbc.c
+++ b/lib/libmbedtls/core/aes_cbc.c
@@ -5,6 +5,7 @@
 
 #include <assert.h>
 #include <compiler.h>
+#include <crypto/crypto_accel.h>
 #include <crypto/crypto.h>
 #include <crypto/crypto_impl.h>
 #include <mbedtls/aes.h>
@@ -118,3 +119,22 @@
 
 	return TEE_SUCCESS;
 }
+
+#if defined(MBEDTLS_AES_ALT)
+int mbedtls_aes_crypt_cbc(mbedtls_aes_context *ctx, int mode, size_t length,
+			  unsigned char iv[16], const unsigned char *input,
+			  unsigned char *output)
+{
+	if (length % 16)
+		return MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH;
+
+	if (mode == MBEDTLS_AES_ENCRYPT)
+		crypto_accel_aes_cbc_enc(output, input, ctx->key,
+					 ctx->round_count, length / 16, iv);
+	else
+		crypto_accel_aes_cbc_dec(output, input, ctx->key,
+					 ctx->round_count, length / 16, iv);
+
+	return 0;
+}
+#endif /*MBEDTLS_AES_ALT*/
diff --git a/lib/libmbedtls/core/aes_ctr.c b/lib/libmbedtls/core/aes_ctr.c
index 3dc16c6..eb33319 100644
--- a/lib/libmbedtls/core/aes_ctr.c
+++ b/lib/libmbedtls/core/aes_ctr.c
@@ -5,6 +5,7 @@
 
 #include <assert.h>
 #include <compiler.h>
+#include <crypto/crypto_accel.h>
 #include <crypto/crypto.h>
 #include <crypto/crypto_impl.h>
 #include <mbedtls/aes.h>
@@ -113,3 +114,56 @@
 
 	return TEE_SUCCESS;
 }
+
+#if defined(MBEDTLS_AES_ALT)
+static void next_ctr(unsigned char stream_block[16], mbedtls_aes_context *ctx,
+		     unsigned char nonce_counter[16])
+{
+	const unsigned char zeroes[16] = { 0 };
+
+	crypto_accel_aes_ctr_be_enc(stream_block, zeroes, ctx->key,
+				    ctx->round_count, 1, nonce_counter);
+}
+
+int mbedtls_aes_crypt_ctr(mbedtls_aes_context *ctx, size_t length,
+			  size_t *nc_off, unsigned char nonce_counter[16],
+			  unsigned char stream_block[16],
+			  const unsigned char *input, unsigned char *output)
+{
+	size_t offs = 0;
+
+	if (*nc_off >= 16)
+		return MBEDTLS_ERR_AES_BAD_INPUT_DATA;
+
+	/*
+	 * If the stream_block is in use, continue until done or
+	 * stream_block is consumed.
+	 */
+	while (*nc_off) {
+		output[offs] = stream_block[*nc_off] ^ input[offs];
+		offs++;
+		*nc_off = (*nc_off + 1) % 16;
+		if (offs == length)
+			return 0;
+	}
+
+	if ((length - offs) >= 16) {
+		size_t block_count = (length - offs) / 16;
+
+		crypto_accel_aes_ctr_be_enc(output + offs, input + offs,
+					    ctx->key, ctx->round_count,
+					    block_count, nonce_counter);
+		offs += block_count * 16;
+	}
+
+	while (offs < length) {
+		if (!*nc_off)
+			next_ctr(stream_block, ctx, nonce_counter);
+		output[offs] = stream_block[*nc_off] ^ input[offs];
+		offs++;
+		*nc_off = (*nc_off + 1) % 16;
+	}
+
+	return 0;
+}
+#endif /*MBEDTLS_AES_ALT*/
diff --git a/lib/libmbedtls/core/aes_ecb.c b/lib/libmbedtls/core/aes_ecb.c
index 58169f9..8aa78a0 100644
--- a/lib/libmbedtls/core/aes_ecb.c
+++ b/lib/libmbedtls/core/aes_ecb.c
@@ -5,6 +5,7 @@
 
 #include <assert.h>
 #include <compiler.h>
+#include <crypto/crypto_accel.h>
 #include <crypto/crypto.h>
 #include <crypto/crypto_impl.h>
 #include <mbedtls/aes.h>
@@ -120,3 +121,20 @@
 
 	return TEE_SUCCESS;
 }
+
+#if defined(MBEDTLS_AES_ALT)
+int mbedtls_aes_crypt_ecb(mbedtls_aes_context *ctx, int mode,
+			  const unsigned char input[16],
+			  unsigned char output[16])
+
+{
+	if (mode == MBEDTLS_AES_ENCRYPT)
+		crypto_accel_aes_ecb_enc(output, input, ctx->key,
+					 ctx->round_count, 1);
+	else
+		crypto_accel_aes_ecb_dec(output, input, ctx->key,
+					 ctx->round_count, 1);
+
+	return 0;
+}
+#endif /*MBEDTLS_AES_ALT*/
diff --git a/lib/libmbedtls/include/aes_alt.h b/lib/libmbedtls/include/aes_alt.h
new file mode 100644
index 0000000..88a7103
--- /dev/null
+++ b/lib/libmbedtls/include/aes_alt.h
@@ -0,0 +1,12 @@
+/* SPDX-License-Identifier: BSD-2-Clause */
+/* Copyright (c) 2020, Linaro Limited */
+
+#ifndef __MBEDTLS_AES_ALT_H
+#define __MBEDTLS_AES_ALT_H
+
+typedef struct mbedtls_aes_context {
+	uint32_t key[60];
+	unsigned int round_count;
+} mbedtls_aes_context;
+
+#endif /*__MBEDTLS_AES_ALT_H*/
diff --git a/lib/libmbedtls/include/mbedtls_config_kernel.h b/lib/libmbedtls/include/mbedtls_config_kernel.h
index 9ec22f1..2cb6124 100644
--- a/lib/libmbedtls/include/mbedtls_config_kernel.h
+++ b/lib/libmbedtls/include/mbedtls_config_kernel.h
@@ -42,6 +42,9 @@
 #if defined(CFG_CRYPTO_AES)
 #define MBEDTLS_AES_C
 #define MBEDTLS_AES_ROM_TABLES
+#if defined(CFG_CORE_CRYPTO_AES_ACCEL)
+#define MBEDTLS_AES_ALT
+#endif
 #endif
 
 #if defined(CFG_CRYPTO_DES)