diff --git a/src/connect.rs b/src/connect.rs index 61bbc7c..cba5506 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -2,6 +2,7 @@ use crate::server::Server; use lan_mouse_ipc::{ClientHandle, DEFAULT_PORT}; use lan_mouse_proto::{ProtoEvent, MAX_EVENT_SIZE}; use local_channel::mpsc::{channel, Receiver, Sender}; +use rustls::pki_types::CertificateDer; use sha2::{Digest, Sha256}; use std::{ collections::{HashMap, HashSet}, @@ -36,6 +37,12 @@ pub(crate) enum LanMouseConnectionError { NotConnected, } +type VerifyPeerCertificateFn = Arc< + dyn (Fn(&[Vec], &[CertificateDer<'static>]) -> Result<(), webrtc_dtls::Error>) + + Send + + Sync, +>; + async fn connect( addr: SocketAddr, ) -> Result<(Arc, SocketAddr), LanMouseConnectionError> { @@ -43,40 +50,38 @@ async fn connect( let conn = Arc::new(UdpSocket::bind("0.0.0.0:0").await?); conn.connect(addr).await?; let certificate = Certificate::generate_self_signed(["localhost".to_owned()])?; + let verify_peer_certificate: Option = Some(Arc::new( + |certs: &[Vec], _chains: &[CertificateDer<'static>]| { + let fingerprints = certs + .into_iter() + .map(|cert| { + let mut hash = Sha256::new(); + hash.update(cert); + let bytes = hash + .finalize() + .iter() + .map(|x| format!("{x:02x}")) + .collect::>(); + let fingerprint = bytes.join(":").to_lowercase(); + fingerprint + }) + .collect::>(); + log::info!("fingerprints: {fingerprints:?}"); + Ok(()) + }, + )); let config = Config { certificates: vec![certificate], insecure_skip_verify: true, extended_master_secret: ExtendedMasterSecretType::Require, + verify_peer_certificate, ..Default::default() }; let dtls_conn = DTLSConn::new(conn, config, true, None).await?; log::info!("{addr} connected successfully!"); - let peer_certificates = dtls_conn.connection_state().await.peer_certificates; - verify_peer_certificates(peer_certificates)?; Ok((Arc::new(dtls_conn), addr)) } -fn verify_peer_certificates( - peer_certificates: Vec>, -) -> Result<(), LanMouseConnectionError> { - let fingerprints = peer_certificates - .into_iter() - .map(|cert| { - let mut hash = Sha256::new(); - hash.update(cert); - let bytes = hash - .finalize() - .iter() - .map(|x| format!("{x:02x}")) - .collect::>(); - let fingerprint = bytes.join(":").to_lowercase(); - fingerprint - }) - .collect::>(); - log::info!("fingerprints: {fingerprints:?}"); - Ok(()) -} - async fn connect_any( addrs: &[SocketAddr], ) -> Result<(Arc, SocketAddr), LanMouseConnectionError> {