Introduce KernelAddressTranslator trait

Add KernelAddressTranslator as a generic parameter of Xlat in order to
decouple dependency on KernelSpace. This trait is used for translating
between physical addresses and virtual addresses of the running kernel
context. Xlat uses the trait for accessing the translation tables.

Signed-off-by: Imre Kis <imre.kis@arm.com>
Change-Id: Iaf4189429f21fced9d40e34fb309388165127124
diff --git a/src/kernel_space.rs b/src/kernel_space.rs
index cc04541..68296e9 100644
--- a/src/kernel_space.rs
+++ b/src/kernel_space.rs
@@ -8,15 +8,29 @@
 use alloc::sync::Arc;
 use spin::Mutex;
 
+use crate::KernelAddressTranslator;
+
 use super::{
     address::{PhysicalAddress, VirtualAddress, VirtualAddressRange},
     page_pool::{Page, PagePool},
     MemoryAccessRights, RegimeVaRange, TranslationGranule, TranslationRegime, Xlat, XlatError,
 };
 
+struct KernelAddressTranslatorIdentity;
+
+impl KernelAddressTranslator for KernelAddressTranslatorIdentity {
+    fn kernel_to_pa(va: VirtualAddress) -> PhysicalAddress {
+        PhysicalAddress(va.0 & 0x0000_000f_ffff_ffff)
+    }
+
+    fn pa_to_kernel(pa: PhysicalAddress) -> VirtualAddress {
+        VirtualAddress(pa.0 | 0xffff_fff0_0000_0000)
+    }
+}
+
 #[derive(Clone)]
 pub struct KernelSpace {
-    xlat: Arc<Mutex<Xlat<36>>>,
+    xlat: Arc<Mutex<Xlat<KernelAddressTranslatorIdentity, 36>>>,
 }
 
 /// # Kernel space memory mapping
@@ -154,16 +168,6 @@
     }
 
     /// Kernel virtual address to physical address
-    #[cfg(not(test))]
-    pub const fn kernel_to_pa(kernel_address: u64) -> u64 {
-        kernel_address & 0x0000_000f_ffff_ffff
-    }
-    /// Physical address to kernel virtual address
-    #[cfg(not(test))]
-    pub const fn pa_to_kernel(pa: u64) -> u64 {
-        // TODO: make this consts assert_eq!(pa & 0xffff_fff0_0000_0000, 0);
-        pa | 0xffff_fff0_0000_0000
-    }
 
     // Do not use any mapping in test build
     #[cfg(test)]
diff --git a/src/lib.rs b/src/lib.rs
index c8abbdd..63f1e4b 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -8,6 +8,7 @@
 
 use core::fmt;
 use core::iter::zip;
+use core::marker::PhantomData;
 use core::panic;
 
 use address::{PhysicalAddress, VirtualAddress, VirtualAddressRange};
@@ -20,7 +21,6 @@
 use self::descriptor::DescriptorType;
 
 use self::descriptor::{Attributes, DataAccessPermissions, Descriptor, Shareability};
-use self::kernel_space::KernelSpace;
 use self::page_pool::{PagePool, Pages};
 use self::region::{PhysicalRegion, VirtualRegion};
 use self::region_pool::{Region, RegionPool, RegionPoolError};
@@ -142,12 +142,20 @@
 
 pub type TranslationGranule<const VA_BITS: usize> = granule::TranslationGranule<VA_BITS>;
 
