improve cancellation

This commit is contained in:
Ferdinand Schober
2024-07-11 16:34:41 +02:00
parent f0c9290579
commit 2d26bd6a0b
5 changed files with 117 additions and 94 deletions

View File

@@ -48,7 +48,7 @@ use super::{
/// events that necessitate restarting the capture session /// events that necessitate restarting the capture session
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
enum CaptureEvent { enum LibeiNotifyEvent {
Create(CaptureHandle, Position), Create(CaptureHandle, Position),
Destroy(CaptureHandle), Destroy(CaptureHandle),
} }
@@ -58,7 +58,7 @@ pub struct LibeiInputCapture<'a> {
input_capture: Pin<Box<InputCapture<'a>>>, input_capture: Pin<Box<InputCapture<'a>>>,
capture_task: JoinHandle<Result<(), CaptureError>>, capture_task: JoinHandle<Result<(), CaptureError>>,
event_rx: Receiver<(CaptureHandle, Event)>, event_rx: Receiver<(CaptureHandle, Event)>,
notify_capture: Sender<CaptureEvent>, notify_capture: Sender<LibeiNotifyEvent>,
notify_release: Arc<Notify>, notify_release: Arc<Notify>,
cancellation_token: CancellationToken, cancellation_token: CancellationToken,
} }
@@ -239,7 +239,7 @@ impl<'a> LibeiInputCapture<'a> {
async fn do_capture<'a>( async fn do_capture<'a>(
input_capture: *const InputCapture<'a>, input_capture: *const InputCapture<'a>,
mut capture_event: Receiver<CaptureEvent>, mut capture_event: Receiver<LibeiNotifyEvent>,
notify_release: Arc<Notify>, notify_release: Arc<Notify>,
session: Option<(Session<'a>, BitFlags<Capabilities>)>, session: Option<(Session<'a>, BitFlags<Capabilities>)>,
event_tx: Sender<(CaptureHandle, Event)>, event_tx: Sender<(CaptureHandle, Event)>,
@@ -259,7 +259,7 @@ async fn do_capture<'a>(
let cancel_session = CancellationToken::new(); let cancel_session = CancellationToken::new();
let cancel_update = CancellationToken::new(); let cancel_update = CancellationToken::new();
let mut capture_event_occured: Option<CaptureEvent> = None; let mut capture_event_occured: Option<LibeiNotifyEvent> = None;
let mut zones_have_changed = false; let mut zones_have_changed = false;
// kill session if clients need to be updated // kill session if clients need to be updated
@@ -315,8 +315,8 @@ async fn do_capture<'a>(
// update clients if requested // update clients if requested
if let Some(event) = capture_event_occured.take() { if let Some(event) = capture_event_occured.take() {
match event { match event {
CaptureEvent::Create(c, p) => active_clients.push((c, p)), LibeiNotifyEvent::Create(c, p) => active_clients.push((c, p)),
CaptureEvent::Destroy(c) => active_clients.retain(|(h, _)| *h != c), LibeiNotifyEvent::Destroy(c) => active_clients.retain(|(h, _)| *h != c),
} }
} }
@@ -658,7 +658,7 @@ impl<'a> LanMouseInputCapture for LibeiInputCapture<'a> {
async fn create(&mut self, handle: CaptureHandle, pos: Position) -> Result<(), CaptureError> { async fn create(&mut self, handle: CaptureHandle, pos: Position) -> Result<(), CaptureError> {
let _ = self let _ = self
.notify_capture .notify_capture
.send(CaptureEvent::Create(handle, pos)) .send(LibeiNotifyEvent::Create(handle, pos))
.await; .await;
Ok(()) Ok(())
} }
@@ -666,7 +666,7 @@ impl<'a> LanMouseInputCapture for LibeiInputCapture<'a> {
async fn destroy(&mut self, handle: CaptureHandle) -> Result<(), CaptureError> { async fn destroy(&mut self, handle: CaptureHandle) -> Result<(), CaptureError> {
let _ = self let _ = self
.notify_capture .notify_capture
.send(CaptureEvent::Destroy(handle)) .send(LibeiNotifyEvent::Destroy(handle))
.await; .await;
Ok(()) Ok(())
} }

View File

@@ -5,7 +5,10 @@ use std::{
rc::Rc, rc::Rc,
sync::Arc, sync::Arc,
}; };
use tokio::{signal, sync::Notify}; use tokio::{
join, signal,
sync::{mpsc::channel, Notify},
};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
use crate::{ use crate::{
@@ -84,7 +87,7 @@ impl Server {
capture_backend: Option<CaptureBackend>, capture_backend: Option<CaptureBackend>,
emulation_backend: Option<EmulationBackend>, emulation_backend: Option<EmulationBackend>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// create frontend communication adapter // create frontend communication adapter, exit if already running
let frontend = match FrontendListener::new().await { let frontend = match FrontendListener::new().await {
Some(f) => f?, Some(f) => f?,
None => { None => {
@@ -95,23 +98,25 @@ impl Server {
}; };
let timer_notify = Arc::new(Notify::new()); let timer_notify = Arc::new(Notify::new());
let (frontend_notify_tx, frontend_notify_rx) = tokio::sync::mpsc::channel(1); let (frontend_tx, frontend_rx) = channel(1); /* events coming from frontends */
let cancellation_token = CancellationToken::new(); let cancellation_token = CancellationToken::new(); /* notify termination */
let notify_capture = Arc::new(Notify::new()); /* notify capture restart */
let notify_emulation = Arc::new(Notify::new()); /* notify emultation restart */
// udp task // udp task
let (mut udp_task, sender_tx, receiver_rx, port_tx) = let (mut udp_task, udp_send, udp_recv, port_tx) = network_task::new(
network_task::new(self.clone(), frontend_notify_tx.clone()).await?; self.clone(),
frontend_tx.clone(),
// restart notify tokens cancellation_token.clone(),
let notify_capture = Arc::new(Notify::new()); )
let notify_emulation = Arc::new(Notify::new()); .await?;
// input capture // input capture
let (mut capture_task, capture_channel) = capture_task::new( let (mut capture_task, capture_channel) = capture_task::new(
capture_backend, capture_backend,
self.clone(), self.clone(),
sender_tx.clone(), udp_send.clone(),
frontend_notify_tx.clone(), frontend_tx.clone(),
timer_notify.clone(), timer_notify.clone(),
self.release_bind.clone(), self.release_bind.clone(),
cancellation_token.clone(), cancellation_token.clone(),
@@ -122,10 +127,10 @@ impl Server {
let (mut emulation_task, emulate_channel) = emulation_task::new( let (mut emulation_task, emulate_channel) = emulation_task::new(
emulation_backend, emulation_backend,
self.clone(), self.clone(),
receiver_rx, udp_recv,
sender_tx.clone(), udp_send.clone(),
capture_channel.clone(), capture_channel.clone(),
frontend_notify_tx.clone(), frontend_tx.clone(),
timer_notify.clone(), timer_notify.clone(),
cancellation_token.clone(), cancellation_token.clone(),
notify_emulation.clone(), notify_emulation.clone(),
@@ -133,19 +138,23 @@ impl Server {
// create dns resolver // create dns resolver
let resolver = dns::DnsResolver::new().await?; let resolver = dns::DnsResolver::new().await?;
let (mut resolver_task, resolve_tx) = let (mut resolver_task, dns_req) = resolver_task::new(
resolver_task::new(resolver, self.clone(), frontend_notify_tx); resolver,
self.clone(),
frontend_tx,
cancellation_token.clone(),
);
// frontend listener // frontend listener
let (mut frontend_task, frontend_tx) = frontend_task::new( let (mut frontend_task, frontend_tx) = frontend_task::new(
frontend, frontend,
frontend_notify_rx, frontend_rx,
self.clone(), self.clone(),
notify_emulation, notify_emulation,
notify_capture, notify_capture,
capture_channel.clone(), capture_channel.clone(),
emulate_channel.clone(), emulate_channel.clone(),
resolve_tx.clone(), dns_req.clone(),
port_tx, port_tx,
cancellation_token.clone(), cancellation_token.clone(),
); );
@@ -153,7 +162,7 @@ impl Server {
// task that pings clients to see if they are responding // task that pings clients to see if they are responding
let mut ping_task = ping_task::new( let mut ping_task = ping_task::new(
self.clone(), self.clone(),
sender_tx.clone(), udp_send.clone(),
emulate_channel.clone(), emulate_channel.clone(),
capture_channel.clone(), capture_channel.clone(),
timer_notify, timer_notify,
@@ -176,7 +185,7 @@ impl Server {
.send(FrontendRequest::Activate(handle, true)) .send(FrontendRequest::Activate(handle, true))
.await?; .await?;
if let Some(hostname) = hostname { if let Some(hostname) = hostname {
let _ = resolve_tx.send(DnsRequest { hostname, handle }).await; let _ = dns_req.send(DnsRequest { hostname, handle }).await;
} }
} }
log::info!("running service"); log::info!("running service");
@@ -197,20 +206,12 @@ impl Server {
_ = &mut ping_task => { } _ = &mut ping_task => { }
} }
// cancel tasks
cancellation_token.cancel(); cancellation_token.cancel();
if !capture_task.is_finished() { let _ = join!(capture_task, emulation_task, frontend_task, udp_task);
let _ = capture_task.await;
}
if !emulation_task.is_finished() {
let _ = emulation_task.await;
}
if !frontend_task.is_finished() {
let _ = frontend_task.await;
}
resolver_task.abort(); resolver_task.abort();
udp_task.abort();
ping_task.abort(); ping_task.abort();
Ok(()) Ok(())

View File

@@ -123,15 +123,22 @@ async fn do_capture(
.await; .await;
// FIXME DUPLICATES // FIXME DUPLICATES
// let clients = server let clients = server
// .client_manager .client_manager
// .borrow() .borrow()
// .get_client_states() .get_client_states()
// .map(|(h, (c, _))| (h, c.pos)) .map(|(h, s)| (h, s.clone()))
// .collect::<Vec<_>>(); .collect::<Vec<_>>();
// for (handle, pos) in clients { log::info!("{clients:?}");
// capture.create(handle, pos.into()).await?; let clients = server
// } .client_manager
.borrow()
.get_client_states()
.map(|(h, (c, _))| (h, c.pos))
.collect::<Vec<_>>();
for (handle, pos) in clients {
capture.create(handle, pos.into()).await?;
}
let mut pressed_keys = HashSet::new(); let mut pressed_keys = HashSet::new();
loop { loop {

View File

@@ -6,6 +6,7 @@ use tokio::{
sync::mpsc::{Receiver, Sender}, sync::mpsc::{Receiver, Sender},
task::JoinHandle, task::JoinHandle,
}; };
use tokio_util::sync::CancellationToken;
use crate::frontend::FrontendEvent; use crate::frontend::FrontendEvent;
use input_event::{Event, ProtocolError}; use input_event::{Event, ProtocolError};
@@ -15,6 +16,7 @@ use super::Server;
pub async fn new( pub async fn new(
server: Server, server: Server,
frontend_notify_tx: Sender<FrontendEvent>, frontend_notify_tx: Sender<FrontendEvent>,
cancellation_token: CancellationToken,
) -> io::Result<( ) -> io::Result<(
JoinHandle<()>, JoinHandle<()>,
Sender<(Event, SocketAddr)>, Sender<(Event, SocketAddr)>,
@@ -38,8 +40,9 @@ pub async fn new(
_ = udp_sender => break, /* channel closed */ _ = udp_sender => break, /* channel closed */
port = port_rx.recv() => match port { port = port_rx.recv() => match port {
Some(port) => update_port(&server, &frontend_notify_tx, &mut socket, port).await, Some(port) => update_port(&server, &frontend_notify_tx, &mut socket, port).await,
_ => break, _ => continue,
} },
_ = cancellation_token.cancelled() => break, /* cancellation requested */
} }
} }
}); });
@@ -80,18 +83,13 @@ async fn udp_receiver(
) { ) {
loop { loop {
let event = receive_event(socket).await; let event = receive_event(socket).await;
if receiver_tx.send(event).await.is_err() { receiver_tx.send(event).await.expect("channel closed");
break;
}
} }
} }
async fn udp_sender(socket: &UdpSocket, rx: &mut Receiver<(Event, SocketAddr)>) { async fn udp_sender(socket: &UdpSocket, rx: &mut Receiver<(Event, SocketAddr)>) {
loop { loop {
let (event, addr) = match rx.recv().await { let (event, addr) = rx.recv().await.expect("channel closed");
Some(e) => e,
None => return,
};
if let Err(e) = send_event(socket, event, addr) { if let Err(e) = send_event(socket, event, addr) {
log::warn!("udp send failed: {e}"); log::warn!("udp send failed: {e}");
}; };

View File

@@ -1,6 +1,10 @@
use std::collections::HashSet; use std::collections::HashSet;
use tokio::{sync::mpsc::Sender, task::JoinHandle}; use tokio::{
sync::mpsc::{Receiver, Sender},
task::JoinHandle,
};
use tokio_util::sync::CancellationToken;
use crate::{client::ClientHandle, dns::DnsResolver, frontend::FrontendEvent}; use crate::{client::ClientHandle, dns::DnsResolver, frontend::FrontendEvent};
@@ -14,52 +18,65 @@ pub struct DnsRequest {
pub fn new( pub fn new(
resolver: DnsResolver, resolver: DnsResolver,
mut server: Server, server: Server,
mut frontend: Sender<FrontendEvent>, frontend: Sender<FrontendEvent>,
cancellation_token: CancellationToken,
) -> (JoinHandle<()>, Sender<DnsRequest>) { ) -> (JoinHandle<()>, Sender<DnsRequest>) {
let (dns_tx, mut dns_rx) = tokio::sync::mpsc::channel::<DnsRequest>(32); let (dns_tx, dns_rx) = tokio::sync::mpsc::channel::<DnsRequest>(32);
let resolver_task = tokio::task::spawn_local(async move { let resolver_task = tokio::task::spawn_local(async move {
loop { tokio::select! {
let (host, handle) = match dns_rx.recv().await { _ = cancellation_token.cancelled() => {},
Some(r) => (r.hostname, r.handle), _ = do_dns(resolver, server, frontend, dns_rx) => {},
None => break,
};
/* update resolving status */
if let Some((_, s)) = server.client_manager.borrow_mut().get_mut(handle) {
s.resolving = true;
}
notify_state_change(&mut frontend, &mut server, handle).await;
let ips = match resolver.resolve(&host).await {
Ok(ips) => ips,
Err(e) => {
log::warn!("could not resolve host '{host}': {e}");
vec![]
}
};
/* update ips and resolving state */
if let Some((c, s)) = server.client_manager.borrow_mut().get_mut(handle) {
let mut addrs = HashSet::from_iter(c.fix_ips.iter().cloned());
for ip in ips {
addrs.insert(ip);
}
s.ips = addrs;
s.resolving = false;
}
notify_state_change(&mut frontend, &mut server, handle).await;
} }
}); });
(resolver_task, dns_tx) (resolver_task, dns_tx)
} }
async fn do_dns(
resolver: DnsResolver,
server: Server,
frontend: Sender<FrontendEvent>,
mut dns_rx: Receiver<DnsRequest>,
) {
loop {
let (host, handle) = match dns_rx.recv().await {
Some(r) => (r.hostname, r.handle),
None => break,
};
/* update resolving status */
if let Some((_, s)) = server.client_manager.borrow_mut().get_mut(handle) {
s.resolving = true;
}
notify_state_change(&frontend, &server, handle).await;
let ips = match resolver.resolve(&host).await {
Ok(ips) => ips,
Err(e) => {
log::warn!("could not resolve host '{host}': {e}");
vec![]
}
};
/* update ips and resolving state */
if let Some((c, s)) = server.client_manager.borrow_mut().get_mut(handle) {
let mut addrs = HashSet::from_iter(c.fix_ips.iter().cloned());
for ip in ips {
addrs.insert(ip);
}
s.ips = addrs;
s.resolving = false;
}
notify_state_change(&frontend, &server, handle).await;
}
}
async fn notify_state_change( async fn notify_state_change(
frontend: &mut Sender<FrontendEvent>, frontend: &Sender<FrontendEvent>,
server: &mut Server, server: &Server,
handle: ClientHandle, handle: ClientHandle,
) { ) {
let state = server.client_manager.borrow_mut().get_mut(handle).cloned(); let state = server.client_manager.borrow().get(handle).cloned();
if let Some((config, state)) = state { if let Some((config, state)) = state {
let _ = frontend let _ = frontend
.send(FrontendEvent::State(handle, config, state)) .send(FrontendEvent::State(handle, config, state))