diff --git a/input-capture/src/dummy.rs b/input-capture/src/dummy.rs index f5f8639..24da6d9 100644 --- a/input-capture/src/dummy.rs +++ b/input-capture/src/dummy.rs @@ -8,7 +8,7 @@ use futures_core::Stream; use input_event::PointerEvent; use tokio::time::{self, Instant, Interval}; -use super::{Capture, CaptureError, CaptureEvent, CaptureHandle, Position}; +use super::{Capture, CaptureError, CaptureEvent, Position}; pub struct DummyInputCapture { start: Option, @@ -34,11 +34,11 @@ impl Default for DummyInputCapture { #[async_trait] impl Capture for DummyInputCapture { - async fn create(&mut self, _handle: CaptureHandle, _pos: Position) -> Result<(), CaptureError> { + async fn create(&mut self, _pos: Position) -> Result<(), CaptureError> { Ok(()) } - async fn destroy(&mut self, _handle: CaptureHandle) -> Result<(), CaptureError> { + async fn destroy(&mut self, _pos: Position) -> Result<(), CaptureError> { Ok(()) } @@ -55,7 +55,7 @@ const FREQUENCY_HZ: f64 = 1.0; const RADIUS: f64 = 100.0; impl Stream for DummyInputCapture { - type Item = Result<(CaptureHandle, CaptureEvent), CaptureError>; + type Item = Result<(Position, CaptureEvent), CaptureError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let current = ready!(self.interval.poll_tick(cx)); @@ -81,6 +81,6 @@ impl Stream for DummyInputCapture { })) } }; - Poll::Ready(Some(Ok((0, event)))) + Poll::Ready(Some(Ok((Position::Left, event)))) } } diff --git a/input-capture/src/lib.rs b/input-capture/src/lib.rs index a4cc39b..7efcaa5 100644 --- a/input-capture/src/lib.rs +++ b/input-capture/src/lib.rs @@ -1,4 +1,9 @@ -use std::{collections::HashSet, fmt::Display, task::Poll}; +use std::{ + collections::{HashMap, HashSet, VecDeque}, + fmt::Display, + mem::swap, + task::{ready, Poll}, +}; use async_trait::async_trait; use futures::StreamExt; @@ -112,19 +117,48 @@ impl Display for Backend { } pub struct InputCapture { + /// capture backend capture: Box, + /// keys pressed by active capture pressed_keys: HashSet, + /// map from position to ids + position_map: HashMap>, + /// map from id to position + id_map: HashMap, + /// pending events + pending: VecDeque<(CaptureHandle, CaptureEvent)>, } impl InputCapture { /// create a new client with the given id pub async fn create(&mut self, id: CaptureHandle, pos: Position) -> Result<(), CaptureError> { - self.capture.create(id, pos).await + if let Some(v) = self.position_map.get_mut(&pos) { + v.push(id); + Ok(()) + } else { + self.position_map.insert(pos, vec![id]); + self.id_map.insert(id, pos); + self.capture.create(pos).await + } } /// destroy the client with the given id, if it exists pub async fn destroy(&mut self, id: CaptureHandle) -> Result<(), CaptureError> { - self.capture.destroy(id).await + if let Some(pos) = self.id_map.remove(&id) { + let destroy = if let Some(v) = self.position_map.get_mut(&pos) { + v.retain(|&i| i != id); + // we were the last id registered at this position + v.is_empty() + } else { + // nothing to destroy + false + }; + if destroy { + self.position_map.remove(&pos); + self.capture.destroy(pos).await?; + } + } + Ok(()) } /// release mouse @@ -143,6 +177,9 @@ impl InputCapture { let capture = create(backend).await?; Ok(Self { capture, + id_map: Default::default(), + pending: Default::default(), + position_map: Default::default(), pressed_keys: HashSet::new(), }) } @@ -170,29 +207,65 @@ impl Stream for InputCapture { mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> Poll> { - match self.capture.poll_next_unpin(cx) { - Poll::Ready(e) => { - if let Some(Ok(( - _, - CaptureEvent::Input(Event::Keyboard(KeyboardEvent::Key { key, state, .. })), - ))) = e + if let Some(e) = self.pending.pop_front() { + return Poll::Ready(Some(Ok(e))); + } + + // ready + let event = ready!(self.capture.poll_next_unpin(cx)); + + // stream closed + let event = match event { + Some(e) => e, + None => return Poll::Ready(None), + }; + + // error occurred + let (pos, event) = match event { + Ok(e) => e, + Err(e) => return Poll::Ready(Some(Err(e))), + }; + + // handle key presses + if let CaptureEvent::Input(Event::Keyboard(KeyboardEvent::Key { key, state, .. })) = event { + self.update_pressed_keys(key, state); + } + + let len = self + .position_map + .get(&pos) + .map(|ids| ids.len()) + .unwrap_or(0); + + match len { + 0 => Poll::Pending, + 1 => Poll::Ready(Some(Ok(( + self.position_map.get(&pos).expect("no id")[0], + event, + )))), + _ => { + let mut position_map = HashMap::new(); + swap(&mut self.position_map, &mut position_map); { - self.update_pressed_keys(key, state); + for &id in position_map.get(&pos).expect("position") { + self.pending.push_back((id, event)); + } } - Poll::Ready(e) + swap(&mut self.position_map, &mut position_map); + + Poll::Ready(Some(Ok(self.pending.pop_front().expect("event")))) } - Poll::Pending => Poll::Pending, } } } #[async_trait] -trait Capture: Stream> + Unpin { +trait Capture: Stream> + Unpin { /// create a new client with the given id - async fn create(&mut self, id: CaptureHandle, pos: Position) -> Result<(), CaptureError>; + async fn create(&mut self, pos: Position) -> Result<(), CaptureError>; /// destroy the client with the given id, if it exists - async fn destroy(&mut self, id: CaptureHandle) -> Result<(), CaptureError>; + async fn destroy(&mut self, pos: Position) -> Result<(), CaptureError>; /// release mouse async fn release(&mut self) -> Result<(), CaptureError>; @@ -204,14 +277,14 @@ trait Capture: Stream async fn create_backend( backend: Backend, ) -> Result< - Box>>, + Box>>, CaptureCreationError, > { match backend { #[cfg(all(unix, feature = "libei", not(target_os = "macos")))] Backend::InputCapturePortal => Ok(Box::new(libei::LibeiInputCapture::new().await?)), #[cfg(all(unix, feature = "wayland", not(target_os = "macos")))] - Backend::LayerShell => Ok(Box::new(wayland::WaylandInputCapture::new()?)), + Backend::LayerShell => Ok(Box::new(wayland::LayerShellInputCapture::new()?)), #[cfg(all(unix, feature = "x11", not(target_os = "macos")))] Backend::X11 => Ok(Box::new(x11::X11InputCapture::new()?)), #[cfg(windows)] @@ -225,7 +298,7 @@ async fn create_backend( async fn create( backend: Option, ) -> Result< - Box>>, + Box>>, CaptureCreationError, > { if let Some(backend) = backend { diff --git a/input-capture/src/libei.rs b/input-capture/src/libei.rs index b169ee4..2e25823 100644 --- a/input-capture/src/libei.rs +++ b/input-capture/src/libei.rs @@ -40,7 +40,7 @@ use crate::CaptureEvent; use super::{ error::{CaptureError, LibeiCaptureCreationError, ReisConvertEventStreamError}, - Capture as LanMouseInputCapture, CaptureHandle, Position, + Capture as LanMouseInputCapture, Position, }; /* there is a bug in xdg-remote-desktop-portal-gnome / mutter that @@ -50,15 +50,15 @@ use super::{ /// events that necessitate restarting the capture session #[derive(Clone, Copy, Debug)] enum LibeiNotifyEvent { - Create(CaptureHandle, Position), - Destroy(CaptureHandle), + Create(Position), + Destroy(Position), } #[allow(dead_code)] pub struct LibeiInputCapture<'a> { input_capture: Pin>>, capture_task: JoinHandle>, - event_rx: Receiver<(CaptureHandle, CaptureEvent)>, + event_rx: Receiver<(Position, CaptureEvent)>, notify_capture: Sender, notify_release: Arc, cancellation_token: CancellationToken, @@ -117,13 +117,13 @@ impl From for Barrier { fn select_barriers( zones: &Zones, - clients: &[(CaptureHandle, Position)], + clients: &[Position], next_barrier_id: &mut u32, -) -> (Vec, HashMap) { - let mut client_for_barrier = HashMap::new(); +) -> (Vec, HashMap) { + let mut pos_for_barrier = HashMap::new(); let mut barriers: Vec = vec![]; - for (handle, pos) in clients { + for pos in clients { let mut client_barriers = zones .regions() .iter() @@ -131,21 +131,21 @@ fn select_barriers( let id = *next_barrier_id; *next_barrier_id = id + 1; let position = pos_to_barrier(r, *pos); - client_for_barrier.insert(id, *handle); + pos_for_barrier.insert(id, *pos); ICBarrier::new(id, position) }) .collect(); barriers.append(&mut client_barriers); } - (barriers, client_for_barrier) + (barriers, pos_for_barrier) } async fn update_barriers( input_capture: &InputCapture<'_>, session: &Session<'_, InputCapture<'_>>, - active_clients: &[(CaptureHandle, Position)], + active_clients: &[Position], next_barrier_id: &mut u32, -) -> Result<(Vec, HashMap), ashpd::Error> { +) -> Result<(Vec, HashMap), ashpd::Error> { let zones = input_capture.zones(session).await?.response()?; log::debug!("zones: {zones:?}"); @@ -203,9 +203,9 @@ async fn connect_to_eis( async fn libei_event_handler( mut ei_event_stream: EiConvertEventStream, context: ei::Context, - event_tx: Sender<(CaptureHandle, CaptureEvent)>, + event_tx: Sender<(Position, CaptureEvent)>, release_session: Arc, - current_client: Rc>>, + current_pos: Rc>>, ) -> Result<(), CaptureError> { loop { let ei_event = ei_event_stream @@ -214,7 +214,7 @@ async fn libei_event_handler( .ok_or(CaptureError::EndOfStream)? .map_err(ReisConvertEventStreamError::from)?; log::trace!("from ei: {ei_event:?}"); - let client = current_client.get(); + let client = current_pos.get(); handle_ei_event(ei_event, client, &context, &event_tx, &release_session).await?; } } @@ -260,14 +260,14 @@ async fn do_capture( mut capture_event: Receiver, notify_release: Arc, session: Option<(Session<'_, InputCapture<'_>>, BitFlags)>, - event_tx: Sender<(CaptureHandle, CaptureEvent)>, + event_tx: Sender<(Position, CaptureEvent)>, cancellation_token: CancellationToken, ) -> Result<(), CaptureError> { let mut session = session.map(|s| s.0); /* safety: libei_task does not outlive Self */ let input_capture = unsafe { &*input_capture }; - let mut active_clients: Vec<(CaptureHandle, Position)> = vec![]; + let mut active_clients: Vec = vec![]; let mut next_barrier_id = 1u32; let mut zones_changed = input_capture.receive_zones_changed().await?; @@ -341,8 +341,8 @@ async fn do_capture( // update clients if requested if let Some(event) = capture_event_occured.take() { match event { - LibeiNotifyEvent::Create(c, p) => active_clients.push((c, p)), - LibeiNotifyEvent::Destroy(c) => active_clients.retain(|(h, _)| *h != c), + LibeiNotifyEvent::Create(p) => active_clients.push(p), + LibeiNotifyEvent::Destroy(p) => active_clients.retain(|&pos| pos != p), } } @@ -356,21 +356,21 @@ async fn do_capture( async fn do_capture_session( input_capture: &InputCapture<'_>, session: &mut Session<'_, InputCapture<'_>>, - event_tx: &Sender<(CaptureHandle, CaptureEvent)>, - active_clients: &[(CaptureHandle, Position)], + event_tx: &Sender<(Position, CaptureEvent)>, + active_clients: &[Position], next_barrier_id: &mut u32, notify_release: &Notify, cancel: (CancellationToken, CancellationToken), ) -> Result<(), CaptureError> { let (cancel_session, cancel_update) = cancel; // current client - let current_client = Rc::new(Cell::new(None)); + let current_pos = Rc::new(Cell::new(None)); // connect to eis server let (context, ei_event_stream) = connect_to_eis(input_capture, session).await?; // set barriers - let (barriers, client_for_barrier_id) = + let (barriers, pos_for_barrier_id) = update_barriers(input_capture, session, active_clients, next_barrier_id).await?; log::debug!("enabling session"); @@ -382,7 +382,7 @@ async fn do_capture_session( // async event task let cancel_ei_handler = CancellationToken::new(); let event_chan = event_tx.clone(); - let client = current_client.clone(); + let pos = current_pos.clone(); let cancel_session_clone = cancel_session.clone(); let release_session_clone = release_session.clone(); let cancel_ei_handler_clone = cancel_ei_handler.clone(); @@ -393,7 +393,7 @@ async fn do_capture_session( context, event_chan, release_session_clone, - client, + pos, ) => { log::debug!("libei exited: {r:?} cancelling session task"); cancel_session_clone.cancel(); @@ -421,11 +421,11 @@ async fn do_capture_session( }; // find client corresponding to barrier - let client = *client_for_barrier_id.get(&barrier_id).expect("invalid barrier id"); - current_client.replace(Some(client)); + let pos = *pos_for_barrier_id.get(&barrier_id).expect("invalid barrier id"); + current_pos.replace(Some(pos)); // client entered => send event - event_tx.send((client, CaptureEvent::Begin)).await.expect("no channel"); + event_tx.send((pos, CaptureEvent::Begin)).await.expect("no channel"); tokio::select! { _ = notify_release.notified() => { /* capture release */ @@ -441,7 +441,7 @@ async fn do_capture_session( }, } - release_capture(input_capture, session, activated, client, active_clients).await?; + release_capture(input_capture, session, activated, pos).await?; } _ = notify_release.notified() => { /* capture release -> we are not capturing anyway, so ignore */ @@ -484,8 +484,7 @@ async fn release_capture<'a>( input_capture: &InputCapture<'a>, session: &Session<'a, InputCapture<'a>>, activated: Activated, - current_client: CaptureHandle, - active_clients: &[(CaptureHandle, Position)], + current_pos: Position, ) -> Result<(), CaptureError> { if let Some(activation_id) = activated.activation_id() { log::debug!("releasing input capture {activation_id}"); @@ -494,13 +493,7 @@ async fn release_capture<'a>( .cursor_position() .expect("compositor did not report cursor position!"); log::debug!("client entered @ ({x}, {y})"); - let pos = active_clients - .iter() - .filter(|(c, _)| *c == current_client) - .map(|(_, p)| p) - .next() - .unwrap(); // FIXME - let (dx, dy) = match pos { + let (dx, dy) = match current_pos { // offset cursor position to not enter again immediately Position::Left => (1., 0.), Position::Right => (-1., 0.), @@ -554,9 +547,9 @@ static ALL_CAPABILITIES: &[DeviceCapability] = &[ async fn handle_ei_event( ei_event: EiEvent, - current_client: Option, + current_client: Option, context: &ei::Context, - event_tx: &Sender<(CaptureHandle, CaptureEvent)>, + event_tx: &Sender<(Position, CaptureEvent)>, release_session: &Notify, ) -> Result<(), CaptureError> { match ei_event { @@ -575,9 +568,9 @@ async fn handle_ei_event( return Err(CaptureError::Disconnected(format!("{:?}", d.reason))) } _ => { - if let Some(handle) = current_client { + if let Some(pos) = current_client { for event in Event::from_ei_event(ei_event) { - event_tx.send((handle, CaptureEvent::Input(event))).await.expect("no channel"); + event_tx.send((pos, CaptureEvent::Input(event))).await.expect("no channel"); } } } @@ -587,18 +580,18 @@ async fn handle_ei_event( #[async_trait] impl<'a> LanMouseInputCapture for LibeiInputCapture<'a> { - async fn create(&mut self, handle: CaptureHandle, pos: Position) -> Result<(), CaptureError> { + async fn create(&mut self, pos: Position) -> Result<(), CaptureError> { let _ = self .notify_capture - .send(LibeiNotifyEvent::Create(handle, pos)) + .send(LibeiNotifyEvent::Create(pos)) .await; Ok(()) } - async fn destroy(&mut self, handle: CaptureHandle) -> Result<(), CaptureError> { + async fn destroy(&mut self, pos: Position) -> Result<(), CaptureError> { let _ = self .notify_capture - .send(LibeiNotifyEvent::Destroy(handle)) + .send(LibeiNotifyEvent::Destroy(pos)) .await; Ok(()) } @@ -629,7 +622,7 @@ impl<'a> Drop for LibeiInputCapture<'a> { } impl<'a> Stream for LibeiInputCapture<'a> { - type Item = Result<(CaptureHandle, CaptureEvent), CaptureError>; + type Item = Result<(Position, CaptureEvent), CaptureError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.capture_task.poll_unpin(cx) { diff --git a/input-capture/src/macos.rs b/input-capture/src/macos.rs index f293800..5ff1ded 100644 --- a/input-capture/src/macos.rs +++ b/input-capture/src/macos.rs @@ -1,6 +1,4 @@ -use super::{ - error::MacosCaptureCreationError, Capture, CaptureError, CaptureEvent, CaptureHandle, Position, -}; +use super::{error::MacosCaptureCreationError, Capture, CaptureError, CaptureEvent, Position}; use async_trait::async_trait; use bitflags::bitflags; use core_foundation::base::{kCFAllocatorDefault, CFRelease}; @@ -20,7 +18,7 @@ use input_event::{Event, KeyboardEvent, PointerEvent, BTN_LEFT, BTN_MIDDLE, BTN_ use keycode::{KeyMap, KeyMapping}; use libc::c_void; use once_cell::unsync::Lazy; -use std::collections::HashMap; +use std::collections::HashSet; use std::ffi::{c_char, CString}; use std::pin::Pin; use std::sync::Arc; @@ -39,44 +37,44 @@ struct Bounds { #[derive(Debug)] struct InputCaptureState { - client_for_pos: Lazy>, - current_client: Option<(CaptureHandle, Position)>, + active_clients: Lazy>, + current_pos: Option, bounds: Bounds, } #[derive(Debug)] enum ProducerEvent { Release, - Create(CaptureHandle, Position), - Destroy(CaptureHandle), - Grab((CaptureHandle, Position)), + Create(Position), + Destroy(Position), + Grab(Position), EventTapDisabled, } impl InputCaptureState { fn new() -> Result { let mut res = Self { - client_for_pos: Lazy::new(HashMap::new), - current_client: None, + active_clients: Lazy::new(HashSet::new), + current_pos: None, bounds: Bounds::default(), }; res.update_bounds()?; Ok(res) } - fn crossed(&mut self, event: &CGEvent) -> Option<(CaptureHandle, Position)> { + fn crossed(&mut self, event: &CGEvent) -> Option { let location = event.location(); let relative_x = event.get_double_value_field(EventField::MOUSE_EVENT_DELTA_X); let relative_y = event.get_double_value_field(EventField::MOUSE_EVENT_DELTA_Y); - for (position, client) in self.client_for_pos.iter() { - if (position == &Position::Left && (location.x + relative_x) <= self.bounds.xmin) - || (position == &Position::Right && (location.x + relative_x) >= self.bounds.xmax) - || (position == &Position::Top && (location.y + relative_y) <= self.bounds.ymin) - || (position == &Position::Bottom && (location.y + relative_y) >= self.bounds.ymax) + for &position in self.active_clients.iter() { + if (position == Position::Left && (location.x + relative_x) <= self.bounds.xmin) + || (position == Position::Right && (location.x + relative_x) >= self.bounds.xmax) + || (position == Position::Top && (location.y + relative_y) <= self.bounds.ymin) + || (position == Position::Bottom && (location.y + relative_y) >= self.bounds.ymax) { - log::debug!("Crossed barrier into client: {client}, {position:?}"); - return Some((*client, *position)); + log::debug!("Crossed barrier into position: {position:?}"); + return Some(position); } } None @@ -102,7 +100,7 @@ impl InputCaptureState { // to the edge of the screen, the cursor will be hidden but we dont want it to appear in a // random location when we exit the client fn reset_mouse_position(&self, event: &CGEvent) -> Result<(), CaptureError> { - if let Some((_, pos)) = self.current_client { + if let Some(pos) = self.current_pos { let location = event.location(); let edge_offset = 1.0; @@ -146,40 +144,31 @@ impl InputCaptureState { log::debug!("handling event: {producer_event:?}"); match producer_event { ProducerEvent::Release => { - if self.current_client.is_some() { + if self.current_pos.is_some() { CGDisplay::show_cursor(&CGDisplay::main()) .map_err(CaptureError::CoreGraphics)?; - self.current_client = None; + self.current_pos = None; } } - ProducerEvent::Grab(client) => { - if self.current_client.is_none() { + ProducerEvent::Grab(pos) => { + if self.current_pos.is_none() { CGDisplay::hide_cursor(&CGDisplay::main()) .map_err(CaptureError::CoreGraphics)?; - self.current_client = Some(client); + self.current_pos = Some(pos); } } - ProducerEvent::Create(c, p) => { - self.client_for_pos.insert(p, c); + ProducerEvent::Create(p) => { + self.active_clients.insert(p); } - ProducerEvent::Destroy(c) => { - for pos in [ - Position::Left, - Position::Right, - Position::Top, - Position::Bottom, - ] { - if let Some((current_c, _)) = self.current_client { - if current_c == c { - CGDisplay::show_cursor(&CGDisplay::main()) - .map_err(CaptureError::CoreGraphics)?; - self.current_client = None; - }; - } - if self.client_for_pos.get(&pos).copied() == Some(c) { - self.client_for_pos.remove(&pos); - } + ProducerEvent::Destroy(p) => { + if let Some(current) = self.current_pos { + if current == p { + CGDisplay::show_cursor(&CGDisplay::main()) + .map_err(CaptureError::CoreGraphics)?; + self.current_pos = None; + }; } + self.active_clients.remove(&p); } ProducerEvent::EventTapDisabled => return Err(CaptureError::EventTapDisabled), }; @@ -335,7 +324,7 @@ fn get_events( fn event_tap_thread( client_state: Arc>, - event_tx: Sender<(CaptureHandle, CaptureEvent)>, + event_tx: Sender<(Position, CaptureEvent)>, notify_tx: Sender, exit: tokio::sync::oneshot::Sender>, ) { @@ -364,7 +353,7 @@ fn event_tap_thread( |_proxy: CGEventTapProxy, event_type: CGEventType, cg_ev: &CGEvent| { log::trace!("Got event from tap: {event_type:?}"); let mut state = client_state.blocking_lock(); - let mut client = None; + let mut pos = None; let mut res_events = vec![]; if matches!( @@ -380,8 +369,8 @@ fn event_tap_thread( } // Are we in a client? - if let Some((current_client, _)) = state.current_client { - client = Some(current_client); + if let Some(current_pos) = state.current_pos { + pos = Some(current_pos); get_events(&event_type, cg_ev, &mut res_events).unwrap_or_else(|e| { log::error!("Failed to get events: {e}"); }); @@ -395,19 +384,19 @@ fn event_tap_thread( } // Did we cross a barrier? else if matches!(event_type, CGEventType::MouseMoved) { - if let Some((new_client, pos)) = state.crossed(cg_ev) { - client = Some(new_client); + if let Some(new_pos) = state.crossed(cg_ev) { + pos = Some(new_pos); res_events.push(CaptureEvent::Begin); notify_tx - .blocking_send(ProducerEvent::Grab((new_client, pos))) + .blocking_send(ProducerEvent::Grab(new_pos)) .expect("Failed to send notification"); } } - if let Some(client) = client { + if let Some(pos) = pos { res_events.iter().for_each(|e| { event_tx - .blocking_send((client, *e)) + .blocking_send((pos, *e)) .expect("Failed to send event"); }); // Returning None should stop the event from being processed @@ -434,7 +423,7 @@ fn event_tap_thread( } pub struct MacOSInputCapture { - event_rx: Receiver<(CaptureHandle, CaptureEvent)>, + event_rx: Receiver<(Position, CaptureEvent)>, notify_tx: Sender, } @@ -491,21 +480,21 @@ impl MacOSInputCapture { #[async_trait] impl Capture for MacOSInputCapture { - async fn create(&mut self, id: CaptureHandle, pos: Position) -> Result<(), CaptureError> { + async fn create(&mut self, pos: Position) -> Result<(), CaptureError> { let notify_tx = self.notify_tx.clone(); tokio::task::spawn_local(async move { - log::debug!("creating client {id}, {pos}"); - let _ = notify_tx.send(ProducerEvent::Create(id, pos)).await; + log::debug!("creating capture, {pos}"); + let _ = notify_tx.send(ProducerEvent::Create(pos)).await; log::debug!("done !"); }); Ok(()) } - async fn destroy(&mut self, id: CaptureHandle) -> Result<(), CaptureError> { + async fn destroy(&mut self, pos: Position) -> Result<(), CaptureError> { let notify_tx = self.notify_tx.clone(); tokio::task::spawn_local(async move { - log::debug!("destroying client {id}"); - let _ = notify_tx.send(ProducerEvent::Destroy(id)).await; + log::debug!("destroying capture {pos}"); + let _ = notify_tx.send(ProducerEvent::Destroy(pos)).await; log::debug!("done !"); }); Ok(()) @@ -526,7 +515,7 @@ impl Capture for MacOSInputCapture { } impl Stream for MacOSInputCapture { - type Item = Result<(CaptureHandle, CaptureEvent), CaptureError>; + type Item = Result<(Position, CaptureEvent), CaptureError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match ready!(self.event_rx.poll_recv(cx)) { diff --git a/input-capture/src/wayland.rs b/input-capture/src/wayland.rs index e990e3e..6f2e498 100644 --- a/input-capture/src/wayland.rs +++ b/input-capture/src/wayland.rs @@ -64,7 +64,7 @@ use crate::{CaptureError, CaptureEvent}; use super::{ error::{LayerShellCaptureCreationError, WaylandBindError}, - Capture, CaptureHandle, Position, + Capture, Position, }; struct Globals { @@ -102,13 +102,13 @@ struct State { pointer_lock: Option, rel_pointer: Option, shortcut_inhibitor: Option, - client_for_window: Vec<(Arc, CaptureHandle)>, - focused: Option<(Arc, CaptureHandle)>, + active_windows: Vec>, + focused: Option>, g: Globals, wayland_fd: RawFd, read_guard: Option, qh: QueueHandle, - pending_events: VecDeque<(CaptureHandle, CaptureEvent)>, + pending_events: VecDeque<(Position, CaptureEvent)>, output_info: Vec<(WlOutput, OutputInfo)>, scroll_discrete_pending: bool, } @@ -124,7 +124,7 @@ impl AsRawFd for Inner { } } -pub struct WaylandInputCapture(AsyncFd); +pub struct LayerShellInputCapture(AsyncFd); struct Window { buffer: wl_buffer::WlBuffer, @@ -256,7 +256,7 @@ fn draw(f: &mut File, (width, height): (u32, u32)) { } } -impl WaylandInputCapture { +impl LayerShellInputCapture { pub fn new() -> std::result::Result { let conn = Connection::connect_to_env()?; let (g, mut queue) = registry_queue_init::(&conn)?; @@ -323,7 +323,7 @@ impl WaylandInputCapture { pointer_lock: None, rel_pointer: None, shortcut_inhibitor: None, - client_for_window: Vec::new(), + active_windows: Vec::new(), focused: None, qh, wayland_fd, @@ -370,23 +370,18 @@ impl WaylandInputCapture { let inner = AsyncFd::new(Inner { queue, state })?; - Ok(WaylandInputCapture(inner)) + Ok(LayerShellInputCapture(inner)) } - fn add_client(&mut self, handle: CaptureHandle, pos: Position) { - self.0.get_mut().state.add_client(handle, pos); + fn add_client(&mut self, pos: Position) { + self.0.get_mut().state.add_client(pos); } - fn delete_client(&mut self, handle: CaptureHandle) { + fn delete_client(&mut self, pos: Position) { let inner = self.0.get_mut(); // remove all windows corresponding to this client - while let Some(i) = inner - .state - .client_for_window - .iter() - .position(|(_, c)| *c == handle) - { - inner.state.client_for_window.remove(i); + while let Some(i) = inner.state.active_windows.iter().position(|w| w.pos == pos) { + inner.state.active_windows.remove(i); inner.state.focused = None; } } @@ -400,7 +395,7 @@ impl State { serial: u32, qh: &QueueHandle, ) { - let (window, _) = self.focused.as_ref().unwrap(); + let window = self.focused.as_ref().unwrap(); // hide the cursor pointer.set_cursor(serial, None, 0, 0); @@ -443,7 +438,7 @@ impl State { fn ungrab(&mut self) { // get focused client - let (window, _client) = match self.focused.as_ref() { + let window = match self.focused.as_ref() { Some(focused) => focused, None => return, }; @@ -473,27 +468,23 @@ impl State { } } - fn add_client(&mut self, client: CaptureHandle, pos: Position) { + fn add_client(&mut self, pos: Position) { let outputs = get_output_configuration(self, pos); log::debug!("outputs: {outputs:?}"); outputs.iter().for_each(|(o, i)| { let window = Window::new(self, &self.qh, o, pos, i.size); let window = Arc::new(window); - self.client_for_window.push((window, client)); + self.active_windows.push(window); }); } fn update_windows(&mut self) { log::debug!("updating windows"); log::debug!("output info: {:?}", self.output_info); - let clients: Vec<_> = self - .client_for_window - .drain(..) - .map(|(w, c)| (c, w.pos)) - .collect(); - for (client, pos) in clients { - self.add_client(client, pos); + let clients: Vec<_> = self.active_windows.drain(..).map(|w| w.pos).collect(); + for pos in clients { + self.add_client(pos); } } } @@ -566,15 +557,15 @@ impl Inner { } #[async_trait] -impl Capture for WaylandInputCapture { - async fn create(&mut self, handle: CaptureHandle, pos: Position) -> Result<(), CaptureError> { - self.add_client(handle, pos); +impl Capture for LayerShellInputCapture { + async fn create(&mut self, pos: Position) -> Result<(), CaptureError> { + self.add_client(pos); let inner = self.0.get_mut(); Ok(inner.flush_events()?) } - async fn destroy(&mut self, handle: CaptureHandle) -> Result<(), CaptureError> { - self.delete_client(handle); + async fn destroy(&mut self, pos: Position) -> Result<(), CaptureError> { + self.delete_client(pos); let inner = self.0.get_mut(); Ok(inner.flush_events()?) } @@ -591,8 +582,8 @@ impl Capture for WaylandInputCapture { } } -impl Stream for WaylandInputCapture { - type Item = Result<(CaptureHandle, CaptureEvent), CaptureError>; +impl Stream for LayerShellInputCapture { + type Item = Result<(Position, CaptureEvent), CaptureError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if let Some(event) = self.0.get_mut().state.pending_events.pop_front() { @@ -685,23 +676,20 @@ impl Dispatch for State { } => { // get client corresponding to the focused surface { - if let Some((window, client)) = app - .client_for_window - .iter() - .find(|(w, _c)| w.surface == surface) - { - app.focused = Some((window.clone(), *client)); + if let Some(window) = app.active_windows.iter().find(|w| w.surface == surface) { + app.focused = Some(window.clone()); app.grab(&surface, pointer, serial, qh); } else { return; } } - let (_, client) = app - .client_for_window + let pos = app + .active_windows .iter() - .find(|(w, _c)| w.surface == surface) + .find(|w| w.surface == surface) + .map(|w| w.pos) .unwrap(); - app.pending_events.push_back((*client, CaptureEvent::Begin)); + app.pending_events.push_back((pos, CaptureEvent::Begin)); } wl_pointer::Event::Leave { .. } => { /* There are rare cases, where when a window is opened in @@ -722,9 +710,9 @@ impl Dispatch for State { button, state, } => { - let (_, client) = app.focused.as_ref().unwrap(); + let window = app.focused.as_ref().unwrap(); app.pending_events.push_back(( - *client, + window.pos, CaptureEvent::Input(Event::Pointer(PointerEvent::Button { time, button, @@ -733,7 +721,7 @@ impl Dispatch for State { )); } wl_pointer::Event::Axis { time, axis, value } => { - let (_, client) = app.focused.as_ref().unwrap(); + let window = app.focused.as_ref().unwrap(); if app.scroll_discrete_pending { // each axisvalue120 event is coupled with // a corresponding axis event, which needs to @@ -741,7 +729,7 @@ impl Dispatch for State { app.scroll_discrete_pending = false; } else { app.pending_events.push_back(( - *client, + window.pos, CaptureEvent::Input(Event::Pointer(PointerEvent::Axis { time, axis: u32::from(axis) as u8, @@ -751,10 +739,10 @@ impl Dispatch for State { } } wl_pointer::Event::AxisValue120 { axis, value120 } => { - let (_, client) = app.focused.as_ref().unwrap(); + let window = app.focused.as_ref().unwrap(); app.scroll_discrete_pending = true; app.pending_events.push_back(( - *client, + window.pos, CaptureEvent::Input(Event::Pointer(PointerEvent::AxisDiscrete120 { axis: u32::from(axis) as u8, value: value120, @@ -780,10 +768,7 @@ impl Dispatch for State { _: &Connection, _: &QueueHandle, ) { - let (_window, client) = match &app.focused { - Some(focused) => (Some(&focused.0), Some(&focused.1)), - None => (None, None), - }; + let window = &app.focused; match event { wl_keyboard::Event::Key { serial: _, @@ -791,9 +776,9 @@ impl Dispatch for State { key, state, } => { - if let Some(client) = client { + if let Some(window) = window { app.pending_events.push_back(( - *client, + window.pos, CaptureEvent::Input(Event::Keyboard(KeyboardEvent::Key { time, key, @@ -809,9 +794,9 @@ impl Dispatch for State { mods_locked, group, } => { - if let Some(client) = client { + if let Some(window) = window { app.pending_events.push_back(( - *client, + window.pos, CaptureEvent::Input(Event::Keyboard(KeyboardEvent::Modifiers { depressed: mods_depressed, latched: mods_latched, @@ -843,10 +828,10 @@ impl Dispatch for State { .. } = event { - if let Some((_window, client)) = &app.focused { + if let Some(window) = &app.focused { let time = (((utime_hi as u64) << 32 | utime_lo as u64) / 1000) as u32; app.pending_events.push_back(( - *client, + window.pos, CaptureEvent::Input(Event::Pointer(PointerEvent::Motion { time, dx, dy })), )); } @@ -864,10 +849,10 @@ impl Dispatch for State { _: &QueueHandle, ) { if let zwlr_layer_surface_v1::Event::Configure { serial, .. } = event { - if let Some((window, _client)) = app - .client_for_window + if let Some(window) = app + .active_windows .iter() - .find(|(w, _c)| &w.layer_surface == layer_surface) + .find(|w| &w.layer_surface == layer_surface) { // client corresponding to the layer_surface let surface = &window.surface; diff --git a/input-capture/src/windows.rs b/input-capture/src/windows.rs index 1d8f8fd..7517984 100644 --- a/input-capture/src/windows.rs +++ b/input-capture/src/windows.rs @@ -3,7 +3,7 @@ use core::task::{Context, Poll}; use futures::Stream; use once_cell::unsync::Lazy; -use std::collections::HashMap; +use std::collections::HashSet; use std::ptr::{addr_of, addr_of_mut}; use futures::executor::block_on; @@ -37,15 +37,15 @@ use input_event::{ Event, KeyboardEvent, PointerEvent, BTN_BACK, BTN_FORWARD, BTN_LEFT, BTN_MIDDLE, BTN_RIGHT, }; -use super::{Capture, CaptureError, CaptureEvent, CaptureHandle, Position}; +use super::{Capture, CaptureError, CaptureEvent, Position}; enum Request { - Create(CaptureHandle, Position), - Destroy(CaptureHandle), + Create(Position), + Destroy(Position), } pub struct WindowsInputCapture { - event_rx: Receiver<(CaptureHandle, CaptureEvent)>, + event_rx: Receiver<(Position, CaptureEvent)>, msg_thread: Option>, } @@ -65,22 +65,22 @@ unsafe fn signal_message_thread(event_type: EventType) { #[async_trait] impl Capture for WindowsInputCapture { - async fn create(&mut self, handle: CaptureHandle, pos: Position) -> Result<(), CaptureError> { + async fn create(&mut self, pos: Position) -> Result<(), CaptureError> { unsafe { { let mut requests = REQUEST_BUFFER.lock().unwrap(); - requests.push(Request::Create(handle, pos)); + requests.push(Request::Create(pos)); } signal_message_thread(EventType::Request); } Ok(()) } - async fn destroy(&mut self, handle: CaptureHandle) -> Result<(), CaptureError> { + async fn destroy(&mut self, pos: Position) -> Result<(), CaptureError> { unsafe { { let mut requests = REQUEST_BUFFER.lock().unwrap(); - requests.push(Request::Destroy(handle)); + requests.push(Request::Destroy(pos)); } signal_message_thread(EventType::Request); } @@ -98,9 +98,9 @@ impl Capture for WindowsInputCapture { } static mut REQUEST_BUFFER: Mutex> = Mutex::new(Vec::new()); -static mut ACTIVE_CLIENT: Option = None; -static mut CLIENT_FOR_POS: Lazy> = Lazy::new(HashMap::new); -static mut EVENT_TX: Option> = None; +static mut ACTIVE_CLIENT: Option = None; +static mut CLIENTS: Lazy> = Lazy::new(HashSet::new); +static mut EVENT_TX: Option> = None; static mut EVENT_THREAD_ID: AtomicU32 = AtomicU32::new(0); unsafe fn set_event_tid(tid: u32) { EVENT_THREAD_ID.store(tid, Ordering::SeqCst); @@ -281,12 +281,12 @@ unsafe fn check_client_activation(wparam: WPARAM, lparam: LPARAM) -> bool { }; /* check if a client is registered for the barrier */ - let Some(client) = CLIENT_FOR_POS.get(&pos) else { + if !CLIENTS.contains(&pos) { return ret; - }; + } /* update active client and entry point */ - ACTIVE_CLIENT.replace(*client); + ACTIVE_CLIENT.replace(pos); ENTRY_POINT = clamp_to_display_bounds(prev_pos, curr_pos); /* notify main thread */ @@ -305,7 +305,7 @@ unsafe extern "system" fn mouse_proc(ncode: i32, wparam: WPARAM, lparam: LPARAM) } /* get active client if any */ - let Some(client) = ACTIVE_CLIENT else { + let Some(pos) = ACTIVE_CLIENT else { return LRESULT(1); }; @@ -313,7 +313,7 @@ unsafe extern "system" fn mouse_proc(ncode: i32, wparam: WPARAM, lparam: LPARAM) let Some(pointer_event) = to_mouse_event(wparam, lparam) else { return LRESULT(1); }; - let event = (client, CaptureEvent::Input(Event::Pointer(pointer_event))); + let event = (pos, CaptureEvent::Input(Event::Pointer(pointer_event))); /* notify mainthread (drop events if sending too fast) */ if let Err(e) = EVENT_TX.as_ref().unwrap().try_send(event) { @@ -575,23 +575,16 @@ fn message_thread(ready_tx: mpsc::Sender<()>) { fn update_clients(request: Request) { match request { - Request::Create(handle, pos) => { - unsafe { CLIENT_FOR_POS.insert(pos, handle) }; + Request::Create(pos) => { + unsafe { CLIENTS.insert(pos) }; } - Request::Destroy(handle) => unsafe { - for pos in [ - Position::Left, - Position::Right, - Position::Top, - Position::Bottom, - ] { - if ACTIVE_CLIENT == Some(handle) { - ACTIVE_CLIENT.take(); - } - if CLIENT_FOR_POS.get(&pos).copied() == Some(handle) { - CLIENT_FOR_POS.remove(&pos); + Request::Destroy(pos) => unsafe { + if let Some(active_pos) = ACTIVE_CLIENT { + if pos == active_pos { + let _ = ACTIVE_CLIENT.take(); } } + CLIENTS.remove(&pos); }, } } @@ -614,7 +607,7 @@ impl WindowsInputCapture { } impl Stream for WindowsInputCapture { - type Item = Result<(CaptureHandle, CaptureEvent), CaptureError>; + type Item = Result<(Position, CaptureEvent), CaptureError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match ready!(self.event_rx.poll_recv(cx)) { None => Poll::Ready(None), diff --git a/input-capture/src/x11.rs b/input-capture/src/x11.rs index 8bcff94..7cf5b58 100644 --- a/input-capture/src/x11.rs +++ b/input-capture/src/x11.rs @@ -3,10 +3,7 @@ use std::task::Poll; use async_trait::async_trait; use futures_core::Stream; -use super::{ - error::X11InputCaptureCreationError, Capture, CaptureError, CaptureEvent, CaptureHandle, - Position, -}; +use super::{error::X11InputCaptureCreationError, Capture, CaptureError, CaptureEvent, Position}; pub struct X11InputCapture {} @@ -18,11 +15,11 @@ impl X11InputCapture { #[async_trait] impl Capture for X11InputCapture { - async fn create(&mut self, _id: CaptureHandle, _pos: Position) -> Result<(), CaptureError> { + async fn create(&mut self, _pos: Position) -> Result<(), CaptureError> { Ok(()) } - async fn destroy(&mut self, _id: CaptureHandle) -> Result<(), CaptureError> { + async fn destroy(&mut self, _pos: Position) -> Result<(), CaptureError> { Ok(()) } @@ -36,7 +33,7 @@ impl Capture for X11InputCapture { } impl Stream for X11InputCapture { - type Item = Result<(CaptureHandle, CaptureEvent), CaptureError>; + type Item = Result<(Position, CaptureEvent), CaptureError>; fn poll_next( self: std::pin::Pin<&mut Self>, diff --git a/src/capture_test.rs b/src/capture_test.rs index eadefa4..15a3019 100644 --- a/src/capture_test.rs +++ b/src/capture_test.rs @@ -11,6 +11,7 @@ pub async fn run(config: Config) -> Result<(), InputCaptureError> { let mut input_capture = InputCapture::new(backend).await?; log::info!("creating clients"); input_capture.create(0, Position::Left).await?; + input_capture.create(4, Position::Left).await?; input_capture.create(1, Position::Right).await?; input_capture.create(2, Position::Top).await?; input_capture.create(3, Position::Bottom).await?; @@ -28,12 +29,13 @@ async fn do_capture(input_capture: &mut InputCapture) -> Result<(), CaptureError .await .ok_or(CaptureError::EndOfStream)??; let pos = match client { - 0 => Position::Left, + 0 | 4 => Position::Left, 1 => Position::Right, 2 => Position::Top, - _ => Position::Bottom, + 3 => Position::Bottom, + _ => panic!(), }; - log::info!("position: {pos}, event: {event}"); + log::info!("position: {client} ({pos}), event: {event}"); if let CaptureEvent::Input(Event::Keyboard(KeyboardEvent::Key { key: 1, .. })) = event { input_capture.release().await?; break Ok(());