use log::trace; use std::any::{type_name, Any}; use std::collections::HashMap; use std::fmt::{Debug, Formatter}; use std::marker::PhantomData; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::RwLock; #[derive(Debug)] pub struct StateVector { next_section: AtomicUsize, sections: RwLock>>>, } #[derive(Clone, Eq, PartialEq, Hash, Debug)] pub struct SectionIdentifier(usize); pub struct SectionWriter<'a, T> { id: SectionIdentifier, state_vector: &'a StateVector, _phantom_data: PhantomData, } impl<'a, T> Debug for SectionWriter<'a, T> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "SectionWriter {{ id: {:?}, state_vector: {:?} }}", type_name::(), self.id, self.state_vector) } } impl<'a, T: 'static> SectionWriter<'a, T> { pub fn get_identifier(&self) -> SectionIdentifier { trace!("SectionWriter::get_identifier(self: {self:?})", type_name::()); self.id.clone() } pub fn update(&self, f: F) -> R where F: FnOnce(&mut T) -> R, { trace!("SectionWriter::update(self: {self:?}, f)", type_name::()); self.state_vector.sections.clear_poison(); let sections = self.state_vector.sections.read().unwrap(); let section = sections.get(&self.id).unwrap(); let mut data = section.write().unwrap(); let result = data.downcast_mut::().unwrap(); f(result) } } impl StateVector { pub fn new() -> Self { trace!("StateVector::new()"); Self { next_section: AtomicUsize::new(0usize), sections: RwLock::new(HashMap::new()), } } pub fn create_section(&self, initial_value: T) -> SectionWriter<'_, T> where T: Send + Sync + 'static, { trace!("StateVector::create_section(self: {self:?}, initial_value)", type_name::()); let id = SectionIdentifier(self.next_section.fetch_add(1usize, Ordering::SeqCst)); let lock = Box::new(RwLock::new(initial_value)); self.sections.clear_poison(); let mut sections = self.sections.write().unwrap(); if !sections.contains_key(&id) { sections.insert(id.clone(), lock); } drop(sections); SectionWriter { id, state_vector: &self, _phantom_data: PhantomData, } } pub fn access_section(&self, id: &SectionIdentifier, f: F) -> Option where T: 'static, F: FnOnce(&T) -> R, { trace!("StateVector::access_section(self: {self:?}, id: {id:?}, f)", type_name::(), type_name::(), type_name::()); self.sections.clear_poison(); let Ok(sections) = self.sections.read() else { return None; }; let Some(section) = sections.get(id) else { return None; }; section.clear_poison(); let Ok(data) = section.read() else { return None; }; let Some(inner) = data.downcast_ref::() else { return None; }; Some(f(inner)) } } #[cfg(test)] mod tests { use super::*; use anyhow::Result; #[derive(Default)] struct TestType { value1: i32, value2: i32, } #[test] fn test_two_sections() -> Result<()> { let state_vector = StateVector::new(); let section_1 = state_vector.create_section(TestType::default()); let section_2 = state_vector.create_section(TestType::default()); section_1.update(|s| { s.value1 = 1; s.value2 = 2; }); section_2.update(|s| { s.value1 = 3; s.value2 = 4; }); let id_1 = section_1.get_identifier(); state_vector.access_section(&id_1, |s: &TestType| { assert_eq!(1, s.value1); assert_eq!(2, s.value2); }); let id_2 = section_2.get_identifier(); state_vector.access_section(&id_2, |s: &TestType| { assert_eq!(3, s.value1); assert_eq!(4, s.value2); }); Ok(()) } }