142 lines
4.1 KiB
Rust
142 lines
4.1 KiB
Rust
use log::trace;
|
|
use std::any::{type_name, Any};
|
|
use std::collections::HashMap;
|
|
use std::fmt::{Debug, Formatter};
|
|
use std::marker::PhantomData;
|
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
use std::sync::RwLock;
|
|
|
|
#[derive(Debug)]
|
|
pub struct StateVector {
|
|
next_section: AtomicUsize,
|
|
sections: RwLock<HashMap<SectionIdentifier, Box<RwLock<dyn Any + Send + Sync>>>>,
|
|
}
|
|
|
|
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
|
|
pub struct SectionIdentifier(usize);
|
|
|
|
pub struct SectionWriter<'a, T> {
|
|
id: SectionIdentifier,
|
|
state_vector: &'a StateVector,
|
|
_phantom_data: PhantomData<T>,
|
|
}
|
|
|
|
impl<'a, T> Debug for SectionWriter<'a, T> {
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "SectionWriter<T={}> {{ id: {:?}, state_vector: {:?} }}", type_name::<T>(), self.id, self.state_vector)
|
|
}
|
|
}
|
|
|
|
impl<'a, T: 'static> SectionWriter<'a, T> {
|
|
pub fn get_identifier(&self) -> SectionIdentifier {
|
|
trace!("SectionWriter<T={}>::get_identifier(self: {self:?})", type_name::<T>());
|
|
self.id.clone()
|
|
}
|
|
|
|
pub fn update<F, R>(&self, f: F) -> R
|
|
where
|
|
F: FnOnce(&mut T) -> R,
|
|
{
|
|
trace!("SectionWriter<T={}>::update(self: {self:?}, f)", type_name::<T>());
|
|
self.state_vector.sections.clear_poison();
|
|
let sections = self.state_vector.sections.read().unwrap();
|
|
let section = sections.get(&self.id).unwrap();
|
|
let mut data = section.write().unwrap();
|
|
let result = data.downcast_mut::<T>().unwrap();
|
|
f(result)
|
|
}
|
|
}
|
|
|
|
impl StateVector {
|
|
pub fn new() -> Self {
|
|
trace!("StateVector::new()");
|
|
Self {
|
|
next_section: AtomicUsize::new(0usize),
|
|
sections: RwLock::new(HashMap::new()),
|
|
}
|
|
}
|
|
|
|
pub fn create_section<T>(&self, initial_value: T) -> SectionWriter<'_, T>
|
|
where
|
|
T: Send + Sync + 'static,
|
|
{
|
|
trace!("StateVector::create_section<T={}>(self: {self:?}, initial_value)", type_name::<T>());
|
|
let id = SectionIdentifier(self.next_section.fetch_add(1usize, Ordering::SeqCst));
|
|
let lock = Box::new(RwLock::new(initial_value));
|
|
|
|
self.sections.clear_poison();
|
|
let mut sections = self.sections.write().unwrap();
|
|
if !sections.contains_key(&id) {
|
|
sections.insert(id.clone(), lock);
|
|
}
|
|
|
|
drop(sections);
|
|
|
|
SectionWriter {
|
|
id,
|
|
state_vector: &self,
|
|
_phantom_data: PhantomData,
|
|
}
|
|
}
|
|
|
|
pub fn access_section<T, F, R>(&self, id: &SectionIdentifier, f: F) -> Option<R>
|
|
where
|
|
T: 'static,
|
|
F: FnOnce(&T) -> R,
|
|
{
|
|
trace!("StateVector::access_section<T={}, F={}, R={}>(self: {self:?}, id: {id:?}, f)", type_name::<T>(), type_name::<F>(), type_name::<R>());
|
|
self.sections.clear_poison();
|
|
let Ok(sections) = self.sections.read() else { return None; };
|
|
let Some(section) = sections.get(id) else { return None; };
|
|
section.clear_poison();
|
|
let Ok(data) = section.read() else { return None; };
|
|
let Some(inner) = data.downcast_ref::<T>() else { return None; };
|
|
Some(f(inner))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use anyhow::Result;
|
|
|
|
#[derive(Default)]
|
|
struct TestType {
|
|
value1: i32,
|
|
value2: i32,
|
|
}
|
|
|
|
#[test]
|
|
fn test_two_sections() -> Result<()> {
|
|
let state_vector = StateVector::new();
|
|
|
|
let section_1 = state_vector.create_section(TestType::default());
|
|
let section_2 = state_vector.create_section(TestType::default());
|
|
|
|
section_1.update(|s| {
|
|
s.value1 = 1;
|
|
s.value2 = 2;
|
|
});
|
|
section_2.update(|s| {
|
|
s.value1 = 3;
|
|
s.value2 = 4;
|
|
});
|
|
|
|
let id_1 = section_1.get_identifier();
|
|
|
|
state_vector.access_section(&id_1, |s: &TestType| {
|
|
assert_eq!(1, s.value1);
|
|
assert_eq!(2, s.value2);
|
|
});
|
|
|
|
let id_2 = section_2.get_identifier();
|
|
|
|
state_vector.access_section(&id_2, |s: &TestType| {
|
|
assert_eq!(3, s.value1);
|
|
assert_eq!(4, s.value2);
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
}
|