150 lines
4.2 KiB
Rust
150 lines
4.2 KiB
Rust
use actix_web::{
|
|
body::{BodyStream, MessageBody},
|
|
dev::{ServiceRequest, ServiceResponse},
|
|
error::{ErrorBadRequest, ErrorInternalServerError},
|
|
http::{header::USER_AGENT, uri::Parts, Uri},
|
|
web::{to, Data, Payload, PayloadConfig},
|
|
App, HttpRequest, HttpResponse, HttpServer,
|
|
};
|
|
use actix_web_lab::middleware::{from_fn, Next};
|
|
use awc::Client;
|
|
use std::{net::SocketAddr, rc::Rc, time::Duration};
|
|
use tracing_actix_web::TracingLogger;
|
|
use tracing_log::LogTracer;
|
|
use tracing_subscriber::{
|
|
filter::Targets, fmt::format::FmtSpan, layer::SubscriberExt, Layer, Registry,
|
|
};
|
|
|
|
struct State {
|
|
upstream: Uri,
|
|
client: Client,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct MissingUserAgent;
|
|
|
|
#[derive(Debug)]
|
|
struct InvalidUserAgent;
|
|
|
|
type Error = Box<dyn std::error::Error>;
|
|
|
|
#[actix_web::main]
|
|
async fn main() -> Result<(), Error> {
|
|
LogTracer::init()?;
|
|
|
|
let targets: Targets = std::env::var("RUST_LOG")
|
|
.unwrap_or_else(|_| "info".into())
|
|
.parse()?;
|
|
|
|
let format_layer = tracing_subscriber::fmt::layer()
|
|
.with_span_events(FmtSpan::NEW | FmtSpan::CLOSE)
|
|
.with_filter(targets.clone());
|
|
|
|
let subscriber = Registry::default().with(format_layer);
|
|
|
|
tracing::subscriber::set_global_default(subscriber)?;
|
|
|
|
let upstream: Uri = std::env::var("PROXY_UPSTREAM")?.parse()?;
|
|
|
|
let blocked: Vec<String> = std::env::var("PROXY_BLOCKS")?
|
|
.split(',')
|
|
.map(String::from)
|
|
.collect();
|
|
|
|
let addr: SocketAddr = std::env::var("PROXY_ADDR")?.parse()?;
|
|
|
|
HttpServer::new(move || {
|
|
let state = State {
|
|
upstream: upstream.clone(),
|
|
client: Client::builder().timeout(Duration::from_secs(30)).finish(),
|
|
};
|
|
|
|
let blocked = Rc::new(blocked.clone());
|
|
|
|
App::new()
|
|
.app_data(Data::new(state))
|
|
.app_data(PayloadConfig::new(1024 * 1024 * 200))
|
|
.wrap(from_fn(move |req, next| {
|
|
block_user_agents(req, next, Rc::clone(&blocked))
|
|
}))
|
|
.wrap(TracingLogger::default())
|
|
.default_service(to(proxy))
|
|
})
|
|
.bind(addr)?
|
|
.run()
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn block_user_agents(
|
|
req: ServiceRequest,
|
|
next: Next<impl MessageBody>,
|
|
blocked: Rc<Vec<String>>,
|
|
) -> Result<ServiceResponse<impl MessageBody>, actix_web::Error> {
|
|
let Some(user_agent) = req.headers().get(USER_AGENT) else {
|
|
return Err(ErrorBadRequest(MissingUserAgent).into());
|
|
};
|
|
|
|
let user_agent = user_agent.to_str().map_err(ErrorBadRequest)?.to_lowercase();
|
|
|
|
for part in blocked.iter() {
|
|
if user_agent.contains(part.as_str()) {
|
|
return Err(ErrorBadRequest(InvalidUserAgent).into());
|
|
}
|
|
}
|
|
|
|
next.call(req).await
|
|
}
|
|
|
|
async fn proxy(
|
|
inbound: HttpRequest,
|
|
body: Option<Payload>,
|
|
state: Data<State>,
|
|
) -> Result<HttpResponse, actix_web::Error> {
|
|
let mut upstream = Parts::default();
|
|
upstream.scheme = state.upstream.scheme().cloned();
|
|
upstream.authority = state.upstream.authority().cloned();
|
|
upstream.path_and_query = inbound.uri().path_and_query().cloned();
|
|
|
|
let mut req = state.client.request_from(upstream, inbound.head());
|
|
|
|
if let Some(peer_addr) = inbound.head().peer_addr {
|
|
req = req.append_header(("x-forwarded-for", peer_addr.to_string()));
|
|
}
|
|
|
|
req = req.no_decompress();
|
|
|
|
let response = if let Some(payload) = body {
|
|
req.send_stream(payload)
|
|
.await
|
|
.map_err(ErrorInternalServerError)?
|
|
} else {
|
|
req.send().await.map_err(ErrorInternalServerError)?
|
|
};
|
|
|
|
let mut downstream = HttpResponse::build(response.status());
|
|
|
|
for (name, value) in response
|
|
.headers()
|
|
.iter()
|
|
.filter(|(h, _)| *h != "connection")
|
|
{
|
|
downstream.insert_header((name.clone(), value.clone()));
|
|
}
|
|
|
|
Ok(downstream.body(BodyStream::new(response)))
|
|
}
|
|
|
|
impl std::fmt::Display for MissingUserAgent {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "No user-agent header provided")
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Display for InvalidUserAgent {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "Invalid user-agent header provided")
|
|
}
|
|
}
|