improve cancellation

This commit is contained in:
Ferdinand Schober
2024-07-11 16:34:41 +02:00
parent f0c9290579
commit 2d26bd6a0b
5 changed files with 117 additions and 94 deletions

View File

@@ -48,7 +48,7 @@ use super::{
/// events that necessitate restarting the capture session
#[derive(Clone, Copy, Debug)]
enum CaptureEvent {
enum LibeiNotifyEvent {
Create(CaptureHandle, Position),
Destroy(CaptureHandle),
}
@@ -58,7 +58,7 @@ pub struct LibeiInputCapture<'a> {
input_capture: Pin<Box<InputCapture<'a>>>,
capture_task: JoinHandle<Result<(), CaptureError>>,
event_rx: Receiver<(CaptureHandle, Event)>,
notify_capture: Sender<CaptureEvent>,
notify_capture: Sender<LibeiNotifyEvent>,
notify_release: Arc<Notify>,
cancellation_token: CancellationToken,
}
@@ -239,7 +239,7 @@ impl<'a> LibeiInputCapture<'a> {
async fn do_capture<'a>(
input_capture: *const InputCapture<'a>,
mut capture_event: Receiver<CaptureEvent>,
mut capture_event: Receiver<LibeiNotifyEvent>,
notify_release: Arc<Notify>,
session: Option<(Session<'a>, BitFlags<Capabilities>)>,
event_tx: Sender<(CaptureHandle, Event)>,
@@ -259,7 +259,7 @@ async fn do_capture<'a>(
let cancel_session = CancellationToken::new();
let cancel_update = CancellationToken::new();
let mut capture_event_occured: Option<CaptureEvent> = None;
let mut capture_event_occured: Option<LibeiNotifyEvent> = None;
let mut zones_have_changed = false;
// kill session if clients need to be updated
@@ -315,8 +315,8 @@ async fn do_capture<'a>(
// update clients if requested
if let Some(event) = capture_event_occured.take() {
match event {
CaptureEvent::Create(c, p) => active_clients.push((c, p)),
CaptureEvent::Destroy(c) => active_clients.retain(|(h, _)| *h != c),
LibeiNotifyEvent::Create(c, p) => active_clients.push((c, p)),
LibeiNotifyEvent::Destroy(c) => active_clients.retain(|(h, _)| *h != c),
}
}
@@ -658,7 +658,7 @@ impl<'a> LanMouseInputCapture for LibeiInputCapture<'a> {
async fn create(&mut self, handle: CaptureHandle, pos: Position) -> Result<(), CaptureError> {
let _ = self
.notify_capture
.send(CaptureEvent::Create(handle, pos))
.send(LibeiNotifyEvent::Create(handle, pos))
.await;
Ok(())
}
@@ -666,7 +666,7 @@ impl<'a> LanMouseInputCapture for LibeiInputCapture<'a> {
async fn destroy(&mut self, handle: CaptureHandle) -> Result<(), CaptureError> {
let _ = self
.notify_capture
.send(CaptureEvent::Destroy(handle))
.send(LibeiNotifyEvent::Destroy(handle))
.await;
Ok(())
}

View File

@@ -5,7 +5,10 @@ use std::{
rc::Rc,
sync::Arc,
};
use tokio::{signal, sync::Notify};
use tokio::{
join, signal,
sync::{mpsc::channel, Notify},
};
use tokio_util::sync::CancellationToken;
use crate::{
@@ -84,7 +87,7 @@ impl Server {
capture_backend: Option<CaptureBackend>,
emulation_backend: Option<EmulationBackend>,
) -> anyhow::Result<()> {
// create frontend communication adapter
// create frontend communication adapter, exit if already running
let frontend = match FrontendListener::new().await {
Some(f) => f?,
None => {
@@ -95,23 +98,25 @@ impl Server {
};
let timer_notify = Arc::new(Notify::new());
let (frontend_notify_tx, frontend_notify_rx) = tokio::sync::mpsc::channel(1);
let cancellation_token = CancellationToken::new();
let (frontend_tx, frontend_rx) = channel(1); /* events coming from frontends */
let cancellation_token = CancellationToken::new(); /* notify termination */
let notify_capture = Arc::new(Notify::new()); /* notify capture restart */
let notify_emulation = Arc::new(Notify::new()); /* notify emultation restart */
// udp task
let (mut udp_task, sender_tx, receiver_rx, port_tx) =
network_task::new(self.clone(), frontend_notify_tx.clone()).await?;
// restart notify tokens
let notify_capture = Arc::new(Notify::new());
let notify_emulation = Arc::new(Notify::new());
let (mut udp_task, udp_send, udp_recv, port_tx) = network_task::new(
self.clone(),
frontend_tx.clone(),
cancellation_token.clone(),
)
.await?;
// input capture
let (mut capture_task, capture_channel) = capture_task::new(
capture_backend,
self.clone(),
sender_tx.clone(),
frontend_notify_tx.clone(),
udp_send.clone(),
frontend_tx.clone(),
timer_notify.clone(),
self.release_bind.clone(),
cancellation_token.clone(),
@@ -122,10 +127,10 @@ impl Server {
let (mut emulation_task, emulate_channel) = emulation_task::new(
emulation_backend,
self.clone(),
receiver_rx,
sender_tx.clone(),
udp_recv,
udp_send.clone(),
capture_channel.clone(),
frontend_notify_tx.clone(),
frontend_tx.clone(),
timer_notify.clone(),
cancellation_token.clone(),
notify_emulation.clone(),
@@ -133,19 +138,23 @@ impl Server {
// create dns resolver
let resolver = dns::DnsResolver::new().await?;
let (mut resolver_task, resolve_tx) =
resolver_task::new(resolver, self.clone(), frontend_notify_tx);
let (mut resolver_task, dns_req) = resolver_task::new(
resolver,
self.clone(),
frontend_tx,
cancellation_token.clone(),
);
// frontend listener
let (mut frontend_task, frontend_tx) = frontend_task::new(
frontend,
frontend_notify_rx,
frontend_rx,
self.clone(),
notify_emulation,
notify_capture,
capture_channel.clone(),
emulate_channel.clone(),
resolve_tx.clone(),
dns_req.clone(),
port_tx,
cancellation_token.clone(),
);
@@ -153,7 +162,7 @@ impl Server {
// task that pings clients to see if they are responding
let mut ping_task = ping_task::new(
self.clone(),
sender_tx.clone(),
udp_send.clone(),
emulate_channel.clone(),
capture_channel.clone(),
timer_notify,
@@ -176,7 +185,7 @@ impl Server {
.send(FrontendRequest::Activate(handle, true))
.await?;
if let Some(hostname) = hostname {
let _ = resolve_tx.send(DnsRequest { hostname, handle }).await;
let _ = dns_req.send(DnsRequest { hostname, handle }).await;
}
}
log::info!("running service");
@@ -197,20 +206,12 @@ impl Server {
_ = &mut ping_task => { }
}
// cancel tasks
cancellation_token.cancel();
if !capture_task.is_finished() {
let _ = capture_task.await;
}
if !emulation_task.is_finished() {
let _ = emulation_task.await;
}
if !frontend_task.is_finished() {
let _ = frontend_task.await;
}
let _ = join!(capture_task, emulation_task, frontend_task, udp_task);
resolver_task.abort();
udp_task.abort();
ping_task.abort();
Ok(())

View File

@@ -123,15 +123,22 @@ async fn do_capture(
.await;
// FIXME DUPLICATES
// let clients = server
// .client_manager
// .borrow()
// .get_client_states()
// .map(|(h, (c, _))| (h, c.pos))
// .collect::<Vec<_>>();
// for (handle, pos) in clients {
// capture.create(handle, pos.into()).await?;
// }
let clients = server
.client_manager
.borrow()
.get_client_states()
.map(|(h, s)| (h, s.clone()))
.collect::<Vec<_>>();
log::info!("{clients:?}");
let clients = server
.client_manager
.borrow()
.get_client_states()
.map(|(h, (c, _))| (h, c.pos))
.collect::<Vec<_>>();
for (handle, pos) in clients {
capture.create(handle, pos.into()).await?;
}
let mut pressed_keys = HashSet::new();
loop {

View File

@@ -6,6 +6,7 @@ use tokio::{
sync::mpsc::{Receiver, Sender},
task::JoinHandle,
};
use tokio_util::sync::CancellationToken;
use crate::frontend::FrontendEvent;
use input_event::{Event, ProtocolError};
@@ -15,6 +16,7 @@ use super::Server;
pub async fn new(
server: Server,
frontend_notify_tx: Sender<FrontendEvent>,
cancellation_token: CancellationToken,
) -> io::Result<(
JoinHandle<()>,
Sender<(Event, SocketAddr)>,
@@ -38,8 +40,9 @@ pub async fn new(
_ = udp_sender => break, /* channel closed */
port = port_rx.recv() => match port {
Some(port) => update_port(&server, &frontend_notify_tx, &mut socket, port).await,
_ => break,
}
_ => continue,
},
_ = cancellation_token.cancelled() => break, /* cancellation requested */
}
}
});
@@ -80,18 +83,13 @@ async fn udp_receiver(
) {
loop {
let event = receive_event(socket).await;
if receiver_tx.send(event).await.is_err() {
break;
}
receiver_tx.send(event).await.expect("channel closed");
}
}
async fn udp_sender(socket: &UdpSocket, rx: &mut Receiver<(Event, SocketAddr)>) {
loop {
let (event, addr) = match rx.recv().await {
Some(e) => e,
None => return,
};
let (event, addr) = rx.recv().await.expect("channel closed");
if let Err(e) = send_event(socket, event, addr) {
log::warn!("udp send failed: {e}");
};

View File

@@ -1,6 +1,10 @@
use std::collections::HashSet;
use tokio::{sync::mpsc::Sender, task::JoinHandle};
use tokio::{
sync::mpsc::{Receiver, Sender},
task::JoinHandle,
};
use tokio_util::sync::CancellationToken;
use crate::{client::ClientHandle, dns::DnsResolver, frontend::FrontendEvent};
@@ -14,52 +18,65 @@ pub struct DnsRequest {
pub fn new(
resolver: DnsResolver,
mut server: Server,
mut frontend: Sender<FrontendEvent>,
server: Server,
frontend: Sender<FrontendEvent>,
cancellation_token: CancellationToken,
) -> (JoinHandle<()>, Sender<DnsRequest>) {
let (dns_tx, mut dns_rx) = tokio::sync::mpsc::channel::<DnsRequest>(32);
let (dns_tx, dns_rx) = tokio::sync::mpsc::channel::<DnsRequest>(32);
let resolver_task = tokio::task::spawn_local(async move {
loop {
let (host, handle) = match dns_rx.recv().await {
Some(r) => (r.hostname, r.handle),
None => break,
};
/* update resolving status */
if let Some((_, s)) = server.client_manager.borrow_mut().get_mut(handle) {
s.resolving = true;
}
notify_state_change(&mut frontend, &mut server, handle).await;
let ips = match resolver.resolve(&host).await {
Ok(ips) => ips,
Err(e) => {
log::warn!("could not resolve host '{host}': {e}");
vec![]
}
};
/* update ips and resolving state */
if let Some((c, s)) = server.client_manager.borrow_mut().get_mut(handle) {
let mut addrs = HashSet::from_iter(c.fix_ips.iter().cloned());
for ip in ips {
addrs.insert(ip);
}
s.ips = addrs;
s.resolving = false;
}
notify_state_change(&mut frontend, &mut server, handle).await;
tokio::select! {
_ = cancellation_token.cancelled() => {},
_ = do_dns(resolver, server, frontend, dns_rx) => {},
}
});
(resolver_task, dns_tx)
}
async fn do_dns(
resolver: DnsResolver,
server: Server,
frontend: Sender<FrontendEvent>,
mut dns_rx: Receiver<DnsRequest>,
) {
loop {
let (host, handle) = match dns_rx.recv().await {
Some(r) => (r.hostname, r.handle),
None => break,
};
/* update resolving status */
if let Some((_, s)) = server.client_manager.borrow_mut().get_mut(handle) {
s.resolving = true;
}
notify_state_change(&frontend, &server, handle).await;
let ips = match resolver.resolve(&host).await {
Ok(ips) => ips,
Err(e) => {
log::warn!("could not resolve host '{host}': {e}");
vec![]
}
};
/* update ips and resolving state */
if let Some((c, s)) = server.client_manager.borrow_mut().get_mut(handle) {
let mut addrs = HashSet::from_iter(c.fix_ips.iter().cloned());
for ip in ips {
addrs.insert(ip);
}
s.ips = addrs;
s.resolving = false;
}
notify_state_change(&frontend, &server, handle).await;
}
}
async fn notify_state_change(
frontend: &mut Sender<FrontendEvent>,
server: &mut Server,
frontend: &Sender<FrontendEvent>,
server: &Server,
handle: ClientHandle,
) {
let state = server.client_manager.borrow_mut().get_mut(handle).cloned();
let state = server.client_manager.borrow().get(handle).cloned();
if let Some((config, state)) = state {
let _ = frontend
.send(FrontendEvent::State(handle, config, state))