Add block_on, multi-thread example
This commit is contained in:
parent
3aa0f237b3
commit
28b72a7d60
46
examples/block_on.rs
Normal file
46
examples/block_on.rs
Normal file
|
@ -0,0 +1,46 @@
|
|||
use safe_executor::{oneshot, Executor, JoinError};
|
||||
use std::time::Duration;
|
||||
|
||||
fn main() -> Result<(), JoinError> {
|
||||
let executor = Executor::new();
|
||||
|
||||
let execu2r = executor.clone();
|
||||
executor.block_on(async move {
|
||||
// Spawn a ton to invoke task pruning heuristic
|
||||
let join_handles = (0..1007)
|
||||
.map(|i| execu2r.spawn(async move { println!("{}", i) }))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for i in join_handles {
|
||||
i.await?;
|
||||
}
|
||||
|
||||
Ok(()) as Result<(), JoinError>
|
||||
})??;
|
||||
|
||||
let execu2r = executor.clone();
|
||||
std::thread::spawn(move || {
|
||||
let (tx, rx) = oneshot();
|
||||
|
||||
execu2r.spawn(async move {
|
||||
println!("Started polling");
|
||||
let val = rx.await;
|
||||
println!("delayed print, {:?}", val);
|
||||
});
|
||||
|
||||
// Delay a bit to invoke task pruning heuristic
|
||||
std::thread::sleep(Duration::from_secs(6));
|
||||
|
||||
println!("sending meowdy");
|
||||
let _ = tx.send("meowdy");
|
||||
|
||||
std::thread::sleep(Duration::from_secs(2));
|
||||
|
||||
println!("stopping");
|
||||
execu2r.stop();
|
||||
});
|
||||
|
||||
let res = executor.block_on(std::future::pending::<()>());
|
||||
println!("{:?}", res);
|
||||
Ok(())
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
use safe_executor::Runtime;
|
||||
use safe_executor::Executor;
|
||||
|
||||
fn spawn(runtime: &Runtime) {
|
||||
fn spawn(runtime: &Executor) {
|
||||
println!("Spawning futures");
|
||||
|
||||
let task1 = runtime.spawn(async move {
|
||||
|
@ -32,7 +32,7 @@ fn spawn(runtime: &Runtime) {
|
|||
}
|
||||
|
||||
fn main() {
|
||||
let runtime = Runtime::new();
|
||||
let runtime = Executor::new();
|
||||
|
||||
// This creates 3 new tasks
|
||||
spawn(&runtime);
|
||||
|
|
30
examples/multi-thread.rs
Normal file
30
examples/multi-thread.rs
Normal file
|
@ -0,0 +1,30 @@
|
|||
use safe_executor::{Executor, JoinError};
|
||||
|
||||
fn main() -> Result<(), JoinError> {
|
||||
let executor = Executor::new();
|
||||
|
||||
for _ in 0..4 {
|
||||
let executor = executor.clone();
|
||||
std::thread::spawn(move || {
|
||||
let _ = executor.block_on(std::future::pending::<()>());
|
||||
});
|
||||
}
|
||||
|
||||
let execu2r = executor.clone();
|
||||
executor.block_on(async move {
|
||||
let mut join_handles = Vec::new();
|
||||
for _ in 0..4 {
|
||||
join_handles.extend((0..502).map(|i| execu2r.spawn(async move {
|
||||
println!("{}", i);
|
||||
})));
|
||||
}
|
||||
|
||||
for handle in join_handles {
|
||||
handle.await?;
|
||||
}
|
||||
|
||||
println!("Done waiting");
|
||||
|
||||
Ok(()) as Result<(), JoinError>
|
||||
})?
|
||||
}
|
207
src/lib.rs
207
src/lib.rs
|
@ -7,6 +7,8 @@ use std::{
|
|||
Arc, Mutex, Weak,
|
||||
},
|
||||
task::{Context, Poll, Wake, Waker},
|
||||
thread::{Thread, ThreadId},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
enum OneshotState<T> {
|
||||
|
@ -22,12 +24,12 @@ impl<T> OneshotState<T> {
|
|||
}
|
||||
}
|
||||
|
||||
struct Sender<T> {
|
||||
pub struct Sender<T> {
|
||||
state: Arc<Mutex<OneshotState<T>>>,
|
||||
}
|
||||
|
||||
impl<T> Sender<T> {
|
||||
fn send(self, item: T) -> Result<(), T> {
|
||||
pub fn send(self, item: T) -> Result<(), T> {
|
||||
let mut guard = self.state.lock().unwrap();
|
||||
match guard.take() {
|
||||
OneshotState::New => {
|
||||
|
@ -58,47 +60,10 @@ impl<T> Drop for Sender<T> {
|
|||
}
|
||||
}
|
||||
|
||||
struct Receiver<T> {
|
||||
pub struct Receiver<T> {
|
||||
state: Arc<Mutex<OneshotState<T>>>,
|
||||
}
|
||||
|
||||
fn oneshot<T>() -> (Sender<T>, Receiver<T>) {
|
||||
let state = Arc::new(Mutex::new(OneshotState::New));
|
||||
|
||||
(
|
||||
Sender {
|
||||
state: Arc::clone(&state),
|
||||
},
|
||||
Receiver { state },
|
||||
)
|
||||
}
|
||||
|
||||
pub struct JoinHandle<T> {
|
||||
rx: Receiver<T>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Dropped;
|
||||
|
||||
impl std::fmt::Display for Dropped {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Sender was dropped")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Dropped {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct JoinError;
|
||||
|
||||
impl std::fmt::Display for JoinError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Task panicked")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for JoinError {}
|
||||
|
||||
impl<T> Future for Receiver<T> {
|
||||
type Output = Result<T, Dropped>;
|
||||
|
||||
|
@ -121,6 +86,43 @@ impl<T> Drop for Receiver<T> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn oneshot<T>() -> (Sender<T>, Receiver<T>) {
|
||||
let state = Arc::new(Mutex::new(OneshotState::New));
|
||||
|
||||
(
|
||||
Sender {
|
||||
state: Arc::clone(&state),
|
||||
},
|
||||
Receiver { state },
|
||||
)
|
||||
}
|
||||
|
||||
pub struct JoinHandle<T> {
|
||||
rx: Receiver<T>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Dropped;
|
||||
|
||||
impl std::fmt::Display for Dropped {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Sender was dropped")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Dropped {}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct JoinError;
|
||||
|
||||
impl std::fmt::Display for JoinError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Task panicked")
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for JoinError {}
|
||||
|
||||
impl<T> Future for JoinHandle<T> {
|
||||
type Output = Result<T, JoinError>;
|
||||
|
||||
|
@ -147,6 +149,7 @@ impl Task {
|
|||
let safe_waker = Arc::new(SafeWaker {
|
||||
task: Arc::downgrade(&self),
|
||||
runtime: Weak::clone(&runtime),
|
||||
thread: std::thread::current(),
|
||||
})
|
||||
.into();
|
||||
|
||||
|
@ -169,7 +172,6 @@ impl Task {
|
|||
};
|
||||
|
||||
if let Some(Poll::Ready(())) = &opt {
|
||||
println!("Marking task {} as available", self.task_id);
|
||||
*self.fut.lock().unwrap() = None;
|
||||
if let Some(runtime) = Weak::upgrade(&runtime) {
|
||||
let mut inner = runtime.lock().unwrap();
|
||||
|
@ -193,15 +195,10 @@ impl Task {
|
|||
}
|
||||
}
|
||||
|
||||
impl Drop for Task {
|
||||
fn drop(&mut self) {
|
||||
println!("Dropping task {}", self.task_id);
|
||||
}
|
||||
}
|
||||
|
||||
struct SafeWaker {
|
||||
task: Weak<Task>,
|
||||
runtime: Weak<Mutex<Inner>>,
|
||||
thread: Thread,
|
||||
}
|
||||
|
||||
impl Wake for SafeWaker {
|
||||
|
@ -211,19 +208,37 @@ impl Wake for SafeWaker {
|
|||
if let Some(runtime) = Weak::upgrade(&self.runtime) {
|
||||
let mut inner = runtime.lock().unwrap();
|
||||
inner.wake(task);
|
||||
self.thread.unpark();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct BlockOnWaker {
|
||||
thread: Thread,
|
||||
}
|
||||
|
||||
impl Wake for BlockOnWaker {
|
||||
fn wake(self: Arc<Self>) {
|
||||
self.thread.unpark();
|
||||
}
|
||||
}
|
||||
|
||||
struct Inner {
|
||||
task_id_counter: u64,
|
||||
head_available: Option<Arc<Task>>,
|
||||
head_woken: Option<Arc<Task>>,
|
||||
pending: HashMap<u64, Arc<Task>>,
|
||||
stopping: bool,
|
||||
threads: HashMap<ThreadId, Thread>,
|
||||
spawn_count: u64,
|
||||
prune_time: Instant,
|
||||
}
|
||||
|
||||
const SPAWN_COUNT: u64 = 1000;
|
||||
const PRUNE_DURATION: Duration = Duration::from_secs(5);
|
||||
|
||||
impl Inner {
|
||||
fn take_available_head(&mut self) -> Option<Arc<Task>> {
|
||||
let head = self.head_available.take()?;
|
||||
|
@ -237,14 +252,34 @@ impl Inner {
|
|||
id
|
||||
}
|
||||
|
||||
fn heuristic_prune(&mut self) {
|
||||
if self.spawn_count > SPAWN_COUNT || self.prune_time + PRUNE_DURATION < Instant::now() {
|
||||
self.prune();
|
||||
}
|
||||
}
|
||||
|
||||
fn prune(&mut self) {
|
||||
self.head_available = None;
|
||||
self.prune_time = Instant::now();
|
||||
}
|
||||
|
||||
fn wake(&mut self, task: Arc<Task>) {
|
||||
*task.next_woken.lock().unwrap() = self.head_woken.take();
|
||||
task.woken.store(true, Ordering::Release);
|
||||
self.head_woken = Some(task);
|
||||
if let Some(tid) = self.threads.keys().next() {
|
||||
let tid = *tid;
|
||||
if let Some(thread) = self.threads.remove(&tid) {
|
||||
thread.unpark();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn stop(&mut self) {
|
||||
self.stopping = true;
|
||||
for (_, thread) in self.threads.drain() {
|
||||
thread.unpark();
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn<T: Send + 'static>(
|
||||
|
@ -255,7 +290,6 @@ impl Inner {
|
|||
.take_available_head()
|
||||
.unwrap_or_else(|| Task::allocate(self.next_task_id()));
|
||||
|
||||
println!("spawning future in task {}", task.task_id);
|
||||
self.pending.insert(task.task_id, Arc::clone(&task));
|
||||
|
||||
let (tx, rx) = oneshot();
|
||||
|
@ -263,30 +297,35 @@ impl Inner {
|
|||
let _ = tx.send(future.await);
|
||||
}));
|
||||
self.wake(task);
|
||||
self.spawn_count += 1;
|
||||
|
||||
JoinHandle { rx }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Runtime {
|
||||
pub struct Executor {
|
||||
inner: Arc<Mutex<Inner>>,
|
||||
}
|
||||
|
||||
impl Default for Runtime {
|
||||
impl Default for Executor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Runtime {
|
||||
impl Executor {
|
||||
pub fn new() -> Self {
|
||||
Runtime {
|
||||
Executor {
|
||||
inner: Arc::new(Mutex::new(Inner {
|
||||
task_id_counter: 0,
|
||||
head_available: None,
|
||||
head_woken: None,
|
||||
pending: HashMap::new(),
|
||||
stopping: false,
|
||||
threads: HashMap::new(),
|
||||
spawn_count: 0,
|
||||
prune_time: Instant::now(),
|
||||
})),
|
||||
}
|
||||
}
|
||||
|
@ -298,8 +337,14 @@ impl Runtime {
|
|||
self.inner.lock().unwrap().spawn(future)
|
||||
}
|
||||
|
||||
pub fn tick(&self) {
|
||||
let woken_list = { self.inner.lock().unwrap().head_woken.take() };
|
||||
pub fn tick(&self) -> bool {
|
||||
let woken_list = {
|
||||
let mut guard = self.inner.lock().unwrap();
|
||||
if guard.stopping {
|
||||
return false;
|
||||
}
|
||||
guard.head_woken.take()
|
||||
};
|
||||
|
||||
if let Some(woken) = woken_list {
|
||||
let mut next = Some(woken);
|
||||
|
@ -309,9 +354,30 @@ impl Runtime {
|
|||
next = task.next_woken.lock().unwrap().take();
|
||||
task.poll(Arc::downgrade(&self.inner));
|
||||
}
|
||||
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stopping(&self) -> bool {
|
||||
self.inner.lock().unwrap().stopping
|
||||
}
|
||||
|
||||
pub fn stop(&self) {
|
||||
self.inner.lock().unwrap().stop();
|
||||
}
|
||||
|
||||
pub fn park(&self) {
|
||||
{
|
||||
let mut guard = self.inner.lock().unwrap();
|
||||
let thread = std::thread::current();
|
||||
guard.threads.insert(thread.id(), thread);
|
||||
}
|
||||
std::thread::park();
|
||||
}
|
||||
|
||||
pub fn any_woken(&self) -> bool {
|
||||
self.inner.lock().unwrap().head_woken.is_some()
|
||||
}
|
||||
|
@ -319,4 +385,37 @@ impl Runtime {
|
|||
pub fn prune(&self) {
|
||||
self.inner.lock().unwrap().prune();
|
||||
}
|
||||
|
||||
pub fn heuristic_prune(&self) {
|
||||
self.inner.lock().unwrap().heuristic_prune();
|
||||
}
|
||||
|
||||
pub fn block_on<T: Send + 'static>(
|
||||
&self,
|
||||
f: impl Future<Output = T> + Send + 'static,
|
||||
) -> Result<T, JoinError> {
|
||||
let mut join_handle = self.spawn(f);
|
||||
|
||||
let block_on_waker = Arc::new(BlockOnWaker {
|
||||
thread: std::thread::current(),
|
||||
})
|
||||
.into();
|
||||
let mut block_on_context = Context::from_waker(&block_on_waker);
|
||||
|
||||
loop {
|
||||
while self.tick() {
|
||||
if let Poll::Ready(res) = Pin::new(&mut join_handle).poll(&mut block_on_context) {
|
||||
return res;
|
||||
}
|
||||
self.heuristic_prune();
|
||||
}
|
||||
if let Poll::Ready(res) = Pin::new(&mut join_handle).poll(&mut block_on_context) {
|
||||
return res;
|
||||
}
|
||||
if self.stopping() {
|
||||
return Err(JoinError);
|
||||
}
|
||||
self.park();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue