RSS: Use CC3XX KDF

Change-Id: If7bbc3da4558a168e606a7b4f0c99683512b57ad
Signed-off-by: Raef Coles <raef.coles@arm.com>
diff --git a/bl1/bl1_2/scripts/create_bl2_img.py b/bl1/bl1_2/scripts/create_bl2_img.py
index 94231a6..255e659 100644
--- a/bl1/bl1_2/scripts/create_bl2_img.py
+++ b/bl1/bl1_2/scripts/create_bl2_img.py
@@ -1,5 +1,5 @@
 #-------------------------------------------------------------------------------
-# Copyright (c) 2021-2022, Arm Limited. All rights reserved.
+# Copyright (c) 2021-2023, Arm Limited. All rights reserved.
 #
 # SPDX-License-Identifier: BSD-3-Clause
 #
@@ -7,6 +7,7 @@
 
 import hashlib
 from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
+from cryptography.hazmat.primitives import cmac
 from cryptography.hazmat.backends import default_backend
 import secrets
 import argparse
@@ -46,16 +47,19 @@
     with open(args.encrypt_key_file, "rb") as encrypt_key_file:
         encrypt_key = encrypt_key_file.read()
 
-    state = struct_pack([(1).to_bytes(4, byteorder='little'),
-                         # C keeps the null byte, python removes it, so we add
-                         # it back manually.
-                         "BL2_DECRYPTION_KEY".encode('ascii') + bytes(1),
-                         bytes(1), security_counter,
-                         (32).to_bytes(4, byteorder='little')])
-    state_hash_object = hashlib.sha256()
-    state_hash_object.update(state)
-    state_hash = state_hash_object.digest()
-    return Cipher(algorithms.AES(encrypt_key), modes.ECB()).encryptor().update(state_hash)
+    output_key = bytes(0);
+    # The KDF outputs 16 bytes per iteration, so we need 2 for an AES-256 key
+    for i in range(2):
+        state = struct_pack([(i + 1).to_bytes(4, byteorder='little'),
+                             # C keeps the null byte, python removes it, so we add
+                             # it back manually.
+                             "BL2_DECRYPTION_KEY".encode('ascii') + bytes(1),
+                             bytes(1), security_counter,
+                             (32).to_bytes(4, byteorder='little')])
+        c = cmac.CMAC(algorithms.AES(encrypt_key))
+        c.update(state)
+        output_key += c.finalize()
+    return output_key
 
 
 
diff --git a/platform/ext/target/arm/rss/common/CMakeLists.txt b/platform/ext/target/arm/rss/common/CMakeLists.txt
index 2ff4466..f47dc75 100644
--- a/platform/ext/target/arm/rss/common/CMakeLists.txt
+++ b/platform/ext/target/arm/rss/common/CMakeLists.txt
@@ -327,6 +327,7 @@
         $<$<BOOL:${RSS_DEBUG_UART}>:${CMAKE_CURRENT_SOURCE_DIR}/cmsis_drivers/Driver_USART_cmsdk.c>
         $<$<BOOL:${RSS_DEBUG_UART}>:${PLATFORM_DIR}/ext/target/arm/drivers/usart/cmsdk/uart_cmsdk_drv.c>
         ./dpa_hardened_word_copy.c
+        ./cc312/cc3xx_aes_external_key_loader.c
 )
 
 target_include_directories(platform_bl1_1_interface
diff --git a/platform/ext/target/arm/rss/common/bl1/bl1_1_shared_symbols.txt b/platform/ext/target/arm/rss/common/bl1/bl1_1_shared_symbols.txt
index 91308a9..a0e8497 100644
--- a/platform/ext/target/arm/rss/common/bl1/bl1_1_shared_symbols.txt
+++ b/platform/ext/target/arm/rss/common/bl1/bl1_1_shared_symbols.txt
@@ -12,6 +12,7 @@
 bl1_sha256_update
 bl1_trng_generate_random
 bl_secure_memcpy
+cc3xx_kdf_cmac
 cc3xx_uninit
 computed_bl1_2_hash
 dpa_hardened_word_copy
@@ -19,7 +20,7 @@
 host_flash_atu_init_regions_for_image
 host_flash_atu_uninit_regions
 kmu_init
-kmu_set_key
+kmu_get_key_buffer_ptr
 kmu_set_key_locked
 kmu_set_slot_invalid
 pq_crypto_verify
diff --git a/platform/ext/target/arm/rss/common/bl1/boot_hal_bl1_2.c b/platform/ext/target/arm/rss/common/bl1/boot_hal_bl1_2.c
index 5814b40..c1fb3aa 100644
--- a/platform/ext/target/arm/rss/common/bl1/boot_hal_bl1_2.c
+++ b/platform/ext/target/arm/rss/common/bl1/boot_hal_bl1_2.c
@@ -31,6 +31,7 @@
 #endif
 #include "tfm_plat_nv_counters.h"
 #include "rss_key_derivation.h"
+#include "rss_kmu_slot_ids.h"
 
 uint32_t image_offsets[2];
 
@@ -203,28 +204,25 @@
 int boot_platform_post_load(uint32_t image_id)
 {
     int rc = 0;
-    uint8_t key_buf[32];
-    size_t key_len;
+    uint32_t vhuk_seed[8];
+    size_t vhuk_seed_len;
 
-    rc = rss_derive_vhuk_seed(key_buf, sizeof(key_buf), &key_len);
+    rc = rss_derive_vhuk_seed(vhuk_seed, sizeof(vhuk_seed), &vhuk_seed_len);
     if (rc) {
-        goto exit;
+        return rc;
     }
 
-    rc = rss_derive_vhuk(key_buf, sizeof(key_buf), KMU_USER_SLOT_MIN);
+    rc = rss_derive_vhuk((uint8_t *)vhuk_seed, vhuk_seed_len, RSS_KMU_SLOT_VHUK);
     if (rc) {
-        goto exit;
+        return rc;
     }
 
-    rc = rss_derive_cpak_seed(KMU_USER_SLOT_MIN + 1);
+    rc = rss_derive_cpak_seed(RSS_KMU_SLOT_CPAK_SEED);
     if (rc) {
-        goto exit;
+        return rc;
     }
 
-    rc = rss_derive_dak_seed(KMU_USER_SLOT_MIN + 2);
-
-exit:
-    memset(key_buf, 0, sizeof(key_buf));
+    rc = rss_derive_dak_seed(RSS_KMU_SLOT_DAK_SEED);
 
     return rc;
 }
diff --git a/platform/ext/target/arm/rss/common/bl1/cc312_rom_crypto.c b/platform/ext/target/arm/rss/common/bl1/cc312_rom_crypto.c
index e6e1c86..e36af93 100644
--- a/platform/ext/target/arm/rss/common/bl1/cc312_rom_crypto.c
+++ b/platform/ext/target/arm/rss/common/bl1/cc312_rom_crypto.c
@@ -15,6 +15,7 @@
 #include "otp.h"
 #include "fih.h"
 #include "cc3xx_drv.h"
+#include "kmu_drv.h"
 
 #define KEY_DERIVATION_MAX_BUF_SIZE 128
 
@@ -82,24 +83,25 @@
     FIH_RET(FIH_SUCCESS);
 }
 
