multipart-client-stream/src/internal.rs

258 lines
6.9 KiB
Rust

use std::{
collections::VecDeque,
pin::Pin,
task::{Context, Poll},
};
use bytes::{BufMut, BytesMut};
use tokio::io::{AsyncRead, ReadBuf};
const CONTENT_TYPE: &[u8] = b"content-type";
const CONTENT_DISPOSITION: &[u8] = b"content-disposition";
pub struct SendRead<'a>(pub(super) Pin<Box<dyn AsyncRead + Send + 'a>>);
pub struct UnsendRead<'a>(pub(super) Pin<Box<dyn AsyncRead + 'a>>);
impl<'a> From<SendRead<'a>> for UnsendRead<'a> {
fn from(value: SendRead<'a>) -> Self {
UnsendRead(value.0)
}
}
pub(super) struct Body<R> {
boundary: Vec<u8>,
pending: BytesMut,
current: Option<R>,
rest: VecDeque<Part<R>>,
closed: bool,
}
pub(super) struct Part<R> {
content_type: Vec<u8>,
content_disposition: Vec<u8>,
reader: R,
}
fn boundary_len(boundary: &[u8]) -> usize {
boundary.len() + 2
}
fn write_boundary<B: BufMut>(boundary: &[u8], buf: &mut B) {
buf.put_slice(b"--");
buf.put_slice(boundary);
}
fn final_boundary_len(boundary: &[u8]) -> usize {
boundary_len(boundary) + 2
}
fn write_final_boundary<B: BufMut>(boundary: &[u8], buf: &mut B) {
write_boundary(boundary, buf);
buf.put_slice(b"--");
}
fn crlf_len() -> usize {
2
}
fn write_crlf<B: BufMut>(buf: &mut B) {
buf.put_slice(b"\r\n");
}
fn headers_len<R>(part: &Part<R>) -> usize {
crlf_len()
+ CONTENT_TYPE.len()
+ 2
+ part.content_type.len()
+ crlf_len()
+ CONTENT_DISPOSITION.len()
+ 2
+ part.content_disposition.len()
+ crlf_len()
+ crlf_len()
}
fn write_headers<B: BufMut, R>(part: &Part<R>, buf: &mut B) {
write_crlf(buf);
buf.put_slice(CONTENT_TYPE);
buf.put_slice(b": ");
buf.put_slice(&part.content_type);
write_crlf(buf);
buf.put_slice(CONTENT_DISPOSITION);
buf.put_slice(b": ");
buf.put_slice(&part.content_disposition);
write_crlf(buf);
write_crlf(buf);
}
impl<R> Body<R> {
pub(super) fn new(boundary: String, parts: VecDeque<Part<R>>) -> Self {
Self {
boundary: Vec::from(boundary),
pending: BytesMut::new(),
current: None,
rest: parts,
closed: false,
}
}
fn write_boundary_to_pending(&mut self) {
write_boundary(&self.boundary, &mut self.pending);
}
fn write_headers_to_pending(&mut self, part: &Part<R>) {
write_headers(part, &mut self.pending);
}
fn write_final_boundary_to_pending(&mut self) {
write_final_boundary(&self.boundary, &mut self.pending);
}
fn write_clrf_to_pending(&mut self) {
write_crlf(&mut self.pending)
}
}
impl<'a> Part<SendRead<'a>> {
pub(super) fn new(
content_type: String,
content_disposition: String,
reader: SendRead<'a>,
) -> Self {
Part {
content_type: Vec::from(content_type),
content_disposition: Vec::from(content_disposition),
reader,
}
}
}
impl<'a> Part<UnsendRead<'a>> {
pub(super) fn new(
content_type: String,
content_disposition: String,
reader: UnsendRead<'a>,
) -> Self {
Part {
content_type: Vec::from(content_type),
content_disposition: Vec::from(content_disposition),
reader,
}
}
}
impl<'a> AsyncRead for SendRead<'a> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
}
}
impl<'a> AsyncRead for UnsendRead<'a> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.get_mut().0).poll_read(cx, buf)
}
}
fn poll_return(any_written: bool) -> Poll<std::io::Result<()>> {
if any_written {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}
impl<R> AsyncRead for Body<R>
where
R: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
if buf.remaining() == 0 {
return Poll::Ready(Ok(()));
}
let initial_len = buf.filled().len();
if !self.pending.is_empty() {
if self.pending.len() > buf.remaining() {
let bytes = self.pending.split_to(buf.remaining());
buf.put_slice(&bytes);
return Poll::Ready(Ok(()));
} else {
buf.put_slice(&self.pending);
self.pending.clear();
}
}
if self.closed {
return Poll::Ready(Ok(()));
}
let before_poll = buf.filled().len();
if let Some(ref mut reader) = self.current {
match Pin::new(reader).poll_read(cx, buf) {
Poll::Ready(Ok(())) if buf.filled().len() == before_poll && buf.remaining() > 0 => {
self.current.take();
if crlf_len() < buf.remaining() {
write_crlf(buf);
} else {
cx.waker().wake_by_ref();
self.write_clrf_to_pending();
}
poll_return(buf.filled().len() > initial_len)
}
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(otherwise) => Poll::Ready(otherwise),
Poll::Pending => poll_return(buf.filled().len() > initial_len),
}
} else if let Some(part) = self.rest.pop_front() {
let mut keep_writing = true;
if keep_writing && boundary_len(&self.boundary) < buf.remaining() {
write_boundary(&self.boundary, buf);
} else {
cx.waker().wake_by_ref();
self.write_boundary_to_pending();
keep_writing = false;
};
if keep_writing && headers_len(&part) < buf.remaining() {
write_headers(&part, buf);
} else {
cx.waker().wake_by_ref();
self.write_headers_to_pending(&part);
};
self.current = Some(part.reader);
poll_return(buf.filled().len() > initial_len)
} else if buf.remaining() > final_boundary_len(&self.boundary) {
self.closed = true;
write_final_boundary(&self.boundary, buf);
if buf.remaining() > crlf_len() {
write_crlf(buf);
} else {
cx.waker().wake_by_ref();
self.write_clrf_to_pending();
}
poll_return(buf.filled().len() > initial_len)
} else {
self.closed = true;
cx.waker().wake_by_ref();
self.write_final_boundary_to_pending();
self.write_clrf_to_pending();
poll_return(buf.filled().len() > initial_len)
}
}
}