From 335a1fc3e2c4ac70b1900d782ce74bb1182c60b8 Mon Sep 17 00:00:00 2001 From: Ferdinand Schober Date: Fri, 12 Jul 2024 01:01:09 +0200 Subject: [PATCH] simplify --- src/server.rs | 134 +++++++++++++++++++++++------------ src/server/capture_task.rs | 49 ++++--------- src/server/emulation_task.rs | 63 +++++----------- src/server/frontend_task.rs | 51 ++++++------- src/server/network_task.rs | 26 +++---- src/server/ping_task.rs | 19 ++--- src/server/resolver_task.rs | 17 ++--- 7 files changed, 162 insertions(+), 197 deletions(-) diff --git a/src/server.rs b/src/server.rs index 8bdfc69..02a78ed 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,7 +3,6 @@ use std::{ cell::{Cell, RefCell}, collections::HashSet, rc::Rc, - sync::Arc, }; use tokio::{ join, signal, @@ -14,7 +13,7 @@ use tokio_util::sync::CancellationToken; use crate::{ client::{ClientConfig, ClientHandle, ClientManager, ClientState}, config::{CaptureBackend, Config, EmulationBackend}, - dns, + dns::DnsResolver, frontend::{FrontendListener, FrontendRequest}, server::capture_task::CaptureEvent, }; @@ -46,6 +45,15 @@ pub struct Server { port: Rc>, state: Rc>, release_bind: Vec, + notifies: Rc, +} + +#[derive(Default)] +struct Notifies { + ping: Notify, + capture: Notify, + emulation: Notify, + cancel: CancellationToken, } impl Server { @@ -72,13 +80,18 @@ impl Server { let c = client_manager.get_mut(handle).expect("invalid handle"); *c = (client, state); } + + // task notification tokens + let notifies = Rc::new(Notifies::default()); let release_bind = config.release_bind.clone(); + Self { active_client, client_manager, port, state, release_bind, + notifies, } } @@ -97,76 +110,69 @@ impl Server { } }; - let notify_ping = Arc::new(Notify::new()); /* notify ping timer restart */ - let (frontend_tx, frontend_rx) = channel(1); /* events coming from frontends */ - let cancellation_token = CancellationToken::new(); /* notify termination */ - let notify_capture = Arc::new(Notify::new()); /* notify capture restart */ - let notify_emulation = Arc::new(Notify::new()); /* notify emultation restart */ + let (frontend_tx, frontend_rx) = channel(1); /* events for frontends */ + let (request_tx, request_rx) = channel(1); /* requests coming from frontends */ + let (capture_tx, capture_rx) = channel(1); /* requests for input capture */ + let (emulation_tx, emulation_rx) = channel(1); /* emulation requests */ + let (udp_recv_tx, udp_recv_rx) = channel(1); /* udp receiver */ + let (udp_send_tx, udp_send_rx) = channel(1); /* udp sender */ + let (port_tx, port_rx) = channel(1); /* port change request */ + let (dns_tx, dns_rx) = channel(1); /* dns requests */ // udp task - let (network, udp_send, udp_recv, port_tx) = network_task::new( + let network = network_task::new( self.clone(), + udp_recv_tx, + udp_send_rx, + port_rx, frontend_tx.clone(), - cancellation_token.clone(), ) .await?; // input capture - let (capture, capture_channel) = capture_task::new( - capture_backend, + let capture = capture_task::new( self.clone(), - udp_send.clone(), + capture_backend, + capture_rx, + udp_send_tx.clone(), frontend_tx.clone(), - notify_ping.clone(), self.release_bind.clone(), - cancellation_token.clone(), - notify_capture.clone(), ); // input emulation - let (emulation, emulate_channel) = emulation_task::new( - emulation_backend, + let emulation = emulation_task::new( self.clone(), - udp_recv, - udp_send.clone(), - capture_channel.clone(), + emulation_backend, + emulation_rx, + udp_recv_rx, + udp_send_tx.clone(), + capture_tx.clone(), frontend_tx.clone(), - notify_ping.clone(), - cancellation_token.clone(), - notify_emulation.clone(), ); // create dns resolver - let resolver = dns::DnsResolver::new().await?; - let (resolver, dns_req) = resolver_task::new( - resolver, - self.clone(), - frontend_tx, - cancellation_token.clone(), - ); + let resolver = DnsResolver::new().await?; + let resolver = resolver_task::new(resolver, dns_rx, self.clone(), frontend_tx); // frontend listener - let (frontend, frontend_tx) = frontend_task::new( + let frontend = frontend_task::new( + self.clone(), frontend, frontend_rx, - self.clone(), - notify_emulation, - notify_capture, - capture_channel.clone(), - emulate_channel.clone(), - dns_req.clone(), + request_tx.clone(), + request_rx, + capture_tx.clone(), + emulation_tx.clone(), + dns_tx.clone(), port_tx, - cancellation_token.clone(), ); // task that pings clients to see if they are responding let ping = ping_task::new( self.clone(), - udp_send.clone(), - emulate_channel.clone(), - capture_channel.clone(), - notify_ping, - cancellation_token.clone(), + udp_send_tx.clone(), + emulation_tx.clone(), + capture_tx.clone(), ); let active = self @@ -182,11 +188,11 @@ impl Server { }) .collect::>(); for (handle, hostname) in active { - frontend_tx + request_tx .send(FrontendRequest::Activate(handle, true)) .await?; if let Some(hostname) = hostname { - let _ = dns_req.send(DnsRequest { hostname, handle }).await; + let _ = dns_tx.send(DnsRequest { hostname, handle }).await; } } @@ -194,9 +200,45 @@ impl Server { signal::ctrl_c().await.expect("failed to listen for CTRL+C"); log::info!("terminating service"); - cancellation_token.cancel(); + self.cancel(); let _ = join!(capture, emulation, frontend, network, resolver, ping); Ok(()) } + + fn cancel(&self) { + self.notifies.cancel.cancel(); + } + + async fn cancelled(&self) { + self.notifies.cancel.cancelled().await + } + + fn is_cancelled(&self) -> bool { + self.notifies.cancel.is_cancelled() + } + + fn notify_capture(&self) { + self.notifies.capture.notify_waiters() + } + + async fn capture_notified(&self) { + self.notifies.capture.notified().await + } + + fn notify_emulation(&self) { + self.notifies.emulation.notify_waiters() + } + + async fn emulation_notified(&self) { + self.notifies.emulation.notified().await + } + + fn restart_ping_timer(&self) { + self.notifies.ping.notify_waiters() + } + + async fn ping_timer_notified(&self) { + self.notifies.ping.notified().await + } } diff --git a/src/server/capture_task.rs b/src/server/capture_task.rs index 04802df..0f26a24 100644 --- a/src/server/capture_task.rs +++ b/src/server/capture_task.rs @@ -1,14 +1,10 @@ use futures::StreamExt; -use std::{collections::HashSet, net::SocketAddr, sync::Arc}; +use std::{collections::HashSet, net::SocketAddr}; use thiserror::Error; -use tokio_util::sync::CancellationToken; use tokio::{ process::Command, - sync::{ - mpsc::{Receiver, Sender}, - Notify, - }, + sync::mpsc::{Receiver, Sender}, task::JoinHandle, }; @@ -46,41 +42,31 @@ pub enum CaptureEvent { } pub fn new( - backend: Option, server: Server, + backend: Option, + capture_rx: Receiver, udp_send: Sender<(Event, SocketAddr)>, frontend_tx: Sender, - notify_ping: Arc, release_bind: Vec, - cancellation_token: CancellationToken, - notify_capture: Arc, -) -> (JoinHandle<()>, Sender) { - let (tx, rx) = tokio::sync::mpsc::channel(32); +) -> JoinHandle<()> { let backend = backend.map(|b| b.into()); - let task = tokio::task::spawn_local(capture_task( - backend, + tokio::task::spawn_local(capture_task( server, + backend, udp_send, - rx, + capture_rx, frontend_tx, - notify_ping, release_bind, - cancellation_token, - notify_capture, - )); - (task, tx) + )) } async fn capture_task( - backend: Option, server: Server, + backend: Option, sender_tx: Sender<(Event, SocketAddr)>, mut notify_rx: Receiver, frontend_tx: Sender, - timer_notify: Arc, release_bind: Vec, - cancellation_token: CancellationToken, - notify_capture: Arc, ) { loop { if let Err(e) = do_capture( @@ -89,9 +75,7 @@ async fn capture_task( &sender_tx, &mut notify_rx, &frontend_tx, - &timer_notify, &release_bind, - &cancellation_token, ) .await { @@ -100,10 +84,10 @@ async fn capture_task( let _ = frontend_tx .send(FrontendEvent::CaptureStatus(Status::Disabled)) .await; - if cancellation_token.is_cancelled() { + if server.is_cancelled() { break; } - notify_capture.notified().await; + server.capture_notified().await; } } @@ -113,9 +97,7 @@ async fn do_capture( sender_tx: &Sender<(Event, SocketAddr)>, notify_rx: &mut Receiver, frontend_tx: &Sender, - timer_notify: &Notify, release_bind: &[scancode::Linux], - cancellation_token: &CancellationToken, ) -> Result<(), LanMouseCaptureError> { let mut capture = input_capture::create(backend).await?; let _ = frontend_tx @@ -145,7 +127,7 @@ async fn do_capture( tokio::select! { event = capture.next() => { match event { - Some(Ok(event)) => handle_capture_event(server, &mut capture, sender_tx, timer_notify, event, &mut pressed_keys, release_bind).await?, + Some(Ok(event)) => handle_capture_event(server, &mut capture, sender_tx, event, &mut pressed_keys, release_bind).await?, Some(Err(e)) => return Err(e.into()), None => return Ok(()), } @@ -164,7 +146,7 @@ async fn do_capture( None => break, } } - _ = cancellation_token.cancelled() => break, + _ = server.cancelled() => break, } } capture.terminate().await?; @@ -185,7 +167,6 @@ async fn handle_capture_event( server: &Server, capture: &mut Box, sender_tx: &Sender<(Event, SocketAddr)>, - timer_notify: &Notify, event: (CaptureHandle, Event), pressed_keys: &mut HashSet, release_bind: &[scancode::Linux], @@ -249,7 +230,7 @@ async fn handle_capture_event( }; if start_timer { - timer_notify.notify_waiters(); + server.restart_ping_timer(); } if enter { spawn_hook_command(server, handle); diff --git a/src/server/emulation_task.rs b/src/server/emulation_task.rs index 327852b..7312cd1 100644 --- a/src/server/emulation_task.rs +++ b/src/server/emulation_task.rs @@ -1,14 +1,10 @@ -use std::{net::SocketAddr, sync::Arc}; +use std::net::SocketAddr; use thiserror::Error; use tokio::{ - sync::{ - mpsc::{Receiver, Sender}, - Notify, - }, + sync::mpsc::{Receiver, Sender}, task::JoinHandle, }; -use tokio_util::sync::CancellationToken; use crate::{ client::{ClientHandle, ClientManager}, @@ -36,31 +32,24 @@ pub enum EmulationEvent { } pub fn new( - backend: Option, server: Server, + backend: Option, + emulation_rx: Receiver, udp_rx: Receiver>, sender_tx: Sender<(Event, SocketAddr)>, capture_tx: Sender, frontend_tx: Sender, - timer_notify: Arc, - cancellation_token: CancellationToken, - notify_emulation: Arc, -) -> (JoinHandle<()>, Sender) { - let (tx, rx) = tokio::sync::mpsc::channel(32); +) -> JoinHandle<()> { let emulation_task = emulation_task( backend, - rx, + emulation_rx, server, udp_rx, sender_tx, capture_tx, frontend_tx, - timer_notify, - cancellation_token, - notify_emulation, ); - let emulate_task = tokio::task::spawn_local(emulation_task); - (emulate_task, tx) + tokio::task::spawn_local(emulation_task) } #[derive(Debug, Error)] @@ -79,21 +68,16 @@ async fn emulation_task( sender_tx: Sender<(Event, SocketAddr)>, capture_tx: Sender, frontend_tx: Sender, - timer_notify: Arc, - cancellation_token: CancellationToken, - notify_emulation: Arc, ) { loop { match do_emulation( + &server, backend, &mut rx, - &server, &mut udp_rx, &sender_tx, &capture_tx, &frontend_tx, - &timer_notify, - &cancellation_token, ) .await { @@ -105,25 +89,23 @@ async fn emulation_task( let _ = frontend_tx .send(FrontendEvent::EmulationStatus(Status::Disabled)) .await; - if cancellation_token.is_cancelled() { + if server.notifies.cancel.is_cancelled() { break; } log::info!("waiting for user to request input emulation ..."); - notify_emulation.notified().await; + server.emulation_notified().await; log::info!("... done"); } } async fn do_emulation( + server: &Server, backend: Option, rx: &mut Receiver, - server: &Server, udp_rx: &mut Receiver>, sender_tx: &Sender<(Event, SocketAddr)>, capture_tx: &Sender, frontend_tx: &Sender, - timer_notify: &Notify, - cancellation_token: &CancellationToken, ) -> Result<(), LanMouseEmulationError> { let backend = backend.map(|b| b.into()); log::info!("creating input emulation..."); @@ -132,17 +114,7 @@ async fn do_emulation( .send(FrontendEvent::EmulationStatus(Status::Enabled)) .await; - let res = do_emulation_session( - &mut emulation, - rx, - server, - udp_rx, - sender_tx, - capture_tx, - timer_notify, - cancellation_token, - ) - .await; + let res = do_emulation_session(server, &mut emulation, rx, udp_rx, sender_tx, capture_tx).await; emulation.terminate().await; res?; @@ -165,14 +137,12 @@ async fn do_emulation( } async fn do_emulation_session( + server: &Server, emulation: &mut Box, rx: &mut Receiver, - server: &Server, udp_rx: &mut Receiver>, sender_tx: &Sender<(Event, SocketAddr)>, capture_tx: &Sender, - timer_notify: &Notify, - cancellation_token: &CancellationToken, ) -> Result<(), LanMouseEmulationError> { let mut last_ignored = None; @@ -186,7 +156,7 @@ async fn do_emulation_session( continue; } }; - handle_udp_rx(&server, &capture_tx, emulation, &sender_tx, &mut last_ignored, udp_event, &timer_notify).await?; + handle_udp_rx(&server, &capture_tx, emulation, &sender_tx, &mut last_ignored, udp_event).await?; } emulate_event = rx.recv() => { match emulate_event.expect("channel closed") { @@ -195,7 +165,7 @@ async fn do_emulation_session( EmulationEvent::ReleaseKeys(c) => release_keys(&server, emulation, c).await?, } } - _ = cancellation_token.cancelled() => break Ok(()), + _ = server.notifies.cancel.cancelled() => break Ok(()), } } } @@ -207,7 +177,6 @@ async fn handle_udp_rx( sender_tx: &Sender<(Event, SocketAddr)>, last_ignored: &mut Option, event: (Event, SocketAddr), - timer_notify: &Notify, ) -> Result<(), EmulationError> { let (event, addr) = event; @@ -257,7 +226,7 @@ async fn handle_udp_rx( ); // restart timer if necessary if restart_timer { - timer_notify.notify_waiters(); + server.restart_ping_timer(); } ignore_event } else { diff --git a/src/server/frontend_task.rs b/src/server/frontend_task.rs index 99c0c82..a430aa4 100644 --- a/src/server/frontend_task.rs +++ b/src/server/frontend_task.rs @@ -2,7 +2,6 @@ use std::{ collections::HashSet, io::ErrorKind, net::{IpAddr, SocketAddr}, - sync::Arc, }; #[cfg(unix)] use tokio::net::UnixStream; @@ -12,13 +11,9 @@ use tokio::net::TcpStream; use tokio::{ io::ReadHalf, - sync::{ - mpsc::{Receiver, Sender}, - Notify, - }, + sync::mpsc::{Receiver, Sender}, task::JoinHandle, }; -use tokio_util::sync::CancellationToken; use crate::{ client::{ClientHandle, Position}, @@ -30,32 +25,30 @@ use super::{ }; pub(crate) fn new( + server: Server, mut frontend: FrontendListener, mut event: Receiver, - server: Server, - notify_emulation: Arc, - notify_capture: Arc, + request_tx: Sender, + mut request_rx: Receiver, capture: Sender, emulate: Sender, resolve_ch: Sender, port_tx: Sender, - cancellation_token: CancellationToken, -) -> (JoinHandle<()>, Sender) { - let (request_tx, mut request) = tokio::sync::mpsc::channel(32); - let request_tx_clone = request_tx.clone(); - let frontend_task = tokio::task::spawn_local(async move { +) -> JoinHandle<()> { + let request = request_tx.clone(); + tokio::task::spawn_local(async move { let mut join_handles = vec![]; loop { tokio::select! { stream = frontend.accept() => { match stream { - Ok(s) => join_handles.push(handle_frontend_stream(&request_tx_clone, s, cancellation_token.clone())), + Ok(s) => join_handles.push(handle_frontend_stream(server.clone(), &request, s)), Err(e) => log::warn!("error accepting frontend connection: {e}"), }; } - request = request.recv() => { + request = request_rx.recv() => { let request = request.expect("frontend request channel closed"); - if handle_frontend_event(&server, ¬ify_capture, ¬ify_emulation, &capture, &emulate, &resolve_ch, &mut frontend, &port_tx, request).await { + if handle_frontend_event(&server, &capture, &emulate, &resolve_ch, &mut frontend, &port_tx, request).await { break; } } @@ -63,33 +56,32 @@ pub(crate) fn new( let event = event.expect("channel closed"); let _ = frontend.broadcast_event(event).await; } - _ = cancellation_token.cancelled() => { + _ = server.cancelled() => { futures::future::join_all(join_handles).await; break; } } } - }); - (frontend_task, request_tx) + }) } fn handle_frontend_stream( - frontend_tx: &Sender, + server: Server, + request_tx: &Sender, #[cfg(unix)] stream: ReadHalf, #[cfg(windows)] stream: ReadHalf, - cancellation_token: CancellationToken, ) -> JoinHandle<()> { - let tx = frontend_tx.clone(); + let tx = request_tx.clone(); tokio::task::spawn_local(async move { tokio::select! { _ = listen_frontend(tx, stream) => {}, - _ = cancellation_token.cancelled() => {}, + _ = server.cancelled() => {}, } }) } async fn listen_frontend( - tx: Sender, + request_tx: Sender, #[cfg(unix)] mut stream: ReadHalf, #[cfg(windows)] mut stream: ReadHalf, ) { @@ -98,7 +90,7 @@ async fn listen_frontend( let request = frontend::wait_for_request(&mut stream).await; match request { Ok(request) => { - let _ = tx.send(request).await; + let _ = request_tx.send(request).await; } Err(e) => { if let Some(e) = e.downcast_ref::() { @@ -115,8 +107,6 @@ async fn listen_frontend( async fn handle_frontend_event( server: &Server, - notify_capture: &Notify, - notify_emulation: &Notify, capture: &Sender, emulate: &Sender, resolve_tx: &Sender, @@ -127,11 +117,12 @@ async fn handle_frontend_event( log::debug!("frontend: {event:?}"); match event { FrontendRequest::EnableCapture => { - notify_capture.notify_waiters(); + log::info!("received capture enable request"); + server.notify_capture(); } FrontendRequest::EnableEmulation => { log::info!("received emulation enable request"); - notify_emulation.notify_waiters(); + server.notify_emulation(); } FrontendRequest::Create => { let handle = add_client(server, frontend).await; diff --git a/src/server/network_task.rs b/src/server/network_task.rs index 3f269a7..8a2023b 100644 --- a/src/server/network_task.rs +++ b/src/server/network_task.rs @@ -6,7 +6,6 @@ use tokio::{ sync::mpsc::{Receiver, Sender}, task::JoinHandle, }; -use tokio_util::sync::CancellationToken; use crate::frontend::FrontendEvent; use input_event::{Event, ProtocolError}; @@ -15,25 +14,19 @@ use super::Server; pub async fn new( server: Server, + udp_recv_tx: Sender>, + udp_send_rx: Receiver<(Event, SocketAddr)>, + mut port_rx: Receiver, frontend_notify_tx: Sender, - cancellation_token: CancellationToken, -) -> io::Result<( - JoinHandle<()>, - Sender<(Event, SocketAddr)>, - Receiver>, - Sender, -)> { +) -> io::Result> { // bind the udp socket let listen_addr = SocketAddr::new("0.0.0.0".parse().unwrap(), server.port.get()); let mut socket = UdpSocket::bind(listen_addr).await?; - let (receiver_tx, receiver_rx) = tokio::sync::mpsc::channel(32); - let (sender_tx, sender_rx) = tokio::sync::mpsc::channel(32); - let (port_tx, mut port_rx) = tokio::sync::mpsc::channel(32); - let udp_task = tokio::task::spawn_local(async move { - let mut sender_rx = sender_rx; + Ok(tokio::task::spawn_local(async move { + let mut sender_rx = udp_send_rx; loop { - let udp_receiver = udp_receiver(&socket, &receiver_tx); + let udp_receiver = udp_receiver(&socket, &udp_recv_tx); let udp_sender = udp_sender(&socket, &mut sender_rx); tokio::select! { _ = udp_receiver => break, /* channel closed */ @@ -42,11 +35,10 @@ pub async fn new( Some(port) => update_port(&server, &frontend_notify_tx, &mut socket, port).await, _ => continue, }, - _ = cancellation_token.cancelled() => break, /* cancellation requested */ + _ = server.cancelled() => break, /* cancellation requested */ } } - }); - Ok((udp_task, sender_tx, receiver_rx, port_tx)) + })) } async fn update_port( diff --git a/src/server/ping_task.rs b/src/server/ping_task.rs index a75ce6e..4908d65 100644 --- a/src/server/ping_task.rs +++ b/src/server/ping_task.rs @@ -1,12 +1,8 @@ -use std::{net::SocketAddr, sync::Arc, time::Duration}; +use std::{net::SocketAddr, time::Duration}; -use tokio::{ - sync::{mpsc::Sender, Notify}, - task::JoinHandle, -}; +use tokio::{sync::mpsc::Sender, task::JoinHandle}; use input_event::Event; -use tokio_util::sync::CancellationToken; use crate::client::ClientHandle; @@ -19,28 +15,25 @@ pub fn new( sender_ch: Sender<(Event, SocketAddr)>, emulate_notify: Sender, capture_notify: Sender, - timer_notify: Arc, - cancellation_token: CancellationToken, ) -> JoinHandle<()> { // timer task tokio::task::spawn_local(async move { tokio::select! { - _ = cancellation_token.cancelled() => {} - _ = ping_task(server, sender_ch, emulate_notify, capture_notify, timer_notify) => {} + _ = server.notifies.cancel.cancelled() => {} + _ = ping_task(&server, sender_ch, emulate_notify, capture_notify) => {} } }) } async fn ping_task( - server: Server, + server: &Server, sender_ch: Sender<(Event, SocketAddr)>, emulate_notify: Sender, capture_notify: Sender, - timer_notify: Arc, ) { loop { // wait for wake up signal - timer_notify.notified().await; + server.ping_timer_notified().await; loop { let receiving = server.state.get() == State::Receiving; let (ping_clients, ping_addrs) = { diff --git a/src/server/resolver_task.rs b/src/server/resolver_task.rs index cf16850..10b8f75 100644 --- a/src/server/resolver_task.rs +++ b/src/server/resolver_task.rs @@ -4,7 +4,6 @@ use tokio::{ sync::mpsc::{Receiver, Sender}, task::JoinHandle, }; -use tokio_util::sync::CancellationToken; use crate::{client::ClientHandle, dns::DnsResolver, frontend::FrontendEvent}; @@ -18,23 +17,21 @@ pub struct DnsRequest { pub fn new( resolver: DnsResolver, + dns_rx: Receiver, server: Server, frontend: Sender, - cancellation_token: CancellationToken, -) -> (JoinHandle<()>, Sender) { - let (dns_tx, dns_rx) = tokio::sync::mpsc::channel::(32); - let resolver_task = tokio::task::spawn_local(async move { +) -> JoinHandle<()> { + tokio::task::spawn_local(async move { tokio::select! { - _ = cancellation_token.cancelled() => {}, - _ = do_dns(resolver, server, frontend, dns_rx) => {}, + _ = server.cancelled() => {}, + _ = do_dns(&server, resolver, frontend, dns_rx) => {}, } - }); - (resolver_task, dns_tx) + }) } async fn do_dns( + server: &Server, resolver: DnsResolver, - server: Server, frontend: Sender, mut dns_rx: Receiver, ) {