diff --git a/Cargo.lock b/Cargo.lock index 7374180..619c9a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "anyhow" +version = "1.0.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" + [[package]] name = "async-trait" version = "0.1.57" @@ -233,6 +239,7 @@ checksum = "879d54834c8c76457ef4293a689b2a8c59b076067ad77b15efafbb05f92a592b" name = "lan-mouse" version = "0.1.0" dependencies = [ + "anyhow", "memmap", "serde", "serde_derive", diff --git a/Cargo.toml b/Cargo.toml index 26713f6..da91d8a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ memmap = "0.7" toml = "0.5" serde = "1.0" serde_derive = "1.0" +anyhow = "1.0.71" [target.'cfg(unix)'.dependencies] wayland-client = { version="0.30.0", optional = true } @@ -30,13 +31,8 @@ winapi = { version = "0.3.9", features = ["winuser"] } [features] -default = [ "wayland", "x11", "xdg_desktop_portal", "libei" ] -wayland = [ - "dep:wayland-client", - "dep:wayland-protocols", - "dep:wayland-protocols-wlr", - "dep:wayland-protocols-misc", - "dep:wayland-protocols-plasma" ] -x11 = [ "dep:x11" ] +default = ["wayland", "x11", "xdg_desktop_portal", "libei"] +wayland = ["dep:wayland-client", "dep:wayland-protocols", "dep:wayland-protocols-wlr", "dep:wayland-protocols-misc", "dep:wayland-protocols-plasma"] +x11 = ["dep:x11"] xdg_desktop_portal = [] libei = [] diff --git a/src/backend/consumer/wlroots.rs b/src/backend/consumer/wlroots.rs index e162645..54b4b70 100644 --- a/src/backend/consumer/wlroots.rs +++ b/src/backend/consumer/wlroots.rs @@ -69,7 +69,7 @@ impl App { (_, _, Ok(fake_input)) => { fake_input.authenticate( "lan-mouse".into(), - "Allow remote clients to control this devices".into(), + "Allow remote clients to control this device".into(), ); VirtualInputManager::Kde { fake_input } } diff --git a/src/client.rs b/src/client.rs index bac18c5..b1c63d3 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,4 +1,4 @@ -use std::{net::SocketAddr, error::Error, fmt::Display}; +use std::{net::SocketAddr, error::Error, fmt::Display, sync::{Arc, atomic::{AtomicBool, Ordering, AtomicU32}, RwLock}}; use crate::{config, dns}; @@ -29,8 +29,9 @@ pub enum ClientEvent { } pub struct ClientManager { - next_id: u32, - clients: Vec, + next_id: AtomicU32, + clients: RwLock>, + subscribers: RwLock>>, } pub type ClientHandle = u32; @@ -47,7 +48,7 @@ impl Display for ClientConfigError { impl Error for ClientConfigError {} impl ClientManager { - fn add_client(&mut self, client: &config::Client, pos: Position) -> Result<(), Box> { + fn add_client(&self, client: &config::Client, pos: Position) -> Result<(), Box> { let ip = match client.ip { Some(ip) => ip, None => match &client.host_name { @@ -60,16 +61,24 @@ impl ClientManager { Ok(()) } - fn new_id(&mut self) -> ClientHandle { - self.next_id += 1; - self.next_id + fn notify(&self) { + for subscriber in self.subscribers.read().unwrap().iter() { + subscriber.store(true, Ordering::SeqCst); + } + } + + fn new_id(&self) -> ClientHandle { + let id = self.next_id.load(Ordering::Acquire); + self.next_id.store(id + 1, Ordering::Release); + id as ClientHandle } pub fn new(config: &config::Config) -> Result> { - let mut client_manager = ClientManager { - next_id: 0, - clients: Vec::new(), + let client_manager = ClientManager { + next_id: AtomicU32::new(0), + clients: RwLock::new(Vec::new()), + subscribers: RwLock::new(vec![]), }; // add clients from config @@ -80,13 +89,18 @@ impl ClientManager { Ok(client_manager) } - pub fn register_client(&mut self, addr: SocketAddr, pos: Position) { + pub fn register_client(&self, addr: SocketAddr, pos: Position) { let handle = self.new_id(); let client = Client { addr, pos, handle }; - self.clients.push(client); + self.clients.write().unwrap().push(client); + self.notify(); } pub fn get_clients(&self) -> Vec { - self.clients.clone() + self.clients.read().unwrap().clone() + } + + pub fn subscribe(&self, subscriber: Arc) { + self.subscribers.write().unwrap().push(subscriber); } } diff --git a/src/event/server.rs b/src/event/server.rs index 41d4713..11c57f6 100644 --- a/src/event/server.rs +++ b/src/event/server.rs @@ -1,3 +1,5 @@ +use anyhow::Result; + use std::{ collections::HashMap, error::Error, @@ -10,7 +12,7 @@ use std::{ thread::{self, JoinHandle}, }; -use crate::client::{ClientHandle, ClientManager}; +use crate::{client::{ClientHandle, ClientManager}, ioutils::{ask_confirmation, ask_position}}; use super::Event; @@ -30,25 +32,25 @@ impl Server { } pub fn run( - self, - client_manager: &mut ClientManager, + &self, + client_manager: Arc, produce_rx: Receiver<(Event, ClientHandle)>, consume_tx: SyncSender<(Event, ClientHandle)>, - ) -> Result<(JoinHandle<()>, JoinHandle<()>), Box> { + ) -> Result<(JoinHandle>, JoinHandle>), Box> { let udp_socket = UdpSocket::bind(self.listen_addr)?; let rx = udp_socket.try_clone()?; let tx = udp_socket; let sending = self.sending.clone(); + let clients_updated = Arc::new(AtomicBool::new(true)); + client_manager.subscribe(clients_updated.clone()); + let client_manager_clone = client_manager.clone(); - let mut client_for_socket = HashMap::new(); - for client in client_manager.get_clients() { - println!("{}: {}", client.handle, client.addr); - client_for_socket.insert(client.addr, client.handle); - } let receiver = thread::Builder::new() .name("event receiver".into()) .spawn(move || { + let mut client_for_socket = HashMap::new(); + loop { let (event, addr) = match Server::receive_event(&rx) { Ok(e) => e, @@ -58,10 +60,30 @@ impl Server { } }; + if let Ok(_) = clients_updated.compare_exchange( + true, + false, + Ordering::SeqCst, + Ordering::SeqCst, + ) { + clients_updated.store(false, Ordering::SeqCst); + client_for_socket.clear(); + println!("updating clients: "); + for client in client_manager_clone.get_clients() { + println!("{}: {}", client.handle, client.addr); + client_for_socket.insert(client.addr, client.handle); + } + } + let client_handle = match client_for_socket.get(&addr) { Some(c) => *c, None => { - println!("Allow connection from {:?}? [Y/n]", addr); + eprint!("Allow connection from {:?}? ", addr); + if ask_confirmation(false)? { + client_manager_clone.register_client(addr, ask_position()?); + } else { + eprintln!("rejecting client: {:?}?", addr); + } continue; } }; diff --git a/src/ioutils.rs b/src/ioutils.rs new file mode 100644 index 0000000..ec03cfe --- /dev/null +++ b/src/ioutils.rs @@ -0,0 +1,49 @@ +use std::io::{self, Write}; + +use crate::client::Position; + + +pub fn ask_confirmation(default: bool) -> Result { + eprint!("{}", if default {" [Y,n] "} else { " [y,N] "}); + io::stderr().flush()?; + let answer = loop { + let mut buffer = String::new(); + io::stdin().read_line(&mut buffer)?; + let answer = buffer.to_lowercase(); + let answer = answer.trim(); + match answer { + "" => break default, + "y" => break true, + "n" => break false, + _ => { + eprint!("Enter y for Yes or n for No: "); + io::stderr().flush()?; + continue + } + } + }; + Ok(answer) +} + +pub fn ask_position() -> Result { + eprint!("Enter position - top (t) | bottom (b) | left(l) | right(r): "); + io::stderr().flush()?; + let pos = loop { + let mut buffer = String::new(); + io::stdin().read_line(&mut buffer)?; + let answer = buffer.to_lowercase(); + let answer = answer.trim(); + match answer { + "t" | "top" => break Position::Top, + "b" | "bottom" => break Position::Bottom, + "l" | "left" => break Position::Right, + "r" | "right" => break Position::Left, + _ => { + eprint!("Invalid position: {answer} - enter top (t) | bottom (b) | left(l) | right(r): "); + io::stderr().flush()?; + continue + } + }; + }; + Ok(pos) +} diff --git a/src/lib.rs b/src/lib.rs index cb0d557..391e042 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,3 +8,4 @@ pub mod consumer; pub mod producer; pub mod backend; +pub mod ioutils; diff --git a/src/main.rs b/src/main.rs index cb45d9c..5c2ef1e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use std::{sync::mpsc, process, env}; +use std::{sync::{mpsc, Arc}, process, env}; use lan_mouse::{ client::ClientManager, @@ -32,7 +32,7 @@ pub fn main() { let (consume_tx, consume_rx) = mpsc::sync_channel(128); // create client manager - let mut client_manager = match ClientManager::new(&config) { + let client_manager = match ClientManager::new(&config) { Err(e) => { eprintln!("{e}"); process::exit(1); @@ -73,7 +73,7 @@ pub fn main() { process::exit(1); } }; - let (receiver, sender) = match event_server.run(&mut client_manager, produce_rx, consume_tx) { + let (receiver, sender) = match event_server.run(Arc::new(client_manager), produce_rx, consume_tx) { Ok((r,s)) => (r,s), Err(e) => { eprintln!("{e}"); @@ -84,7 +84,11 @@ pub fn main() { request_thread.join().unwrap(); // stop receiving events and terminate event-consumer - receiver.join().unwrap(); + if let Err(e) = receiver.join().unwrap() { + eprint!("{e}"); + process::exit(1); + } + if let Some(thread) = event_consumer { thread.join().unwrap(); } @@ -93,5 +97,9 @@ pub fn main() { if let Some(thread) = event_producer { thread.join().unwrap(); } - sender.join().unwrap(); + + if let Err(e) = sender.join().unwrap() { + eprint!("{e}"); + process::exit(1); + } }