From 5662415b05c35130cbe720058e48a7ada5bcb1bc Mon Sep 17 00:00:00 2001 From: "Aode (lion)" Date: Mon, 6 Dec 2021 11:58:57 -0600 Subject: [PATCH] Don't lose parent context for response body --- src/lib.rs | 135 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 88 insertions(+), 47 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 0246f3d..ce79334 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tracing::{instrument::Instrumented, Instrument, Span}; +use tracing::{instrument::Instrumented, Id, Instrument, Span}; #[cfg(feature = "opentelemetry_0_13")] use opentelemetry_0_13_pkg as opentelemetry; @@ -138,6 +138,29 @@ where } } +enum IdOrSpan { + Empty, + Id(Id), + Span(Span), +} + +impl IdOrSpan { + fn from_id(id: Option) -> Self { + id.map(IdOrSpan::Id).unwrap_or(IdOrSpan::Empty) + } + + fn take(&mut self) -> Self { + std::mem::replace(self, IdOrSpan::Empty) + } + + fn as_span(&self) -> Option<&Span> { + match self { + IdOrSpan::Span(ref span) => Some(span), + _ => None, + } + } +} + pin_project_lite::pin_project! { pub struct TracingFuture { span: Span, @@ -159,58 +182,58 @@ where let future = this.future; let span = this.span; - future - .poll(cx) - .map_ok(|succ| match succ { - ConnectResponse::Client(client_response) => { - let status: i32 = client_response.status().as_u16().into(); - span.record("http.status_code", &status); - if client_response.status().is_client_error() { - span.record("otel.status_code", &"ERROR"); + span.in_scope(|| { + future + .poll(cx) + .map_ok(|succ| match succ { + ConnectResponse::Client(client_response) => { + let status: i32 = client_response.status().as_u16().into(); + span.record("http.status_code", &status); + if client_response.status().is_client_error() { + span.record("otel.status_code", &"ERROR"); + } + + ConnectResponse::Client(client_response.map_body(|_, payload| { + let instrumented = + InstrumentedBody::new(IdOrSpan::from_id(span.id()), payload); + let pinned: Pin>>> = + Box::pin(instrumented); + + Payload::Stream(pinned) + })) } - - let body_span = tracing::info_span!(parent: None, "HTTP Client Response Body"); - body_span.follows_from(span.clone()); - - ConnectResponse::Client(client_response.map_body(|_, payload| { - let instrumented = InstrumentedBody::new(body_span, payload); - let pinned: Pin>>> = - Box::pin(instrumented); - - Payload::Stream(pinned) - })) - } - ConnectResponse::Tunnel(response_head, etc) => { - let status: i32 = response_head.status.as_u16().into(); - span.record("http.status_code", &status); - if response_head.status.is_client_error() { - span.record("otel.status_code", &"ERROR"); + ConnectResponse::Tunnel(response_head, etc) => { + let status: i32 = response_head.status.as_u16().into(); + span.record("http.status_code", &status); + if response_head.status.is_client_error() { + span.record("otel.status_code", &"ERROR"); + } + ConnectResponse::Tunnel(response_head, etc) } - ConnectResponse::Tunnel(response_head, etc) - } - }) - .map_err(|err| { - span.record("otel.status_code", &"ERROR"); - span.record( - "exception.message", - &tracing::field::display(&format!("{}", err)), - ); - span.record( - "exception.details", - &tracing::field::display(&format!("{:?}", err)), - ); + }) + .map_err(|err| { + span.record("otel.status_code", &"ERROR"); + span.record( + "exception.message", + &tracing::field::display(&format!("{}", err)), + ); + span.record( + "exception.details", + &tracing::field::display(&format!("{:?}", err)), + ); - #[cfg(feature = "emit_event_on_error")] - tracing::warn!("Error in request: {}", err); + #[cfg(feature = "emit_event_on_error")] + tracing::warn!("Error in request: {}", err); - err - }) + err + }) + }) } } pin_project_lite::pin_project! { struct InstrumentedBody { - span: Span, + span: IdOrSpan, #[pin] body: S, @@ -221,7 +244,7 @@ impl InstrumentedBody where S: Stream>, { - fn new(span: Span, body: S) -> InstrumentedBody { + fn new(span: IdOrSpan, body: S) -> InstrumentedBody { InstrumentedBody { span, body } } } @@ -235,10 +258,28 @@ where fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.as_mut().project(); - let span = this.span; + let span = this.span.take(); let body = this.body; - span.in_scope(|| body.poll_next(cx)) + let span = match span { + IdOrSpan::Empty => { + return body.poll_next(cx); + } + IdOrSpan::Id(id) => { + let span = tracing::info_span!("HTTP Client Response Body"); + span.follows_from(id); + span + } + IdOrSpan::Span(span) => span, + }; + + *this.span = IdOrSpan::Span(span); + + if let Some(span) = this.span.as_span() { + span.in_scope(|| body.poll_next(cx)) + } else { + unreachable!("The span should always exist by this point") + } } }