pict-rs/src/middleware/payload.rs

181 lines
4.5 KiB
Rust
Raw Normal View History

use std::{
future::{ready, Ready},
sync::Arc,
};
use actix_web::{
dev::{Service, ServiceRequest, Transform},
http::Method,
HttpMessage,
};
use streem::IntoStreamer;
use tokio::task::JoinSet;
use crate::{future::NowOrNever, stream::LocalBoxStream};
const LIMIT: usize = 256;
async fn drain(rx: flume::Receiver<actix_web::dev::Payload>) {
let mut set = JoinSet::new();
while let Ok(payload) = rx.recv_async().await {
set.spawn_local(async move {
let mut streamer = payload.into_streamer();
while streamer.next().await.is_some() {}
});
let mut count = 0;
// drain completed tasks
while set.join_next().now_or_never().is_some() {
count += 1;
}
// if we're past the limit, wait for completions
while set.len() > LIMIT {
if set.join_next().await.is_some() {
count += 1;
}
}
if count > 0 {
tracing::info!("Drained {count} dropped payloads");
}
}
// drain set
while set.join_next().await.is_some() {}
}
#[derive(Clone)]
struct DrainHandle(Option<Arc<actix_web::rt::task::JoinHandle<()>>>);
pub(crate) struct Payload {
sender: flume::Sender<actix_web::dev::Payload>,
handle: DrainHandle,
}
pub(crate) struct PayloadMiddleware<S> {
inner: S,
sender: flume::Sender<actix_web::dev::Payload>,
_handle: DrainHandle,
}
pub(crate) struct PayloadStream {
inner: Option<actix_web::dev::Payload>,
sender: flume::Sender<actix_web::dev::Payload>,
}
impl DrainHandle {
fn new(handle: actix_web::rt::task::JoinHandle<()>) -> Self {
Self(Some(Arc::new(handle)))
}
}
impl Payload {
pub(crate) fn new() -> Self {
let (tx, rx) = crate::sync::channel(LIMIT);
let handle = DrainHandle::new(crate::sync::spawn(async move { drain(rx).await }));
Payload { sender: tx, handle }
}
}
impl Drop for DrainHandle {
fn drop(&mut self) {
if let Some(handle) = self.0.take().and_then(Arc::into_inner) {
handle.abort();
}
}
}
impl Drop for PayloadStream {
fn drop(&mut self) {
if let Some(payload) = self.inner.take() {
tracing::warn!("Dropped unclosed payload, draining");
if self.sender.try_send(payload).is_err() {
tracing::error!("Failed to send unclosed payload for draining");
}
}
}
}
impl futures_core::Stream for PayloadStream {
type Item = Result<actix_web::web::Bytes, actix_web::error::PayloadError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if let Some(inner) = self.inner.as_mut() {
let opt = std::task::ready!(std::pin::Pin::new(inner).poll_next(cx));
if opt.is_none() {
self.inner.take();
}
std::task::Poll::Ready(opt)
} else {
std::task::Poll::Ready(None)
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
if let Some(inner) = self.inner.as_ref() {
inner.size_hint()
} else {
(0, Some(0))
}
}
}
impl<S> Transform<S, ServiceRequest> for Payload
where
S: Service<ServiceRequest>,
S::Future: 'static,
{
type Response = S::Response;
type Error = S::Error;
type InitError = ();
type Transform = PayloadMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ready(Ok(PayloadMiddleware {
inner: service,
sender: self.sender.clone(),
_handle: self.handle.clone(),
}))
}
}
impl<S> Service<ServiceRequest> for PayloadMiddleware<S>
where
S: Service<ServiceRequest>,
S::Future: 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(
&self,
ctx: &mut core::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(ctx)
}
fn call(&self, mut req: ServiceRequest) -> Self::Future {
if matches!(*req.method(), Method::POST | Method::PATCH | Method::PUT) {
let payload = req.take_payload();
let payload: LocalBoxStream<'static, _> = Box::pin(PayloadStream {
inner: Some(payload),
sender: self.sender.clone(),
});
req.set_payload(payload.into());
}
self.inner.call(req)
}
}