-static int32_t bl1_key_to_cc3xx_key(enum tfm_bl1_key_id_t key_id,
-                                    cc3xx_aes_key_id_t *cc3xx_key_type,
-                                    uint8_t *key_buf, size_t key_buf_size)
+static int32_t bl1_key_to_kmu_key(enum tfm_bl1_key_id_t key_id,
+                                    enum kmu_hardware_keyslot_t *cc3xx_key_type,
+                                    uint8_t **key_buf, size_t key_buf_size)
 {
     int32_t rc;
 
     switch(key_id) {
     case TFM_BL1_KEY_HUK:
-        *cc3xx_key_type = CC3XX_AES_KEY_ID_HUK;
+        *cc3xx_key_type = KMU_HW_SLOT_HUK;
+        *key_buf = NULL;
         break;
     case TFM_BL1_KEY_GUK:
-        *cc3xx_key_type = CC3XX_AES_KEY_ID_GUK;
+        *cc3xx_key_type = KMU_HW_SLOT_GUK;
+        *key_buf = NULL;
         break;
     default:
-        *cc3xx_key_type = CC3XX_AES_KEY_ID_USER_KEY;
-        rc = bl1_otp_read_key(key_id, key_buf);
+        rc = bl1_otp_read_key(key_id, *key_buf);
         if (rc) {
-            memset(key_buf, 0, key_buf_size);
+            memset(*key_buf, 0, key_buf_size);
             return rc;
         }
         break;
@@ -115,10 +117,10 @@
                                 size_t ciphertext_length,
                                 uint8_t *plaintext)
 {
-    cc3xx_aes_key_id_t cc3xx_key_type;
+    enum kmu_hardware_keyslot_t kmu_key_slot;
     uint32_t key_buf[32 / sizeof(uint32_t)];
     int32_t rc = 0;
-    const uint8_t *input_key = key_buf;
+    uint8_t *input_key = (uint8_t *)key_buf;
     cc3xx_err_t err;
 
     if (ciphertext_length == 0) {
@@ -134,19 +136,18 @@
     }
 
     if (key_material == NULL) {
-        rc = bl1_key_to_cc3xx_key(key_id, &cc3xx_key_type, (uint8_t *)key_buf,
-                                  sizeof(key_buf));
+        rc = bl1_key_to_kmu_key(key_id, &kmu_key_slot, &input_key,
+                                sizeof(key_buf));
         if (rc) {
             return rc;
         }
     } else {
-        cc3xx_key_type = CC3XX_AES_KEY_ID_USER_KEY;
-        input_key = key_material;
+        input_key = (uint8_t *)key_material;
     }
 
     err = cc3xx_aes_init(CC3XX_AES_DIRECTION_DECRYPT, CC3XX_AES_MODE_CTR,
-                         cc3xx_key_type, input_key, CC3XX_AES_KEYSIZE_256,
-                         (uint32_t *)counter, 16);
+                         kmu_key_slot, (uint32_t *)input_key,
+                         CC3XX_AES_KEYSIZE_256, (uint32_t *)counter, 16);
     if (err != CC3XX_ERR_SUCCESS) {
         return 1;
     }
@@ -158,132 +159,28 @@
     return 0;
 }
 
