169 lines
4.7 KiB
Rust
169 lines
4.7 KiB
Rust
use bytes::{Bytes, BytesMut};
|
|
use futures::{future::try_join, sink::SinkExt, stream::StreamExt};
|
|
use log::{debug, error, info};
|
|
use mobc::{Manager, Pool};
|
|
use std::{io, net::SocketAddr, sync::Arc};
|
|
use structopt::StructOpt;
|
|
use tokio::{
|
|
net::{TcpStream, UdpSocket},
|
|
sync::mpsc::{channel, Sender},
|
|
};
|
|
use tokio_rustls::{
|
|
client::TlsStream,
|
|
rustls::ClientConfig,
|
|
webpki::{DNSName, DNSNameRef},
|
|
TlsConnector,
|
|
};
|
|
use tokio_util::{
|
|
codec::{length_delimited::LengthDelimitedCodec, BytesCodec, Framed},
|
|
udp::UdpFramed,
|
|
};
|
|
use webpki_roots::TLS_SERVER_ROOTS;
|
|
|
|
pub struct DotManager {
|
|
config: Arc<ClientConfig>,
|
|
upstream: SocketAddr,
|
|
domain: DNSName,
|
|
}
|
|
|
|
impl DotManager {
|
|
pub fn new(config: ClientConfig, upstream: SocketAddr, domain: DNSName) -> Self {
|
|
DotManager {
|
|
config: Arc::new(config),
|
|
upstream,
|
|
domain,
|
|
}
|
|
}
|
|
}
|
|
|
|
static TEST_QUERY: &[u8] = &[
|
|
1, 32, 0, 1, 0, 0, 0, 0, 0, 1, 11, 100, 111, 117, 98, 108, 101, 99, 108, 105, 99, 107, 3, 110,
|
|
101, 116, 0, 0, 1, 0, 1, 0, 0, 41, 16, 0, 0, 0, 0, 0, 0, 12, 0, 10, 0, 8, 126, 55, 17, 213,
|
|
219, 230, 65, 120,
|
|
];
|
|
|
|
#[mobc::async_trait]
|
|
impl Manager for DotManager {
|
|
type Connection = Framed<TlsStream<TcpStream>, LengthDelimitedCodec>;
|
|
type Error = io::Error;
|
|
|
|
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
|
|
let stream = TcpStream::connect(&self.upstream).await?;
|
|
let connector = TlsConnector::from(self.config.clone());
|
|
|
|
let tls_stream = connector.connect(self.domain.as_ref(), stream).await?;
|
|
|
|
let framed = LengthDelimitedCodec::builder()
|
|
.length_field_length(2)
|
|
.new_framed(tls_stream);
|
|
Ok(framed)
|
|
}
|
|
|
|
async fn check(&self, mut conn: Self::Connection) -> Result<Self::Connection, Self::Error> {
|
|
let mut id = vec![0u8; 2];
|
|
{
|
|
use rand::RngCore;
|
|
rand::thread_rng()
|
|
.try_fill_bytes(&mut id)
|
|
.map_err(|_| io::Error::from(io::ErrorKind::Other))?;
|
|
}
|
|
id.extend(TEST_QUERY);
|
|
conn.send(id.into()).await?;
|
|
conn.next()
|
|
.await
|
|
.ok_or(io::Error::from(io::ErrorKind::ConnectionReset))??;
|
|
Ok(conn)
|
|
}
|
|
}
|
|
|
|
#[derive(StructOpt)]
|
|
struct Options {
|
|
#[structopt(short, long, help = "The local address to bind the server to")]
|
|
address: Option<SocketAddr>,
|
|
|
|
#[structopt(short, long, help = "The upstream DNS server")]
|
|
upstream: SocketAddr,
|
|
|
|
#[structopt(short, long, help = "The domain of the upstream server")]
|
|
domain: String,
|
|
}
|
|
|
|
async fn forward(
|
|
bytes_mut: BytesMut,
|
|
addr: SocketAddr,
|
|
pool: Pool<DotManager>,
|
|
mut tx: Sender<Result<(Bytes, SocketAddr), io::Error>>,
|
|
) -> Result<(), anyhow::Error> {
|
|
let mut conn = pool.get().await?;
|
|
|
|
debug!("SENDING {:?}", bytes_mut.as_ref());
|
|
|
|
conn.send(bytes_mut.freeze()).await?;
|
|
let bytes_mut = conn
|
|
.next()
|
|
.await
|
|
.ok_or(io::Error::from(io::ErrorKind::ConnectionReset))??;
|
|
|
|
if let Err(_) = tx.send(Ok((bytes_mut.freeze(), addr))).await {
|
|
error!("Error responding to {}", addr);
|
|
}
|
|
|
|
info!("Finished forwarding for {}", addr);
|
|
Ok(())
|
|
}
|
|
|
|
async fn do_forward(
|
|
bytes_mut: BytesMut,
|
|
addr: SocketAddr,
|
|
pool: Pool<DotManager>,
|
|
tx: Sender<Result<(Bytes, SocketAddr), io::Error>>,
|
|
) {
|
|
if let Err(e) = forward(bytes_mut, addr, pool, tx).await {
|
|
error!("Error forwarding for {}, {}", addr, e);
|
|
}
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), anyhow::Error> {
|
|
let mut options = Options::from_args();
|
|
|
|
let local_address = options.address.take().unwrap_or("127.0.0.1:53".parse()?);
|
|
let domain = DNSNameRef::try_from_ascii_str(&options.domain)?.to_owned();
|
|
|
|
env_logger::init();
|
|
|
|
let mut config = ClientConfig::new();
|
|
config
|
|
.root_store
|
|
.add_server_trust_anchors(&TLS_SERVER_ROOTS);
|
|
|
|
let udp_socket = UdpSocket::bind(local_address).await?;
|
|
info!("Listening on {}", local_address);
|
|
let udp_framed = UdpFramed::new(udp_socket, BytesCodec::new());
|
|
|
|
let manager = DotManager::new(config, options.upstream, domain);
|
|
let tls_pool = Pool::builder().max_open(16).build(manager);
|
|
|
|
let (udp_sink, mut udp_stream) = udp_framed.split();
|
|
|
|
let (tx, rx) = channel(32);
|
|
|
|
let f2 = async move {
|
|
rx.forward(udp_sink).await?;
|
|
Ok(()) as Result<_, io::Error>
|
|
};
|
|
|
|
let f1 = async move {
|
|
while let Some(res) = udp_stream.next().await {
|
|
let (bytes_mut, addr) = res?;
|
|
|
|
tokio::spawn(do_forward(bytes_mut, addr, tls_pool.clone(), tx.clone()));
|
|
}
|
|
Ok(()) as Result<_, io::Error>
|
|
};
|
|
|
|
try_join(f1, f2).await?;
|
|
|
|
Ok(())
|
|
}
|