// SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: MIT OR Apache-2.0

//! Region pool implementation for allocating pages

use core::slice;

use alloc::sync::Arc;
use alloc::vec::Vec;
use spin::Mutex;

use super::address::PhysicalAddress;
use super::kernel_space::KernelSpace;
use super::region_pool::{Region, RegionPool, RegionPoolError};

/// Single 4kB page definition
pub struct Page {}

impl Page {
    pub const SIZE: usize = 4096;
}

/// Area for allocating pages
#[repr(C, align(4096))]
pub struct PagePoolArea<const AREA_SIZE: usize> {
    area: [u8; AREA_SIZE],
}

impl<const AREA_SIZE: usize> PagePoolArea<AREA_SIZE> {
    pub const fn new() -> Self {
        Self {
            area: [0; AREA_SIZE],
        }
    }
}

/// Continuous pages
pub struct Pages {
    pa: usize,
    length: usize,
    used: bool,
}

impl Pages {
    // Create new instance
    pub(crate) fn new(pa: usize, length: usize, used: bool) -> Self {
        Pages { pa, length, used }
    }

    /// Copy data to pages
    pub fn copy_data_to_page(&mut self, data: &[u8]) {
        assert!(data.len() <= self.length);

        let page_contents = unsafe { slice::from_raw_parts_mut(self.pa as *mut u8, data.len()) };
        page_contents.clone_from_slice(data);
    }

    /// Zero init pages
    pub fn zero_init(&mut self) {
        unsafe {
            self.get_as_slice::<u8>().fill(0);
        }
    }

    /// Get physical address
    pub fn get_pa(&self) -> PhysicalAddress {
        PhysicalAddress(self.pa)
    }

    /// Get as mutable slice
    ///
    /// **Unsafe**: The returned slice is created from its address and length which is stored in the
    /// object. The caller has to ensure that no other references are being used of the pages.
    pub unsafe fn get_as_slice<T>(&mut self) -> &mut [T] {
        assert!((core::mem::align_of::<T>() - 1) & self.pa == 0);

        core::slice::from_raw_parts_mut(
            KernelSpace::pa_to_kernel(self.pa as u64) as *mut T,
            self.length / core::mem::size_of::<T>(),
        )
    }

    /// Set contents from slice
    ///
    /// **Unsafe:** The caller has to ensure that the passed slice is a valid page range.
    pub unsafe fn from_slice<T>(s: &mut [T]) -> Pages {
        Pages {
            pa: KernelSpace::kernel_to_pa(s.as_ptr() as u64) as usize,
            length: core::mem::size_of_val(s),
            used: true,
        }
    }
}

impl Region for Pages {
    type Resource = ();
    type Base = usize;
    type Length = usize;
    type Alignment = usize;

    fn base(&self) -> usize {
        self.pa
    }

    fn length(&self) -> usize {
        self.length
    }

    fn used(&self) -> bool {
        self.used
    }

    fn contains(&self, base: usize, length: usize) -> bool {
        if let (Some(end), Some(self_end)) =
            (base.checked_add(length), self.pa.checked_add(self.length))
        {
            self.pa <= base && end <= self_end
        } else {
            false
        }
    }

    fn try_alloc_aligned(
        &self,
        length: Self::Length,
        alignment: Self::Alignment,
    ) -> Option<Self::Base> {
        let aligned_base = self.pa.next_multiple_of(alignment);
        let base_offset = aligned_base.checked_sub(self.pa)?;

        let required_length = base_offset.checked_add(length)?;
        if required_length <= self.length {
            Some(aligned_base)
        } else {
            None
        }
    }

    fn try_append(&mut self, other: &Self) -> bool {
        if let (Some(self_end), Some(new_length)) = (
            self.pa.checked_add(self.length),
            self.length.checked_add(other.length),
        ) {
            if self.used == other.used && self_end == other.pa {
                self.length = new_length;
                true
            } else {
                false
            }
        } else {
            false
        }
    }

