mpsc/src/lib.rs
2022-02-17 13:39:47 -05:00

117 lines
2.5 KiB
Rust

use std::{
collections::VecDeque,
future::Future,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
thread::Thread,
};
pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
let state = Arc::new(Mutex::new(State {
items: VecDeque::new(),
wake: None,
}));
(
Sender {
state: Arc::clone(&state),
},
Receiver { state },
)
}
pub struct Sender<T> {
state: Arc<Mutex<State<T>>>,
}
pub struct Receiver<T> {
state: Arc<Mutex<State<T>>>,
}
struct State<T> {
items: VecDeque<T>,
wake: Option<WakerKind>,
}
enum WakerKind {
Waker(Waker),
Thread(Thread),
}
struct Receive<'a, T> {
state: &'a Arc<Mutex<State<T>>>,
}
impl<T> Sender<T> {
pub fn send(&self, item: T) {
let mut guard = self.state.lock().unwrap();
guard.items.push_back(item);
match &guard.wake {
Some(WakerKind::Waker(ref waker)) => waker.wake_by_ref(),
Some(WakerKind::Thread(ref thread)) => thread.unpark(),
None => {}
}
}
}
impl<T> Receiver<T> {
pub async fn recv(&mut self) -> Option<T> {
Receive { state: &self.state }.await
}
pub fn recv_blocking(&mut self) -> Option<T> {
loop {
{
let mut guard = self.state.lock().unwrap();
if let Some(item) = guard.items.pop_front() {
guard.wake.take();
return Some(item);
}
if Arc::strong_count(&self.state) == 1 {
return None;
}
guard.wake = Some(WakerKind::Thread(std::thread::current()));
}
std::thread::park();
}
}
pub fn try_recv(&mut self) -> Option<T> {
self.state.lock().unwrap().items.pop_front()
}
}
impl<'a, T> Future for Receive<'a, T> {
type Output = Option<T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut guard = self.state.lock().unwrap();
if let Some(item) = guard.items.pop_front() {
guard.wake.take();
return Poll::Ready(Some(item));
}
if Arc::strong_count(self.state) == 1 {
return Poll::Ready(None);
}
guard.wake = Some(WakerKind::Waker(cx.waker().clone()));
Poll::Pending
}
}
impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
}
}
}