Add success type for FFA_NOTIFICATION_INFO_GET

Add specialized success type for converting between
FFA_NOTIFICATION_GET_INFO return arguments and generic success
arguments.

Signed-off-by: Imre Kis <imre.kis@arm.com>
Change-Id: Ibe740e2798c88197224278343881ccc1f331f802
diff --git a/src/lib.rs b/src/lib.rs
index 0a7b6ca..417bad8 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -10,6 +10,7 @@
 use num_enum::{IntoPrimitive, TryFromPrimitive};
 use thiserror::Error;
 use uuid::Uuid;
+use zerocopy::transmute;
 
 pub mod boot_info;
 mod ffa_v1_1;
@@ -51,6 +52,8 @@
     InvalidPartitionInfoGetFlag(u32),
     #[error("Invalid success argument variant")]
     InvalidSuccessArgsVariant,
+    #[error("Invalid notification count")]
+    InvalidNotificationCount,
 }
 
 impl From<Error> for FfaError {
@@ -69,7 +72,8 @@
             | Error::InvalidVmId(_)
             | Error::UnrecognisedWarmBootType(_)
             | Error::InvalidPartitionInfoGetFlag(_)
-            | Error::InvalidSuccessArgsVariant => Self::InvalidParameters,
+            | Error::InvalidSuccessArgsVariant
+            | Error::InvalidNotificationCount => Self::InvalidParameters,
         }
     }
 }
@@ -209,6 +213,8 @@
 /// * `FFA_SPM_ID_GET` - [`SuccessArgsSpmIdGet`]
 /// * `FFA_PARTITION_INFO_GET` - [`partition_info::SuccessArgsPartitionInfoGet`]
 /// * `FFA_NOTIFICATION_GET` - [`SuccessArgsNotificationGet`]
+/// * `FFA_NOTIFICATION_INFO_GET_32` - [`SuccessArgsNotificationInfoGet32`]
+/// * `FFA_NOTIFICATION_INFO_GET_64` - [`SuccessArgsNotificationInfoGet64`]
 #[derive(Debug, Eq, PartialEq, Clone, Copy)]
 pub enum SuccessArgs {
     Args32([u32; 6]),
@@ -833,6 +839,195 @@
         })
     }
 }
