1#![warn(missing_docs)]
2pub mod clear;
5#[cfg(feature = "dfir_macro")]
6pub mod demux_enum;
7pub mod monotonic_map;
8pub mod multiset;
9pub mod priority_stack;
10pub mod slot_vec;
11pub mod sparse_vec;
12pub mod unsync;
13
14pub mod simulation;
15
16mod monotonic;
17pub use monotonic::*;
18
19mod udp;
20#[cfg(not(target_arch = "wasm32"))]
21pub use udp::*;
22
23mod tcp;
24#[cfg(not(target_arch = "wasm32"))]
25pub use tcp::*;
26
27#[cfg(unix)]
28mod socket;
29#[cfg(unix)]
30pub use socket::*;
31
32#[cfg(feature = "deploy_integration")]
33pub mod deploy;
34
35use std::io::Read;
36use std::net::SocketAddr;
37use std::num::NonZeroUsize;
38use std::process::{Child, ChildStdin, ChildStdout, Stdio};
39use std::task::{Context, Poll};
40
41use futures::Stream;
42use serde::de::DeserializeOwned;
43use serde::ser::Serialize;
44
45pub enum Persistence<T> {
47 Persist(T),
49 Delete(T),
51}
52
53pub enum PersistenceKeyed<K, V> {
55 Persist(K, V),
57 Delete(K),
59}
60
61pub fn unbounded_channel<T>() -> (
63 tokio::sync::mpsc::UnboundedSender<T>,
64 tokio_stream::wrappers::UnboundedReceiverStream<T>,
65) {
66 let (send, recv) = tokio::sync::mpsc::unbounded_channel();
67 let recv = tokio_stream::wrappers::UnboundedReceiverStream::new(recv);
68 (send, recv)
69}
70
71pub fn unsync_channel<T>(
73 capacity: Option<NonZeroUsize>,
74) -> (unsync::mpsc::Sender<T>, unsync::mpsc::Receiver<T>) {
75 unsync::mpsc::channel(capacity)
76}
77
78pub fn ready_iter<S>(stream: S) -> impl Iterator<Item = S::Item>
80where
81 S: Stream,
82{
83 let mut stream = Box::pin(stream);
84 std::iter::from_fn(move || {
85 match stream
86 .as_mut()
87 .poll_next(&mut Context::from_waker(futures::task::noop_waker_ref()))
88 {
89 Poll::Ready(opt) => opt,
90 Poll::Pending => None,
91 }
92 })
93}
94
95pub fn collect_ready<C, S>(stream: S) -> C
100where
101 C: FromIterator<S::Item>,
102 S: Stream,
103{
104 assert!(
105 tokio::runtime::Handle::try_current().is_err(),
106 "Calling `collect_ready` from an async runtime may cause incorrect results, use `collect_ready_async` instead."
107 );
108 ready_iter(stream).collect()
109}
110
111pub async fn collect_ready_async<C, S>(stream: S) -> C
116where
117 C: Default + Extend<S::Item>,
118 S: Stream,
119{
120 use std::sync::atomic::Ordering;
121
122 tokio::task::yield_now().await;
124
125 let got_any_items = std::sync::atomic::AtomicBool::new(true);
126 let mut unfused_iter =
127 ready_iter(stream).inspect(|_| got_any_items.store(true, Ordering::Relaxed));
128 let mut out = C::default();
129 while got_any_items.swap(false, Ordering::Relaxed) {
130 out.extend(unfused_iter.by_ref());
131 tokio::task::yield_now().await;
134 }
135 out
136}
137
138pub fn serialize_to_bytes<T>(msg: T) -> bytes::Bytes
140where
141 T: Serialize,
142{
143 bytes::Bytes::from(bincode::serialize(&msg).unwrap())
144}
145
146pub fn deserialize_from_bytes<T>(msg: impl AsRef<[u8]>) -> bincode::Result<T>
148where
149 T: DeserializeOwned,
150{
151 bincode::deserialize(msg.as_ref())
152}
153
154pub fn ipv4_resolve(addr: &str) -> Result<SocketAddr, std::io::Error> {
156 use std::net::ToSocketAddrs;
157 let mut addrs = addr.to_socket_addrs()?;
158 let result = addrs.find(|addr| addr.is_ipv4());
159 match result {
160 Some(addr) => Ok(addr),
161 None => Err(std::io::Error::other("Unable to resolve IPv4 address")),
162 }
163}
164
165#[cfg(not(target_arch = "wasm32"))]
168pub async fn bind_udp_bytes(addr: SocketAddr) -> (UdpSink, UdpStream, SocketAddr) {
169 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
170 udp_bytes(socket)
171}
172
173#[cfg(not(target_arch = "wasm32"))]
176pub async fn bind_udp_lines(addr: SocketAddr) -> (UdpLinesSink, UdpLinesStream, SocketAddr) {
177 let socket = tokio::net::UdpSocket::bind(addr).await.unwrap();
178 udp_lines(socket)
179}
180
181#[cfg(not(target_arch = "wasm32"))]
188pub async fn bind_tcp_bytes(
189 addr: SocketAddr,
190) -> (
191 unsync::mpsc::Sender<(bytes::Bytes, SocketAddr)>,
192 unsync::mpsc::Receiver<Result<(bytes::BytesMut, SocketAddr), std::io::Error>>,
193 SocketAddr,
194) {
195 bind_tcp(addr, tokio_util::codec::LengthDelimitedCodec::new())
196 .await
197 .unwrap()
198}
199
200#[cfg(not(target_arch = "wasm32"))]
202pub async fn bind_tcp_lines(
203 addr: SocketAddr,
204) -> (
205 unsync::mpsc::Sender<(String, SocketAddr)>,
206 unsync::mpsc::Receiver<Result<(String, SocketAddr), tokio_util::codec::LinesCodecError>>,
207 SocketAddr,
208) {
209 bind_tcp(addr, tokio_util::codec::LinesCodec::new())
210 .await
211 .unwrap()
212}
213
214#[cfg(not(target_arch = "wasm32"))]
219pub fn connect_tcp_bytes() -> (
220 TcpFramedSink<bytes::Bytes>,
221 TcpFramedStream<tokio_util::codec::LengthDelimitedCodec>,
222) {
223 connect_tcp(tokio_util::codec::LengthDelimitedCodec::new())
224}
225
226#[cfg(not(target_arch = "wasm32"))]
228pub fn connect_tcp_lines() -> (
229 TcpFramedSink<String>,
230 TcpFramedStream<tokio_util::codec::LinesCodec>,
231) {
232 connect_tcp(tokio_util::codec::LinesCodec::new())
233}
234
235pub fn sort_unstable_by_key_hrtb<T, F, K>(slice: &mut [T], f: F)
240where
241 F: for<'a> Fn(&'a T) -> &'a K,
242 K: Ord,
243{
244 slice.sort_unstable_by(|a, b| f(a).cmp(f(b)))
245}
246
247pub fn wait_for_process_output(
255 output_so_far: &mut String,
256 output: &mut ChildStdout,
257 wait_for: &str,
258) {
259 let re = regex::Regex::new(wait_for).unwrap();
260
261 while !re.is_match(output_so_far) {
262 println!("waiting: {}", output_so_far);
263 let mut buffer = [0u8; 1024];
264 let bytes_read = output.read(&mut buffer).unwrap();
265
266 if bytes_read == 0 {
267 panic!();
268 }
269
270 output_so_far.push_str(&String::from_utf8_lossy(&buffer[0..bytes_read]));
271
272 println!("XXX {}", output_so_far);
273 }
274}
275
276pub struct DroppableChild(Child);
281
282impl Drop for DroppableChild {
283 fn drop(&mut self) {
284 #[cfg(target_family = "windows")]
285 let _ = self.0.kill(); #[cfg(not(target_family = "windows"))]
287 self.0.kill().unwrap();
288
289 self.0.wait().unwrap();
290 }
291}
292
293pub fn run_cargo_example(test_name: &str, args: &str) -> (DroppableChild, ChildStdin, ChildStdout) {
300 let mut server = if args.is_empty() {
301 std::process::Command::new("cargo")
302 .args(["run", "-p", "dfir_rs", "--example"])
303 .arg(test_name)
304 .stdin(Stdio::piped())
305 .stdout(Stdio::piped())
306 .spawn()
307 .unwrap()
308 } else {
309 std::process::Command::new("cargo")
310 .args(["run", "-p", "dfir_rs", "--example"])
311 .arg(test_name)
312 .arg("--")
313 .args(args.split(' '))
314 .stdin(Stdio::piped())
315 .stdout(Stdio::piped())
316 .spawn()
317 .unwrap()
318 };
319
320 let stdin = server.stdin.take().unwrap();
321 let stdout = server.stdout.take().unwrap();
322
323 (DroppableChild(server), stdin, stdout)
324}
325
326pub fn iter_batches_stream<I>(
333 iter: I,
334 n: usize,
335) -> futures::stream::PollFn<impl FnMut(&mut Context<'_>) -> Poll<Option<I::Item>>>
336where
337 I: IntoIterator + Unpin,
338{
339 let mut count = 0;
340 let mut iter = iter.into_iter();
341 futures::stream::poll_fn(move |ctx| {
342 count += 1;
343 if n < count {
344 count = 0;
345 ctx.waker().wake_by_ref();
346 Poll::Pending
347 } else {
348 Poll::Ready(iter.next())
349 }
350 })
351}
352
353#[cfg(test)]
354mod test {
355 use super::*;
356
357 #[test]
358 pub fn test_collect_ready() {
359 let (send, mut recv) = unbounded_channel::<usize>();
360 for x in 0..1000 {
361 send.send(x).unwrap();
362 }
363 assert_eq!(1000, collect_ready::<Vec<_>, _>(&mut recv).len());
364 }
365
366 #[crate::test]
367 pub async fn test_collect_ready_async() {
368 let (send, mut recv) = unbounded_channel::<usize>();
370 for x in 0..1000 {
371 send.send(x).unwrap();
372 }
373 assert_eq!(
374 1000,
375 collect_ready_async::<Vec<_>, _>(&mut recv).await.len()
376 );
377 }
378}