polldance/src/lib.rs
2023-08-25 12:11:07 -05:00

421 lines
9.1 KiB
Rust

use rustix::{
event::{PollFd, PollFlags},
pipe::PipeFlags,
};
use std::{
borrow::Borrow,
collections::HashMap,
ops::{BitAnd, BitOr, BitXor},
sync::Arc,
};
pub mod io;
pub mod net;
pub mod fd {
pub use rustix::fd::{AsFd, AsRawFd, BorrowedFd, OwnedFd};
pub fn try_clone<A: AsFd>(fd: A) -> super::io::Result<OwnedFd> {
rustix::fs::fcntl_dupfd_cloexec(fd, 0).map_err(From::from)
}
}
use fd::{AsFd, BorrowedFd, OwnedFd};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct KeyRef(*const ());
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct Key(KeyRef);
pub struct PollManager {
notify: Notify,
notify_token: NotifyToken,
notify_key: Option<usize>,
io: HashMap<KeyRef, Managed>,
}
struct Managed {
io: Io,
interests: Readiness,
}
struct Io {
inner: Arc<dyn AsFd>,
}
#[derive(Default)]
pub struct Poller<'a> {
timeout: Option<i32>,
flags: Vec<PollFlags>,
fds: Vec<PollFd<'a>>,
}
#[derive(Clone, Copy, Default, PartialEq, Eq, Hash)]
pub struct Readiness {
read: bool,
write: bool,
hangup: bool,
error: bool,
}
pub struct Notify {
registered: OwnedFd,
}
#[derive(Clone)]
pub struct NotifyToken {
free: Arc<OwnedFd>,
}
pub fn notify_pair() -> io::Result<(Notify, NotifyToken)> {
let (registered, free) = rustix::pipe::pipe_with(PipeFlags::NONBLOCK | PipeFlags::CLOEXEC)?;
Ok((
Notify { registered },
NotifyToken {
free: Arc::new(free),
},
))
}
impl KeyRef {
fn from_arc<A>(arc: &Arc<A>) -> Self {
KeyRef(Arc::as_ptr(arc) as *const _)
}
}
impl Key {
pub fn is_key(&self, keyref: KeyRef) -> bool {
self.0 == keyref
}
}
impl Borrow<KeyRef> for Key {
fn borrow(&self) -> &KeyRef {
&self.0
}
}
impl AsRef<KeyRef> for Key {
fn as_ref(&self) -> &KeyRef {
&self.0
}
}
impl AsFd for Io {
fn as_fd(&self) -> BorrowedFd<'_> {
self.inner.as_fd()
}
}
impl PollManager {
pub fn new() -> io::Result<Self> {
let (notify, notify_token) = notify_pair()?;
Ok(Self {
notify,
notify_token,
notify_key: None,
io: HashMap::new(),
})
}
pub fn register<A: AsFd + 'static>(&mut self, io: Arc<A>, interests: Readiness) -> Key {
let key = KeyRef::from_arc(&io);
self.io.insert(
key,
Managed {
io: Io { inner: io },
interests,
},
);
Key(key)
}
pub fn update_interests<A: AsFd + 'static>(
&mut self,
io: &Arc<A>,
callback: impl Fn(Readiness) -> Readiness,
) {
let keyref = KeyRef::from_arc(io);
if let Some(ref mut managed) = self.io.get_mut(&keyref) {
managed.interests = callback(managed.interests);
}
}
pub fn update_interests_by_key(
&mut self,
key: &Key,
callback: impl Fn(Readiness) -> Readiness,
) {
if let Some(ref mut managed) = self.io.get_mut(&key.0) {
managed.interests = callback(managed.interests);
}
}
pub fn deregister<A: AsFd + 'static>(&mut self, io: &Arc<A>) {
self.io
.remove(&KeyRef::from_arc(io))
.map(|managed| managed.interests);
}
pub fn deregister_by_key(&mut self, key: Key) {
self.io.remove(&key.0).map(|managed| managed.interests);
}
pub fn notifier(&self) -> NotifyToken {
self.notify_token.clone()
}
pub fn poll(&mut self, timeout: Option<i32>) -> io::Result<Vec<(KeyRef, Readiness)>> {
let mut poller = Poller::new();
if let Some(timeout) = timeout {
poller.timeout(timeout);
}
let notify_key = poller.add(
&self.notify,
Readiness::hangup() | Readiness::error() | Readiness::read(),
);
self.notify_key = Some(notify_key);
let mut mapping = HashMap::new();
for (mkey, managed) in self.io.iter() {
let key = poller.add(&managed.io, managed.interests);
mapping.insert(key, *mkey);
}
let events = poller
.poll()?
.into_iter()
.filter_map(|(key, readiness)| {
if key == notify_key {
let mut buf = [0; 16];
loop {
if let Err(e) = rustix::io::read(&self.notify, &mut buf) {
if e == rustix::io::Errno::AGAIN || e == rustix::io::Errno::WOULDBLOCK {
break;
}
}
}
None
} else {
let mkey = mapping.remove(&key)?;
Some((mkey, readiness))
}
})
.collect();
Ok(events)
}
}
impl Readiness {
pub fn empty() -> Self {
Self::default()
}
pub fn all() -> Self {
Readiness {
read: true,
write: true,
hangup: true,
error: true,
}
}
pub fn read() -> Self {
Self {
read: true,
..Default::default()
}
}
pub fn write() -> Self {
Self {
write: true,
..Default::default()
}
}
pub fn hangup() -> Self {
Self {
hangup: true,
..Default::default()
}
}
pub fn error() -> Self {
Self {
error: true,
..Default::default()
}
}
pub fn difference(self, rhs: Self) -> Self {
self & (self ^ rhs)
}
pub fn is_empty(self) -> bool {
self == Self::empty()
}
pub fn is_intersect(self, rhs: Self) -> bool {
!self.is_disjoint(rhs)
}
pub fn is_disjoint(self, rhs: Self) -> bool {
(self & rhs).is_empty()
}
pub fn is_read(self) -> bool {
self.read
}
pub fn is_write(self) -> bool {
self.write
}
pub fn is_hangup(self) -> bool {
self.hangup
}
pub fn is_error(self) -> bool {
self.error
}
}
impl BitXor for Readiness {
type Output = Self;
fn bitxor(self, rhs: Self) -> Self::Output {
Self {
read: self.read ^ rhs.read,
write: self.write ^ rhs.write,
hangup: self.hangup ^ rhs.hangup,
error: self.error ^ rhs.error,
}
}
}
impl BitOr for Readiness {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
Self {
read: self.read || rhs.read,
write: self.write || rhs.write,
hangup: self.hangup || rhs.hangup,
error: self.error || rhs.error,
}
}
}
impl BitAnd for Readiness {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
Self {
read: self.read && rhs.read,
write: self.write && rhs.write,
hangup: self.hangup && rhs.hangup,
error: self.error && rhs.error,
}
}
}
impl From<PollFlags> for Readiness {
fn from(flags: PollFlags) -> Self {
Self {
read: flags.intersects(PollFlags::IN),
write: flags.intersects(PollFlags::OUT),
hangup: flags.intersects(PollFlags::HUP),
error: flags.intersects(PollFlags::ERR),
}
}
}
impl From<Readiness> for PollFlags {
fn from(r: Readiness) -> Self {
let mut output = PollFlags::empty();
if r.read {
output |= PollFlags::IN;
}
if r.write {
output |= PollFlags::OUT;
}
if r.hangup {
output |= PollFlags::HUP;
}
if r.error {
output |= PollFlags::ERR;
}
output
}
}
impl<'a> Poller<'a> {
pub fn new() -> Self {
Default::default()
}
pub fn add<'b: 'a, A: AsFd>(&mut self, fd: &'b A, interests: Readiness) -> usize {
let len = self.fds.len();
self.flags.push(interests.into());
self.fds.push(PollFd::new(fd, interests.into()));
len
}
pub fn timeout(&mut self, timeout: i32) {
self.timeout = Some(timeout);
}
pub fn poll(mut self) -> io::Result<Vec<(usize, Readiness)>> {
let timeout = self.timeout.unwrap_or(-1);
let n = rustix::event::poll(&mut self.fds, timeout)?;
Ok(self
.fds
.into_iter()
.zip(self.flags)
.enumerate()
.filter_map(|(key, (pollfd, flags))| {
if pollfd.revents().intersects(flags) {
Some((key, (pollfd.revents().into())))
} else {
None
}
})
.take(n)
.collect())
}
}
impl NotifyToken {
pub fn notify(self) -> io::Result<Self> {
if let Err(e) = rustix::io::write(&*self.free, &[0]) {
if e == rustix::io::Errno::AGAIN || e == rustix::io::Errno::WOULDBLOCK {
return Ok(self);
}
return Err(e.into());
}
Ok(self)
}
}
impl AsFd for Notify {
fn as_fd(&self) -> BorrowedFd<'_> {
self.registered.as_fd()
}
}