From 875a31907a4151c5e41b4f1439f2500de7d9238c Mon Sep 17 00:00:00 2001 From: Ferdinand Schober Date: Fri, 8 Nov 2024 22:45:11 +0100 Subject: [PATCH] cleanup --- src/dns.rs | 29 ++-- src/service.rs | 427 +++++++++++++++++++++++++------------------------ 2 files changed, 235 insertions(+), 221 deletions(-) diff --git a/src/dns.rs b/src/dns.rs index 6aa45a0..68647a3 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -78,24 +78,29 @@ impl DnsTask { } async fn do_dns(&mut self) { - loop { - let DnsRequest { handle, hostname } = - self.request_rx.recv().await.expect("channel closed"); + while let Some(dns_request) = self.request_rx.recv().await { + let DnsRequest { handle, hostname } = dns_request; self.event_tx .send(DnsEvent::Resolving(handle)) .expect("channel closed"); - /* resolve host */ - let ips = self - .resolver - .lookup_ip(&hostname) - .await - .map(|ips| ips.iter().collect::>()); + /* spawn task for dns request */ + let event_tx = self.event_tx.clone(); + let resolver = self.resolver.clone(); + let cancellation_token = self.cancellation_token.clone(); - self.event_tx - .send(DnsEvent::Resolved(handle, hostname, ips)) - .expect("channel closed"); + tokio::task::spawn_local(async move { + tokio::select! { + ips = resolver.lookup_ip(&hostname) => { + let ips = ips.map(|ips| ips.iter().collect::>()); + event_tx + .send(DnsEvent::Resolved(handle, hostname, ips)) + .expect("channel closed"); + } + _ = cancellation_token.cancelled() => {}, + } + }); } } } diff --git a/src/service.rs b/src/service.rs index 3f14907..9ab9122 100644 --- a/src/service.rs +++ b/src/service.rs @@ -12,7 +12,7 @@ use futures::StreamExt; use hickory_resolver::error::ResolveError; use lan_mouse_ipc::{ AsyncFrontendListener, ClientConfig, ClientHandle, ClientState, FrontendEvent, FrontendRequest, - IpcListenerCreationError, Position, Status, + IpcError, IpcListenerCreationError, Position, Status, }; use log; use std::{ @@ -39,17 +39,29 @@ pub enum ServiceError { } pub struct Service { + /// input capture capture: Capture, + /// input emulation emulation: Emulation, + /// dns resolver resolver: DnsResolver, + /// frontend listener frontend_listener: AsyncFrontendListener, + /// authorized public key sha256 fingerprints authorized_keys: Arc>>, + /// (outgoing) client information client_manager: ClientManager, + /// current port port: u16, + /// the public key fingerprint for (D)TLS public_key_fingerprint: String, + /// notify for pending frontend events frontend_event_pending: Notify, + /// frontend events queued for sending pending_frontend_events: VecDeque, + /// status of input capture (enabled / disabled) capture_status: Status, + /// status of input emulation (enabled / disabled) emulation_status: Status, /// keep track of registered connections to avoid duplicate barriers incoming_conns: HashSet, @@ -131,198 +143,164 @@ impl Service { pub async fn run(&mut self) -> Result<(), ServiceError> { for handle in self.client_manager.active_clients() { - if let Some(hostname) = self.client_manager.get_hostname(handle) { - self.resolver.resolve(handle, hostname); - } - if let Some(pos) = self.client_manager.get_pos(handle) { - self.capture.create(handle, pos, CaptureType::Default); - } + self.activate_client(handle); } loop { tokio::select! { - request = self.frontend_listener.next() => { - let request = match request { - Some(Ok(r)) => r, - Some(Err(e)) => { - log::error!("error receiving request: {e}"); - continue; - } - None => break, - }; - match request { - FrontendRequest::EnableCapture => self.capture.reenable(), - FrontendRequest::EnableEmulation => self.emulation.reenable(), - FrontendRequest::Create => { - self.add_client(); - } - FrontendRequest::Activate(handle, active) => { - if active { - if let Some(hostname) = self.client_manager.get_hostname(handle) { - self.resolver.resolve(handle, hostname); - } - self.activate_client(handle); - } else { - self.deactivate_client(handle); - } - } - FrontendRequest::ChangePort(port) => { - if self.port != port { - self.emulation.request_port_change(port); - } else { - self.notify_frontend(FrontendEvent::PortChanged(self.port, None)); - } - } - FrontendRequest::Delete(handle) => { - self.remove_client(handle); - self.notify_frontend(FrontendEvent::Deleted(handle)); - } - FrontendRequest::Enumerate() => self.enumerate(), - FrontendRequest::GetState(handle) => self.broadcast_client(handle), - FrontendRequest::UpdateFixIps(handle, fix_ips) => self.update_fix_ips(handle, fix_ips), - FrontendRequest::UpdateHostname(handle, host) => { - self.update_hostname(handle, host) - } - FrontendRequest::UpdatePort(handle, port) => self.update_port(handle, port), - FrontendRequest::UpdatePosition(handle, pos) => { - self.update_pos(handle, pos); - } - FrontendRequest::ResolveDns(handle) => { - if let Some(hostname) = self.client_manager.get_hostname(handle) { - self.resolver.resolve(handle, hostname); - } - } - FrontendRequest::Sync => { - self.enumerate(); - self.notify_frontend(FrontendEvent::EmulationStatus(self.emulation_status)); - self.notify_frontend(FrontendEvent::CaptureStatus(self.capture_status)); - self.notify_frontend(FrontendEvent::PortChanged(self.port, None)); - self.notify_frontend(FrontendEvent::PublicKeyFingerprint( - self.public_key_fingerprint.clone(), - )); - let keys = self.authorized_keys.read().expect("lock").clone(); - self.notify_frontend(FrontendEvent::AuthorizedUpdated(keys)); - } - FrontendRequest::AuthorizeKey(desc, fp) => { - self.add_authorized_key(desc, fp); - } - FrontendRequest::RemoveAuthorizedKey(key) => { - self.remove_authorized_key(key); - } - } - } - _ = self.frontend_event_pending.notified() => { - while let Some(event) = self.pending_frontend_events.pop_front() { - self.frontend_listener.broadcast(event).await; - } - }, - event = self.emulation.event() => match event { - EmulationEvent::Connected { addr, pos, fingerprint } => { - // check if already registered - if !self.incoming_conns.contains(&addr) { - self.add_incoming(addr, pos, fingerprint.clone()); - self.notify_frontend(FrontendEvent::IncomingConnected(fingerprint, addr, pos)); - } else { - let handle = self - .incoming_conn_info - .iter() - .find(|(_, incoming)| incoming.addr == addr) - .map(|(k, _)| *k) - .expect("no such client"); - let mut changed = false; - if let Some(incoming) = self.incoming_conn_info.get_mut(&handle) { - if incoming.fingerprint != fingerprint { - incoming.fingerprint = fingerprint.clone(); - changed = true; - } - if incoming.pos != pos { - incoming.pos = pos; - changed = true; - } - } - if changed { - self.remove_incoming(addr); - self.add_incoming(addr, pos, fingerprint.clone()); - self.notify_frontend(FrontendEvent::IncomingDisconnected(addr)); - self.notify_frontend(FrontendEvent::IncomingConnected(fingerprint, addr, pos)); - } - } - } - EmulationEvent::Disconnected { addr } => { - if let Some(addr) = self.remove_incoming(addr) { - self.notify_frontend(FrontendEvent::IncomingDisconnected(addr)); - } - } - EmulationEvent::PortChanged(port) => match port { - Ok(port) => { - self.port = port; - self.notify_frontend(FrontendEvent::PortChanged(port, None)); - }, - Err(e) => self.notify_frontend(FrontendEvent::PortChanged(self.port, Some(format!("{e}")))), - } - EmulationEvent::EmulationDisabled => { - self.emulation_status = Status::Disabled; - self.notify_frontend(FrontendEvent::EmulationStatus(self.emulation_status)); - }, - EmulationEvent::EmulationEnabled => { - self.emulation_status = Status::Enabled; - self.notify_frontend(FrontendEvent::EmulationStatus(self.emulation_status)); - }, - EmulationEvent::ReleaseNotify => self.capture.release(), - }, - event = self.capture.event() => match event { - ICaptureEvent::CaptureBegin(handle) => { - // we entered the capture zone for an incoming connection - // => notify it that its capture should be released - if let Some(incoming) = self.incoming_conn_info.get(&handle) { - self.emulation.send_leave_event(incoming.addr); - } - } - ICaptureEvent::CaptureDisabled => { - self.capture_status = Status::Disabled; - self.notify_frontend(FrontendEvent::CaptureStatus(self.capture_status)); - } - ICaptureEvent::CaptureEnabled => { - self.capture_status = Status::Enabled; - self.notify_frontend(FrontendEvent::CaptureStatus(self.capture_status)); - } - ICaptureEvent::ClientEntered(handle) => { - log::info!("entering client {handle} ..."); - self.spawn_hook_command(handle); - }, - }, - event = self.resolver.event() => match event { - DnsEvent::Resolving(handle) => self.set_resolving(handle, true), - DnsEvent::Resolved(handle, hostname, ips) => { - self.set_resolving(handle, false); - match ips { - Ok(ips) => self.update_dns_ips(handle, ips), - Err(e) => { - log::warn!("could not resolve {hostname}: {e}"); - self.update_dns_ips(handle, vec![]); - }, - } - } - }, - r = signal::ctrl_c() => { - r.expect("failed to wait for CTRL+C"); - break; - } + request = self.frontend_listener.next() => self.handle_frontend_request(request), + _ = self.frontend_event_pending.notified() => self.handle_frontend_pending().await, + event = self.emulation.event() => self.handle_emulation_event(event), + event = self.capture.event() => self.handle_capture_event(event), + event = self.resolver.event() => self.handle_resolver_event(event), + r = signal::ctrl_c() => break r.expect("failed to wait for CTRL+C"), } } log::info!("terminating service ..."); - log::info!("terminating capture ..."); + log::debug!("terminating capture ..."); self.capture.terminate().await; - log::info!("terminating emulation ..."); + log::debug!("terminating emulation ..."); self.emulation.terminate().await; - log::info!("terminating dns resolver ..."); + log::debug!("terminating dns resolver ..."); self.resolver.terminate().await; Ok(()) } - pub(crate) const ENTER_HANDLE_BEGIN: u64 = u64::MAX / 2 + 1; + fn handle_frontend_request(&mut self, request: Option>) { + let request = match request.expect("frontend listener closed") { + Ok(r) => r, + Err(e) => return log::error!("error receiving request: {e}"), + }; + match request { + FrontendRequest::Activate(handle, active) => self.set_client_active(handle, active), + FrontendRequest::AuthorizeKey(desc, fp) => self.add_authorized_key(desc, fp), + FrontendRequest::ChangePort(port) => self.change_port(port), + FrontendRequest::Create => self.add_client(), + FrontendRequest::Delete(handle) => self.remove_client(handle), + FrontendRequest::EnableCapture => self.capture.reenable(), + FrontendRequest::EnableEmulation => self.emulation.reenable(), + FrontendRequest::Enumerate() => self.enumerate(), + FrontendRequest::GetState(handle) => self.broadcast_client(handle), + FrontendRequest::UpdateFixIps(handle, fix_ips) => self.update_fix_ips(handle, fix_ips), + FrontendRequest::UpdateHostname(handle, host) => self.update_hostname(handle, host), + FrontendRequest::UpdatePort(handle, port) => self.update_port(handle, port), + FrontendRequest::UpdatePosition(handle, pos) => self.update_pos(handle, pos), + FrontendRequest::ResolveDns(handle) => self.resolve(handle), + FrontendRequest::Sync => self.sync_frontend(), + FrontendRequest::RemoveAuthorizedKey(key) => self.remove_authorized_key(key), + } + } + + async fn handle_frontend_pending(&mut self) { + while let Some(event) = self.pending_frontend_events.pop_front() { + self.frontend_listener.broadcast(event).await; + } + } + + fn handle_emulation_event(&mut self, event: EmulationEvent) { + match event { + EmulationEvent::Connected { + addr, + pos, + fingerprint, + } => { + // check if already registered + if !self.incoming_conns.contains(&addr) { + self.add_incoming(addr, pos, fingerprint.clone()); + self.notify_frontend(FrontendEvent::IncomingConnected(fingerprint, addr, pos)); + } else { + self.update_incoming(addr, pos, fingerprint); + } + } + EmulationEvent::Disconnected { addr } => { + if let Some(addr) = self.remove_incoming(addr) { + self.notify_frontend(FrontendEvent::IncomingDisconnected(addr)); + } + } + EmulationEvent::PortChanged(port) => match port { + Ok(port) => { + self.port = port; + self.notify_frontend(FrontendEvent::PortChanged(port, None)); + } + Err(e) => self + .notify_frontend(FrontendEvent::PortChanged(self.port, Some(format!("{e}")))), + }, + EmulationEvent::EmulationDisabled => { + self.emulation_status = Status::Disabled; + self.notify_frontend(FrontendEvent::EmulationStatus(self.emulation_status)); + } + EmulationEvent::EmulationEnabled => { + self.emulation_status = Status::Enabled; + self.notify_frontend(FrontendEvent::EmulationStatus(self.emulation_status)); + } + EmulationEvent::ReleaseNotify => self.capture.release(), + } + } + + fn handle_capture_event(&mut self, event: ICaptureEvent) { + match event { + ICaptureEvent::CaptureBegin(handle) => { + // we entered the capture zone for an incoming connection + // => notify it that its capture should be released + if let Some(incoming) = self.incoming_conn_info.get(&handle) { + self.emulation.send_leave_event(incoming.addr); + } + } + ICaptureEvent::CaptureDisabled => { + self.capture_status = Status::Disabled; + self.notify_frontend(FrontendEvent::CaptureStatus(self.capture_status)); + } + ICaptureEvent::CaptureEnabled => { + self.capture_status = Status::Enabled; + self.notify_frontend(FrontendEvent::CaptureStatus(self.capture_status)); + } + ICaptureEvent::ClientEntered(handle) => { + log::info!("entering client {handle} ..."); + self.spawn_hook_command(handle); + } + } + } + + fn handle_resolver_event(&mut self, event: DnsEvent) { + let handle = match event { + DnsEvent::Resolving(handle) => { + self.client_manager.set_resolving(handle, true); + handle + } + DnsEvent::Resolved(handle, hostname, ips) => { + self.client_manager.set_resolving(handle, false); + if let Err(e) = &ips { + log::warn!("could not resolve {hostname}: {e}"); + } + let ips = ips.unwrap_or_default(); + self.client_manager.set_dns_ips(handle, ips); + handle + } + }; + self.notify_frontend(FrontendEvent::Changed(handle)); + } + + fn resolve(&self, handle: ClientHandle) { + if let Some(hostname) = self.client_manager.get_hostname(handle) { + self.resolver.resolve(handle, hostname); + } + } + + fn sync_frontend(&mut self) { + self.enumerate(); + self.notify_frontend(FrontendEvent::EmulationStatus(self.emulation_status)); + self.notify_frontend(FrontendEvent::CaptureStatus(self.capture_status)); + self.notify_frontend(FrontendEvent::PortChanged(self.port, None)); + self.notify_frontend(FrontendEvent::PublicKeyFingerprint( + self.public_key_fingerprint.clone(), + )); + let keys = self.authorized_keys.read().expect("lock").clone(); + self.notify_frontend(FrontendEvent::AuthorizedUpdated(keys)); + } + + const ENTER_HANDLE_BEGIN: u64 = u64::MAX / 2 + 1; fn add_incoming(&mut self, addr: SocketAddr, pos: Position, fingerprint: String) { let handle = Self::ENTER_HANDLE_BEGIN + self.next_trigger_handle; @@ -339,6 +317,30 @@ impl Service { ); } + fn update_incoming(&mut self, addr: SocketAddr, pos: Position, fingerprint: String) { + let incoming = self + .incoming_conn_info + .iter_mut() + .find(|(_, i)| i.addr == addr) + .map(|(_, i)| i) + .expect("no such client"); + let mut changed = false; + if incoming.fingerprint != fingerprint { + incoming.fingerprint = fingerprint.clone(); + changed = true; + } + if incoming.pos != pos { + incoming.pos = pos; + changed = true; + } + if changed { + self.remove_incoming(addr); + self.add_incoming(addr, pos, fingerprint.clone()); + self.notify_frontend(FrontendEvent::IncomingDisconnected(addr)); + self.notify_frontend(FrontendEvent::IncomingConnected(fingerprint, addr, pos)); + } + } + fn remove_incoming(&mut self, addr: SocketAddr) -> Option { let handle = self .incoming_conn_info @@ -357,10 +359,6 @@ impl Service { self.frontend_event_pending.notify_one(); } - fn client_updated(&mut self, handle: ClientHandle) { - self.notify_frontend(FrontendEvent::Changed(handle)); - } - fn add_authorized_key(&mut self, desc: String, fp: String) { self.authorized_keys.write().expect("lock").insert(fp, desc); let keys = self.authorized_keys.read().expect("lock").clone(); @@ -378,25 +376,36 @@ impl Service { self.notify_frontend(FrontendEvent::Enumerate(clients)); } - fn add_client(&mut self) -> ClientHandle { + fn add_client(&mut self) { let handle = self.client_manager.add_client(); log::info!("added client {handle}"); let (c, s) = self.client_manager.get_state(handle).unwrap(); self.notify_frontend(FrontendEvent::Created(handle, c, s)); - handle + } + + fn set_client_active(&mut self, handle: ClientHandle, active: bool) { + if active { + self.activate_client(handle); + } else { + self.deactivate_client(handle); + } } fn deactivate_client(&mut self, handle: ClientHandle) { log::debug!("deactivating client {handle}"); if self.client_manager.deactivate_client(handle) { self.capture.destroy(handle); - self.client_updated(handle); + self.notify_frontend(FrontendEvent::Changed(handle)); log::info!("deactivated client {handle}"); } } fn activate_client(&mut self, handle: ClientHandle) { log::debug!("activating client"); + + /* resolve dns on activate */ + self.resolve(handle); + /* deactivate potential other client at this position */ let Some(pos) = self.client_manager.get_pos(handle) else { return; @@ -412,42 +421,46 @@ impl Service { if self.client_manager.activate_client(handle) { /* notify capture and frontends */ self.capture.create(handle, pos, CaptureType::Default); - self.client_updated(handle); + self.notify_frontend(FrontendEvent::Changed(handle)); log::info!("activated client {handle} ({pos})"); } } - fn remove_client(&self, handle: ClientHandle) { - if let Some(true) = self + fn change_port(&mut self, port: u16) { + if self.port != port { + self.emulation.request_port_change(port); + } else { + self.notify_frontend(FrontendEvent::PortChanged(self.port, None)); + } + } + + fn remove_client(&mut self, handle: ClientHandle) { + if self .client_manager .remove_client(handle) .map(|(_, s)| s.active) + .unwrap_or(false) { self.capture.destroy(handle); } + self.notify_frontend(FrontendEvent::Deleted(handle)); } fn update_fix_ips(&mut self, handle: ClientHandle, fix_ips: Vec) { self.client_manager.set_fix_ips(handle, fix_ips); - self.client_updated(handle); - } - - fn update_dns_ips(&mut self, handle: ClientHandle, dns_ips: Vec) { - self.client_manager.set_dns_ips(handle, dns_ips); - self.client_updated(handle); + self.notify_frontend(FrontendEvent::Changed(handle)); } fn update_hostname(&mut self, handle: ClientHandle, hostname: Option) { if self.client_manager.set_hostname(handle, hostname.clone()) { - if let Some(hostname) = hostname { - self.resolver.resolve(handle, hostname); - } - self.client_updated(handle); + self.resolve(handle); } + self.notify_frontend(FrontendEvent::Changed(handle)); } - fn update_port(&self, handle: ClientHandle, port: u16) { + fn update_port(&mut self, handle: ClientHandle, port: u16) { self.client_manager.set_port(handle, port); + self.notify_frontend(FrontendEvent::Changed(handle)); } fn update_pos(&mut self, handle: ClientHandle, pos: Position) { @@ -456,22 +469,18 @@ impl Service { self.deactivate_client(handle); self.activate_client(handle); } + self.notify_frontend(FrontendEvent::Changed(handle)); } fn broadcast_client(&mut self, handle: ClientHandle) { - let event = if let Some((config, state)) = self.client_manager.get_state(handle) { - FrontendEvent::State(handle, config, state) - } else { - FrontendEvent::NoSuchClient(handle) - }; + let event = self + .client_manager + .get_state(handle) + .map(|(c, s)| FrontendEvent::State(handle, c, s)) + .unwrap_or(FrontendEvent::NoSuchClient(handle)); self.notify_frontend(event); } - fn set_resolving(&mut self, handle: ClientHandle, status: bool) { - self.client_manager.set_resolving(handle, status); - self.client_updated(handle); - } - fn spawn_hook_command(&self, handle: ClientHandle) { let Some(cmd) = self.client_manager.get_enter_cmd(handle) else { return;