diff --git a/Cargo.lock b/Cargo.lock index 2254718..7a23c56 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1196,6 +1196,7 @@ version = "0.1.0" dependencies = [ "anyhow", "ashpd", + "async-trait", "core-graphics", "futures", "futures-core", diff --git a/input-capture/Cargo.toml b/input-capture/Cargo.toml index eadb004..2cd1f33 100644 --- a/input-capture/Cargo.toml +++ b/input-capture/Cargo.toml @@ -17,6 +17,7 @@ tempfile = "3.8" thiserror = "1.0.61" tokio = { version = "1.32.0", features = ["io-util", "io-std", "macros", "net", "process", "rt", "sync", "signal"] } once_cell = "1.19.0" +async-trait = "0.1.81" [target.'cfg(all(unix, not(target_os="macos")))'.dependencies] diff --git a/input-capture/src/dummy.rs b/input-capture/src/dummy.rs index 42a119c..12b1463 100644 --- a/input-capture/src/dummy.rs +++ b/input-capture/src/dummy.rs @@ -2,6 +2,7 @@ use std::io; use std::pin::Pin; use std::task::{Context, Poll}; +use async_trait::async_trait; use futures_core::Stream; use input_event::Event; @@ -24,16 +25,21 @@ impl Default for DummyInputCapture { } } -impl InputCapture for DummyInputCapture { - fn create(&mut self, _handle: CaptureHandle, _pos: Position) -> io::Result<()> { +#[async_trait] +impl<'a> InputCapture for DummyInputCapture { + async fn create(&mut self, _handle: CaptureHandle, _pos: Position) -> io::Result<()> { Ok(()) } - fn destroy(&mut self, _handle: CaptureHandle) -> io::Result<()> { + async fn destroy(&mut self, _handle: CaptureHandle) -> io::Result<()> { Ok(()) } - fn release(&mut self) -> io::Result<()> { + async fn release(&mut self) -> io::Result<()> { + Ok(()) + } + + async fn async_drop(&mut self) -> Result<(), CaptureError> { Ok(()) } } diff --git a/input-capture/src/lib.rs b/input-capture/src/lib.rs index c4a64a2..c5380e6 100644 --- a/input-capture/src/lib.rs +++ b/input-capture/src/lib.rs @@ -1,5 +1,6 @@ use std::{fmt::Display, io}; +use async_trait::async_trait; use futures_core::Stream; use input_event::Event; @@ -92,17 +93,21 @@ impl Display for Backend { } } +#[async_trait] pub trait InputCapture: Stream> + Unpin { /// create a new client with the given id - fn create(&mut self, id: CaptureHandle, pos: Position) -> io::Result<()>; + async fn create(&mut self, id: CaptureHandle, pos: Position) -> io::Result<()>; /// destroy the client with the given id, if it exists - fn destroy(&mut self, id: CaptureHandle) -> io::Result<()>; + async fn destroy(&mut self, id: CaptureHandle) -> io::Result<()>; /// release mouse - fn release(&mut self) -> io::Result<()>; + async fn release(&mut self) -> io::Result<()>; + + /// destroy the input acpture + async fn async_drop(&mut self) -> Result<(), CaptureError>; } pub async fn create_backend( diff --git a/input-capture/src/libei.rs b/input-capture/src/libei.rs index 0aa5525..688def1 100644 --- a/input-capture/src/libei.rs +++ b/input-capture/src/libei.rs @@ -5,6 +5,7 @@ use ashpd::{ }, enumflags2::BitFlags, }; +use async_trait::async_trait; use futures::{FutureExt, StreamExt}; use reis::{ ei::{self, keyboard::KeyState}, @@ -17,7 +18,7 @@ use std::{ collections::HashMap, io, os::unix::net::UnixStream, - pin::{pin, Pin}, + pin::Pin, rc::Rc, task::{Context, Poll}, }; @@ -37,19 +38,30 @@ use super::{ error::LibeiCaptureCreationError, CaptureHandle, InputCapture as LanMouseInputCapture, Position, }; -#[derive(Debug)] -enum ProducerEvent { - Release, +struct PoisonPill; + +/* there is a bug in xdg-remote-desktop-portal-gnome / mutter that + * prevents receiving further events after a session has been disabled once. + * Therefore the session needs to recreated when the barriers are updated */ + +/// events that necessitate restarting the capture session +#[derive(Clone, Copy, Debug)] +enum CaptureEvent { Create(CaptureHandle, Position), Destroy(CaptureHandle), } +/// events that do not necessitate restarting the capture session +#[derive(Clone, Copy, Debug)] +struct ReleaseCaptureEvent; + #[allow(dead_code)] pub struct LibeiInputCapture<'a> { input_capture: Pin>>, - libei_task: JoinHandle>, - event_rx: tokio::sync::mpsc::Receiver<(CaptureHandle, Event)>, - notify_tx: tokio::sync::mpsc::Sender, + capture_task: JoinHandle>, + event_rx: Option>, + notify_capture: tokio::sync::mpsc::Sender, + notify_capture_session: tokio::sync::mpsc::Sender, } static INTERFACES: Lazy> = Lazy::new(|| { @@ -125,12 +137,6 @@ async fn update_barriers( Ok(id_map) } -impl<'a> Drop for LibeiInputCapture<'a> { - fn drop(&mut self) { - self.libei_task.abort(); - } -} - async fn create_session<'a>( input_capture: &'a InputCapture<'a>, ) -> std::result::Result<(Session<'a>, BitFlags), ashpd::Error> { @@ -182,6 +188,7 @@ async fn libei_event_handler( mut ei_event_stream: EiConvertEventStream, context: ei::Context, event_tx: Sender<(CaptureHandle, Event)>, + capture_tx: Sender, current_client: Rc>>, ) -> Result<(), CaptureError> { loop { @@ -192,26 +199,14 @@ async fn libei_event_handler( .map_err(ReisConvertEventStreamError::from)?; log::trace!("from ei: {ei_event:?}"); let client = current_client.get(); - if !handle_ei_event(ei_event, client, &context, &event_tx).await? { - /* close requested */ + handle_ei_event(ei_event, client, &context, &event_tx, &capture_tx).await?; + if event_tx.is_closed() { + log::info!("event_tx closed -> exiting"); break Ok(()); } } } -async fn wait_for_active_client( - notify_rx: &mut Receiver, - active_clients: &mut Vec<(CaptureHandle, Position)>, -) { - // wait for a client update - while let Some(producer_event) = notify_rx.recv().await { - if let ProducerEvent::Create(c, p) = producer_event { - handle_producer_event(ProducerEvent::Create(c, p), active_clients); - break; - } - } -} - impl<'a> LibeiInputCapture<'a> { pub async fn new() -> std::result::Result { let input_capture = Box::pin(InputCapture::new().await?); @@ -219,15 +214,24 @@ impl<'a> LibeiInputCapture<'a> { let first_session = Some(create_session(unsafe { &*input_capture_ptr }).await?); let (event_tx, event_rx) = tokio::sync::mpsc::channel(32); - let (notify_tx, notify_rx) = tokio::sync::mpsc::channel(32); - let capture = do_capture(input_capture_ptr, notify_rx, first_session, event_tx); - let libei_task = tokio::task::spawn_local(capture); + let (notify_capture, notify_rx) = tokio::sync::mpsc::channel(32); + let (notify_capture_session, notify_session_rx) = tokio::sync::mpsc::channel(32); + let capture = do_capture( + input_capture_ptr, + notify_rx, + notify_session_rx, + first_session, + event_tx, + ); + let capture_task = tokio::task::spawn_local(capture); + let event_rx = Some(event_rx); let producer = Self { input_capture, event_rx, - libei_task, - notify_tx, + capture_task, + notify_capture, + notify_capture_session, }; Ok(producer) @@ -236,64 +240,138 @@ impl<'a> LibeiInputCapture<'a> { async fn do_capture<'a>( input_capture_ptr: *const InputCapture<'static>, - mut notify_rx: Receiver, - mut first_session: Option<(Session<'a>, BitFlags)>, + mut capture_event: Receiver, + mut release_capture_channel: Receiver, + session: Option<(Session<'a>, BitFlags)>, event_tx: Sender<(CaptureHandle, Event)>, ) -> Result<(), CaptureError> { + let mut session = session.map(|s| s.0); + /* safety: libei_task does not outlive Self */ let input_capture = unsafe { &*input_capture_ptr }; let mut active_clients: Vec<(CaptureHandle, Position)> = vec![]; let mut next_barrier_id = 1u32; - /* there is a bug in xdg-remote-desktop-portal-gnome / mutter that - * prevents receiving further events after a session has been disabled once. - * Therefore the session needs to recreated when the barriers are updated */ + let mut zones_changed = input_capture.receive_zones_changed().await?; loop { - // otherwise it asks to capture input even with no active clients - if active_clients.is_empty() { - wait_for_active_client(&mut notify_rx, &mut active_clients).await; - if notify_rx.is_closed() { - break; - } - continue; - } - - let current_client = Rc::new(Cell::new(None)); - // create session - let (session, _) = match first_session.take() { + let mut session = match session.take() { Some(s) => s, - _ => create_session(input_capture).await?, + None => create_session(input_capture).await?.0, }; - // connect to eis server - let (context, ei_event_stream) = connect_to_eis(input_capture, &session).await?; + // do capture session + let (session_kill, session_kill_rx) = tokio::sync::mpsc::channel(1); + let (kill_update_tx, mut kill_update) = tokio::sync::mpsc::channel(1); - // async event task - let mut ei_task = pin!(libei_event_handler( - ei_event_stream, - context, - event_tx.clone(), - current_client.clone(), - )); - - let mut activated = input_capture.receive_activated().await?; - let mut zones_changed = input_capture.receive_zones_changed().await?; - - // set barriers - let client_for_barrier_id = update_barriers( + let capture_session = do_capture_session( input_capture, - &session, - &active_clients, + &mut session, + &event_tx, + &mut active_clients, &mut next_barrier_id, - ) - .await?; + &mut release_capture_channel, + session_kill_rx, + session_kill.clone(), + kill_update_tx, + ); - log::debug!("enabling session"); - input_capture.enable(&session).await?; + let mut capture_event_occured: Option = None; + let mut zones_have_changed = false; + // kill session if clients need to be updated + let handle_session_update_request = async { + tokio::select! { + /* session exited */ + _ = kill_update.recv() => {}, + /* zones have changed */ + _ = zones_changed.next() => zones_have_changed = true, + /* clients changed */ + e = capture_event.recv() => if let Some(e) = e { + capture_event_occured.replace(e); + }, + } + // kill session (might already be dead!) + let _ = session_kill.send(PoisonPill).await; + }; + + let (a, _) = tokio::join!(capture_session, handle_session_update_request); + if let Err(e) = a { + log::warn!("{e}"); + } + log::info!("capture session + session_update task done!"); + + log::info!("no error occured"); + + if let Some(event) = capture_event_occured.take() { + match event { + CaptureEvent::Create(c, p) => active_clients.push((c, p)), + CaptureEvent::Destroy(c) => active_clients.retain(|(h, _)| *h != c), + } + } + + // disable capture + log::info!("disabling input capture"); + input_capture.disable(&session).await?; + + // break + if event_tx.is_closed() { + break Ok(()); + } + } +} + +async fn do_capture_session( + input_capture: &InputCapture<'_>, + session: &mut Session<'_>, + event_tx: &Sender<(CaptureHandle, Event)>, + active_clients: &mut Vec<(CaptureHandle, Position)>, + next_barrier_id: &mut u32, + capture_session_event: &mut Receiver, + mut kill_session: Receiver, + kill_session_tx: Sender, + kill_update_tx: Sender, +) -> Result<(), CaptureError> { + // current client + let current_client = Rc::new(Cell::new(None)); + + // connect to eis server + let (context, ei_event_stream) = connect_to_eis(input_capture, session).await?; + + // set barriers + let client_for_barrier_id = + update_barriers(input_capture, session, &active_clients, next_barrier_id).await?; + + log::debug!("enabling session"); + input_capture.enable(session).await?; + + // async event task + let (kill_ei, mut kill_ei_rx) = tokio::sync::mpsc::channel(1); + let (break_tx, mut break_rx) = tokio::sync::mpsc::channel(1); + let event_chan = event_tx.clone(); + let client = current_client.clone(); + let ei_task = async move { + tokio::select! { + r = libei_event_handler( + ei_event_stream, + context, + event_chan, + break_tx, + client, + ) => { + log::info!("libei exited: {r:?} killing session task"); + let _ = kill_session_tx.send(PoisonPill).await; + } + _ = kill_ei_rx.recv() => {}, + } + Ok::<(), CaptureError>(()) + }; + + let capture_session_task = async { + // receiver for activation tokens + let mut activated = input_capture.receive_activated().await?; loop { tokio::select! { activated = activated.next() => { @@ -310,46 +388,32 @@ async fn do_capture<'a>( }; tokio::select! { - producer_event = notify_rx.recv() => { - let producer_event = producer_event.expect("channel closed"); - if handle_producer_event(producer_event, &mut active_clients) { - break; /* clients updated */ - } - } - zones_changed = zones_changed.next() => { - log::debug!("zones changed: {zones_changed:?}"); - break; - } - res = &mut ei_task => { - /* propagate errors to toplevel task */ - res?; - } + _ = capture_session_event.recv() => {}, /* capture release */ + _ = break_rx.recv() => {}, /* ei notifying that it needs to restart */ + _ = kill_session.recv() => break, /* kill session notify */ } - release_capture( - input_capture, - &session, - activated, - client, - &active_clients, - ).await?; - } - producer_event = notify_rx.recv() => { - let producer_event = producer_event.expect("channel closed"); - if handle_producer_event(producer_event, &mut active_clients) { - /* clients updated */ - break; - } - }, - res = &mut ei_task => { - res?; + + release_capture(input_capture, session, activated, client, &active_clients).await?; } + _ = break_rx.recv() => {}, /* ei notifying that it needs to restart */ + _ = capture_session_event.recv() => {}, /* capture release -> we are not capturing anyway, so ignore */ + _ = kill_session.recv() => break, /* kill session notify */ } } - input_capture.disable(&session).await?; - if event_tx.is_closed() { - break; - } - } + // kill libei task + log::info!("session exited: killing libei task"); + let _ = kill_ei.send(()).await; + Ok::<(), CaptureError>(()) + }; + + let (a, b) = tokio::join!(ei_task, capture_session_task); + + let _ = kill_update_tx.send(PoisonPill).await; + + log::info!("both session and ei task finished!"); + a?; + b?; + Ok(()) } @@ -383,30 +447,15 @@ async fn release_capture( Ok(()) } -fn handle_producer_event( - producer_event: ProducerEvent, - active_clients: &mut Vec<(CaptureHandle, Position)>, -) -> bool { - log::debug!("handling event: {producer_event:?}"); - match producer_event { - ProducerEvent::Release => false, - ProducerEvent::Create(c, p) => { - active_clients.push((c, p)); - true - } - ProducerEvent::Destroy(c) => { - active_clients.retain(|(h, _)| *h != c); - true - } - } -} +struct DeviceUpdate; async fn handle_ei_event( ei_event: EiEvent, current_client: Option, context: &ei::Context, event_tx: &Sender<(CaptureHandle, Event)>, -) -> Result { + capture_tx: &Sender, +) -> Result<(), CaptureError> { match ei_event { EiEvent::SeatAdded(s) => { s.seat.bind_capabilities(&[ @@ -419,11 +468,79 @@ async fn handle_ei_event( ]); context.flush().map_err(|e| io::Error::new(e.kind(), e))?; } - EiEvent::SeatRemoved(_) => {} - EiEvent::DeviceAdded(_) => {} - EiEvent::DeviceRemoved(_) => {} - EiEvent::DevicePaused(_) => {} - EiEvent::DeviceResumed(_) => {} + EiEvent::SeatRemoved(_) + | EiEvent::DeviceAdded(_) + | EiEvent::DeviceRemoved(_) + | EiEvent::DevicePaused(_) + | EiEvent::DeviceResumed(_) => { + if capture_tx.send(DeviceUpdate).await.is_err() { + return Ok(()); + } + } + EiEvent::DeviceStartEmulating(_) => log::debug!("START EMULATING"), + EiEvent::DeviceStopEmulating(_) => log::debug!("STOP EMULATING"), + EiEvent::Disconnected(d) => { + return Err(CaptureError::Disconnected(format!("{:?}", d.reason))) + } + _ => { + if let Some(handle) = current_client { + let events = to_input_events(ei_event); + for event in events.into_iter() { + if event_tx.send((handle, event)).await.is_err() { + return Ok(()); + }; + } + } + } + } + Ok(()) +} + +/* not pretty but saves a heap allocation */ +enum Events { + None, + One(Event), + Two(Event, Event), +} + +impl Events { + fn into_iter(self) -> impl Iterator { + EventIterator::new(self) + } +} + +struct EventIterator { + events: [Option; 2], + pos: usize, +} + +impl EventIterator { + fn new(events: Events) -> Self { + let events = match events { + Events::None => [None, None], + Events::One(e) => [Some(e), None], + Events::Two(e, f) => [Some(e), Some(f)], + }; + Self { events, pos: 0 } + } +} + +impl Iterator for EventIterator { + type Item = Event; + + fn next(&mut self) -> Option { + let res = if self.pos >= self.events.len() { + None + } else { + self.events[self.pos] + }; + self.pos += 1; + res + } +} + +fn to_input_events(ei_event: EiEvent) -> Events { + match ei_event { EiEvent::KeyboardModifiers(mods) => { let modifier_event = KeyboardEvent::Modifiers { mods_depressed: mods.depressed, @@ -431,36 +548,18 @@ async fn handle_ei_event( mods_locked: mods.locked, group: mods.group, }; - if let Some(current_client) = current_client { - if event_tx - .send((current_client, Event::Keyboard(modifier_event))) - .await.is_err() { - return Ok(false); - } - } - } - EiEvent::Frame(_) => {} - EiEvent::DeviceStartEmulating(_) => { - log::debug!("START EMULATING =============>"); - } - EiEvent::DeviceStopEmulating(_) => { - log::debug!("==================> STOP EMULATING"); + Events::One(Event::Keyboard(modifier_event)) } + EiEvent::Frame(_) => Events::None, /* FIXME */ EiEvent::PointerMotion(motion) => { let motion_event = PointerEvent::Motion { time: motion.time as u32, dx: motion.dx as f64, dy: motion.dy as f64, }; - if let Some(current_client) = current_client { - if event_tx - .send((current_client, Event::Pointer(motion_event))) - .await.is_err() { - return Ok(false); - } - } + Events::One(Event::Pointer(motion_event)) } - EiEvent::PointerMotionAbsolute(_) => {} + EiEvent::PointerMotionAbsolute(_) => Events::None, EiEvent::Button(button) => { let button_event = PointerEvent::Button { time: button.time as u32, @@ -470,69 +569,49 @@ async fn handle_ei_event( ButtonState::Press => 1, }, }; - if let Some(current_client) = current_client { - if event_tx - .send((current_client, Event::Pointer(button_event))) - .await.is_err() { - return Ok(false); - } - } + Events::One(Event::Pointer(button_event)) } EiEvent::ScrollDelta(delta) => { - if let Some(handle) = current_client { - let mut events = vec![]; - if delta.dy != 0. { - events.push(PointerEvent::Axis { - time: 0, - axis: 0, - value: delta.dy as f64, - }); - } - if delta.dx != 0. { - events.push(PointerEvent::Axis { - time: 0, - axis: 1, - value: delta.dx as f64, - }); - } - for event in events { - if event_tx - .send((handle, Event::Pointer(event))) - .await.is_err() { - return Ok(false); - } - } + let dy = Event::Pointer(PointerEvent::Axis { + time: 0, + axis: 0, + value: delta.dy as f64, + }); + let dx = Event::Pointer(PointerEvent::Axis { + time: 0, + axis: 1, + value: delta.dx as f64, + }); + if delta.dy != 0. && delta.dx != 0. { + Events::Two(dy, dx) + } else if delta.dy != 0. { + Events::One(dy) + } else if delta.dx != 0. { + Events::One(dx) + } else { + Events::None } } - EiEvent::ScrollStop(_) => {} - EiEvent::ScrollCancel(_) => {} + EiEvent::ScrollStop(_) => Events::None, /* TODO */ + EiEvent::ScrollCancel(_) => Events::None, /* TODO */ EiEvent::ScrollDiscrete(scroll) => { - if scroll.discrete_dy != 0 { - let event = PointerEvent::AxisDiscrete120 { - axis: 0, - value: scroll.discrete_dy, - }; - if let Some(current_client) = current_client { - if event_tx - .send((current_client, Event::Pointer(event))) - .await.is_err() { - return Ok(false); - } - } + let dy = Event::Pointer(PointerEvent::AxisDiscrete120 { + axis: 0, + value: scroll.discrete_dy, + }); + let dx = Event::Pointer(PointerEvent::AxisDiscrete120 { + axis: 1, + value: scroll.discrete_dx, + }); + if scroll.discrete_dy != 0 && scroll.discrete_dx != 0 { + Events::Two(dy, dx) + } else if scroll.discrete_dy != 0 { + Events::One(dy) + } else if scroll.discrete_dx != 0 { + Events::One(dx) + } else { + Events::None } - if scroll.discrete_dx != 0 { - let event = PointerEvent::AxisDiscrete120 { - axis: 1, - value: scroll.discrete_dx, - }; - if let Some(current_client) = current_client { - if event_tx - .send((current_client, Event::Pointer(event))) - .await.is_err() { - return Ok(false); - } - } - }; } EiEvent::KeyboardKey(key) => { let key_event = KeyboardEvent::Key { @@ -543,58 +622,61 @@ async fn handle_ei_event( }, time: key.time as u32, }; - if let Some(current_client) = current_client { - if event_tx - .send((current_client, Event::Keyboard(key_event))) - .await.is_err() { - return Ok(false); - } - } + Events::One(Event::Keyboard(key_event)) } - EiEvent::TouchDown(_) => {} - EiEvent::TouchUp(_) => {} - EiEvent::TouchMotion(_) => {} - EiEvent::Disconnected(d) => return Err(CaptureError::Disconnected(format!("{:?}", d.reason))), + EiEvent::TouchDown(_) => Events::None, /* TODO */ + EiEvent::TouchUp(_) => Events::None, /* TODO */ + EiEvent::TouchMotion(_) => Events::None, /* TODO */ + _ => Events::None, } - Ok(true) } +#[async_trait] impl<'a> LanMouseInputCapture for LibeiInputCapture<'a> { - fn create(&mut self, handle: super::CaptureHandle, pos: super::Position) -> io::Result<()> { - let notify_tx = self.notify_tx.clone(); - tokio::task::spawn_local(async move { - let _ = notify_tx.send(ProducerEvent::Create(handle, pos)).await; - }); + async fn create(&mut self, handle: CaptureHandle, pos: Position) -> io::Result<()> { + let _ = self + .notify_capture + .send(CaptureEvent::Create(handle, pos)) + .await; Ok(()) } - fn destroy(&mut self, handle: super::CaptureHandle) -> io::Result<()> { - let notify_tx = self.notify_tx.clone(); - tokio::task::spawn_local(async move { - let _ = notify_tx.send(ProducerEvent::Destroy(handle)).await; - }); + async fn destroy(&mut self, handle: CaptureHandle) -> io::Result<()> { + let _ = self + .notify_capture + .send(CaptureEvent::Destroy(handle)) + .await; Ok(()) } - fn release(&mut self) -> io::Result<()> { - let notify_tx = self.notify_tx.clone(); - tokio::task::spawn_local(async move { - let _ = notify_tx.send(ProducerEvent::Release).await; - }); + async fn release(&mut self) -> io::Result<()> { + let _ = self.notify_capture_session.send(ReleaseCaptureEvent).await; Ok(()) } + + async fn async_drop(&mut self) -> Result<(), CaptureError> { + let event_rx = self.event_rx.take().expect("no channel"); + std::mem::drop(event_rx); + let task = &mut self.capture_task; + task.await.expect("libei task panic") + } } impl<'a> Stream for LibeiInputCapture<'a> { type Item = Result<(CaptureHandle, Event), CaptureError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.libei_task.poll_unpin(cx) { + match self.capture_task.poll_unpin(cx) { Poll::Ready(r) => match r.expect("failed to join") { Ok(()) => Poll::Ready(None), Err(e) => Poll::Ready(Some(Err(e))), }, - Poll::Pending => self.event_rx.poll_recv(cx).map(|e| e.map(Result::Ok)), + Poll::Pending => self + .event_rx + .as_mut() + .expect("no channel") + .poll_recv(cx) + .map(|e| e.map(Result::Ok)), } } } diff --git a/input-capture/src/wayland.rs b/input-capture/src/wayland.rs index efb65fb..f8e83ea 100644 --- a/input-capture/src/wayland.rs +++ b/input-capture/src/wayland.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use futures_core::Stream; use memmap::MmapOptions; use std::{ @@ -14,7 +15,7 @@ use std::{ fs::File, io::{BufWriter, Write}, os::unix::prelude::{AsRawFd, FromRawFd}, - rc::Rc, + sync::Arc, }; use wayland_protocols::{ @@ -104,8 +105,8 @@ struct State { pointer_lock: Option, rel_pointer: Option, shortcut_inhibitor: Option, - client_for_window: Vec<(Rc, CaptureHandle)>, - focused: Option<(Rc, CaptureHandle)>, + client_for_window: Vec<(Arc, CaptureHandle)>, + focused: Option<(Arc, CaptureHandle)>, g: Globals, wayland_fd: OwnedFd, read_guard: Option, @@ -477,7 +478,7 @@ impl State { log::debug!("outputs: {outputs:?}"); outputs.iter().for_each(|(o, i)| { let window = Window::new(self, &self.qh, o, pos, i.size); - let window = Rc::new(window); + let window = Arc::new(window); self.client_for_window.push((window, client)); }); } @@ -563,24 +564,30 @@ impl Inner { } } +#[async_trait] impl InputCapture for WaylandInputCapture { - fn create(&mut self, handle: CaptureHandle, pos: Position) -> io::Result<()> { + async fn create(&mut self, handle: CaptureHandle, pos: Position) -> io::Result<()> { self.add_client(handle, pos); let inner = self.0.get_mut(); inner.flush_events() } - fn destroy(&mut self, handle: CaptureHandle) -> io::Result<()> { + + async fn destroy(&mut self, handle: CaptureHandle) -> io::Result<()> { self.delete_client(handle); let inner = self.0.get_mut(); inner.flush_events() } - fn release(&mut self) -> io::Result<()> { + async fn release(&mut self) -> io::Result<()> { log::debug!("releasing pointer"); let inner = self.0.get_mut(); inner.state.ungrab(); inner.flush_events() } + + async fn async_drop(&mut self) -> Result<(), CaptureError> { + Ok(()) + } } impl Stream for WaylandInputCapture { diff --git a/input-capture/src/x11.rs b/input-capture/src/x11.rs index fcd1624..02dcfe4 100644 --- a/input-capture/src/x11.rs +++ b/input-capture/src/x11.rs @@ -1,6 +1,7 @@ use std::io; use std::task::Poll; +use async_trait::async_trait; use futures_core::Stream; use crate::CaptureError; @@ -19,16 +20,21 @@ impl X11InputCapture { } } +#[async_trait] impl InputCapture for X11InputCapture { - fn create(&mut self, _id: CaptureHandle, _pos: Position) -> io::Result<()> { + async fn create(&mut self, _id: CaptureHandle, _pos: Position) -> io::Result<()> { Ok(()) } - fn destroy(&mut self, _id: CaptureHandle) -> io::Result<()> { + async fn destroy(&mut self, _id: CaptureHandle) -> io::Result<()> { Ok(()) } - fn release(&mut self) -> io::Result<()> { + async fn release(&mut self) -> io::Result<()> { + Ok(()) + } + + async fn async_drop(&mut self) -> Result<(), CaptureError> { Ok(()) } } diff --git a/src/capture_test.rs b/src/capture_test.rs index 68b94c5..a809809 100644 --- a/src/capture_test.rs +++ b/src/capture_test.rs @@ -20,26 +20,49 @@ pub fn run() -> Result<()> { async fn input_capture_test(config: Config) -> Result<()> { log::info!("creating input capture"); let backend = config.capture_backend.map(|b| b.into()); - let mut input_capture = input_capture::create(backend).await?; - log::info!("creating clients"); - input_capture.create(0, Position::Left)?; - input_capture.create(1, Position::Right)?; - input_capture.create(2, Position::Top)?; - input_capture.create(3, Position::Bottom)?; - loop { - let (client, event) = input_capture - .next() - .await - .ok_or(anyhow!("capture stream closed"))??; - let pos = match client { - 0 => Position::Left, - 1 => Position::Right, - 2 => Position::Top, - _ => Position::Bottom, - }; - log::info!("position: {pos}, event: {event}"); - if let Event::Keyboard(KeyboardEvent::Key { key: 1, .. }) = event { - input_capture.release()?; + for _ in 0..2 { + let mut input_capture = Some(input_capture::create(backend).await?); + log::info!("creating clients"); + input_capture + .as_mut() + .unwrap() + .create(0, Position::Left) + .await?; + input_capture + .as_mut() + .unwrap() + .create(1, Position::Right) + .await?; + input_capture + .as_mut() + .unwrap() + .create(2, Position::Top) + .await?; + input_capture + .as_mut() + .unwrap() + .create(3, Position::Bottom) + .await?; + loop { + let (client, event) = input_capture + .as_mut() + .unwrap() + .next() + .await + .ok_or(anyhow!("capture stream closed"))??; + let pos = match client { + 0 => Position::Left, + 1 => Position::Right, + 2 => Position::Top, + _ => Position::Bottom, + }; + log::info!("position: {pos}, event: {event}"); + if let Event::Keyboard(KeyboardEvent::Key { key: 1, .. }) = event { + // input_capture.as_mut().unwrap().release()?; + break; + } } + input_capture.take().unwrap().async_drop().await.unwrap(); } + Ok(()) } diff --git a/src/server/capture_task.rs b/src/server/capture_task.rs index 6044a14..ce50a59 100644 --- a/src/server/capture_task.rs +++ b/src/server/capture_task.rs @@ -52,16 +52,16 @@ pub fn new( match e { Some(e) => match e { CaptureEvent::Release => { - capture.release()?; + capture.release().await?; server.state.replace(State::Receiving); } - CaptureEvent::Create(h, p) => capture.create(h, p)?, - CaptureEvent::Destroy(h) => capture.destroy(h)?, + CaptureEvent::Create(h, p) => capture.create(h, p).await?, + CaptureEvent::Destroy(h) => capture.destroy(h).await?, CaptureEvent::Restart => { let clients = server.client_manager.borrow().get_client_states().map(|(h, (c,_))| (h, c.pos)).collect::>(); capture = input_capture::create(backend).await?; for (handle, pos) in clients { - capture.create(handle, pos.into())?; + capture.create(handle, pos.into()).await?; } } CaptureEvent::Terminate => break, @@ -104,7 +104,7 @@ async fn handle_capture_event( if release_bind.iter().all(|k| pressed_keys.contains(k)) { pressed_keys.clear(); log::info!("releasing pointer"); - capture.release()?; + capture.release().await?; server.state.replace(State::Receiving); log::trace!("STATE ===> Receiving"); // send an event to release all the modifiers @@ -123,7 +123,7 @@ async fn handle_capture_event( None => { // should not happen log::warn!("unknown client!"); - capture.release()?; + capture.release().await?; server.state.replace(State::Receiving); log::trace!("STATE ===> Receiving"); return Ok(());