Skip to content

Commit 815df41

Browse files
committed
chore(server): Vendor hyper-util graceful shutdown feature
1 parent 52a0f2f commit 815df41

File tree

2 files changed

+269
-34
lines changed

2 files changed

+269
-34
lines changed
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
// From https://github.com/hyperium/hyper-util/blob/7afb1ed5337c0689d7341e09d31578f1fcffc8af/src/server/graceful.rs,
2+
// implements Clone for GracefulShutdown.
3+
4+
use std::{
5+
fmt::{self, Debug},
6+
future::Future,
7+
pin::Pin,
8+
task::{self, Poll},
9+
};
10+
11+
use pin_project::pin_project;
12+
use tokio::sync::watch;
13+
14+
/// A graceful shutdown utility
15+
#[derive(Clone)]
16+
pub(super) struct GracefulShutdown {
17+
tx: watch::Sender<()>,
18+
}
19+
20+
impl GracefulShutdown {
21+
/// Create a new graceful shutdown helper.
22+
pub(super) fn new() -> Self {
23+
let (tx, _) = watch::channel(());
24+
Self { tx }
25+
}
26+
27+
/// Wrap a future for graceful shutdown watching.
28+
pub(super) fn watch<C: GracefulConnection>(&self, conn: C) -> impl Future<Output = C::Output> {
29+
let mut rx = self.tx.subscribe();
30+
GracefulConnectionFuture::new(conn, async move {
31+
let _ = rx.changed().await;
32+
// hold onto the rx until the watched future is completed
33+
rx
34+
})
35+
}
36+
37+
/// Signal shutdown for all watched connections.
38+
///
39+
/// This returns a `Future` which will complete once all watched
40+
/// connections have shutdown.
41+
pub(super) async fn shutdown(self) {
42+
let Self { tx } = self;
43+
44+
// signal all the watched futures about the change
45+
let _ = tx.send(());
46+
// and then wait for all of them to complete
47+
tx.closed().await;
48+
}
49+
}
50+
51+
impl Debug for GracefulShutdown {
52+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53+
f.debug_struct("GracefulShutdown").finish()
54+
}
55+
}
56+
57+
impl Default for GracefulShutdown {
58+
fn default() -> Self {
59+
Self::new()
60+
}
61+
}
62+
63+
#[pin_project]
64+
struct GracefulConnectionFuture<C, F: Future> {
65+
#[pin]
66+
conn: C,
67+
#[pin]
68+
cancel: F,
69+
#[pin]
70+
// If cancelled, this is held until the inner conn is done.
71+
cancelled_guard: Option<F::Output>,
72+
}
73+
74+
impl<C, F: Future> GracefulConnectionFuture<C, F> {
75+
fn new(conn: C, cancel: F) -> Self {
76+
Self {
77+
conn,
78+
cancel,
79+
cancelled_guard: None,
80+
}
81+
}
82+
}
83+
84+
impl<C, F: Future> Debug for GracefulConnectionFuture<C, F> {
85+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86+
f.debug_struct("GracefulConnectionFuture").finish()
87+
}
88+
}
89+
90+
impl<C, F> Future for GracefulConnectionFuture<C, F>
91+
where
92+
C: GracefulConnection,
93+
F: Future,
94+
{
95+
type Output = C::Output;
96+
97+
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
98+
let mut this = self.project();
99+
if this.cancelled_guard.is_none() {
100+
if let Poll::Ready(guard) = this.cancel.poll(cx) {
101+
this.cancelled_guard.set(Some(guard));
102+
this.conn.as_mut().graceful_shutdown();
103+
}
104+
}
105+
this.conn.poll(cx)
106+
}
107+
}
108+
109+
/// An internal utility trait as an umbrella target for all (hyper) connection
110+
/// types that the [`GracefulShutdown`] can watch.
111+
pub(super) trait GracefulConnection:
112+
Future<Output = Result<(), Self::Error>> + private::Sealed
113+
{
114+
/// The error type returned by the connection when used as a future.
115+
type Error;
116+
117+
/// Start a graceful shutdown process for this connection.
118+
fn graceful_shutdown(self: Pin<&mut Self>);
119+
}
120+
121+
impl<I, B, S> GracefulConnection for hyper::server::conn::http1::Connection<I, S>
122+
where
123+
S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
124+
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
125+
I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
126+
B: hyper::body::Body + 'static,
127+
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
128+
{
129+
type Error = hyper::Error;
130+
131+
fn graceful_shutdown(self: Pin<&mut Self>) {
132+
hyper::server::conn::http1::Connection::graceful_shutdown(self);
133+
}
134+
}
135+
136+
impl<I, B, S, E> GracefulConnection for hyper::server::conn::http2::Connection<I, S, E>
137+
where
138+
S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
139+
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
140+
I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
141+
B: hyper::body::Body + 'static,
142+
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
143+
E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
144+
{
145+
type Error = hyper::Error;
146+
147+
fn graceful_shutdown(self: Pin<&mut Self>) {
148+
hyper::server::conn::http2::Connection::graceful_shutdown(self);
149+
}
150+
}
151+
152+
impl<'a, I, B, S, E> GracefulConnection for hyper_util::server::conn::auto::Connection<'a, I, S, E>
153+
where
154+
S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
155+
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
156+
S::Future: 'static,
157+
I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
158+
B: hyper::body::Body + 'static,
159+
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
160+
E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
161+
{
162+
type Error = Box<dyn std::error::Error + Send + Sync>;
163+
164+
fn graceful_shutdown(self: Pin<&mut Self>) {
165+
hyper_util::server::conn::auto::Connection::graceful_shutdown(self);
166+
}
167+
}
168+
169+
impl<'a, I, B, S, E> GracefulConnection
170+
for hyper_util::server::conn::auto::UpgradeableConnection<'a, I, S, E>
171+
where
172+
S: hyper::service::Service<http::Request<hyper::body::Incoming>, Response = http::Response<B>>,
173+
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
174+
S::Future: 'static,
175+
I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
176+
B: hyper::body::Body + 'static,
177+
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
178+
E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
179+
{
180+
type Error = Box<dyn std::error::Error + Send + Sync>;
181+
182+
fn graceful_shutdown(self: Pin<&mut Self>) {
183+
hyper_util::server::conn::auto::UpgradeableConnection::graceful_shutdown(self);
184+
}
185+
}
186+
187+
mod private {
188+
pub(crate) trait Sealed {}
189+
190+
impl<I, B, S> Sealed for hyper::server::conn::http1::Connection<I, S>
191+
where
192+
S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
193+
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
194+
I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
195+
B: hyper::body::Body + 'static,
196+
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
197+
{
198+
}
199+
200+
impl<I, B, S> Sealed for hyper::server::conn::http1::UpgradeableConnection<I, S>
201+
where
202+
S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
203+
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
204+
I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
205+
B: hyper::body::Body + 'static,
206+
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
207+
{
208+
}
209+
210+
impl<I, B, S, E> Sealed for hyper::server::conn::http2::Connection<I, S, E>
211+
where
212+
S: hyper::service::HttpService<hyper::body::Incoming, ResBody = B>,
213+
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
214+
I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
215+
B: hyper::body::Body + 'static,
216+
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
217+
E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
218+
{
219+
}
220+
221+
impl<'a, I, B, S, E> Sealed for hyper_util::server::conn::auto::Connection<'a, I, S, E>
222+
where
223+
S: hyper::service::Service<
224+
http::Request<hyper::body::Incoming>,
225+
Response = http::Response<B>,
226+
>,
227+
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
228+
S::Future: 'static,
229+
I: hyper::rt::Read + hyper::rt::Write + Unpin + 'static,
230+
B: hyper::body::Body + 'static,
231+
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
232+
E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
233+
{
234+
}
235+
236+
impl<'a, I, B, S, E> Sealed for hyper_util::server::conn::auto::UpgradeableConnection<'a, I, S, E>
237+
where
238+
S: hyper::service::Service<
239+
http::Request<hyper::body::Incoming>,
240+
Response = http::Response<B>,
241+
>,
242+
S::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
243+
S::Future: 'static,
244+
I: hyper::rt::Read + hyper::rt::Write + Unpin + Send + 'static,
245+
B: hyper::body::Body + 'static,
246+
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
247+
E: hyper::rt::bounds::Http2ServerConnExec<S::Future, B>,
248+
{
249+
}
250+
}

