dfir_rs/util/
deploy.rs

1#![allow(clippy::allow_attributes, missing_docs, reason = "// TODO(mingwei)")]
2
3use std::cell::RefCell;
4use std::collections::HashMap;
5
6pub use hydro_deploy_integration::*;
7use serde::de::DeserializeOwned;
8
9use crate::scheduled::graph::Dfir;
10
11#[macro_export]
12macro_rules! launch {
13    ($f:expr) => {
14        async {
15            let ports = $crate::util::deploy::init_no_ack_start().await;
16            let flow = $f(&ports);
17
18            println!("ack start");
19
20            $crate::util::deploy::launch_flow(flow).await
21        }
22    };
23}
24
25pub use crate::launch;
26
27pub async fn launch_flow(mut flow: Dfir<'_>) {
28    let stop = tokio::sync::oneshot::channel();
29    tokio::task::spawn_blocking(|| {
30        let mut line = String::new();
31        std::io::stdin().read_line(&mut line).unwrap();
32        if line.starts_with("stop") {
33            stop.0.send(()).unwrap();
34        } else {
35            eprintln!("Unexpected stdin input: {:?}", line);
36        }
37    });
38
39    let local_set = tokio::task::LocalSet::new();
40    let flow = local_set.run_until(flow.run_async());
41
42    tokio::select! {
43        _ = stop.1 => {},
44        _ = flow => {}
45    }
46}
47
48/// Contains runtime information passed by Hydro Deploy to a program,
49/// describing how to connect to other services and metadata about them.
50pub struct DeployPorts<T = Option<()>> {
51    ports: RefCell<HashMap<String, Connection>>,
52    pub meta: T,
53}
54
55impl<T> DeployPorts<T> {
56    pub fn port(&self, name: &str) -> Connection {
57        self.ports
58            .try_borrow_mut()
59            .unwrap()
60            .remove(name)
61            .unwrap_or_else(|| panic!("port {} not found", name))
62    }
63}
64
65pub async fn init_no_ack_start<T: DeserializeOwned + Default>() -> DeployPorts<T> {
66    let mut input = String::new();
67    std::io::stdin().read_line(&mut input).unwrap();
68    let trimmed = input.trim();
69
70    let bind_config = serde_json::from_str::<InitConfig>(trimmed).unwrap();
71
72    // config telling other services how to connect to me
73    let mut bind_results: HashMap<String, ServerPort> = HashMap::new();
74    let mut binds = HashMap::new();
75    for (name, config) in bind_config.0 {
76        let bound = config.bind().await;
77        bind_results.insert(name.clone(), bound.server_port());
78        binds.insert(name.clone(), bound);
79    }
80
81    let bind_serialized = serde_json::to_string(&bind_results).unwrap();
82    println!("ready: {bind_serialized}");
83
84    let mut start_buf = String::new();
85    std::io::stdin().read_line(&mut start_buf).unwrap();
86    let connection_defns = if start_buf.starts_with("start: ") {
87        serde_json::from_str::<HashMap<String, ServerPort>>(
88            start_buf.trim_start_matches("start: ").trim(),
89        )
90        .unwrap()
91    } else {
92        panic!("expected start");
93    };
94
95    let mut all_connected = HashMap::new();
96    for (name, defn) in connection_defns {
97        all_connected.insert(name, Connection::AsClient(defn.connect()));
98    }
99
100    for (name, defn) in binds {
101        all_connected.insert(name, Connection::AsServer(defn));
102    }
103
104    DeployPorts {
105        ports: RefCell::new(all_connected),
106        meta: bind_config
107            .1
108            .map(|b| serde_json::from_str(&b).unwrap())
109            .unwrap_or_default(),
110    }
111}
112
113pub async fn init<T: DeserializeOwned + Default>() -> DeployPorts<T> {
114    let ret = init_no_ack_start::<T>().await;
115
116    println!("ack start");
117
118    ret
119}