diff --git a/src/connect.rs b/src/connect.rs index dd01901..521aaa7 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -1,9 +1,19 @@ use crate::server::Server; use lan_mouse_ipc::{ClientHandle, DEFAULT_PORT}; use lan_mouse_proto::{ProtoEvent, MAX_EVENT_SIZE}; -use std::{collections::HashMap, io, net::SocketAddr, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + io, + net::SocketAddr, + rc::Rc, + sync::Arc, +}; use thiserror::Error; -use tokio::{net::UdpSocket, task::JoinSet}; +use tokio::{ + net::UdpSocket, + sync::Mutex, + task::{spawn_local, JoinSet}, +}; use webrtc_dtls::{ config::{Config, ExtendedMasterSecretType}, conn::DTLSConn, @@ -21,6 +31,8 @@ pub(crate) enum LanMouseConnectionError { Webrtc(#[from] webrtc_util::Error), #[error("no ips associated with the client")] NoIps, + #[error("not connected")] + NotConnected, } async fn connect( @@ -54,7 +66,8 @@ async fn connect_any( pub(crate) struct LanMouseConnection { server: Server, - conns: HashMap>, + conns: Rc>>>, + connecting: Rc>>, } impl LanMouseConnection { @@ -62,6 +75,7 @@ impl LanMouseConnection { Self { server, conns: Default::default(), + connecting: Default::default(), } } @@ -73,24 +87,45 @@ impl LanMouseConnection { let (buf, len): ([u8; MAX_EVENT_SIZE], usize) = event.into(); let buf = &buf[..len]; if let Some(addr) = self.server.active_addr(handle) { - if let Some(conn) = self.conns.get(&addr) { - if let Ok(_) = conn.send(buf).await { - return Ok(()); + if let Some(conn) = self.conns.lock().await.get(&addr) { + match conn.send(buf).await { + Ok(_) => return Ok(()), + Err(e) => { + log::warn!("failed to connect: {e}"); + } } } } - // sending did not work, figure out active conn. - if let Some(addrs) = self.server.get_ips(handle) { - let port = self.server.get_port(handle).unwrap_or(DEFAULT_PORT); - let addrs = addrs - .into_iter() - .map(|a| SocketAddr::new(a, port)) - .collect::>(); - let (conn, addr) = connect_any(&addrs).await?; - self.server.set_active_addr(handle, addr); - conn.send(buf).await?; - return Ok(()); + + // check if we are already trying to connect + { + let mut connecting = self.connecting.lock().await; + if connecting.contains(&handle) { + return Err(LanMouseConnectionError::NotConnected); + } else { + connecting.insert(handle); + } } - Err(LanMouseConnectionError::NoIps) + let server = self.server.clone(); + let conns = self.conns.clone(); + let connecting = self.connecting.clone(); + + // connect in the background + spawn_local(async move { + // sending did not work, figure out active conn. + if let Some(addrs) = server.get_ips(handle) { + let port = server.get_port(handle).unwrap_or(DEFAULT_PORT); + let addrs = addrs + .into_iter() + .map(|a| SocketAddr::new(a, port)) + .collect::>(); + let (conn, addr) = connect_any(&addrs).await?; + server.set_active_addr(handle, addr); + conns.lock().await.insert(addr, conn); + connecting.lock().await.remove(&handle); + } + Result::<(), LanMouseConnectionError>::Ok(()) + }); + Err(LanMouseConnectionError::NotConnected) } }