Simplify process future by removing explicit channel, don't lock mutex as often

This commit is contained in:
asonix 2023-12-22 12:03:05 -06:00
parent 46ceac3432
commit db43392a3b

View file

@ -1,5 +1,4 @@
use actix_web::web::Bytes;
use flume::r#async::RecvFut;
use std::{
ffi::OsStr,
future::Future,
@ -73,7 +72,7 @@ impl std::fmt::Debug for Process {
}
struct DropHandle {
inner: JoinHandle<()>,
inner: JoinHandle<std::io::Result<()>>,
}
struct ProcessReadState {
@ -88,7 +87,6 @@ struct ProcessReadWaker {
pub(crate) struct ProcessRead {
inner: ChildStdout,
err_recv: RecvFut<'static, std::io::Error>,
handle: DropHandle,
closed: bool,
state: Arc<ProcessReadState>,
@ -233,9 +231,6 @@ impl Process {
let stdin = child.stdin.take().expect("stdin exists");
let stdout = child.stdout.take().expect("stdout exists");
let (tx, rx) = crate::sync::channel::<std::io::Error>(1);
let rx = rx.into_recv_async();
let background_span =
tracing::info_span!(parent: None, "Background process task", %command);
background_span.follows_from(Span::current());
@ -255,7 +250,7 @@ impl Process {
let error = match child_fut.with_timeout(timeout).await {
Ok(Ok(status)) if status.success() => {
guard.disarm();
return;
return Ok(());
}
Ok(Ok(status)) => {
std::io::Error::new(std::io::ErrorKind::Other, StatusError(status))
@ -264,15 +259,15 @@ impl Process {
Err(_) => std::io::ErrorKind::TimedOut.into(),
};
let _ = tx.send(error);
let _ = child.kill().await;
child.kill().await?;
Err(error)
}
.instrument(background_span),
);
ProcessRead {
inner: stdout,
err_recv: rx,
handle: DropHandle { inner: handle },
closed: false,
state: ProcessReadState::new_woken(),
@ -291,7 +286,7 @@ impl ProcessReadState {
fn clone_parent(&self) -> Option<Waker> {
let guard = self.parent.lock().unwrap();
guard.as_ref().map(|w| w.clone())
guard.as_ref().cloned()
}
fn into_parts(self) -> (AtomicU8, Option<Waker>) {
@ -322,19 +317,26 @@ impl ProcessRead {
}
}
fn set_parent_waker(&self, parent: &Waker) {
fn set_parent_waker(&self, parent: &Waker) -> bool {
let mut guard = self.state.parent.lock().unwrap();
if let Some(waker) = guard.as_mut() {
if !waker.will_wake(parent) {
*waker = parent.clone();
true
} else {
false
}
} else {
*guard = Some(parent.clone());
true
}
}
fn mark_all_woken(&self) {
self.state.flags.store(0xff, Ordering::Release);
}
}
const RECV_WAKER: u8 = 0b_0010;
const HANDLE_WAKER: u8 = 0b_0100;
impl AsyncRead for ProcessRead {
@ -343,8 +345,6 @@ impl AsyncRead for ProcessRead {
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
self.set_parent_waker(cx.waker());
let span = self.span.clone();
let guard = span.enter();
@ -364,29 +364,30 @@ impl AsyncRead for ProcessRead {
} else {
break Poll::Ready(Ok(()));
}
} else if let Some(waker) = self.get_waker(RECV_WAKER) {
// only poll recv if we've been explicitly woken
let mut recv_cx = Context::from_waker(&waker);
if let Poll::Ready(res) = Pin::new(&mut self.err_recv).poll(&mut recv_cx) {
if let Ok(err) = res {
self.closed = true;
break Poll::Ready(Err(err));
}
}
} else if let Some(waker) = self.get_waker(HANDLE_WAKER) {
// only poll handle if we've been explicitly woken
let mut handle_cx = Context::from_waker(&waker);
if let Poll::Ready(res) = Pin::new(&mut self.handle.inner).poll(&mut handle_cx) {
if let Err(e) = res {
self.closed = true;
break Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)));
}
let error = match res {
Ok(Ok(())) => continue,
Ok(Err(e)) => e,
Err(e) => std::io::Error::new(std::io::ErrorKind::Other, e),
};
self.closed = true;
break Poll::Ready(Err(error));
}
} else if self.closed {
// Stop if we're closed
break Poll::Ready(Ok(()));
} else if self.set_parent_waker(cx.waker()) {
// if we updated the stored waker, mark all as woken an try polling again
// This doesn't actually "wake" the waker, it just allows the handle to be polled
// again next iteration
self.mark_all_woken();
} else {
// if the waker hasn't changed and nothing polled ready, return pending
break Poll::Pending;
}
};