-pub struct Xlat<const VA_BITS: usize> {
+/// Trait for converting between virtual address space of the running kernel environment and
+/// the physical address space.
+pub trait KernelAddressTranslator {
+    fn kernel_to_pa(va: VirtualAddress) -> PhysicalAddress;
+    fn pa_to_kernel(pa: PhysicalAddress) -> VirtualAddress;
+}
+
+pub struct Xlat<K: KernelAddressTranslator, const VA_BITS: usize> {
     base_table: Pages,
     page_pool: PagePool,
     regions: RegionPool<VirtualRegion>,
     regime: TranslationRegime,
     granule: TranslationGranule<VA_BITS>,
+    _kernel_address_translator: PhantomData<K>,
 }
 
 /// Memory translation table handling
@@ -175,7 +183,7 @@
 /// * map block
 /// * unmap block
 /// * set access rights of block
-impl<const VA_BITS: usize> Xlat<VA_BITS> {
+impl<K: KernelAddressTranslator, const VA_BITS: usize> Xlat<K, VA_BITS> {
     pub fn new(
         page_pool: PagePool,
         address: VirtualAddressRange,
@@ -210,6 +218,7 @@
             regions,
             regime,
             granule,
+            _kernel_address_translator: PhantomData,
         }
     }
 
@@ -232,7 +241,7 @@
             .allocate_pages(data.len(), Some(self.granule as usize))
             .map_err(|e| XlatError::PageAllocationError(e, data.len()))?;
 
-        pages.copy_data_to_page(data);
+        pages.copy_data_to_page::<K>(data);
 
         let pages_length = pages.length();
         let physical_region = PhysicalRegion::Allocated(self.page_pool.clone(), pages);
@@ -265,7 +274,7 @@
             .allocate_pages(length, Some(self.granule as usize))
             .map_err(|e| XlatError::PageAllocationError(e, length))?;
 
-        pages.zero_init();
+        pages.zero_init::<K>();
 
         let pages_length = pages.length();
         let physical_region = PhysicalRegion::Allocated(self.page_pool.clone(), pages);
@@ -418,7 +427,7 @@
         }
 
         // Set translation table
-        let base_table_pa = KernelSpace::kernel_to_pa(self.base_table.get_pa().0 as u64);
+        let base_table_pa = self.base_table.get_pa().0 as u64;
 
         match &self.regime {
             TranslationRegime::EL1_0(RegimeVaRange::Lower, asid) => core::arch::asm!(
@@ -616,7 +625,7 @@
             block.va,
             block.size,
             self.granule.initial_lookup_level(),
-            unsafe { self.base_table.get_as_mut_slice::<Descriptor>() },
+            unsafe { self.base_table.get_as_mut_slice::<K, Descriptor>() },
             &self.page_pool,
             &self.regime,
             self.granule,
@@ -675,7 +684,7 @@
                         )
                     })?;
 
-                let next_table = unsafe { page.get_as_mut_slice() };
+                let next_table = unsafe { page.get_as_mut_slice::<K, Descriptor>() };
 
                 // Fill next level table
                 let result = Self::set_block_descriptor_recursively(
@@ -692,9 +701,8 @@
 
                 if result.is_ok() {
                     // Set table descriptor if the table is configured properly
-                    let next_table_pa = PhysicalAddress(KernelSpace::kernel_to_pa(
-                        next_table.as_ptr() as u64,
-                    ) as usize);
+                    let next_table_pa =
+                        K::kernel_to_pa(VirtualAddress(next_table.as_ptr() as usize));
                     descriptor.set_table_descriptor(level, next_table_pa, None);
                 } else {
                     // Release next level table on error and keep invalid descriptor on current level
@@ -724,7 +732,7 @@
                         )
                     })?;
 
-                let next_table = unsafe { page.get_as_mut_slice() };
+                let next_table = unsafe { page.get_as_mut_slice::<K, Descriptor>() };
 
                 // Explode existing block descriptor into table entries
                 for exploded_va in VirtualAddressRange::new(
@@ -766,9 +774,8 @@
                 );
 
                 if result.is_ok() {
-                    let next_table_pa = PhysicalAddress(KernelSpace::kernel_to_pa(
-                        next_table.as_ptr() as u64,
-                    ) as usize);
+                    let next_table_pa =
+                        K::kernel_to_pa(VirtualAddress(next_table.as_ptr() as usize));
 
                     // Follow break-before-make sequence
                     descriptor.set_block_or_invalid_descriptor_to_invalid(level);
@@ -815,7 +822,7 @@
             block.va,
             block.size,
             self.granule.initial_lookup_level(),
-            unsafe { self.base_table.get_as_mut_slice::<Descriptor>() },
+            unsafe { self.base_table.get_as_mut_slice::<K, Descriptor>() },
             &self.page_pool,
             &self.regime,
             self.granule,
@@ -883,10 +890,10 @@
                     let mut page = unsafe {
                         let table_pa = descriptor.set_table_descriptor_to_invalid(level);
                         let next_table = Self::get_table_from_pa_mut(table_pa, granule, level + 1);
-                        Pages::from_slice(next_table)
+                        Pages::from_slice::<K, Descriptor>(next_table)
                     };
 
-                    page.zero_init();
+                    page.zero_init::<K>();
                     page_pool.release_pages(page).unwrap();
                 }
             }
@@ -898,7 +905,7 @@
             va,
             block_size,
             self.granule.initial_lookup_level(),
-            unsafe { self.base_table.get_as_mut_slice::<Descriptor>() },
+            unsafe { self.base_table.get_as_mut_slice::<K, Descriptor>() },
             self.granule,
         )
     }
@@ -955,10 +962,10 @@
         granule: TranslationGranule<VA_BITS>,
         level: isize,
     ) -> &'a [Descriptor] {
-        let table_va = KernelSpace::pa_to_kernel(pa.0 as u64);
+        let table_va = K::pa_to_kernel(pa);
         unsafe {
             core::slice::from_raw_parts(
-                table_va as *const Descriptor,
+                table_va.0 as *const Descriptor,
                 granule.entry_count_at_level(level),
             )
         }
