This commit is contained in:
Ferdinand Schober
2024-07-12 01:01:09 +02:00
parent 3014e404c3
commit 335a1fc3e2
7 changed files with 162 additions and 197 deletions

View File

@@ -3,7 +3,6 @@ use std::{
cell::{Cell, RefCell}, cell::{Cell, RefCell},
collections::HashSet, collections::HashSet,
rc::Rc, rc::Rc,
sync::Arc,
}; };
use tokio::{ use tokio::{
join, signal, join, signal,
@@ -14,7 +13,7 @@ use tokio_util::sync::CancellationToken;
use crate::{ use crate::{
client::{ClientConfig, ClientHandle, ClientManager, ClientState}, client::{ClientConfig, ClientHandle, ClientManager, ClientState},
config::{CaptureBackend, Config, EmulationBackend}, config::{CaptureBackend, Config, EmulationBackend},
dns, dns::DnsResolver,
frontend::{FrontendListener, FrontendRequest}, frontend::{FrontendListener, FrontendRequest},
server::capture_task::CaptureEvent, server::capture_task::CaptureEvent,
}; };
@@ -46,6 +45,15 @@ pub struct Server {
port: Rc<Cell<u16>>, port: Rc<Cell<u16>>,
state: Rc<Cell<State>>, state: Rc<Cell<State>>,
release_bind: Vec<input_event::scancode::Linux>, release_bind: Vec<input_event::scancode::Linux>,
notifies: Rc<Notifies>,
}
#[derive(Default)]
struct Notifies {
ping: Notify,
capture: Notify,
emulation: Notify,
cancel: CancellationToken,
} }
impl Server { impl Server {
@@ -72,13 +80,18 @@ impl Server {
let c = client_manager.get_mut(handle).expect("invalid handle"); let c = client_manager.get_mut(handle).expect("invalid handle");
*c = (client, state); *c = (client, state);
} }
// task notification tokens
let notifies = Rc::new(Notifies::default());
let release_bind = config.release_bind.clone(); let release_bind = config.release_bind.clone();
Self { Self {
active_client, active_client,
client_manager, client_manager,
port, port,
state, state,
release_bind, release_bind,
notifies,
} }
} }
@@ -97,76 +110,69 @@ impl Server {
} }
}; };
let notify_ping = Arc::new(Notify::new()); /* notify ping timer restart */ let (frontend_tx, frontend_rx) = channel(1); /* events for frontends */
let (frontend_tx, frontend_rx) = channel(1); /* events coming from frontends */ let (request_tx, request_rx) = channel(1); /* requests coming from frontends */
let cancellation_token = CancellationToken::new(); /* notify termination */ let (capture_tx, capture_rx) = channel(1); /* requests for input capture */
let notify_capture = Arc::new(Notify::new()); /* notify capture restart */ let (emulation_tx, emulation_rx) = channel(1); /* emulation requests */
let notify_emulation = Arc::new(Notify::new()); /* notify emultation restart */ let (udp_recv_tx, udp_recv_rx) = channel(1); /* udp receiver */
let (udp_send_tx, udp_send_rx) = channel(1); /* udp sender */
let (port_tx, port_rx) = channel(1); /* port change request */
let (dns_tx, dns_rx) = channel(1); /* dns requests */
// udp task // udp task
let (network, udp_send, udp_recv, port_tx) = network_task::new( let network = network_task::new(
self.clone(), self.clone(),
udp_recv_tx,
udp_send_rx,
port_rx,
frontend_tx.clone(), frontend_tx.clone(),
cancellation_token.clone(),
) )
.await?; .await?;
// input capture // input capture
let (capture, capture_channel) = capture_task::new( let capture = capture_task::new(
capture_backend,
self.clone(), self.clone(),
udp_send.clone(), capture_backend,
capture_rx,
udp_send_tx.clone(),
frontend_tx.clone(), frontend_tx.clone(),
notify_ping.clone(),
self.release_bind.clone(), self.release_bind.clone(),
cancellation_token.clone(),
notify_capture.clone(),
); );
// input emulation // input emulation
let (emulation, emulate_channel) = emulation_task::new( let emulation = emulation_task::new(
emulation_backend,
self.clone(), self.clone(),
udp_recv, emulation_backend,
udp_send.clone(), emulation_rx,
capture_channel.clone(), udp_recv_rx,
udp_send_tx.clone(),
capture_tx.clone(),
frontend_tx.clone(), frontend_tx.clone(),
notify_ping.clone(),
cancellation_token.clone(),
notify_emulation.clone(),
); );
// create dns resolver // create dns resolver
let resolver = dns::DnsResolver::new().await?; let resolver = DnsResolver::new().await?;
let (resolver, dns_req) = resolver_task::new( let resolver = resolver_task::new(resolver, dns_rx, self.clone(), frontend_tx);
resolver,
self.clone(),
frontend_tx,
cancellation_token.clone(),
);
// frontend listener // frontend listener
let (frontend, frontend_tx) = frontend_task::new( let frontend = frontend_task::new(
self.clone(),
frontend, frontend,
frontend_rx, frontend_rx,
self.clone(), request_tx.clone(),
notify_emulation, request_rx,
notify_capture, capture_tx.clone(),
capture_channel.clone(), emulation_tx.clone(),
emulate_channel.clone(), dns_tx.clone(),
dns_req.clone(),
port_tx, port_tx,
cancellation_token.clone(),
); );
// task that pings clients to see if they are responding // task that pings clients to see if they are responding
let ping = ping_task::new( let ping = ping_task::new(
self.clone(), self.clone(),
udp_send.clone(), udp_send_tx.clone(),
emulate_channel.clone(), emulation_tx.clone(),
capture_channel.clone(), capture_tx.clone(),
notify_ping,
cancellation_token.clone(),
); );
let active = self let active = self
@@ -182,11 +188,11 @@ impl Server {
}) })
.collect::<Vec<_>>(); .collect::<Vec<_>>();
for (handle, hostname) in active { for (handle, hostname) in active {
frontend_tx request_tx
.send(FrontendRequest::Activate(handle, true)) .send(FrontendRequest::Activate(handle, true))
.await?; .await?;
if let Some(hostname) = hostname { if let Some(hostname) = hostname {
let _ = dns_req.send(DnsRequest { hostname, handle }).await; let _ = dns_tx.send(DnsRequest { hostname, handle }).await;
} }
} }
@@ -194,9 +200,45 @@ impl Server {
signal::ctrl_c().await.expect("failed to listen for CTRL+C"); signal::ctrl_c().await.expect("failed to listen for CTRL+C");
log::info!("terminating service"); log::info!("terminating service");
cancellation_token.cancel(); self.cancel();
let _ = join!(capture, emulation, frontend, network, resolver, ping); let _ = join!(capture, emulation, frontend, network, resolver, ping);
Ok(()) Ok(())
} }
fn cancel(&self) {
self.notifies.cancel.cancel();
}
async fn cancelled(&self) {
self.notifies.cancel.cancelled().await
}
fn is_cancelled(&self) -> bool {
self.notifies.cancel.is_cancelled()
}
fn notify_capture(&self) {
self.notifies.capture.notify_waiters()
}
async fn capture_notified(&self) {
self.notifies.capture.notified().await
}
fn notify_emulation(&self) {
self.notifies.emulation.notify_waiters()
}
async fn emulation_notified(&self) {
self.notifies.emulation.notified().await
}
fn restart_ping_timer(&self) {
self.notifies.ping.notify_waiters()
}
async fn ping_timer_notified(&self) {
self.notifies.ping.notified().await
}
} }

