dns-over-tls-client-proxy/src/main.rs
2020-05-26 17:25:43 -05:00

169 lines
4.7 KiB
Rust

use bytes::{Bytes, BytesMut};
use futures::{future::try_join, sink::SinkExt, stream::StreamExt};
use log::{debug, error, info};
use mobc::{Manager, Pool};
use std::{io, net::SocketAddr, sync::Arc};
use structopt::StructOpt;
use tokio::{
net::{TcpStream, UdpSocket},
sync::mpsc::{channel, Sender},
};
use tokio_rustls::{
client::TlsStream,
rustls::ClientConfig,
webpki::{DNSName, DNSNameRef},
TlsConnector,
};
use tokio_util::{
codec::{length_delimited::LengthDelimitedCodec, BytesCodec, Framed},
udp::UdpFramed,
};
use webpki_roots::TLS_SERVER_ROOTS;
pub struct DotManager {
config: Arc<ClientConfig>,
upstream: SocketAddr,
domain: DNSName,
}
impl DotManager {
pub fn new(config: ClientConfig, upstream: SocketAddr, domain: DNSName) -> Self {
DotManager {
config: Arc::new(config),
upstream,
domain,
}
}
}
static TEST_QUERY: &[u8] = &[
1, 32, 0, 1, 0, 0, 0, 0, 0, 1, 11, 100, 111, 117, 98, 108, 101, 99, 108, 105, 99, 107, 3, 110,
101, 116, 0, 0, 1, 0, 1, 0, 0, 41, 16, 0, 0, 0, 0, 0, 0, 12, 0, 10, 0, 8, 126, 55, 17, 213,
219, 230, 65, 120,
];
#[mobc::async_trait]
impl Manager for DotManager {
type Connection = Framed<TlsStream<TcpStream>, LengthDelimitedCodec>;
type Error = io::Error;
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
let stream = TcpStream::connect(&self.upstream).await?;
let connector = TlsConnector::from(self.config.clone());
let tls_stream = connector.connect(self.domain.as_ref(), stream).await?;
let framed = LengthDelimitedCodec::builder()
.length_field_length(2)
.new_framed(tls_stream);
Ok(framed)
}
async fn check(&self, mut conn: Self::Connection) -> Result<Self::Connection, Self::Error> {
let mut id = vec![0u8; 2];
{
use rand::RngCore;
rand::thread_rng()
.try_fill_bytes(&mut id)
.map_err(|_| io::Error::from(io::ErrorKind::Other))?;
}
id.extend(TEST_QUERY);
conn.send(id.into()).await?;
conn.next()
.await
.ok_or(io::Error::from(io::ErrorKind::ConnectionReset))??;
Ok(conn)
}
}
#[derive(StructOpt)]
struct Options {
#[structopt(short, long, help = "The local address to bind the server to")]
address: Option<SocketAddr>,
#[structopt(short, long, help = "The upstream DNS server")]
upstream: SocketAddr,
#[structopt(short, long, help = "The domain of the upstream server")]
domain: String,
}
async fn forward(
bytes_mut: BytesMut,
addr: SocketAddr,
pool: Pool<DotManager>,
mut tx: Sender<Result<(Bytes, SocketAddr), io::Error>>,
) -> Result<(), anyhow::Error> {
let mut conn = pool.get().await?;
debug!("SENDING {:?}", bytes_mut.as_ref());
conn.send(bytes_mut.freeze()).await?;
let bytes_mut = conn
.next()
.await
.ok_or(io::Error::from(io::ErrorKind::ConnectionReset))??;
if let Err(_) = tx.send(Ok((bytes_mut.freeze(), addr))).await {
error!("Error responding to {}", addr);
}
info!("Finished forwarding for {}", addr);
Ok(())
}
async fn do_forward(
bytes_mut: BytesMut,
addr: SocketAddr,
pool: Pool<DotManager>,
tx: Sender<Result<(Bytes, SocketAddr), io::Error>>,
) {
if let Err(e) = forward(bytes_mut, addr, pool, tx).await {
error!("Error forwarding for {}, {}", addr, e);
}
}
#[tokio::main]
async fn main() -> Result<(), anyhow::Error> {
let mut options = Options::from_args();
let local_address = options.address.take().unwrap_or("127.0.0.1:53".parse()?);
let domain = DNSNameRef::try_from_ascii_str(&options.domain)?.to_owned();
env_logger::init();
let mut config = ClientConfig::new();
config
.root_store
.add_server_trust_anchors(&TLS_SERVER_ROOTS);
let udp_socket = UdpSocket::bind(local_address).await?;
info!("Listening on {}", local_address);
let udp_framed = UdpFramed::new(udp_socket, BytesCodec::new());
let manager = DotManager::new(config, options.upstream, domain);
let tls_pool = Pool::builder().max_open(16).build(manager);
let (udp_sink, mut udp_stream) = udp_framed.split();
let (tx, rx) = channel(32);
let f2 = async move {
rx.forward(udp_sink).await?;
Ok(()) as Result<_, io::Error>
};
let f1 = async move {
while let Some(res) = udp_stream.next().await {
let (bytes_mut, addr) = res?;
tokio::spawn(do_forward(bytes_mut, addr, tls_pool.clone(), tx.clone()));
}
Ok(()) as Result<_, io::Error>
};
try_join(f1, f2).await?;
Ok(())
}