prefer CancellationToken / Notify to channel

This commit is contained in:
Ferdinand Schober
2024-07-10 22:03:56 +02:00
parent 8ba92ede34
commit 9c0a40563e
4 changed files with 92 additions and 57 deletions

14
Cargo.lock generated
View File

@@ -1208,6 +1208,7 @@ dependencies = [
"tempfile", "tempfile",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-util",
"wayland-client", "wayland-client",
"wayland-protocols", "wayland-protocols",
"wayland-protocols-wlr", "wayland-protocols-wlr",
@@ -2049,6 +2050,19 @@ dependencies = [
"syn", "syn",
] ]
[[package]]
name = "tokio-util"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project-lite",
"tokio",
]
[[package]] [[package]]
name = "toml" name = "toml"
version = "0.8.14" version = "0.8.14"

View File

@@ -18,6 +18,7 @@ thiserror = "1.0.61"
tokio = { version = "1.32.0", features = ["io-util", "io-std", "macros", "net", "process", "rt", "sync", "signal"] } tokio = { version = "1.32.0", features = ["io-util", "io-std", "macros", "net", "process", "rt", "sync", "signal"] }
once_cell = "1.19.0" once_cell = "1.19.0"
async-trait = "0.1.81" async-trait = "0.1.81"
tokio-util = "0.7.11"
[target.'cfg(all(unix, not(target_os="macos")))'.dependencies] [target.'cfg(all(unix, not(target_os="macos")))'.dependencies]

View File

