feat(s2tt): add support for FEAT_LPA2 to the s2tte library

During realm creation, RMM verifies the IPA size and start level
against LPA2 settings setup by the Host.

This patch also checks that the memory banks specified in the
Boot Manifest are within the platform address space.

Signed-off-by: Javier Almansa Sobrino <javier.almansasobrino@arm.com>
Change-Id: I2c42b7cb9ce40cd17a7426bafa15f0a2fbee0dac
diff --git a/lib/s2tt/src/s2tt.c b/lib/s2tt/src/s2tt.c
index 730bed5..26c68d5 100644
--- a/lib/s2tt/src/s2tt.c
+++ b/lib/s2tt/src/s2tt.c
@@ -4,16 +4,51 @@
  */
 
 #include <arch_helpers.h>
+#include <assert.h>
 #include <bitmap.h>
 #include <granule.h>
 #include <ripas.h>
 #include <s2tt.h>
 #include <s2tt_pvt_defs.h>
 #include <smc.h>
+#include <stdbool.h>
 #include <stddef.h>
 
-/* TODO: Fix this when introducing LPA2 support */
-COMPILER_ASSERT(S2TT_MIN_STARTING_LEVEL >= 0);
+/*
+ * Return a mask for the IPA field on a S2TTE
+ */
+static unsigned long s2tte_lvl_mask(long level, bool lpa2)
+{
+	assert(level <= S2TT_PAGE_LEVEL);
+	assert(level >= S2TT_MIN_STARTING_LEVEL_LPA2);
+
+	unsigned long mask;
+	unsigned long levels = (unsigned long)(S2TT_PAGE_LEVEL - level);
+	unsigned long lsb = (levels * S2TTE_STRIDE) + GRANULE_SHIFT;
+
+	mask = BIT_MASK_ULL((S2TTE_OA_BITS - 1U), lsb);
+
+	if (lpa2 == true) {
+		mask |= (MASK(LPA2_S2TTE_51_50) | MASK(LPA2_OA_49_48));
+	}
+
+	return mask;
+}
+
+/*
+ * Extracts the PA mapped by an S2TTE, aligned to a given level.
+ */
+static unsigned long s2tte_to_pa(unsigned long s2tte, long level, bool lpa2)
+{
+	unsigned long pa = s2tte & s2tte_lvl_mask(level, lpa2);
+
+	if (lpa2 == true) {
+		pa &= ~MASK(LPA2_S2TTE_51_50);
+		pa |= INPLACE(LPA2_OA_51_50, EXTRACT(LPA2_S2TTE_51_50, s2tte));
+	}
+
+	return pa;
+}
 
 /*
  * Invalidates S2 TLB entries from [ipa, ipa + size] region tagged with `vmid`.
@@ -77,6 +112,40 @@
 }
 
 /*
+ * Returns true if s2tte has 'output address' field, namely, if it is one of:
+ * - assigned_empty
+ * - assigned_ram
+ * - assigned_ns
+ * - assigned_destroyed
+ * - table
+ */
+static bool s2tte_has_pa(const struct s2tt_context *s2_ctx,
+			 unsigned long s2tte, long level)
+{
+	unsigned long desc_type = s2tte & S2TT_DESC_TYPE_MASK;
+
+	return ((desc_type != S2TTE_INVALID) ||	/* block, page or table */
+		s2tte_is_assigned_empty(s2_ctx, s2tte, level) ||
+		s2tte_is_assigned_destroyed(s2_ctx, s2tte, level));
+}
+
+/*
+ * Creates a TTE containing only the PA.
+ * This function expects 'pa' to be aligned and bounded.
+ */
+static unsigned long pa_to_s2tte(unsigned long pa, bool lpa2)
+{
+	unsigned long tte = pa;
+
+	if (lpa2 == true) {
+		tte &= ~MASK(LPA2_OA_51_50);
+		tte |= INPLACE(LPA2_S2TTE_51_50, EXTRACT(LPA2_OA_51_50, pa));
+	}
+
+	return tte;
+}
+
+/*
  * Invalidate S2 TLB entries with "addr" IPA.
  * Call this function after:
  * 1.  A L3 page desc has been removed.
@@ -112,7 +181,7 @@
 
 /*
  * Return the index of the entry describing @addr in the translation table at
- * level @level.  This only works for non-concatenated page tables, so should
+ * level @level. This only works for non-concatenated page tables, so should
  * not be called to get the index for the starting level.
  *
  * See the library pseudocode
@@ -122,18 +191,20 @@
 static unsigned long s2_addr_to_idx(unsigned long addr, long level)
 {
 	unsigned int levels, lsb;
+	unsigned int s2tte_stride = (level < S2TT_MIN_STARTING_LEVEL) ?
+					S2TTE_STRIDE_LM1 : S2TTE_STRIDE;
 
 	levels = (unsigned int)(S2TT_PAGE_LEVEL - level);
 	lsb = (levels * S2TTE_STRIDE) + GRANULE_SHIFT;
 
 	addr >>= lsb;
-	addr &= (1UL << S2TTE_STRIDE) - 1UL;
+	addr &= (1UL << s2tte_stride) - 1UL;
 	return addr;
 }
 
 /*
  * Return the index of the entry describing @addr in the translation table
- * starting level.  This may return an index >= S2TTES_PER_S2TT when the
+ * starting level. This may return an index >= S2TTES_PER_S2TT when the
  * combination of @start_level and @ipa_bits implies concatenated
  * stage 2 tables.
  *
@@ -149,26 +220,12 @@
 	levels = (unsigned int)(S2TT_PAGE_LEVEL - start_level);
 	lsb = (levels * S2TTE_STRIDE) + GRANULE_SHIFT;
 
-	addr &= (1UL << ipa_bits) - 1UL;
+	addr &= ((1UL << ipa_bits) - 1UL);
 	addr >>= lsb;
 	return addr;
 }
 
-static unsigned long addr_level_mask(unsigned long addr, long level)
-{
-	unsigned int levels, lsb, msb;
-
-	assert(level <= S2TT_PAGE_LEVEL);
-	assert(level >= S2TT_MIN_STARTING_LEVEL);
-
-	levels = (unsigned int)(S2TT_PAGE_LEVEL - level);
-	lsb = (levels * S2TTE_STRIDE) + GRANULE_SHIFT;
-	msb = S2TTE_OA_BITS - 1U;
-
-	return (addr & BIT_MASK_ULL(msb, lsb));
-}
-
-static inline bool entry_is_table(unsigned long entry)
+static bool entry_is_table(unsigned long entry)
 {
 	return ((entry & S2TT_DESC_TYPE_MASK) == S2TTE_L012_TABLE);
 }
@@ -190,19 +247,22 @@
 	return entry;
 }
 
-#define table_entry_to_phys(tte)	addr_level_mask(tte, S2TT_PAGE_LEVEL)
+#define table_entry_to_phys(tte, lpa2)			\
+				s2tte_to_pa(tte, S2TT_PAGE_LEVEL, lpa2)
 
 static struct granule *find_next_level_idx(const struct s2tt_context *s2_ctx,
 					   struct granule *g_tbl,
 					   unsigned long idx)
 {
+	assert(s2_ctx != NULL);
+
 	const unsigned long entry = table_get_entry(s2_ctx, g_tbl, idx);
 
 	if (!entry_is_table(entry)) {
 		return NULL;
 	}
 
-	return addr_to_granule(table_entry_to_phys(entry));
+	return addr_to_granule(table_entry_to_phys(entry, s2_ctx->enable_lpa2));
 }
 
 static struct granule *find_lock_next_level(const struct s2tt_context *s2_ctx,
@@ -222,7 +282,7 @@
 
 /*
  * Walk an RTT until level @level using @map_addr.
- * @g_root is the root (level 0) table and must be locked before the call.
+ * @g_root is the root (level 0/-1) table and must be locked before the call.
  * @start_level is the initial lookup level used for the stage 2 translation
  * tables which may depend on the configuration of the realm, factoring in the
  * IPA size of the realm and the desired starting level (within the limits
@@ -249,7 +309,7 @@
 			   long level,
 			   struct s2tt_walk *wi)
 {
-	struct granule *g_tbls[NR_RTT_LEVELS] = { (struct granule *)NULL };
+	struct granule *g_tbls[NR_RTT_LEVELS_LPA2] = { (struct granule *)NULL };
 	struct granule *g_root;
 	unsigned long sl_idx, ipa_bits;
 	int i, start_level, last_level;
@@ -275,34 +335,39 @@
 		assert(tt_num < S2TTE_MAX_CONCAT_TABLES);
 
 		g_concat_root = (struct granule *)((uintptr_t)g_root +
-						(tt_num * sizeof(struct granule)));
+					(tt_num * sizeof(struct granule)));
 
 		granule_lock(g_concat_root, GRANULE_STATE_RTT);
 		granule_unlock(g_root);
 		g_root = g_concat_root;
 	}
 
-	g_tbls[start_level] = g_root;
+	/* 'start_level' can be '-1', so add 1 when used as an index */
+	g_tbls[start_level + 1] = g_root;
 	for (i = start_level; i < level; i++) {
 		/*
 		 * Lock next RTT level. Correct locking order is guaranteed
 		 * because reference is obtained from a locked granule
 		 * (previous level). Also, hand-over-hand locking/unlocking is
 		 * used to avoid race conditions.
+		 *
+		 * Note that as 'start_level' can be -1, we add '1' to the
+		 * index 'i' to compensate for the negative value when we
+		 * use it to index then 'g_tbls' list.
 		 */
-		g_tbls[i + 1] = find_lock_next_level(s2_ctx, g_tbls[i],
+		g_tbls[i + 1 + 1] = find_lock_next_level(s2_ctx, g_tbls[i + 1],
 						     map_addr, i);
-		if (g_tbls[i + 1] == NULL) {
+		if (g_tbls[i + 1 + 1] == NULL) {
 			last_level = i;
 			goto out;
 		}
-		granule_unlock(g_tbls[i]);
+		granule_unlock(g_tbls[i + 1]);
 	}
 
 	last_level = (int)level;
 out:
 	wi->last_level = last_level;
-	wi->g_llt = g_tbls[last_level];
+	wi->g_llt = g_tbls[last_level + 1];
 	wi->index = s2_addr_to_idx(map_addr, last_level);
 }
 
@@ -355,22 +420,31 @@
 					   unsigned long pa, long level,
 					   unsigned long s2tte_ripas)
 {
-	(void)level;
-	(void)s2_ctx;
-
 	assert(level >= S2TT_MIN_BLOCK_LEVEL);
 	assert(level <= S2TT_PAGE_LEVEL);
 	assert(s2tte_ripas <= S2TTE_INVALID_RIPAS_DESTROYED);
 	assert(s2tte_is_addr_lvl_aligned(s2_ctx, pa, level));
+	assert(s2_ctx != NULL);
+
+	unsigned long tte = pa_to_s2tte(pa, s2_ctx->enable_lpa2);
+	unsigned long s2tte_page, s2tte_block;
+
+	if (s2_ctx->enable_lpa2 == true) {
+		s2tte_page = S2TTE_PAGE_LPA2;
+		s2tte_block = S2TTE_BLOCK_LPA2;
+	} else {
+		s2tte_page = S2TTE_PAGE;
+		s2tte_block = S2TTE_BLOCK;
+	}
 
 	if (s2tte_ripas == S2TTE_INVALID_RIPAS_RAM) {
 		if (level == S2TT_PAGE_LEVEL) {
-			return (pa | S2TTE_PAGE);
+			return (tte | s2tte_page);
 		}
-		return (pa | S2TTE_BLOCK);
+		return (tte | s2tte_block);
 	}
 
-	return (pa | S2TTE_INVALID_HIPAS_ASSIGNED | s2tte_ripas);
+	return (tte | S2TTE_INVALID_HIPAS_ASSIGNED | s2tte_ripas);
 }
 
 /*
@@ -426,11 +500,17 @@
  * - The physical address
  * - MemAttr
  * - S2AP
- * - Shareability
+ * - Shareability (when FEAT_LPA2 is disabled)
  */
 unsigned long s2tte_create_assigned_ns(const struct s2tt_context *s2_ctx,
 				       unsigned long s2tte, long level)
 {
+	/*
+	 * We just mask out the DESC_TYPE below. The Shareability bits
+	 * without FEAT_LPA2 are at the same position as OA bits [51:50]
+	 * with FEAT_LPA2 enabled, so we don't need to cater for that
+	 * separately.
+	 */
 	unsigned long new_s2tte = s2tte & ~S2TT_DESC_TYPE_MASK;
 
 	(void)s2_ctx;
@@ -452,13 +532,26 @@
 			    unsigned long s2tte, long level)
 {
 
-	unsigned long mask = addr_level_mask(~0UL, level) |
-						S2TTE_NS_ATTR_HOST_MASK;
+	bool lpa2;
+	unsigned long mask;
 
-	(void)s2_ctx;
-
+	assert(s2_ctx != NULL);
 	assert(level >= S2TT_MIN_BLOCK_LEVEL);
 
+	lpa2 = s2_ctx->enable_lpa2;
+
+	mask = s2tte_lvl_mask(level, lpa2);
+	if (lpa2 == true) {
+		mask |= S2TTE_NS_ATTR_LPA2_MASK;
+	} else {
+		mask |= S2TTE_NS_ATTR_MASK;
+
+		/* Only SH_IS is allowed */
+		if ((s2tte & S2TTE_SH_MASK) != S2TTE_SH_IS) {
+			return false;
+		}
+	}
+
 	/*
 	 * Test that all fields that are not controlled by the host are zero
 	 * and that the output address is correctly aligned. Note that
@@ -476,13 +569,6 @@
 	}
 
 	/*
-	 * Only one value masked by S2TTE_SH_MASK is invalid/reserved.
-	 */
-	if ((s2tte & S2TTE_SH_MASK) != S2TTE_SH_IS) {
-		return false;
-	}
-
-	/*
 	 * Note that all the values that are masked by S2TTE_AP_MASK are valid.
 	 */
 	return true;
@@ -494,9 +580,12 @@
 unsigned long host_ns_s2tte(const struct s2tt_context *s2_ctx,
 			    unsigned long s2tte, long level)
 {
-	unsigned long mask = addr_level_mask(~0UL, level) |
-						S2TTE_NS_ATTR_HOST_MASK;
-	(void)s2_ctx;
+	assert(s2_ctx != NULL);
+
+	unsigned long mask = s2tte_lvl_mask(level, s2_ctx->enable_lpa2);
+
+	mask |= (s2_ctx->enable_lpa2 == true) ? S2TTE_NS_ATTR_LPA2_MASK :
+					     S2TTE_NS_ATTR_MASK;
 
 	assert(level >= S2TT_MIN_BLOCK_LEVEL);
 
@@ -509,14 +598,19 @@
 unsigned long s2tte_create_table(const struct s2tt_context *s2_ctx,
 				 unsigned long pa, long level)
 {
+	__unused int min_starting_level;
+
 	(void)level;
-	(void)s2_ctx;
+
+	assert(s2_ctx != NULL);
+	min_starting_level = (s2_ctx->enable_lpa2 == true) ?
+			S2TT_MIN_STARTING_LEVEL_LPA2 : S2TT_MIN_STARTING_LEVEL;
 
 	assert(level < S2TT_PAGE_LEVEL);
-	assert(level >= S2TT_MIN_STARTING_LEVEL);
+	assert(level >= min_starting_level);
 	assert(GRANULE_ALIGNED(pa));
 
-	return (pa | S2TTE_TABLE);
+	return (pa_to_s2tte(pa, s2_ctx->enable_lpa2) | S2TTE_TABLE);
 }
 
 /*
@@ -820,9 +914,9 @@
 				  unsigned long *s2tt, unsigned long pa,
 				  long level)
 {
+	assert(s2tt != NULL);
 	assert(level >= S2TT_MIN_BLOCK_LEVEL);
 	assert(level <= S2TT_PAGE_LEVEL);
-	assert(s2tt != NULL);
 	assert(s2tte_is_addr_lvl_aligned(s2_ctx, pa, level));
 
 	const unsigned long map_size = s2tte_map_size(level);
@@ -901,11 +995,15 @@
 	assert(s2tte_is_addr_lvl_aligned(s2_ctx, pa, level));
 
 	const unsigned long map_size = s2tte_map_size(level);
+	unsigned long ns_attr_host_mask = (s2_ctx->enable_lpa2 == true) ?
+		S2TTE_NS_ATTR_LPA2_MASK : S2TTE_NS_ATTR_MASK;
 
 	for (unsigned int i = 0U; i < S2TTES_PER_S2TT; i++) {
-		unsigned long s2tte = attrs & S2TTE_NS_ATTR_HOST_MASK;
+		unsigned long s2tte = attrs & ns_attr_host_mask;
 
-		s2tt[i] = s2tte_create_assigned_ns(s2_ctx, s2tte | pa, level);
+		s2tt[i] = s2tte_create_assigned_ns(s2_ctx,
+				s2tte | pa_to_s2tte(pa, s2_ctx->enable_lpa2),
+				level);
 		pa += map_size;
 	}
 
@@ -913,24 +1011,6 @@
 }
 
 /*
- * Returns true if s2tte has 'output address' field, namely, if it is one of:
- * - assigned_empty
- * - assigned_ram
- * - assigned_ns
- * - assigned_destroyed
- * - table
- */
-static bool s2tte_has_pa(const struct s2tt_context *s2_ctx,
-			 unsigned long s2tte, long level)
-{
-	unsigned long desc_type = s2tte & S2TT_DESC_TYPE_MASK;
-
-	return ((desc_type != S2TTE_INVALID) ||	/* block, page or table */
-		s2tte_is_assigned_empty(s2_ctx, s2tte, level) ||
-		s2tte_is_assigned_destroyed(s2_ctx, s2tte, level));
-}
-
-/*
  * Returns true if s2tte is a live RTTE entry. i.e.,
  * HIPAS is ASSIGNED.
  *
@@ -947,26 +1027,45 @@
 unsigned long s2tte_pa(const struct s2tt_context *s2_ctx, unsigned long s2tte,
 		       long level)
 {
-	assert(level <= S2TT_PAGE_LEVEL);
-	assert(level >= S2TT_MIN_STARTING_LEVEL);
+	bool lpa2;
+	unsigned long pa;
+	__unused long min_starting_level;
 
-	if (!s2tte_has_pa(s2_ctx, s2tte, level)) {
-		assert(false);
-	}
+	assert(s2_ctx != NULL);
+
+	min_starting_level = (s2_ctx->enable_lpa2 == true) ?
+		S2TT_MIN_STARTING_LEVEL_LPA2 : S2TT_MIN_STARTING_LEVEL;
+	assert(level >= min_starting_level);
+	assert(level <= S2TT_PAGE_LEVEL);
+	assert(s2tte_has_pa(s2_ctx, s2tte, level));
+
+	lpa2 = s2_ctx->enable_lpa2;
 
 	if (s2tte_is_table(s2_ctx, s2tte, level)) {
-		return addr_level_mask(s2tte, S2TT_PAGE_LEVEL);
+		pa = table_entry_to_phys(s2tte, lpa2);
+	} else {
+		pa = s2tte_to_pa(s2tte, level, lpa2);
 	}
 
-	return addr_level_mask(s2tte, level);
+	return pa;
 }
 
 bool s2tte_is_addr_lvl_aligned(const struct s2tt_context *s2_ctx,
 			      unsigned long addr, long level)
 {
-	(void)s2_ctx;
+	assert(s2_ctx != NULL);
 
-	return (addr == addr_level_mask(addr, level));
+	__unused long min_starting_level = (s2_ctx->enable_lpa2 == true) ?
+		S2TT_MIN_STARTING_LEVEL_LPA2 : S2TT_MIN_STARTING_LEVEL;
+	unsigned long levels = (unsigned long)(S2TT_PAGE_LEVEL - level);
+	unsigned long lsb = (levels * S2TTE_STRIDE) + GRANULE_SHIFT;
+	unsigned long s2tte_oa_bits = (s2_ctx->enable_lpa2 == true) ?
+	       S2TTE_OA_BITS_LPA2 : S2TTE_OA_BITS;
+
+	assert(level <= S2TT_PAGE_LEVEL);
+	assert(level >= min_starting_level);
+
+	return (addr == (addr & BIT_MASK_ULL((s2tte_oa_bits - 1U), lsb)));
 }
 
 typedef bool (*s2tte_type_checker)(const struct s2tt_context *s2_ctx,
@@ -1042,8 +1141,9 @@
 			     bool check_ns_attrs)
 {
 	assert(table != NULL);
+	assert(s2_ctx != NULL);
 
-	unsigned long base_pa;
+	unsigned long base_pa, ns_attr_host_mask;
 	unsigned long map_size = s2tte_map_size(level);
 	unsigned long s2tte = s2tte_read(&table[0]);
 	unsigned int i;
@@ -1057,6 +1157,9 @@
 		return false;
 	}
 
+	ns_attr_host_mask = (s2_ctx->enable_lpa2 == true) ?
+		S2TTE_NS_ATTR_LPA2_MASK : S2TTE_NS_ATTR_MASK;
+
 	for (i = 1U; i < S2TTES_PER_S2TT; i++) {
 		unsigned long expected_pa = base_pa + (i * map_size);
 
@@ -1072,13 +1175,13 @@
 
 		if (check_ns_attrs) {
 			unsigned long ns_attrs =
-					s2tte & S2TTE_NS_ATTR_HOST_MASK;
+					s2tte & ns_attr_host_mask;
 
 			/*
 			 * We match all the attributes in the S2TTE
 			 * except for the AF bit.
 			 */
-			if ((s2tte & S2TTE_NS_ATTR_HOST_MASK) != ns_attrs) {
+			if ((s2tte & ns_attr_host_mask) != ns_attrs) {
 				return false;
 			}
 		}
@@ -1162,12 +1265,16 @@
 	assert(table != NULL);
 	assert(wi != NULL);
 	assert(wi->index <= S2TTES_PER_S2TT);
-	assert(wi->last_level >= S2TT_MIN_STARTING_LEVEL);
 	assert(wi->last_level <= S2TT_PAGE_LEVEL);
+	assert(s2_ctx != NULL);
 
 	unsigned long i, index = wi->index;
 	long level = wi->last_level;
 	unsigned long map_size;
+	__unused long min_starting_level = (s2_ctx->enable_lpa2 == true) ?
+			S2TT_MIN_STARTING_LEVEL_LPA2 : S2TT_MIN_STARTING_LEVEL;
+
+	assert(wi->last_level >= min_starting_level);
 
 	/*
 	 * If the entry for the map_addr is live,