Files
ProjectNautilus/flight/src/state_vector/mod.rs
2025-10-26 08:56:59 -07:00

142 lines
4.1 KiB
Rust

use log::trace;
use std::any::{type_name, Any};
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::marker::PhantomData;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::RwLock;
#[derive(Debug)]
pub struct StateVector {
next_section: AtomicUsize,
sections: RwLock<HashMap<SectionIdentifier, Box<RwLock<dyn Any + Send + Sync>>>>,
}
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct SectionIdentifier(usize);
pub struct SectionWriter<'a, T> {
id: SectionIdentifier,
state_vector: &'a StateVector,
_phantom_data: PhantomData<T>,
}
impl<'a, T> Debug for SectionWriter<'a, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "SectionWriter<T={}> {{ id: {:?}, state_vector: {:?} }}", type_name::<T>(), self.id, self.state_vector)
}
}
impl<'a, T: 'static> SectionWriter<'a, T> {
pub fn get_identifier(&self) -> SectionIdentifier {
trace!("SectionWriter<T={}>::get_identifier(self: {self:?})", type_name::<T>());
self.id.clone()
}
pub fn update<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut T) -> R,
{
trace!("SectionWriter<T={}>::update(self: {self:?}, f)", type_name::<T>());
self.state_vector.sections.clear_poison();
let sections = self.state_vector.sections.read().unwrap();
let section = sections.get(&self.id).unwrap();
let mut data = section.write().unwrap();
let result = data.downcast_mut::<T>().unwrap();
f(result)
}
}
impl StateVector {
pub fn new() -> Self {
trace!("StateVector::new()");
Self {
next_section: AtomicUsize::new(0usize),
sections: RwLock::new(HashMap::new()),
}
}
pub fn create_section<T>(&self, initial_value: T) -> SectionWriter<'_, T>
where
T: Send + Sync + 'static,
{
trace!("StateVector::create_section<T={}>(self: {self:?}, initial_value)", type_name::<T>());
let id = SectionIdentifier(self.next_section.fetch_add(1usize, Ordering::SeqCst));
let lock = Box::new(RwLock::new(initial_value));
self.sections.clear_poison();
let mut sections = self.sections.write().unwrap();
if !sections.contains_key(&id) {
sections.insert(id.clone(), lock);
}
drop(sections);
SectionWriter {
id,
state_vector: &self,
_phantom_data: PhantomData,
}
}
pub fn access_section<T, F, R>(&self, id: &SectionIdentifier, f: F) -> Option<R>
where
T: 'static,
F: FnOnce(&T) -> R,
{
trace!("StateVector::access_section<T={}, F={}, R={}>(self: {self:?}, id: {id:?}, f)", type_name::<T>(), type_name::<F>(), type_name::<R>());
self.sections.clear_poison();
let Ok(sections) = self.sections.read() else { return None; };
let Some(section) = sections.get(id) else { return None; };
section.clear_poison();
let Ok(data) = section.read() else { return None; };
let Some(inner) = data.downcast_ref::<T>() else { return None; };
Some(f(inner))
}
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
#[derive(Default)]
struct TestType {
value1: i32,
value2: i32,
}
#[test]
fn test_two_sections() -> Result<()> {
let state_vector = StateVector::new();
let section_1 = state_vector.create_section(TestType::default());
let section_2 = state_vector.create_section(TestType::default());
section_1.update(|s| {
s.value1 = 1;
s.value2 = 2;
});
section_2.update(|s| {
s.value1 = 3;
s.value2 = 4;
});
let id_1 = section_1.get_identifier();
state_vector.access_section(&id_1, |s: &TestType| {
assert_eq!(1, s.value1);
assert_eq!(2, s.value2);
});
let id_2 = section_2.get_identifier();
state_vector.access_section(&id_2, |s: &TestType| {
assert_eq!(3, s.value1);
assert_eq!(4, s.value2);
});
Ok(())
}
}