diff --git a/src/capture.rs b/src/capture.rs index ac2ab95..6ecfb3e 100644 --- a/src/capture.rs +++ b/src/capture.rs @@ -1,5 +1,6 @@ use std::{ cell::Cell, + rc::Rc, time::{Duration, Instant}, }; @@ -7,7 +8,7 @@ use futures::StreamExt; use input_capture::{ CaptureError, CaptureEvent, CaptureHandle, InputCapture, InputCaptureError, Position, }; -use lan_mouse_ipc::{ClientHandle, Status}; +use lan_mouse_ipc::ClientHandle; use lan_mouse_proto::ProtoEvent; use local_channel::mpsc::{channel, Receiver, Sender}; use tokio::{ @@ -18,9 +19,19 @@ use tokio::{ use crate::{connect::LanMouseConnection, service::Service}; pub(crate) struct Capture { + _active: Rc>>, tx: Sender, task: JoinHandle<()>, - enter_rx: Receiver, + event_rx: Receiver, +} + +pub(crate) enum ICaptureEvent { + /// a client was entered + ClientEntered(CaptureHandle), + /// capture disabled + CaptureDisabled, + /// capture disabled + CaptureEnabled, } #[derive(Clone, Copy, Debug)] @@ -31,19 +42,48 @@ enum CaptureRequest { Create(CaptureHandle, Position), /// destory a capture client Destroy(CaptureHandle), + /// terminate + Terminate, + /// reenable input capture + Reenable, } impl Capture { - pub(crate) fn new(server: Service, conn: LanMouseConnection) -> Self { + pub(crate) fn new( + backend: Option, + conn: LanMouseConnection, + service: Service, + ) -> Self { let (tx, rx) = channel(); - let (enter_tx, enter_rx) = channel(); - let task = spawn_local(Self::run(server.clone(), rx, conn, enter_tx)); - Self { tx, task, enter_rx } + let (event_tx, event_rx) = channel(); + let active = Rc::new(Cell::new(None)); + let task = spawn_local(Self::run( + active.clone(), + service, + backend, + rx, + conn, + event_tx, + )); + Self { + _active: active, + tx, + task, + event_rx, + } + } + + pub(crate) fn reenable(&self) { + self.tx + .send(CaptureRequest::Reenable) + .expect("channel closed"); } pub(crate) async fn terminate(&mut self) { + self.tx + .send(CaptureRequest::Terminate) + .expect("channel closed"); log::debug!("terminating capture"); - self.tx.close(); if let Err(e) = (&mut self.task).await { log::warn!("{e}"); } @@ -61,36 +101,45 @@ impl Capture { .expect("channel closed"); } - #[allow(unused)] pub(crate) fn release(&self) { self.tx .send(CaptureRequest::Release) .expect("channel closed"); } - pub(crate) async fn entered(&mut self) -> CaptureHandle { - self.enter_rx.recv().await.expect("channel closed") + pub(crate) async fn event(&mut self) -> ICaptureEvent { + self.event_rx.recv().await.expect("channel closed") } async fn run( - server: Service, + active: Rc>>, + service: Service, + backend: Option, mut rx: Receiver, mut conn: LanMouseConnection, - mut enter_tx: Sender, + mut event_tx: Sender, ) { loop { - if let Err(e) = do_capture(&server, &mut conn, &mut rx, &mut enter_tx).await { + if let Err(e) = do_capture( + &active, + &service, + backend, + &mut conn, + &mut rx, + &mut event_tx, + ) + .await + { log::warn!("input capture exited: {e}"); } - server.set_capture_status(Status::Disabled); + event_tx + .send(ICaptureEvent::CaptureDisabled) + .expect("channel closed"); loop { - tokio::select! { - e = rx.recv() => match e { - Some(_) => continue, - None => break, - }, - _ = server.capture_enabled() => break, - _ = server.cancelled() => return, + match rx.recv().await.expect("channel closed") { + CaptureRequest::Reenable => break, + CaptureRequest::Terminate => return, + _ => {} } } } @@ -98,25 +147,27 @@ impl Capture { } async fn do_capture( - server: &Service, + active: &Cell>, + service: &Service, + backend: Option, conn: &mut LanMouseConnection, rx: &mut Receiver, - enter_tx: &mut Sender, + event_tx: &mut Sender, ) -> Result<(), InputCaptureError> { - let backend = server.config.capture_backend.map(|b| b.into()); - /* allow cancelling capture request */ let mut capture = tokio::select! { r = InputCapture::new(backend) => r?, - _ = server.cancelled() => return Ok(()), + _ = wait_for_termination(rx) => return Ok(()), }; - server.set_capture_status(Status::Enabled); + event_tx + .send(ICaptureEvent::CaptureEnabled) + .expect("channel closed"); - let clients = server.client_manager.active_clients(); + let clients = service.client_manager.active_clients(); let clients = clients.iter().copied().map(|handle| { ( handle, - server + service .client_manager .get_pos(handle) .expect("no such client"), @@ -133,11 +184,11 @@ async fn do_capture( loop { tokio::select! { event = capture.next() => match event { - Some(event) => handle_capture_event(server, &mut capture, conn, event?, &mut state, enter_tx).await?, + Some(event) => handle_capture_event(active, &service, &mut capture, conn, event?, &mut state, event_tx).await?, None => return Ok(()), }, (handle, event) = conn.recv() => { - if let Some(active) = server.get_active() { + if let Some(active) = active.get() { if handle != active { // we only care about events coming from the client we are currently connected to // only `Ack` and `Leave` are relevant @@ -154,22 +205,18 @@ async fn do_capture( // client disconnected ProtoEvent::Leave(_) => { log::info!("releasing capture: left remote client device region"); - release_capture(&mut capture, server).await?; + release_capture(&mut capture, &active).await?; }, _ => {} } }, - e = rx.recv() => { - match e { - Some(e) => match e { - CaptureRequest::Release => release_capture(&mut capture, server).await?, - CaptureRequest::Create(h, p) => capture.create(h, p).await?, - CaptureRequest::Destroy(h) => capture.destroy(h).await?, - }, - None => break, - } + e = rx.recv() => match e.expect("channel closed") { + CaptureRequest::Reenable => { /* already active */ }, + CaptureRequest::Release => release_capture(&mut capture, &active).await?, + CaptureRequest::Create(h, p) => capture.create(h, p).await?, + CaptureRequest::Destroy(h) => capture.destroy(h).await?, + CaptureRequest::Terminate => break, } - _ = server.cancelled() => break, } } @@ -205,31 +252,34 @@ enum State { } async fn handle_capture_event( - server: &Service, + active: &Cell>, + service: &Service, capture: &mut InputCapture, conn: &LanMouseConnection, event: (CaptureHandle, CaptureEvent), state: &mut State, - enter_tx: &mut Sender, + event_tx: &mut Sender, ) -> Result<(), CaptureError> { let (handle, event) = event; log::trace!("({handle}): {event:?}"); - if capture.keys_pressed(&server.config.release_bind) { + if capture.keys_pressed(&service.config.release_bind) { log::info!("releasing capture: release-bind pressed"); - return release_capture(capture, server).await; + return release_capture(capture, &active).await; } if event == CaptureEvent::Begin { - enter_tx.send(handle).expect("channel closed"); + event_tx + .send(ICaptureEvent::ClientEntered(handle)) + .expect("channel closed"); } // incoming connection if handle >= Service::ENTER_HANDLE_BEGIN { // if there is no active outgoing connection at the current capture, // we release the capture - if let Some(pos) = server.get_incoming_pos(handle) { - if server.client_manager.client_at(pos).is_none() { + if let Some(pos) = service.get_incoming_pos(handle) { + if service.client_manager.client_at(pos).is_none() { log::info!("releasing capture: no active client at this position"); capture.release().await?; } @@ -239,16 +289,16 @@ async fn handle_capture_event( } // activated a new client - if event == CaptureEvent::Begin && Some(handle) != server.get_active() { + if event == CaptureEvent::Begin && Some(handle) != active.get() { *state = State::WaitingForAck; - server.set_active(Some(handle)); + active.replace(Some(handle)); log::info!("entering client {handle} ..."); - spawn_hook_command(server, handle); + spawn_hook_command(service, handle); } - let pos = match server.client_manager.get_pos(handle) { + let pos = match service.client_manager.get_pos(handle) { Some(pos) => to_proto_pos(pos.opposite()), - None => return release_capture(capture, server).await, + None => return release_capture(capture, active).await, }; let event = match event { @@ -268,8 +318,11 @@ async fn handle_capture_event( Ok(()) } -async fn release_capture(capture: &mut InputCapture, server: &Service) -> Result<(), CaptureError> { - server.set_active(None); +async fn release_capture( + capture: &mut InputCapture, + active: &Cell>, +) -> Result<(), CaptureError> { + active.replace(None); capture.release().await } @@ -291,8 +344,8 @@ fn to_proto_pos(pos: lan_mouse_ipc::Position) -> lan_mouse_proto::Position { } } -fn spawn_hook_command(server: &Service, handle: ClientHandle) { - let Some(cmd) = server.client_manager.get_enter_cmd(handle) else { +fn spawn_hook_command(service: &Service, handle: ClientHandle) { + let Some(cmd) = service.client_manager.get_enter_cmd(handle) else { return; }; tokio::task::spawn_local(async move { @@ -316,3 +369,15 @@ fn spawn_hook_command(server: &Service, handle: ClientHandle) { } }); } + +async fn wait_for_termination(rx: &mut Receiver) { + loop { + match rx.recv().await.expect("channel closed") { + CaptureRequest::Terminate => return, + CaptureRequest::Release => continue, + CaptureRequest::Create(_, _) => continue, + CaptureRequest::Destroy(_) => continue, + CaptureRequest::Reenable => continue, + } + } +} diff --git a/src/dns.rs b/src/dns.rs index 5e11921..f88f13a 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -1,69 +1,97 @@ +use std::net::IpAddr; + use local_channel::mpsc::{channel, Receiver, Sender}; use tokio::task::{spawn_local, JoinHandle}; use hickory_resolver::{error::ResolveError, TokioAsyncResolver}; +use tokio_util::sync::CancellationToken; -use crate::service::Service; use lan_mouse_ipc::ClientHandle; pub(crate) struct DnsResolver { - _task: JoinHandle<()>, - tx: Sender, + cancellation_token: CancellationToken, + task: Option>, + request_tx: Sender, + event_rx: Receiver, +} + +struct DnsRequest { + handle: ClientHandle, + hostname: String, +} + +pub(crate) enum DnsEvent { + Resolving(ClientHandle), + Resolved(ClientHandle, String, Result, ResolveError>), } impl DnsResolver { - pub(crate) fn new(server: Service) -> Result { + pub(crate) fn new() -> Result { let resolver = TokioAsyncResolver::tokio_from_system_conf()?; - let (tx, rx) = channel(); - let _task = spawn_local(Self::run(server, resolver, rx)); - Ok(Self { _task, tx }) + let (request_tx, request_rx) = channel(); + let (event_tx, event_rx) = channel(); + let cancellation_token = CancellationToken::new(); + let task = Some(spawn_local(Self::run( + resolver, + request_rx, + event_tx, + cancellation_token.clone(), + ))); + Ok(Self { + cancellation_token, + task, + event_rx, + request_tx, + }) } - pub(crate) fn resolve(&self, host: ClientHandle) { - self.tx.send(host).expect("channel closed"); + pub(crate) fn resolve(&self, handle: ClientHandle, hostname: String) { + let request = DnsRequest { handle, hostname }; + self.request_tx.send(request).expect("channel closed"); } - async fn run(server: Service, resolver: TokioAsyncResolver, mut rx: Receiver) { + pub(crate) async fn event(&mut self) -> DnsEvent { + self.event_rx.recv().await.expect("channel closed") + } + + async fn run( + resolver: TokioAsyncResolver, + mut request_rx: Receiver, + event_tx: Sender, + cancellation_token: CancellationToken, + ) { tokio::select! { - _ = server.cancelled() => {}, - _ = Self::do_dns(&server, &resolver, &mut rx) => {}, + _ = Self::do_dns(&resolver, &mut request_rx, &event_tx) => {}, + _ = cancellation_token.cancelled() => {}, } } async fn do_dns( - server: &Service, resolver: &TokioAsyncResolver, - rx: &mut Receiver, + request_rx: &mut Receiver, + event_tx: &Sender, ) { loop { - let handle = rx.recv().await.expect("channel closed"); + let DnsRequest { handle, hostname } = request_rx.recv().await.expect("channel closed"); - /* update resolving status */ - let hostname = match server.client_manager.get_hostname(handle) { - Some(hostname) => hostname, - None => continue, - }; - - log::info!("resolving ({handle}) `{hostname}` ..."); - server.set_resolving(handle, true); + event_tx + .send(DnsEvent::Resolving(handle)) + .expect("channel closed"); /* resolve host */ - let ips = match resolver.lookup_ip(&hostname).await { - Ok(response) => { - let ips = response.iter().collect::>(); - for ip in ips.iter() { - log::info!("{hostname}: adding ip {ip}"); - } - ips - } - Err(e) => { - log::warn!("could not resolve host '{hostname}': {e}"); - vec![] - } - }; + let ips = resolver + .lookup_ip(&hostname) + .await + .map(|ips| ips.iter().collect::>()); - server.update_dns_ips(handle, ips); - server.set_resolving(handle, false); + event_tx + .send(DnsEvent::Resolved(handle, hostname, ips)) + .expect("channel closed"); } } + + pub(crate) async fn terminate(&mut self) { + self.cancellation_token.cancel(); + self.task.take().expect("task").await.expect("join error"); + } } diff --git a/src/service.rs b/src/service.rs index 930fbbe..16f5fe4 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,10 +1,10 @@ use crate::{ - capture::Capture, + capture::{Capture, ICaptureEvent}, client::ClientManager, config::Config, connect::LanMouseConnection, crypto, - dns::DnsResolver, + dns::{DnsEvent, DnsResolver}, emulation::{Emulation, EmulationEvent}, listen::{LanMouseListener, ListenerCreationError}, }; @@ -25,7 +25,6 @@ use std::{ }; use thiserror::Error; use tokio::{signal, sync::Notify}; -use tokio_util::sync::CancellationToken; use webrtc_dtls::crypto::Certificate; #[derive(Debug, Error)] @@ -64,7 +63,6 @@ pub struct Incoming { #[derive(Clone)] pub struct Service { - active: Rc>>, authorized_keys: Arc>>, pub(crate) client_manager: ClientManager, port: Rc>, @@ -85,10 +83,8 @@ pub struct Service { #[derive(Default)] struct Notifies { - reenable_capture: Notify, incoming: Notify, frontend_event_pending: Notify, - cancel: CancellationToken, } impl Service { @@ -118,7 +114,6 @@ impl Service { let public_key_fingerprint = crypto::certificate_fingerprint(&cert); let service = Self { - active: Rc::new(Cell::new(None)), authorized_keys: Arc::new(RwLock::new(config.authorized_fingerprints.clone())), cert, public_key_fingerprint, @@ -133,7 +128,6 @@ impl Service { incoming_conn_info: Default::default(), incoming_conns: Default::default(), next_trigger_handle: 0, - requested_port: Default::default(), }; Ok(service) } @@ -152,15 +146,18 @@ impl Service { let conn = LanMouseConnection::new(self.clone(), self.cert.clone()); // input capture + emulation - let mut capture = Capture::new(self.clone(), conn); + let capture_backend = self.config.capture_backend.map(|b| b.into()); + let mut capture = Capture::new(capture_backend, conn, self.clone()); let emulation_backend = self.config.emulation_backend.map(|b| b.into()); let mut emulation = Emulation::new(emulation_backend, listener); // create dns resolver - let resolver = DnsResolver::new(self.clone())?; + let mut resolver = DnsResolver::new()?; for handle in self.client_manager.active_clients() { - resolver.resolve(handle); + if let Some(hostname) = self.client_manager.get_hostname(handle) { + resolver.resolve(handle, hostname); + } } loop { @@ -226,19 +223,38 @@ impl Service { self.emulation_status.replace(Status::Enabled); self.notify_frontend(FrontendEvent::EmulationStatus(Status::Enabled)); }, - EmulationEvent::ReleaseNotify => { - self.set_active(None); - capture.release(); + EmulationEvent::ReleaseNotify => capture.release(), + }, + event = capture.event() => match event { + ICaptureEvent::ClientEntered(handle) => { + // we entered the capture zone for an incoming connection + // => notify it that its capture should be released + if let Some(incoming) = self.incoming_conn_info.borrow().get(&handle) { + emulation.send_leave_event(incoming.addr); + } + } + ICaptureEvent::CaptureDisabled => { + self.capture_status.replace(Status::Disabled); + self.notify_frontend(FrontendEvent::CaptureStatus(Status::Disabled)); + } + ICaptureEvent::CaptureEnabled => { + self.capture_status.replace(Status::Enabled); + self.notify_frontend(FrontendEvent::CaptureStatus(Status::Enabled)); } }, - handle = capture.entered() => { - // we entered the capture zone for an incoming connection - // => notify it that its capture should be released - if let Some(incoming) = self.incoming_conn_info.borrow().get(&handle) { - emulation.send_leave_event(incoming.addr); + event = resolver.event() => match event { + DnsEvent::Resolving(handle) => self.set_resolving(handle, true), + DnsEvent::Resolved(handle, hostname, ips) => { + self.set_resolving(handle, false); + match ips { + Ok(ips) => self.update_dns_ips(handle, ips), + Err(e) => { + log::warn!("could not resolve {hostname}: {e}"); + self.update_dns_ips(handle, vec![]); + }, + } } }, - _ = self.cancelled() => break, r = signal::ctrl_c() => { r.expect("failed to wait for CTRL+C"); break; @@ -248,10 +264,9 @@ impl Service { log::info!("terminating service"); - self.cancel(); - capture.terminate().await; emulation.terminate().await; + resolver.terminate().await; Ok(()) } @@ -305,23 +320,6 @@ impl Service { self.notifies.frontend_event_pending.notify_one(); } - fn cancel(&self) { - self.notifies.cancel.cancel(); - } - - pub(crate) async fn cancelled(&self) { - self.notifies.cancel.cancelled().await - } - - fn request_capture_reenable(&self) { - log::info!("received capture enable request"); - self.notifies.reenable_capture.notify_waiters() - } - - pub(crate) async fn capture_enabled(&self) { - self.notifies.reenable_capture.notified().await - } - pub(crate) fn client_updated(&self, handle: ClientHandle) { self.notify_frontend(FrontendEvent::Changed(handle)); } @@ -335,7 +333,7 @@ impl Service { ) -> bool { log::debug!("frontend: {event:?}"); match event { - FrontendRequest::EnableCapture => self.request_capture_reenable(), + FrontendRequest::EnableCapture => capture.reenable(), FrontendRequest::EnableEmulation => emulation.reenable(), FrontendRequest::Create => { self.add_client(); @@ -362,7 +360,11 @@ impl Service { FrontendRequest::UpdatePosition(handle, pos) => { self.update_pos(handle, capture, pos); } - FrontendRequest::ResolveDns(handle) => dns.resolve(handle), + FrontendRequest::ResolveDns(handle) => { + if let Some(hostname) = self.client_manager.get_hostname(handle) { + dns.resolve(handle, hostname); + } + } FrontendRequest::Sync => { self.enumerate(); self.notify_frontend(FrontendEvent::EmulationStatus(self.emulation_status.get())); @@ -464,8 +466,10 @@ impl Service { } fn update_hostname(&self, handle: ClientHandle, hostname: Option, dns: &DnsResolver) { - if self.client_manager.set_hostname(handle, hostname) { - dns.resolve(handle); + if self.client_manager.set_hostname(handle, hostname.clone()) { + if let Some(hostname) = hostname { + dns.resolve(handle, hostname); + } self.client_updated(handle); } } @@ -491,25 +495,11 @@ impl Service { self.notify_frontend(event); } - pub(crate) fn set_capture_status(&self, status: Status) { - self.capture_status.replace(status); - let status = FrontendEvent::CaptureStatus(status); - self.notify_frontend(status); - } - pub(crate) fn set_resolving(&self, handle: ClientHandle, status: bool) { self.client_manager.set_resolving(handle, status); self.client_updated(handle); } - pub(crate) fn set_active(&self, handle: Option) { - self.active.replace(handle); - } - - pub(crate) fn get_active(&self) -> Option { - self.active.get() - } - pub(crate) fn register_incoming(&self, addr: SocketAddr, pos: Position, fingerprint: String) { self.pending_incoming .borrow_mut()