Add block_on, multi-thread example

This commit is contained in:
Aode (Lion) 2022-01-29 20:26:35 -06:00
parent 3aa0f237b3
commit 28b72a7d60
4 changed files with 232 additions and 57 deletions

46
examples/block_on.rs Normal file
View file

@ -0,0 +1,46 @@
use safe_executor::{oneshot, Executor, JoinError};
use std::time::Duration;
fn main() -> Result<(), JoinError> {
let executor = Executor::new();
let execu2r = executor.clone();
executor.block_on(async move {
// Spawn a ton to invoke task pruning heuristic
let join_handles = (0..1007)
.map(|i| execu2r.spawn(async move { println!("{}", i) }))
.collect::<Vec<_>>();
for i in join_handles {
i.await?;
}
Ok(()) as Result<(), JoinError>
})??;
let execu2r = executor.clone();
std::thread::spawn(move || {
let (tx, rx) = oneshot();
execu2r.spawn(async move {
println!("Started polling");
let val = rx.await;
println!("delayed print, {:?}", val);
});
// Delay a bit to invoke task pruning heuristic
std::thread::sleep(Duration::from_secs(6));
println!("sending meowdy");
let _ = tx.send("meowdy");
std::thread::sleep(Duration::from_secs(2));
println!("stopping");
execu2r.stop();
});
let res = executor.block_on(std::future::pending::<()>());
println!("{:?}", res);
Ok(())
}

View file

