use std::{ collections::HashMap, error::Error, fmt::Display, io::prelude::*, net::{SocketAddr, TcpListener, TcpStream}, sync::{Arc, RwLock}, thread::{self, JoinHandle}, }; use memmap::MmapMut; #[derive(Copy, Clone, PartialEq, Eq, Hash)] pub enum Request { KeyMap, Connect, } impl TryFrom<[u8; 4]> for Request { fn try_from(buf: [u8; 4]) -> Result { let val = u32::from_ne_bytes(buf); match val { x if x == Request::KeyMap as u32 => Ok(Self::KeyMap), x if x == Request::Connect as u32 => Ok(Self::Connect), _ => Err("Bad Request"), } } type Error = &'static str; } #[derive(Clone)] pub struct Server { data: Arc>>, } impl Server { fn handle_request(&self, mut stream: TcpStream) -> Result<(), Box> { let mut buf = [0u8; 4]; stream.read_exact(&mut buf)?; match Request::try_from(buf) { Ok(Request::KeyMap) => { let data = self.data.read().unwrap(); let buf = data.get(&Request::KeyMap); match buf { None => { stream.write(&0u32.to_ne_bytes())?; } Some(buf) => { stream.write(&buf[..].len().to_ne_bytes())?; stream.write(&buf[..])?; } } stream.flush()?; } Ok(Request::Connect) => todo!(), Err(msg) => eprintln!("{}", msg), } Ok(()) } pub fn listen(port: u16) -> Result<(Server, JoinHandle<()>), Box> { let data: Arc>> = Arc::new(RwLock::new(HashMap::new())); let listen_addr = SocketAddr::new("0.0.0.0".parse()?, port); let server = Server { data }; let server_copy = server.clone(); let listen_socket = TcpListener::bind(listen_addr)?; let thread = thread::Builder::new() .name("tcp server".into()) .spawn(move || { for stream in listen_socket.incoming() { match stream { Ok(stream) => { if let Err(e) = server.handle_request(stream) { eprintln!("{}", e); } } Err(e) => { eprintln!("{}", e); } } } })?; Ok((server_copy, thread)) } pub fn offer_data(&self, req: Request, d: MmapMut) { self.data.write().unwrap().insert(req, d); } } #[derive(Debug)] pub struct BadRequest; impl Display for BadRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "BadRequest") } } impl Error for BadRequest {} pub fn request_data(addr: SocketAddr, req: Request) -> Result, Box> { // connect to server let mut sock = match TcpStream::connect(addr) { Ok(sock) => sock, Err(e) => return Err(Box::new(e)), }; // write the request to the socket // convert to u32 let req: u32 = req as u32; if let Err(e) = sock.write(&req.to_ne_bytes()) { return Err(Box::new(e)); } if let Err(e) = sock.flush() { return Err(Box::new(e)); } // read the response = (len, data) - len 0 means no data / bad request // read len let mut buf = [0u8; 8]; if let Err(e) = sock.read_exact(&mut buf[..]) { return Err(Box::new(e)); } let len = usize::from_ne_bytes(buf); // check for bad request if len == 0 { return Err(Box::new(BadRequest {})); } // read the data let mut data: Vec = vec![0u8; len]; if let Err(e) = sock.read_exact(&mut data[..]) { return Err(Box::new(e)); } Ok(data) }