1use std::collections::HashMap;
4use std::fmt::{Debug, Display};
5use std::ops::{Bound, RangeBounds};
6use std::sync::OnceLock;
7
8use proc_macro2::{Ident, Literal, Span, TokenStream};
9use quote::quote_spanned;
10use serde::{Deserialize, Serialize};
11use slotmap::Key;
12use syn::punctuated::Punctuated;
13use syn::{Expr, Token, parse_quote_spanned};
14
15use super::{
16 GraphLoopId, GraphNode, GraphNodeId, GraphSubgraphId, OpInstGenerics, OperatorInstance,
17 PortIndexValue,
18};
19use crate::diagnostic::Diagnostic;
20use crate::parse::{Operator, PortIndex};
21
22#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
24pub enum DelayType {
25 Stratum,
27 MonotoneAccum,
29 Tick,
31 TickLazy,
33}
34
35pub enum PortListSpec {
37 Variadic,
39 Fixed(Punctuated<PortIndex, Token![,]>),
41}
42
43pub struct OperatorConstraints {
45 pub name: &'static str,
47 pub categories: &'static [OperatorCategory],
49
50 pub hard_range_inn: &'static dyn RangeTrait<usize>,
53 pub soft_range_inn: &'static dyn RangeTrait<usize>,
55 pub hard_range_out: &'static dyn RangeTrait<usize>,
57 pub soft_range_out: &'static dyn RangeTrait<usize>,
59 pub num_args: usize,
61 pub persistence_args: &'static dyn RangeTrait<usize>,
63 pub type_args: &'static dyn RangeTrait<usize>,
67 pub is_external_input: bool,
70 pub has_singleton_output: bool,
74 pub flo_type: Option<FloType>,
76
77 pub ports_inn: Option<fn() -> PortListSpec>,
79 pub ports_out: Option<fn() -> PortListSpec>,
81
82 pub input_delaytype_fn: fn(&PortIndexValue) -> Option<DelayType>,
84 pub write_fn: WriteFn,
86}
87
88pub type WriteFn =
90 fn(&WriteContextArgs<'_>, &mut Vec<Diagnostic>) -> Result<OperatorWriteOutput, ()>;
91
92impl Debug for OperatorConstraints {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 f.debug_struct("OperatorConstraints")
95 .field("name", &self.name)
96 .field("hard_range_inn", &self.hard_range_inn)
97 .field("soft_range_inn", &self.soft_range_inn)
98 .field("hard_range_out", &self.hard_range_out)
99 .field("soft_range_out", &self.soft_range_out)
100 .field("num_args", &self.num_args)
101 .field("persistence_args", &self.persistence_args)
102 .field("type_args", &self.type_args)
103 .field("is_external_input", &self.is_external_input)
104 .field("ports_inn", &self.ports_inn)
105 .field("ports_out", &self.ports_out)
106 .finish()
110 }
111}
112
113#[derive(Default)]
115#[non_exhaustive]
116pub struct OperatorWriteOutput {
117 pub write_prologue: TokenStream,
120 pub write_iterator: TokenStream,
127 pub write_iterator_after: TokenStream,
129}
130
131pub const RANGE_ANY: &'static dyn RangeTrait<usize> = &(0..);
133pub const RANGE_0: &'static dyn RangeTrait<usize> = &(0..=0);
135pub const RANGE_1: &'static dyn RangeTrait<usize> = &(1..=1);
137
138pub fn identity_write_iterator_fn(
141 &WriteContextArgs {
142 root,
143 op_span,
144 ident,
145 inputs,
146 outputs,
147 is_pull,
148 op_inst:
149 OperatorInstance {
150 generics: OpInstGenerics { type_args, .. },
151 ..
152 },
153 ..
154 }: &WriteContextArgs,
155) -> TokenStream {
156 let generic_type = type_args
157 .first()
158 .map(quote::ToTokens::to_token_stream)
159 .unwrap_or(quote_spanned!(op_span=> _));
160
161 if is_pull {
162 let input = &inputs[0];
163 quote_spanned! {op_span=>
164 let #ident = {
165 fn check_input<Iter: ::std::iter::Iterator<Item = Item>, Item>(iter: Iter) -> impl ::std::iter::Iterator<Item = Item> { iter }
166 check_input::<_, #generic_type>(#input)
167 };
168 }
169 } else {
170 let output = &outputs[0];
171 quote_spanned! {op_span=>
172 let #ident = {
173 fn check_output<Push: #root::pusherator::Pusherator<Item = Item>, Item>(push: Push) -> impl #root::pusherator::Pusherator<Item = Item> { push }
174 check_output::<_, #generic_type>(#output)
175 };
176 }
177 }
178}
179
180pub const IDENTITY_WRITE_FN: WriteFn = |write_context_args, _| {
182 let write_iterator = identity_write_iterator_fn(write_context_args);
183 Ok(OperatorWriteOutput {
184 write_iterator,
185 ..Default::default()
186 })
187};
188
189pub fn null_write_iterator_fn(
192 &WriteContextArgs {
193 root,
194 op_span,
195 ident,
196 inputs,
197 outputs,
198 is_pull,
199 op_inst:
200 OperatorInstance {
201 generics: OpInstGenerics { type_args, .. },
202 ..
203 },
204 ..
205 }: &WriteContextArgs,
206) -> TokenStream {
207 let default_type = parse_quote_spanned! {op_span=> _};
208 let iter_type = type_args.first().unwrap_or(&default_type);
209
210 if is_pull {
211 quote_spanned! {op_span=>
212 #(
213 #inputs.for_each(std::mem::drop);
214 )*
215 let #ident = std::iter::empty::<#iter_type>();
216 }
217 } else {
218 quote_spanned! {op_span=>
219 #[allow(clippy::let_unit_value)]
220 let _ = (#(#outputs),*);
221 let #ident = #root::pusherator::null::Null::<#iter_type>::new();
222 }
223 }
224}
225
226pub const NULL_WRITE_FN: WriteFn = |write_context_args, _| {
229 let write_iterator = null_write_iterator_fn(write_context_args);
230 Ok(OperatorWriteOutput {
231 write_iterator,
232 ..Default::default()
233 })
234};
235
236macro_rules! declare_ops {
237 ( $( $mod:ident :: $op:ident, )* ) => {
238 $( pub(crate) mod $mod; )*
239 pub const OPERATORS: &[OperatorConstraints] = &[
241 $( $mod :: $op, )*
242 ];
243 };
244}
245declare_ops![
246 all_iterations::ALL_ITERATIONS,
247 all_once::ALL_ONCE,
248 anti_join::ANTI_JOIN,
249 anti_join_multiset::ANTI_JOIN_MULTISET,
250 assert::ASSERT,
251 assert_eq::ASSERT_EQ,
252 batch::BATCH,
253 chain::CHAIN,
254 _counter::_COUNTER,
255 cross_join::CROSS_JOIN,
256 cross_join_multiset::CROSS_JOIN_MULTISET,
257 cross_singleton::CROSS_SINGLETON,
258 demux::DEMUX,
259 demux_enum::DEMUX_ENUM,
260 dest_file::DEST_FILE,
261 dest_sink::DEST_SINK,
262 dest_sink_serde::DEST_SINK_SERDE,
263 difference::DIFFERENCE,
264 difference_multiset::DIFFERENCE_MULTISET,
265 enumerate::ENUMERATE,
266 filter::FILTER,
267 filter_map::FILTER_MAP,
268 flat_map::FLAT_MAP,
269 flatten::FLATTEN,
270 fold::FOLD,
271 for_each::FOR_EACH,
272 identity::IDENTITY,
273 initialize::INITIALIZE,
274 inspect::INSPECT,
275 join::JOIN,
276 join_fused::JOIN_FUSED,
277 join_fused_lhs::JOIN_FUSED_LHS,
278 join_fused_rhs::JOIN_FUSED_RHS,
279 join_multiset::JOIN_MULTISET,
280 fold_keyed::FOLD_KEYED,
281 reduce_keyed::REDUCE_KEYED,
282 repeat_n::REPEAT_N,
283 lattice_bimorphism::LATTICE_BIMORPHISM,
285 _lattice_fold_batch::_LATTICE_FOLD_BATCH,
286 lattice_fold::LATTICE_FOLD,
287 _lattice_join_fused_join::_LATTICE_JOIN_FUSED_JOIN,
288 lattice_reduce::LATTICE_REDUCE,
289 map::MAP,
290 union::UNION,
291 multiset_delta::MULTISET_DELTA,
292 next_iteration::NEXT_ITERATION,
293 next_stratum::NEXT_STRATUM,
294 defer_signal::DEFER_SIGNAL,
295 defer_tick::DEFER_TICK,
296 defer_tick_lazy::DEFER_TICK_LAZY,
297 null::NULL,
298 partition::PARTITION,
299 persist::PERSIST,
300 persist_mut::PERSIST_MUT,
301 persist_mut_keyed::PERSIST_MUT_KEYED,
302 prefix::PREFIX,
303 py_udf::PY_UDF,
304 reduce::REDUCE,
305 spin::SPIN,
306 sort::SORT,
307 sort_by_key::SORT_BY_KEY,
308 source_file::SOURCE_FILE,
309 source_interval::SOURCE_INTERVAL,
310 source_iter::SOURCE_ITER,
311 source_json::SOURCE_JSON,
312 source_stdin::SOURCE_STDIN,
313 source_stream::SOURCE_STREAM,
314 source_stream_serde::SOURCE_STREAM_SERDE,
315 state::STATE,
316 state_by::STATE_BY,
317 tee::TEE,
318 unique::UNIQUE,
319 unzip::UNZIP,
320 zip::ZIP,
321 zip_longest::ZIP_LONGEST,
322];
323
324pub fn operator_lookup() -> &'static HashMap<&'static str, &'static OperatorConstraints> {
326 pub static OPERATOR_LOOKUP: OnceLock<HashMap<&'static str, &'static OperatorConstraints>> =
327 OnceLock::new();
328 OPERATOR_LOOKUP.get_or_init(|| OPERATORS.iter().map(|op| (op.name, op)).collect())
329}
330pub fn find_node_op_constraints(node: &GraphNode) -> Option<&'static OperatorConstraints> {
332 if let GraphNode::Operator(operator) = node {
333 find_op_op_constraints(operator)
334 } else {
335 None
336 }
337}
338pub fn find_op_op_constraints(operator: &Operator) -> Option<&'static OperatorConstraints> {
340 let name = &*operator.name_string();
341 operator_lookup().get(name).copied()
342}
343
344#[derive(Clone)]
346pub struct WriteContextArgs<'a> {
347 pub root: &'a TokenStream,
349 pub context: &'a Ident,
352 pub df_ident: &'a Ident,
356 pub subgraph_id: GraphSubgraphId,
358 pub node_id: GraphNodeId,
360 pub loop_id: Option<GraphLoopId>,
362 pub op_span: Span,
364 pub op_tag: Option<String>,
366 pub work_fn: &'a Ident,
368
369 pub ident: &'a Ident,
371 pub is_pull: bool,
373 pub inputs: &'a [Ident],
375 pub outputs: &'a [Ident],
377 pub singleton_output_ident: &'a Ident,
379
380 pub op_name: &'static str,
382 pub op_inst: &'a OperatorInstance,
384 pub arguments: &'a Punctuated<Expr, Token![,]>,
390 pub arguments_handles: &'a Punctuated<Expr, Token![,]>,
392}
393impl WriteContextArgs<'_> {
394 pub fn make_ident(&self, suffix: impl AsRef<str>) -> Ident {
400 Ident::new(
401 &format!(
402 "sg_{:?}_node_{:?}_{}",
403 self.subgraph_id.data(),
404 self.node_id.data(),
405 suffix.as_ref(),
406 ),
407 self.op_span,
408 )
409 }
410}
411
412pub trait RangeTrait<T>: Send + Sync + Debug
414where
415 T: ?Sized,
416{
417 fn start_bound(&self) -> Bound<&T>;
419 fn end_bound(&self) -> Bound<&T>;
421 fn contains(&self, item: &T) -> bool
423 where
424 T: PartialOrd<T>;
425
426 fn human_string(&self) -> String
428 where
429 T: Display + PartialEq,
430 {
431 match (self.start_bound(), self.end_bound()) {
432 (Bound::Unbounded, Bound::Unbounded) => "any number of".to_owned(),
433
434 (Bound::Included(n), Bound::Included(x)) if n == x => {
435 format!("exactly {}", n)
436 }
437 (Bound::Included(n), Bound::Included(x)) => {
438 format!("at least {} and at most {}", n, x)
439 }
440 (Bound::Included(n), Bound::Excluded(x)) => {
441 format!("at least {} and less than {}", n, x)
442 }
443 (Bound::Included(n), Bound::Unbounded) => format!("at least {}", n),
444 (Bound::Excluded(n), Bound::Included(x)) => {
445 format!("more than {} and at most {}", n, x)
446 }
447 (Bound::Excluded(n), Bound::Excluded(x)) => {
448 format!("more than {} and less than {}", n, x)
449 }
450 (Bound::Excluded(n), Bound::Unbounded) => format!("more than {}", n),
451 (Bound::Unbounded, Bound::Included(x)) => format!("at most {}", x),
452 (Bound::Unbounded, Bound::Excluded(x)) => format!("less than {}", x),
453 }
454 }
455}
456
457impl<R, T> RangeTrait<T> for R
458where
459 R: RangeBounds<T> + Send + Sync + Debug,
460{
461 fn start_bound(&self) -> Bound<&T> {
462 self.start_bound()
463 }
464
465 fn end_bound(&self) -> Bound<&T> {
466 self.end_bound()
467 }
468
469 fn contains(&self, item: &T) -> bool
470 where
471 T: PartialOrd<T>,
472 {
473 self.contains(item)
474 }
475}
476
477#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug, Serialize, Deserialize)]
479pub enum Persistence {
480 None,
482 Tick,
484 Static,
486 Mutable,
488}
489
490fn make_missing_runtime_msg(op_name: &str) -> Literal {
492 Literal::string(&format!(
493 "`{}()` must be used within a Tokio runtime. For example, use `#[dfir_rs::main]` on your main method.",
494 op_name
495 ))
496}
497
498#[allow(
502 clippy::allow_attributes,
503 missing_docs,
504 reason = "see `Self::description`"
505)]
506#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
507pub enum OperatorCategory {
508 Map,
509 Filter,
510 Flatten,
511 Fold,
512 KeyedFold,
513 LatticeFold,
514 Persistence,
515 MultiIn,
516 MultiOut,
517 Source,
518 Sink,
519 Control,
520 CompilerFusionOperator,
521 Windowing,
522 Unwindowing,
523}
524impl OperatorCategory {
525 pub fn name(self) -> &'static str {
527 match self {
528 OperatorCategory::Map => "Maps",
529 OperatorCategory::Filter => "Filters",
530 OperatorCategory::Flatten => "Flattens",
531 OperatorCategory::Fold => "Folds",
532 OperatorCategory::KeyedFold => "Keyed Folds",
533 OperatorCategory::LatticeFold => "Lattice Folds",
534 OperatorCategory::Persistence => "Persistent Operators",
535 OperatorCategory::MultiIn => "Multi-Input Operators",
536 OperatorCategory::MultiOut => "Multi-Output Operators",
537 OperatorCategory::Source => "Sources",
538 OperatorCategory::Sink => "Sinks",
539 OperatorCategory::Control => "Control Flow Operators",
540 OperatorCategory::CompilerFusionOperator => "Compiler Fusion Operators",
541 OperatorCategory::Windowing => "Windowing Operator",
542 OperatorCategory::Unwindowing => "Un-Windowing Operator",
543 }
544 }
545 pub fn description(self) -> &'static str {
547 match self {
548 OperatorCategory::Map => "Simple one-in-one-out operators.",
549 OperatorCategory::Filter => "One-in zero-or-one-out operators.",
550 OperatorCategory::Flatten => "One-in multiple-out operators.",
551 OperatorCategory::Fold => "Operators which accumulate elements together.",
552 OperatorCategory::KeyedFold => "Keyed fold operators.",
553 OperatorCategory::LatticeFold => "Folds based on lattice-merge.",
554 OperatorCategory::Persistence => "Persistent (stateful) operators.",
555 OperatorCategory::MultiIn => "Operators with multiple inputs.",
556 OperatorCategory::MultiOut => "Operators with multiple outputs.",
557 OperatorCategory::Source => {
558 "Operators which produce output elements (and consume no inputs)."
559 }
560 OperatorCategory::Sink => {
561 "Operators which consume input elements (and produce no outputs)."
562 }
563 OperatorCategory::Control => "Operators which affect control flow/scheduling.",
564 OperatorCategory::CompilerFusionOperator => {
565 "Operators which are necessary to implement certain optimizations and rewrite rules"
566 }
567 OperatorCategory::Windowing => "Operators for windowing `loop` inputs.",
568 OperatorCategory::Unwindowing => "Operators for collecting `loop` outputs.",
569 }
570 }
571}
572
573#[derive(Clone, Copy, PartialOrd, Ord, PartialEq, Eq, Debug)]
575pub enum FloType {
576 Source,
578 Windowing,
580 Unwindowing,
582 NextIteration,
584}