Enable proper draining of dropped request payloads

Doing this as the outermost middleware ensures all endpoints are covered.

Update request deadline to turn negative deadlines into immediate failures
This commit is contained in:
asonix 2023-09-30 16:26:43 -05:00
parent 8ed5484efe
commit 66e1711723
4 changed files with 216 additions and 2 deletions

View file

@ -1,10 +1,39 @@
use std::{
future::Future,
sync::{Arc, OnceLock},
time::{Duration, Instant},
};
static NOOP_WAKER: OnceLock<std::task::Waker> = OnceLock::new();
fn noop_waker() -> &'static std::task::Waker {
NOOP_WAKER.get_or_init(|| std::task::Waker::from(Arc::new(NoopWaker)))
}
struct NoopWaker;
impl std::task::Wake for NoopWaker {
fn wake(self: std::sync::Arc<Self>) {}
fn wake_by_ref(self: &std::sync::Arc<Self>) {}
}
pub(crate) type LocalBoxFuture<'a, T> = std::pin::Pin<Box<dyn Future<Output = T> + 'a>>;
pub(crate) trait NowOrNever: Future {
fn now_or_never(self) -> Option<Self::Output>
where
Self: Sized,
{
let fut = std::pin::pin!(self);
let mut cx = std::task::Context::from_waker(noop_waker());
match fut.poll(&mut cx) {
std::task::Poll::Pending => None,
std::task::Poll::Ready(out) => Some(out),
}
}
}
pub(crate) trait WithTimeout: Future {
fn with_timeout(self, duration: Duration) -> actix_web::rt::time::Timeout<Self>
where
@ -30,6 +59,7 @@ pub(crate) trait WithMetrics: Future {
}
}
impl<F> NowOrNever for F where F: Future {}
impl<F> WithMetrics for F where F: Future {}
impl<F> WithTimeout for F where F: Future {}

View file

@ -42,7 +42,7 @@ use details::{ApiDetails, HumanDate};
use future::WithTimeout;
use futures_core::Stream;
use metrics_exporter_prometheus::PrometheusBuilder;
use middleware::Metrics;
use middleware::{Metrics, Payload};
use repo::ArcRepo;
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use reqwest_tracing::TracingMiddleware;
@ -1784,6 +1784,7 @@ async fn launch_file_store<F: Fn(&mut web::ServiceConfig) + Send + Clone + 'stat
.wrap(TracingLogger::default())
.wrap(Deadline)
.wrap(Metrics)
.wrap(Payload::new())
.app_data(web::Data::new(process_map.clone()))
.configure(move |sc| configure_endpoints(sc, repo, store, config, client, extra_config))
})
@ -1824,6 +1825,7 @@ async fn launch_object_store<F: Fn(&mut web::ServiceConfig) + Send + Clone + 'st
.wrap(TracingLogger::default())
.wrap(Deadline)
.wrap(Metrics)
.wrap(Payload::new())
.app_data(web::Data::new(process_map.clone()))
.configure(move |sc| configure_endpoints(sc, repo, store, config, client, extra_config))
})

View file

