From 38e72a8c5be8116cfc3f99d0fa96bcd1b0724a4b Mon Sep 17 00:00:00 2001 From: asonix Date: Sat, 15 Jul 2023 20:29:02 -0500 Subject: [PATCH] Multipart client implementation --- Cargo.toml | 11 +- flake.nix | 2 +- src/internal.rs | 239 +++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 275 ++++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 519 insertions(+), 8 deletions(-) create mode 100644 src/internal.rs diff --git a/Cargo.toml b/Cargo.toml index d1660b9..4d3fdc5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,17 @@ [package] -name = "awc-multipart" +name = "multipart-client-stream" version = "0.1.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +bytes = "1" +futures-core = "0.3" +mime = "0.3" +rand = { version = "0.8", features = ["small_rng"] } +tokio = { version = "1", default-features = false, features = [ "io-util" ] } +tokio-util = { version = "0.7", default-features = false, features = ["io"] } + +[dev-dependencies] +tokio = { version = "1", features = ["full"] } diff --git a/flake.nix b/flake.nix index f535a28..dd73d42 100644 --- a/flake.nix +++ b/flake.nix @@ -1,5 +1,5 @@ { - description = "awc-multipart"; + description = "multipart-client-stream"; inputs = { nixpkgs.url = "nixpkgs/nixos-unstable"; diff --git a/src/internal.rs b/src/internal.rs new file mode 100644 index 0000000..4e5e12a --- /dev/null +++ b/src/internal.rs @@ -0,0 +1,239 @@ +use std::{ + collections::VecDeque, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::{BufMut, BytesMut}; +use tokio::io::{AsyncRead, ReadBuf}; + +const CONTENT_TYPE: &[u8] = b"content-type"; +const CONTENT_DISPOSITION: &[u8] = b"content-disposition"; + +pub struct SendRead<'a>(pub(super) Pin>); +pub struct UnsendRead<'a>(pub(super) Pin>); + +impl<'a> From> for UnsendRead<'a> { + fn from(value: SendRead<'a>) -> Self { + UnsendRead(value.0) + } +} + +pub(super) struct Body { + boundary: Vec, + pending: BytesMut, + current: Option, + rest: VecDeque>, + closed: bool, +} + +pub(super) struct Part { + content_type: Vec, + content_disposition: Vec, + reader: R, +} + +fn boundary_len(boundary: &[u8]) -> usize { + boundary.len() + 2 +} +fn write_boundary(boundary: &[u8], buf: &mut B) { + buf.put_slice(b"--"); + buf.put_slice(boundary); +} + +fn final_boundary_len(boundary: &[u8]) -> usize { + boundary_len(boundary) + 2 +} +fn write_final_boundary(boundary: &[u8], buf: &mut B) { + write_boundary(boundary, buf); + buf.put_slice(b"--"); +} + +fn crlf_len() -> usize { + 2 +} +fn write_crlf(buf: &mut B) { + buf.put_slice(b"\r\n"); +} + +fn headers_len(part: &Part) -> usize { + crlf_len() + + CONTENT_TYPE.len() + + 2 + + part.content_type.len() + + crlf_len() + + CONTENT_DISPOSITION.len() + + 2 + + part.content_disposition.len() + + crlf_len() + + crlf_len() +} +fn write_headers(part: &Part, buf: &mut B) { + write_crlf(buf); + buf.put_slice(CONTENT_TYPE); + buf.put_slice(b": "); + buf.put_slice(&part.content_type); + write_crlf(buf); + buf.put_slice(CONTENT_DISPOSITION); + buf.put_slice(b": "); + buf.put_slice(&part.content_disposition); + write_crlf(buf); + write_crlf(buf); +} + +impl Body { + pub(super) fn new(boundary: String, parts: VecDeque>) -> Self { + Self { + boundary: Vec::from(boundary), + pending: BytesMut::new(), + current: None, + rest: parts, + closed: false, + } + } + + fn write_boundary_to_pending(&mut self) { + write_boundary(&self.boundary, &mut self.pending); + } + + fn write_headers_to_pending(&mut self, part: &Part) { + write_headers(part, &mut self.pending); + } + + fn write_final_boundary_to_pending(&mut self) { + write_final_boundary(&self.boundary, &mut self.pending); + } + + fn write_clrf_to_pending(&mut self) { + write_crlf(&mut self.pending) + } +} + +impl<'a> Part> { + pub(super) fn new( + content_type: String, + content_disposition: String, + reader: SendRead<'a>, + ) -> Self { + Part { + content_type: Vec::from(content_type), + content_disposition: Vec::from(content_disposition), + reader, + } + } +} + +impl<'a> Part> { + pub(super) fn new( + content_type: String, + content_disposition: String, + reader: UnsendRead<'a>, + ) -> Self { + Part { + content_type: Vec::from(content_type), + content_disposition: Vec::from(content_disposition), + reader, + } + } +} + +impl<'a> AsyncRead for SendRead<'a> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_read(cx, buf) + } +} + +impl<'a> AsyncRead for UnsendRead<'a> { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.get_mut().0).poll_read(cx, buf) + } +} + +impl AsyncRead for Body +where + R: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + if buf.remaining() == 0 { + return Poll::Ready(Ok(())); + } + + let initial_len = buf.filled().len(); + + if !self.pending.is_empty() { + if self.pending.len() > buf.remaining() { + let bytes = self.pending.split_to(buf.remaining()); + buf.put_slice(&bytes); + return Poll::Ready(Ok(())); + } else { + buf.put_slice(&self.pending); + self.pending.clear(); + } + } + + if self.closed { + return Poll::Ready(Ok(())); + } + + let before_poll = buf.filled().len(); + if let Some(ref mut reader) = self.current { + match Pin::new(reader).poll_read(cx, buf) { + Poll::Ready(Ok(())) if buf.filled().len() == before_poll => { + self.current.take(); + if crlf_len() < buf.remaining() { + write_crlf(buf); + } else { + self.write_clrf_to_pending(); + } + self.poll_read(cx, buf) + } + Poll::Ready(Ok(())) => Poll::Ready(Ok(())), + Poll::Ready(otherwise) => Poll::Ready(otherwise), + Poll::Pending if buf.filled().len() > initial_len => Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, + } + } else if let Some(part) = self.rest.pop_front() { + let fill_buf = if boundary_len(&self.boundary) < buf.remaining() { + write_boundary(&self.boundary, buf); + true + } else { + self.write_boundary_to_pending(); + false + }; + + if fill_buf && headers_len(&part) < buf.remaining() { + write_headers(&part, buf); + } else { + self.write_headers_to_pending(&part); + }; + + self.current = Some(part.reader); + + if buf.remaining() > 0 { + self.poll_read(cx, buf) + } else { + Poll::Ready(Ok(())) + } + } else if buf.remaining() > final_boundary_len(&self.boundary) { + write_final_boundary(&self.boundary, buf); + self.closed = true; + Poll::Ready(Ok(())) + } else { + self.write_final_boundary_to_pending(); + self.closed = true; + self.poll_read(cx, buf) + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 7d12d9a..278a27f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,14 +1,277 @@ -pub fn add(left: usize, right: usize) -> usize { - left + right +mod internal; + +use std::{collections::VecDeque, io::Cursor, pin::Pin}; + +use bytes::Bytes; +use futures_core::Stream; +use internal::{SendRead, UnsendRead}; +use mime::Mime; +use rand::{distributions::Alphanumeric, rngs::SmallRng, Rng, SeedableRng}; +use tokio::io::AsyncRead; +use tokio_util::io::ReaderStream; + +pub struct Body { + stream: ReaderStream>, +} + +pub struct BodyBuilder { + boundary: String, + parts: Vec>, +} + +pub struct Part { + name: String, + content_type: Option, + filename: Option, + reader: R, +} + +#[derive(Debug)] +pub struct Empty; + +impl<'a> Body> { + pub fn builder() -> BodyBuilder> { + let boundary = SmallRng::from_entropy() + .sample_iter(&Alphanumeric) + .map(char::from) + .take(6) + .collect::(); + + BodyBuilder { + boundary, + parts: Vec::new(), + } + } +} + +impl<'a> BodyBuilder> { + pub fn boundary(mut self, boundary: String) -> Self { + self.boundary = boundary; + self + } + + pub fn append(mut self, part: Part>) -> Self { + self.parts.push(part); + self + } + + pub fn append_unsend(self, part: Part>) -> BodyBuilder> { + let mut parts: Vec<_> = self + .parts + .into_iter() + .map(Part::>::from) + .collect(); + + parts.push(part); + + BodyBuilder { + boundary: self.boundary, + parts, + } + } + + pub fn build(self) -> Body> { + let parts = self + .parts + .into_iter() + .map(Part::>::build) + .collect(); + + Body { + stream: ReaderStream::new(internal::Body::new(self.boundary, parts)), + } + } +} + +impl<'a> BodyBuilder> { + pub fn boundary(mut self, boundary: String) -> Self { + self.boundary = boundary; + self + } + + pub fn append(mut self, part: Part>) -> Self { + self.parts.push(From::from(part)); + self + } + + pub fn append_unsend(mut self, part: Part>) -> BodyBuilder> { + self.parts.push(part); + self + } + + pub fn build(self) -> Body> { + let parts: VecDeque>> = self + .parts + .into_iter() + .map(Part::>::build) + .collect(); + + Body { + stream: ReaderStream::new(internal::Body::new(self.boundary, parts)), + } + } +} + +fn encode(value: String) -> String { + value.replace('"', "\\\"") +} + +impl<'a> Part> { + pub fn new(name: String, reader: R) -> Self { + Part { + name, + content_type: None, + filename: None, + reader: SendRead(Box::pin(reader)), + } + } + + pub fn new_unsend(name: String, reader: R) -> Part> { + Part { + name, + content_type: None, + filename: None, + reader: UnsendRead(Box::pin(reader)), + } + } + + pub fn new_str(name: String, text: &'a str) -> Self { + Self::new(name, text.as_bytes()).content_type(mime::TEXT_PLAIN) + } + + pub fn new_string(name: String, text: String) -> Self { + Self::new(name, Cursor::new(text)).content_type(mime::TEXT_PLAIN) + } + + pub fn content_type(mut self, content_type: mime::Mime) -> Self { + self.content_type = Some(content_type); + self + } + + pub fn filename(mut self, filename: String) -> Self { + self.filename = Some(filename); + self + } + + fn build(self) -> internal::Part> { + let content_type = self.content_type.unwrap_or(mime::APPLICATION_OCTET_STREAM); + + let name = encode(self.name); + let filename = self.filename.map(encode); + + let content_disposition = if let Some(filename) = filename { + format!("form-data; name=\"{name}\"; filename=\"{filename}\"") + } else { + format!("form-data; name=\"{name}\"") + }; + + internal::Part::>::new( + content_type.to_string(), + content_disposition, + self.reader, + ) + } +} +impl<'a> Part> { + pub fn content_type(mut self, content_type: mime::Mime) -> Self { + self.content_type = Some(content_type); + self + } + + pub fn filename(mut self, filename: String) -> Self { + self.filename = Some(filename); + self + } + + fn build(self) -> internal::Part> { + let content_type = self.content_type.unwrap_or(mime::APPLICATION_OCTET_STREAM); + + let name = encode(self.name); + let filename = self.filename.map(encode); + + let content_disposition = if let Some(filename) = filename { + format!("form-data; name=\"{name}\"; filename=\"{filename}\"") + } else { + format!("form-data; name=\"{name}\"") + }; + + internal::Part::>::new( + content_type.to_string(), + content_disposition, + self.reader, + ) + } +} + +impl Stream for Body +where + R: AsyncRead + Unpin, +{ + type Item = std::io::Result; + + fn poll_next( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.get_mut().stream).poll_next(cx) + } +} + +impl<'a> From>> for Part> { + fn from(value: Part>) -> Self { + Self { + name: value.name, + content_type: value.content_type, + filename: value.filename, + reader: UnsendRead::from(value.reader), + } + } } #[cfg(test)] mod tests { - use super::*; + use std::{future::poll_fn, pin::Pin}; + + struct Streamer(S); + + impl Streamer { + async fn next(&mut self) -> Option + where + S: futures_core::Stream + Unpin, + { + poll_fn(|cx| Pin::new(&mut self.0).poll_next(cx)).await + } + } + + #[tokio::test] + async fn build_text() { + let body = super::Body::builder() + .boundary(String::from("hello")) + .append(super::Part::new_str(String::from("first_name"), "John")) + .append(super::Part::new_str(String::from("last_name"), "Doe")) + .build(); + + let mut out = Vec::new(); + + let mut streamer = Streamer(body); + + while let Some(res) = streamer.next().await { + out.extend(res.expect("read success")); + } + + let out = String::from_utf8(out).expect("Valid string"); + + assert_eq!(out, "--hello\r\ncontent-type: text/plain\r\ncontent-disposition: form-data; name=\"first_name\"\r\n\r\nJohn\r\n--hello\r\ncontent-type: text/plain\r\ncontent-disposition: form-data; name=\"last_name\"\r\n\r\nDoe\r\n--hello--") + } #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); + fn encode() { + let cases = [("hello", "hello"), ("Hello \"John\"", "Hello \\\"John\\\"")]; + + for (input, expected) in cases { + let output = super::encode(String::from(input)); + + assert_eq!(output, expected); + } } }