diff --git a/input-capture/src/libei.rs b/input-capture/src/libei.rs index 9517fb0..728730b 100644 --- a/input-capture/src/libei.rs +++ b/input-capture/src/libei.rs @@ -48,7 +48,7 @@ use super::{ /// events that necessitate restarting the capture session #[derive(Clone, Copy, Debug)] -enum CaptureEvent { +enum LibeiNotifyEvent { Create(CaptureHandle, Position), Destroy(CaptureHandle), } @@ -58,7 +58,7 @@ pub struct LibeiInputCapture<'a> { input_capture: Pin>>, capture_task: JoinHandle>, event_rx: Receiver<(CaptureHandle, Event)>, - notify_capture: Sender, + notify_capture: Sender, notify_release: Arc, cancellation_token: CancellationToken, } @@ -239,7 +239,7 @@ impl<'a> LibeiInputCapture<'a> { async fn do_capture<'a>( input_capture: *const InputCapture<'a>, - mut capture_event: Receiver, + mut capture_event: Receiver, notify_release: Arc, session: Option<(Session<'a>, BitFlags)>, event_tx: Sender<(CaptureHandle, Event)>, @@ -259,7 +259,7 @@ async fn do_capture<'a>( let cancel_session = CancellationToken::new(); let cancel_update = CancellationToken::new(); - let mut capture_event_occured: Option = None; + let mut capture_event_occured: Option = None; let mut zones_have_changed = false; // kill session if clients need to be updated @@ -315,8 +315,8 @@ async fn do_capture<'a>( // update clients if requested if let Some(event) = capture_event_occured.take() { match event { - CaptureEvent::Create(c, p) => active_clients.push((c, p)), - CaptureEvent::Destroy(c) => active_clients.retain(|(h, _)| *h != c), + LibeiNotifyEvent::Create(c, p) => active_clients.push((c, p)), + LibeiNotifyEvent::Destroy(c) => active_clients.retain(|(h, _)| *h != c), } } @@ -658,7 +658,7 @@ impl<'a> LanMouseInputCapture for LibeiInputCapture<'a> { async fn create(&mut self, handle: CaptureHandle, pos: Position) -> Result<(), CaptureError> { let _ = self .notify_capture - .send(CaptureEvent::Create(handle, pos)) + .send(LibeiNotifyEvent::Create(handle, pos)) .await; Ok(()) } @@ -666,7 +666,7 @@ impl<'a> LanMouseInputCapture for LibeiInputCapture<'a> { async fn destroy(&mut self, handle: CaptureHandle) -> Result<(), CaptureError> { let _ = self .notify_capture - .send(CaptureEvent::Destroy(handle)) + .send(LibeiNotifyEvent::Destroy(handle)) .await; Ok(()) } diff --git a/src/server.rs b/src/server.rs index ea32ed0..3f5709e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -5,7 +5,10 @@ use std::{ rc::Rc, sync::Arc, }; -use tokio::{signal, sync::Notify}; +use tokio::{ + join, signal, + sync::{mpsc::channel, Notify}, +}; use tokio_util::sync::CancellationToken; use crate::{ @@ -84,7 +87,7 @@ impl Server { capture_backend: Option, emulation_backend: Option, ) -> anyhow::Result<()> { - // create frontend communication adapter + // create frontend communication adapter, exit if already running let frontend = match FrontendListener::new().await { Some(f) => f?, None => { @@ -95,23 +98,25 @@ impl Server { }; let timer_notify = Arc::new(Notify::new()); - let (frontend_notify_tx, frontend_notify_rx) = tokio::sync::mpsc::channel(1); - let cancellation_token = CancellationToken::new(); + let (frontend_tx, frontend_rx) = channel(1); /* events coming from frontends */ + let cancellation_token = CancellationToken::new(); /* notify termination */ + let notify_capture = Arc::new(Notify::new()); /* notify capture restart */ + let notify_emulation = Arc::new(Notify::new()); /* notify emultation restart */ // udp task - let (mut udp_task, sender_tx, receiver_rx, port_tx) = - network_task::new(self.clone(), frontend_notify_tx.clone()).await?; - - // restart notify tokens - let notify_capture = Arc::new(Notify::new()); - let notify_emulation = Arc::new(Notify::new()); + let (mut udp_task, udp_send, udp_recv, port_tx) = network_task::new( + self.clone(), + frontend_tx.clone(), + cancellation_token.clone(), + ) + .await?; // input capture let (mut capture_task, capture_channel) = capture_task::new( capture_backend, self.clone(), - sender_tx.clone(), - frontend_notify_tx.clone(), + udp_send.clone(), + frontend_tx.clone(), timer_notify.clone(), self.release_bind.clone(), cancellation_token.clone(), @@ -122,10 +127,10 @@ impl Server { let (mut emulation_task, emulate_channel) = emulation_task::new( emulation_backend, self.clone(), - receiver_rx, - sender_tx.clone(), + udp_recv, + udp_send.clone(), capture_channel.clone(), - frontend_notify_tx.clone(), + frontend_tx.clone(), timer_notify.clone(), cancellation_token.clone(), notify_emulation.clone(), @@ -133,19 +138,23 @@ impl Server { // create dns resolver let resolver = dns::DnsResolver::new().await?; - let (mut resolver_task, resolve_tx) = - resolver_task::new(resolver, self.clone(), frontend_notify_tx); + let (mut resolver_task, dns_req) = resolver_task::new( + resolver, + self.clone(), + frontend_tx, + cancellation_token.clone(), + ); // frontend listener let (mut frontend_task, frontend_tx) = frontend_task::new( frontend, - frontend_notify_rx, + frontend_rx, self.clone(), notify_emulation, notify_capture, capture_channel.clone(), emulate_channel.clone(), - resolve_tx.clone(), + dns_req.clone(), port_tx, cancellation_token.clone(), ); @@ -153,7 +162,7 @@ impl Server { // task that pings clients to see if they are responding let mut ping_task = ping_task::new( self.clone(), - sender_tx.clone(), + udp_send.clone(), emulate_channel.clone(), capture_channel.clone(), timer_notify, @@ -176,7 +185,7 @@ impl Server { .send(FrontendRequest::Activate(handle, true)) .await?; if let Some(hostname) = hostname { - let _ = resolve_tx.send(DnsRequest { hostname, handle }).await; + let _ = dns_req.send(DnsRequest { hostname, handle }).await; } } log::info!("running service"); @@ -197,20 +206,12 @@ impl Server { _ = &mut ping_task => { } } + // cancel tasks cancellation_token.cancel(); - if !capture_task.is_finished() { - let _ = capture_task.await; - } - if !emulation_task.is_finished() { - let _ = emulation_task.await; - } - if !frontend_task.is_finished() { - let _ = frontend_task.await; - } + let _ = join!(capture_task, emulation_task, frontend_task, udp_task); resolver_task.abort(); - udp_task.abort(); ping_task.abort(); Ok(()) diff --git a/src/server/capture_task.rs b/src/server/capture_task.rs index 2bbb4c0..423b707 100644 --- a/src/server/capture_task.rs +++ b/src/server/capture_task.rs @@ -123,15 +123,22 @@ async fn do_capture( .await; // FIXME DUPLICATES - // let clients = server - // .client_manager - // .borrow() - // .get_client_states() - // .map(|(h, (c, _))| (h, c.pos)) - // .collect::>(); - // for (handle, pos) in clients { - // capture.create(handle, pos.into()).await?; - // } + let clients = server + .client_manager + .borrow() + .get_client_states() + .map(|(h, s)| (h, s.clone())) + .collect::>(); + log::info!("{clients:?}"); + let clients = server + .client_manager + .borrow() + .get_client_states() + .map(|(h, (c, _))| (h, c.pos)) + .collect::>(); + for (handle, pos) in clients { + capture.create(handle, pos.into()).await?; + } let mut pressed_keys = HashSet::new(); loop { diff --git a/src/server/network_task.rs b/src/server/network_task.rs index 1e8fb0f..3f269a7 100644 --- a/src/server/network_task.rs +++ b/src/server/network_task.rs @@ -6,6 +6,7 @@ use tokio::{ sync::mpsc::{Receiver, Sender}, task::JoinHandle, }; +use tokio_util::sync::CancellationToken; use crate::frontend::FrontendEvent; use input_event::{Event, ProtocolError}; @@ -15,6 +16,7 @@ use super::Server; pub async fn new( server: Server, frontend_notify_tx: Sender, + cancellation_token: CancellationToken, ) -> io::Result<( JoinHandle<()>, Sender<(Event, SocketAddr)>, @@ -38,8 +40,9 @@ pub async fn new( _ = udp_sender => break, /* channel closed */ port = port_rx.recv() => match port { Some(port) => update_port(&server, &frontend_notify_tx, &mut socket, port).await, - _ => break, - } + _ => continue, + }, + _ = cancellation_token.cancelled() => break, /* cancellation requested */ } } }); @@ -80,18 +83,13 @@ async fn udp_receiver( ) { loop { let event = receive_event(socket).await; - if receiver_tx.send(event).await.is_err() { - break; - } + receiver_tx.send(event).await.expect("channel closed"); } } async fn udp_sender(socket: &UdpSocket, rx: &mut Receiver<(Event, SocketAddr)>) { loop { - let (event, addr) = match rx.recv().await { - Some(e) => e, - None => return, - }; + let (event, addr) = rx.recv().await.expect("channel closed"); if let Err(e) = send_event(socket, event, addr) { log::warn!("udp send failed: {e}"); }; diff --git a/src/server/resolver_task.rs b/src/server/resolver_task.rs index dc1e604..cf16850 100644 --- a/src/server/resolver_task.rs +++ b/src/server/resolver_task.rs @@ -1,6 +1,10 @@ use std::collections::HashSet; -use tokio::{sync::mpsc::Sender, task::JoinHandle}; +use tokio::{ + sync::mpsc::{Receiver, Sender}, + task::JoinHandle, +}; +use tokio_util::sync::CancellationToken; use crate::{client::ClientHandle, dns::DnsResolver, frontend::FrontendEvent}; @@ -14,52 +18,65 @@ pub struct DnsRequest { pub fn new( resolver: DnsResolver, - mut server: Server, - mut frontend: Sender, + server: Server, + frontend: Sender, + cancellation_token: CancellationToken, ) -> (JoinHandle<()>, Sender) { - let (dns_tx, mut dns_rx) = tokio::sync::mpsc::channel::(32); + let (dns_tx, dns_rx) = tokio::sync::mpsc::channel::(32); let resolver_task = tokio::task::spawn_local(async move { - loop { - let (host, handle) = match dns_rx.recv().await { - Some(r) => (r.hostname, r.handle), - None => break, - }; - - /* update resolving status */ - if let Some((_, s)) = server.client_manager.borrow_mut().get_mut(handle) { - s.resolving = true; - } - notify_state_change(&mut frontend, &mut server, handle).await; - - let ips = match resolver.resolve(&host).await { - Ok(ips) => ips, - Err(e) => { - log::warn!("could not resolve host '{host}': {e}"); - vec![] - } - }; - - /* update ips and resolving state */ - if let Some((c, s)) = server.client_manager.borrow_mut().get_mut(handle) { - let mut addrs = HashSet::from_iter(c.fix_ips.iter().cloned()); - for ip in ips { - addrs.insert(ip); - } - s.ips = addrs; - s.resolving = false; - } - notify_state_change(&mut frontend, &mut server, handle).await; + tokio::select! { + _ = cancellation_token.cancelled() => {}, + _ = do_dns(resolver, server, frontend, dns_rx) => {}, } }); (resolver_task, dns_tx) } +async fn do_dns( + resolver: DnsResolver, + server: Server, + frontend: Sender, + mut dns_rx: Receiver, +) { + loop { + let (host, handle) = match dns_rx.recv().await { + Some(r) => (r.hostname, r.handle), + None => break, + }; + + /* update resolving status */ + if let Some((_, s)) = server.client_manager.borrow_mut().get_mut(handle) { + s.resolving = true; + } + notify_state_change(&frontend, &server, handle).await; + + let ips = match resolver.resolve(&host).await { + Ok(ips) => ips, + Err(e) => { + log::warn!("could not resolve host '{host}': {e}"); + vec![] + } + }; + + /* update ips and resolving state */ + if let Some((c, s)) = server.client_manager.borrow_mut().get_mut(handle) { + let mut addrs = HashSet::from_iter(c.fix_ips.iter().cloned()); + for ip in ips { + addrs.insert(ip); + } + s.ips = addrs; + s.resolving = false; + } + notify_state_change(&frontend, &server, handle).await; + } +} + async fn notify_state_change( - frontend: &mut Sender, - server: &mut Server, + frontend: &Sender, + server: &Server, handle: ClientHandle, ) { - let state = server.client_manager.borrow_mut().get_mut(handle).cloned(); + let state = server.client_manager.borrow().get(handle).cloned(); if let Some((config, state)) = state { let _ = frontend .send(FrontendEvent::State(handle, config, state))