refactor(mm): return pointer from `arch_mm_table_from_pte`

Change-Id: I371f1b40a3d36d835e67198de203023bd0a32960
Signed-off-by: Karl Meakin <karl.meakin@arm.com>
diff --git a/src/arch/aarch64/mm.c b/src/arch/aarch64/mm.c
index 01f95da..e53027b 100644
--- a/src/arch/aarch64/mm.c
+++ b/src/arch/aarch64/mm.c
@@ -15,6 +15,7 @@
 #include "hf/arch/std.h"
 #include "hf/arch/types.h"
 
+#include "hf/addr.h"
 #include "hf/check.h"
 #include "hf/dlog.h"
 
@@ -211,9 +212,9 @@
 	return PTE_TYPE_VALID_BLOCK;
 }
 
-static uint64_t pte_addr(pte_t pte)
+static paddr_t pte_addr(pte_t pte)
 {
-	return pte & PTE_ADDR_MASK;
+	return pa_init(pte & PTE_ADDR_MASK);
 }
 
 /**
@@ -225,19 +226,18 @@
 	(void)level;
 
 	assert(arch_mm_pte_is_block(pte, level));
-	return pa_init(pte_addr(pte));
+	return pte_addr(pte);
 }
 
 /**
- * Extracts the physical address of the page table referred to by the given page
- * table entry.
+ * Extracts the page table referred to by the given page table entry.
  */
-paddr_t arch_mm_table_from_pte(pte_t pte, mm_level_t level)
+struct mm_page_table *arch_mm_table_from_pte(pte_t pte, mm_level_t level)
 {
 	(void)level;
 
 	assert(arch_mm_pte_is_table(pte, level));
-	return pa_init(pte_addr(pte));
+	return ptr_from_pa(pte_addr(pte));
 }
 
 /**
diff --git a/src/arch/fake/mm.c b/src/arch/fake/mm.c
index 2859fb9..9c23e41 100644
--- a/src/arch/fake/mm.c
+++ b/src/arch/fake/mm.c
@@ -88,10 +88,10 @@
 	return pte_addr(pte, level);
 }
 
-paddr_t arch_mm_table_from_pte(pte_t pte, mm_level_t level)
+struct mm_page_table *arch_mm_table_from_pte(pte_t pte, mm_level_t level)
 {
 	assert(arch_mm_pte_is_table(pte, level));
-	return pte_addr(pte, level);
+	return ptr_from_pa(pte_addr(pte, level));
 }
 
 mm_attr_t arch_mm_pte_attrs(pte_t pte, mm_level_t level)
diff --git a/src/mm.c b/src/mm.c
index d361b71..270acea 100644
--- a/src/mm.c
+++ b/src/mm.c
@@ -54,14 +54,6 @@
 }
 
 /**
- * Get the page table from the physical address.
- */
-static struct mm_page_table *mm_page_table_from_pa(paddr_t pa)
-{
-	return ptr_from_va(va_from_pa(pa));
-}
-
-/**
  * Rounds an address down to a page boundary.
  */
 static ptable_addr_t mm_round_down_to_page(ptable_addr_t addr)
@@ -191,7 +183,7 @@
 	}
 
 	/* Recursively free any subtables. */
-	table = mm_page_table_from_pa(arch_mm_table_from_pte(pte, level));
+	table = arch_mm_table_from_pte(pte, level);
 	for (size_t i = 0; i < MM_PTE_PER_PAGE; ++i) {
 		mm_free_page_pte(table->entries[i], level - 1, ppool);
 	}
@@ -312,7 +304,7 @@
 
 	/* Just return pointer to table if it's already populated. */
 	if (arch_mm_pte_is_table(v, level)) {
-		return mm_page_table_from_pa(arch_mm_table_from_pte(v, level));
+		return arch_mm_table_from_pte(v, level);
 	}
 
 	/* Allocate a new table. */
@@ -587,9 +579,9 @@
 	indent += 1;
 	{
 		mm_attr_t attrs = arch_mm_pte_attrs(entry, level);
-		paddr_t addr = arch_mm_table_from_pte(entry, level);
 		const struct mm_page_table *child_table =
-			mm_page_table_from_pa(addr);
+			arch_mm_table_from_pte(entry, level);
+		paddr_t addr = pa_init((uintpaddr_t)child_table);
 
 		dlog_indent(indent, ".pte   = %#016lx,\n", entry);
 		dlog_indent(indent, ".attrs = %#016lx,\n", attrs);
@@ -694,7 +686,7 @@
 	mm_attr_t combined_attrs;
 	paddr_t block_address;
 
-	table = mm_page_table_from_pa(arch_mm_table_from_pte(table_pte, level));
+	table = arch_mm_table_from_pte(table_pte, level);
 
 	if (!arch_mm_pte_is_present(table->entries[0], level - 1)) {
 		return arch_mm_absent_pte(level);
@@ -735,8 +727,7 @@
 		return;
 	}
 
-	child_table =
-		mm_page_table_from_pa(arch_mm_table_from_pte(*entry, level));
+	child_table = arch_mm_table_from_pte(*entry, level);
 
 	/* Defrag the first entry in the table and use it as the base entry. */
 	static_assert(MM_PTE_PER_PAGE >= 1, "There must be at least one PTE.");
@@ -857,10 +848,8 @@
 	while (begin < end) {
 		if (arch_mm_pte_is_table(*pte, level)) {
 			if (!mm_ptable_get_attrs_level(
-				    mm_page_table_from_pa(
-					    arch_mm_table_from_pte(*pte,
-								   level)),
-				    begin, end, level - 1, got_attrs, attrs)) {
+				    arch_mm_table_from_pte(*pte, level), begin,
+				    end, level - 1, got_attrs, attrs)) {
 				return false;
 			}
 			got_attrs = true;
diff --git a/src/mm_test.cc b/src/mm_test.cc
index 5737595..99bf08a 100644
--- a/src/mm_test.cc
+++ b/src/mm_test.cc
@@ -62,10 +62,8 @@
 /**
  * Get an STL representation of the page table.
  */
-std::span<pte_t, MM_PTE_PER_PAGE> get_table(paddr_t pa)
+std::span<pte_t, MM_PTE_PER_PAGE> get_table(struct mm_page_table *table)
 {
-	auto table = reinterpret_cast<struct mm_page_table *>(
-		ptr_from_va(va_from_pa(pa)));
 	return std::span<pte_t, MM_PTE_PER_PAGE>(table->entries,
 						 std::end(table->entries));
 }
@@ -1066,8 +1064,7 @@
 	std::vector<std::span<pte_t, MM_PTE_PER_PAGE>> all;
 	const uint8_t root_table_count = arch_mm_stage2_root_table_count();
 	for (uint8_t i = 0; i < root_table_count; ++i) {
-		all.push_back(get_table(
-			pa_init((uintpaddr_t)&ptable.root_tables[i])));
+		all.push_back(get_table(&ptable.root_tables[i]));
 	}
 	return all;
 }