View File

@@ -1,14 +1,10 @@
use futures::StreamExt; use futures::StreamExt;
use std::{collections::HashSet, net::SocketAddr, sync::Arc}; use std::{collections::HashSet, net::SocketAddr};
use thiserror::Error; use thiserror::Error;
use tokio_util::sync::CancellationToken;
use tokio::{ use tokio::{
process::Command, process::Command,
sync::{ sync::mpsc::{Receiver, Sender},
mpsc::{Receiver, Sender},
Notify,
},
task::JoinHandle, task::JoinHandle,
}; };
@@ -46,41 +42,31 @@ pub enum CaptureEvent {
} }
pub fn new( pub fn new(
backend: Option<CaptureBackend>,
server: Server, server: Server,
backend: Option<CaptureBackend>,
capture_rx: Receiver<CaptureEvent>,
udp_send: Sender<(Event, SocketAddr)>, udp_send: Sender<(Event, SocketAddr)>,
frontend_tx: Sender<FrontendEvent>, frontend_tx: Sender<FrontendEvent>,
notify_ping: Arc<Notify>,
release_bind: Vec<scancode::Linux>, release_bind: Vec<scancode::Linux>,
cancellation_token: CancellationToken, ) -> JoinHandle<()> {
notify_capture: Arc<Notify>,
) -> (JoinHandle<()>, Sender<CaptureEvent>) {
let (tx, rx) = tokio::sync::mpsc::channel(32);
let backend = backend.map(|b| b.into()); let backend = backend.map(|b| b.into());
let task = tokio::task::spawn_local(capture_task( tokio::task::spawn_local(capture_task(
backend,
server, server,
backend,
udp_send, udp_send,
rx, capture_rx,
frontend_tx, frontend_tx,
notify_ping,
release_bind, release_bind,
cancellation_token, ))
notify_capture,
));
(task, tx)
} }
async fn capture_task( async fn capture_task(
backend: Option<input_capture::Backend>,
server: Server, server: Server,
backend: Option<input_capture::Backend>,
sender_tx: Sender<(Event, SocketAddr)>, sender_tx: Sender<(Event, SocketAddr)>,
mut notify_rx: Receiver<CaptureEvent>, mut notify_rx: Receiver<CaptureEvent>,
frontend_tx: Sender<FrontendEvent>, frontend_tx: Sender<FrontendEvent>,
timer_notify: Arc<Notify>,
release_bind: Vec<scancode::Linux>, release_bind: Vec<scancode::Linux>,
cancellation_token: CancellationToken,
notify_capture: Arc<Notify>,
) { ) {
loop { loop {
if let Err(e) = do_capture( if let Err(e) = do_capture(
@@ -89,9 +75,7 @@ async fn capture_task(
&sender_tx, &sender_tx,
&mut notify_rx, &mut notify_rx,
&frontend_tx, &frontend_tx,
&timer_notify,
&release_bind, &release_bind,
&cancellation_token,
) )
.await .await
{ {
@@ -100,10 +84,10 @@ async fn capture_task(
let _ = frontend_tx let _ = frontend_tx
.send(FrontendEvent::CaptureStatus(Status::Disabled)) .send(FrontendEvent::CaptureStatus(Status::Disabled))
.await; .await;
if cancellation_token.is_cancelled() { if server.is_cancelled() {
break; break;
} }
notify_capture.notified().await; server.capture_notified().await;
} }
} }
@@ -113,9 +97,7 @@ async fn do_capture(
sender_tx: &Sender<(Event, SocketAddr)>, sender_tx: &Sender<(Event, SocketAddr)>,
notify_rx: &mut Receiver<CaptureEvent>, notify_rx: &mut Receiver<CaptureEvent>,
frontend_tx: &Sender<FrontendEvent>, frontend_tx: &Sender<FrontendEvent>,
timer_notify: &Notify,
release_bind: &[scancode::Linux], release_bind: &[scancode::Linux],
cancellation_token: &CancellationToken,
) -> Result<(), LanMouseCaptureError> { ) -> Result<(), LanMouseCaptureError> {
let mut capture = input_capture::create(backend).await?; let mut capture = input_capture::create(backend).await?;
let _ = frontend_tx let _ = frontend_tx
@@ -145,7 +127,7 @@ async fn do_capture(
tokio::select! { tokio::select! {
event = capture.next() => { event = capture.next() => {
match event { match event {
Some(Ok(event)) => handle_capture_event(server, &mut capture, sender_tx, timer_notify, event, &mut pressed_keys, release_bind).await?, Some(Ok(event)) => handle_capture_event(server, &mut capture, sender_tx, event, &mut pressed_keys, release_bind).await?,
Some(Err(e)) => return Err(e.into()), Some(Err(e)) => return Err(e.into()),
None => return Ok(()), None => return Ok(()),
} }
@@ -164,7 +146,7 @@ async fn do_capture(
None => break, None => break,
} }
} }
_ = cancellation_token.cancelled() => break, _ = server.cancelled() => break,
} }
} }
capture.terminate().await?; capture.terminate().await?;
@@ -185,7 +167,6 @@ async fn handle_capture_event(
server: &Server, server: &Server,
capture: &mut Box<dyn InputCapture>, capture: &mut Box<dyn InputCapture>,
sender_tx: &Sender<(Event, SocketAddr)>, sender_tx: &Sender<(Event, SocketAddr)>,
timer_notify: &Notify,
event: (CaptureHandle, Event), event: (CaptureHandle, Event),
pressed_keys: &mut HashSet<scancode::Linux>, pressed_keys: &mut HashSet<scancode::Linux>,
release_bind: &[scancode::Linux], release_bind: &[scancode::Linux],
@@ -249,7 +230,7 @@ async fn handle_capture_event(
}; };
if start_timer { if start_timer {
timer_notify.notify_waiters(); server.restart_ping_timer();
} }
if enter { if enter {
spawn_hook_command(server, handle); spawn_hook_command(server, handle);

View File

@@ -1,14 +1,10 @@
use std::{net::SocketAddr, sync::Arc}; use std::net::SocketAddr;
use thiserror::Error; use thiserror::Error;
use tokio::{ use tokio::{
sync::{ sync::mpsc::{Receiver, Sender},
mpsc::{Receiver, Sender},
Notify,
},
task::JoinHandle, task::JoinHandle,
}; };
use tokio_util::sync::CancellationToken;
use crate::{ use crate::{
client::{ClientHandle, ClientManager}, client::{ClientHandle, ClientManager},
@@ -36,31 +32,24 @@ pub enum EmulationEvent {
} }
pub fn new( pub fn new(
backend: Option<EmulationBackend>,
server: Server, server: Server,
backend: Option<EmulationBackend>,
emulation_rx: Receiver<EmulationEvent>,
udp_rx: Receiver<Result<(Event, SocketAddr), NetworkError>>, udp_rx: Receiver<Result<(Event, SocketAddr), NetworkError>>,
sender_tx: Sender<(Event, SocketAddr)>, sender_tx: Sender<(Event, SocketAddr)>,
capture_tx: Sender<CaptureEvent>, capture_tx: Sender<CaptureEvent>,
frontend_tx: Sender<FrontendEvent>, frontend_tx: Sender<FrontendEvent>,
timer_notify: Arc<Notify>, ) -> JoinHandle<()> {
cancellation_token: CancellationToken,
notify_emulation: Arc<Notify>,
) -> (JoinHandle<()>, Sender<EmulationEvent>) {
let (tx, rx) = tokio::sync::mpsc::channel(32);
let emulation_task = emulation_task( let emulation_task = emulation_task(
backend, backend,
rx, emulation_rx,
server, server,
udp_rx, udp_rx,
sender_tx, sender_tx,
capture_tx, capture_tx,
frontend_tx, frontend_tx,
timer_notify,
cancellation_token,
notify_emulation,
); );
let emulate_task = tokio::task::spawn_local(emulation_task); tokio::task::spawn_local(emulation_task)
(emulate_task, tx)
} }
#[derive(Debug, Error)] #[derive(Debug, Error)]
@@ -79,21 +68,16 @@ async fn emulation_task(
sender_tx: Sender<(Event, SocketAddr)>, sender_tx: Sender<(Event, SocketAddr)>,
capture_tx: Sender<CaptureEvent>, capture_tx: Sender<CaptureEvent>,
frontend_tx: Sender<FrontendEvent>, frontend_tx: Sender<FrontendEvent>,
timer_notify: Arc<Notify>,
cancellation_token: CancellationToken,
notify_emulation: Arc<Notify>,
) { ) {
loop { loop {
match do_emulation( match do_emulation(
&server,
backend, backend,
&mut rx, &mut rx,
&server,
&mut udp_rx, &mut udp_rx,
&sender_tx, &sender_tx,
&capture_tx, &capture_tx,
&frontend_tx, &frontend_tx,
&timer_notify,
&cancellation_token,
) )
.await .await
{ {
@@ -105,25 +89,23 @@ async fn emulation_task(
let _ = frontend_tx let _ = frontend_tx
.send(FrontendEvent::EmulationStatus(Status::Disabled)) .send(FrontendEvent::EmulationStatus(Status::Disabled))
.await; .await;
if cancellation_token.is_cancelled() { if server.notifies.cancel.is_cancelled() {
break; break;
} }
log::info!("waiting for user to request input emulation ..."); log::info!("waiting for user to request input emulation ...");
notify_emulation.notified().await; server.emulation_notified().await;
log::info!("... done"); log::info!("... done");
} }
} }
async fn do_emulation( async fn do_emulation(
server: &Server,
backend: Option<EmulationBackend>, backend: Option<EmulationBackend>,
rx: &mut Receiver<EmulationEvent>, rx: &mut Receiver<EmulationEvent>,
server: &Server,
udp_rx: &mut Receiver<Result<(Event, SocketAddr), NetworkError>>, udp_rx: &mut Receiver<Result<(Event, SocketAddr), NetworkError>>,
sender_tx: &Sender<(Event, SocketAddr)>, sender_tx: &Sender<(Event, SocketAddr)>,
capture_tx: &Sender<CaptureEvent>, capture_tx: &Sender<CaptureEvent>,
frontend_tx: &Sender<FrontendEvent>, frontend_tx: &Sender<FrontendEvent>,
timer_notify: &Notify,
cancellation_token: &CancellationToken,
) -> Result<(), LanMouseEmulationError> { ) -> Result<(), LanMouseEmulationError> {
let backend = backend.map(|b| b.into()); let backend = backend.map(|b| b.into());
log::info!("creating input emulation..."); log::info!("creating input emulation...");
@@ -132,17 +114,7 @@ async fn do_emulation(
.send(FrontendEvent::EmulationStatus(Status::Enabled)) .send(FrontendEvent::EmulationStatus(Status::Enabled))
.await; .await;
let res = do_emulation_session( let res = do_emulation_session(server, &mut emulation, rx, udp_rx, sender_tx, capture_tx).await;
&mut emulation,
rx,
server,
udp_rx,
sender_tx,
capture_tx,
timer_notify,
cancellation_token,
)
.await;
emulation.terminate().await; emulation.terminate().await;
res?; res?;
@@ -165,14 +137,12 @@ async fn do_emulation(
} }
async fn do_emulation_session( async fn do_emulation_session(
server: &Server,
emulation: &mut Box<dyn InputEmulation>, emulation: &mut Box<dyn InputEmulation>,
rx: &mut Receiver<EmulationEvent>, rx: &mut Receiver<EmulationEvent>,
server: &Server,
udp_rx: &mut Receiver<Result<(Event, SocketAddr), NetworkError>>, udp_rx: &mut Receiver<Result<(Event, SocketAddr), NetworkError>>,
sender_tx: &Sender<(Event, SocketAddr)>, sender_tx: &Sender<(Event, SocketAddr)>,
capture_tx: &Sender<CaptureEvent>, capture_tx: &Sender<CaptureEvent>,
timer_notify: &Notify,
cancellation_token: &CancellationToken,
) -> Result<(), LanMouseEmulationError> { ) -> Result<(), LanMouseEmulationError> {
let mut last_ignored = None; let mut last_ignored = None;
@@ -186,7 +156,7 @@ async fn do_emulation_session(
continue; continue;
} }
}; };
handle_udp_rx(&server, &capture_tx, emulation, &sender_tx, &mut last_ignored, udp_event, &timer_notify).await?; handle_udp_rx(&server, &capture_tx, emulation, &sender_tx, &mut last_ignored, udp_event).await?;
} }
emulate_event = rx.recv() => { emulate_event = rx.recv() => {
match emulate_event.expect("channel closed") { match emulate_event.expect("channel closed") {
@@ -195,7 +165,7 @@ async fn do_emulation_session(
EmulationEvent::ReleaseKeys(c) => release_keys(&server, emulation, c).await?, EmulationEvent::ReleaseKeys(c) => release_keys(&server, emulation, c).await?,
} }
} }
_ = cancellation_token.cancelled() => break Ok(()), _ = server.notifies.cancel.cancelled() => break Ok(()),
} }
} }
} }
@@ -207,7 +177,6 @@ async fn handle_udp_rx(
sender_tx: &Sender<(Event, SocketAddr)>, sender_tx: &Sender<(Event, SocketAddr)>,
last_ignored: &mut Option<SocketAddr>, last_ignored: &mut Option<SocketAddr>,
event: (Event, SocketAddr), event: (Event, SocketAddr),
timer_notify: &Notify,
) -> Result<(), EmulationError> { ) -> Result<(), EmulationError> {
let (event, addr) = event; let (event, addr) = event;
@@ -257,7 +226,7 @@ async fn handle_udp_rx(
); );
// restart timer if necessary // restart timer if necessary
if restart_timer { if restart_timer {
timer_notify.notify_waiters(); server.restart_ping_timer();
} }
ignore_event ignore_event
} else { } else {

View File

@@ -2,7 +2,6 @@ use std::{
collections::HashSet, collections::HashSet,
io::ErrorKind, io::ErrorKind,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
sync::Arc,
}; };
#[cfg(unix)] #[cfg(unix)]
use tokio::net::UnixStream; use tokio::net::UnixStream;
@@ -12,13 +11,9 @@ use tokio::net::TcpStream;
use tokio::{ use tokio::{
io::ReadHalf, io::ReadHalf,
sync::{ sync::mpsc::{Receiver, Sender},
mpsc::{Receiver, Sender},
Notify,
},
task::JoinHandle, task::JoinHandle,
}; };
use tokio_util::sync::CancellationToken;
use crate::{ use crate::{
client::{ClientHandle, Position}, client::{ClientHandle, Position},
@@ -30,32 +25,30 @@ use super::{
}; };
pub(crate) fn new( pub(crate) fn new(
server: Server,
mut frontend: FrontendListener, mut frontend: FrontendListener,
mut event: Receiver<FrontendEvent>, mut event: Receiver<FrontendEvent>,
server: Server, request_tx: Sender<FrontendRequest>,
notify_emulation: Arc<Notify>, mut request_rx: Receiver<FrontendRequest>,
notify_capture: Arc<Notify>,
capture: Sender<CaptureEvent>, capture: Sender<CaptureEvent>,
emulate: Sender<EmulationEvent>, emulate: Sender<EmulationEvent>,
resolve_ch: Sender<DnsRequest>, resolve_ch: Sender<DnsRequest>,
port_tx: Sender<u16>, port_tx: Sender<u16>,
cancellation_token: CancellationToken, ) -> JoinHandle<()> {
) -> (JoinHandle<()>, Sender<FrontendRequest>) { let request = request_tx.clone();
let (request_tx, mut request) = tokio::sync::mpsc::channel(32); tokio::task::spawn_local(async move {
let request_tx_clone = request_tx.clone();
let frontend_task = tokio::task::spawn_local(async move {
let mut join_handles = vec![]; let mut join_handles = vec![];
loop { loop {
tokio::select! { tokio::select! {
stream = frontend.accept() => { stream = frontend.accept() => {
match stream { match stream {
Ok(s) => join_handles.push(handle_frontend_stream(&request_tx_clone, s, cancellation_token.clone())), Ok(s) => join_handles.push(handle_frontend_stream(server.clone(), &request, s)),
Err(e) => log::warn!("error accepting frontend connection: {e}"), Err(e) => log::warn!("error accepting frontend connection: {e}"),
}; };
} }
request = request.recv() => { request = request_rx.recv() => {
let request = request.expect("frontend request channel closed"); let request = request.expect("frontend request channel closed");
if handle_frontend_event(&server, &notify_capture, &notify_emulation, &capture, &emulate, &resolve_ch, &mut frontend, &port_tx, request).await { if handle_frontend_event(&server, &capture, &emulate, &resolve_ch, &mut frontend, &port_tx, request).await {
break; break;
} }
} }
@@ -63,33 +56,32 @@ pub(crate) fn new(
let event = event.expect("channel closed"); let event = event.expect("channel closed");
let _ = frontend.broadcast_event(event).await; let _ = frontend.broadcast_event(event).await;
} }
_ = cancellation_token.cancelled() => { _ = server.cancelled() => {
futures::future::join_all(join_handles).await; futures::future::join_all(join_handles).await;
break; break;
} }
} }
} }
}); })
(frontend_task, request_tx)
} }
fn handle_frontend_stream( fn handle_frontend_stream(
frontend_tx: &Sender<FrontendRequest>, server: Server,
request_tx: &Sender<FrontendRequest>,
#[cfg(unix)] stream: ReadHalf<UnixStream>, #[cfg(unix)] stream: ReadHalf<UnixStream>,
#[cfg(windows)] stream: ReadHalf<TcpStream>, #[cfg(windows)] stream: ReadHalf<TcpStream>,
cancellation_token: CancellationToken,
) -> JoinHandle<()> { ) -> JoinHandle<()> {
let tx = frontend_tx.clone(); let tx = request_tx.clone();
tokio::task::spawn_local(async move { tokio::task::spawn_local(async move {
tokio::select! { tokio::select! {
_ = listen_frontend(tx, stream) => {}, _ = listen_frontend(tx, stream) => {},
_ = cancellation_token.cancelled() => {}, _ = server.cancelled() => {},
} }
}) })
} }
async fn listen_frontend( async fn listen_frontend(
tx: Sender<FrontendRequest>, request_tx: Sender<FrontendRequest>,
#[cfg(unix)] mut stream: ReadHalf<UnixStream>, #[cfg(unix)] mut stream: ReadHalf<UnixStream>,
#[cfg(windows)] mut stream: ReadHalf<TcpStream>, #[cfg(windows)] mut stream: ReadHalf<TcpStream>,
) { ) {
@@ -98,7 +90,7 @@ async fn listen_frontend(
let request = frontend::wait_for_request(&mut stream).await; let request = frontend::wait_for_request(&mut stream).await;
match request { match request {
Ok(request) => { Ok(request) => {
let _ = tx.send(request).await; let _ = request_tx.send(request).await;
} }
Err(e) => { Err(e) => {
if let Some(e) = e.downcast_ref::<io::Error>() { if let Some(e) = e.downcast_ref::<io::Error>() {
@@ -115,8 +107,6 @@ async fn listen_frontend(
async fn handle_frontend_event( async fn handle_frontend_event(
server: &Server, server: &Server,
notify_capture: &Notify,
notify_emulation: &Notify,
capture: &Sender<CaptureEvent>, capture: &Sender<CaptureEvent>,
emulate: &Sender<EmulationEvent>, emulate: &Sender<EmulationEvent>,
resolve_tx: &Sender<DnsRequest>, resolve_tx: &Sender<DnsRequest>,
@@ -127,11 +117,12 @@ async fn handle_frontend_event(
log::debug!("frontend: {event:?}"); log::debug!("frontend: {event:?}");
match event { match event {
FrontendRequest::EnableCapture => { FrontendRequest::EnableCapture => {
notify_capture.notify_waiters(); log::info!("received capture enable request");
server.notify_capture();
} }
FrontendRequest::EnableEmulation => { FrontendRequest::EnableEmulation => {
log::info!("received emulation enable request"); log::info!("received emulation enable request");
notify_emulation.notify_waiters(); server.notify_emulation();
} }
FrontendRequest::Create => { FrontendRequest::Create => {
let handle = add_client(server, frontend).await; let handle = add_client(server, frontend).await;

View File

@@ -6,7 +6,6 @@ use tokio::{
sync::mpsc::{Receiver, Sender}, sync::mpsc::{Receiver, Sender},
task::JoinHandle, task::JoinHandle,
}; };
use tokio_util::sync::CancellationToken;
use crate::frontend::FrontendEvent; use crate::frontend::FrontendEvent;
use input_event::{Event, ProtocolError}; use input_event::{Event, ProtocolError};
@@ -15,25 +14,19 @@ use super::Server;
pub async fn new( pub async fn new(
server: Server, server: Server,
udp_recv_tx: Sender<Result<(Event, SocketAddr), NetworkError>>,
udp_send_rx: Receiver<(Event, SocketAddr)>,
mut port_rx: Receiver<u16>,
frontend_notify_tx: Sender<FrontendEvent>, frontend_notify_tx: Sender<FrontendEvent>,
cancellation_token: CancellationToken, ) -> io::Result<JoinHandle<()>> {
) -> io::Result<(
JoinHandle<()>,
Sender<(Event, SocketAddr)>,
Receiver<Result<(Event, SocketAddr), NetworkError>>,
Sender<u16>,
)> {
// bind the udp socket // bind the udp socket
let listen_addr = SocketAddr::new("0.0.0.0".parse().unwrap(), server.port.get()); let listen_addr = SocketAddr::new("0.0.0.0".parse().unwrap(), server.port.get());
let mut socket = UdpSocket::bind(listen_addr).await?; let mut socket = UdpSocket::bind(listen_addr).await?;
let (receiver_tx, receiver_rx) = tokio::sync::mpsc::channel(32);
let (sender_tx, sender_rx) = tokio::sync::mpsc::channel(32);
let (port_tx, mut port_rx) = tokio::sync::mpsc::channel(32);
let udp_task = tokio::task::spawn_local(async move { Ok(tokio::task::spawn_local(async move {
let mut sender_rx = sender_rx; let mut sender_rx = udp_send_rx;
loop { loop {
let udp_receiver = udp_receiver(&socket, &receiver_tx); let udp_receiver = udp_receiver(&socket, &udp_recv_tx);
let udp_sender = udp_sender(&socket, &mut sender_rx); let udp_sender = udp_sender(&socket, &mut sender_rx);
tokio::select! { tokio::select! {
_ = udp_receiver => break, /* channel closed */ _ = udp_receiver => break, /* channel closed */
@@ -42,11 +35,10 @@ pub async fn new(
Some(port) => update_port(&server, &frontend_notify_tx, &mut socket, port).await, Some(port) => update_port(&server, &frontend_notify_tx, &mut socket, port).await,
_ => continue, _ => continue,
}, },
_ = cancellation_token.cancelled() => break, /* cancellation requested */ _ = server.cancelled() => break, /* cancellation requested */
} }
} }
}); }))
Ok((udp_task, sender_tx, receiver_rx, port_tx))
} }
async fn update_port( async fn update_port(

View File

@@ -1,12 +1,8 @@
use std::{net::SocketAddr, sync::Arc, time::Duration}; use std::{net::SocketAddr, time::Duration};
use tokio::{ use tokio::{sync::mpsc::Sender, task::JoinHandle};
sync::{mpsc::Sender, Notify},
task::JoinHandle,
};
use input_event::Event; use input_event::Event;
use tokio_util::sync::CancellationToken;
use crate::client::ClientHandle; use crate::client::ClientHandle;
@@ -19,28 +15,25 @@ pub fn new(
sender_ch: Sender<(Event, SocketAddr)>, sender_ch: Sender<(Event, SocketAddr)>,
emulate_notify: Sender<EmulationEvent>, emulate_notify: Sender<EmulationEvent>,
capture_notify: Sender<CaptureEvent>, capture_notify: Sender<CaptureEvent>,
timer_notify: Arc<Notify>,
cancellation_token: CancellationToken,
) -> JoinHandle<()> { ) -> JoinHandle<()> {
// timer task // timer task
tokio::task::spawn_local(async move { tokio::task::spawn_local(async move {
tokio::select! { tokio::select! {
_ = cancellation_token.cancelled() => {} _ = server.notifies.cancel.cancelled() => {}
_ = ping_task(server, sender_ch, emulate_notify, capture_notify, timer_notify) => {} _ = ping_task(&server, sender_ch, emulate_notify, capture_notify) => {}
} }
}) })
} }
async fn ping_task( async fn ping_task(
server: Server, server: &Server,
sender_ch: Sender<(Event, SocketAddr)>, sender_ch: Sender<(Event, SocketAddr)>,
emulate_notify: Sender<EmulationEvent>, emulate_notify: Sender<EmulationEvent>,
capture_notify: Sender<CaptureEvent>, capture_notify: Sender<CaptureEvent>,
timer_notify: Arc<Notify>,
) { ) {
loop { loop {
// wait for wake up signal // wait for wake up signal
timer_notify.notified().await; server.ping_timer_notified().await;
loop { loop {
let receiving = server.state.get() == State::Receiving; let receiving = server.state.get() == State::Receiving;
let (ping_clients, ping_addrs) = { let (ping_clients, ping_addrs) = {

View File

@@ -4,7 +4,6 @@ use tokio::{
sync::mpsc::{Receiver, Sender}, sync::mpsc::{Receiver, Sender},
task::JoinHandle, task::JoinHandle,
}; };
use tokio_util::sync::CancellationToken;
use crate::{client::ClientHandle, dns::DnsResolver, frontend::FrontendEvent}; use crate::{client::ClientHandle, dns::DnsResolver, frontend::FrontendEvent};
@@ -18,23 +17,21 @@ pub struct DnsRequest {
pub fn new( pub fn new(
resolver: DnsResolver, resolver: DnsResolver,
dns_rx: Receiver<DnsRequest>,
server: Server, server: Server,
frontend: Sender<FrontendEvent>, frontend: Sender<FrontendEvent>,
cancellation_token: CancellationToken, ) -> JoinHandle<()> {
) -> (JoinHandle<()>, Sender<DnsRequest>) { tokio::task::spawn_local(async move {
let (dns_tx, dns_rx) = tokio::sync::mpsc::channel::<DnsRequest>(32);
let resolver_task = tokio::task::spawn_local(async move {
tokio::select! { tokio::select! {
_ = cancellation_token.cancelled() => {}, _ = server.cancelled() => {},
_ = do_dns(resolver, server, frontend, dns_rx) => {}, _ = do_dns(&server, resolver, frontend, dns_rx) => {},
} }
}); })
(resolver_task, dns_tx)
} }
async fn do_dns( async fn do_dns(
server: &Server,
resolver: DnsResolver, resolver: DnsResolver,
server: Server,
frontend: Sender<FrontendEvent>, frontend: Sender<FrontendEvent>,
mut dns_rx: Receiver<DnsRequest>, mut dns_rx: Receiver<DnsRequest>,
) { ) {