From 9c0a40563e2649870dca98cfd3dcf84726b6cab1 Mon Sep 17 00:00:00 2001 From: Ferdinand Schober Date: Wed, 10 Jul 2024 22:03:56 +0200 Subject: [PATCH] prefer CancellationToken / Notify to channel --- Cargo.lock | 14 ++++ input-capture/Cargo.toml | 1 + input-capture/src/libei.rs | 133 +++++++++++++++++++++---------------- src/server/capture_task.rs | 1 + 4 files changed, 92 insertions(+), 57 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7a23c56..5b2cea0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1208,6 +1208,7 @@ dependencies = [ "tempfile", "thiserror", "tokio", + "tokio-util", "wayland-client", "wayland-protocols", "wayland-protocols-wlr", @@ -2049,6 +2050,19 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-util" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml" version = "0.8.14" diff --git a/input-capture/Cargo.toml b/input-capture/Cargo.toml index 2cd1f33..0a92017 100644 --- a/input-capture/Cargo.toml +++ b/input-capture/Cargo.toml @@ -18,6 +18,7 @@ thiserror = "1.0.61" tokio = { version = "1.32.0", features = ["io-util", "io-std", "macros", "net", "process", "rt", "sync", "signal"] } once_cell = "1.19.0" async-trait = "0.1.81" +tokio-util = "0.7.11" [target.'cfg(all(unix, not(target_os="macos")))'.dependencies] diff --git a/input-capture/src/libei.rs b/input-capture/src/libei.rs index 688def1..9fda9f5 100644 --- a/input-capture/src/libei.rs +++ b/input-capture/src/libei.rs @@ -20,12 +20,17 @@ use std::{ os::unix::net::UnixStream, pin::Pin, rc::Rc, + sync::Arc, task::{Context, Poll}, }; use tokio::{ - sync::mpsc::{Receiver, Sender}, + sync::{ + mpsc::{self, Receiver, Sender}, + Notify, + }, task::JoinHandle, }; +use tokio_util::sync::CancellationToken; use futures_core::Stream; use once_cell::sync::Lazy; @@ -38,8 +43,6 @@ use super::{ error::LibeiCaptureCreationError, CaptureHandle, InputCapture as LanMouseInputCapture, Position, }; -struct PoisonPill; - /* there is a bug in xdg-remote-desktop-portal-gnome / mutter that * prevents receiving further events after a session has been disabled once. * Therefore the session needs to recreated when the barriers are updated */ @@ -59,9 +62,10 @@ struct ReleaseCaptureEvent; pub struct LibeiInputCapture<'a> { input_capture: Pin>>, capture_task: JoinHandle>, - event_rx: Option>, - notify_capture: tokio::sync::mpsc::Sender, - notify_capture_session: tokio::sync::mpsc::Sender, + event_rx: Option>, + notify_capture: Sender, + notify_capture_session: Sender, + cancellation_token: CancellationToken, } static INTERFACES: Lazy> = Lazy::new(|| { @@ -188,7 +192,7 @@ async fn libei_event_handler( mut ei_event_stream: EiConvertEventStream, context: ei::Context, event_tx: Sender<(CaptureHandle, Event)>, - capture_tx: Sender, + release_session: Arc, current_client: Rc>>, ) -> Result<(), CaptureError> { loop { @@ -199,7 +203,7 @@ async fn libei_event_handler( .map_err(ReisConvertEventStreamError::from)?; log::trace!("from ei: {ei_event:?}"); let client = current_client.get(); - handle_ei_event(ei_event, client, &context, &event_tx, &capture_tx).await?; + handle_ei_event(ei_event, client, &context, &event_tx, &release_session).await?; if event_tx.is_closed() { log::info!("event_tx closed -> exiting"); break Ok(()); @@ -213,15 +217,19 @@ impl<'a> LibeiInputCapture<'a> { let input_capture_ptr = input_capture.as_ref().get_ref() as *const InputCapture<'static>; let first_session = Some(create_session(unsafe { &*input_capture_ptr }).await?); - let (event_tx, event_rx) = tokio::sync::mpsc::channel(32); - let (notify_capture, notify_rx) = tokio::sync::mpsc::channel(32); - let (notify_capture_session, notify_session_rx) = tokio::sync::mpsc::channel(32); + let (event_tx, event_rx) = mpsc::channel(1); + let (notify_capture, notify_rx) = mpsc::channel(1); + let (notify_capture_session, notify_session_rx) = mpsc::channel(1); + + let cancellation_token = CancellationToken::new(); + let capture = do_capture( input_capture_ptr, notify_rx, notify_session_rx, first_session, event_tx, + cancellation_token.clone(), ); let capture_task = tokio::task::spawn_local(capture); let event_rx = Some(event_rx); @@ -232,6 +240,7 @@ impl<'a> LibeiInputCapture<'a> { capture_task, notify_capture, notify_capture_session, + cancellation_token, }; Ok(producer) @@ -244,6 +253,7 @@ async fn do_capture<'a>( mut release_capture_channel: Receiver, session: Option<(Session<'a>, BitFlags)>, event_tx: Sender<(CaptureHandle, Event)>, + cancellation_token: CancellationToken, ) -> Result<(), CaptureError> { let mut session = session.map(|s| s.0); @@ -263,8 +273,8 @@ async fn do_capture<'a>( }; // do capture session - let (session_kill, session_kill_rx) = tokio::sync::mpsc::channel(1); - let (kill_update_tx, mut kill_update) = tokio::sync::mpsc::channel(1); + let cancel_session = CancellationToken::new(); + let cancel_update = CancellationToken::new(); let capture_session = do_capture_session( input_capture, @@ -273,9 +283,8 @@ async fn do_capture<'a>( &mut active_clients, &mut next_barrier_id, &mut release_capture_channel, - session_kill_rx, - session_kill.clone(), - kill_update_tx, + cancel_session.clone(), + cancel_update.clone(), ); let mut capture_event_occured: Option = None; @@ -284,8 +293,10 @@ async fn do_capture<'a>( // kill session if clients need to be updated let handle_session_update_request = async { tokio::select! { + /* exit requested */ + _ = cancellation_token.cancelled() => {}, /* session exited */ - _ = kill_update.recv() => {}, + _ = cancel_update.cancelled() => {}, /* zones have changed */ _ = zones_changed.next() => zones_have_changed = true, /* clients changed */ @@ -294,17 +305,17 @@ async fn do_capture<'a>( }, } // kill session (might already be dead!) - let _ = session_kill.send(PoisonPill).await; + cancel_session.cancel(); }; - let (a, _) = tokio::join!(capture_session, handle_session_update_request); - if let Err(e) = a { - log::warn!("{e}"); - } + let (capture_result, ()) = tokio::join!(capture_session, handle_session_update_request); log::info!("capture session + session_update task done!"); - log::info!("no error occured"); + // disable capture + log::info!("disabling input capture"); + input_capture.disable(&session).await?; + // update clients if requested if let Some(event) = capture_event_occured.take() { match event { CaptureEvent::Create(c, p) => active_clients.push((c, p)), @@ -312,12 +323,15 @@ async fn do_capture<'a>( } } - // disable capture - log::info!("disabling input capture"); - input_capture.disable(&session).await?; + // propagate error from capture session + if capture_result.is_err() { + return capture_result; + } + + log::info!("no error occured"); // break - if event_tx.is_closed() { + if cancellation_token.is_cancelled() { break Ok(()); } } @@ -330,9 +344,8 @@ async fn do_capture_session( active_clients: &mut Vec<(CaptureHandle, Position)>, next_barrier_id: &mut u32, capture_session_event: &mut Receiver, - mut kill_session: Receiver, - kill_session_tx: Sender, - kill_update_tx: Sender, + cancel_session: CancellationToken, + cancel_update: CancellationToken, ) -> Result<(), CaptureError> { // current client let current_client = Rc::new(Cell::new(None)); @@ -347,24 +360,29 @@ async fn do_capture_session( log::debug!("enabling session"); input_capture.enable(session).await?; + // cancellation token to release session + let release_session = Arc::new(Notify::new()); + // async event task - let (kill_ei, mut kill_ei_rx) = tokio::sync::mpsc::channel(1); - let (break_tx, mut break_rx) = tokio::sync::mpsc::channel(1); + let cancel_ei_handler = CancellationToken::new(); let event_chan = event_tx.clone(); let client = current_client.clone(); + let cancel_session_clone = cancel_session.clone(); + let release_session_clone = release_session.clone(); + let cancel_ei_handler_clone = cancel_ei_handler.clone(); let ei_task = async move { tokio::select! { r = libei_event_handler( ei_event_stream, context, event_chan, - break_tx, + release_session_clone, client, ) => { - log::info!("libei exited: {r:?} killing session task"); - let _ = kill_session_tx.send(PoisonPill).await; + log::info!("libei exited: {r:?} cancelling session task"); + cancel_session_clone.cancel(); } - _ = kill_ei_rx.recv() => {}, + _ = cancel_ei_handler_clone.cancelled() => {}, } Ok::<(), CaptureError>(()) }; @@ -383,32 +401,37 @@ async fn do_capture_session( .expect("invalid barrier id"); current_client.replace(Some(client)); + // client entered => send event if event_tx.send((client, Event::Enter())).await.is_err() { break; }; tokio::select! { _ = capture_session_event.recv() => {}, /* capture release */ - _ = break_rx.recv() => {}, /* ei notifying that it needs to restart */ - _ = kill_session.recv() => break, /* kill session notify */ + _ = release_session.notified() => { + log::warn!("release session aquired (a): {release_session:?}"); + }, + _ = cancel_session.cancelled() => break, /* kill session notify */ } release_capture(input_capture, session, activated, client, &active_clients).await?; } - _ = break_rx.recv() => {}, /* ei notifying that it needs to restart */ _ = capture_session_event.recv() => {}, /* capture release -> we are not capturing anyway, so ignore */ - _ = kill_session.recv() => break, /* kill session notify */ + _ = release_session.notified() => { + log::warn!("release session aquired (b): {release_session:?}"); + }, + _ = cancel_session.cancelled() => break, /* kill session notify */ } } - // kill libei task + // cancel libei task log::info!("session exited: killing libei task"); - let _ = kill_ei.send(()).await; + cancel_ei_handler.cancel(); Ok::<(), CaptureError>(()) }; let (a, b) = tokio::join!(ei_task, capture_session_task); - let _ = kill_update_tx.send(PoisonPill).await; + cancel_update.cancel(); log::info!("both session and ei task finished!"); a?; @@ -447,14 +470,12 @@ async fn release_capture( Ok(()) } -struct DeviceUpdate; - async fn handle_ei_event( ei_event: EiEvent, current_client: Option, context: &ei::Context, event_tx: &Sender<(CaptureHandle, Event)>, - capture_tx: &Sender, + release_session: &Notify, ) -> Result<(), CaptureError> { match ei_event { EiEvent::SeatAdded(s) => { @@ -468,15 +489,10 @@ async fn handle_ei_event( ]); context.flush().map_err(|e| io::Error::new(e.kind(), e))?; } - EiEvent::SeatRemoved(_) - | EiEvent::DeviceAdded(_) - | EiEvent::DeviceRemoved(_) - | EiEvent::DevicePaused(_) - | EiEvent::DeviceResumed(_) => { - if capture_tx.send(DeviceUpdate).await.is_err() { - return Ok(()); - } + EiEvent::SeatRemoved(_) | EiEvent::DeviceAdded(_) | EiEvent::DeviceRemoved(_) => { + release_session.notify_waiters(); } + EiEvent::DevicePaused(_) | EiEvent::DeviceResumed(_) => {} EiEvent::DeviceStartEmulating(_) => log::debug!("START EMULATING"), EiEvent::DeviceStopEmulating(_) => log::debug!("STOP EMULATING"), EiEvent::Disconnected(d) => { @@ -484,8 +500,7 @@ async fn handle_ei_event( } _ => { if let Some(handle) = current_client { - let events = to_input_events(ei_event); - for event in events.into_iter() { + for event in to_input_events(ei_event).into_iter() { if event_tx.send((handle, event)).await.is_err() { return Ok(()); }; @@ -657,8 +672,12 @@ impl<'a> LanMouseInputCapture for LibeiInputCapture<'a> { async fn async_drop(&mut self) -> Result<(), CaptureError> { let event_rx = self.event_rx.take().expect("no channel"); std::mem::drop(event_rx); + self.cancellation_token.cancel(); let task = &mut self.capture_task; - task.await.expect("libei task panic") + log::info!("waiting for capture to terminate..."); + let res = task.await.expect("libei task panic"); + log::info!("done!"); + res } } diff --git a/src/server/capture_task.rs b/src/server/capture_task.rs index ce50a59..e1b4dff 100644 --- a/src/server/capture_task.rs +++ b/src/server/capture_task.rs @@ -59,6 +59,7 @@ pub fn new( CaptureEvent::Destroy(h) => capture.destroy(h).await?, CaptureEvent::Restart => { let clients = server.client_manager.borrow().get_client_states().map(|(h, (c,_))| (h, c.pos)).collect::>(); + capture.async_drop().await?; capture = input_capture::create(backend).await?; for (handle, pos) in clients { capture.create(handle, pos.into()).await?;