polldance/src/net.rs
2023-08-25 12:11:07 -05:00

292 lines
8.6 KiB
Rust

use crate::io::{in_progress, nonblocking, InProgress, Nonblocking, ReadBytes, Result};
use rustix::{
fd::{AsFd, OwnedFd},
net::{AddressFamily, RecvFlags, SendFlags, SocketAddrAny, SocketFlags, SocketType},
};
use std::net::{IpAddr, SocketAddr};
pub struct UdpSocket {
fd: OwnedFd,
}
pub struct UdpSocketBuilder {
bind_addr: SocketAddr,
fd: OwnedFd,
}
pub struct TcpStream {
fd: OwnedFd,
}
pub struct TcpStreamBuilder {
connect_addr: SocketAddr,
bind_addr: Option<SocketAddr>,
fd: OwnedFd,
}
pub struct TcpListener {
fd: OwnedFd,
}
pub struct TcpListenerBuilder {
bind_addr: SocketAddr,
fd: OwnedFd,
}
impl UdpSocket {
pub fn bind<A: Into<SocketAddr>>(socket_address: A) -> Result<UdpSocketBuilder> {
let bind_addr = socket_address.into();
let family = match bind_addr.ip() {
IpAddr::V4(_) => AddressFamily::INET,
IpAddr::V6(_) => AddressFamily::INET6,
};
let sock = rustix::net::socket_with(
family,
SocketType::DGRAM,
SocketFlags::NONBLOCK | SocketFlags::CLOEXEC,
Some(rustix::net::ipproto::UDP),
)?;
rustix::net::sockopt::set_socket_reuseaddr(&sock, true)?;
Ok(UdpSocketBuilder {
bind_addr,
fd: sock,
})
}
pub fn try_recv_from(
&self,
buf: &mut [u8],
) -> Result<Nonblocking<(usize, Option<SocketAddr>)>> {
nonblocking(
rustix::net::recvfrom(self, buf, RecvFlags::empty()).map(|(size, opt)| {
(
size,
opt.map(|addr| match addr {
SocketAddrAny::V4(v4) => SocketAddr::V4(v4),
SocketAddrAny::V6(v6) => SocketAddr::V6(v6),
SocketAddrAny::Unix(_) => {
unreachable!("Cannot have unix socket origin for UdpSocket")
}
e => {
panic!("Unsupported SocketAddrAny variant, {:?}", e)
}
}),
)
}),
)
}
pub fn try_send_to(&self, buf: &[u8], target: SocketAddr) -> Result<Nonblocking<usize>> {
nonblocking(rustix::net::sendto(self, buf, SendFlags::NOSIGNAL, &target))
}
}
impl UdpSocketBuilder {
pub fn try_finish(self) -> Result<std::result::Result<UdpSocket, UdpSocketBuilder>> {
let res = match self.bind_addr {
SocketAddr::V4(v4) => rustix::net::bind_v4(&self.fd, &v4),
SocketAddr::V6(v6) => rustix::net::bind_v6(&self.fd, &v6),
};
match in_progress(res)? {
InProgress::Ready(_) => Ok(Ok(UdpSocket { fd: self.fd })),
InProgress::InProgress => Ok(Err(self)),
}
}
}
impl TcpListener {
pub fn try_accept(&self) -> Result<Nonblocking<(TcpStream, Option<SocketAddr>)>> {
nonblocking(
rustix::net::acceptfrom_with(&self.fd, SocketFlags::NONBLOCK | SocketFlags::CLOEXEC)
.map(|(fd, addr)| {
(
TcpStream { fd },
addr.map(|addr| match addr {
SocketAddrAny::V4(v4) => SocketAddr::V4(v4),
SocketAddrAny::V6(v6) => SocketAddr::V6(v6),
SocketAddrAny::Unix(_) => {
unreachable!("Cannot have unix socket origin for TcpListener")
}
e => {
panic!("Unsupported SocketAddrAny variant, {:?}", e)
}
}),
)
}),
)
}
pub fn bind<A: Into<SocketAddr>>(socket_address: A) -> Result<TcpListenerBuilder> {
let bind_addr = socket_address.into();
let family = match bind_addr.ip() {
IpAddr::V4(_) => AddressFamily::INET,
IpAddr::V6(_) => AddressFamily::INET6,
};
let sock = rustix::net::socket_with(
family,
SocketType::STREAM,
SocketFlags::NONBLOCK | SocketFlags::CLOEXEC,
Some(rustix::net::ipproto::TCP),
)?;
rustix::net::sockopt::set_socket_reuseaddr(&sock, true)?;
Ok(TcpListenerBuilder {
bind_addr,
fd: sock,
})
}
}
impl TcpStream {
pub fn connect<A: Into<SocketAddr>>(socket_address: A) -> Result<TcpStreamBuilder> {
Self::connect_with_bind(socket_address, None as Option<SocketAddr>)
}
pub fn connect_with_bind<A: Into<SocketAddr>, B: Into<SocketAddr>>(
socket_address: A,
bind_address: Option<B>,
) -> Result<TcpStreamBuilder> {
let connect_addr = socket_address.into();
let bind_addr = bind_address.map(|a| a.into());
let family = match (connect_addr.ip(), bind_addr.map(|a| a.ip())) {
(IpAddr::V4(_), None | Some(IpAddr::V4(_))) => AddressFamily::INET,
(IpAddr::V6(_), None | Some(IpAddr::V6(_))) => AddressFamily::INET6,
_ => return Err(std::io::ErrorKind::InvalidInput.into()),
};
let sock = rustix::net::socket_with(
family,
SocketType::STREAM,
SocketFlags::NONBLOCK | SocketFlags::CLOEXEC,
Some(rustix::net::ipproto::TCP),
)?;
Ok(TcpStreamBuilder {
connect_addr,
bind_addr,
fd: sock,
})
}
pub fn try_write(&self, buf: &[u8]) -> Result<Nonblocking<usize>> {
nonblocking(rustix::io::write(&self, buf))
}
pub fn try_read(&self, buf: &mut [u8]) -> Result<Nonblocking<ReadBytes>> {
nonblocking(rustix::io::read(&self, buf).map(ReadBytes::from))
}
}
impl TcpListenerBuilder {
pub fn try_finish(self) -> Result<std::result::Result<TcpListener, TcpListenerBuilder>> {
let res = match self.bind_addr {
SocketAddr::V4(v4) => rustix::net::bind_v4(&self.fd, &v4),
SocketAddr::V6(v6) => rustix::net::bind_v6(&self.fd, &v6),
};
match in_progress(res)? {
InProgress::Ready(_) => {
rustix::net::listen(&self.fd, 1024)?;
Ok(Ok(TcpListener { fd: self.fd }))
}
InProgress::InProgress => Ok(Err(self)),
}
}
}
impl TcpStreamBuilder {
pub fn try_finish(mut self) -> Result<std::result::Result<TcpStream, TcpStreamBuilder>> {
if let Some(bind_addr) = self.bind_addr {
let res = match bind_addr {
SocketAddr::V4(v4) => rustix::net::bind_v4(&self.fd, &v4),
SocketAddr::V6(v6) => rustix::net::bind_v6(&self.fd, &v6),
};
if let InProgress::InProgress = in_progress(res)? {
return Ok(Err(self));
}
// we have succesfully bound, don't call bind again after this
self.bind_addr.take();
}
let res = match self.connect_addr {
SocketAddr::V4(v4) => rustix::net::connect_v4(&self.fd, &v4),
SocketAddr::V6(v6) => rustix::net::connect_v6(&self.fd, &v6),
};
match in_progress(res)? {
InProgress::Ready(_) => Ok(Ok(TcpStream { fd: self.fd })),
InProgress::InProgress => Ok(Err(self)),
}
}
}
impl std::io::Read for TcpStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
match self.try_read(buf) {
Ok(Nonblocking::Ready(size)) => Ok(size.into()),
Ok(Nonblocking::WouldBlock) => Err(std::io::ErrorKind::WouldBlock.into()),
Err(e) => Err(e),
}
}
}
impl std::io::Write for TcpStream {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
match self.try_write(buf) {
Ok(Nonblocking::Ready(size)) => Ok(size),
Ok(Nonblocking::WouldBlock) => Err(std::io::ErrorKind::WouldBlock.into()),
Err(e) => Err(e),
}
}
fn flush(&mut self) -> Result<()> {
Ok(())
}
}
impl AsFd for UdpSocket {
fn as_fd(&self) -> rustix::fd::BorrowedFd<'_> {
self.fd.as_fd()
}
}
impl AsFd for UdpSocketBuilder {
fn as_fd(&self) -> rustix::fd::BorrowedFd<'_> {
self.fd.as_fd()
}
}
impl AsFd for TcpStream {
fn as_fd(&self) -> rustix::fd::BorrowedFd<'_> {
self.fd.as_fd()
}
}
impl AsFd for TcpListener {
fn as_fd(&self) -> rustix::fd::BorrowedFd<'_> {
self.fd.as_fd()
}
}
impl AsFd for TcpStreamBuilder {
fn as_fd(&self) -> rustix::fd::BorrowedFd<'_> {
self.fd.as_fd()
}
}
impl AsFd for TcpListenerBuilder {
fn as_fd(&self) -> rustix::fd::BorrowedFd<'_> {
self.fd.as_fd()
}
}