initial state vector implementation

This commit is contained in:
2025-10-25 16:11:39 -07:00
parent b067ae5cec
commit e0f17649b2
8 changed files with 244 additions and 28 deletions

View File

@@ -0,0 +1,122 @@
use std::any::Any;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::RwLock;
pub struct StateVector {
next_section: AtomicUsize,
sections: RwLock<HashMap<SectionIdentifier, Box<RwLock<dyn Any + Send + Sync>>>>,
}
#[derive(Clone, Eq, PartialEq, Hash)]
pub struct SectionIdentifier(usize);
pub struct SectionWriter<'a, T> {
id: SectionIdentifier,
state_vector: &'a StateVector,
_phantom_data: PhantomData<T>,
}
impl<'a, T: 'static> SectionWriter<'a, T> {
pub fn get_identifier(&self) -> SectionIdentifier {
self.id.clone()
}
pub fn update<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut T) -> R,
{
self.state_vector.sections.clear_poison();
let sections = self.state_vector.sections.read().unwrap();
let section = sections.get(&self.id).unwrap();
let mut data = section.write().unwrap();
let result = data.downcast_mut::<T>().unwrap();
f(result)
}
}
impl StateVector {
pub fn new() -> Self {
Self {
next_section: AtomicUsize::new(0usize),
sections: RwLock::new(HashMap::new()),
}
}
pub fn create_section<T>(&self, initial_value: T) -> SectionWriter<'_, T>
where
T: Send + Sync + 'static,
{
let id = SectionIdentifier(self.next_section.fetch_add(1usize, Ordering::SeqCst));
let lock = Box::new(RwLock::new(initial_value));
self.sections.clear_poison();
self.sections.write().unwrap().insert(id.clone(), lock);
SectionWriter {
id,
state_vector: &self,
_phantom_data: PhantomData,
}
}
pub fn access_section<T, F, R>(&self, id: &SectionIdentifier, f: F) -> Option<R>
where
T: 'static,
F: FnOnce(&T) -> R,
{
self.sections.clear_poison();
let Ok(sections) = self.sections.read() else { return None; };
let Some(section) = sections.get(id) else { return None; };
section.clear_poison();
let Ok(data) = section.read() else { return None; };
let Some(inner) = data.downcast_ref::<T>() else { return None; };
Some(f(inner))
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
#[derive(Default)]
struct TestType {
value1: i32,
value2: i32,
}
#[test]
fn test_two_sections() -> Result<()> {
let state_vector = StateVector::new();
let section_1 = state_vector.create_section(TestType::default());
let section_2 = state_vector.create_section(TestType::default());
section_1.update(|s| {
s.value1 = 1;
s.value2 = 2;
});
section_2.update(|s| {
s.value1 = 3;
s.value2 = 4;
});
let id_1 = section_1.get_identifier();
state_vector.access_section(&id_1, |s: &TestType| {
assert_eq!(1, s.value1);
assert_eq!(2, s.value2);
});
let id_2 = section_2.get_identifier();
state_vector.access_section(&id_2, |s: &TestType| {
assert_eq!(3, s.value1);
assert_eq!(4, s.value2);
});
Ok(())
}
}