Add abort to bachata

This commit is contained in:
asonix 2023-08-27 16:55:37 -05:00
parent 1380aa410a
commit 591318434b
6 changed files with 161 additions and 18 deletions

2
.gitignore vendored
View file

@ -1,2 +1,4 @@
/target
/.direnv
/.envrc
Cargo.lock

19
examples/abort.rs Normal file
View file

@ -0,0 +1,19 @@
use std::time::Duration;
fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
foxtrot::block_on(bachata::run_with(async {
println!("hai!!!");
let handle = bachata::spawn(async move {
foxtrot::time::sleep(Duration::from_secs(2)).await;
println!("slept");
});
handle.abort();
println!("This will error");
handle.await?;
Ok(())
}))?
}

View file

@ -17,6 +17,8 @@ fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
handle.await?;
}
println!("Joined all");
Ok(())
}))?
}

61
flake.lock Normal file
View file

@ -0,0 +1,61 @@
{
"nodes": {
"flake-utils": {
"inputs": {
"systems": "systems"
},
"locked": {
"lastModified": 1692799911,
"narHash": "sha256-3eihraek4qL744EvQXsK1Ha6C3CR7nnT8X2qWap4RNk=",
"owner": "numtide",
"repo": "flake-utils",
"rev": "f9e7cf818399d17d347f847525c5a5a8032e4e44",
"type": "github"
},
"original": {
"owner": "numtide",
"repo": "flake-utils",
"type": "github"
}
},
"nixpkgs": {
"locked": {
"lastModified": 1693003285,
"narHash": "sha256-5nm4yrEHKupjn62MibENtfqlP6pWcRTuSKrMiH9bLkc=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "5690c4271f2998c304a45c91a0aeb8fb69feaea7",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixos-unstable",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"flake-utils": "flake-utils",
"nixpkgs": "nixpkgs"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
"owner": "nix-systems",
"repo": "default",
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
"type": "github"
},
"original": {
"owner": "nix-systems",
"repo": "default",
"type": "github"
}
}
},
"root": "root",
"version": 7
}

25
flake.nix Normal file
View file

@ -0,0 +1,25 @@
{
description = "bachata";
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
flake-utils.url = "github:numtide/flake-utils";
};
outputs = { self, nixpkgs, flake-utils }:
flake-utils.lib.eachDefaultSystem (system:
let
pkgs = import nixpkgs {
inherit system;
};
in
{
packages.default = pkgs.hello;
devShell = with pkgs; mkShell {
nativeBuildInputs = [ cargo cargo-outdated cargo-zigbuild clippy gcc protobuf rust-analyzer rustc rustfmt ];
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
};
});
}

View file

@ -75,18 +75,27 @@ mod task_id {
}
}
fn joiner<T>() -> (JoinNotifier<T>, JoinHandle<T>) {
let state = Rc::new(RefCell::new(JoinState {
waker: None,
fn joiner<T>() -> (JoinNotifier<T>, Rc<RefCell<JoinState>>, JoinHandle<T>) {
let task_state = Rc::new(RefCell::new(TaskState {
handle_waker: None,
item: None,
dropped: false,
}));
let join_state = Rc::new(RefCell::new(JoinState {
task_waker: None,
aborted: false,
}));
(
JoinNotifier {
state: Rc::clone(&state),
task_state: Rc::clone(&task_state),
},
Rc::clone(&join_state),
JoinHandle {
task_state,
join_state,
},
JoinHandle { state },
)
}
@ -111,20 +120,27 @@ struct RemoteState {
struct Task {
future: Pin<Box<dyn Future<Output = ()>>>,
join_state: Rc<RefCell<JoinState>>,
}
struct JoinState<T> {
waker: Option<Waker>,
struct TaskState<T> {
handle_waker: Option<Waker>,
item: Option<T>,
dropped: bool,
}
struct JoinState {
task_waker: Option<Waker>,
aborted: bool,
}
struct JoinNotifier<T> {
state: Rc<RefCell<JoinState<T>>>,
task_state: Rc<RefCell<TaskState<T>>>,
}
pub struct JoinHandle<T> {
state: Rc<RefCell<JoinState<T>>>,
join_state: Rc<RefCell<JoinState>>,
task_state: Rc<RefCell<TaskState<T>>>,
}
#[derive(Debug)]
@ -196,12 +212,13 @@ impl ExecutorState {
fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
let id = self.id_gen.generate();
let (notifier, handle) = joiner();
let (notifier, join_state, handle) = joiner();
let task = Task {
future: Box::pin(async move {
notifier.ready(future.await);
}),
join_state,
};
self.tasks.borrow_mut().insert(id, task);
@ -254,11 +271,8 @@ impl RemoteState {
impl<T> JoinNotifier<T> {
fn ready(self, item: T) {
let mut guard = self.state.borrow_mut();
let mut guard = self.task_state.borrow_mut();
guard.item = Some(item);
if let Some(waker) = guard.waker.take() {
waker.wake();
}
}
}
@ -309,11 +323,23 @@ impl ThreadWaker {
}
}
impl<T> JoinHandle<T> {
pub fn abort(&self) {
let mut guard = self.join_state.borrow_mut();
guard.aborted = true;
if let Some(waker) = guard.task_waker.take() {
waker.wake();
}
}
}
impl<T> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut guard = self.state.borrow_mut();
let mut guard = self.task_state.borrow_mut();
if let Some(item) = guard.item.take() {
return Poll::Ready(Ok(item));
@ -323,7 +349,8 @@ impl<T> Future for JoinHandle<T> {
return Poll::Ready(Err(JoinError));
}
guard.waker = Some(cx.waker().clone());
guard.handle_waker = Some(cx.waker().clone());
drop(guard);
Poll::Pending
}
@ -368,6 +395,12 @@ where
continue;
};
let guard = task.join_state.borrow();
if guard.aborted {
continue;
}
drop(guard);
let maybe_local_waker = Arc::new(MaybeLocalWaker {
task_id,
remote_state: Arc::clone(&this.executor.state.remote_state),
@ -383,6 +416,7 @@ where
match res {
Ok(Poll::Ready(())) => {}
Ok(Poll::Pending) => {
task.join_state.borrow_mut().task_waker = Some(maybe_local_waker);
this.executor.state.tasks.borrow_mut().insert(task_id, task);
}
Err(_) => {
@ -479,9 +513,9 @@ impl Drop for ExecutorToken {
impl<T> Drop for JoinNotifier<T> {
fn drop(&mut self) {
let mut guard = self.state.borrow_mut();
let mut guard = self.task_state.borrow_mut();
guard.dropped = true;
if let Some(waker) = guard.waker.take() {
if let Some(waker) = guard.handle_waker.take() {
waker.wake();
}
}