-static int32_t aes_256_ecb_encrypt(enum tfm_bl1_key_id_t key_id,
-                                   const uint8_t *plaintext,
-                                   size_t ciphertext_length,
-                                   uint8_t *ciphertext)
-{
-    cc3xx_aes_key_id_t cc3xx_key_type;
-    uint32_t key_buf[32 / sizeof(uint32_t)];
-    int32_t rc = 0;
-    cc3xx_err_t err;
-
-    if (ciphertext_length == 0) {
-        return 0;
-    }
-
-    if (ciphertext == NULL || plaintext == NULL) {
-        return -1;
-    }
-
-    rc = bl1_key_to_cc3xx_key(key_id, &cc3xx_key_type, key_buf, sizeof(key_buf));
-    if (rc) {
-        return rc;
-    }
-
-    err = cc3xx_aes_init(CC3XX_AES_DIRECTION_ENCRYPT, CC3XX_AES_MODE_ECB,
-                         cc3xx_key_type, (uint32_t *)key_buf,
-                         CC3XX_AES_KEYSIZE_256,
-                         NULL, 0);
-    if (err != CC3XX_ERR_SUCCESS) {
-        return 1;
-    }
-
-    cc3xx_aes_set_output_buffer(ciphertext, ciphertext_length);
-    cc3xx_aes_update(plaintext, ciphertext_length);
-    cc3xx_aes_finish(NULL);
-}
-
-/* This is a counter-mode KDF complying with NIST SP800-108 where the PRF is a
- * combined sha256 hash and an ECB-mode AES encryption. ECB is acceptable here
- * since the input to the PRF is a hash, and the hash input is different every
- * time because of the counter being part of the input.
- */
-int32_t bl1_derive_key(enum tfm_bl1_key_id_t input_key, const uint8_t *label,
+int32_t bl1_derive_key(enum tfm_bl1_key_id_t key_id, const uint8_t *label,
                        size_t label_length, const uint8_t *context,
                        size_t context_length, uint8_t *output_key,
                        size_t output_length)
 {
-    uint8_t state[KEY_DERIVATION_MAX_BUF_SIZE];
-    uint8_t state_size = label_length + context_length + sizeof(uint8_t)
-                         + 2 * sizeof(uint32_t);
-    uint8_t state_hash[32];
-    uint32_t L = output_length;
-    uint32_t n = (output_length + sizeof(state_hash) - 1) / sizeof(state_hash);
-    uint32_t i = 1;
-    size_t output_idx = 0;
-    cc3xx_err_t rc;
+    enum kmu_hardware_keyslot_t kmu_key_slot;
+    uint32_t key_buf[32 / sizeof(uint32_t)];
+    uint8_t *input_key = (uint8_t *)key_buf;
+    int32_t rc = 0;
+    cc3xx_err_t err;
 
-    if (output_length == 0) {
-        return 0;
+    rc = bl1_key_to_kmu_key(key_id, &kmu_key_slot, &input_key, sizeof(key_buf));
+    if (rc) {
+        return rc;
     }
 
-    if (label == NULL || label_length == 0 ||
-        context == NULL || context_length == 0 ||
-        output_key == NULL) {
-        return -1;
+    err = cc3xx_kdf_cmac(kmu_key_slot, (uint32_t *)input_key,
+                         CC3XX_AES_KEYSIZE_256, label, label_length, context,
+                         context_length, (uint32_t *)output_key, output_length);
+    if (err != CC3XX_ERR_SUCCESS) {
+        return 1;
     }
 
-    if (state_size > KEY_DERIVATION_MAX_BUF_SIZE) {
-        return -1;
-    }
-
-    memcpy(state + sizeof(uint32_t), label, label_length);
-    memset(state + sizeof(uint32_t) + label_length, 0, sizeof(uint8_t));
-    memcpy(state + sizeof(uint32_t) + label_length + sizeof(uint8_t),
-           context, context_length);
-    memcpy(state + sizeof(uint32_t) + label_length + sizeof(uint8_t) + context_length,
-           &L, sizeof(uint32_t));
-
-    for (i = 1; i < n; i++) {
-        memcpy(state, &i, sizeof(uint32_t));
-
-        /* Hash the state to make it a constant size */
-        rc = bl1_sha256_compute(state, state_size, state_hash);
-        if (rc != CC3XX_ERR_SUCCESS) {
-            goto err;
-        }
-
-        /* Encrypt using ECB, which is fine because the state is different every
-         * time and we're hashing it.
-         */
-        rc = aes_256_ecb_encrypt(input_key, state_hash, sizeof(state_hash),
-                                 output_key + output_idx);
-        if (rc != CC3XX_ERR_SUCCESS) {
-            goto err;
-        }
-
-        output_idx += sizeof(state_hash);
-    }
-
-    /* For the last block, encrypt into the state buf and then memcpy out how
-     * much we need
-     */
-    memcpy(state, &i, sizeof(uint32_t));
-
-    rc = bl1_sha256_compute(state, state_size, state_hash);
-    if (rc != CC3XX_ERR_SUCCESS) {
-        goto err;
-    }
-
-    /* This relies on us being able to have overlapping input and output
-     * pointers.
-     */
-    rc = aes_256_ecb_encrypt(input_key, state_hash, sizeof(state_hash),
-                             state_hash);
-    if (rc != CC3XX_ERR_SUCCESS) {
-        goto err;
-    }
-
-    memcpy(output_key + output_idx, state_hash, output_length - output_idx);
-    memset(state, 0, sizeof(state));
-    memset(state_hash, 0, sizeof(state_hash));
-
     return 0;
-
-err:
-    memset(output_key, 0, output_length);
-    memset(state, 0, sizeof(state));
-    memset(state_hash, 0, sizeof(state_hash));
-    return rc;
 }
diff --git a/platform/ext/target/arm/rss/common/cc312/cc3xx_aes_external_key_loader.c b/platform/ext/target/arm/rss/common/cc312/cc3xx_aes_external_key_loader.c
new file mode 100644
index 0000000..36bc894
--- /dev/null
+++ b/platform/ext/target/arm/rss/common/cc312/cc3xx_aes_external_key_loader.c
@@ -0,0 +1,92 @@
+/*
+ * Copyright (c) 2023, Arm Limited. All rights reserved.
+ *
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ */
+
+#include "cc3xx_aes_external_key_loader.h"
+
+#include "cc3xx_stdlib.h"
+#include "cc3xx_dev.h"
+
+#include "device_definition.h"
+
+cc3xx_err_t set_key(cc3xx_aes_key_id_t key_id, const uint32_t *key,
+                    cc3xx_aes_keysize_t key_size, bool is_tun1)
+{
+    enum kmu_error_t kmu_err;
+    volatile uint32_t *hw_key_buf_ptr;
+    size_t key_word_size = 4 + (key_size * 2);
+
+    hw_key_buf_ptr = P_CC3XX->aes.aes_key_0;
+#if defined(CC3XX_CONFIG_AES_CCM_ENABLE) && defined(CC3XX_CONFIG_AES_TUNNELLING_ENABLE)
+    if (is_tun1) {
+        hw_key_buf_ptr = P_CC3XX->aes.aes_key_1;
+    }
+#endif /* defined(CC3XX_CONFIG_AES_CCM_ENABLE) && defined(CC3XX_CONFIG_AES_TUNNELLING_ENABLE) */
+
+    /* Check if the HOST_FATAL_ERROR mode is enabled */
+    if (P_CC3XX->ao.host_ao_lock_bits & 0x1U) {
+        return CC3XX_ERR_INVALID_STATE;
+    }
+
+    /* Set key0 size */
+    if (!is_tun1) {
+        P_CC3XX->aes.aes_control &= ~(0b11U << 12);
+        P_CC3XX->aes.aes_control |= (key_size & 0b11U) << 12;
+    }
+#if defined(CC3XX_CONFIG_AES_CCM_ENABLE) && defined(CC3XX_CONFIG_AES_TUNNELLING_ENABLE)
+    if (is_tun1) {
+        /* Set key1 size */
+        P_CC3XX->aes.aes_control &= ~(0b11U << 14);
+        P_CC3XX->aes.aes_control |= (key_size & 0b11U) << 14;
+    }
+#endif /* defined(CC3XX_CONFIG_AES_CCM_ENABLE) && defined(CC3XX_CONFIG_AES_TUNNELLING_ENABLE) */
+
+    if (key != NULL) {
+#ifdef CC3XX_CONFIG_DPA_MITIGATIONS_ENABLE
+        cc3xx_dpa_hardened_word_copy(hw_key_buf_ptr, key, key_word_size - 1);
+        hw_key_buf_ptr[key_word_size - 1] = key[key_word_size - 1];
+#else
+        hw_key_buf_ptr[0] = key[0];
+        hw_key_buf_ptr[1] = key[1];
+        hw_key_buf_ptr[2] = key[2];
+        hw_key_buf_ptr[3] = key[3];
+        if (key_size > CC3XX_AES_KEYSIZE_128) {
+            hw_key_buf_ptr[4] = key[4];
+            hw_key_buf_ptr[5] = key[5];
+        }
+        if (key_size > CC3XX_AES_KEYSIZE_192) {
+            hw_key_buf_ptr[6] = key[6];
+            hw_key_buf_ptr[7] = key[7];
+        }
+#endif /* CC3XX_CONFIG_DPA_MITIGATIONS_ENABLE */
+    } else {
+        /* Hardware keys are locked to aes_key_0 */
+        if (is_tun1 && key_id < KMU_USER_SLOT_MIN) {
+            while(1){}
+            return CC3XX_ERR_KEY_IMPORT_FAILED;
+        }
+
+        /* It's an error to use an unlocked slot */
+        kmu_err = kmu_get_key_export_config_locked(&KMU_DEV_S, key_id);
+        if (kmu_err != KMU_ERROR_SLOT_LOCKED) {
+            while(1){}
+            return CC3XX_ERR_KEY_IMPORT_FAILED;
+        }
+        kmu_err = kmu_get_key_locked(&KMU_DEV_S, key_id);
+        if (kmu_err != KMU_ERROR_SLOT_LOCKED) {
+            while(1){}
+            return CC3XX_ERR_KEY_IMPORT_FAILED;
+        }
+
+        kmu_err = kmu_export_key(&KMU_DEV_S, key_id);
+        if (kmu_err != KMU_ERROR_NONE) {
+            while(1){}
+            return CC3XX_ERR_KEY_IMPORT_FAILED;
+        }
+    }
+
+    return CC3XX_ERR_SUCCESS;
+}
diff --git a/platform/ext/target/arm/rss/common/cc312/cc3xx_aes_external_key_loader.h b/platform/ext/target/arm/rss/common/cc312/cc3xx_aes_external_key_loader.h
new file mode 100644
index 0000000..683af3f
--- /dev/null
+++ b/platform/ext/target/arm/rss/common/cc312/cc3xx_aes_external_key_loader.h
@@ -0,0 +1,28 @@
+/*
+ * Copyright (c) 2023, Arm Limited. All rights reserved.
+ *
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ */
+
+#ifndef CC3XX_AES_EXTERNAL_KEY_LOADER_H
+#define CC3XX_AES_EXTERNAL_KEY_LOADER_H
+
+#include "cc3xx_error.h"
+#include "cc3xx_aes.h"
+
+#include <stdint.h>
+#include <stddef.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+cc3xx_err_t set_key(cc3xx_aes_key_id_t key_id, const uint32_t *key,
+                    cc3xx_aes_keysize_t key_size, bool is_tun1);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* CC3XX_AES_EXTERNAL_KEY_LOADER_H */
diff --git a/platform/ext/target/arm/rss/common/cc312/cc3xx_config.h b/platform/ext/target/arm/rss/common/cc312/cc3xx_config.h
index c3417f2..659ab78 100644
--- a/platform/ext/target/arm/rss/common/cc312/cc3xx_config.h
+++ b/platform/ext/target/arm/rss/common/cc312/cc3xx_config.h
@@ -39,7 +39,7 @@
 /* #define CC3XX_CONFIG_AES_GCM_VARIABLE_IV_ENABLE */
 
 /* Whether the AES CMAC support is enabled */
-/* #define CC3XX_CONFIG_AES_CMAC_ENABLE */
+#define CC3XX_CONFIG_AES_CMAC_ENABLE
 
 /* Whether the AES CCM support is enabled */
 #define CC3XX_CONFIG_AES_CCM_ENABLE
@@ -52,6 +52,11 @@
  * CTR decryption having to be done seperately. */
 #define CC3XX_CONFIG_AES_TUNNELLING_ENABLE
 
+/* Whether an external key-loader should be invoked instead of the standard AES
+ * hardware key loading mechanism
+ */
+#define CC3XX_CONFIG_AES_EXTERNAL_KEY_LOADER
+
 /* Whether CHACHA is enabled */
 /* #define CC3XX_CONFIG_CHACHA_ENABLE */
 
diff --git a/platform/ext/target/arm/rss/common/cpak_generator/cpak_generator.c b/platform/ext/target/arm/rss/common/cpak_generator/cpak_generator.c
index 70d8cf0..294f7c2 100644
--- a/platform/ext/target/arm/rss/common/cpak_generator/cpak_generator.c
+++ b/platform/ext/target/arm/rss/common/cpak_generator/cpak_generator.c
@@ -1,12 +1,12 @@
 /*
- * Copyright (c) 2022, Arm Limited. All rights reserved.
+ * Copyright (c) 2022-2023, Arm Limited. All rights reserved.
  *
  * SPDX-License-Identifier: BSD-3-Clause
  *
  */
 
 #include "psa/crypto.h"
-#include "mbedtls/aes.h"
+#include "mbedtls/cmac.h"
 
 #include <stdint.h>
 #include <stdio.h>
@@ -52,70 +52,54 @@
     return 0;
 }
 
-
-int generate_seed(uint8_t *bl1_2_hash, uint8_t *guk, uint8_t *seed_buf)
+int generate_boot_state(uint8_t *bl1_2_hash, uint8_t *boot_state)
 {
-    uint8_t label[] = "BL1_CPAK_SEED_DERIVATION";
-    uint8_t context[32 + sizeof(uint32_t) * 2] = {0};
-    uint8_t state[sizeof(context) + sizeof(label) + sizeof(uint8_t)
-                  + sizeof(uint32_t) * 2];
-    uint8_t state_hash[PSA_HASH_LENGTH(PSA_ALG_SHA_256)];
-    mbedtls_aes_context ctx;
+    uint8_t context[PSA_HASH_LENGTH(PSA_ALG_SHA_256) + 2 * sizeof(uint32_t)];
     uint32_t reprovisioning_bits = 0;
     uint32_t lcs = 3;
+
+    memcpy(context, &lcs, sizeof(uint32_t));
+
+    memcpy(context + sizeof(uint32_t), &reprovisioning_bits, sizeof(uint32_t));
+
+    memcpy(context + (2 * sizeof(uint32_t)), bl1_2_hash, 32);
+
+    return mbedtls_sha256(context, sizeof(context), boot_state, 0);
+}
+
+int generate_seed(uint8_t *boot_state, uint8_t *guk, uint8_t *seed_buf)
+{
+    uint8_t label[] = "BL1_CPAK_SEED_DERIVATION";
+    uint8_t state[PSA_HASH_LENGTH(PSA_ALG_SHA_256) + sizeof(label) + sizeof(uint8_t)
+                  + (sizeof(uint32_t) * 2)];
     uint32_t seed_output_length = 32;
     uint32_t block_index = 1;
-    size_t state_hash_len;
     psa_status_t status;
     int rc;
 
-    memcpy(context, bl1_2_hash, 32);
-
-    memcpy(context + 32, &lcs, sizeof(uint32_t));
-
-    memcpy(context + 32 + sizeof(uint32_t), &reprovisioning_bits,
-           sizeof(uint32_t));
-
+    memcpy(state, &block_index, sizeof(uint32_t));
     memcpy(state + sizeof(uint32_t), label, sizeof(label));
     memset(state + sizeof(uint32_t) + sizeof(label), 0, sizeof(uint8_t));
     memcpy(state + sizeof(uint32_t) + sizeof(label) + sizeof(uint8_t),
-           context, sizeof(context));
-    memcpy(state + sizeof(uint32_t) + sizeof(label) + sizeof(uint8_t) +
-           sizeof(context), &seed_output_length, sizeof(uint32_t));
+           boot_state, 32);
+    memcpy(state + sizeof(uint32_t) + sizeof(label) + sizeof(uint8_t) + 32,
+           &seed_output_length, sizeof(uint32_t));
 
-    status = psa_hash_compute(PSA_ALG_SHA_256, state, sizeof(state),
-                              state_hash, sizeof(state_hash), &state_hash_len);
-    if (status != PSA_SUCCESS) {
-        return 1;
+    rc = mbedtls_cipher_cmac(mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_256_ECB),
+                             guk, 256, state, sizeof(state), seed_buf);
+    if (rc) {
+        return rc;
     }
 
+    block_index += 1;
     memcpy(state, &block_index, sizeof(uint32_t));
 
-    status = psa_hash_compute(PSA_ALG_SHA_256, state, sizeof(state),
-                              state_hash, sizeof(state_hash), &state_hash_len);
-    if (status != PSA_SUCCESS) {
-        return 1;
-    }
-
-    mbedtls_aes_init(&ctx);
-    rc = mbedtls_aes_setkey_enc(&ctx, guk, 256);
+    rc = mbedtls_cipher_cmac(mbedtls_cipher_info_from_type(MBEDTLS_CIPHER_AES_256_ECB),
+                             guk, 256, state, sizeof(state), seed_buf + 16);
     if (rc) {
         return rc;
     }
 
-    rc = mbedtls_aes_crypt_ecb(&ctx, MBEDTLS_AES_ENCRYPT, state_hash, seed_buf);
-    if (rc) {
-        return rc;
-    }
-
-    rc = mbedtls_aes_crypt_ecb(&ctx, MBEDTLS_AES_ENCRYPT, state_hash + 16,
-                               seed_buf + 16);
-    if (rc) {
-        return rc;
-    }
-
-    mbedtls_aes_free(&ctx);
-
     return 0;
 }
 
