Add fallback to plaintext is tls server isn't online

This commit is contained in:
asonix 2020-05-26 21:28:45 -05:00
parent 474985b1f4
commit b45a061184
2 changed files with 144 additions and 24 deletions

View file

@ -2,24 +2,36 @@ use crate::cache::RequestCache;
use async_lock::Lock;
use bytes::Bytes;
use futures::{
future::try_join,
sink::SinkExt,
stream::{SplitSink, SplitStream, StreamExt},
try_join,
};
use log::{debug, error};
use std::{io, net::SocketAddr, sync::Arc};
use std::{
io,
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use tokio::{
net::TcpStream,
net::{TcpStream, UdpSocket},
sync::mpsc::{Receiver, Sender},
};
use tokio_rustls::{client::TlsStream, rustls::ClientConfig, webpki::DNSName, TlsConnector};
use tokio_util::codec::{length_delimited::LengthDelimitedCodec, Framed};
use tokio_util::{
codec::{length_delimited::LengthDelimitedCodec, BytesCodec, Framed},
udp::UdpFramed,
};
#[derive(Clone)]
pub struct Config {
config: Arc<ClientConfig>,
upstream: SocketAddr,
domain: DNSName,
use_fallback: Arc<AtomicBool>,
fallback: SocketAddr,
cache: RequestCache,
}
@ -36,11 +48,19 @@ pub struct RecvHalf {
}
impl Config {
pub fn new(config: ClientConfig, upstream: SocketAddr, domain: DNSName) -> Self {
pub fn new(
config: ClientConfig,
upstream: SocketAddr,
domain: DNSName,
use_fallback: Arc<AtomicBool>,
fallback: SocketAddr,
) -> Self {
Config {
config: Arc::new(config),
upstream,
domain,
use_fallback,
fallback,
cache: RequestCache::new(),
}
}
@ -82,11 +102,17 @@ async fn do_run(
rx: Lock<Receiver<(Bytes, SocketAddr)>>,
mut tx: Sender<(Bytes, SocketAddr)>,
) -> Result<(), io::Error> {
let fallback_socket = UdpSocket::bind("0.0.0.0:0").await?;
let fallback_framed = UdpFramed::new(fallback_socket, BytesCodec::new());
let (mut fallback_sink, mut fallback_stream) = fallback_framed.split();
let fallback = config.fallback;
let conn = config.connect().await?;
let (mut sender, mut receiver) = conn.split();
let cache = config.cache.clone();
let use_fallback = config.use_fallback.clone();
let f1 = async move {
debug!("Locking");
let mut rx_guard = rx.lock().await;
@ -100,16 +126,42 @@ async fn do_run(
);
cache.insert(bytes.as_ref(), addr).await;
sender.inner.send(bytes).await?;
if use_fallback.load(Ordering::Relaxed) {
fallback_sink.send((bytes, fallback)).await?;
} else {
sender.inner.send(bytes).await?;
}
}
Ok(()) as Result<(), io::Error>
};
let cache = config.cache.clone();
let mut tx2 = tx.clone();
let f2 = async move {
while let Some(res) = receiver.inner.next().await {
let bytes = res?.freeze();
if let Some(addr) = cache.remove(bytes.as_ref()).await {
if let Err(_) = tx2.send((bytes, addr)).await {
error!("Error replying to {}", addr);
}
} else {
debug!(
"Couldn't find addr for {}{}",
bytes.as_ref()[0],
bytes.as_ref()[1]
);
}
}
Ok(()) as Result<(), io::Error>
};
let cache = config.cache.clone();
let f3 = async move {
while let Some(res) = fallback_stream.next().await {
let (bytes, _) = res?;
let bytes = bytes.freeze();
if let Some(addr) = cache.remove(bytes.as_ref()).await {
if let Err(_) = tx.send((bytes, addr)).await {
error!("Error replying to {}", addr);
@ -125,6 +177,6 @@ async fn do_run(
Ok(()) as Result<(), io::Error>
};
try_join(f1, f2).await?;
try_join!(f1, f2, f3)?;
Ok(())
}

View file

@ -1,12 +1,19 @@
use async_lock::Lock;
use bytes::Bytes;
use bytes::BytesMut;
use futures::{
future::try_join,
sink::SinkExt,
stream::{select_all, FuturesUnordered, StreamExt},
try_join,
stream::{select, select_all, FuturesUnordered, StreamExt},
};
use log::{debug, error, info, warn};
use std::{io, net::SocketAddr};
use std::{
io,
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use structopt::StructOpt;
use tokio::{net::UdpSocket, sync::mpsc::channel};
use tokio_rustls::{rustls::ClientConfig, webpki::DNSNameRef};
@ -22,12 +29,16 @@ static TEST_QUERY: &[u8] = &[
219, 230, 65, 120,
];
fn build_test_query() -> Bytes {
fn build_test_query() -> BytesMut {
use rand::RngCore;
let mut id = vec![0u8; 2];
rand::thread_rng().fill_bytes(&mut id);
id.extend(TEST_QUERY);
id.into()
let mut bm = BytesMut::new();
bm.extend(id);
bm.extend(TEST_QUERY);
bm
}
#[derive(StructOpt)]
@ -35,9 +46,12 @@ struct Options {
#[structopt(short, long, help = "The local address to bind the server to")]
address: Option<SocketAddr>,
#[structopt(short, long, help = "The upstream DNS server")]
#[structopt(short, long, help = "The IP of the upstream DNS server")]
upstream: SocketAddr,
#[structopt(short, long, help = "Fallback plaintext DNS server")]
fallback: SocketAddr,
#[structopt(short, long, help = "The domain of the upstream server")]
domain: String,
}
@ -46,6 +60,7 @@ struct Options {
async fn main() -> Result<(), anyhow::Error> {
let mut options = Options::from_args();
let fallback = options.fallback;
let local_address = options.address.take().unwrap_or("127.0.0.1:53".parse()?);
let domain = DNSNameRef::try_from_ascii_str(&options.domain)?.to_owned();
@ -60,9 +75,19 @@ async fn main() -> Result<(), anyhow::Error> {
info!("Listening on {}", local_address);
let udp_framed = UdpFramed::new(udp_socket, BytesCodec::new());
let (mut udp_sink, mut udp_stream) = udp_framed.split();
let (mut udp_sink, udp_stream) = udp_framed.split();
let config = self::conn::Config::new(config, options.upstream, domain);
let use_fallback = Arc::new(AtomicBool::new(false));
let config = self::conn::Config::new(
config,
options.upstream,
domain,
use_fallback.clone(),
fallback,
);
let (mut ctrl_cmd_tx, ctrl_cmd_rx) = channel(8);
let (mut ctrl_dns_tx, mut ctrl_dns_rx) = channel(8);
let mut unordered = FuturesUnordered::new();
let mut txs = vec![];
@ -80,18 +105,55 @@ async fn main() -> Result<(), anyhow::Error> {
rxs.push(rx2);
}
let f1 = async move {
let any_addr = "0.0.0.0:0".parse::<SocketAddr>().unwrap();
let ctrl_fut = tokio::spawn(async move {
use tokio::time::{delay_for, interval, timeout, Duration};
let mut i = interval(Duration::from_secs(5));
let max_fails: usize = 5;
let mut fail_count = 0;
loop {
if fail_count == 0 {
i.tick().await;
} else if fail_count == max_fails {
error!("Reached max failures, falling back to plaintext");
use_fallback.store(true, Ordering::Relaxed);
delay_for(Duration::from_secs(60)).await;
i = interval(Duration::from_secs(5));
}
let _ = ctrl_cmd_tx.send(Ok((build_test_query(), any_addr))).await;
if let Err(_) = timeout(Duration::from_secs(2), ctrl_dns_rx.next()).await {
warn!("Failed to get response for test query");
fail_count += 1;
} else if use_fallback.load(Ordering::Relaxed) {
info!("TLS server back online, disabling fallback");
fail_count = 0;
use_fallback.store(false, Ordering::Relaxed);
}
}
});
let handler_fut = tokio::spawn(async move {
while let Some(_) = unordered.next().await {
warn!("Handler completed");
// pass
}
Ok(()) as Result<(), io::Error>
};
});
let mut rx_stream = select_all(rxs);
let f2 = async move {
let reply_fut = async move {
while let Some(tup) = rx_stream.next().await {
if tup.1 == any_addr {
let _ = ctrl_dns_tx.send(tup.0).await;
continue;
}
debug!("Responding to {}", tup.1);
udp_sink.send(tup).await?;
@ -99,25 +161,31 @@ async fn main() -> Result<(), anyhow::Error> {
Ok(()) as Result<(), io::Error>
};
let f3 = async move {
let request_fut = async move {
let mut idx = 0;
let max_idx = txs.len();
while let Some(res) = udp_stream.next().await {
let mut stream = select(udp_stream, ctrl_cmd_rx);
while let Some(res) = stream.next().await {
let (bytes, addr) = res?;
let tup = (bytes.freeze(), addr);
debug!("Requesting for {}", addr);
let tup = (bytes.freeze(), addr);
if let Err(_) = txs[idx].send(tup).await {
error!("Unexpected dropped receiver!!!");
};
}
idx = (idx + 1) % max_idx;
}
Ok(()) as Result<(), io::Error>
};
try_join!(f1, f2, f3)?;
let udp_fut = tokio::spawn(try_join(reply_fut, request_fut));
let (r1, r2, _) = futures::try_join!(handler_fut, udp_fut, ctrl_fut)?;
r1?;
r2?;
Ok(())
}