diff --git a/src/capture.rs b/src/capture.rs index 26207de..01e249d 100644 --- a/src/capture.rs +++ b/src/capture.rs @@ -19,7 +19,7 @@ use crate::{connect::LanMouseConnection, server::Server}; pub(crate) struct Capture { tx: Sender, - _task: JoinHandle<()>, + task: JoinHandle<()>, } #[derive(Clone, Copy, Debug)] @@ -35,8 +35,16 @@ enum CaptureRequest { impl Capture { pub(crate) fn new(server: Server, conn: LanMouseConnection) -> Self { let (tx, rx) = channel(); - let _task = spawn_local(Self::run(server.clone(), rx, conn)); - Self { tx, _task } + let task = spawn_local(Self::run(server.clone(), rx, conn)); + Self { tx, task } + } + + pub(crate) async fn terminate(&mut self) { + log::debug!("terminating capture"); + self.tx.close(); + if let Err(e) = (&mut self.task).await { + log::warn!("{e}"); + } } pub(crate) fn create(&self, handle: CaptureHandle, pos: Position) { @@ -66,7 +74,10 @@ impl Capture { server.set_capture_status(Status::Disabled); loop { tokio::select! { - _ = rx.recv() => continue, + e = rx.recv() => match e { + Some(_) => continue, + None => break, + }, _ = server.capture_enabled() => break, _ = server.cancelled() => return, } diff --git a/src/emulation.rs b/src/emulation.rs index 3ce5c7c..fe9706f 100644 --- a/src/emulation.rs +++ b/src/emulation.rs @@ -5,37 +5,79 @@ use input_event::Event; use lan_mouse_ipc::Status; use lan_mouse_proto::ProtoEvent; use local_channel::mpsc::{channel, Receiver, Sender}; -use std::{collections::HashMap, net::SocketAddr}; -use tokio::task::{spawn_local, JoinHandle}; +use std::{ + collections::HashMap, + net::SocketAddr, + time::{Duration, Instant}, +}; +use tokio::{ + select, + task::{spawn_local, JoinHandle}, +}; /// emulation handling events received from a listener pub(crate) struct Emulation { - _tx: Sender, - _task: JoinHandle<()>, + task: JoinHandle<()>, } impl Emulation { pub(crate) fn new(server: Server, listener: LanMouseListener) -> Self { - let (_tx, _rx) = channel(); let emulation_proxy = EmulationProxy::new(server.clone()); - let _task = spawn_local(Self::run(server, listener, emulation_proxy)); - Self { _tx, _task } + let task = spawn_local(Self::run(server, listener, emulation_proxy)); + Self { task } } - async fn run(server: Server, mut listener: LanMouseListener, emulation_proxy: EmulationProxy) { - while let Some((event, addr)) = listener.next().await { - match event { - ProtoEvent::Enter(_) => { - server.release_capture(); - listener.reply(addr, ProtoEvent::Ack(0)).await; + async fn run( + server: Server, + mut listener: LanMouseListener, + mut emulation_proxy: EmulationProxy, + ) { + let mut interval = tokio::time::interval(Duration::from_secs(5)); + let mut last_response = HashMap::new(); + loop { + select! { + e = listener.next() => { + let (event, addr) = match e { + Some(e) => e, + None => break, + }; + last_response.insert(addr, Instant::now()); + match event { + ProtoEvent::Enter(_) => { + server.release_capture(); + listener.reply(addr, ProtoEvent::Ack(0)).await; + } + ProtoEvent::Leave(_) => { + emulation_proxy.release_keys(addr); + listener.reply(addr, ProtoEvent::Ack(0)).await; + } + ProtoEvent::Ack(_) => {} + ProtoEvent::Input(event) => emulation_proxy.consume(event, addr), + ProtoEvent::Ping => listener.reply(addr, ProtoEvent::Pong).await, + ProtoEvent::Pong => {}, + } } - ProtoEvent::Leave(_) => emulation_proxy.release_keys(addr).await, - ProtoEvent::Ack(_) => {} - ProtoEvent::Input(event) => emulation_proxy.consume(event, addr).await, - ProtoEvent::Ping => listener.reply(addr, ProtoEvent::Pong).await, - ProtoEvent::Pong => todo!(), + _ = interval.tick() => { + for (addr, last_response) in last_response.iter() { + if last_response.elapsed() > Duration::from_secs(5) { + log::warn!("{addr} is not responding, releasing keys!"); + emulation_proxy.release_keys(*addr); + } + } + } + _ = server.cancelled() => break, } } + listener.terminate().await; + emulation_proxy.terminate().await; + } + + /// wait for termination + pub(crate) async fn terminate(&mut self) { + log::debug!("terminating emulation"); + if let Err(e) = (&mut self.task).await { + log::warn!("{e}"); + } } } @@ -44,7 +86,6 @@ impl Emulation { pub(crate) struct EmulationProxy { server: Server, tx: Sender<(ProxyEvent, SocketAddr)>, - #[allow(unused)] task: JoinHandle<()>, } @@ -60,7 +101,7 @@ impl EmulationProxy { Self { server, tx, task } } - async fn consume(&self, event: Event, addr: SocketAddr) { + fn consume(&self, event: Event, addr: SocketAddr) { // ignore events if emulation is currently disabled if let Status::Enabled = self.server.emulation_status.get() { self.tx @@ -69,7 +110,7 @@ impl EmulationProxy { } } - async fn release_keys(&self, addr: SocketAddr) { + fn release_keys(&self, addr: SocketAddr) { self.tx .send((ProxyEvent::ReleaseKeys, addr)) .expect("channel closed"); @@ -146,4 +187,8 @@ impl EmulationProxy { } } } + + async fn terminate(&mut self) { + let _ = (&mut self.task).await; + } } diff --git a/src/listen.rs b/src/listen.rs index 66331a2..6a9afde 100644 --- a/src/listen.rs +++ b/src/listen.rs @@ -24,16 +24,14 @@ pub enum ListenerCreationError { pub(crate) struct LanMouseListener { listen_rx: Receiver<(ProtoEvent, SocketAddr)>, - _listen_task: JoinHandle<()>, + listen_tx: Sender<(ProtoEvent, SocketAddr)>, + listen_task: JoinHandle<()>, conns: Rc>>>, } impl LanMouseListener { pub(crate) async fn new(port: u16) -> Result { - let (listen_tx, listen_rx): ( - Sender<(ProtoEvent, SocketAddr)>, - Receiver<(ProtoEvent, SocketAddr)>, - ) = channel(); + let (listen_tx, listen_rx) = channel(); let listen_addr = SocketAddr::new("0.0.0.0".parse().expect("invalid ip"), port); let certificate = Certificate::generate_self_signed(["localhost".to_owned()])?; @@ -49,7 +47,8 @@ impl LanMouseListener { let conns_clone = conns.clone(); - let _listen_task: JoinHandle<()> = spawn_local(async move { + let tx = listen_tx.clone(); + let listen_task: JoinHandle<()> = spawn_local(async move { loop { let (conn, addr) = match listener.accept().await { Ok(c) => c, @@ -61,17 +60,27 @@ impl LanMouseListener { log::info!("dtls client connected, ip: {addr}"); let mut conns = conns_clone.lock().await; conns.push(conn.clone()); - spawn_local(read_loop(conns_clone.clone(), conn, listen_tx.clone())); + spawn_local(read_loop(conns_clone.clone(), conn, tx.clone())); } }); Ok(Self { conns, listen_rx, - _listen_task, + listen_tx, + listen_task, }) } + pub(crate) async fn terminate(&mut self) { + self.listen_task.abort(); + let conns = self.conns.lock().await; + for conn in conns.iter() { + let _ = conn.close().await; + } + self.listen_tx.close(); + } + #[allow(unused)] pub(crate) async fn broadcast(&self, event: ProtoEvent) { let (buf, len): ([u8; MAX_EVENT_SIZE], usize) = event.into(); diff --git a/src/server.rs b/src/server.rs index b1d8f8d..5f1ee65 100644 --- a/src/server.rs +++ b/src/server.rs @@ -43,7 +43,6 @@ pub struct ReleaseToken; pub struct Server { pub(crate) client_manager: Rc>, port: Rc>, - #[allow(unused)] pub(crate) release_bind: Vec, notifies: Rc, pub(crate) config: Rc, @@ -120,8 +119,8 @@ impl Server { let conn = LanMouseConnection::new(self.clone()); // input capture + emulation - let capture = Capture::new(self.clone(), conn); - let _emulation = Emulation::new(self.clone(), listener); + let mut capture = Capture::new(self.clone(), conn); + let mut emulation = Emulation::new(self.clone(), listener); // create dns resolver let resolver = DnsResolver::new(self.clone())?; @@ -166,6 +165,9 @@ impl Server { self.cancel(); + capture.terminate().await; + emulation.terminate().await; + Ok(()) } @@ -224,12 +226,7 @@ impl Server { .collect() } - fn handle_request( - &self, - capture: &Capture, - event: FrontendRequest, - dns: &DnsResolver, - ) -> bool { + fn handle_request(&self, capture: &Capture, event: FrontendRequest, dns: &DnsResolver) -> bool { log::debug!("frontend: {event:?}"); match event { FrontendRequest::EnableCapture => self.notify_capture(), @@ -386,12 +383,7 @@ impl Server { } } - fn update_hostname( - &self, - handle: ClientHandle, - hostname: Option, - dns: &DnsResolver, - ) { + fn update_hostname(&self, handle: ClientHandle, hostname: Option, dns: &DnsResolver) { let mut client_manager = self.client_manager.borrow_mut(); let Some((c, s)) = client_manager.get_mut(handle) else { return;