use hbb_common::{ anyhow, bytes::{Bytes, BytesMut}, bytes_codec::BytesCodec, config, log, tcp::{DynTcpStream, FramedStream}, tokio::{self, net::UdpSocket, sync::mpsc, sync::oneshot}, tokio_util, ResultType, Stream, }; use kcp_sys::{ endpoint::KcpEndpoint, packet_def::{KcpPacket, KcpPacketHeader}, stream, }; use std::{net::SocketAddr, sync::Arc}; pub struct KcpStream { _endpoint: KcpEndpoint, stop_sender: Option>, } impl KcpStream { fn create_framed(stream: stream::KcpStream, local_addr: Option) -> Stream { Stream::Tcp(FramedStream( tokio_util::codec::Framed::new(DynTcpStream(Box::new(stream)), BytesCodec::new()), local_addr.unwrap_or(config::Config::get_any_listen_addr(true)), None, 0, )) } pub async fn accept( udp_socket: Arc, timeout: std::time::Duration, init_packet: Option, ) -> ResultType<(Self, Stream)> { let mut endpoint = KcpEndpoint::new(); endpoint.run().await; let (input, output) = ( endpoint.input_sender(), endpoint .output_receiver() .ok_or_else(|| anyhow::anyhow!("Failed to get output receiver"))?, ); let (stop_sender, stop_receiver) = oneshot::channel(); if let Some(packet) = init_packet { if packet.len() >= std::mem::size_of::() { input.send(packet.into()).await?; } } Self::kcp_io(udp_socket.clone(), input, output, stop_receiver).await; let conn_id = tokio::time::timeout(timeout, endpoint.accept()).await??; if let Some(stream) = stream::KcpStream::new(&endpoint, conn_id) { Ok(( Self { _endpoint: endpoint, stop_sender: Some(stop_sender), }, Self::create_framed(stream, udp_socket.local_addr().ok()), )) } else { Err(anyhow::anyhow!("Failed to create KcpStream")) } } pub async fn connect( udp_socket: Arc, timeout: std::time::Duration, ) -> ResultType<(Self, Stream)> { let mut endpoint = KcpEndpoint::new(); endpoint.run().await; let (input, output) = ( endpoint.input_sender(), endpoint .output_receiver() .ok_or_else(|| anyhow::anyhow!("Failed to get output receiver"))?, ); let (stop_sender, stop_receiver) = oneshot::channel(); Self::kcp_io(udp_socket.clone(), input, output, stop_receiver).await; let conn_id = endpoint.connect(timeout, 0, 0, Bytes::new()).await?; if let Some(stream) = stream::KcpStream::new(&endpoint, conn_id) { Ok(( Self { _endpoint: endpoint, stop_sender: Some(stop_sender), }, Self::create_framed(stream, udp_socket.local_addr().ok()), )) } else { Err(anyhow::anyhow!("Failed to create KcpStream")) } } async fn kcp_io( udp_socket: Arc, input: mpsc::Sender, mut output: mpsc::Receiver, mut stop_receiver: oneshot::Receiver<()>, ) { let udp = udp_socket.clone(); tokio::spawn(async move { let mut buf = vec![0; 1500]; loop { tokio::select! { _ = &mut stop_receiver => { log::debug!("KCP io loop received stop signal"); break; } Some(data) = output.recv() => { if let Err(e) = udp.send(&data.inner()).await { log::debug!("KCP send error: {:?}", e); break; } } result = udp.recv_from(&mut buf) => { match result { Ok((size, _)) => { if size < std::mem::size_of::() { continue; } input .send(BytesMut::from(&buf[..size]).into()) .await.ok(); } Err(e) => { log::debug!("KCP recv_from error: {:?}", e); break; } } } else => { log::debug!("KCP endpoint input closed"); break; } } } }); } } impl Drop for KcpStream { fn drop(&mut self) { if let Some(sender) = self.stop_sender.take() { let _ = sender.send(()); } } }