@@ -20,12 +20,17 @@ use std::{
os::unix::net::UnixStream, os::unix::net::UnixStream,
pin::Pin, pin::Pin,
rc::Rc, rc::Rc,
sync::Arc,
task::{Context, Poll}, task::{Context, Poll},
}; };
use tokio::{ use tokio::{
sync::mpsc::{Receiver, Sender}, sync::{
mpsc::{self, Receiver, Sender},
Notify,
},
task::JoinHandle, task::JoinHandle,
}; };
use tokio_util::sync::CancellationToken;
use futures_core::Stream; use futures_core::Stream;
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
@@ -38,8 +43,6 @@ use super::{
error::LibeiCaptureCreationError, CaptureHandle, InputCapture as LanMouseInputCapture, Position, error::LibeiCaptureCreationError, CaptureHandle, InputCapture as LanMouseInputCapture, Position,
}; };
struct PoisonPill;
/* there is a bug in xdg-remote-desktop-portal-gnome / mutter that /* there is a bug in xdg-remote-desktop-portal-gnome / mutter that
* prevents receiving further events after a session has been disabled once. * prevents receiving further events after a session has been disabled once.
* Therefore the session needs to recreated when the barriers are updated */ * Therefore the session needs to recreated when the barriers are updated */
@@ -59,9 +62,10 @@ struct ReleaseCaptureEvent;
pub struct LibeiInputCapture<'a> { pub struct LibeiInputCapture<'a> {
input_capture: Pin<Box<InputCapture<'a>>>, input_capture: Pin<Box<InputCapture<'a>>>,
capture_task: JoinHandle<Result<(), CaptureError>>, capture_task: JoinHandle<Result<(), CaptureError>>,
event_rx: Option<tokio::sync::mpsc::Receiver<(CaptureHandle, Event)>>, event_rx: Option<Receiver<(CaptureHandle, Event)>>,
notify_capture: tokio::sync::mpsc::Sender<CaptureEvent>, notify_capture: Sender<CaptureEvent>,
notify_capture_session: tokio::sync::mpsc::Sender<ReleaseCaptureEvent>, notify_capture_session: Sender<ReleaseCaptureEvent>,
cancellation_token: CancellationToken,
} }
static INTERFACES: Lazy<HashMap<&'static str, u32>> = Lazy::new(|| { static INTERFACES: Lazy<HashMap<&'static str, u32>> = Lazy::new(|| {
@@ -188,7 +192,7 @@ async fn libei_event_handler(
mut ei_event_stream: EiConvertEventStream, mut ei_event_stream: EiConvertEventStream,
context: ei::Context, context: ei::Context,
event_tx: Sender<(CaptureHandle, Event)>, event_tx: Sender<(CaptureHandle, Event)>,
capture_tx: Sender<DeviceUpdate>, release_session: Arc<Notify>,
current_client: Rc<Cell<Option<CaptureHandle>>>, current_client: Rc<Cell<Option<CaptureHandle>>>,
) -> Result<(), CaptureError> { ) -> Result<(), CaptureError> {
loop { loop {
@@ -199,7 +203,7 @@ async fn libei_event_handler(
.map_err(ReisConvertEventStreamError::from)?; .map_err(ReisConvertEventStreamError::from)?;
log::trace!("from ei: {ei_event:?}"); log::trace!("from ei: {ei_event:?}");
let client = current_client.get(); let client = current_client.get();
handle_ei_event(ei_event, client, &context, &event_tx, &capture_tx).await?; handle_ei_event(ei_event, client, &context, &event_tx, &release_session).await?;
if event_tx.is_closed() { if event_tx.is_closed() {
log::info!("event_tx closed -> exiting"); log::info!("event_tx closed -> exiting");
break Ok(()); break Ok(());
@@ -213,15 +217,19 @@ impl<'a> LibeiInputCapture<'a> {
let input_capture_ptr = input_capture.as_ref().get_ref() as *const InputCapture<'static>; let input_capture_ptr = input_capture.as_ref().get_ref() as *const InputCapture<'static>;
let first_session = Some(create_session(unsafe { &*input_capture_ptr }).await?); let first_session = Some(create_session(unsafe { &*input_capture_ptr }).await?);
let (event_tx, event_rx) = tokio::sync::mpsc::channel(32); let (event_tx, event_rx) = mpsc::channel(1);
let (notify_capture, notify_rx) = tokio::sync::mpsc::channel(32); let (notify_capture, notify_rx) = mpsc::channel(1);
let (notify_capture_session, notify_session_rx) = tokio::sync::mpsc::channel(32); let (notify_capture_session, notify_session_rx) = mpsc::channel(1);
let cancellation_token = CancellationToken::new();
let capture = do_capture( let capture = do_capture(
input_capture_ptr, input_capture_ptr,
notify_rx, notify_rx,
notify_session_rx, notify_session_rx,
first_session, first_session,
event_tx, event_tx,
cancellation_token.clone(),
); );
let capture_task = tokio::task::spawn_local(capture); let capture_task = tokio::task::spawn_local(capture);
let event_rx = Some(event_rx); let event_rx = Some(event_rx);
@@ -232,6 +240,7 @@ impl<'a> LibeiInputCapture<'a> {
capture_task, capture_task,
notify_capture, notify_capture,
notify_capture_session, notify_capture_session,
cancellation_token,
}; };
Ok(producer) Ok(producer)
@@ -244,6 +253,7 @@ async fn do_capture<'a>(
mut release_capture_channel: Receiver<ReleaseCaptureEvent>, mut release_capture_channel: Receiver<ReleaseCaptureEvent>,
session: Option<(Session<'a>, BitFlags<Capabilities>)>, session: Option<(Session<'a>, BitFlags<Capabilities>)>,
event_tx: Sender<(CaptureHandle, Event)>, event_tx: Sender<(CaptureHandle, Event)>,
cancellation_token: CancellationToken,
) -> Result<(), CaptureError> { ) -> Result<(), CaptureError> {
let mut session = session.map(|s| s.0); let mut session = session.map(|s| s.0);
@@ -263,8 +273,8 @@ async fn do_capture<'a>(
}; };
// do capture session // do capture session
let (session_kill, session_kill_rx) = tokio::sync::mpsc::channel(1); let cancel_session = CancellationToken::new();
let (kill_update_tx, mut kill_update) = tokio::sync::mpsc::channel(1); let cancel_update = CancellationToken::new();
let capture_session = do_capture_session( let capture_session = do_capture_session(
input_capture, input_capture,
@@ -273,9 +283,8 @@ async fn do_capture<'a>(
&mut active_clients, &mut active_clients,
&mut next_barrier_id, &mut next_barrier_id,
&mut release_capture_channel, &mut release_capture_channel,
session_kill_rx, cancel_session.clone(),
session_kill.clone(), cancel_update.clone(),
kill_update_tx,
); );
let mut capture_event_occured: Option<CaptureEvent> = None; let mut capture_event_occured: Option<CaptureEvent> = None;
@@ -284,8 +293,10 @@ async fn do_capture<'a>(
// kill session if clients need to be updated // kill session if clients need to be updated
let handle_session_update_request = async { let handle_session_update_request = async {
tokio::select! { tokio::select! {
/* exit requested */
_ = cancellation_token.cancelled() => {},
/* session exited */ /* session exited */
_ = kill_update.recv() => {}, _ = cancel_update.cancelled() => {},
/* zones have changed */ /* zones have changed */
_ = zones_changed.next() => zones_have_changed = true, _ = zones_changed.next() => zones_have_changed = true,
/* clients changed */ /* clients changed */
@@ -294,17 +305,17 @@ async fn do_capture<'a>(
}, },
} }
// kill session (might already be dead!) // kill session (might already be dead!)
let _ = session_kill.send(PoisonPill).await; cancel_session.cancel();
}; };
let (a, _) = tokio::join!(capture_session, handle_session_update_request); let (capture_result, ()) = tokio::join!(capture_session, handle_session_update_request);
if let Err(e) = a {
log::warn!("{e}");
}
log::info!("capture session + session_update task done!"); log::info!("capture session + session_update task done!");
log::info!("no error occured"); // disable capture
log::info!("disabling input capture");
input_capture.disable(&session).await?;
// update clients if requested
if let Some(event) = capture_event_occured.take() { if let Some(event) = capture_event_occured.take() {
match event { match event {
CaptureEvent::Create(c, p) => active_clients.push((c, p)), CaptureEvent::Create(c, p) => active_clients.push((c, p)),
@@ -312,12 +323,15 @@ async fn do_capture<'a>(
} }
} }
// disable capture // propagate error from capture session
log::info!("disabling input capture"); if capture_result.is_err() {
input_capture.disable(&session).await?; return capture_result;
}
log::info!("no error occured");
// break // break
if event_tx.is_closed() { if cancellation_token.is_cancelled() {
break Ok(()); break Ok(());
} }
} }
@@ -330,9 +344,8 @@ async fn do_capture_session(
active_clients: &mut Vec<(CaptureHandle, Position)>, active_clients: &mut Vec<(CaptureHandle, Position)>,
next_barrier_id: &mut u32, next_barrier_id: &mut u32,
capture_session_event: &mut Receiver<ReleaseCaptureEvent>, capture_session_event: &mut Receiver<ReleaseCaptureEvent>,
mut kill_session: Receiver<PoisonPill>, cancel_session: CancellationToken,
kill_session_tx: Sender<PoisonPill>, cancel_update: CancellationToken,
kill_update_tx: Sender<PoisonPill>,
) -> Result<(), CaptureError> { ) -> Result<(), CaptureError> {
// current client // current client
let current_client = Rc::new(Cell::new(None)); let current_client = Rc::new(Cell::new(None));
@@ -347,24 +360,29 @@ async fn do_capture_session(
log::debug!("enabling session"); log::debug!("enabling session");
input_capture.enable(session).await?; input_capture.enable(session).await?;
// cancellation token to release session
let release_session = Arc::new(Notify::new());
// async event task // async event task
let (kill_ei, mut kill_ei_rx) = tokio::sync::mpsc::channel(1); let cancel_ei_handler = CancellationToken::new();
let (break_tx, mut break_rx) = tokio::sync::mpsc::channel(1);
let event_chan = event_tx.clone(); let event_chan = event_tx.clone();
let client = current_client.clone(); let client = current_client.clone();
let cancel_session_clone = cancel_session.clone();
let release_session_clone = release_session.clone();
let cancel_ei_handler_clone = cancel_ei_handler.clone();
let ei_task = async move { let ei_task = async move {
tokio::select! { tokio::select! {
r = libei_event_handler( r = libei_event_handler(
ei_event_stream, ei_event_stream,
context, context,
event_chan, event_chan,
break_tx, release_session_clone,
client, client,
) => { ) => {
log::info!("libei exited: {r:?} killing session task"); log::info!("libei exited: {r:?} cancelling session task");
let _ = kill_session_tx.send(PoisonPill).await; cancel_session_clone.cancel();
} }
_ = kill_ei_rx.recv() => {}, _ = cancel_ei_handler_clone.cancelled() => {},
} }
Ok::<(), CaptureError>(()) Ok::<(), CaptureError>(())
}; };
@@ -383,32 +401,37 @@ async fn do_capture_session(
.expect("invalid barrier id"); .expect("invalid barrier id");
current_client.replace(Some(client)); current_client.replace(Some(client));
// client entered => send event
if event_tx.send((client, Event::Enter())).await.is_err() { if event_tx.send((client, Event::Enter())).await.is_err() {
break; break;
}; };
tokio::select! { tokio::select! {
_ = capture_session_event.recv() => {}, /* capture release */ _ = capture_session_event.recv() => {}, /* capture release */
_ = break_rx.recv() => {}, /* ei notifying that it needs to restart */ _ = release_session.notified() => {
_ = kill_session.recv() => break, /* kill session notify */ log::warn!("release session aquired (a): {release_session:?}");
},
_ = cancel_session.cancelled() => break, /* kill session notify */
} }
release_capture(input_capture, session, activated, client, &active_clients).await?; release_capture(input_capture, session, activated, client, &active_clients).await?;
} }
_ = break_rx.recv() => {}, /* ei notifying that it needs to restart */
_ = capture_session_event.recv() => {}, /* capture release -> we are not capturing anyway, so ignore */ _ = capture_session_event.recv() => {}, /* capture release -> we are not capturing anyway, so ignore */
_ = kill_session.recv() => break, /* kill session notify */ _ = release_session.notified() => {
log::warn!("release session aquired (b): {release_session:?}");
},
_ = cancel_session.cancelled() => break, /* kill session notify */
} }
} }
// kill libei task // cancel libei task
log::info!("session exited: killing libei task"); log::info!("session exited: killing libei task");
let _ = kill_ei.send(()).await; cancel_ei_handler.cancel();
Ok::<(), CaptureError>(()) Ok::<(), CaptureError>(())
}; };
let (a, b) = tokio::join!(ei_task, capture_session_task); let (a, b) = tokio::join!(ei_task, capture_session_task);
let _ = kill_update_tx.send(PoisonPill).await; cancel_update.cancel();
log::info!("both session and ei task finished!"); log::info!("both session and ei task finished!");
a?; a?;
@@ -447,14 +470,12 @@ async fn release_capture(
Ok(()) Ok(())
} }
struct DeviceUpdate;
async fn handle_ei_event( async fn handle_ei_event(
ei_event: EiEvent, ei_event: EiEvent,
current_client: Option<CaptureHandle>, current_client: Option<CaptureHandle>,
context: &ei::Context, context: &ei::Context,
event_tx: &Sender<(CaptureHandle, Event)>, event_tx: &Sender<(CaptureHandle, Event)>,
capture_tx: &Sender<DeviceUpdate>, release_session: &Notify,
) -> Result<(), CaptureError> { ) -> Result<(), CaptureError> {
match ei_event { match ei_event {
EiEvent::SeatAdded(s) => { EiEvent::SeatAdded(s) => {
@@ -468,15 +489,10 @@ async fn handle_ei_event(
]); ]);
context.flush().map_err(|e| io::Error::new(e.kind(), e))?; context.flush().map_err(|e| io::Error::new(e.kind(), e))?;
} }
EiEvent::SeatRemoved(_) EiEvent::SeatRemoved(_) | EiEvent::DeviceAdded(_) | EiEvent::DeviceRemoved(_) => {
| EiEvent::DeviceAdded(_) release_session.notify_waiters();
| EiEvent::DeviceRemoved(_)
| EiEvent::DevicePaused(_)
| EiEvent::DeviceResumed(_) => {
if capture_tx.send(DeviceUpdate).await.is_err() {
return Ok(());
}
} }
EiEvent::DevicePaused(_) | EiEvent::DeviceResumed(_) => {}
EiEvent::DeviceStartEmulating(_) => log::debug!("START EMULATING"), EiEvent::DeviceStartEmulating(_) => log::debug!("START EMULATING"),
EiEvent::DeviceStopEmulating(_) => log::debug!("STOP EMULATING"), EiEvent::DeviceStopEmulating(_) => log::debug!("STOP EMULATING"),
EiEvent::Disconnected(d) => { EiEvent::Disconnected(d) => {
@@ -484,8 +500,7 @@ async fn handle_ei_event(
} }
_ => { _ => {
if let Some(handle) = current_client { if let Some(handle) = current_client {
let events = to_input_events(ei_event); for event in to_input_events(ei_event).into_iter() {
for event in events.into_iter() {
if event_tx.send((handle, event)).await.is_err() { if event_tx.send((handle, event)).await.is_err() {
return Ok(()); return Ok(());
}; };
@@ -657,8 +672,12 @@ impl<'a> LanMouseInputCapture for LibeiInputCapture<'a> {
async fn async_drop(&mut self) -> Result<(), CaptureError> { async fn async_drop(&mut self) -> Result<(), CaptureError> {
let event_rx = self.event_rx.take().expect("no channel"); let event_rx = self.event_rx.take().expect("no channel");
std::mem::drop(event_rx); std::mem::drop(event_rx);
self.cancellation_token.cancel();
let task = &mut self.capture_task; let task = &mut self.capture_task;
task.await.expect("libei task panic") log::info!("waiting for capture to terminate...");
let res = task.await.expect("libei task panic");
log::info!("done!");
res
} }
} }

View File

@@ -59,6 +59,7 @@ pub fn new(
CaptureEvent::Destroy(h) => capture.destroy(h).await?, CaptureEvent::Destroy(h) => capture.destroy(h).await?,
CaptureEvent::Restart => { CaptureEvent::Restart => {
let clients = server.client_manager.borrow().get_client_states().map(|(h, (c,_))| (h, c.pos)).collect::<Vec<_>>(); let clients = server.client_manager.borrow().get_client_states().map(|(h, (c,_))| (h, c.pos)).collect::<Vec<_>>();
capture.async_drop().await?;
capture = input_capture::create(backend).await?; capture = input_capture::create(backend).await?;
for (handle, pos) in clients { for (handle, pos) in clients {
capture.create(handle, pos.into()).await?; capture.create(handle, pos.into()).await?;