use std::sync::Arc; use tokio::{ sync::{Notify, Semaphore}, task::JoinHandle, }; use crate::future::WithPollTimer; pub(crate) struct DropHandle { handle: JoinHandle, } pub(crate) fn abort_on_drop(handle: JoinHandle) -> DropHandle { DropHandle { handle } } impl DropHandle { pub(crate) fn abort(&self) { self.handle.abort(); } } impl Drop for DropHandle { fn drop(&mut self) { self.handle.abort(); } } impl std::future::Future for DropHandle { type Output = as std::future::Future>::Output; fn poll( mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll { std::pin::Pin::new(&mut self.handle).poll(cx) } } #[track_caller] pub(crate) fn channel( bound: usize, ) -> (tokio::sync::mpsc::Sender, tokio::sync::mpsc::Receiver) { let span = tracing::trace_span!(parent: None, "make channel"); let guard = span.enter(); let channel = tokio::sync::mpsc::channel(bound); drop(guard); channel } #[track_caller] pub(crate) fn notify() -> Arc { Arc::new(bare_notify()) } #[track_caller] pub(crate) fn bare_notify() -> Notify { let span = tracing::trace_span!(parent: None, "make notifier"); let guard = span.enter(); let notify = Notify::new(); drop(guard); notify } #[track_caller] pub(crate) fn bare_semaphore(permits: usize) -> Semaphore { let span = tracing::trace_span!(parent: None, "make semaphore"); let guard = span.enter(); let semaphore = Semaphore::new(permits); drop(guard); semaphore } // best effort cooperation mechanism pub(crate) async fn cooperate() { #[cfg(tokio_unstable)] tokio::task::consume_budget().await; #[cfg(not(tokio_unstable))] tokio::task::yield_now().await; } #[track_caller] pub(crate) fn spawn(name: &'static str, future: F) -> tokio::task::JoinHandle where F: std::future::Future + 'static, F::Output: 'static, { let future = future.with_poll_timer(name); let span = tracing::trace_span!(parent: None, "spawn task"); let guard = span.enter(); #[cfg(tokio_unstable)] let handle = tokio::task::Builder::new() .name(name) .spawn_local(future) .expect("Failed to spawn"); #[cfg(not(tokio_unstable))] let handle = tokio::task::spawn_local(future); drop(guard); handle } #[track_caller] pub(crate) fn spawn_sendable(name: &'static str, future: F) -> tokio::task::JoinHandle where F: std::future::Future + Send + 'static, F::Output: Send + 'static, { let future = future.with_poll_timer(name); let span = tracing::trace_span!(parent: None, "spawn task"); let guard = span.enter(); #[cfg(tokio_unstable)] let handle = tokio::task::Builder::new() .name(name) .spawn(future) .expect("Failed to spawn"); #[cfg(not(tokio_unstable))] let handle = tokio::task::spawn(future); drop(guard); handle } #[track_caller] pub(crate) fn spawn_blocking(name: &str, function: F) -> tokio::task::JoinHandle where F: FnOnce() -> Out + Send + 'static, Out: Send + 'static, { #[cfg(not(tokio_unstable))] let _ = name; let outer_span = tracing::Span::current(); let span = tracing::trace_span!(parent: None, "spawn blocking task"); let guard = span.enter(); #[cfg(tokio_unstable)] let handle = tokio::task::Builder::new() .name(name) .spawn_blocking(move || outer_span.in_scope(function)) .expect("Failed to spawn"); #[cfg(not(tokio_unstable))] let handle = tokio::task::spawn_blocking(move || outer_span.in_scope(function)); drop(guard); handle }