Fix TLB invalidation

Add synchronization barriers and use the fix TLBI calls when
invalidating single VA entries. TLBI instructions expect the virtual
address shifted right by 12 bits which was missing from the
implementation.

Signed-off-by: Imre Kis <imre.kis@arm.com>
Change-Id: I413f986fffbdecb875a8ddc3356bae61b73e51d8
diff --git a/src/address.rs b/src/address.rs
index 2ccea8c..3a35c6c 100644
--- a/src/address.rs
+++ b/src/address.rs
@@ -89,15 +89,6 @@
         PhysicalAddress(self.0)
     }
 
-    /// Mask the lower bits of the virtual address for the given granule and level
-    pub const fn mask_for_level<const VA_BITS: usize>(
-        self,
-        translation_granule: TranslationGranule<VA_BITS>,
-        level: isize,
-    ) -> Self {
-        Self(self.0 & (translation_granule.block_size_at_level(level) - 1))
-    }
-
     /// Calculate the index of the virtual address in a translation table at the
     /// given granule and level.
     pub const fn get_level_index<const VA_BITS: usize>(
@@ -105,7 +96,8 @@
         translation_granule: TranslationGranule<VA_BITS>,
         level: isize,
     ) -> usize {
-        self.0 >> translation_granule.total_bits_at_level(level)
+        let mask = translation_granule.entry_count_at_level(level) - 1;
+        (self.0 >> translation_granule.total_bits_at_level(level)) & mask
     }
 
     /// Check if the address is valid in the translation regime, i.e. if the top bits match the
diff --git a/src/descriptor.rs b/src/descriptor.rs
index b5db7cc..c3ed68f 100644
--- a/src/descriptor.rs
+++ b/src/descriptor.rs
@@ -354,7 +354,7 @@
         unsafe {
             ptr::write_volatile(self.cell.get(), value);
             #[cfg(target_arch = "aarch64")]
-            core::arch::asm!("dsb nsh");
+            core::arch::asm!("dsb ishst");
         }
     }
 
diff --git a/src/lib.rs b/src/lib.rs
index f1f1234..63d24cc 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -117,7 +117,7 @@
             contiguous: false,
             not_global: !access_rights.contains(MemoryAccessRights::GLOBAL),
             access_flag: true,
-            shareability: Shareability::NonShareable,
+            shareability: Shareability::Inner,
             data_access_permissions,
             non_secure: access_rights.contains(MemoryAccessRights::NS),
             mem_attr_index,
@@ -788,7 +788,7 @@
                 let result = Self::set_block_descriptor_recursively(
                     attributes,
                     pa,
-                    va.mask_for_level(granule, level),
+                    va,
                     block_size,
                     level + 1,
                     next_table,
@@ -847,7 +847,7 @@
                     Self::set_block_descriptor_recursively(
                         current_attributes.clone(),
                         current_pa.add_offset(offset).unwrap(),
-                        exploded_va.mask_for_level(granule, level),
+                        exploded_va,
                         granule.block_size_at_level(level + 1),
                         level + 1,
                         next_table,
@@ -862,7 +862,7 @@
                 let result = Self::set_block_descriptor_recursively(
                     attributes,
                     pa,
-                    va.mask_for_level(granule, level),
+                    va,
                     block_size,
                     level + 1,
                     next_table,
@@ -900,7 +900,7 @@
                 Self::set_block_descriptor_recursively(
                     attributes,
                     pa,
-                    va.mask_for_level(granule, level),
+                    va,
                     block_size,
                     level + 1,
                     next_level_table,
@@ -974,7 +974,7 @@
                 };
 
                 Self::remove_block_descriptor_recursively(
-                    va.mask_for_level(granule, level),
+                    va,
                     block_size,
                     level + 1,
                     next_level_table,
@@ -1041,7 +1041,7 @@
                 };
 
                 Self::walk_descriptors(
-                    va.mask_for_level(granule, level),
+                    va,
                     level + 1,
                     next_level_table,
                     granule,
@@ -1093,61 +1093,54 @@
         // SAFETY: The assembly code invalidates the translation table entry of
         // the VA or all entries of the translation regime.
         unsafe {
-            if let Some(VirtualAddress(va)) = va {
+            // Wait for store in inner shareable
+            core::arch::asm!("dsb ishst");
+
+            if let Some(va) = va {
+                // Invalidate single virtual address for translation regime
+                let index = (va.0 >> 12) as u64 & 0x0000_0fff_ffff_ffff;
+
                 match regime {
-                    TranslationRegime::EL1_0(_, _) => {
+                    TranslationRegime::EL1_0(_, asid) => {
                         core::arch::asm!(
-                        "tlbi vaae1is, {0}
-                        dsb nsh
-                        isb",
-                        in(reg) va)
+                        "tlbi vale1, {0}",
+                        in(reg) ((*asid as u64) << 48) | index)
                     }
                     #[cfg(target_feature = "vh")]
-                    TranslationRegime::EL2_0(_, _) => {
+                    TranslationRegime::EL2_0(_, asid) => {
                         core::arch::asm!(
-                        "tlbi vaae1is, {0}
-                        dsb nsh
-                        isb",
-                        in(reg) va)
+                        "tlbi vale1, {0}",
+                        in(reg) ((*asid as u64) << 48) | index)
                     }
                     TranslationRegime::EL2 => core::arch::asm!(
-                        "tlbi vae2is, {0}
-                        dsb nsh
-                        isb",
-                        in(reg) va),
+                        "tlbi vae2, {0}",
+                        in(reg) index),
                     TranslationRegime::EL3 => core::arch::asm!(
-                        "tlbi vae3is, {0}
-                        dsb nsh
-                        isb",
-                        in(reg) va),
+                        "tlbi vae3, {0}",
+                        in(reg) index),
                 }
             } else {
+                // Invalidate all entries for translation regime
                 match regime {
                     TranslationRegime::EL1_0(_, asid) => core::arch::asm!(
-                        "tlbi aside1, {0}
-                        dsb nsh
-                        isb",
+                        "tlbi aside1, {0}",
                         in(reg) (*asid as u64) << 48
                     ),
                     #[cfg(target_feature = "vh")]
                     TranslationRegime::EL2_0(_, asid) => core::arch::asm!(
-                        "tlbi aside1, {0}
-                        dsb nsh
-                        isb",
+                        "tlbi aside1, {0}",
                         in(reg) (*asid as u64) << 48
                     ),
-                    TranslationRegime::EL2 => core::arch::asm!(
-                        "tlbi alle2
-                        dsb nsh
-                        isb"
-                    ),
-                    TranslationRegime::EL3 => core::arch::asm!(
-                        "tlbi alle3
-                        dsb nsh
-                        isb"
-                    ),
+                    TranslationRegime::EL2 => core::arch::asm!("tlbi alle2"),
+                    TranslationRegime::EL3 => core::arch::asm!("tlbi alle3"),
                 }
             }
+
+            // Synchronize TLB invalidation
+            core::arch::asm!(
+                "dsb ish
+                isb"
+            );
         }
     }