@@ -218,7 +202,8 @@
 int main (int argc, char *argv[])
 {
     int rc;
-    uint8_t bl1_2_hash[32];
+    uint8_t bl1_2_hash[PSA_HASH_LENGTH(PSA_ALG_SHA_256)];
+    uint8_t boot_state[PSA_HASH_LENGTH(PSA_ALG_SHA_256)];
     uint8_t guk[32];
     uint8_t seed_buf[32];
     psa_key_handle_t cpak_handle;
@@ -241,7 +226,13 @@
         return rc;
     }
 
-    rc = generate_seed(bl1_2_hash, guk, seed_buf);
+    rc = generate_boot_state(bl1_2_hash, boot_state);
+    if (rc) {
+        printf("boot state generation failed\r\n");
+        return rc;
+    }
+
+    rc = generate_seed(boot_state, guk, seed_buf);
     if (rc) {
         printf("cpak seed generation failed\r\n");
         return rc;
diff --git a/platform/ext/target/arm/rss/common/provisioning/bundle_cm/cm_dummy_provisioning_data.c b/platform/ext/target/arm/rss/common/provisioning/bundle_cm/cm_dummy_provisioning_data.c
index 9406d9d..aafc9e9 100644
--- a/platform/ext/target/arm/rss/common/provisioning/bundle_cm/cm_dummy_provisioning_data.c
+++ b/platform/ext/target/arm/rss/common/provisioning/bundle_cm/cm_dummy_provisioning_data.c
@@ -16,10 +16,10 @@
     0,
     /* GUK */
     {
-        0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
-        0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
-        0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07,
-        0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
+        0x01, 0x23, 0x45, 0x67, 0x89, 0x01, 0x23, 0x45,
+        0x67, 0x89, 0x01, 0x23, 0x45, 0x67, 0x89, 0x01,
+        0x23, 0x45, 0x67, 0x89, 0x01, 0x23, 0x45, 0x67,
+        0x89, 0x01, 0x23, 0x45, 0x67, 0x89, 0x01, 0x23,
     },
     /* CCA system properties placeholder */
     0xDEADBEEF,
diff --git a/platform/ext/target/arm/rss/common/rss_key_derivation.c b/platform/ext/target/arm/rss/common/rss_key_derivation.c
index c5e4fa2..fd60e14 100644
--- a/platform/ext/target/arm/rss/common/rss_key_derivation.c
+++ b/platform/ext/target/arm/rss/common/rss_key_derivation.c
@@ -11,32 +11,16 @@
 #include "crypto.h"
 #include "otp.h"
 #include "tfm_plat_otp.h"
+#include "dpa_hardened_word_copy.h"
+#include "cc3xx_drv.h"
 
 #include <stdint.h>
 #include <string.h>
 
 extern uint8_t computed_bl1_2_hash[];
 
-static int rss_derive_key(enum tfm_bl1_key_id_t key_id, const uint8_t *label,
-                          size_t label_len, uint8_t *out)
-{
-    int rc;
-    uint8_t context[32] = {0};
-    size_t context_len;
-
-    rc = rss_get_boot_state(context, sizeof(context), &context_len);
-    if (rc) {
-        return rc;
-    }
-
-    rc = bl1_derive_key(key_id, label, label_len, context,
-                        context_len, out, 32);
-
-    return rc;
-}
-
-int rss_get_boot_state(uint8_t *state, size_t state_buf_len,
-                       size_t *state_size)
+static int rss_get_boot_state(uint8_t *state, size_t state_buf_len,
+                              size_t *state_size)
 {
     int rc;
     enum plat_otp_lcs_t lcs;
@@ -82,36 +66,98 @@
     return 0;
 }
 
-static int set_and_lock_kmu_slot(uint8_t *key, uint32_t kmu_output_slot)
+static int rss_derive_key(enum tfm_bl1_key_id_t key_id, const uint8_t *label,
+                          size_t label_len, enum rss_kmu_slot_id_t slot,
+                          bool duplicate_into_next_slot)
 {
+    int rc;
+    uint8_t context[32] = {0};
+    size_t context_len;
     enum kmu_error_t kmu_err;
+    volatile uint32_t *p_kmu_slot_buf;
+    volatile uint32_t *p_kmu_secondary_slot_buf;
+    size_t kmu_slot_size;
 
-    kmu_err = kmu_set_key(&KMU_DEV_S, kmu_output_slot, key, 32);
+    rc = rss_get_boot_state(context, sizeof(context), &context_len);
+    if (rc) {
+        return rc;
+    }
+
+    kmu_err = kmu_get_key_buffer_ptr(&KMU_DEV_S, slot,
+                                     &p_kmu_slot_buf, &kmu_slot_size);
     if (kmu_err != KMU_ERROR_NONE) {
         return -1;
     }
 
-    /* TODO lock the key slots once they can be used by the runtime CC driver */
-    /* kmu_err = kmu_set_key_locked(&KMU_DEV_S, kmu_output_slot); */
-    /* if (kmu_err != KMU_ERROR_NONE) { */
-    /*     return -2; */
-    /* } */
-
-    return 0;
-}
-
-int rss_derive_vhuk_seed(uint8_t *vhuk_seed, size_t vhuk_seed_buf_len,
-                         size_t *vhuk_seed_size)
-{
-    uint8_t vhuk_label[] = "BL1_VHUK_DERIVATION";
-    int rc = 0;
-
-    if (vhuk_seed_buf_len < 32) {
-        return -1;
+    rc = bl1_derive_key(key_id, label, label_len, context,
+                        context_len, (uint8_t *)p_kmu_slot_buf, 32);
+    if (rc) {
+        return rc;
     }
 
-    rc = rss_derive_key(TFM_BL1_KEY_HUK, vhuk_label,
-                        sizeof(vhuk_label), vhuk_seed);
+    /* Due to limitations in CryptoCell, any key that needs to be used for
+     * AES-CCM needs to be duplicated into a second slot.
+     */
+    if (duplicate_into_next_slot) {
+        kmu_err = kmu_get_key_buffer_ptr(&KMU_DEV_S, slot + 1,
+                                         &p_kmu_secondary_slot_buf,
+                                         &kmu_slot_size);
+        if (kmu_err != KMU_ERROR_NONE) {
+            return -3;
+        }
+
+        dpa_hardened_word_copy(p_kmu_secondary_slot_buf, p_kmu_slot_buf,
+                        kmu_slot_size / sizeof(uint32_t));
+
+        /* TODO lock keyslots once the runtime CC3XX driver supports locked KMU
+         * keyslots
+         */
+        /* kmu_err = kmu_set_key_locked(&KMU_DEV_S, slot + 1); */
+        /* if (kmu_err != KMU_ERROR_NONE) { */
+        /*     return -5; */
+        /* } */
+    }
+
+    /* TODO lock keyslots once the runtime CC3XX driver supports locked KMU
+     * keyslots
+     */
+    /* kmu_err = kmu_set_key_locked(&KMU_DEV_S, slot); */
+    /* if (kmu_err != KMU_ERROR_NONE) { */
+    /*     return -4; */
+    /* } */
+
+    return rc;
+}
+
+int rss_derive_cpak_seed(enum rss_kmu_slot_id_t slot)
+{
+    uint8_t cpak_seed_label[] = "BL1_CPAK_SEED_DERIVATION";
+
+    return rss_derive_key(TFM_BL1_KEY_GUK, cpak_seed_label,
+                          sizeof(cpak_seed_label), slot, false);
+}
+
+int rss_derive_dak_seed(enum rss_kmu_slot_id_t slot)
+{
+    uint8_t dak_seed_label[]  = "BL1_DAK_SEED_DERIVATION";
+
+    return rss_derive_key(TFM_BL1_KEY_GUK, dak_seed_label,
+                          sizeof(dak_seed_label), slot, false);
+}
+
+int rss_derive_vhuk_seed(uint32_t *vhuk_seed, size_t vhuk_seed_buf_len,
+                         size_t *vhuk_seed_size)
+{
+    uint8_t vhuk_seed_label[]  = "VHUK_SEED_DERIVATION";
+    int rc;
+
+    if (vhuk_seed_buf_len != 32) {
+        return 1;
+    }
+
+    rc = cc3xx_kdf_cmac(KMU_HW_SLOT_HUK, NULL, CC3XX_AES_KEYSIZE_256,
+                        vhuk_seed_label, sizeof(vhuk_seed_label), NULL, 0,
+                        vhuk_seed, 32);
     if (rc) {
         return rc;
     }
@@ -121,108 +167,15 @@
     return 0;
 }
 
-int rss_derive_cpak_seed(uint32_t kmu_output_slot)
-{
-    uint8_t cpak_seed_label[] = "BL1_CPAK_SEED_DERIVATION";
-    uint8_t __attribute__((__aligned__(4))) key_buf[32];
-    int rc = 0;
-
-    rc = rss_derive_key(TFM_BL1_KEY_GUK, cpak_seed_label,
-                        sizeof(cpak_seed_label), key_buf);
-    if (rc) {
-        goto out;
-    }
-
-    rc = set_and_lock_kmu_slot(key_buf, kmu_output_slot);
-    if (rc) {
-        goto out;
-    }
-
-out:
-    memset(key_buf, 0, sizeof(key_buf));
-
-    return rc;
-}
-
-int rss_derive_dak_seed(uint32_t kmu_output_slot)
-{
-    uint8_t dak_seed_label[]  = "BL1_DAK_SEED_DERIVATION";
-    uint8_t __attribute__((__aligned__(4))) key_buf[32];
-    int rc = 0;
-
-    rc = rss_derive_key(TFM_BL1_KEY_GUK, dak_seed_label,
-                        sizeof(dak_seed_label), key_buf);
-    if (rc) {
-        goto out;
-    }
-
-    rc = set_and_lock_kmu_slot(key_buf, kmu_output_slot);
-    if (rc) {
-        goto out;
-    }
-
-out:
-    memset(key_buf, 0, sizeof(key_buf));
-
-    return rc;
-}
-
 int rss_derive_vhuk(const uint8_t *vhuk_seeds, size_t vhuk_seeds_len,
-                    uint32_t kmu_output_slot)
+                    enum rss_kmu_slot_id_t slot)
 {
-    uint8_t __attribute__((__aligned__(4))) key_buf[32];
-    int rc;
-
-    if (vhuk_seeds_len != RSS_AMOUNT * 32) {
-        return -1;
-    }
-
-    rc = bl1_sha256_compute(vhuk_seeds, vhuk_seeds_len, key_buf);
-    if (rc) {
-        return rc;
-    }
-
-    rc = set_and_lock_kmu_slot(key_buf, kmu_output_slot);
-    if (rc) {
-        goto out;
-    }
-
-out:
-    memset(key_buf, 0, sizeof(key_buf));
-
-    return rc;
+    return rss_derive_key(TFM_BL1_KEY_GUK, vhuk_seeds, vhuk_seeds_len,
+                          slot, false);
 }
 
 int rss_derive_session_key(const uint8_t *ivs, size_t ivs_len,
-                           uint32_t kmu_output_slot)
+                           enum rss_kmu_slot_id_t slot)
 {
-    uint8_t common_iv[32];
-    uint8_t __attribute__((__aligned__(4))) key_buf[32];
-    int rc;
-
-    if (ivs_len != RSS_AMOUNT * 32) {
-        return -1;
-    }
-
-    rc = bl1_sha256_compute(ivs, ivs_len, common_iv);
-    if (rc) {
-        return rc;
-    }
-
-    rc = rss_derive_key(TFM_BL1_KEY_GUK, common_iv,
-                        sizeof(common_iv), key_buf);
-    if (rc) {
-        goto out;
-    }
-
-    rc = set_and_lock_kmu_slot(key_buf, kmu_output_slot);
-    if (rc) {
-        goto out;
-    }
-
-out:
-    memset(key_buf, 0, sizeof(key_buf));
-
-    return rc;
+    return rss_derive_key(TFM_BL1_KEY_GUK, ivs, ivs_len, slot, true);
 }
