Add abort to bachata
This commit is contained in:
parent
1380aa410a
commit
591318434b
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,2 +1,4 @@
|
|||
/target
|
||||
/.direnv
|
||||
/.envrc
|
||||
Cargo.lock
|
||||
|
|
19
examples/abort.rs
Normal file
19
examples/abort.rs
Normal file
|
@ -0,0 +1,19 @@
|
|||
use std::time::Duration;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
foxtrot::block_on(bachata::run_with(async {
|
||||
println!("hai!!!");
|
||||
|
||||
let handle = bachata::spawn(async move {
|
||||
foxtrot::time::sleep(Duration::from_secs(2)).await;
|
||||
println!("slept");
|
||||
});
|
||||
|
||||
handle.abort();
|
||||
|
||||
println!("This will error");
|
||||
handle.await?;
|
||||
|
||||
Ok(())
|
||||
}))?
|
||||
}
|
|
@ -17,6 +17,8 @@ fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
handle.await?;
|
||||
}
|
||||
|
||||
println!("Joined all");
|
||||
|
||||
Ok(())
|
||||
}))?
|
||||
}
|
||||
|
|
61
flake.lock
Normal file
61
flake.lock
Normal file
|
@ -0,0 +1,61 @@
|
|||
{
|
||||
"nodes": {
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1692799911,
|
||||
"narHash": "sha256-3eihraek4qL744EvQXsK1Ha6C3CR7nnT8X2qWap4RNk=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "f9e7cf818399d17d347f847525c5a5a8032e4e44",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1693003285,
|
||||
"narHash": "sha256-5nm4yrEHKupjn62MibENtfqlP6pWcRTuSKrMiH9bLkc=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "5690c4271f2998c304a45c91a0aeb8fb69feaea7",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils",
|
||||
"nixpkgs": "nixpkgs"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
25
flake.nix
Normal file
25
flake.nix
Normal file
|
@ -0,0 +1,25 @@
|
|||
{
|
||||
description = "bachata";
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
|
||||
outputs = { self, nixpkgs, flake-utils }:
|
||||
flake-utils.lib.eachDefaultSystem (system:
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
inherit system;
|
||||
};
|
||||
in
|
||||
{
|
||||
packages.default = pkgs.hello;
|
||||
|
||||
devShell = with pkgs; mkShell {
|
||||
nativeBuildInputs = [ cargo cargo-outdated cargo-zigbuild clippy gcc protobuf rust-analyzer rustc rustfmt ];
|
||||
|
||||
RUST_SRC_PATH = "${pkgs.rust.packages.stable.rustPlatform.rustLibSrc}";
|
||||
};
|
||||
});
|
||||
}
|
70
src/lib.rs
70
src/lib.rs
|
@ -75,18 +75,27 @@ mod task_id {
|
|||
}
|
||||
}
|
||||
|
||||
fn joiner<T>() -> (JoinNotifier<T>, JoinHandle<T>) {
|
||||
let state = Rc::new(RefCell::new(JoinState {
|
||||
waker: None,
|
||||
fn joiner<T>() -> (JoinNotifier<T>, Rc<RefCell<JoinState>>, JoinHandle<T>) {
|
||||
let task_state = Rc::new(RefCell::new(TaskState {
|
||||
handle_waker: None,
|
||||
item: None,
|
||||
dropped: false,
|
||||
}));
|
||||
|
||||
let join_state = Rc::new(RefCell::new(JoinState {
|
||||
task_waker: None,
|
||||
aborted: false,
|
||||
}));
|
||||
|
||||
(
|
||||
JoinNotifier {
|
||||
state: Rc::clone(&state),
|
||||
task_state: Rc::clone(&task_state),
|
||||
},
|
||||
Rc::clone(&join_state),
|
||||
JoinHandle {
|
||||
task_state,
|
||||
join_state,
|
||||
},
|
||||
JoinHandle { state },
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -111,20 +120,27 @@ struct RemoteState {
|
|||
|
||||
struct Task {
|
||||
future: Pin<Box<dyn Future<Output = ()>>>,
|
||||
join_state: Rc<RefCell<JoinState>>,
|
||||
}
|
||||
|
||||
struct JoinState<T> {
|
||||
waker: Option<Waker>,
|
||||
struct TaskState<T> {
|
||||
handle_waker: Option<Waker>,
|
||||
item: Option<T>,
|
||||
dropped: bool,
|
||||
}
|
||||
|
||||
struct JoinState {
|
||||
task_waker: Option<Waker>,
|
||||
aborted: bool,
|
||||
}
|
||||
|
||||
struct JoinNotifier<T> {
|
||||
state: Rc<RefCell<JoinState<T>>>,
|
||||
task_state: Rc<RefCell<TaskState<T>>>,
|
||||
}
|
||||
|
||||
pub struct JoinHandle<T> {
|
||||
state: Rc<RefCell<JoinState<T>>>,
|
||||
join_state: Rc<RefCell<JoinState>>,
|
||||
task_state: Rc<RefCell<TaskState<T>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
@ -196,12 +212,13 @@ impl ExecutorState {
|
|||
|
||||
fn spawn<F: Future + 'static>(&self, future: F) -> JoinHandle<F::Output> {
|
||||
let id = self.id_gen.generate();
|
||||
let (notifier, handle) = joiner();
|
||||
let (notifier, join_state, handle) = joiner();
|
||||
|
||||
let task = Task {
|
||||
future: Box::pin(async move {
|
||||
notifier.ready(future.await);
|
||||
}),
|
||||
join_state,
|
||||
};
|
||||
|
||||
self.tasks.borrow_mut().insert(id, task);
|
||||
|
@ -254,11 +271,8 @@ impl RemoteState {
|
|||
|
||||
impl<T> JoinNotifier<T> {
|
||||
fn ready(self, item: T) {
|
||||
let mut guard = self.state.borrow_mut();
|
||||
let mut guard = self.task_state.borrow_mut();
|
||||
guard.item = Some(item);
|
||||
if let Some(waker) = guard.waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -309,11 +323,23 @@ impl ThreadWaker {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T> JoinHandle<T> {
|
||||
pub fn abort(&self) {
|
||||
let mut guard = self.join_state.borrow_mut();
|
||||
|
||||
guard.aborted = true;
|
||||
|
||||
if let Some(waker) = guard.task_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Future for JoinHandle<T> {
|
||||
type Output = Result<T, JoinError>;
|
||||
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut guard = self.state.borrow_mut();
|
||||
let mut guard = self.task_state.borrow_mut();
|
||||
|
||||
if let Some(item) = guard.item.take() {
|
||||
return Poll::Ready(Ok(item));
|
||||
|
@ -323,7 +349,8 @@ impl<T> Future for JoinHandle<T> {
|
|||
return Poll::Ready(Err(JoinError));
|
||||
}
|
||||
|
||||
guard.waker = Some(cx.waker().clone());
|
||||
guard.handle_waker = Some(cx.waker().clone());
|
||||
drop(guard);
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
|
@ -368,6 +395,12 @@ where
|
|||
continue;
|
||||
};
|
||||
|
||||
let guard = task.join_state.borrow();
|
||||
if guard.aborted {
|
||||
continue;
|
||||
}
|
||||
drop(guard);
|
||||
|
||||
let maybe_local_waker = Arc::new(MaybeLocalWaker {
|
||||
task_id,
|
||||
remote_state: Arc::clone(&this.executor.state.remote_state),
|
||||
|
@ -383,6 +416,7 @@ where
|
|||
match res {
|
||||
Ok(Poll::Ready(())) => {}
|
||||
Ok(Poll::Pending) => {
|
||||
task.join_state.borrow_mut().task_waker = Some(maybe_local_waker);
|
||||
this.executor.state.tasks.borrow_mut().insert(task_id, task);
|
||||
}
|
||||
Err(_) => {
|
||||
|
@ -479,9 +513,9 @@ impl Drop for ExecutorToken {
|
|||
|
||||
impl<T> Drop for JoinNotifier<T> {
|
||||
fn drop(&mut self) {
|
||||
let mut guard = self.state.borrow_mut();
|
||||
let mut guard = self.task_state.borrow_mut();
|
||||
guard.dropped = true;
|
||||
if let Some(waker) = guard.waker.take() {
|
||||
if let Some(waker) = guard.handle_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue