Implement Async traits for Async<T>

This commit is contained in:
Aode (lion) 2022-03-03 20:36:43 -06:00
parent aede3ae0c9
commit 97088874f7
5 changed files with 221 additions and 58 deletions

9
Cargo.lock generated
View file

@ -39,11 +39,18 @@ dependencies = [
name = "foxtrot"
version = "0.1.0"
dependencies = [
"futures-io",
"join-all",
"polldance",
"read-write-buf",
]
[[package]]
name = "futures-io"
version = "0.3.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc4045962a5a5e935ee2fdedaa4e08284547402885ab326734432bed5d12966b"
[[package]]
name = "io-lifetimes"
version = "0.5.3"
@ -70,7 +77,7 @@ checksum = "5284f00d480e1c39af34e72f8ad60b94f47007e3481cd3b731c1d67190ddc7b7"
[[package]]
name = "polldance"
version = "0.1.0"
source = "git+https://git.asonix.dog/safe-async/polldance#43222ec6236fd2ae01ebd1b01f81a386d153490c"
source = "git+https://git.asonix.dog/safe-async/polldance#44c92daf0e1877d351c2b16e8dff07f817c221bd"
dependencies = [
"rustix",
]

View file

@ -7,6 +7,7 @@ edition = "2021"
[dependencies]
polldance = { git = "https://git.asonix.dog/safe-async/polldance" }
futures-io = "0.3.21"
[dev-dependencies]
read-write-buf = { git = "https://git.asonix.dog/safe-async/read-write-buf" }

View file

@ -6,11 +6,11 @@ use join_all::join_all;
use read_write_buf::ReadWriteBuf;
async fn echo(port: u16) -> Result<(), foxtrot::Error> {
let listener = Async::bind(([127, 0, 0, 1], port)).await?;
let mut listener = Async::bind(([127, 0, 0, 1], port)).await?;
println!("bound listener");
loop {
let (stream, _addr) = listener.accept().await?;
let (mut stream, _addr) = listener.accept().await?;
println!("Accepted connection");
let mut interests = Readiness::read();

View file

@ -7,21 +7,21 @@ async fn echo_to(port: u16) -> Result<(), Box<dyn std::error::Error>> {
let stream = TcpStream::connect(sockaddr)?;
stream.set_nonblocking(true)?;
let stream = Async::new(stream);
let mut stream = Async::new(stream);
println!("Connected");
loop {
let mut buf = [0; 1024];
if let Err(e) = stream.read(&mut buf).await {
if e == foxtrot::io::Error::PIPE {
if e.kind() == std::io::ErrorKind::BrokenPipe {
break;
}
return Err(e.into());
}
if let Err(e) = stream.write_all(&buf).await {
if e == foxtrot::io::Error::PIPE {
if e.kind() == std::io::ErrorKind::BrokenPipe {
break;
}

259
src/io.rs
View file

@ -24,8 +24,24 @@ macro_rules! poll_nonblocking {
}};
}
macro_rules! poll_std {
($expr:expr) => {{
match $expr {
Ok(t) => return Poll::Ready(Ok(t)),
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {}
Err(e) => return Poll::Ready(Err(e)),
}
}};
}
enum MaybeArc<T> {
Arc(Arc<T>),
Owned(T),
Intermediate,
}
pub struct Async<T: AsFd + 'static> {
io: Arc<T>,
io: MaybeArc<T>,
}
struct Ready<'a, T: AsFd + 'static> {
@ -34,12 +50,12 @@ struct Ready<'a, T: AsFd + 'static> {
}
struct Read<'a, T: AsFd + 'static> {
io: &'a Arc<T>,
io: &'a mut Async<T>,
bytes: &'a mut [u8],
}
struct Write<'a, T: AsFd + 'static> {
io: &'a Arc<T>,
io: &'a mut Async<T>,
bytes: &'a [u8],
}
@ -57,7 +73,7 @@ struct Connect {
impl<T: AsFd + 'static> Drop for Async<T> {
fn drop(&mut self) {
let _ = ReactorRef::with(|mut reactor| reactor.deregister(&self.io));
let _ = ReactorRef::with(|mut reactor| reactor.deregister(self.io.ensure_arc()));
}
}
@ -70,8 +86,11 @@ impl Async<TcpListener> {
.map(Async::new)
}
pub async fn accept(&self) -> Result<(Async<TcpStream>, Option<SocketAddrAny>)> {
let (stream, addr) = Accept { io: &self.io }.await?;
pub async fn accept(&mut self) -> Result<(Async<TcpStream>, Option<SocketAddrAny>)> {
let (stream, addr) = Accept {
io: self.io.ensure_arc(),
}
.await?;
Ok((Async::new(stream), addr))
}
@ -89,30 +108,97 @@ impl Async<TcpStream> {
impl<T: AsFd + 'static> Async<T> {
pub fn new(io: T) -> Self {
Async { io: Arc::new(io) }
Async {
io: MaybeArc::Owned(io),
}
}
pub async fn ready(&self, interests: Readiness) -> Result<Readiness> {
pub async fn ready(&mut self, interests: Readiness) -> Result<Readiness> {
Ready {
io: &self.io,
io: self.io.ensure_arc(),
interests,
}
.await
}
pub async fn read(&self, bytes: &mut [u8]) -> Result<ReadBytes> {
Read {
io: &self.io,
bytes,
}
.await
pub async fn read(&mut self, bytes: &mut [u8]) -> Result<ReadBytes>
where
T: std::io::Read + Unpin,
{
Read { io: self, bytes }.await
}
pub fn read_nonblocking(&self, buf: &mut [u8]) -> Result<Nonblocking<ReadBytes>> {
fn futures_poll_read(&mut self, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<Result<usize>>
where
T: std::io::Read,
{
if let Some(io) = self.io.ensure_owned() {
poll_std!(std::io::Read::read(io, buf))
}
ReactorRef::with(|mut reactor| {
reactor.register(
Arc::clone(self.io.ensure_arc()),
cx.waker().clone(),
Readiness::read() | Readiness::hangup(),
);
})
.unwrap();
Poll::Pending
}
fn futures_poll_write(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>>
where
T: std::io::Write,
{
if let Some(io) = self.io.ensure_owned() {
poll_std!(std::io::Write::write(io, buf))
}
ReactorRef::with(|mut reactor| {
reactor.register(
Arc::clone(self.io.ensure_arc()),
cx.waker().clone(),
Readiness::write() | Readiness::hangup(),
);
})
.unwrap();
Poll::Pending
}
fn futures_poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>>
where
T: std::io::Write,
{
if let Some(io) = self.io.ensure_owned() {
poll_std!(std::io::Write::flush(io))
}
ReactorRef::with(|mut reactor| {
reactor.register(
Arc::clone(self.io.ensure_arc()),
cx.waker().clone(),
Readiness::write() | Readiness::hangup(),
);
})
.unwrap();
Poll::Pending
}
pub fn read_nonblocking(&self, buf: &mut [u8]) -> Result<Nonblocking<ReadBytes>>
where
T: std::io::Read,
{
polldance::io::try_read(self.io.as_ref(), buf)
}
pub async fn read_exact(&self, bytes: &mut [u8]) -> Result<usize> {
pub async fn read_exact(&mut self, bytes: &mut [u8]) -> Result<usize>
where
T: std::io::Read + Unpin,
{
let mut start = 0;
while start < bytes.len() {
@ -125,19 +211,24 @@ impl<T: AsFd + 'static> Async<T> {
Ok(start)
}
pub async fn write(&self, bytes: &[u8]) -> Result<usize> {
Write {
io: &self.io,
bytes,
}
.await
pub async fn write(&mut self, bytes: &[u8]) -> Result<usize>
where
T: std::io::Write + Unpin,
{
Write { io: self, bytes }.await
}
pub fn write_nonblocking(&self, buf: &[u8]) -> Result<Nonblocking<usize>> {
pub fn write_nonblocking(&self, buf: &[u8]) -> Result<Nonblocking<usize>>
where
T: std::io::Write,
{
polldance::io::try_write(self.io.as_ref(), buf)
}
pub async fn write_all(&self, bytes: &[u8]) -> Result<()> {
pub async fn write_all(&mut self, bytes: &[u8]) -> Result<()>
where
T: std::io::Write + Unpin,
{
let mut start = 0;
while start < bytes.len() {
@ -148,6 +239,44 @@ impl<T: AsFd + 'static> Async<T> {
}
}
impl<T> futures_io::AsyncRead for Async<T>
where
T: std::io::Read + AsFd + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize>> {
self.get_mut()
.futures_poll_read(cx, buf)
.map(|res| res.map_err(From::from))
}
}
impl<'a, T> futures_io::AsyncWrite for Async<T>
where
T: std::io::Write + AsFd + Unpin,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
self.get_mut()
.futures_poll_write(cx, buf)
.map(|res| res.map_err(From::from))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.get_mut()
.futures_poll_flush(cx)
.map(|res| res.map_err(From::from))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.get_mut()
.futures_poll_flush(cx)
.map(|res| res.map_err(From::from))
}
}
impl Future for Bind {
type Output = Result<TcpListener>;
@ -225,47 +354,32 @@ impl Future for Connect {
impl<'a, T> Future for Read<'a, T>
where
T: AsFd + 'static,
T: std::io::Read + AsFd + Unpin + 'static,
{
type Output = Result<ReadBytes>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.as_mut();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let io = &mut this.io;
let bytes = &mut this.bytes;
poll_nonblocking!(polldance::io::try_read(this.io.as_ref(), this.bytes));
ReactorRef::with(|mut reactor| {
reactor.register(
Arc::clone(this.io),
cx.waker().clone(),
Readiness::read() | Readiness::hangup(),
);
})
.unwrap();
Poll::Pending
io.futures_poll_read(cx, bytes)
.map(|res| res.map(From::from))
}
}
impl<'a, T> Future for Write<'a, T>
where
T: AsFd + 'static,
T: std::io::Write + AsFd + Unpin + 'static,
{
type Output = Result<usize>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
poll_nonblocking!(polldance::io::try_write(self.io.as_ref(), self.bytes));
let this = self.get_mut();
let io = &mut this.io;
let bytes = &mut this.bytes;
ReactorRef::with(|mut reactor| {
reactor.register(
Arc::clone(self.io),
cx.waker().clone(),
Readiness::write() | Readiness::hangup(),
);
})
.unwrap();
Poll::Pending
io.futures_poll_write(cx, bytes)
}
}
@ -307,3 +421,44 @@ where
Poll::Pending
}
}
impl<T> MaybeArc<T> {
fn ensure_arc(&mut self) -> &Arc<T> {
*self = match std::mem::replace(self, MaybeArc::Intermediate) {
MaybeArc::Owned(owned) => MaybeArc::Arc(Arc::new(owned)),
MaybeArc::Arc(arc) => MaybeArc::Arc(arc),
MaybeArc::Intermediate => unreachable!("Should never be intermediate"),
};
if let MaybeArc::Arc(arc) = self {
return arc;
}
unreachable!("We should always have an Arc by the end")
}
fn ensure_owned(&mut self) -> Option<&mut T> {
*self = match std::mem::replace(self, MaybeArc::Intermediate) {
MaybeArc::Arc(arc) => match Arc::try_unwrap(arc) {
Ok(owned) => MaybeArc::Owned(owned),
Err(arc) => MaybeArc::Arc(arc),
},
MaybeArc::Owned(owned) => MaybeArc::Owned(owned),
MaybeArc::Intermediate => unreachable!("Should never be intermediate"),
};
if let MaybeArc::Owned(owned) = self {
return Some(owned);
}
None
}
fn as_ref(&self) -> &T {
match self {
MaybeArc::Arc(arc) => arc.as_ref(),
MaybeArc::Owned(owned) => owned,
_ => unreachable!("Should never have intermediate state"),
}
}
}