dfir_rs/scheduled/
context.rs1use std::any::Any;
6use std::cell::Cell;
7use std::collections::VecDeque;
8use std::future::Future;
9use std::marker::PhantomData;
10use std::ops::DerefMut;
11use std::pin::Pin;
12
13use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
14use tokio::task::JoinHandle;
15use web_time::SystemTime;
16
17use super::state::StateHandle;
18use super::{LoopTag, StateId, SubgraphId};
19use crate::scheduled::ticks::TickInstant;
20use crate::util::priority_stack::PriorityStack;
21use crate::util::slot_vec::SlotVec;
22
23pub struct Context {
29 states: Vec<StateData>,
31
32 pub(super) stratum_stack: PriorityStack<usize>,
34
35 pub(super) loop_nonce_stack: Vec<usize>,
37
38 pub(super) schedule_deferred: Vec<SubgraphId>,
41
42 pub(super) stratum_queues: Vec<VecDeque<SubgraphId>>,
45
46 pub(super) event_queue_recv: UnboundedReceiver<(SubgraphId, bool)>,
48 pub(super) can_start_tick: bool,
50 pub(super) events_received_tick: bool,
52
53 pub(super) event_queue_send: UnboundedSender<(SubgraphId, bool)>,
56
57 pub(super) reschedule_loop_block: Cell<bool>,
59 pub(super) allow_another_iteration: Cell<bool>,
60
61 pub(super) current_tick: TickInstant,
62 pub(super) current_stratum: usize,
63
64 pub(super) current_tick_start: SystemTime,
65 pub(super) is_first_run_this_tick: bool,
66 pub(super) loop_iter_count: usize,
67
68 pub(super) loop_depth: SlotVec<LoopTag, usize>,
70
71 pub(super) loop_nonce: usize,
72
73 pub(super) subgraph_id: SubgraphId,
76
77 tasks_to_spawn: Vec<Pin<Box<dyn Future<Output = ()> + 'static>>>,
78 task_join_handles: Vec<JoinHandle<()>>,
80}
81impl Context {
83 pub fn current_tick(&self) -> TickInstant {
85 self.current_tick
86 }
87
88 pub fn current_tick_start(&self) -> SystemTime {
90 self.current_tick_start
91 }
92
93 pub fn is_first_run_this_tick(&self) -> bool {
95 self.is_first_run_this_tick
96 }
97
98 pub fn loop_iter_count(&self) -> usize {
100 self.loop_iter_count
101 }
102
103 pub fn current_stratum(&self) -> usize {
105 self.current_stratum
106 }
107
108 pub fn current_subgraph(&self) -> SubgraphId {
110 self.subgraph_id
111 }
112
113 pub fn schedule_subgraph(&self, sg_id: SubgraphId, is_external: bool) {
119 self.event_queue_send.send((sg_id, is_external)).unwrap()
120 }
121
122 pub fn reschedule_loop_block(&self) {
124 self.reschedule_loop_block.set(true);
125 }
126
127 pub fn allow_another_iteration(&self) {
129 self.allow_another_iteration.set(true);
130 }
131
132 pub fn waker(&self) -> std::task::Waker {
135 use std::sync::Arc;
136
137 use futures::task::ArcWake;
138
139 struct ContextWaker {
140 subgraph_id: SubgraphId,
141 event_queue_send: UnboundedSender<(SubgraphId, bool)>,
142 }
143 impl ArcWake for ContextWaker {
144 fn wake_by_ref(arc_self: &Arc<Self>) {
145 let _recv_closed_error =
146 arc_self.event_queue_send.send((arc_self.subgraph_id, true));
147 }
148 }
149
150 let context_waker = ContextWaker {
151 subgraph_id: self.subgraph_id,
152 event_queue_send: self.event_queue_send.clone(),
153 };
154 futures::task::waker(Arc::new(context_waker))
155 }
156
157 pub unsafe fn state_ref_unchecked<T>(&self, handle: StateHandle<T>) -> &'_ T
162 where
163 T: Any,
164 {
165 let state = self
166 .states
167 .get(handle.state_id.0)
168 .expect("Failed to find state with given handle.")
169 .state
170 .as_ref();
171
172 debug_assert!(state.is::<T>());
173
174 unsafe {
175 &*(state as *const dyn Any as *const T)
178 }
179 }
180
181 pub fn state_ref<T>(&self, handle: StateHandle<T>) -> &'_ T
183 where
184 T: Any,
185 {
186 self.states
187 .get(handle.state_id.0)
188 .expect("Failed to find state with given handle.")
189 .state
190 .downcast_ref()
191 .expect("StateHandle wrong type T for casting.")
192 }
193
194 pub fn state_mut<T>(&mut self, handle: StateHandle<T>) -> &'_ mut T
196 where
197 T: Any,
198 {
199 self.states
200 .get_mut(handle.state_id.0)
201 .expect("Failed to find state with given handle.")
202 .state
203 .downcast_mut()
204 .expect("StateHandle wrong type T for casting.")
205 }
206
207 pub fn add_state<T>(&mut self, state: T) -> StateHandle<T>
209 where
210 T: Any,
211 {
212 let state_id = StateId(self.states.len());
213
214 let state_data = StateData {
215 state: Box::new(state),
216 tick_reset: None,
217 };
218 self.states.push(state_data);
219
220 StateHandle {
221 state_id,
222 _phantom: PhantomData,
223 }
224 }
225
226 pub fn set_state_tick_hook<T>(
228 &mut self,
229 handle: StateHandle<T>,
230 mut tick_hook_fn: impl 'static + FnMut(&mut T),
231 ) where
232 T: Any,
233 {
234 self.states
235 .get_mut(handle.state_id.0)
236 .expect("Failed to find state with given handle.")
237 .tick_reset = Some(Box::new(move |state| {
238 (tick_hook_fn)(state.downcast_mut::<T>().unwrap());
239 }));
240 }
241
242 pub fn remove_state<T>(&mut self, handle: StateHandle<T>) -> Box<T>
244 where
245 T: Any,
246 {
247 self.states
248 .remove(handle.state_id.0)
249 .state
250 .downcast()
251 .expect("StateHandle wrong type T for casting.")
252 }
253
254 pub fn request_task<Fut>(&mut self, future: Fut)
256 where
257 Fut: Future<Output = ()> + 'static,
258 {
259 self.tasks_to_spawn.push(Box::pin(future));
260 }
261
262 pub fn spawn_tasks(&mut self) {
264 for task in self.tasks_to_spawn.drain(..) {
265 self.task_join_handles.push(tokio::task::spawn_local(task));
266 }
267 }
268
269 pub fn abort_tasks(&mut self) {
271 for task in self.task_join_handles.drain(..) {
272 task.abort();
273 }
274 }
275
276 pub async fn join_tasks(&mut self) {
280 futures::future::join_all(self.task_join_handles.drain(..)).await;
281 }
282}
283
284impl Default for Context {
285 fn default() -> Self {
286 let stratum_queues = vec![Default::default()]; let (event_queue_send, event_queue_recv) = mpsc::unbounded_channel();
288 let (stratum_stack, loop_depth) = Default::default();
289 Self {
290 states: Vec::new(),
291
292 stratum_stack,
293
294 loop_nonce_stack: Vec::new(),
295
296 schedule_deferred: Vec::new(),
297
298 stratum_queues,
299 event_queue_recv,
300 can_start_tick: false,
301 events_received_tick: false,
302
303 event_queue_send,
304 reschedule_loop_block: Cell::new(false),
305 allow_another_iteration: Cell::new(false),
306
307 current_stratum: 0,
308 current_tick: TickInstant::default(),
309
310 current_tick_start: SystemTime::now(),
311 is_first_run_this_tick: false,
312 loop_iter_count: 0,
313
314 loop_depth,
315 loop_nonce: 0,
316
317 subgraph_id: SubgraphId::from_raw(0),
319
320 tasks_to_spawn: Vec::new(),
321 task_join_handles: Vec::new(),
322 }
323 }
324}
325impl Context {
327 pub(super) fn init_stratum(&mut self, stratum: usize) {
329 if self.stratum_queues.len() <= stratum {
330 self.stratum_queues
331 .resize_with(stratum + 1, Default::default);
332 }
333 }
334
335 pub(super) fn reset_state_at_end_of_tick(&mut self) {
337 for StateData { state, tick_reset } in self.states.iter_mut() {
338 if let Some(tick_reset) = tick_reset {
339 (tick_reset)(Box::deref_mut(state));
340 }
341 }
342 }
343}
344
345struct StateData {
347 state: Box<dyn Any>,
348 tick_reset: Option<TickResetFn>,
349}
350type TickResetFn = Box<dyn FnMut(&mut dyn Any)>;