diff --git a/Cargo.toml b/Cargo.toml index ccce032..7ac916c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,10 @@ required-features = ["server"] name = "axum" required-features = ["server"] +[[example]] +name = "client" +required-features = ["client"] + # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] default = ["client", "server"] @@ -20,7 +24,6 @@ client = [ "futures-io", "hyper/client", "jive/futures-io-compat", - "tower-service", "trust-dns-proto", "trust-dns-resolver", ] @@ -33,7 +36,6 @@ jive = { git = "https://git.asonix.dog/safe-async/jive", features = [ "tokio-io-compat", ] } tokio = { version = "1", default-features = false } -tower-service = { version = "0.3.1", optional = true } trust-dns-proto = { version = "0.21.0", default-features = false, optional = true } trust-dns-resolver = { version = "0.21.1", default-features = false, features = [ "system-config", diff --git a/examples/client.rs b/examples/client.rs new file mode 100644 index 0000000..8ec5ce2 --- /dev/null +++ b/examples/client.rs @@ -0,0 +1,14 @@ +fn main() -> Result<(), Box> { + jive::block_on(async { + let client = hyperjive::client::new()?; + + let res = client.get("http://httpbin.org/ip".parse()?).await?; + + println!("GOT: {}", res.status()); + + let buf = hyper::body::to_bytes(res).await?; + let s = String::from_utf8_lossy(&buf); + println!("{}", s); + Ok(()) + }) +} diff --git a/src/lib.rs b/src/lib.rs index b42f720..82584f4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,46 +1,27 @@ -use std::future::Future; - -#[derive(Clone)] -pub struct JiveRuntime; - -pub struct JiveTcpStream { - io: jive::io::Async, -} - -impl hyper::rt::Executor for JiveRuntime -where - Fut: Future + Send + 'static, -{ - fn execute(&self, fut: Fut) { - jive::spawn(async move { - fut.await; - }); - } -} - -#[cfg(feature = "server")] -pub mod server { - use super::JiveTcpStream; +pub mod common { use std::{ + future::Future, pin::Pin, task::{Context, Poll}, }; use tokio::io::{AsyncRead, AsyncWrite}; - pub struct JiveIncoming { - io: jive::io::Async, + #[derive(Clone)] + pub struct JiveRuntime; + + pub struct JiveTcpStream { + pub(crate) io: jive::io::Async, } - pub async fn builder>( - socket_address: A, - ) -> std::io::Result> { - let listener = jive::io::Async::::bind(socket_address).await?; - - let incoming = JiveIncoming { io: listener }; - - let http = hyper::server::conn::Http::new().with_executor(super::JiveRuntime); - - Ok(hyper::server::Builder::new(incoming, http)) + impl hyper::rt::Executor for JiveRuntime + where + Fut: Future + Send + 'static, + { + fn execute(&self, fut: Fut) { + jive::spawn(async move { + fut.await; + }); + } } impl AsyncRead for JiveTcpStream { @@ -70,6 +51,31 @@ pub mod server { Pin::new(&mut self.get_mut().io).poll_shutdown(cx) } } +} + +#[cfg(feature = "server")] +pub mod server { + use crate::common::{JiveRuntime, JiveTcpStream}; + use std::{ + pin::Pin, + task::{Context, Poll}, + }; + + pub struct JiveIncoming { + io: jive::io::Async, + } + + pub async fn builder>( + socket_address: A, + ) -> std::io::Result> { + let listener = jive::io::Async::::bind(socket_address).await?; + + let incoming = JiveIncoming { io: listener }; + + let http = hyper::server::conn::Http::new().with_executor(JiveRuntime); + + Ok(hyper::server::Builder::new(incoming, http)) + } impl hyper::server::accept::Accept for JiveIncoming { type Conn = JiveTcpStream; @@ -90,8 +96,328 @@ pub mod server { #[cfg(feature = "client")] pub mod client { + use crate::common::{JiveRuntime, JiveTcpStream}; + use dns::JiveResolver; + use hyper::{service::Service, Uri}; + use std::{ + future::Future, + net::{IpAddr, SocketAddr}, + pin::Pin, + task::{Context, Poll}, + }; + + #[derive(Clone)] + pub struct ConnectService { + resolver: R, + } + + #[derive(Clone)] + pub struct ResolverService { + resolver: JiveResolver, + } + + #[derive(Debug)] + pub enum ResolveError { + Dns(trust_dns_resolver::error::ResolveError), + Invalid, + } + + pub enum ConnectFuture { + Resolving { + fut: F, + }, + Connecting { + fut: BoxFuture<'static, std::io::Result>, + }, + Complete, + } + + pub enum ResolverFuture { + Pending { + port: u16, + fut: BoxFuture< + 'static, + Result< + trust_dns_resolver::lookup_ip::LookupIp, + trust_dns_resolver::error::ResolveError, + >, + >, + }, + Resolved { + port: u16, + addr: IpAddr, + }, + Invalid, + Complete, + } + + type BoxFuture<'a, T> = Pin + Send + 'a>>; + + pub fn new() -> Result< + hyper::Client>, + trust_dns_resolver::error::ResolveError, + > { + Ok(builder().build(connect_service()?)) + } + + pub fn builder() -> hyper::client::Builder { + let mut builder = hyper::Client::builder(); + builder.executor(JiveRuntime); + builder + } + + pub fn connect_service( + ) -> Result, trust_dns_resolver::error::ResolveError> { + ConnectService::new() + } + + pub fn connect_service_with_resolver(resolver: R) -> ConnectService + where + R: Service, + R::Response: IntoIterator, + R::Error: Into>, + ConnectService: Service, + { + ConnectService::new_with_resolver(resolver) + } + + impl ConnectService { + pub fn new() -> Result { + Ok(Self::new_with_resolver(ResolverService::from_system_conf()?)) + } + } + + impl ConnectService + where + R: Service, + R::Response: IntoIterator, + R::Error: Into>, + Self: Service, + { + pub fn new_with_resolver(resolver: R) -> Self { + Self { resolver } + } + } + + impl JiveTcpStream { + async fn connect>( + addrs: A, + ) -> std::io::Result { + let mut last_error = None; + + for addr in addrs.into_iter() { + match jive::io::Async::::connect(addr).await { + Ok(io) => return Ok(JiveTcpStream { io }), + Err(e) => last_error = Some(e), + } + } + + Err(last_error.unwrap_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "No addresses supplied to TcpStream", + ) + })) + } + } + + impl hyper::client::connect::Connection for JiveTcpStream { + fn connected(&self) -> hyper::client::connect::Connected { + hyper::client::connect::Connected::new() + } + } + + impl ResolverService { + pub fn from_system_conf() -> Result { + Ok(ResolverService { + resolver: dns::resolver_from_system_conf()?, + }) + } + + pub fn new( + config: trust_dns_resolver::config::ResolverConfig, + options: trust_dns_resolver::config::ResolverOpts, + ) -> Result { + Ok(ResolverService { + resolver: dns::resolver(config, options)?, + }) + } + } + + impl Service for ConnectService + where + R: Service, + R::Response: IntoIterator + Send + 'static, + ::IntoIter: Send, + R::Error: Into>, + R::Future: Unpin, + { + type Response = JiveTcpStream; + type Error = Box; + type Future = ConnectFuture; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Uri) -> Self::Future { + ConnectFuture::new(self.resolver.call(req)) + } + } + + impl Service for ResolverService { + type Response = Vec; + type Error = ResolveError; + type Future = ResolverFuture; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Uri) -> Self::Future { + ResolverFuture::new(self.resolver.clone(), req) + } + } + + impl std::fmt::Display for ResolveError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Dns(_) => write!(f, "Error in DNS resolution"), + Self::Invalid => write!(f, "Invalid URI"), + } + } + } + + impl std::error::Error for ResolveError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Dns(dns) => Some(dns), + Self::Invalid => None, + } + } + } + + impl From for ResolveError { + fn from(e: trust_dns_resolver::error::ResolveError) -> Self { + ResolveError::Dns(e) + } + } + + impl ResolverFuture { + fn new(resolver: JiveResolver, uri: Uri) -> Self { + let host = if let Some(host) = uri.host() { + host + } else { + return ResolverFuture::Invalid; + }; + + let port = if let Some(port) = uri.port_u16() { + port + } else if let Some(scheme) = uri.scheme_str() { + if scheme == "http" { + 80 + } else if scheme == "https" { + 443 + } else { + return ResolverFuture::Invalid; + } + } else { + return ResolverFuture::Invalid; + }; + + if let Ok(addr) = host.parse() { + return ResolverFuture::Resolved { port, addr }; + } + + let host = host.to_string(); + + ResolverFuture::Pending { + port, + fut: Box::pin(async move { resolver.lookup_ip(host).await }), + } + } + } + + impl Future for ResolverFuture { + type Output = Result, ResolveError>; + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + match std::mem::replace(this, ResolverFuture::Complete) { + ResolverFuture::Pending { port, mut fut } => match fut.as_mut().poll(cx) { + Poll::Ready(Ok(lookup_ip)) => { + Poll::Ready(Ok(lookup_ip.iter().map(|ip| (ip, port).into()).collect())) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Pending => { + *this = ResolverFuture::Pending { port, fut }; + Poll::Pending + } + }, + ResolverFuture::Resolved { port, addr } => { + Poll::Ready(Ok(vec![(addr, port).into()])) + } + ResolverFuture::Invalid => Poll::Ready(Err(ResolveError::Invalid)), + ResolverFuture::Complete => panic!("ResolverFuture polled after completion"), + } + } + } + + impl ConnectFuture + where + F: Future> + Unpin, + T: IntoIterator + Send + 'static, + T::IntoIter: Send, + E: Into>, + Self: Future, + { + fn new(fut: F) -> Self { + ConnectFuture::Resolving { fut } + } + } + + impl Future for ConnectFuture + where + F: Future> + Unpin, + T: IntoIterator + Send + 'static, + T::IntoIter: Send, + E: Into>, + { + type Output = Result>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + match std::mem::replace(this, ConnectFuture::Complete) { + ConnectFuture::Resolving { mut fut } => match Pin::new(&mut fut).poll(cx) { + Poll::Ready(Ok(addrs)) => { + *this = ConnectFuture::Connecting { + fut: Box::pin(async move { JiveTcpStream::connect(addrs).await }), + }; + + Pin::new(this).poll(cx) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())), + Poll::Pending => { + *this = ConnectFuture::Resolving { fut }; + Poll::Pending + } + }, + ConnectFuture::Connecting { mut fut } => match fut.as_mut().poll(cx) { + Poll::Ready(res) => Poll::Ready(res.map_err(From::from)), + Poll::Pending => { + *this = ConnectFuture::Connecting { fut }; + Poll::Pending + } + }, + ConnectFuture::Complete => panic!("ConnectFuture polled after completion"), + } + } + } + pub mod dns { - use crate::{JiveRuntime, JiveTcpStream}; + use crate::common::{JiveRuntime, JiveTcpStream}; use futures_io::{AsyncRead, AsyncWrite}; use std::{ future::Future, @@ -100,7 +426,7 @@ pub mod client { }; use trust_dns_resolver::config::{ResolverConfig, ResolverOpts}; - pub async fn resolver( + pub fn resolver( config: ResolverConfig, options: ResolverOpts, ) -> Result { @@ -151,7 +477,7 @@ pub mod client { bind_addr: Option, ) -> std::io::Result { if let Some(_bind_addr) = bind_addr { - todo!("Implement connect with binnd"); + todo!("Implement connect with bind"); } else { let io = jive::io::Async::::connect(addr).await?;