@ -1,4 +1,5 @@
mod metrics;
mod payload;
use actix_web::{
dev::{Service, ServiceRequest, Transform},
@ -15,6 +16,7 @@ use std::{
use crate::future::WithTimeout;
pub(crate) use self::metrics::Metrics;
pub(crate) use self::payload::Payload;
pub(crate) struct Deadline;
pub(crate) struct DeadlineMiddleware<S> {
@ -128,7 +130,7 @@ where
if now < deadline {
Some((deadline - now).try_into().ok()?)
} else {
None
Some(std::time::Duration::from_secs(0))
}
});
DeadlineFuture::new(self.inner.call(req), duration)

180
src/middleware/payload.rs Normal file
View file

@ -0,0 +1,180 @@
use std::{
future::{ready, Ready},
sync::Arc,
};
use actix_web::{
dev::{Service, ServiceRequest, Transform},
http::Method,
HttpMessage,
};
use streem::IntoStreamer;
use tokio::task::JoinSet;
use crate::{future::NowOrNever, stream::LocalBoxStream};
const LIMIT: usize = 256;
async fn drain(rx: flume::Receiver<actix_web::dev::Payload>) {
let mut set = JoinSet::new();
while let Ok(payload) = rx.recv_async().await {
set.spawn_local(async move {
let mut streamer = payload.into_streamer();
while streamer.next().await.is_some() {}
});
let mut count = 0;
// drain completed tasks
while set.join_next().now_or_never().is_some() {
count += 1;
}
// if we're past the limit, wait for completions
while set.len() > LIMIT {
if set.join_next().await.is_some() {
count += 1;
}
}
if count > 0 {
tracing::info!("Drained {count} dropped payloads");
}
}
// drain set
while set.join_next().await.is_some() {}
}
#[derive(Clone)]
struct DrainHandle(Option<Arc<actix_web::rt::task::JoinHandle<()>>>);
pub(crate) struct Payload {
sender: flume::Sender<actix_web::dev::Payload>,
handle: DrainHandle,
}
pub(crate) struct PayloadMiddleware<S> {
inner: S,
sender: flume::Sender<actix_web::dev::Payload>,
_handle: DrainHandle,
}
pub(crate) struct PayloadStream {
inner: Option<actix_web::dev::Payload>,
sender: flume::Sender<actix_web::dev::Payload>,
}
impl DrainHandle {
fn new(handle: actix_web::rt::task::JoinHandle<()>) -> Self {
Self(Some(Arc::new(handle)))
}
}
impl Payload {
pub(crate) fn new() -> Self {
let (tx, rx) = crate::sync::channel(LIMIT);
let handle = DrainHandle::new(crate::sync::spawn(async move { drain(rx).await }));
Payload { sender: tx, handle }
}
}
impl Drop for DrainHandle {
fn drop(&mut self) {
if let Some(handle) = self.0.take().and_then(Arc::into_inner) {
handle.abort();
}
}
}
impl Drop for PayloadStream {
fn drop(&mut self) {
if let Some(payload) = self.inner.take() {
tracing::warn!("Dropped unclosed payload, draining");
if self.sender.try_send(payload).is_err() {
tracing::error!("Failed to send unclosed payload for draining");
}
}
}
}
impl futures_core::Stream for PayloadStream {
type Item = Result<actix_web::web::Bytes, actix_web::error::PayloadError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if let Some(inner) = self.inner.as_mut() {
let opt = std::task::ready!(std::pin::Pin::new(inner).poll_next(cx));
if opt.is_none() {
self.inner.take();
}
std::task::Poll::Ready(opt)
} else {
std::task::Poll::Ready(None)
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
if let Some(inner) = self.inner.as_ref() {
inner.size_hint()
} else {
(0, Some(0))
}
}
}
impl<S> Transform<S, ServiceRequest> for Payload
where
S: Service<ServiceRequest>,
S::Future: 'static,
{
type Response = S::Response;
type Error = S::Error;
type InitError = ();
type Transform = PayloadMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(PayloadMiddleware {
inner: service,
sender: self.sender.clone(),
_handle: self.handle.clone(),
}))
}
}
impl<S> Service<ServiceRequest> for PayloadMiddleware<S>
where
S: Service<ServiceRequest>,
S::Future: 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(
&self,
ctx: &mut core::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(ctx)
}
fn call(&self, mut req: ServiceRequest) -> Self::Future {
if matches!(*req.method(), Method::POST | Method::PATCH | Method::PUT) {
let payload = req.take_payload();
let payload: LocalBoxStream<'static, _> = Box::pin(PayloadStream {
inner: Some(payload),
sender: self.sender.clone(),
});
req.set_payload(payload.into());
}
self.inner.call(req)
}
}