+
+/// `FFA_NOTIFICATION_INFO_GET` specific success argument structure. The `MAX_COUNT` parameter
+/// depends on the 32-bit or 64-bit packing.
+#[derive(Debug, Eq, PartialEq, Clone, Copy)]
+pub struct SuccessArgsNotificationInfoGet<const MAX_COUNT: usize> {
+    pub more_pending_notifications: bool,
+    list_count: usize,
+    id_counts: [u8; MAX_COUNT],
+    ids: [u16; MAX_COUNT],
+}
+
+impl<const MAX_COUNT: usize> Default for SuccessArgsNotificationInfoGet<MAX_COUNT> {
+    fn default() -> Self {
+        Self {
+            more_pending_notifications: false,
+            list_count: 0,
+            id_counts: [0; MAX_COUNT],
+            ids: [0; MAX_COUNT],
+        }
+    }
+}
+
+impl<const MAX_COUNT: usize> SuccessArgsNotificationInfoGet<MAX_COUNT> {
+    const MORE_PENDING_NOTIFICATIONS_FLAG: u64 = 1 << 0;
+    const LIST_COUNT_SHIFT: usize = 7;
+    const LIST_COUNT_MASK: u64 = 0x1f;
+    const ID_COUNT_SHIFT: usize = 12;
+    const ID_COUNT_MASK: u64 = 0x03;
+    const ID_COUNT_BITS: usize = 2;
+
+    pub fn add_list(&mut self, endpoint: u16, vcpu_ids: &[u16]) -> Result<(), Error> {
+        if self.list_count >= MAX_COUNT || vcpu_ids.len() > Self::ID_COUNT_MASK as usize {
+            return Err(Error::InvalidNotificationCount);
+        }
+
+        // Each list contains at least one ID: the partition ID, followed by vCPU IDs. The number
+        // of vCPU IDs is recorded in `id_counts`.
+        let mut current_id_index = self.list_count + self.id_counts.iter().sum::<u8>() as usize;
+        if current_id_index + 1 + vcpu_ids.len() > MAX_COUNT {
+            // The new list does not fit into the available space for IDs.
+            return Err(Error::InvalidNotificationCount);
+        }
+
+        self.id_counts[self.list_count] = vcpu_ids.len() as u8;
+        self.list_count += 1;
+
+        // The first ID is the endpoint ID.
+        self.ids[current_id_index] = endpoint;
+        current_id_index += 1;
+
+        // Insert the vCPU IDs.
+        self.ids[current_id_index..current_id_index + vcpu_ids.len()].copy_from_slice(vcpu_ids);
+
+        Ok(())
+    }
+
+    pub fn iter(&self) -> NotificationInfoGetIterator<'_> {
+        NotificationInfoGetIterator {
+            list_index: 0,
+            id_index: 0,
+            id_count: &self.id_counts[0..self.list_count],
+            ids: &self.ids,
+        }
+    }
+
+    /// Pack flags field and IDs.
+    fn pack(self) -> (u64, [u16; MAX_COUNT]) {
+        let mut flags = if self.more_pending_notifications {
+            Self::MORE_PENDING_NOTIFICATIONS_FLAG
+        } else {
+            0
+        };
+
+        flags |= (self.list_count as u64) << Self::LIST_COUNT_SHIFT;
+        for (count, shift) in self.id_counts.iter().take(self.list_count).zip(
+            (Self::ID_COUNT_SHIFT..Self::ID_COUNT_SHIFT + Self::ID_COUNT_BITS * MAX_COUNT)
+                .step_by(Self::ID_COUNT_BITS),
+        ) {
+            flags |= u64::from(*count) << shift;
+        }
+
+        (flags, self.ids)
+    }
+
+    /// Unpack flags field and IDs.
+    fn unpack(flags: u64, ids: [u16; MAX_COUNT]) -> Result<Self, Error> {
+        let count_of_lists = ((flags >> Self::LIST_COUNT_SHIFT) & Self::LIST_COUNT_MASK) as usize;
+
+        if count_of_lists > MAX_COUNT {
+            return Err(Error::InvalidNotificationCount);
+        }
+
+        let mut count_of_ids = [0; MAX_COUNT];
+        let mut count_of_ids_bits = flags >> Self::ID_COUNT_SHIFT;
+
+        for id in count_of_ids.iter_mut().take(count_of_lists) {
+            *id = (count_of_ids_bits & Self::ID_COUNT_MASK) as u8;
+            count_of_ids_bits >>= Self::ID_COUNT_BITS;
+        }
+
+        Ok(Self {
+            more_pending_notifications: (flags & Self::MORE_PENDING_NOTIFICATIONS_FLAG) != 0,
+            list_count: count_of_lists,
+            id_counts: count_of_ids,
+            ids,
+        })
+    }
+}
+
+/// `FFA_NOTIFICATION_INFO_GET_32` specific success argument structure.
+pub type SuccessArgsNotificationInfoGet32 = SuccessArgsNotificationInfoGet<10>;
+
+impl From<SuccessArgsNotificationInfoGet32> for SuccessArgs {
+    fn from(value: SuccessArgsNotificationInfoGet32) -> Self {
+        let (flags, ids) = value.pack();
+        let id_regs: [u32; 5] = transmute!(ids);
+
+        let mut args = [0; 6];
+        args[0] = flags as u32;
+        args[1..6].copy_from_slice(&id_regs);
+
+        SuccessArgs::Args32(args)
+    }
+}
+
+impl TryFrom<SuccessArgs> for SuccessArgsNotificationInfoGet32 {
+    type Error = Error;
+
+    fn try_from(value: SuccessArgs) -> Result<Self, Self::Error> {
+        let args = value.try_get_args32()?;
+        let flags = args[0].into();
+        let id_regs: [u32; 5] = args[1..6].try_into().unwrap();
+        Self::unpack(flags, transmute!(id_regs))
+    }
+}
+
+/// `FFA_NOTIFICATION_INFO_GET_64` specific success argument structure.
+pub type SuccessArgsNotificationInfoGet64 = SuccessArgsNotificationInfoGet<20>;
+
+impl From<SuccessArgsNotificationInfoGet64> for SuccessArgs {
+    fn from(value: SuccessArgsNotificationInfoGet64) -> Self {
+        let (flags, ids) = value.pack();
+        let id_regs: [u64; 5] = transmute!(ids);
+
+        let mut args = [0; 6];
+        args[0] = flags;
+        args[1..6].copy_from_slice(&id_regs);
+
+        SuccessArgs::Args64(args)
+    }
+}
+
+impl TryFrom<SuccessArgs> for SuccessArgsNotificationInfoGet64 {
+    type Error = Error;
+
+    fn try_from(value: SuccessArgs) -> Result<Self, Self::Error> {
+        let args = value.try_get_args64()?;
+        let flags = args[0];
+        let id_regs: [u64; 5] = args[1..6].try_into().unwrap();
+        Self::unpack(flags, transmute!(id_regs))
+    }
+}
+
+pub struct NotificationInfoGetIterator<'a> {
+    list_index: usize,
+    id_index: usize,
+    id_count: &'a [u8],
+    ids: &'a [u16],
+}
+
+impl<'a> Iterator for NotificationInfoGetIterator<'a> {
+    type Item = (u16, &'a [u16]);
+
+    fn next(&mut self) -> Option<Self::Item> {
+        if self.list_index < self.id_count.len() {
+            let partition_id = self.ids[self.id_index];
+            let id_range =
+                (self.id_index + 1)..=(self.id_index + self.id_count[self.list_index] as usize);
+
+            self.id_index += 1 + self.id_count[self.list_index] as usize;
+            self.list_index += 1;
+
+            Some((partition_id, &self.ids[id_range]))
+        } else {
+            None
+        }
+    }
+}
+
 /// FF-A "message types", the terminology used by the spec is "interfaces".
 ///
 /// The interfaces are used by FF-A components for communication at an FF-A instance. The spec also
@@ -2309,4 +2504,66 @@
         };
         assert!(interface_32.is_32bit());
     }
