Add connect_with_bind for TcpStream

This commit is contained in:
Aode (lion) 2022-07-22 18:01:45 -05:00
parent cf88cef839
commit d6d08c8bce

View file

@ -14,7 +14,7 @@ pub struct UdpSocket {
}
pub struct UdpSocketBuilder {
sock_addr: SocketAddr,
bind_addr: SocketAddr,
fd: OwnedFd,
}
@ -23,7 +23,8 @@ pub struct TcpStream {
}
pub struct TcpStreamBuilder {
sock_addr: SocketAddr,
connect_addr: SocketAddr,
bind_addr: Option<SocketAddr>,
fd: OwnedFd,
}
@ -32,15 +33,15 @@ pub struct TcpListener {
}
pub struct TcpListenerBuilder {
sock_addr: SocketAddr,
bind_addr: SocketAddr,
fd: OwnedFd,
}
impl UdpSocket {
pub fn bind<A: Into<SocketAddr>>(socket_address: A) -> Result<UdpSocketBuilder> {
let sock_addr = socket_address.into();
let bind_addr = socket_address.into();
let family = match sock_addr.ip() {
let family = match bind_addr.ip() {
IpAddr::V4(_) => AddressFamily::INET,
IpAddr::V6(_) => AddressFamily::INET6,
};
@ -55,7 +56,7 @@ impl UdpSocket {
rustix::net::sockopt::set_socket_reuseaddr(&sock, true)?;
Ok(UdpSocketBuilder {
sock_addr,
bind_addr,
fd: sock,
})
}
@ -90,7 +91,7 @@ impl UdpSocket {
impl UdpSocketBuilder {
pub fn try_finish(self) -> Result<std::result::Result<UdpSocket, UdpSocketBuilder>> {
let res = match self.sock_addr {
let res = match self.bind_addr {
SocketAddr::V4(v4) => rustix::net::bind_v4(&self.fd, &v4),
SocketAddr::V6(v6) => rustix::net::bind_v6(&self.fd, &v6),
};
@ -125,9 +126,9 @@ impl TcpListener {
}
pub fn bind<A: Into<SocketAddr>>(socket_address: A) -> Result<TcpListenerBuilder> {
let sock_addr = socket_address.into();
let bind_addr = socket_address.into();
let family = match sock_addr.ip() {
let family = match bind_addr.ip() {
IpAddr::V4(_) => AddressFamily::INET,
IpAddr::V6(_) => AddressFamily::INET6,
};
@ -142,7 +143,7 @@ impl TcpListener {
rustix::net::sockopt::set_socket_reuseaddr(&sock, true)?;
Ok(TcpListenerBuilder {
sock_addr,
bind_addr,
fd: sock,
})
}
@ -150,10 +151,19 @@ impl TcpListener {
impl TcpStream {
pub fn connect<A: Into<SocketAddr>>(socket_address: A) -> Result<TcpStreamBuilder> {
let sock_addr = socket_address.into();
let family = match sock_addr.ip() {
IpAddr::V4(_) => AddressFamily::INET,
IpAddr::V6(_) => AddressFamily::INET6,
Self::connect_with_bind(socket_address, None as Option<SocketAddr>)
}
pub fn connect_with_bind<A: Into<SocketAddr>, B: Into<SocketAddr>>(
socket_address: A,
bind_address: Option<B>,
) -> Result<TcpStreamBuilder> {
let connect_addr = socket_address.into();
let bind_addr = bind_address.map(|a| a.into());
let family = match (connect_addr.ip(), bind_addr.map(|a| a.ip())) {
(IpAddr::V4(_), None | Some(IpAddr::V4(_))) => AddressFamily::INET,
(IpAddr::V6(_), None | Some(IpAddr::V6(_))) => AddressFamily::INET6,
_ => return Err(std::io::ErrorKind::InvalidInput.into()),
};
let sock = rustix::net::socket_with(
@ -164,7 +174,8 @@ impl TcpStream {
)?;
Ok(TcpStreamBuilder {
sock_addr,
connect_addr,
bind_addr,
fd: sock,
})
}
@ -180,7 +191,7 @@ impl TcpStream {
impl TcpListenerBuilder {
pub fn try_finish(self) -> Result<std::result::Result<TcpListener, TcpListenerBuilder>> {
let res = match self.sock_addr {
let res = match self.bind_addr {
SocketAddr::V4(v4) => rustix::net::bind_v4(&self.fd, &v4),
SocketAddr::V6(v6) => rustix::net::bind_v6(&self.fd, &v6),
};
@ -196,8 +207,22 @@ impl TcpListenerBuilder {
}
impl TcpStreamBuilder {
pub fn try_finish(self) -> Result<std::result::Result<TcpStream, TcpStreamBuilder>> {
let res = match self.sock_addr {
pub fn try_finish(mut self) -> Result<std::result::Result<TcpStream, TcpStreamBuilder>> {
if let Some(bind_addr) = self.bind_addr {
let res = match bind_addr {
SocketAddr::V4(v4) => rustix::net::bind_v4(&self.fd, &v4),
SocketAddr::V6(v6) => rustix::net::bind_v6(&self.fd, &v6),
};
if let InProgress::InProgress = in_progress(res)? {
return Ok(Err(self));
}
// we have succesfully bound, don't call bind again after this
self.bind_addr.take();
}
let res = match self.connect_addr {
SocketAddr::V4(v4) => rustix::net::connect_v4(&self.fd, &v4),
SocketAddr::V6(v6) => rustix::net::connect_v6(&self.fd, &v6),
};