simplify cancellation of service

This commit is contained in:
Ferdinand Schober
2024-07-11 00:42:32 +02:00
parent ae3ea2c497
commit 8ba178dce0
9 changed files with 132 additions and 95 deletions

1
Cargo.lock generated
View File

@@ -1329,6 +1329,7 @@ dependencies = [
"slab",
"thiserror",
"tokio",
"tokio-util",
"toml",
]

View File

@@ -39,6 +39,7 @@ hostname = "0.4.0"
slab = "0.4.9"
endi = "1.1.0"
thiserror = "1.0.61"
tokio-util = "0.7.11"
[target.'cfg(unix)'.dependencies]
libc = "0.2.148"

View File

@@ -62,7 +62,7 @@ struct ReleaseCaptureEvent;
pub struct LibeiInputCapture<'a> {
input_capture: Pin<Box<InputCapture<'a>>>,
capture_task: JoinHandle<Result<(), CaptureError>>,
event_rx: Option<Receiver<(CaptureHandle, Event)>>,
event_rx: Receiver<(CaptureHandle, Event)>,
notify_capture: Sender<CaptureEvent>,
notify_capture_session: Sender<ReleaseCaptureEvent>,
cancellation_token: CancellationToken,
@@ -204,10 +204,6 @@ async fn libei_event_handler(
log::trace!("from ei: {ei_event:?}");
let client = current_client.get();
handle_ei_event(ei_event, client, &context, &event_tx, &release_session).await?;
if event_tx.is_closed() {
log::info!("event_tx closed -> exiting");
break Ok(());
}
}
}
@@ -232,7 +228,6 @@ impl<'a> LibeiInputCapture<'a> {
cancellation_token.clone(),
);
let capture_task = tokio::task::spawn_local(capture);
let event_rx = Some(event_rx);
let producer = Self {
input_capture,
@@ -305,11 +300,16 @@ async fn do_capture<'a>(
);
let (capture_result, ()) = tokio::join!(capture_session, handle_session_update_request);
log::info!("capture session + session_update task done!");
log::debug!("capture session + session_update task done!");
// disable capture
log::info!("disabling input capture");
input_capture.disable(&session).await?;
log::debug!("disabling input capture");
if let Err(e) = input_capture.disable(&session).await {
log::warn!("input_capture.disable(&session) {e}");
}
if let Err(e) = session.close().await {
log::warn!("session.close(): {e}");
}
// propagate error from capture session
if capture_result.is_err() {
@@ -327,8 +327,6 @@ async fn do_capture<'a>(
}
}
log::info!("no error occured");
// break
if cancellation_token.is_cancelled() {
break Ok(());
@@ -378,7 +376,7 @@ async fn do_capture_session(
release_session_clone,
client,
) => {
log::info!("libei exited: {r:?} cancelling session task");
log::debug!("libei exited: {r:?} cancelling session task");
cancel_session_clone.cancel();
}
_ = cancel_ei_handler_clone.cancelled() => {},
@@ -389,6 +387,7 @@ async fn do_capture_session(
let capture_session_task = async {
// receiver for activation tokens
let mut activated = input_capture.receive_activated().await?;
let mut ei_devices_changed = false;
loop {
tokio::select! {
activated = activated.next() => {
@@ -401,29 +400,34 @@ async fn do_capture_session(
current_client.replace(Some(client));
// client entered => send event
if event_tx.send((client, Event::Enter())).await.is_err() {
break;
};
event_tx.send((client, Event::Enter())).await.expect("no channel");
tokio::select! {
_ = capture_session_event.recv() => {}, /* capture release */
_ = release_session.notified() => {
log::warn!("release session aquired (a): {release_session:?}");
_ = release_session.notified() => { /* release session */
ei_devices_changed = true;
},
_ = cancel_session.cancelled() => break, /* kill session notify */
}
release_capture(input_capture, session, activated, client, &active_clients).await?;
}
_ = capture_session_event.recv() => {}, /* capture release -> we are not capturing anyway, so ignore */
_ = release_session.notified() => {
log::warn!("release session aquired (b): {release_session:?}");
_ = release_session.notified() => { /* release session */
ei_devices_changed = true;
},
_ = cancel_session.cancelled() => break, /* kill session notify */
}
if ei_devices_changed {
/* for whatever reason, GNOME seems to kill the session
* as soon as devices are added or removed, so we need
* to cancel */
break;
}
}
// cancel libei task
log::info!("session exited: killing libei task");
log::debug!("session exited: killing libei task");
cancel_ei_handler.cancel();
Ok::<(), CaptureError>(())
};
@@ -432,7 +436,7 @@ async fn do_capture_session(
cancel_update.cancel();
log::info!("both session and ei task finished!");
log::debug!("both session and ei task finished!");
a?;
b?;
@@ -488,7 +492,8 @@ async fn handle_ei_event(
]);
context.flush().map_err(|e| io::Error::new(e.kind(), e))?;
}
EiEvent::SeatRemoved(_) | EiEvent::DeviceAdded(_) | EiEvent::DeviceRemoved(_) => {
EiEvent::SeatRemoved(_) | /* EiEvent::DeviceAdded(_) | */ EiEvent::DeviceRemoved(_) => {
log::debug!("releasing session: {ei_event:?}");
release_session.notify_waiters();
}
EiEvent::DevicePaused(_) | EiEvent::DeviceResumed(_) => {}
@@ -500,9 +505,7 @@ async fn handle_ei_event(
_ => {
if let Some(handle) = current_client {
for event in to_input_events(ei_event).into_iter() {
if event_tx.send((handle, event)).await.is_err() {
return Ok(());
};
event_tx.send((handle, event)).await.expect("no channel");
}
}
}
@@ -669,13 +672,11 @@ impl<'a> LanMouseInputCapture for LibeiInputCapture<'a> {
}
async fn terminate(&mut self) -> Result<(), CaptureError> {
let event_rx = self.event_rx.take().expect("no channel");
std::mem::drop(event_rx);
self.cancellation_token.cancel();
let task = &mut self.capture_task;
log::info!("waiting for capture to terminate...");
log::debug!("waiting for capture to terminate...");
let res = task.await.expect("libei task panic");
log::info!("done!");
log::debug!("done!");
res
}
}
@@ -689,12 +690,7 @@ impl<'a> Stream for LibeiInputCapture<'a> {
Ok(()) => Poll::Ready(None),
Err(e) => Poll::Ready(Some(Err(e))),
},
Poll::Pending => self
.event_rx
.as_mut()
.expect("no channel")
.poll_recv(cx)
.map(|e| e.map(Result::Ok)),
Poll::Pending => self.event_rx.poll_recv(cx).map(|e| e.map(Result::Ok)),
}
}
}

View File

@@ -58,11 +58,11 @@ async fn input_capture_test(config: Config) -> Result<()> {
};
log::info!("position: {pos}, event: {event}");
if let Event::Keyboard(KeyboardEvent::Key { key: 1, .. }) = event {
// input_capture.as_mut().unwrap().release()?;
break;
input_capture.as_mut().unwrap().release().await?;
// break;
}
}
input_capture.take().unwrap().terminate().await.unwrap();
// input_capture.take().unwrap().terminate().await.unwrap();
}
Ok(())
}