@@ -974,10 +981,10 @@
         granule: TranslationGranule<VA_BITS>,
         level: isize,
     ) -> &'a mut [Descriptor] {
-        let table_va = KernelSpace::pa_to_kernel(pa.0 as u64);
+        let table_va = K::pa_to_kernel(pa);
         unsafe {
             core::slice::from_raw_parts_mut(
-                table_va as *mut Descriptor,
+                table_va.0 as *mut Descriptor,
                 granule.entry_count_at_level(level),
             )
         }
@@ -1050,7 +1057,7 @@
     fn invalidate(_regime: &TranslationRegime, _va: Option<VirtualAddress>) {}
 }
 
-impl<const VA_BITS: usize> fmt::Debug for Xlat<VA_BITS> {
+impl<K: KernelAddressTranslator, const VA_BITS: usize> fmt::Debug for Xlat<K, VA_BITS> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> core::fmt::Result {
         f.debug_struct("Xlat")
             .field("regime", &self.regime)
@@ -1063,7 +1070,7 @@
             f,
             self.granule.initial_lookup_level(),
             0,
-            unsafe { self.base_table.get_as_slice() },
+            unsafe { self.base_table.get_as_slice::<K, Descriptor>() },
             self.granule,
         )?;
 
diff --git a/src/page_pool.rs b/src/page_pool.rs
index c2707c7..d767caf 100644
--- a/src/page_pool.rs
+++ b/src/page_pool.rs
@@ -3,14 +3,14 @@
 
 //! Region pool implementation for allocating pages
 
-use core::slice;
-
 use alloc::sync::Arc;
 use alloc::vec::Vec;
 use spin::Mutex;
 
+use crate::address::VirtualAddress;
+use crate::KernelAddressTranslator;
+
 use super::address::PhysicalAddress;
-use super::kernel_space::KernelSpace;
 use super::region_pool::{Region, RegionPool, RegionPoolError};
 
 /// Single 4kB page definition
@@ -48,17 +48,22 @@
     }
 
     /// Copy data to pages
