Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(redirect): Coverting policy into tower-http policy. #2617

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ tower-service = "0.3"
futures-core = { version = "0.3.28", default-features = false }
futures-util = { version = "0.3.28", default-features = false }
sync_wrapper = { version = "1.0", features = ["futures"] }
tower-http = { git="https://github.com/firefantasy/tower-http.git", rev= "89599d88673bfae74a96ffa44e46df41661d36bc", default-features = false, features = ["follow-redirect"] }

# Optional deps...

Expand Down
127 changes: 76 additions & 51 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::any::Any;
use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::str::FromStr;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::Duration;
use std::{collections::HashMap, convert::TryInto, net::SocketAddr};
Expand Down Expand Up @@ -57,6 +58,9 @@ use quinn::VarInt;
use tokio::time::Sleep;
use tower::util::BoxCloneSyncServiceLayer;
use tower::{Layer, Service};
use tower_http::follow_redirect::policy::{
Action as TowerAction, Attempt as TowerAttempt, Policy as TowerPolicy,
};

type HyperResponseFuture = hyper_util::client::legacy::ResponseFuture;

Expand Down Expand Up @@ -795,6 +799,14 @@ impl ClientBuilder {

let proxies_maybe_http_auth = proxies.iter().any(|p| p.maybe_has_http_auth());

let redirect_policy_display = {
if config.redirect_policy.is_default() {
None
} else {
Some(format!("{:?}", &config.redirect_policy))
}
};

Ok(Client {
inner: Arc::new(ClientRef {
accepts: config.accepts,
Expand All @@ -811,13 +823,14 @@ impl ClientBuilder {
},
hyper: builder.build(connector_builder.build(config.connector_layers)),
headers: config.headers,
redirect_policy: config.redirect_policy,
redirect_policy: Mutex::new(config.redirect_policy.into_tower_policy()),
referer: config.referer,
read_timeout: config.read_timeout,
request_timeout: config.timeout,
proxies,
proxies_maybe_http_auth,
https_only: config.https_only,
redirect_policy_display,
}),
})
}
Expand Down Expand Up @@ -2416,13 +2429,16 @@ struct ClientRef {
hyper: HyperClient,
#[cfg(feature = "http3")]
h3_client: Option<H3Client>,
redirect_policy: redirect::Policy,
redirect_policy: Mutex<
Option<Box<dyn TowerPolicy<(), Box<dyn std::error::Error + Send + Sync>> + Send + Sync>>,
>,
referer: bool,
request_timeout: Option<Duration>,
read_timeout: Option<Duration>,
proxies: Arc<Vec<Proxy>>,
proxies_maybe_http_auth: bool,
https_only: bool,
redirect_policy_display: Option<String>,
}

