diff --git a/src/capture.rs b/src/capture.rs index e0dd6ee..00b7702 100644 --- a/src/capture.rs +++ b/src/capture.rs @@ -15,7 +15,7 @@ use tokio::{ task::{spawn_local, JoinHandle}, }; -use crate::{connect::LanMouseConnection, server::Server}; +use crate::{connect::LanMouseConnection, service::Service}; pub(crate) struct Capture { tx: Sender, @@ -33,7 +33,7 @@ enum CaptureRequest { } impl Capture { - pub(crate) fn new(server: Server, conn: LanMouseConnection) -> Self { + pub(crate) fn new(server: Service, conn: LanMouseConnection) -> Self { let (tx, rx) = channel(); let task = spawn_local(Self::run(server.clone(), rx, conn)); Self { tx, task } @@ -66,7 +66,7 @@ impl Capture { .expect("channel closed"); } - async fn run(server: Server, mut rx: Receiver, mut conn: LanMouseConnection) { + async fn run(server: Service, mut rx: Receiver, mut conn: LanMouseConnection) { loop { if let Err(e) = do_capture(&server, &mut conn, &mut rx).await { log::warn!("input capture exited: {e}"); @@ -87,7 +87,7 @@ impl Capture { } async fn do_capture( - server: &Server, + server: &Service, conn: &mut LanMouseConnection, rx: &mut Receiver, ) -> Result<(), InputCaptureError> { @@ -191,7 +191,7 @@ enum State { } async fn handle_capture_event( - server: &Server, + server: &Service, capture: &mut InputCapture, conn: &LanMouseConnection, event: (CaptureHandle, CaptureEvent), @@ -241,7 +241,7 @@ async fn handle_capture_event( Ok(()) } -async fn release_capture(capture: &mut InputCapture, server: &Server) -> Result<(), CaptureError> { +async fn release_capture(capture: &mut InputCapture, server: &Service) -> Result<(), CaptureError> { server.set_active(None); capture.release().await } @@ -264,7 +264,7 @@ fn to_proto_pos(pos: lan_mouse_ipc::Position) -> lan_mouse_proto::Position { } } -fn spawn_hook_command(server: &Server, handle: ClientHandle) { +fn spawn_hook_command(server: &Service, handle: ClientHandle) { let Some(cmd) = server.client_manager.get_enter_cmd(handle) else { return; }; diff --git a/src/client.rs b/src/client.rs index 60aac85..0b36384 100644 --- a/src/client.rs +++ b/src/client.rs @@ -39,7 +39,7 @@ impl ClientManager { pub fn activate_client(&self, handle: ClientHandle) -> bool { let mut clients = self.clients.borrow_mut(); match clients.get_mut(handle as usize) { - Some((_, s)) if s.active == false => { + Some((_, s)) if !s.active => { s.active = true; true } diff --git a/src/connect.rs b/src/connect.rs index a867f31..851a6d6 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -1,4 +1,4 @@ -use crate::server::Server; +use crate::service::Service; use lan_mouse_ipc::{ClientHandle, DEFAULT_PORT}; use lan_mouse_proto::{ProtoEvent, MAX_EVENT_SIZE}; use local_channel::mpsc::{channel, Receiver, Sender}; @@ -68,7 +68,7 @@ async fn connect_any( } pub(crate) struct LanMouseConnection { - server: Server, + server: Service, cert: Certificate, conns: Rc>>>, connecting: Rc>>, @@ -77,7 +77,7 @@ pub(crate) struct LanMouseConnection { } impl LanMouseConnection { - pub(crate) fn new(server: Server, cert: Certificate) -> Self { + pub(crate) fn new(server: Service, cert: Certificate) -> Self { let (recv_tx, recv_rx) = channel(); Self { server, @@ -139,7 +139,7 @@ impl LanMouseConnection { } async fn connect_to_handle( - server: Server, + server: Service, cert: Certificate, handle: ClientHandle, conns: Rc>>>, @@ -189,7 +189,7 @@ async fn connect_to_handle( } async fn ping_pong( - server: Server, + server: Service, handle: ClientHandle, addr: SocketAddr, conn: Arc, @@ -214,7 +214,7 @@ async fn ping_pong( } async fn receive_loop( - server: Server, + server: Service, handle: ClientHandle, addr: SocketAddr, conn: Arc, @@ -238,7 +238,7 @@ async fn receive_loop( } async fn disconnect( - server: &Server, + server: &Service, handle: ClientHandle, addr: SocketAddr, conns: &Mutex>>, diff --git a/src/crypto.rs b/src/crypto.rs index 74c8c7b..ddbb285 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -64,6 +64,6 @@ pub(crate) fn generate_key_and_cert(path: &Path) -> Result { } /* FIXME windows permissions */ let mut writer = BufWriter::new(f); - writer.write(serialized.as_bytes())?; + writer.write_all(serialized.as_bytes())?; Ok(cert) } diff --git a/src/dns.rs b/src/dns.rs index 3c7f7da..5e11921 100644 --- a/src/dns.rs +++ b/src/dns.rs @@ -3,7 +3,7 @@ use tokio::task::{spawn_local, JoinHandle}; use hickory_resolver::{error::ResolveError, TokioAsyncResolver}; -use crate::server::Server; +use crate::service::Service; use lan_mouse_ipc::ClientHandle; pub(crate) struct DnsResolver { @@ -12,7 +12,7 @@ pub(crate) struct DnsResolver { } impl DnsResolver { - pub(crate) fn new(server: Server) -> Result { + pub(crate) fn new(server: Service) -> Result { let resolver = TokioAsyncResolver::tokio_from_system_conf()?; let (tx, rx) = channel(); let _task = spawn_local(Self::run(server, resolver, rx)); @@ -23,7 +23,7 @@ impl DnsResolver { self.tx.send(host).expect("channel closed"); } - async fn run(server: Server, resolver: TokioAsyncResolver, mut rx: Receiver) { + async fn run(server: Service, resolver: TokioAsyncResolver, mut rx: Receiver) { tokio::select! { _ = server.cancelled() => {}, _ = Self::do_dns(&server, &resolver, &mut rx) => {}, @@ -31,7 +31,7 @@ impl DnsResolver { } async fn do_dns( - server: &Server, + server: &Service, resolver: &TokioAsyncResolver, rx: &mut Receiver, ) { diff --git a/src/emulation.rs b/src/emulation.rs index f2cd896..4ba1b77 100644 --- a/src/emulation.rs +++ b/src/emulation.rs @@ -1,4 +1,4 @@ -use crate::{listen::LanMouseListener, server::Server}; +use crate::{listen::LanMouseListener, service::Service}; use futures::StreamExt; use input_emulation::{EmulationHandle, InputEmulation, InputEmulationError}; use input_event::Event; @@ -21,14 +21,14 @@ pub(crate) struct Emulation { } impl Emulation { - pub(crate) fn new(server: Server, listener: LanMouseListener) -> Self { + pub(crate) fn new(server: Service, listener: LanMouseListener) -> Self { let emulation_proxy = EmulationProxy::new(server.clone()); let task = spawn_local(Self::run(server, listener, emulation_proxy)); Self { task } } async fn run( - server: Server, + server: Service, mut listener: LanMouseListener, mut emulation_proxy: EmulationProxy, ) { @@ -45,10 +45,12 @@ impl Emulation { last_response.insert(addr, Instant::now()); match event { ProtoEvent::Enter(pos) => { - log::info!("{addr} entered this device"); - server.release_capture(); - listener.reply(addr, ProtoEvent::Ack(0)).await; - server.register_incoming(addr, to_ipc_pos(pos)); + if let Some(cert) = listener.get_certificate_fingerprint(addr).await { + log::info!("{addr} entered this device"); + server.release_capture(); + listener.reply(addr, ProtoEvent::Ack(0)).await; + server.register_incoming(addr, to_ipc_pos(pos), cert); + } } ProtoEvent::Leave(_) => { emulation_proxy.release_keys(addr); @@ -89,7 +91,7 @@ impl Emulation { /// proxy handling the actual input emulation, /// discarding events when it is disabled pub(crate) struct EmulationProxy { - server: Server, + server: Service, tx: Sender<(ProxyEvent, SocketAddr)>, task: JoinHandle<()>, } @@ -100,7 +102,7 @@ enum ProxyEvent { } impl EmulationProxy { - fn new(server: Server) -> Self { + fn new(server: Service) -> Self { let (tx, rx) = channel(); let task = spawn_local(Self::emulation_task(server.clone(), rx)); Self { server, tx, task } @@ -121,7 +123,7 @@ impl EmulationProxy { .expect("channel closed"); } - async fn emulation_task(server: Server, mut rx: Receiver<(ProxyEvent, SocketAddr)>) { + async fn emulation_task(server: Service, mut rx: Receiver<(ProxyEvent, SocketAddr)>) { let mut handles = HashMap::new(); let mut next_id = 0; loop { @@ -136,7 +138,7 @@ impl EmulationProxy { } async fn do_emulation( - server: &Server, + server: &Service, handles: &mut HashMap, next_id: &mut EmulationHandle, rx: &mut Receiver<(ProxyEvent, SocketAddr)>, @@ -163,7 +165,7 @@ impl EmulationProxy { } async fn do_emulation_session( - server: &Server, + server: &Service, emulation: &mut InputEmulation, handles: &mut HashMap, next_id: &mut EmulationHandle, diff --git a/src/lib.rs b/src/lib.rs index 67618b7..ae9c66e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,4 +8,4 @@ mod dns; mod emulation; pub mod emulation_test; mod listen; -pub mod server; +pub mod service; diff --git a/src/listen.rs b/src/listen.rs index d1745e1..dac0b10 100644 --- a/src/listen.rs +++ b/src/listen.rs @@ -16,6 +16,7 @@ use tokio::{ }; use webrtc_dtls::{ config::{ClientAuthType::RequireAnyClientCert, Config, ExtendedMasterSecretType}, + conn::DTLSConn, crypto::Certificate, listener::listen, }; @@ -57,7 +58,7 @@ impl LanMouseListener { move |certs: &[Vec], _chains: &[CertificateDer<'static>]| { assert!(certs.len() == 1); let fingerprints = certs - .into_iter() + .iter() .map(|c| crypto::generate_fingerprint(c)) .collect::>(); if authorized_keys @@ -143,6 +144,25 @@ impl LanMouseListener { } } } + + pub(crate) async fn get_certificate_fingerprint(&self, addr: SocketAddr) -> Option { + if let Some(conn) = self + .conns + .lock() + .await + .iter() + .find(|(a, _)| *a == addr) + .map(|(_, c)| c.clone()) + { + let conn: &DTLSConn = conn.as_any().downcast_ref().expect("dtls conn"); + let certs = conn.connection_state().await.peer_certificates; + let cert = certs.get(0)?; + let fingerprint = crypto::generate_fingerprint(cert); + Some(fingerprint) + } else { + None + } + } } impl Stream for LanMouseListener { diff --git a/src/main.rs b/src/main.rs index c663278..37c7a92 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,7 +5,7 @@ use lan_mouse::{ capture_test, config::{Config, ConfigError, Frontend}, emulation_test, - server::{Server, ServiceError}, + service::{Service, ServiceError}, }; use lan_mouse_ipc::IpcError; use std::{ @@ -101,7 +101,8 @@ fn start_service() -> Result { async fn run_service(config: Config) -> Result<(), ServiceError> { log::info!("Press {:?} to release the mouse", config.release_bind); - Server::new(config).run().await?; + let mut server = Service::new(config).await?; + server.run().await?; log::info!("service exited!"); Ok(()) } diff --git a/src/server.rs b/src/service.rs similarity index 89% rename from src/server.rs rename to src/service.rs index 370e1f5..56841db 100644 --- a/src/server.rs +++ b/src/service.rs @@ -26,6 +26,7 @@ use std::{ use thiserror::Error; use tokio::{signal, sync::Notify}; use tokio_util::sync::CancellationToken; +use webrtc_dtls::crypto::Certificate; #[derive(Debug, Error)] pub enum ServiceError { @@ -44,33 +45,35 @@ pub enum ServiceError { pub struct ReleaseToken; #[derive(Clone)] -pub struct Server { +pub struct Service { active: Rc>>, authorized_keys: Arc>>, - _known_hosts: Rc>>, pub(crate) client_manager: ClientManager, port: Rc>, - public_key_fingerprint: Option, + public_key_fingerprint: String, notifies: Rc, pub(crate) config: Rc, pending_frontend_events: Rc>>, + pending_incoming: Rc>>, capture_status: Rc>, pub(crate) emulation_status: Rc>, pub(crate) should_release: Rc>>, incoming_conns: Rc>>, + cert: Certificate, } #[derive(Default)] struct Notifies { capture: Notify, emulation: Notify, + incoming: Notify, port_changed: Notify, frontend_event_pending: Notify, cancel: CancellationToken, } -impl Server { - pub fn new(config: Config) -> Self { +impl Service { + pub async fn new(config: Config) -> Result { let client_manager = ClientManager::default(); let port = Rc::new(Cell::new(config.port)); for client in config.get_clients() { @@ -96,13 +99,18 @@ impl Server { let config = Rc::new(config); - Self { + // load certificate + let cert = crypto::load_or_generate_key_and_cert(&config.cert_path)?; + let public_key_fingerprint = crypto::certificate_fingerprint(&cert); + + let service = Self { active: Rc::new(Cell::new(None)), authorized_keys: Default::default(), - _known_hosts: Default::default(), - public_key_fingerprint: None, + cert, + public_key_fingerprint, config, client_manager, + pending_incoming: Default::default(), port, notifies, pending_frontend_events: Rc::new(RefCell::new(VecDeque::new())), @@ -110,30 +118,22 @@ impl Server { emulation_status: Default::default(), incoming_conns: Rc::new(RefCell::new(HashMap::new())), should_release: Default::default(), - } + }; + Ok(service) } pub async fn run(&mut self) -> Result<(), ServiceError> { // create frontend communication adapter, exit if already running - let mut frontend = match AsyncFrontendListener::new().await { - Ok(f) => f, - Err(IpcListenerCreationError::AlreadyRunning) => { - log::info!("service already running, exiting"); - return Ok(()); - } - e => e?, - }; - - // load certificate - let cert = crypto::load_or_generate_key_and_cert(&self.config.cert_path)?; - let public_key_fingerprint = crypto::certificate_fingerprint(&cert); - self.public_key_fingerprint.replace(public_key_fingerprint); + let mut frontend_listener = AsyncFrontendListener::new().await?; // listener + connection - let listener = - LanMouseListener::new(self.config.port, cert.clone(), self.authorized_keys.clone()) - .await?; - let conn = LanMouseConnection::new(self.clone(), cert); + let listener = LanMouseListener::new( + self.config.port, + self.cert.clone(), + self.authorized_keys.clone(), + ) + .await?; + let conn = LanMouseConnection::new(self.clone(), self.cert.clone()); // input capture + emulation let mut capture = Capture::new(self.clone(), conn); @@ -148,7 +148,7 @@ impl Server { loop { tokio::select! { - request = frontend.next() => { + request = frontend_listener.next() => { let request = match request { Some(Ok(r)) => r, Some(Err(e)) => { @@ -167,7 +167,15 @@ impl Server { let event = self.pending_frontend_events.borrow_mut().pop_front(); event } { - frontend.broadcast(event).await; + frontend_listener.broadcast(event).await; + } + }, + _ = self.notifies.incoming.notified() => { + while let Some((addr, pos, fingerprint)) = { + let incoming = self.pending_incoming.borrow_mut().pop_front(); + incoming + } { + // capture.register(addr, pos); } }, _ = self.cancelled() => break, @@ -271,10 +279,7 @@ impl Server { self.notify_frontend(FrontendEvent::CaptureStatus(self.capture_status.get())); self.notify_frontend(FrontendEvent::PortChanged(self.port.get(), None)); self.notify_frontend(FrontendEvent::PublicKeyFingerprint( - self.public_key_fingerprint - .as_ref() - .expect("fingerprint") - .clone(), + self.public_key_fingerprint.clone(), )); self.notify_frontend(FrontendEvent::AuthorizedUpdated( self.authorized_keys.read().expect("lock").clone(), @@ -425,7 +430,10 @@ impl Server { self.active.get() } - pub(crate) fn register_incoming(&self, addr: SocketAddr, pos: Position) { - self.incoming_conns.borrow_mut().insert(addr, pos); + pub(crate) fn register_incoming(&self, addr: SocketAddr, pos: Position, fingerprint: String) { + self.pending_incoming + .borrow_mut() + .push_back((addr, pos, fingerprint)); + self.notifies.incoming.notify_one(); } }