-    pub fn copy_data_to_page(&mut self, data: &[u8]) {
+    pub fn copy_data_to_page<K: KernelAddressTranslator>(&mut self, data: &[u8]) {
         assert!(data.len() <= self.length);
 
-        let page_contents = unsafe { slice::from_raw_parts_mut(self.pa as *mut u8, data.len()) };
+        let page_contents = unsafe {
+            core::slice::from_raw_parts_mut(
+                K::pa_to_kernel(PhysicalAddress(self.pa)).0 as *mut u8,
+                data.len(),
+            )
+        };
         page_contents.clone_from_slice(data);
     }
 
     /// Zero init pages
-    pub fn zero_init(&mut self) {
+    pub fn zero_init<K: KernelAddressTranslator>(&mut self) {
         unsafe {
-            self.get_as_mut_slice::<u8>().fill(0);
+            self.get_as_mut_slice::<K, u8>().fill(0);
         }
     }
 
@@ -72,11 +77,11 @@
     /// # Safety
     /// The returned slice is created from its address and length which is stored in the
     /// object. The caller has to ensure that no other references are being used of the pages.
-    pub unsafe fn get_as_slice<T>(&self) -> &[T] {
+    pub unsafe fn get_as_slice<K: KernelAddressTranslator, T>(&self) -> &[T] {
         assert!((core::mem::align_of::<T>() - 1) & self.pa == 0);
 
         core::slice::from_raw_parts(
-            KernelSpace::pa_to_kernel(self.pa as u64) as *const T,
+            K::pa_to_kernel(PhysicalAddress(self.pa)).0 as *const T,
             self.length / core::mem::size_of::<T>(),
         )
     }
@@ -86,11 +91,11 @@
     /// # Safety
     /// The returned slice is created from its address and length which is stored in the
     /// object. The caller has to ensure that no other references are being used of the pages.
-    pub unsafe fn get_as_mut_slice<T>(&mut self) -> &mut [T] {
+    pub unsafe fn get_as_mut_slice<K: KernelAddressTranslator, T>(&mut self) -> &mut [T] {
         assert!((core::mem::align_of::<T>() - 1) & self.pa == 0);
 
         core::slice::from_raw_parts_mut(
-            KernelSpace::pa_to_kernel(self.pa as u64) as *mut T,
+            K::pa_to_kernel(PhysicalAddress(self.pa)).0 as *mut T,
             self.length / core::mem::size_of::<T>(),
         )
     }
@@ -99,9 +104,9 @@
     ///
     /// # Safety
     ///  The caller has to ensure that the passed slice is a valid page range.
-    pub unsafe fn from_slice<T>(s: &mut [T]) -> Pages {
+    pub unsafe fn from_slice<K: KernelAddressTranslator, T>(s: &mut [T]) -> Pages {
         Pages {
-            pa: KernelSpace::kernel_to_pa(s.as_ptr() as u64) as usize,
+            pa: K::kernel_to_pa(VirtualAddress(s.as_ptr() as usize)).0,
             length: core::mem::size_of_val(s),
             used: true,
         }
@@ -204,8 +209,13 @@
 
 impl PagePool {
     /// Create new page pool
-    pub fn new<const AREA_SIZE: usize>(page_pool_area: &'static PagePoolArea<AREA_SIZE>) -> Self {
-        let pa = KernelSpace::kernel_to_pa(&page_pool_area.area[0] as *const u8 as u64) as usize;
+    pub fn new<K: KernelAddressTranslator, const AREA_SIZE: usize>(
+        page_pool_area: &'static PagePoolArea<AREA_SIZE>,
+    ) -> Self {
+        let pa = K::kernel_to_pa(VirtualAddress(
+            &page_pool_area.area[0] as *const u8 as usize,
+        ))
+        .0;
         let length = page_pool_area.area.len();
 
         let mut region_pool = RegionPool::new();
@@ -240,6 +250,18 @@
 mod tests {
     use super::*;
 
+    struct DummyKernelAddressTranslator {}
+
+    impl KernelAddressTranslator for DummyKernelAddressTranslator {
+        fn kernel_to_pa(va: VirtualAddress) -> PhysicalAddress {
+            va.identity_pa()
+        }
+
+        fn pa_to_kernel(pa: PhysicalAddress) -> VirtualAddress {
+            pa.identity_va()
+        }
+    }
+
     #[test]
     fn test_pages() {
         let area = [0x5au8; 4096];
@@ -253,19 +275,19 @@
         assert_eq!(area.len(), pages.length());
         assert!(pages.used());
 
-        pages.copy_data_to_page(&[0, 1, 2, 3, 4, 5, 6, 7]);
+        pages.copy_data_to_page::<DummyKernelAddressTranslator>(&[0, 1, 2, 3, 4, 5, 6, 7]);
         assert_eq!([0, 1, 2, 3, 4, 5, 6, 7], area[0..8]);
 
-        pages.zero_init();
+        pages.zero_init::<DummyKernelAddressTranslator>();
         assert_eq!([0, 0, 0, 0, 0, 0, 0, 0], area[0..8]);
 
-        let s = unsafe { pages.get_as_mut_slice() };
+        let s = unsafe { pages.get_as_mut_slice::<DummyKernelAddressTranslator, u8>() };
         for (i, e) in s.iter_mut().enumerate().take(8) {
             *e = i as u8;
         }
         assert_eq!([0, 1, 2, 3, 4, 5, 6, 7], area[0..8]);
 
-        let from_slice = unsafe { Pages::from_slice(s) };
+        let from_slice = unsafe { Pages::from_slice::<DummyKernelAddressTranslator, u8>(s) };
         assert_eq!(area.as_ptr() as usize, from_slice.pa);
         assert_eq!(area.len(), from_slice.length);
         assert!(from_slice.used);
diff --git a/src/region.rs b/src/region.rs
index d98afa5..1cba7c4 100644
--- a/src/region.rs
+++ b/src/region.rs
@@ -266,7 +266,19 @@
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::page_pool::PagePoolArea;
+    use crate::{page_pool::PagePoolArea, KernelAddressTranslator};
+
+    struct DummyKernelAddressTranslator {}
+
+    impl KernelAddressTranslator for DummyKernelAddressTranslator {
+        fn kernel_to_pa(va: VirtualAddress) -> PhysicalAddress {
+            va.identity_pa()
+        }
+
+        fn pa_to_kernel(pa: PhysicalAddress) -> VirtualAddress {
+            pa.identity_va()
+        }
+    }
 
     #[test]
     #[should_panic]
@@ -282,7 +294,7 @@
 
         static PAGE_POOL_AREA: PagePoolArea<16> = PagePoolArea::new();
         let region = PhysicalRegion::Allocated(
-            PagePool::new(&PAGE_POOL_AREA),
+            PagePool::new::<DummyKernelAddressTranslator, 16>(&PAGE_POOL_AREA),
             Pages::new(PA.0, LENGTH, true),
         );
         assert_eq!(PA, region.get_pa());
@@ -722,7 +734,7 @@
     #[test]
     fn test_virtual_region_drop() {
         static PAGE_POOL_AREA: PagePoolArea<8192> = PagePoolArea::new();
-        let page_pool = PagePool::new(&PAGE_POOL_AREA);
+        let page_pool = PagePool::new::<DummyKernelAddressTranslator, 8192>(&PAGE_POOL_AREA);
         let page = page_pool.allocate_pages(4096, None).unwrap();
 
         let physical_region = PhysicalRegion::Allocated(page_pool, page);