From 2803db70737867199786f5363ca6dae8543e00d9 Mon Sep 17 00:00:00 2001 From: Ferdinand Schober Date: Fri, 19 Jan 2024 00:22:00 +0100 Subject: [PATCH] refactor dns task --- src/server.rs | 40 ++++++++----------------------------- src/server/frontend_task.rs | 16 ++++++++------- src/server/resolver_task.rs | 40 +++++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 39 deletions(-) create mode 100644 src/server/resolver_task.rs diff --git a/src/server.rs b/src/server.rs index 61701ab..70e0cf2 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,7 +1,6 @@ use log; use std::{ cell::{Cell, RefCell}, - collections::HashSet, io::Result, rc::Rc, time::Duration, @@ -21,11 +20,12 @@ use crate::{ }; use crate::{consumer, producer}; -use self::consumer_task::ConsumerEvent; +use self::{consumer_task::ConsumerEvent, resolver_task::DnsRequest}; mod consumer_task; mod frontend_task; mod producer_task; +mod resolver_task; const MAX_RESPONSE_TIME: Duration = Duration::from_millis(500); @@ -83,7 +83,6 @@ impl Server { }; let (consumer, producer) = tokio::join!(consumer::create(), producer::create()); - let (resolve_tx, mut resolve_rx) = tokio::sync::mpsc::channel(32); let (receiver_tx, receiver_rx) = tokio::sync::mpsc::channel(32); let (sender_tx, mut sender_rx) = tokio::sync::mpsc::channel(32); let (port_tx, mut port_rx) = tokio::sync::mpsc::channel(32); @@ -103,6 +102,10 @@ impl Server { timer_tx, ); + // create dns resolver + let resolver = dns::DnsResolver::new().await?; + let (mut resolver_task, resolve_tx) = resolver_task::new(resolver, self.clone()); + // frontend listener let (mut frontend_task, frontend_tx, frontend_notify_tx) = frontend_task::new( frontend, @@ -113,37 +116,10 @@ impl Server { port_tx, ); - // dns resolver - - // create dns resolver - let resolver = dns::DnsResolver::new().await?; - let server = self.clone(); - let mut resolver_task = tokio::task::spawn_local(async move { - loop { - let (host, client): (String, ClientHandle) = match resolve_rx.recv().await { - Some(r) => r, - None => break, - }; - let ips = match resolver.resolve(&host).await { - Ok(ips) => ips, - Err(e) => { - log::warn!("could not resolve host '{host}': {e}"); - continue; - } - }; - if let Some(state) = server.client_manager.borrow_mut().get_mut(client) { - let mut addrs = HashSet::from_iter(state.client.fix_ips.iter().cloned()); - for ip in ips { - addrs.insert(ip); - } - state.client.ips = addrs; - } - } - }); - // bind the udp socket let listen_addr = SocketAddr::new("0.0.0.0".parse().unwrap(), self.port.get()); let mut socket = UdpSocket::bind(listen_addr).await?; + let server = self.clone(); // udp task let mut udp_task = tokio::task::spawn_local(async move { loop { @@ -318,7 +294,7 @@ impl Server { .send(FrontendEvent::ActivateClient(handle, true)) .await?; if let Some(hostname) = hostname { - let _ = resolve_tx.send((hostname, handle)).await; + let _ = resolve_tx.send(DnsRequest { hostname, handle }).await; } } diff --git a/src/server/frontend_task.rs b/src/server/frontend_task.rs index 611d72d..a0409d1 100644 --- a/src/server/frontend_task.rs +++ b/src/server/frontend_task.rs @@ -17,14 +17,16 @@ use crate::{ frontend::{self, FrontendEvent, FrontendListener, FrontendNotify}, }; -use super::{consumer_task::ConsumerEvent, producer_task::ProducerEvent, Server}; +use super::{ + consumer_task::ConsumerEvent, producer_task::ProducerEvent, resolver_task::DnsRequest, Server, +}; pub(crate) fn new( mut frontend: FrontendListener, server: Server, producer_notify: Sender, consumer_notify: Sender, - resolve_ch: Sender<(String, u32)>, + resolve_ch: Sender, port_tx: Sender, ) -> ( JoinHandle>, @@ -98,7 +100,7 @@ async fn handle_frontend_event( server: &Server, producer_tx: &Sender, consumer_tx: &Sender, - resolve_tx: &Sender<(String, ClientHandle)>, + resolve_tx: &Sender, frontend: &mut FrontendListener, port_tx: &Sender, event: FrontendEvent, @@ -173,7 +175,7 @@ async fn handle_frontend_event( pub async fn add_client( server: &Server, - resolver_tx: &Sender<(String, ClientHandle)>, + resolver_tx: &Sender, hostname: Option, addr: HashSet, port: u16, @@ -194,7 +196,7 @@ pub async fn add_client( log::debug!("add_client {handle}"); if let Some(hostname) = hostname { - let _ = resolver_tx.send((hostname, handle)).await; + let _ = resolver_tx.send(DnsRequest { hostname, handle }).await; } handle @@ -266,7 +268,7 @@ async fn update_client( server: &Server, producer_notify_tx: &Sender, consumer_notify_tx: &Sender, - resolve_tx: &Sender<(String, ClientHandle)>, + resolve_tx: &Sender, client_update: (ClientHandle, Option, u16, Position), ) { let (handle, hostname, port, pos) = client_update; @@ -303,7 +305,7 @@ async fn update_client( // resolve dns if let Some(hostname) = hostname { - let _ = resolve_tx.send((hostname, handle)).await; + let _ = resolve_tx.send(DnsRequest { hostname, handle }).await; } // update state in event consumer & producer diff --git a/src/server/resolver_task.rs b/src/server/resolver_task.rs new file mode 100644 index 0000000..8a0412e --- /dev/null +++ b/src/server/resolver_task.rs @@ -0,0 +1,40 @@ +use std::collections::HashSet; + +use tokio::{sync::mpsc::Sender, task::JoinHandle}; + +use crate::{client::ClientHandle, dns::DnsResolver}; + +use super::Server; + +#[derive(Clone)] +pub struct DnsRequest { + pub hostname: String, + pub handle: ClientHandle, +} + +pub fn new(resolver: DnsResolver, server: Server) -> (JoinHandle<()>, Sender) { + let (dns_tx, mut dns_rx) = tokio::sync::mpsc::channel::(32); + let resolver_task = tokio::task::spawn_local(async move { + loop { + let (host, handle) = match dns_rx.recv().await { + Some(r) => (r.hostname, r.handle), + None => break, + }; + let ips = match resolver.resolve(&host).await { + Ok(ips) => ips, + Err(e) => { + log::warn!("could not resolve host '{host}': {e}"); + continue; + } + }; + if let Some(state) = server.client_manager.borrow_mut().get_mut(handle) { + let mut addrs = HashSet::from_iter(state.client.fix_ips.iter().cloned()); + for ip in ips { + addrs.insert(ip); + } + state.client.ips = addrs; + } + } + }); + (resolver_task, dns_tx) +}