tonic/src/transport/server/mod.rs

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Server implementation and builder.
22
33
mod conn;
4+
mod graceful;
45
mod incoming;
56
mod service;
67
#[cfg(feature = "tls")]
@@ -36,7 +37,10 @@ pub use incoming::TcpIncoming;
3637
#[cfg(feature = "tls")]
3738
use crate::transport::Error;
3839

39-
use self::service::{RecoverError, ServerIo};
40+
use self::{
41+
graceful::GracefulShutdown,
42+
service::{RecoverError, ServerIo},
43+
};
4044
use super::service::GrpcTimeout;
4145
use crate::body::{boxed, BoxBody};
4246
use crate::server::NamedService;
@@ -561,10 +565,7 @@ impl<L> Server<L> {
561565
builder
562566
};
563567

564-
let (signal_tx, signal_rx) = tokio::sync::watch::channel(());
565-
let signal_tx = Arc::new(signal_tx);
566-
567-
let graceful = signal.is_some();
568+
let graceful = signal.is_some().then(GracefulShutdown::new);
568569
let mut sig = pin!(Fuse { inner: signal });
569570
let mut incoming = pin!(incoming);
570571

@@ -600,21 +601,13 @@ impl<L> Server<L> {
600601
let hyper_io = TokioIo::new(io);
601602
let hyper_svc = TowerToHyperService::new(req_svc.map_request(|req: Request<Incoming>| req.map(boxed)));
602603

603-
serve_connection(hyper_io, hyper_svc, server.clone(), graceful.then(|| signal_rx.clone()));
604+
serve_connection(hyper_io, hyper_svc, server.clone(), graceful.clone());
604605
}
605606
}
606607
}
607608