View File

@@ -2,9 +2,10 @@ use log;
use std::{
cell::{Cell, RefCell},
collections::HashSet,
rc::Rc,
rc::Rc, sync::Arc,
};
use tokio::signal;
use tokio::{signal, sync::Notify};
use tokio_util::sync::CancellationToken;
use crate::{
client::{ClientConfig, ClientHandle, ClientManager, ClientState},
@@ -14,7 +15,7 @@ use crate::{
server::capture_task::CaptureEvent,
};
use self::{emulation_task::EmulationEvent, resolver_task::DnsRequest};
use self::resolver_task::DnsRequest;
mod capture_task;
mod emulation_task;
@@ -92,8 +93,9 @@ impl Server {
}
};
let (timer_tx, timer_rx) = tokio::sync::mpsc::channel(1);
let timer_notify = Arc::new(Notify::new());
let (frontend_notify_tx, frontend_notify_rx) = tokio::sync::mpsc::channel(1);
let cancellation_token = CancellationToken::new();
// udp task
let (mut udp_task, sender_tx, receiver_rx, port_tx) =
@@ -104,8 +106,9 @@ impl Server {
capture_backend,
self.clone(),
sender_tx.clone(),
timer_tx.clone(),
timer_notify.clone(),
self.release_bind.clone(),
cancellation_token.clone(),
)?;
// input emulation
@@ -115,13 +118,17 @@ impl Server {
receiver_rx,
sender_tx.clone(),
capture_channel.clone(),
timer_tx,
timer_notify.clone(),
cancellation_token.clone(),
);
// create dns resolver
let resolver = dns::DnsResolver::new().await?;
let (mut resolver_task, resolve_tx) =
resolver_task::new(resolver, self.clone(), frontend_notify_tx);
let (mut resolver_task, resolve_tx) = resolver_task::new(
resolver,
self.clone(),
frontend_notify_tx,
);
// frontend listener
let (mut frontend_task, frontend_tx) = frontend_task::new(
@@ -132,6 +139,7 @@ impl Server {
emulate_channel.clone(),
resolve_tx.clone(),
port_tx,
cancellation_token.clone(),
);
// task that pings clients to see if they are responding
@@ -140,7 +148,7 @@ impl Server {
sender_tx.clone(),
emulate_channel.clone(),
capture_channel.clone(),
timer_rx,
timer_notify,
);
let active = self
@@ -189,9 +197,7 @@ impl Server {
_ = &mut ping_task => { }
}
let _ = emulate_channel.send(EmulationEvent::Terminate).await;
let _ = capture_channel.send(CaptureEvent::Terminate).await;
let _ = frontend_tx.send(FrontendRequest::Terminate()).await;
cancellation_token.cancel();
if !capture_task.is_finished() {
if let Err(e) = capture_task.await {

View File

@@ -1,8 +1,13 @@
use anyhow::{anyhow, Result};
use futures::StreamExt;
use std::{collections::HashSet, net::SocketAddr};
use tokio_util::sync::CancellationToken;
use tokio::{process::Command, sync::mpsc::Sender, task::JoinHandle};
use tokio::{
process::Command,
sync::{mpsc::Sender, Notify},
task::JoinHandle,
};
use input_capture::{self, error::CaptureCreationError, CaptureHandle, InputCapture, Position};
@@ -20,8 +25,6 @@ pub enum CaptureEvent {
Create(CaptureHandle, Position),
/// destory a capture client
Destroy(CaptureHandle),
/// termination signal
Terminate,
/// restart input capture
Restart,
}
@@ -30,8 +33,9 @@ pub fn new(
backend: Option<CaptureBackend>,
server: Server,
sender_tx: Sender<(Event, SocketAddr)>,
timer_tx: Sender<()>,
timer_notify: Notify,
release_bind: Vec<scancode::Linux>,
cancellation_token: CancellationToken,
) -> Result<(JoinHandle<Result<()>>, Sender<CaptureEvent>), CaptureCreationError> {
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
let backend = backend.map(|b| b.into());
@@ -42,7 +46,7 @@ pub fn new(
tokio::select! {
event = capture.next() => {
match event {
Some(Ok(event)) => handle_capture_event(&server, &mut capture, &sender_tx, &timer_tx, event, &mut pressed_keys, &release_bind).await?,
Some(Ok(event)) => handle_capture_event(&server, &mut capture, &sender_tx, &timer_notify, event, &mut pressed_keys, &release_bind).await?,
Some(Err(e)) => return Err(anyhow!("input capture: {e:?}")),
None => return Err(anyhow!("input capture terminated")),
}
@@ -65,11 +69,11 @@ pub fn new(
capture.create(handle, pos.into()).await?;
}
}
CaptureEvent::Terminate => break,
},
None => break,
}
}
_ = cancellation_token.cancelled() => break,
}
}
anyhow::Ok(())
@@ -91,7 +95,7 @@ async fn handle_capture_event(
server: &Server,
capture: &mut Box<dyn InputCapture>,
sender_tx: &Sender<(Event, SocketAddr)>,
timer_tx: &Sender<()>,
timer_notify: &Notify,
event: (CaptureHandle, Event),
pressed_keys: &mut HashSet<scancode::Linux>,
release_bind: &[scancode::Linux],
@@ -150,7 +154,7 @@ async fn handle_capture_event(
(client_state.active_addr, enter, start_timer)
};
if start_timer {
let _ = timer_tx.try_send(());
timer_notify.notify_waiters();
}
if enter {
spawn_hook_command(server, handle);

View File

@@ -1,10 +1,14 @@
use std::net::SocketAddr;
use std::{net::SocketAddr, sync::Arc};
use thiserror::Error;
use tokio::{
sync::mpsc::{Receiver, Sender},
sync::{
mpsc::{Receiver, Sender},
Notify,
},
task::JoinHandle,
};
use tokio_util::sync::CancellationToken;
use crate::{
client::{ClientHandle, ClientManager},
@@ -28,8 +32,6 @@ pub enum EmulationEvent {
Destroy(EmulationHandle),
/// input emulation must release keys for client
ReleaseKeys(ClientHandle),
/// termination signal
Terminate,
/// restart input emulation
Restart,
}
@@ -40,14 +42,23 @@ pub fn new(
udp_rx: Receiver<Result<(Event, SocketAddr), NetworkError>>,
sender_tx: Sender<(Event, SocketAddr)>,
capture_tx: Sender<CaptureEvent>,
timer_tx: Sender<()>,
timer_notify: Arc<Notify>,
cancellation_token: CancellationToken,
) -> (
JoinHandle<Result<(), LanMouseEmulationError>>,
Sender<EmulationEvent>,
) {
let (tx, rx) = tokio::sync::mpsc::channel(32);
let emulation_task =
emulation_task(backend, rx, server, udp_rx, sender_tx, capture_tx, timer_tx);
let emulation_task = emulation_task(
backend,
rx,
server,
udp_rx,
sender_tx,
capture_tx,
timer_notify,
cancellation_token,
);
let emulate_task = tokio::task::spawn_local(emulation_task);
(emulate_task, tx)
}
@@ -67,7 +78,8 @@ async fn emulation_task(
mut udp_rx: Receiver<Result<(Event, SocketAddr), NetworkError>>,
sender_tx: Sender<(Event, SocketAddr)>,
capture_tx: Sender<CaptureEvent>,
timer_tx: Sender<()>,
timer_notify: Arc<Notify>,
cancellation_token: CancellationToken,
) -> Result<(), LanMouseEmulationError> {
let backend = backend.map(|b| b.into());
let mut emulation = input_emulation::create(backend).await?;
@@ -84,7 +96,7 @@ async fn emulation_task(
}
None => break,
};
handle_udp_rx(&server, &capture_tx, &mut emulation, &sender_tx, &mut last_ignored, udp_event, &timer_tx).await?;
handle_udp_rx(&server, &capture_tx, &mut emulation, &sender_tx, &mut last_ignored, udp_event, &timer_notify).await?;
}
emulate_event = rx.recv() => {
match emulate_event {
@@ -99,11 +111,11 @@ async fn emulation_task(
emulation.create(handle).await;
}
},
EmulationEvent::Terminate => break,
},
None => break,
}
}
_ = cancellation_token.cancelled() => break,
}
}
@@ -120,7 +132,7 @@ async fn handle_udp_rx(
sender_tx: &Sender<(Event, SocketAddr)>,
last_ignored: &mut Option<SocketAddr>,
event: (Event, SocketAddr),
timer_tx: &Sender<()>,
timer_notify: &Notify,
) -> Result<(), EmulationError> {
let (event, addr) = event;
@@ -170,7 +182,7 @@ async fn handle_udp_rx(
);
// restart timer if necessary
if restart_timer {
let _ = timer_tx.try_send(());
timer_notify.notify_waiters();
}
ignore_event
} else {

View File

@@ -15,6 +15,7 @@ use tokio::{
sync::mpsc::{Receiver, Sender},
task::JoinHandle,
};
use tokio_util::sync::CancellationToken;
use crate::{
client::{ClientHandle, Position},
@@ -33,10 +34,12 @@ pub(crate) fn new(
emulate: Sender<EmulationEvent>,
resolve_ch: Sender<DnsRequest>,
port_tx: Sender<u16>,
cancellation_token: CancellationToken,
) -> (JoinHandle<Result<()>>, Sender<FrontendRequest>) {
let (event_tx, mut event_rx) = tokio::sync::mpsc::channel(32);
let event_tx_clone = event_tx.clone();
let frontend_task = tokio::task::spawn_local(async move {
let mut join_handles = vec![];
loop {
tokio::select! {
stream = frontend.accept() => {
@@ -47,7 +50,7 @@ pub(crate) fn new(
continue;
}
};
handle_frontend_stream(&event_tx_clone, stream).await;
join_handles.push(handle_frontend_stream(&event_tx_clone, stream, cancellation_token.clone()));
}
event = event_rx.recv() => {
let frontend_event = event.ok_or(anyhow!("frontend channel closed"))?;
@@ -59,6 +62,10 @@ pub(crate) fn new(
let notify = notify.ok_or(anyhow!("frontend notify closed"))?;
let _ = frontend.broadcast_event(notify).await;
}
_ = cancellation_token.cancelled() => {
futures::future::join_all(join_handles).await;
break;
}
}
}
anyhow::Ok(())
@@ -66,33 +73,45 @@ pub(crate) fn new(
(frontend_task, event_tx)
}
async fn handle_frontend_stream(
fn handle_frontend_stream(
frontend_tx: &Sender<FrontendRequest>,
#[cfg(unix)] stream: ReadHalf<UnixStream>,
#[cfg(windows)] stream: ReadHalf<TcpStream>,
cancellation_token: CancellationToken,
) -> JoinHandle<()> {
let tx = frontend_tx.clone();
tokio::task::spawn_local(async move {
tokio::select! {
_ = listen_frontend(tx, stream) => return,
_ = cancellation_token.cancelled() => return,
}
})
}
async fn listen_frontend(
tx: Sender<FrontendRequest>,
#[cfg(unix)] mut stream: ReadHalf<UnixStream>,
#[cfg(windows)] mut stream: ReadHalf<TcpStream>,
) {
use std::io;
let tx = frontend_tx.clone();
tokio::task::spawn_local(async move {
loop {
let request = frontend::wait_for_request(&mut stream).await;
match request {
Ok(request) => {
let _ = tx.send(request).await;
}
Err(e) => {
if let Some(e) = e.downcast_ref::<io::Error>() {
if e.kind() == ErrorKind::UnexpectedEof {
return;
}
loop {
let request = frontend::wait_for_request(&mut stream).await;
match request {
Ok(request) => {
let _ = tx.send(request).await;
}
Err(e) => {
if let Some(e) = e.downcast_ref::<io::Error>() {
if e.kind() == ErrorKind::UnexpectedEof {
return;
}
log::error!("error reading frontend event: {e}");
return;
}
log::error!("error reading frontend event: {e}");
return;
}
}
});
}
}
async fn handle_frontend_event(

View File

@@ -1,7 +1,7 @@
use std::{net::SocketAddr, time::Duration};
use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::{
sync::mpsc::{Receiver, Sender},
sync::{mpsc::Sender, Notify},
task::JoinHandle,
};
@@ -18,15 +18,13 @@ pub fn new(
sender_ch: Sender<(Event, SocketAddr)>,
emulate_notify: Sender<EmulationEvent>,
capture_notify: Sender<CaptureEvent>,
mut timer_rx: Receiver<()>,
timer_notify: Arc<Notify>,
) -> JoinHandle<()> {
// timer task
let ping_task = tokio::task::spawn_local(async move {
loop {
// wait for wake up signal
let Some(_): Option<()> = timer_rx.recv().await else {
break;
};
timer_notify.notified().await;
loop {
let receiving = server.state.get() == State::Receiving;
let (ping_clients, ping_addrs) = {