impl ClientRef {
Expand All @@ -2443,8 +2459,8 @@ impl ClientRef {
f.field("proxies", &self.proxies);
}

if !self.redirect_policy.is_default() {
f.field("redirect_policy", &self.redirect_policy);
if let Some(msg) = &self.redirect_policy_display {
f.field("redirect_policy", &msg);
}

if self.referer {
Expand Down Expand Up @@ -2769,47 +2785,51 @@ impl Future for PendingRequest {
}
let url = self.url.clone();
self.as_mut().urls().push(url);
let action = self
.client
.redirect_policy
.check(res.status(), &loc, &self.urls);

match action {
redirect::ActionKind::Follow => {
debug!("redirecting '{}' to '{}'", self.url, loc);

if loc.scheme() != "http" && loc.scheme() != "https" {
return Poll::Ready(Err(error::url_bad_scheme(loc)));
}
let pervious = Uri::from_str(self.url.clone().as_str()).unwrap();
let next = Uri::from_str(loc.as_str()).unwrap();
let tower_attempt = TowerAttempt::new(res.status(), &next, &pervious);
let client = self.client.clone();
let mut tower_policy = client.redirect_policy.lock().unwrap();
if let Some(tower_policy) = tower_policy.as_deref_mut() {
match tower_policy.redirect(&tower_attempt) {
Ok(TowerAction::Follow) => {
debug!("redirecting '{}' to '{}'", self.url, loc);

if loc.scheme() != "http" && loc.scheme() != "https" {
return Poll::Ready(Err(error::url_bad_scheme(loc)));
}

if self.client.https_only && loc.scheme() != "https" {
return Poll::Ready(Err(error::redirect(
error::url_bad_scheme(loc.clone()),
loc,
)));
}
if self.client.https_only && loc.scheme() != "https" {
return Poll::Ready(Err(error::redirect(
error::url_bad_scheme(loc.clone()),
loc,
)));
}

self.url = loc;
let mut headers =
std::mem::replace(self.as_mut().headers(), HeaderMap::new());
self.url = loc;
let mut headers =
std::mem::replace(self.as_mut().headers(), HeaderMap::new());

remove_sensitive_headers(&mut headers, &self.url, &self.urls);
let uri = try_uri(&self.url)?;
let body = match self.body {
Some(Some(ref body)) => Body::reusable(body.clone()),
_ => Body::empty(),
};
remove_sensitive_headers(&mut headers, &self.url, &self.urls);
let uri = try_uri(&self.url)?;
let body = match self.body {
Some(Some(ref body)) => Body::reusable(body.clone()),
_ => Body::empty(),
};

// Add cookies from the cookie store.
#[cfg(feature = "cookies")]
{
if let Some(ref cookie_store) = self.client.cookie_store {
add_cookie_header(&mut headers, &**cookie_store, &self.url);
// Add cookies from the cookie store.
#[cfg(feature = "cookies")]
{
if let Some(ref cookie_store) = self.client.cookie_store {
add_cookie_header(&mut headers, &**cookie_store, &self.url);
}
}
}

*self.as_mut().in_flight().get_mut() =
match *self.as_mut().in_flight().as_ref() {
*self.as_mut().in_flight().get_mut() = match *self
.as_mut()
.in_flight()
.as_ref()
{
#[cfg(feature = "http3")]
ResponseFuture::H3(_) => {
let mut req = hyper::Request::builder()
Expand All @@ -2820,9 +2840,9 @@ impl Future for PendingRequest {
*req.headers_mut() = headers.clone();
std::mem::swap(self.as_mut().headers(), &mut headers);
ResponseFuture::H3(self.client.h3_client
.as_ref()
.expect("H3 client must exists, otherwise we can't have a h3 request here")
.request(req))
.as_ref()
.expect("H3 client must exists, otherwise we can't have a h3 request here")
.request(req))
}
_ => {
let mut req = hyper::Request::builder()
Expand All @@ -2835,15 +2855,20 @@ impl Future for PendingRequest {
ResponseFuture::Default(self.client.hyper.request(req))
}
};

continue;
}
redirect::ActionKind::Stop => {
debug!("redirect policy disallowed redirection to '{loc}'");
}
redirect::ActionKind::Error(err) => {
return Poll::Ready(Err(crate::error::redirect(err, self.url.clone())));
continue;
}
Ok(TowerAction::Stop) => {
debug!("redirect policy disallowed redirection to '{loc}'");
}
Err(err) => {
return Poll::Ready(Err(crate::error::redirect(
err,
self.url.clone(),
)));
}
}
} else {
debug!("redirect policy disallowed redirection to '{loc}'");
}
}
}
Expand Down
41 changes: 41 additions & 0 deletions src/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ use crate::header::{HeaderMap, AUTHORIZATION, COOKIE, PROXY_AUTHORIZATION, WWW_A
use hyper::StatusCode;

use crate::Url;
use tower_http::follow_redirect::policy::{
redirect_fn, Action as TowerAction, Attempt as TowerAttempt, Limited, Policy as TowerPolicy,
};

/// A type that controls the policy on how to handle the following of redirects.
///
Expand Down Expand Up @@ -150,6 +153,44 @@ impl Policy {
pub(crate) fn is_default(&self) -> bool {
matches!(self.inner, PolicyKind::Limit(10))
}

pub(crate) fn into_tower_policy(
self,
) -> Option<Box<dyn TowerPolicy<(), Box<dyn StdError + Send + Sync>> + Send + Sync>>
where {
match self.inner {
PolicyKind::Custom(custom) => {
let t = redirect_fn(move |attemp: &TowerAttempt<'_>| -> Result<TowerAction, Box<dyn StdError + Send + Sync>> {
Ok(match custom(Attempt {
status: attemp.status(),
next: &Url::parse(&attemp.location().to_string()).unwrap(),
previous: &[Url::parse(&attemp.previous().to_string()).unwrap()],
})
.inner
{
ActionKind::Follow => TowerAction::Follow,
ActionKind::Stop => TowerAction::Stop,
ActionKind::Error(err) => return Err(err),
})
});
Some(Box::new(t))
}
PolicyKind::Limit(max) => {
let mut policy = Limited::new(max);
let t = redirect_fn(move |attemp: &TowerAttempt<'_>| -> Result<TowerAction, Box<dyn StdError + Send + Sync>> {
match <dyn TowerPolicy<(), Box<dyn StdError + Send + Sync>>>::redirect(&mut policy, attemp)
{
Ok(TowerAction::Follow) => Ok(TowerAction::Follow),
Ok(TowerAction::Stop) => Err(Box::new(TooManyRedirects)),
Err(err) => Err(err),
}
});
Some(Box::new(t))
}

PolicyKind::None => None,
}
}
}

impl Default for Policy {
Expand Down
65 changes: 65 additions & 0 deletions tests/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,68 @@ async fn test_redirect_https_only_enforced_gh1312() {
let err = res.unwrap_err();
assert!(err.is_redirect());
}

#[tokio::test]
async fn test_redirect_limit() {
let server = server::http(move |req| async move {
let i: i32 = req
.uri()
.path()
.rsplit('/')
.next()
.unwrap()
.parse::<i32>()
.unwrap();
assert!(req.uri().path().ends_with(&format!("/redirect/{i}")));
http::Response::builder()
.status(302)
.header("location", format!("/redirect/{}", i + 1))
.body(Body::default())
.unwrap()
});

let client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::limited(3))
.build()
.unwrap();

let url = format!("http://{}/redirect/0", server.addr());
let res = client.get(&url).send().await.unwrap_err();
assert_eq!(
res.url().unwrap().as_str(),
format!("http://{}/redirect/3", server.addr()).as_str()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behavior is difference.

);
assert!(res.is_redirect());
}

#[tokio::test]
async fn test_redirect_custom() {
let server = server::http(move |req| async move {
assert!(req.uri().path().ends_with("/foo"));
http::Response::builder()
.status(302)
.header("location", "/should_not_be_called")
.body(Body::default())
.unwrap()
});

let url = format!("http://{}/foo", server.addr());

let res = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::custom(|attempt| {
if attempt.url().path().ends_with("/should_not_be_called") {
attempt.stop()
} else {
attempt.follow()
}
}))
.build()
.unwrap()
.get(&url)
.send()
.await
.unwrap();

assert_eq!(res.url().as_str(), url);
assert_eq!(res.status(), reqwest::StatusCode::FOUND);
}