    fn create_split(
        &self,
        base: usize,
        length: usize,
        resource: Option<Self::Resource>,
    ) -> (Self, Vec<Self>) {
        assert!(self.contains(base, length));

        let used = resource.is_some();
        let mut res = Vec::new();
        if self.pa != base {
            res.push(Pages::new(self.pa, base - self.pa, self.used));
        }
        res.push(Pages::new(base, length, used));
        if self.pa + self.length != base + length {
            res.push(Pages::new(
                base + length,
                (self.pa + self.length) - (base + length),
                self.used,
            ));
        }

        (Pages::new(base, length, used), res)
    }
}

/// RegionPool implementation for pages
#[derive(Clone)]
pub struct PagePool {
    pages: Arc<Mutex<RegionPool<Pages>>>,
}

type PagePoolError = RegionPoolError;

impl PagePool {
    /// Create new page pool
    pub fn new<const AREA_SIZE: usize>(page_pool_area: &'static PagePoolArea<AREA_SIZE>) -> Self {
        let pa = KernelSpace::kernel_to_pa(&page_pool_area.area[0] as *const u8 as u64) as usize;
        let length = page_pool_area.area.len();

        let mut region_pool = RegionPool::new();
        region_pool.add(Pages::new(pa, length, false)).unwrap();
        Self {
            pages: Arc::new(Mutex::new(region_pool)),
        }
    }

    /// Allocate pages for given length
    pub fn allocate_pages(&self, length: usize) -> Result<Pages, PagePoolError> {
        self.pages
            .lock()
            .allocate(Self::round_up_to_page_size(length), (), None)
    }

    /// Release pages
    pub fn release_pages(&self, pages_to_release: Pages) -> Result<(), PagePoolError> {
        self.pages.lock().release(pages_to_release)
    }

    fn round_up_to_page_size(length: usize) -> usize {
        (length + Page::SIZE - 1) & !(Page::SIZE - 1)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_pages() {
        let area = [0x5au8; 4096];
        let mut pages = Pages::new(area.as_ptr() as usize, area.len(), true);

        assert_eq!(area.as_ptr() as usize, pages.pa);
        assert_eq!(area.len(), pages.length);
        assert!(pages.used);
        assert_eq!(PhysicalAddress(area.as_ptr() as usize), pages.get_pa());
        assert_eq!(area.as_ptr() as usize, pages.base());
        assert_eq!(area.len(), pages.length());
        assert!(pages.used());

        pages.copy_data_to_page(&[0, 1, 2, 3, 4, 5, 6, 7]);
        assert_eq!([0, 1, 2, 3, 4, 5, 6, 7], area[0..8]);

        pages.zero_init();
        assert_eq!([0, 0, 0, 0, 0, 0, 0, 0], area[0..8]);

        let s = unsafe { pages.get_as_slice() };
        for (i, e) in s.iter_mut().enumerate().take(8) {
            *e = i as u8;
        }
        assert_eq!([0, 1, 2, 3, 4, 5, 6, 7], area[0..8]);

        let from_slice = unsafe { Pages::from_slice(s) };
        assert_eq!(area.as_ptr() as usize, from_slice.pa);
        assert_eq!(area.len(), from_slice.length);
        assert!(from_slice.used);
    }

    #[test]
    fn test_pages_contains() {
        let pages = Pages::new(0x4000_0000, 0x4000, true);

        assert!(!pages.contains(0x3fff_f000, 0x1000));
        assert!(!pages.contains(0x3fff_f000, 0x1_0000));
        assert!(!pages.contains(0x4000_4000, 0x1000));
        assert!(!pages.contains(0x4000_0000, 0x1_0000));

        // Overflow tests
    }

    #[test]
    fn test_pages_try_alloc() {
        let pages = Pages::new(0x4000_1000, 0x10000, false);

        assert_eq!(Some(0x4000_1000), pages.try_alloc(0x1000, None));
        assert_eq!(Some(0x4000_2000), pages.try_alloc(0x1000, Some(0x2000)));
        assert_eq!(None, pages.try_alloc(0x1000, Some(0x10_0000)));
    }
}