608-
if graceful {
609-
let _ = signal_tx.send(());
610-
drop(signal_rx);
611-
trace!(
612-
"waiting for {} connections to close",
613-
signal_tx.receiver_count()
614-
);
615-
616-
// Wait for all connections to close
617-
signal_tx.closed().await;
609+
if let Some(graceful) = graceful {
610+
graceful.shutdown().await;
618611
}
619612

620613
Ok(())
@@ -627,7 +620,7 @@ fn serve_connection<B, IO, S, E>(
627620
hyper_io: IO,
628621
hyper_svc: S,
629622
builder: ConnectionBuilder<E>,
630-
mut watcher: Option<tokio::sync::watch::Receiver<()>>,
623+
graceful: Option<GracefulShutdown>,
631624
) where
632625
B: http_body::Body + Send + 'static,
633626
B::Data: Send,
@@ -640,28 +633,20 @@ fn serve_connection<B, IO, S, E>(
640633
{
641634
tokio::spawn(async move {
642635
{
643-
let mut sig = pin!(Fuse {
644-
inner: watcher.as_mut().map(|w| w.changed()),
645-
});
636+
let conn = builder.serve_connection(hyper_io, hyper_svc);
646637

647-
let mut conn = pin!(builder.serve_connection(hyper_io, hyper_svc));
638+
let result = if let Some(graceful) = graceful {
639+
let conn = graceful.watch(conn);
640+
conn.await
641+
} else {
642+
conn.await
643+
};
648644

649-
loop {
650-
tokio::select! {
651-
rv = &mut conn => {
652-
if let Err(err) = rv {
653-
debug!("failed serving connection: {:#}", err);
654-
}
655-
break;
656-
},
657-
_ = &mut sig => {
658-
conn.as_mut().graceful_shutdown();
659-
}
660-
}
645+
if let Err(err) = result {
646+
debug!("failed serving connection: {:#}", err);
661647
}
662648
}
663649

664-
drop(watcher);
665650
trace!("connection closed");
666651
});
667652
}

0 commit comments

Comments
 (0)