+
+    #[test]
+    fn success_args_notification_info_get32() {
+        let mut notifications = SuccessArgsNotificationInfoGet32::default();
+
+        // 16.7.1.1 Example usage
+        notifications.add_list(0x0000, &[0, 2, 3]).unwrap();
+        notifications.add_list(0x0000, &[4, 6]).unwrap();
+        notifications.add_list(0x0002, &[]).unwrap();
+        notifications.add_list(0x0003, &[1]).unwrap();
+
+        let args: SuccessArgs = notifications.into();
+        assert_eq!(
+            SuccessArgs::Args32([
+                0x0004_b200,
+                0x0000_0000,
+                0x0003_0002,
+                0x0004_0000,
+                0x0002_0006,
+                0x0001_0003
+            ]),
+            args
+        );
+
+        let notifications = SuccessArgsNotificationInfoGet32::try_from(args).unwrap();
+        let mut iter = notifications.iter();
+        assert_eq!(Some((0x0000, &[0, 2, 3][..])), iter.next());
+        assert_eq!(Some((0x0000, &[4, 6][..])), iter.next());
+        assert_eq!(Some((0x0002, &[][..])), iter.next());
+        assert_eq!(Some((0x0003, &[1][..])), iter.next());
+    }
+
+    #[test]
+    fn success_args_notification_info_get64() {
+        let mut notifications = SuccessArgsNotificationInfoGet64::default();
+
+        // 16.7.1.1 Example usage
+        notifications.add_list(0x0000, &[0, 2, 3]).unwrap();
+        notifications.add_list(0x0000, &[4, 6]).unwrap();
+        notifications.add_list(0x0002, &[]).unwrap();
+        notifications.add_list(0x0003, &[1]).unwrap();
+
+        let args: SuccessArgs = notifications.into();
+        assert_eq!(
+            SuccessArgs::Args64([
+                0x0004_b200,
+                0x0003_0002_0000_0000,
+                0x0002_0006_0004_0000,
+                0x0000_0000_0001_0003,
+                0x0000_0000_0000_0000,
+                0x0000_0000_0000_0000,
+            ]),
+            args
+        );
+
+        let notifications = SuccessArgsNotificationInfoGet64::try_from(args).unwrap();
+        let mut iter = notifications.iter();
+        assert_eq!(Some((0x0000, &[0, 2, 3][..])), iter.next());
+        assert_eq!(Some((0x0000, &[4, 6][..])), iter.next());
+        assert_eq!(Some((0x0002, &[][..])), iter.next());
+        assert_eq!(Some((0x0003, &[1][..])), iter.next());
+    }
 }