diff --git a/Cargo.lock b/Cargo.lock index 5b2cea0..09893f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1329,6 +1329,7 @@ dependencies = [ "slab", "thiserror", "tokio", + "tokio-util", "toml", ] diff --git a/Cargo.toml b/Cargo.toml index 814c2bb..90efeba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ hostname = "0.4.0" slab = "0.4.9" endi = "1.1.0" thiserror = "1.0.61" +tokio-util = "0.7.11" [target.'cfg(unix)'.dependencies] libc = "0.2.148" diff --git a/input-capture/src/libei.rs b/input-capture/src/libei.rs index 034d798..057db16 100644 --- a/input-capture/src/libei.rs +++ b/input-capture/src/libei.rs @@ -62,7 +62,7 @@ struct ReleaseCaptureEvent; pub struct LibeiInputCapture<'a> { input_capture: Pin>>, capture_task: JoinHandle>, - event_rx: Option>, + event_rx: Receiver<(CaptureHandle, Event)>, notify_capture: Sender, notify_capture_session: Sender, cancellation_token: CancellationToken, @@ -204,10 +204,6 @@ async fn libei_event_handler( log::trace!("from ei: {ei_event:?}"); let client = current_client.get(); handle_ei_event(ei_event, client, &context, &event_tx, &release_session).await?; - if event_tx.is_closed() { - log::info!("event_tx closed -> exiting"); - break Ok(()); - } } } @@ -232,7 +228,6 @@ impl<'a> LibeiInputCapture<'a> { cancellation_token.clone(), ); let capture_task = tokio::task::spawn_local(capture); - let event_rx = Some(event_rx); let producer = Self { input_capture, @@ -305,11 +300,16 @@ async fn do_capture<'a>( ); let (capture_result, ()) = tokio::join!(capture_session, handle_session_update_request); - log::info!("capture session + session_update task done!"); + log::debug!("capture session + session_update task done!"); // disable capture - log::info!("disabling input capture"); - input_capture.disable(&session).await?; + log::debug!("disabling input capture"); + if let Err(e) = input_capture.disable(&session).await { + log::warn!("input_capture.disable(&session) {e}"); + } + if let Err(e) = session.close().await { + log::warn!("session.close(): {e}"); + } // propagate error from capture session if capture_result.is_err() { @@ -327,8 +327,6 @@ async fn do_capture<'a>( } } - log::info!("no error occured"); - // break if cancellation_token.is_cancelled() { break Ok(()); @@ -378,7 +376,7 @@ async fn do_capture_session( release_session_clone, client, ) => { - log::info!("libei exited: {r:?} cancelling session task"); + log::debug!("libei exited: {r:?} cancelling session task"); cancel_session_clone.cancel(); } _ = cancel_ei_handler_clone.cancelled() => {}, @@ -389,6 +387,7 @@ async fn do_capture_session( let capture_session_task = async { // receiver for activation tokens let mut activated = input_capture.receive_activated().await?; + let mut ei_devices_changed = false; loop { tokio::select! { activated = activated.next() => { @@ -401,29 +400,34 @@ async fn do_capture_session( current_client.replace(Some(client)); // client entered => send event - if event_tx.send((client, Event::Enter())).await.is_err() { - break; - }; + event_tx.send((client, Event::Enter())).await.expect("no channel"); tokio::select! { _ = capture_session_event.recv() => {}, /* capture release */ - _ = release_session.notified() => { - log::warn!("release session aquired (a): {release_session:?}"); + _ = release_session.notified() => { /* release session */ + ei_devices_changed = true; }, _ = cancel_session.cancelled() => break, /* kill session notify */ } release_capture(input_capture, session, activated, client, &active_clients).await?; + } _ = capture_session_event.recv() => {}, /* capture release -> we are not capturing anyway, so ignore */ - _ = release_session.notified() => { - log::warn!("release session aquired (b): {release_session:?}"); + _ = release_session.notified() => { /* release session */ + ei_devices_changed = true; }, _ = cancel_session.cancelled() => break, /* kill session notify */ } + if ei_devices_changed { + /* for whatever reason, GNOME seems to kill the session + * as soon as devices are added or removed, so we need + * to cancel */ + break; + } } // cancel libei task - log::info!("session exited: killing libei task"); + log::debug!("session exited: killing libei task"); cancel_ei_handler.cancel(); Ok::<(), CaptureError>(()) }; @@ -432,7 +436,7 @@ async fn do_capture_session( cancel_update.cancel(); - log::info!("both session and ei task finished!"); + log::debug!("both session and ei task finished!"); a?; b?; @@ -488,7 +492,8 @@ async fn handle_ei_event( ]); context.flush().map_err(|e| io::Error::new(e.kind(), e))?; } - EiEvent::SeatRemoved(_) | EiEvent::DeviceAdded(_) | EiEvent::DeviceRemoved(_) => { + EiEvent::SeatRemoved(_) | /* EiEvent::DeviceAdded(_) | */ EiEvent::DeviceRemoved(_) => { + log::debug!("releasing session: {ei_event:?}"); release_session.notify_waiters(); } EiEvent::DevicePaused(_) | EiEvent::DeviceResumed(_) => {} @@ -500,9 +505,7 @@ async fn handle_ei_event( _ => { if let Some(handle) = current_client { for event in to_input_events(ei_event).into_iter() { - if event_tx.send((handle, event)).await.is_err() { - return Ok(()); - }; + event_tx.send((handle, event)).await.expect("no channel"); } } } @@ -669,13 +672,11 @@ impl<'a> LanMouseInputCapture for LibeiInputCapture<'a> { } async fn terminate(&mut self) -> Result<(), CaptureError> { - let event_rx = self.event_rx.take().expect("no channel"); - std::mem::drop(event_rx); self.cancellation_token.cancel(); let task = &mut self.capture_task; - log::info!("waiting for capture to terminate..."); + log::debug!("waiting for capture to terminate..."); let res = task.await.expect("libei task panic"); - log::info!("done!"); + log::debug!("done!"); res } } @@ -689,12 +690,7 @@ impl<'a> Stream for LibeiInputCapture<'a> { Ok(()) => Poll::Ready(None), Err(e) => Poll::Ready(Some(Err(e))), }, - Poll::Pending => self - .event_rx - .as_mut() - .expect("no channel") - .poll_recv(cx) - .map(|e| e.map(Result::Ok)), + Poll::Pending => self.event_rx.poll_recv(cx).map(|e| e.map(Result::Ok)), } } } diff --git a/src/capture_test.rs b/src/capture_test.rs index aff63ba..b555473 100644 --- a/src/capture_test.rs +++ b/src/capture_test.rs @@ -58,11 +58,11 @@ async fn input_capture_test(config: Config) -> Result<()> { }; log::info!("position: {pos}, event: {event}"); if let Event::Keyboard(KeyboardEvent::Key { key: 1, .. }) = event { - // input_capture.as_mut().unwrap().release()?; - break; + input_capture.as_mut().unwrap().release().await?; + // break; } } - input_capture.take().unwrap().terminate().await.unwrap(); + // input_capture.take().unwrap().terminate().await.unwrap(); } Ok(()) } diff --git a/src/server.rs b/src/server.rs index 4f6b9c8..965a57e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -2,9 +2,10 @@ use log; use std::{ cell::{Cell, RefCell}, collections::HashSet, - rc::Rc, + rc::Rc, sync::Arc, }; -use tokio::signal; +use tokio::{signal, sync::Notify}; +use tokio_util::sync::CancellationToken; use crate::{ client::{ClientConfig, ClientHandle, ClientManager, ClientState}, @@ -14,7 +15,7 @@ use crate::{ server::capture_task::CaptureEvent, }; -use self::{emulation_task::EmulationEvent, resolver_task::DnsRequest}; +use self::resolver_task::DnsRequest; mod capture_task; mod emulation_task; @@ -92,8 +93,9 @@ impl Server { } }; - let (timer_tx, timer_rx) = tokio::sync::mpsc::channel(1); + let timer_notify = Arc::new(Notify::new()); let (frontend_notify_tx, frontend_notify_rx) = tokio::sync::mpsc::channel(1); + let cancellation_token = CancellationToken::new(); // udp task let (mut udp_task, sender_tx, receiver_rx, port_tx) = @@ -104,8 +106,9 @@ impl Server { capture_backend, self.clone(), sender_tx.clone(), - timer_tx.clone(), + timer_notify.clone(), self.release_bind.clone(), + cancellation_token.clone(), )?; // input emulation @@ -115,13 +118,17 @@ impl Server { receiver_rx, sender_tx.clone(), capture_channel.clone(), - timer_tx, + timer_notify.clone(), + cancellation_token.clone(), ); // create dns resolver let resolver = dns::DnsResolver::new().await?; - let (mut resolver_task, resolve_tx) = - resolver_task::new(resolver, self.clone(), frontend_notify_tx); + let (mut resolver_task, resolve_tx) = resolver_task::new( + resolver, + self.clone(), + frontend_notify_tx, + ); // frontend listener let (mut frontend_task, frontend_tx) = frontend_task::new( @@ -132,6 +139,7 @@ impl Server { emulate_channel.clone(), resolve_tx.clone(), port_tx, + cancellation_token.clone(), ); // task that pings clients to see if they are responding @@ -140,7 +148,7 @@ impl Server { sender_tx.clone(), emulate_channel.clone(), capture_channel.clone(), - timer_rx, + timer_notify, ); let active = self @@ -189,9 +197,7 @@ impl Server { _ = &mut ping_task => { } } - let _ = emulate_channel.send(EmulationEvent::Terminate).await; - let _ = capture_channel.send(CaptureEvent::Terminate).await; - let _ = frontend_tx.send(FrontendRequest::Terminate()).await; + cancellation_token.cancel(); if !capture_task.is_finished() { if let Err(e) = capture_task.await { diff --git a/src/server/capture_task.rs b/src/server/capture_task.rs index 387317d..5037e02 100644 --- a/src/server/capture_task.rs +++ b/src/server/capture_task.rs @@ -1,8 +1,13 @@ use anyhow::{anyhow, Result}; use futures::StreamExt; use std::{collections::HashSet, net::SocketAddr}; +use tokio_util::sync::CancellationToken; -use tokio::{process::Command, sync::mpsc::Sender, task::JoinHandle}; +use tokio::{ + process::Command, + sync::{mpsc::Sender, Notify}, + task::JoinHandle, +}; use input_capture::{self, error::CaptureCreationError, CaptureHandle, InputCapture, Position}; @@ -20,8 +25,6 @@ pub enum CaptureEvent { Create(CaptureHandle, Position), /// destory a capture client Destroy(CaptureHandle), - /// termination signal - Terminate, /// restart input capture Restart, } @@ -30,8 +33,9 @@ pub fn new( backend: Option, server: Server, sender_tx: Sender<(Event, SocketAddr)>, - timer_tx: Sender<()>, + timer_notify: Notify, release_bind: Vec, + cancellation_token: CancellationToken, ) -> Result<(JoinHandle>, Sender), CaptureCreationError> { let (tx, mut rx) = tokio::sync::mpsc::channel(32); let backend = backend.map(|b| b.into()); @@ -42,7 +46,7 @@ pub fn new( tokio::select! { event = capture.next() => { match event { - Some(Ok(event)) => handle_capture_event(&server, &mut capture, &sender_tx, &timer_tx, event, &mut pressed_keys, &release_bind).await?, + Some(Ok(event)) => handle_capture_event(&server, &mut capture, &sender_tx, &timer_notify, event, &mut pressed_keys, &release_bind).await?, Some(Err(e)) => return Err(anyhow!("input capture: {e:?}")), None => return Err(anyhow!("input capture terminated")), } @@ -65,11 +69,11 @@ pub fn new( capture.create(handle, pos.into()).await?; } } - CaptureEvent::Terminate => break, }, None => break, } } + _ = cancellation_token.cancelled() => break, } } anyhow::Ok(()) @@ -91,7 +95,7 @@ async fn handle_capture_event( server: &Server, capture: &mut Box, sender_tx: &Sender<(Event, SocketAddr)>, - timer_tx: &Sender<()>, + timer_notify: &Notify, event: (CaptureHandle, Event), pressed_keys: &mut HashSet, release_bind: &[scancode::Linux], @@ -150,7 +154,7 @@ async fn handle_capture_event( (client_state.active_addr, enter, start_timer) }; if start_timer { - let _ = timer_tx.try_send(()); + timer_notify.notify_waiters(); } if enter { spawn_hook_command(server, handle); diff --git a/src/server/emulation_task.rs b/src/server/emulation_task.rs index 9d58bec..0e97a3c 100644 --- a/src/server/emulation_task.rs +++ b/src/server/emulation_task.rs @@ -1,10 +1,14 @@ -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use thiserror::Error; use tokio::{ - sync::mpsc::{Receiver, Sender}, + sync::{ + mpsc::{Receiver, Sender}, + Notify, + }, task::JoinHandle, }; +use tokio_util::sync::CancellationToken; use crate::{ client::{ClientHandle, ClientManager}, @@ -28,8 +32,6 @@ pub enum EmulationEvent { Destroy(EmulationHandle), /// input emulation must release keys for client ReleaseKeys(ClientHandle), - /// termination signal - Terminate, /// restart input emulation Restart, } @@ -40,14 +42,23 @@ pub fn new( udp_rx: Receiver>, sender_tx: Sender<(Event, SocketAddr)>, capture_tx: Sender, - timer_tx: Sender<()>, + timer_notify: Arc, + cancellation_token: CancellationToken, ) -> ( JoinHandle>, Sender, ) { let (tx, rx) = tokio::sync::mpsc::channel(32); - let emulation_task = - emulation_task(backend, rx, server, udp_rx, sender_tx, capture_tx, timer_tx); + let emulation_task = emulation_task( + backend, + rx, + server, + udp_rx, + sender_tx, + capture_tx, + timer_notify, + cancellation_token, + ); let emulate_task = tokio::task::spawn_local(emulation_task); (emulate_task, tx) } @@ -67,7 +78,8 @@ async fn emulation_task( mut udp_rx: Receiver>, sender_tx: Sender<(Event, SocketAddr)>, capture_tx: Sender, - timer_tx: Sender<()>, + timer_notify: Arc, + cancellation_token: CancellationToken, ) -> Result<(), LanMouseEmulationError> { let backend = backend.map(|b| b.into()); let mut emulation = input_emulation::create(backend).await?; @@ -84,7 +96,7 @@ async fn emulation_task( } None => break, }; - handle_udp_rx(&server, &capture_tx, &mut emulation, &sender_tx, &mut last_ignored, udp_event, &timer_tx).await?; + handle_udp_rx(&server, &capture_tx, &mut emulation, &sender_tx, &mut last_ignored, udp_event, &timer_notify).await?; } emulate_event = rx.recv() => { match emulate_event { @@ -99,11 +111,11 @@ async fn emulation_task( emulation.create(handle).await; } }, - EmulationEvent::Terminate => break, }, None => break, } } + _ = cancellation_token.cancelled() => break, } } @@ -120,7 +132,7 @@ async fn handle_udp_rx( sender_tx: &Sender<(Event, SocketAddr)>, last_ignored: &mut Option, event: (Event, SocketAddr), - timer_tx: &Sender<()>, + timer_notify: &Notify, ) -> Result<(), EmulationError> { let (event, addr) = event; @@ -170,7 +182,7 @@ async fn handle_udp_rx( ); // restart timer if necessary if restart_timer { - let _ = timer_tx.try_send(()); + timer_notify.notify_waiters(); } ignore_event } else { diff --git a/src/server/frontend_task.rs b/src/server/frontend_task.rs index eb945d7..b7feb33 100644 --- a/src/server/frontend_task.rs +++ b/src/server/frontend_task.rs @@ -15,6 +15,7 @@ use tokio::{ sync::mpsc::{Receiver, Sender}, task::JoinHandle, }; +use tokio_util::sync::CancellationToken; use crate::{ client::{ClientHandle, Position}, @@ -33,10 +34,12 @@ pub(crate) fn new( emulate: Sender, resolve_ch: Sender, port_tx: Sender, + cancellation_token: CancellationToken, ) -> (JoinHandle>, Sender) { let (event_tx, mut event_rx) = tokio::sync::mpsc::channel(32); let event_tx_clone = event_tx.clone(); let frontend_task = tokio::task::spawn_local(async move { + let mut join_handles = vec![]; loop { tokio::select! { stream = frontend.accept() => { @@ -47,7 +50,7 @@ pub(crate) fn new( continue; } }; - handle_frontend_stream(&event_tx_clone, stream).await; + join_handles.push(handle_frontend_stream(&event_tx_clone, stream, cancellation_token.clone())); } event = event_rx.recv() => { let frontend_event = event.ok_or(anyhow!("frontend channel closed"))?; @@ -59,6 +62,10 @@ pub(crate) fn new( let notify = notify.ok_or(anyhow!("frontend notify closed"))?; let _ = frontend.broadcast_event(notify).await; } + _ = cancellation_token.cancelled() => { + futures::future::join_all(join_handles).await; + break; + } } } anyhow::Ok(()) @@ -66,33 +73,45 @@ pub(crate) fn new( (frontend_task, event_tx) } -async fn handle_frontend_stream( +fn handle_frontend_stream( frontend_tx: &Sender, + #[cfg(unix)] stream: ReadHalf, + #[cfg(windows)] stream: ReadHalf, + cancellation_token: CancellationToken, +) -> JoinHandle<()> { + + let tx = frontend_tx.clone(); + tokio::task::spawn_local(async move { + tokio::select! { + _ = listen_frontend(tx, stream) => return, + _ = cancellation_token.cancelled() => return, + } + }) +} + +async fn listen_frontend( + tx: Sender, #[cfg(unix)] mut stream: ReadHalf, #[cfg(windows)] mut stream: ReadHalf, ) { use std::io; - - let tx = frontend_tx.clone(); - tokio::task::spawn_local(async move { - loop { - let request = frontend::wait_for_request(&mut stream).await; - match request { - Ok(request) => { - let _ = tx.send(request).await; - } - Err(e) => { - if let Some(e) = e.downcast_ref::() { - if e.kind() == ErrorKind::UnexpectedEof { - return; - } + loop { + let request = frontend::wait_for_request(&mut stream).await; + match request { + Ok(request) => { + let _ = tx.send(request).await; + } + Err(e) => { + if let Some(e) = e.downcast_ref::() { + if e.kind() == ErrorKind::UnexpectedEof { + return; } - log::error!("error reading frontend event: {e}"); - return; } + log::error!("error reading frontend event: {e}"); + return; } } - }); + } } async fn handle_frontend_event( diff --git a/src/server/ping_task.rs b/src/server/ping_task.rs index 2177671..0e7189f 100644 --- a/src/server/ping_task.rs +++ b/src/server/ping_task.rs @@ -1,7 +1,7 @@ -use std::{net::SocketAddr, time::Duration}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; use tokio::{ - sync::mpsc::{Receiver, Sender}, + sync::{mpsc::Sender, Notify}, task::JoinHandle, }; @@ -18,15 +18,13 @@ pub fn new( sender_ch: Sender<(Event, SocketAddr)>, emulate_notify: Sender, capture_notify: Sender, - mut timer_rx: Receiver<()>, + timer_notify: Arc, ) -> JoinHandle<()> { // timer task let ping_task = tokio::task::spawn_local(async move { loop { // wait for wake up signal - let Some(_): Option<()> = timer_rx.recv().await else { - break; - }; + timer_notify.notified().await; loop { let receiving = server.state.get() == State::Receiving; let (ping_clients, ping_addrs) = {