-
diff --git a/platform/ext/target/arm/rss/common/rss_key_derivation.h b/platform/ext/target/arm/rss/common/rss_key_derivation.h
index 0f711fb..128b691 100644
--- a/platform/ext/target/arm/rss/common/rss_key_derivation.h
+++ b/platform/ext/target/arm/rss/common/rss_key_derivation.h
@@ -10,24 +10,13 @@
 
 #include <stdint.h>
 #include <stddef.h>
+#include "rss_kmu_slot_ids.h"
 
 #ifdef __cplusplus
 extern "C" {
 #endif
 
 /**
- * \brief                     Get the boot state.
-
- * \param[out] state          The buffer to get the state into.
- * \param[in]  state_buf_len  The size of the state buffer.
- * \param[out] state_size     The size of the state.
- *
- * \return                    0 on success, non-zero on error.
- */
-int rss_get_boot_state(uint8_t *state, size_t state_buf_len,
-                       size_t *state_size);
-
-/**
  * \brief                     Derive a VHUK seed.
 
  * \param[out] vhuk_seed         The buffer to derive the seed into.
@@ -36,7 +25,7 @@
  *
  * \return                    0 on success, non-zero on error.
  */
-int rss_derive_vhuk_seed(uint8_t *vhuk_seed, size_t vhuk_seed_buf_len,
+int rss_derive_vhuk_seed(uint32_t *vhuk_seed, size_t vhuk_seed_buf_len,
                          size_t *vhuk_seed_size);
 
 /**
@@ -46,7 +35,7 @@
  *
  * \return                    0 on success, non-zero on error.
  */
-int rss_derive_cpak_seed(uint32_t kmu_output_slot);
+int rss_derive_cpak_seed(enum rss_kmu_slot_id_t slot);
 
 /**
  * \brief                     Derive the DAK seed, and lock in a KMU slot.
@@ -55,7 +44,7 @@
  *
  * \return                    0 on success, non-zero on error.
  */
-int rss_derive_dak_seed(uint32_t kmu_output_slot);
+int rss_derive_dak_seed(enum rss_kmu_slot_id_t slot);
 
 /**
  * \brief                     Derive the VHUK, and lock in a KMU slot.
@@ -68,7 +57,7 @@
  * \return                    0 on success, non-zero on error.
  */
 int rss_derive_vhuk(const uint8_t *vhuk_seeds, size_t vhuk_seeds_len,
-                    uint32_t kmu_output_slot);
+                    enum rss_kmu_slot_id_t slot);
 
 /**
  * \brief                     Derive the session key, and lock in a KMU slot.
@@ -81,7 +70,7 @@
  * \return                    0 on success, non-zero on error.
  */
 int rss_derive_session_key(const uint8_t *ivs, size_t ivs_len,
-                           uint32_t kmu_output_slot);
+                           enum rss_kmu_slot_id_t slot);
 
 #ifdef __cplusplus
 }
