Add fallback to plaintext is tls server isn't online
This commit is contained in:
parent
474985b1f4
commit
b45a061184
66
src/conn.rs
66
src/conn.rs
|
@ -2,24 +2,36 @@ use crate::cache::RequestCache;
|
|||
use async_lock::Lock;
|
||||
use bytes::Bytes;
|
||||
use futures::{
|
||||
future::try_join,
|
||||
sink::SinkExt,
|
||||
stream::{SplitSink, SplitStream, StreamExt},
|
||||
try_join,
|
||||
};
|
||||
use log::{debug, error};
|
||||
use std::{io, net::SocketAddr, sync::Arc};
|
||||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use tokio::{
|
||||
net::TcpStream,
|
||||
net::{TcpStream, UdpSocket},
|
||||
sync::mpsc::{Receiver, Sender},
|
||||
};
|
||||
use tokio_rustls::{client::TlsStream, rustls::ClientConfig, webpki::DNSName, TlsConnector};
|
||||
use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed};
|
||||
use tokio_util::{
|
||||
codec::{length_delimited::LengthDelimitedCodec, BytesCodec, Framed},
|
||||
udp::UdpFramed,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
config: Arc<ClientConfig>,
|
||||
upstream: SocketAddr,
|
||||
domain: DNSName,
|
||||
use_fallback: Arc<AtomicBool>,
|
||||
fallback: SocketAddr,
|
||||
cache: RequestCache,
|
||||
}
|
||||
|
||||
|
@ -36,11 +48,19 @@ pub struct RecvHalf {
|
|||
}
|
||||
|
||||
impl Config {
|
||||
pub fn new(config: ClientConfig, upstream: SocketAddr, domain: DNSName) -> Self {
|
||||
pub fn new(
|
||||
config: ClientConfig,
|
||||
upstream: SocketAddr,
|
||||
domain: DNSName,
|
||||
use_fallback: Arc<AtomicBool>,
|
||||
fallback: SocketAddr,
|
||||
) -> Self {
|
||||
Config {
|
||||
config: Arc::new(config),
|
||||
upstream,
|
||||
domain,
|
||||
use_fallback,
|
||||
fallback,
|
||||
cache: RequestCache::new(),
|
||||
}
|
||||
}
|
||||
|
@ -82,11 +102,17 @@ async fn do_run(
|
|||
rx: Lock<Receiver<(Bytes, SocketAddr)>>,
|
||||
mut tx: Sender<(Bytes, SocketAddr)>,
|
||||
) -> Result<(), io::Error> {
|
||||
let fallback_socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
let fallback_framed = UdpFramed::new(fallback_socket, BytesCodec::new());
|
||||
let (mut fallback_sink, mut fallback_stream) = fallback_framed.split();
|
||||
let fallback = config.fallback;
|
||||
|
||||
let conn = config.connect().await?;
|
||||
|
||||
let (mut sender, mut receiver) = conn.split();
|
||||
|
||||
let cache = config.cache.clone();
|
||||
let use_fallback = config.use_fallback.clone();
|
||||
let f1 = async move {
|
||||
debug!("Locking");
|
||||
let mut rx_guard = rx.lock().await;
|
||||
|
@ -100,16 +126,42 @@ async fn do_run(
|
|||
);
|
||||
cache.insert(bytes.as_ref(), addr).await;
|
||||
|
||||
sender.inner.send(bytes).await?;
|
||||
if use_fallback.load(Ordering::Relaxed) {
|
||||
fallback_sink.send((bytes, fallback)).await?;
|
||||
} else {
|
||||
sender.inner.send(bytes).await?;
|
||||
}
|
||||
}
|
||||
Ok(()) as Result<(), io::Error>
|
||||
};
|
||||
|
||||
let cache = config.cache.clone();
|
||||
let mut tx2 = tx.clone();
|
||||
let f2 = async move {
|
||||
while let Some(res) = receiver.inner.next().await {
|
||||
let bytes = res?.freeze();
|
||||
|
||||
if let Some(addr) = cache.remove(bytes.as_ref()).await {
|
||||
if let Err(_) = tx2.send((bytes, addr)).await {
|
||||
error!("Error replying to {}", addr);
|
||||
}
|
||||
} else {
|
||||
debug!(
|
||||
"Couldn't find addr for {}{}",
|
||||
bytes.as_ref()[0],
|
||||
bytes.as_ref()[1]
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(()) as Result<(), io::Error>
|
||||
};
|
||||
|
||||
let cache = config.cache.clone();
|
||||
let f3 = async move {
|
||||
while let Some(res) = fallback_stream.next().await {
|
||||
let (bytes, _) = res?;
|
||||
let bytes = bytes.freeze();
|
||||
|
||||
if let Some(addr) = cache.remove(bytes.as_ref()).await {
|
||||
if let Err(_) = tx.send((bytes, addr)).await {
|
||||
error!("Error replying to {}", addr);
|
||||
|
@ -125,6 +177,6 @@ async fn do_run(
|
|||
Ok(()) as Result<(), io::Error>
|
||||
};
|
||||
|
||||
try_join(f1, f2).await?;
|
||||
try_join!(f1, f2, f3)?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
102
src/main.rs
102
src/main.rs
|
@ -1,12 +1,19 @@
|
|||
use async_lock::Lock;
|
||||
use bytes::Bytes;
|
||||
use bytes::BytesMut;
|
||||
use futures::{
|
||||
future::try_join,
|
||||
sink::SinkExt,
|
||||
stream::{select_all, FuturesUnordered, StreamExt},
|
||||
try_join,
|
||||
stream::{select, select_all, FuturesUnordered, StreamExt},
|
||||
};
|
||||
use log::{debug, error, info, warn};
|
||||
use std::{io, net::SocketAddr};
|
||||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
use structopt::StructOpt;
|
||||
use tokio::{net::UdpSocket, sync::mpsc::channel};
|
||||
use tokio_rustls::{rustls::ClientConfig, webpki::DNSNameRef};
|
||||
|
@ -22,12 +29,16 @@ static TEST_QUERY: &[u8] = &[
|
|||
219, 230, 65, 120,
|
||||
];
|
||||
|
||||
fn build_test_query() -> Bytes {
|
||||
fn build_test_query() -> BytesMut {
|
||||
use rand::RngCore;
|
||||
let mut id = vec![0u8; 2];
|
||||
rand::thread_rng().fill_bytes(&mut id);
|
||||
id.extend(TEST_QUERY);
|
||||
id.into()
|
||||
|
||||
let mut bm = BytesMut::new();
|
||||
bm.extend(id);
|
||||
bm.extend(TEST_QUERY);
|
||||
bm
|
||||
}
|
||||
|
||||
#[derive(StructOpt)]
|
||||
|
@ -35,9 +46,12 @@ struct Options {
|
|||
#[structopt(short, long, help = "The local address to bind the server to")]
|
||||
address: Option<SocketAddr>,
|
||||
|
||||
#[structopt(short, long, help = "The upstream DNS server")]
|
||||
#[structopt(short, long, help = "The IP of the upstream DNS server")]
|
||||
upstream: SocketAddr,
|
||||
|
||||
#[structopt(short, long, help = "Fallback plaintext DNS server")]
|
||||
fallback: SocketAddr,
|
||||
|
||||
#[structopt(short, long, help = "The domain of the upstream server")]
|
||||
domain: String,
|
||||
}
|
||||
|
@ -46,6 +60,7 @@ struct Options {
|
|||
async fn main() -> Result<(), anyhow::Error> {
|
||||
let mut options = Options::from_args();
|
||||
|
||||
let fallback = options.fallback;
|
||||
let local_address = options.address.take().unwrap_or("127.0.0.1:53".parse()?);
|
||||
let domain = DNSNameRef::try_from_ascii_str(&options.domain)?.to_owned();
|
||||
|
||||
|
@ -60,9 +75,19 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
info!("Listening on {}", local_address);
|
||||
let udp_framed = UdpFramed::new(udp_socket, BytesCodec::new());
|
||||
|
||||
let (mut udp_sink, mut udp_stream) = udp_framed.split();
|
||||
let (mut udp_sink, udp_stream) = udp_framed.split();
|
||||
|
||||
let config = self::conn::Config::new(config, options.upstream, domain);
|
||||
let use_fallback = Arc::new(AtomicBool::new(false));
|
||||
let config = self::conn::Config::new(
|
||||
config,
|
||||
options.upstream,
|
||||
domain,
|
||||
use_fallback.clone(),
|
||||
fallback,
|
||||
);
|
||||
|
||||
let (mut ctrl_cmd_tx, ctrl_cmd_rx) = channel(8);
|
||||
let (mut ctrl_dns_tx, mut ctrl_dns_rx) = channel(8);
|
||||
|
||||
let mut unordered = FuturesUnordered::new();
|
||||
let mut txs = vec![];
|
||||
|
@ -80,18 +105,55 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
rxs.push(rx2);
|
||||
}
|
||||
|
||||
let f1 = async move {
|
||||
let any_addr = "0.0.0.0:0".parse::<SocketAddr>().unwrap();
|
||||
|
||||
let ctrl_fut = tokio::spawn(async move {
|
||||
use tokio::time::{delay_for, interval, timeout, Duration};
|
||||
|
||||
let mut i = interval(Duration::from_secs(5));
|
||||
|
||||
let max_fails: usize = 5;
|
||||
let mut fail_count = 0;
|
||||
|
||||
loop {
|
||||
if fail_count == 0 {
|
||||
i.tick().await;
|
||||
} else if fail_count == max_fails {
|
||||
error!("Reached max failures, falling back to plaintext");
|
||||
use_fallback.store(true, Ordering::Relaxed);
|
||||
delay_for(Duration::from_secs(60)).await;
|
||||
i = interval(Duration::from_secs(5));
|
||||
}
|
||||
|
||||
let _ = ctrl_cmd_tx.send(Ok((build_test_query(), any_addr))).await;
|
||||
|
||||
if let Err(_) = timeout(Duration::from_secs(2), ctrl_dns_rx.next()).await {
|
||||
warn!("Failed to get response for test query");
|
||||
fail_count += 1;
|
||||
} else if use_fallback.load(Ordering::Relaxed) {
|
||||
info!("TLS server back online, disabling fallback");
|
||||
fail_count = 0;
|
||||
use_fallback.store(false, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let handler_fut = tokio::spawn(async move {
|
||||
while let Some(_) = unordered.next().await {
|
||||
warn!("Handler completed");
|
||||
// pass
|
||||
}
|
||||
Ok(()) as Result<(), io::Error>
|
||||
};
|
||||
});
|
||||
|
||||
let mut rx_stream = select_all(rxs);
|
||||
|
||||
let f2 = async move {
|
||||
let reply_fut = async move {
|
||||
while let Some(tup) = rx_stream.next().await {
|
||||
if tup.1 == any_addr {
|
||||
let _ = ctrl_dns_tx.send(tup.0).await;
|
||||
continue;
|
||||
}
|
||||
debug!("Responding to {}", tup.1);
|
||||
|
||||
udp_sink.send(tup).await?;
|
||||
|
@ -99,25 +161,31 @@ async fn main() -> Result<(), anyhow::Error> {
|
|||
Ok(()) as Result<(), io::Error>
|
||||
};
|
||||
|
||||
let f3 = async move {
|
||||
let request_fut = async move {
|
||||
let mut idx = 0;
|
||||
let max_idx = txs.len();
|
||||
while let Some(res) = udp_stream.next().await {
|
||||
let mut stream = select(udp_stream, ctrl_cmd_rx);
|
||||
while let Some(res) = stream.next().await {
|
||||
let (bytes, addr) = res?;
|
||||
let tup = (bytes.freeze(), addr);
|
||||
|
||||
debug!("Requesting for {}", addr);
|
||||
|
||||
let tup = (bytes.freeze(), addr);
|
||||
|
||||
if let Err(_) = txs[idx].send(tup).await {
|
||||
error!("Unexpected dropped receiver!!!");
|
||||
};
|
||||
}
|
||||
|
||||
idx = (idx + 1) % max_idx;
|
||||
}
|
||||
Ok(()) as Result<(), io::Error>
|
||||
};
|
||||
|
||||
try_join!(f1, f2, f3)?;
|
||||
let udp_fut = tokio::spawn(try_join(reply_fut, request_fut));
|
||||
|
||||
let (r1, r2, _) = futures::try_join!(handler_fut, udp_fut, ctrl_fut)?;
|
||||
r1?;
|
||||
r2?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue