diff --git a/.gitignore b/.gitignore index 96ef6c0..83157df 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ /target +/.direnv +/.envrc Cargo.lock diff --git a/examples/abort.rs b/examples/abort.rs new file mode 100644 index 0000000..06be131 --- /dev/null +++ b/examples/abort.rs @@ -0,0 +1,19 @@ +use std::time::Duration; + +fn main() -> Result<(), Box> { + foxtrot::block_on(bachata::run_with(async { + println!("hai!!!"); + + let handle = bachata::spawn(async move { + foxtrot::time::sleep(Duration::from_secs(2)).await; + println!("slept"); + }); + + handle.abort(); + + println!("This will error"); + handle.await?; + + Ok(()) + }))? +} diff --git a/examples/time.rs b/examples/time.rs index ba3f2bf..615aa39 100644 --- a/examples/time.rs +++ b/examples/time.rs @@ -17,6 +17,8 @@ fn main() -> Result<(), Box> { handle.await?; } + println!("Joined all"); + Ok(()) }))? } diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..2a20bfa --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1692799911, + "narHash": "sha256-3eihraek4qL744EvQXsK1Ha6C3CR7nnT8X2qWap4RNk=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "f9e7cf818399d17d347f847525c5a5a8032e4e44", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1693003285, + "narHash": "sha256-5nm4yrEHKupjn62MibENtfqlP6pWcRTuSKrMiH9bLkc=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "5690c4271f2998c304a45c91a0aeb8fb69feaea7", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..0833233 --- /dev/null +++ b/flake.nix @@ -0,0 +1,25 @@ +{ + description = "bachata"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { + inherit system; + }; + in + { + packages.default = pkgs.hello; + + devShell = with pkgs; mkShell { + nativeBuildInputs = [ cargo cargo-outdated cargo-zigbuild clippy gcc protobuf rust-analyzer rustc rustfmt ]; + + RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}"; + }; + }); +} diff --git a/src/lib.rs b/src/lib.rs index ed64578..b9af54b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,18 +75,27 @@ mod task_id { } } -fn joiner() -> (JoinNotifier, JoinHandle) { - let state = Rc::new(RefCell::new(JoinState { - waker: None, +fn joiner() -> (JoinNotifier, Rc>, JoinHandle) { + let task_state = Rc::new(RefCell::new(TaskState { + handle_waker: None, item: None, dropped: false, })); + let join_state = Rc::new(RefCell::new(JoinState { + task_waker: None, + aborted: false, + })); + ( JoinNotifier { - state: Rc::clone(&state), + task_state: Rc::clone(&task_state), + }, + Rc::clone(&join_state), + JoinHandle { + task_state, + join_state, }, - JoinHandle { state }, ) } @@ -111,20 +120,27 @@ struct RemoteState { struct Task { future: Pin>>, + join_state: Rc>, } -struct JoinState { - waker: Option, +struct TaskState { + handle_waker: Option, item: Option, dropped: bool, } +struct JoinState { + task_waker: Option, + aborted: bool, +} + struct JoinNotifier { - state: Rc>>, + task_state: Rc>>, } pub struct JoinHandle { - state: Rc>>, + join_state: Rc>, + task_state: Rc>>, } #[derive(Debug)] @@ -196,12 +212,13 @@ impl ExecutorState { fn spawn(&self, future: F) -> JoinHandle { let id = self.id_gen.generate(); - let (notifier, handle) = joiner(); + let (notifier, join_state, handle) = joiner(); let task = Task { future: Box::pin(async move { notifier.ready(future.await); }), + join_state, }; self.tasks.borrow_mut().insert(id, task); @@ -254,11 +271,8 @@ impl RemoteState { impl JoinNotifier { fn ready(self, item: T) { - let mut guard = self.state.borrow_mut(); + let mut guard = self.task_state.borrow_mut(); guard.item = Some(item); - if let Some(waker) = guard.waker.take() { - waker.wake(); - } } } @@ -309,11 +323,23 @@ impl ThreadWaker { } } +impl JoinHandle { + pub fn abort(&self) { + let mut guard = self.join_state.borrow_mut(); + + guard.aborted = true; + + if let Some(waker) = guard.task_waker.take() { + waker.wake(); + } + } +} + impl Future for JoinHandle { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut guard = self.state.borrow_mut(); + let mut guard = self.task_state.borrow_mut(); if let Some(item) = guard.item.take() { return Poll::Ready(Ok(item)); @@ -323,7 +349,8 @@ impl Future for JoinHandle { return Poll::Ready(Err(JoinError)); } - guard.waker = Some(cx.waker().clone()); + guard.handle_waker = Some(cx.waker().clone()); + drop(guard); Poll::Pending } @@ -368,6 +395,12 @@ where continue; }; + let guard = task.join_state.borrow(); + if guard.aborted { + continue; + } + drop(guard); + let maybe_local_waker = Arc::new(MaybeLocalWaker { task_id, remote_state: Arc::clone(&this.executor.state.remote_state), @@ -383,6 +416,7 @@ where match res { Ok(Poll::Ready(())) => {} Ok(Poll::Pending) => { + task.join_state.borrow_mut().task_waker = Some(maybe_local_waker); this.executor.state.tasks.borrow_mut().insert(task_id, task); } Err(_) => { @@ -479,9 +513,9 @@ impl Drop for ExecutorToken { impl Drop for JoinNotifier { fn drop(&mut self) { - let mut guard = self.state.borrow_mut(); + let mut guard = self.task_state.borrow_mut(); guard.dropped = true; - if let Some(waker) = guard.waker.take() { + if let Some(waker) = guard.handle_waker.take() { waker.wake(); } }