Check total_page_count overflow

Check for overflows when constructing the total_page_count field of a
composite_memory_region_descriptor. This also results in changing the
return type of MemTransactionDesc::pack from usize to
Result<usize, Error>.

Signed-off-by: Imre Kis <imre.kis@arm.com>
Change-Id: Ie3f7c70bb244a1c25c4fa55c9c189228a9dea097
diff --git a/src/memory_management.rs b/src/memory_management.rs
index 925fe44..dc0fc41 100644
--- a/src/memory_management.rs
+++ b/src/memory_management.rs
@@ -431,9 +431,16 @@
 
 impl<'a> ConstituentMemRegionIterator<'a> {
     /// Create an iterator of constituent memory region descriptors from a buffer.
-    fn new(buf: &'a [u8], count: usize, offset: usize) -> Result<Self, Error> {
-        let Some(total_size) = count
-            .checked_mul(size_of::<constituent_memory_region_descriptor>())
+    fn new(
+        buf: &'a [u8],
+        region_count: usize,
+        total_page_count: u32,
+        offset: usize,
+    ) -> Result<Self, Error> {
+        let descriptor_size = size_of::<constituent_memory_region_descriptor>();
+
+        let Some(total_size) = region_count
+            .checked_mul(descriptor_size)
             .and_then(|x| x.checked_add(offset))
         else {
             return Err(Error::InvalidBufferSize);
@@ -443,7 +450,32 @@
             return Err(Error::InvalidBufferSize);
         }
 
-        Ok(Self { buf, offset, count })
+        // Check if the sum of of page counts in the constituent_memory_region_descriptors matches
+        // the total_page_count field of the composite_memory_region_descriptor.
+        let mut page_count_sum: u32 = 0;
+        for desc_offset in
+            (offset..offset + descriptor_size * region_count).step_by(descriptor_size)
+        {
+            let Ok(desc_raw) = constituent_memory_region_descriptor::ref_from_bytes(
+                &buf[desc_offset..desc_offset + descriptor_size],
+            ) else {
+                return Err(Error::MalformedDescriptor);
+            };
+
+            page_count_sum = page_count_sum
+                .checked_add(desc_raw.page_count)
+                .ok_or(Error::MalformedDescriptor)?;
+        }
+
+        if page_count_sum != total_page_count {
+            return Err(Error::MalformedDescriptor);
+        }
+
+        Ok(Self {
+            buf,
+            offset,
+            count: region_count,
+        })
     }
 }
 
@@ -566,7 +598,7 @@
             offset += mem_access_desc_size;
         }
 
-        let mut total_page_count = 0;
+        let mut total_page_count: u32 = 0;
 
         offset = composite_offset + Self::CONSTITUENT_ARRAY_OFFSET;
         for constituent in constituents {
@@ -579,7 +611,9 @@
             constituent_raw.write_to_prefix(&mut buf[offset..]).unwrap();
             offset += size_of::<constituent_memory_region_descriptor>();
 
-            total_page_count += constituent_raw.page_count;
+            total_page_count = total_page_count
+                .checked_add(constituent_raw.page_count)
+                .expect("total_page_count overflow");
         }
 
         let composite_desc_raw = composite_memory_region_descriptor {
@@ -689,12 +723,10 @@
         let constituent_iter = ConstituentMemRegionIterator::new(
             buf,
             composite_desc_raw.address_range_count as usize,
+            composite_desc_raw.total_page_count,
             offset + Self::CONSTITUENT_ARRAY_OFFSET,
         )?;
 
-        // TODO: add a sainty check to compare the composite descriptor's total page count and the
-        // sum of page counts from constituent memory regions (not sure if it's really valuable)
-
         Ok((transaction_desc, access_desc_iter, Some(constituent_iter)))
     }
 }