2024-02-04 15:45:47 -06:00

226 lines
5.8 KiB

use std::{
future::{ready, Ready},
time::{Duration, Instant},
use actix_web::{
dev::{Service, ServiceRequest, Transform},
use streem::IntoStreamer;
use tokio::task::JoinSet;
use crate::{future::NowOrNever, stream::LocalBoxStream};
const LIMIT: usize = 256;
struct MetricsGuard {
start: Instant,
armed: bool,
impl MetricsGuard {
fn guard() -> Self {
MetricsGuard {
start: Instant::now(),
armed: true,
fn disarm(mut self) {
self.armed = false;
impl Drop for MetricsGuard {
fn drop(&mut self) {
metrics::counter!(crate::init_metrics::PAYLOAD_DRAIN_END, "completed" => (!self.armed).to_string())
metrics::histogram!(crate::init_metrics::PAYLOAD_DRAIN_DURATION, "completed" => (!self.armed).to_string())
async fn drain(rx: flume::Receiver<actix_web::dev::Payload>) {
let mut set = JoinSet::new();
while let Ok(payload) = rx.recv_async().await {
tracing::trace!("drain: looping");
// draining a payload is a best-effort task - if we can't collect in 2 minutes we bail
let guard = MetricsGuard::guard();
set.spawn_local(tokio::time::timeout(Duration::from_secs(120), async move {
let mut streamer = payload.into_streamer();
while streamer.next().await.is_some() {
tracing::trace!("drain drop bytes: looping");
let mut count = 0;
// drain completed tasks
while set.join_next().now_or_never().is_some() {
tracing::trace!("drain join now: looping");
count += 1;
// if we're past the limit, wait for completions
while set.len() > LIMIT {
tracing::trace!("drain join await: looping");
if set.join_next().await.is_some() {
count += 1;
if count > 0 {
tracing::debug!("Drained {count} dropped payloads");
// drain set
while set.join_next().await.is_some() {
tracing::trace!("drain join await cleanup: looping");
struct DrainHandle(Option<Rc<tokio::task::JoinHandle<()>>>);
pub(crate) struct Payload {
sender: flume::Sender<actix_web::dev::Payload>,
handle: DrainHandle,
pub(crate) struct PayloadMiddleware<S> {
inner: S,
sender: flume::Sender<actix_web::dev::Payload>,
_handle: DrainHandle,
pub(crate) struct PayloadStream {
inner: Option<actix_web::dev::Payload>,
sender: flume::Sender<actix_web::dev::Payload>,
impl DrainHandle {
fn new(handle: tokio::task::JoinHandle<()>) -> Self {
impl Payload {
pub(crate) fn new() -> Self {
let (tx, rx) = crate::sync::channel(LIMIT);
let handle = DrainHandle::new(crate::sync::spawn("drain-payloads", drain(rx)));
Payload { sender: tx, handle }
impl Drop for DrainHandle {
fn drop(&mut self) {
if let Some(handle) = self.0.take().and_then(Rc::into_inner) {
impl Drop for PayloadStream {
fn drop(&mut self) {
if let Some(payload) = self.inner.take() {
tracing::debug!("Dropped unclosed payload, draining");
if self.sender.try_send(payload).is_err() {
tracing::error!("Failed to send unclosed payload for draining");
impl futures_core::Stream for PayloadStream {
type Item = Result<actix_web::web::Bytes, actix_web::error::PayloadError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if let Some(inner) = self.inner.as_mut() {
let opt = std::task::ready!(std::pin::Pin::new(inner).poll_next(cx));
if opt.is_none() {
} else {
fn size_hint(&self) -> (usize, Option<usize>) {
if let Some(inner) = self.inner.as_ref() {
} else {
(0, Some(0))
impl<S> Transform<S, ServiceRequest> for Payload
S: Service<ServiceRequest>,
S::Future: 'static,
type Response = S::Response;
type Error = S::Error;
type InitError = ();
type Transform = PayloadMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(PayloadMiddleware {
inner: service,
sender: self.sender.clone(),
_handle: self.handle.clone(),
impl<S> Service<ServiceRequest> for PayloadMiddleware<S>
S: Service<ServiceRequest>,
S::Future: 'static,
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(
ctx: &mut core::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
fn call(&self, mut req: ServiceRequest) -> Self::Future {
let payload = req.take_payload();
if !matches!(payload, actix_web::dev::Payload::None) {
let payload: LocalBoxStream<'static, _> = Box::pin(PayloadStream {
inner: Some(payload),
sender: self.sender.clone(),