@ -1,6 +1,6 @@
use safe_executor::Runtime;
use safe_executor::Executor;
fn spawn(runtime: &Runtime) {
fn spawn(runtime: &Executor) {
println!("Spawning futures");
let task1 = runtime.spawn(async move {
@ -32,7 +32,7 @@ fn spawn(runtime: &Runtime) {
}
fn main() {
let runtime = Runtime::new();
let runtime = Executor::new();
// This creates 3 new tasks
spawn(&runtime);

30
examples/multi-thread.rs Normal file
View file

@ -0,0 +1,30 @@
use safe_executor::{Executor, JoinError};
fn main() -> Result<(), JoinError> {
let executor = Executor::new();
for _ in 0..4 {
let executor = executor.clone();
std::thread::spawn(move || {
let _ = executor.block_on(std::future::pending::<()>());
});
}
let execu2r = executor.clone();
executor.block_on(async move {
let mut join_handles = Vec::new();
for _ in 0..4 {
join_handles.extend((0..502).map(|i| execu2r.spawn(async move {
println!("{}", i);
})));
}
for handle in join_handles {
handle.await?;
}
println!("Done waiting");
Ok(()) as Result<(), JoinError>
})?
}

View file

@ -7,6 +7,8 @@ use std::{
Arc, Mutex, Weak,
},
task::{Context, Poll, Wake, Waker},
thread::{Thread, ThreadId},
time::{Duration, Instant},
};
enum OneshotState<T> {
@ -22,12 +24,12 @@ impl<T> OneshotState<T> {
}
}
struct Sender<T> {
pub struct Sender<T> {
state: Arc<Mutex<OneshotState<T>>>,
}
impl<T> Sender<T> {
fn send(self, item: T) -> Result<(), T> {
pub fn send(self, item: T) -> Result<(), T> {
let mut guard = self.state.lock().unwrap();
match guard.take() {
OneshotState::New => {
@ -58,47 +60,10 @@ impl<T> Drop for Sender<T> {
}
}
struct Receiver<T> {
pub struct Receiver<T> {
state: Arc<Mutex<OneshotState<T>>>,
}
fn oneshot<T>() -> (Sender<T>, Receiver<T>) {
let state = Arc::new(Mutex::new(OneshotState::New));
(
Sender {
state: Arc::clone(&state),
},
Receiver { state },
)
}
pub struct JoinHandle<T> {
rx: Receiver<T>,
}
#[derive(Debug)]
struct Dropped;
impl std::fmt::Display for Dropped {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Sender was dropped")
}
}
impl std::error::Error for Dropped {}
#[derive(Debug)]
pub struct JoinError;
impl std::fmt::Display for JoinError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Task panicked")
}
}
impl std::error::Error for JoinError {}
impl<T> Future for Receiver<T> {
type Output = Result<T, Dropped>;
@ -121,6 +86,43 @@ impl<T> Drop for Receiver<T> {
}
}
pub fn oneshot<T>() -> (Sender<T>, Receiver<T>) {
let state = Arc::new(Mutex::new(OneshotState::New));
(
Sender {
state: Arc::clone(&state),
},
Receiver { state },
)
}
pub struct JoinHandle<T> {
rx: Receiver<T>,
}
#[derive(Debug)]
pub struct Dropped;
impl std::fmt::Display for Dropped {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Sender was dropped")
}
}
impl std::error::Error for Dropped {}
#[derive(Debug)]
pub struct JoinError;
impl std::fmt::Display for JoinError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Task panicked")
}
}
impl std::error::Error for JoinError {}
impl<T> Future for JoinHandle<T> {
type Output = Result<T, JoinError>;
@ -147,6 +149,7 @@ impl Task {
let safe_waker = Arc::new(SafeWaker {
task: Arc::downgrade(&self),
runtime: Weak::clone(&runtime),
thread: std::thread::current(),
})
.into();
@ -169,7 +172,6 @@ impl Task {
};
if let Some(Poll::Ready(())) = &opt {
println!("Marking task {} as available", self.task_id);
*self.fut.lock().unwrap() = None;
if let Some(runtime) = Weak::upgrade(&runtime) {
let mut inner = runtime.lock().unwrap();
@ -193,15 +195,10 @@ impl Task {
}
}
impl Drop for Task {
fn drop(&mut self) {
println!("Dropping task {}", self.task_id);
}
}
struct SafeWaker {
task: Weak<Task>,
runtime: Weak<Mutex<Inner>>,
thread: Thread,
}
impl Wake for SafeWaker {
@ -211,19 +208,37 @@ impl Wake for SafeWaker {
if let Some(runtime) = Weak::upgrade(&self.runtime) {
let mut inner = runtime.lock().unwrap();
inner.wake(task);
self.thread.unpark();
}
}
}
}
}
struct BlockOnWaker {
thread: Thread,
}
impl Wake for BlockOnWaker {
fn wake(self: Arc<Self>) {
self.thread.unpark();
}
}
struct Inner {
task_id_counter: u64,
head_available: Option<Arc<Task>>,
head_woken: Option<Arc<Task>>,
pending: HashMap<u64, Arc<Task>>,
stopping: bool,
threads: HashMap<ThreadId, Thread>,
spawn_count: u64,
prune_time: Instant,
}
const SPAWN_COUNT: u64 = 1000;
const PRUNE_DURATION: Duration = Duration::from_secs(5);
impl Inner {
fn take_available_head(&mut self) -> Option<Arc<Task>> {
let head = self.head_available.take()?;
@ -237,14 +252,34 @@ impl Inner {
id
}
fn heuristic_prune(&mut self) {
if self.spawn_count > SPAWN_COUNT || self.prune_time + PRUNE_DURATION < Instant::now() {
self.prune();
}
}
fn prune(&mut self) {
self.head_available = None;
self.prune_time = Instant::now();
}
fn wake(&mut self, task: Arc<Task>) {
*task.next_woken.lock().unwrap() = self.head_woken.take();
task.woken.store(true, Ordering::Release);
self.head_woken = Some(task);
if let Some(tid) = self.threads.keys().next() {
let tid = *tid;
if let Some(thread) = self.threads.remove(&tid) {
thread.unpark();
}
}
}
fn stop(&mut self) {
self.stopping = true;
for (_, thread) in self.threads.drain() {
thread.unpark();
}
}
fn spawn<T: Send + 'static>(
@ -255,7 +290,6 @@ impl Inner {
.take_available_head()
.unwrap_or_else(|| Task::allocate(self.next_task_id()));
println!("spawning future in task {}", task.task_id);
self.pending.insert(task.task_id, Arc::clone(&task));
let (tx, rx) = oneshot();
@ -263,30 +297,35 @@ impl Inner {
let _ = tx.send(future.await);
}));
self.wake(task);
self.spawn_count += 1;
JoinHandle { rx }
}
}
#[derive(Clone)]
pub struct Runtime {
pub struct Executor {
inner: Arc<Mutex<Inner>>,
}
impl Default for Runtime {
impl Default for Executor {
fn default() -> Self {
Self::new()
}
}
impl Runtime {
impl Executor {
pub fn new() -> Self {
Runtime {
Executor {
inner: Arc::new(Mutex::new(Inner {
task_id_counter: 0,
head_available: None,
head_woken: None,
pending: HashMap::new(),
stopping: false,
threads: HashMap::new(),
spawn_count: 0,
prune_time: Instant::now(),
})),
}
}
@ -298,8 +337,14 @@ impl Runtime {
self.inner.lock().unwrap().spawn(future)
}
pub fn tick(&self) {
let woken_list = { self.inner.lock().unwrap().head_woken.take() };
pub fn tick(&self) -> bool {
let woken_list = {
let mut guard = self.inner.lock().unwrap();
if guard.stopping {
return false;
}
guard.head_woken.take()
};
if let Some(woken) = woken_list {
let mut next = Some(woken);
@ -309,9 +354,30 @@ impl Runtime {
next = task.next_woken.lock().unwrap().take();
task.poll(Arc::downgrade(&self.inner));
}
true
} else {
false
}
}
pub fn stopping(&self) -> bool {
self.inner.lock().unwrap().stopping
}
pub fn stop(&self) {
self.inner.lock().unwrap().stop();
}
pub fn park(&self) {
{
let mut guard = self.inner.lock().unwrap();
let thread = std::thread::current();
guard.threads.insert(thread.id(), thread);
}
std::thread::park();
}
pub fn any_woken(&self) -> bool {
self.inner.lock().unwrap().head_woken.is_some()
}
@ -319,4 +385,37 @@ impl Runtime {
pub fn prune(&self) {
self.inner.lock().unwrap().prune();
}
pub fn heuristic_prune(&self) {
self.inner.lock().unwrap().heuristic_prune();
}
pub fn block_on<T: Send + 'static>(
&self,
f: impl Future<Output = T> + Send + 'static,
) -> Result<T, JoinError> {
let mut join_handle = self.spawn(f);
let block_on_waker = Arc::new(BlockOnWaker {
thread: std::thread::current(),
})
.into();
let mut block_on_context = Context::from_waker(&block_on_waker);
loop {
while self.tick() {
if let Poll::Ready(res) = Pin::new(&mut join_handle).poll(&mut block_on_context) {
return res;
}
self.heuristic_prune();
}
if let Poll::Ready(res) = Pin::new(&mut join_handle).poll(&mut block_on_context) {
return res;
}
if self.stopping() {
return Err(JoinError);
}
self.park();
}
}
}