diff --git a/platform/ext/target/arm/rss/common/rss_kmu_slot_ids.h b/platform/ext/target/arm/rss/common/rss_kmu_slot_ids.h
new file mode 100644
index 0000000..d82811b
--- /dev/null
+++ b/platform/ext/target/arm/rss/common/rss_kmu_slot_ids.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (c) 2023, Arm Limited. All rights reserved.
+ *
+ * SPDX-License-Identifier: BSD-3-Clause
+ *
+ */
+
+#ifndef __RSS_KMU_SLOT_IDS_H__
+#define __RSS_KMU_SLOT_IDS_H__
+
+#include "kmu_drv.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+enum rss_kmu_slot_id_t {
+    RSS_KMU_SLOT_VHUK = KMU_USER_SLOT_MIN,
+    RSS_KMU_SLOT_CPAK_SEED,
+    RSS_KMU_SLOT_DAK_SEED,
+    /* The session key is used for AEAD, so requires two contiguous slots. Only
+     * the first should be used for calls, the key loader and derivation code
+     * will transparently use the second where necessary.
+     */
+    RSS_KMU_SLOT_SESSION_KEY_0,
+    _RSS_KMU_AEAD_RESERVED_SLOT_SESSION_KEY,
+};
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif /* __RSS_KMU_SLOT_IDS_H__ */