diff --git a/Cargo.lock b/Cargo.lock index c3aae4a3e7435..ee8e4aa178457 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -624,7 +624,7 @@ dependencies = [ "once_cell", "rand 0.8.5", "regex", - "ring", + "ring 0.17.5", "rustls 0.21.11", "rustls-native-certs 0.6.3", "rustls-pemfile 1.0.3", @@ -786,7 +786,7 @@ dependencies = [ "hex", "http 0.2.9", "hyper 0.14.28", - "ring", + "ring 0.17.5", "time", "tokio", "tracing 0.1.40", @@ -8159,6 +8159,21 @@ dependencies = [ "subtle", ] +[[package]] +name = "ring" +version = "0.16.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc" +dependencies = [ + "cc", + "libc", + "once_cell", + "spin 0.5.2", + "untrusted 0.7.1", + "web-sys", + "winapi", +] + [[package]] name = "ring" version = "0.17.5" @@ -8169,7 +8184,7 @@ dependencies = [ "getrandom 0.2.15", "libc", "spin 0.9.8", - "untrusted", + "untrusted 0.9.0", "windows-sys 0.48.0", ] @@ -8409,6 +8424,18 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rustls" +version = "0.20.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99" +dependencies = [ + "log", + "ring 0.16.20", + "sct", + "webpki", +] + [[package]] name = "rustls" version = "0.21.11" @@ -8416,7 +8443,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fecbfb7b1444f477b345853b1fce097a2c6fb637b2bfb87e6bc5db0f043fae4" dependencies = [ "log", - "ring", + "ring 0.17.5", "rustls-webpki 0.101.7", "sct", ] @@ -8428,7 +8455,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" dependencies = [ "log", - "ring", + "ring 0.17.5", "rustls-pki-types", "rustls-webpki 0.102.2", "subtle", @@ -8460,6 +8487,15 @@ dependencies = [ "security-framework", ] +[[package]] +name = "rustls-pemfile" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee86d63972a7c661d1536fefe8c3c8407321c3df668891286de28abcd087360" +dependencies = [ + "base64 0.13.1", +] + [[package]] name = "rustls-pemfile" version = "1.0.3" @@ -8491,8 +8527,8 @@ version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring", - "untrusted", + "ring 0.17.5", + "untrusted 0.9.0", ] [[package]] @@ -8501,9 +8537,9 @@ version = "0.102.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" dependencies = [ - "ring", + "ring 0.17.5", "rustls-pki-types", - "untrusted", + "untrusted 0.9.0", ] [[package]] @@ -8616,8 +8652,8 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring", - "untrusted", + "ring 0.17.5", + "untrusted 0.9.0", ] [[package]] @@ -9827,6 +9863,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.23.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" +dependencies = [ + "rustls 0.20.9", + "tokio", + "webpki", +] + [[package]] name = "tokio-rustls" version = "0.24.1" @@ -10597,6 +10644,12 @@ version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "673aac59facbab8a9007c7f6108d11f63b603f7cabff99fabf650fea5c32b861" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + [[package]] name = "untrusted" version = "0.9.0" @@ -10856,6 +10909,8 @@ dependencies = [ "roaring", "rstest", "rumqttc", + "rustls 0.20.9", + "rustls-pemfile 0.3.0", "seahash", "semver 1.0.23", "serde", @@ -10879,6 +10934,7 @@ dependencies = [ "tokio", "tokio-openssl", "tokio-postgres", + "tokio-rustls 0.23.4", "tokio-stream", "tokio-test", "tokio-tungstenite 0.20.1", @@ -11596,6 +11652,16 @@ dependencies = [ "web-sys", ] +[[package]] +name = "webpki" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" +dependencies = [ + "ring 0.17.5", + "untrusted 0.9.0", +] + [[package]] name = "webpki-roots" version = "0.25.2" diff --git a/Cargo.toml b/Cargo.toml index ac1f4b06cf4c9..a6690db8417b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -414,6 +414,10 @@ vrl.workspace = true wiremock = "0.6.2" zstd = { version = "0.13.0", default-features = false } +tokio-rustls = "0.23" +rustls = "0.20" +rustls-pemfile = "0.3" + [patch.crates-io] # The upgrade for `tokio-util` >= 0.6.9 is blocked on https://github.com/vectordotdev/vector/issues/11257. tokio-util = { git = "https://github.com/vectordotdev/tokio", branch = "tokio-util-0.7.11-framed-read-continue-on-error" } diff --git a/changelog.d/add_automatic_bearer_token_acquisition_in_http_client.enhancement.md b/changelog.d/add_automatic_bearer_token_acquisition_in_http_client.enhancement.md new file mode 100644 index 0000000000000..5795f8504ced2 --- /dev/null +++ b/changelog.d/add_automatic_bearer_token_acquisition_in_http_client.enhancement.md @@ -0,0 +1,4 @@ +The `http_client` can now acquire a bearer token using OAuth2 protocol, cache it and refresh before token expires. +OAuth2 and mTLS extension are supported in this implementation. + +authors: KowalczykBartek diff --git a/src/http.rs b/src/http.rs index f8e1c939c58c8..c35dd396af5cb 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,4 +1,6 @@ #![allow(missing_docs)] +use async_trait::async_trait; +use bytes::Buf; use futures::future::BoxFuture; use headers::{Authorization, HeaderMapExt}; use http::{ @@ -13,13 +15,16 @@ use hyper::{ use hyper_openssl::HttpsConnector; use hyper_proxy::ProxyConnector; use rand::Rng; +use serde::Deserialize; use serde_with::serde_as; use snafu::{ResultExt, Snafu}; use std::{ + error::Error, fmt, net::SocketAddr, + sync::{Arc, Mutex}, task::{Context, Poll}, - time::Duration, + time::{Duration, SystemTime, UNIX_EPOCH}, }; use tokio::time::Instant; use tower::{Layer, Service}; @@ -28,8 +33,11 @@ use tower_http::{ trace::TraceLayer, }; use tracing::{Instrument, Span}; -use vector_lib::configurable::configurable_component; use vector_lib::sensitive_string::SensitiveString; +use vector_lib::{ + configurable::configurable_component, + tls::{TlsConfig, TlsSettings}, +}; use crate::{ config::ProxyConfig, @@ -56,6 +64,10 @@ pub enum HttpError { CallRequest { source: hyper::Error }, #[snafu(display("Failed to build HTTP request: {}", source))] BuildRequest { source: http::Error }, + #[snafu(display("Failed to acquire authentication resource."))] + AuthenticationExtension { + source: Box, + }, } impl HttpError { @@ -64,6 +76,7 @@ impl HttpError { HttpError::BuildRequest { .. } | HttpError::MakeProxyConnector { .. } => false, HttpError::CallRequest { .. } | HttpError::BuildTlsConnector { .. } + | HttpError::AuthenticationExtension { .. } | HttpError::MakeHttpsConnector { .. } => true, } } @@ -72,31 +85,270 @@ impl HttpError { pub type HttpClientFuture = >>::Future; type HttpProxyConnector = ProxyConnector>; +#[async_trait] +trait AuthExtension: Send + Sync +where + B: fmt::Debug + HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into + Send, +{ + async fn modify_request(&self, req: &mut Request) -> Result<(), vector_lib::Error>; +} + +#[derive(Clone)] +struct OAuth2Extension { + token_endpoint: String, + client_id: String, + client_secret: Option, + grace_period: u32, + client: Client, + token: Arc>>, +} + +#[derive(Clone)] +struct BasicAuthExtension { + user: String, + password: SensitiveString, +} + +#[derive(Debug, Deserialize)] +struct Token { + access_token: String, + // This property, according to RFC, is expected to be in seconds. + expires_in: u32, +} + +#[derive(Debug, Clone)] +struct ExpirableToken { + access_token: String, + expires_after_ms: u128, +} + +impl OAuth2Extension { + /// Creates a new `OAuth2Extension`. + fn new( + token_endpoint: String, + client_id: String, + client_secret: Option, + grace_period: u32, + client: Client, + ) -> OAuth2Extension { + let initial_empty_token = Arc::new(Mutex::new(None)); + OAuth2Extension { + token_endpoint, + client_id, + client_secret, + grace_period, + client, + token: initial_empty_token, + } + } + + async fn get_token(&self) -> Result { + if let Some(token) = self.acquire_token_from_cache() { + return Ok(token.access_token); + } + + //no valid token in cache (or no token at all) + let new_token = self.request_token().await?; + let token_to_return = new_token.access_token.clone(); + self.save_into_cache(new_token); + + Ok(token_to_return) + } + + fn acquire_token_from_cache(&self) -> Option { + let maybe_token = self.token.lock().expect("Poisoned token lock"); + match &*maybe_token { + Some(token) => { + let time_now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + if time_now.as_millis() < token.expires_after_ms { + //we have token, token is valid for at least 1min, we can use it. + return Some(token.clone()); + } + + None + } + _ => None, + } + } + + fn save_into_cache(&self, token: ExpirableToken) { + self.token + .lock() + .expect("Poisoned token lock") + .replace(token); + } + + async fn request_token( + &self, + ) -> Result> { + let mut request_body = + format!("grant_type=client_credentials&client_id={}", self.client_id); + + // in case of oauth2 with mTLS (https://datatracker.ietf.org/doc/html/rfc8705) we only pass client_id, + // so secret can be considered as optional. + if let Some(client_secret) = &self.client_secret { + let secret_param = format!("&client_secret={}", client_secret.inner()); + request_body.push_str(&secret_param); + } + + let builder = Request::post(self.token_endpoint.clone()); + let builder = builder.header("Content-Type", "application/x-www-form-urlencoded"); + let request = builder + .body(Body::from(request_body)) + .expect("Error creating request"); + + let before = std::time::Instant::now(); + let response_result = self.client.request(request).await; + let roundtrip = before.elapsed(); + + let response = response_result.inspect_err(|error| { + emit!(http_client::GotHttpWarning { error, roundtrip }); + })?; + + emit!(http_client::GotHttpResponse { + response: &response, + roundtrip + }); + + if !response.status().is_success() { + let body_bytes = hyper::body::aggregate(response).await?; + let body_str = std::str::from_utf8(body_bytes.chunk())?.to_string(); + return Err(Box::new(AcquireTokenError { message: body_str })); + } + + let body = hyper::body::aggregate(response).await?; + let token: Token = serde_json::from_reader(body.reader())?; + + let time_now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards"); + let token_will_expire_after_ms = + OAuth2Extension::calculate_valid_until(time_now, self.grace_period, &token); + + Ok(ExpirableToken { + access_token: token.access_token, + expires_after_ms: token_will_expire_after_ms, + }) + } + + const fn calculate_valid_until(now: Duration, grace_period: u32, token: &Token) -> u128 { + // 'expires_in' means, in seconds, for how long it will be valid, lets say 5min, + // to not cause some random 4xx, because token expired in the meantime, we will make some + // room for token refreshing, this room is a grace_period. + let (mut grace_period_seconds, overflow) = token.expires_in.overflowing_sub(grace_period); + + // If time for grace period exceed an expire_in, it basically means: always use new token. + if overflow { + grace_period_seconds = 0; + } + + // We are multiplying by 1000 because expires_in field is in seconds(oauth standard), grace_period also, + // but later we operate on milliseconds. + let token_is_valid_until_ms: u128 = grace_period_seconds as u128 * 1000; + let now_millis = now.as_millis(); + + now_millis + token_is_valid_until_ms + } +} + +#[derive(Debug)] +pub struct AcquireTokenError { + message: String, +} + +impl fmt::Display for AcquireTokenError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Server error from authentication server: {}", + self.message + ) + } +} + +impl Error for AcquireTokenError {} + +#[async_trait] +impl AuthExtension for OAuth2Extension +where + B: fmt::Debug + HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into + Send, +{ + async fn modify_request(&self, req: &mut Request) -> Result<(), vector_lib::Error> { + let token = self.get_token().await?; + let auth = Auth::Bearer { + token: SensitiveString::from(token), + }; + auth.apply(req); + + Ok(()) + } +} + +#[async_trait] +impl AuthExtension for BasicAuthExtension +where + B: fmt::Debug + HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into + Send, +{ + async fn modify_request(&self, req: &mut Request) -> Result<(), vector_lib::Error> { + let user = self.user.clone(); + let password = self.password.clone(); + + let auth = Auth::Basic { user, password }; + auth.apply(req); + + Ok(()) + } +} + pub struct HttpClient { client: Client, user_agent: HeaderValue, proxy_connector: HttpProxyConnector, + auth_extension: Option>>, } impl HttpClient where B: fmt::Debug + HttpBody + Send + 'static, B::Data: Send, - B::Error: Into, + B::Error: Into + Send, { pub fn new( tls_settings: impl Into, proxy_config: &ProxyConfig, ) -> Result, HttpError> { - HttpClient::new_with_custom_client(tls_settings, proxy_config, &mut Client::builder()) + HttpClient::new_with_custom_client(tls_settings, proxy_config, &mut Client::builder(), None) + } + + pub fn new_with_auth_extension( + tls_settings: impl Into, + proxy_config: &ProxyConfig, + auth_config: Option, + ) -> Result, HttpError> { + HttpClient::new_with_custom_client( + tls_settings, + proxy_config, + &mut Client::builder(), + auth_config, + ) } pub fn new_with_custom_client( tls_settings: impl Into, proxy_config: &ProxyConfig, client_builder: &mut client::Builder, + auth_config: Option, ) -> Result, HttpError> { let proxy_connector = build_proxy_connector(tls_settings.into(), proxy_config)?; + let auth_extension = build_auth_extension(auth_config, proxy_config, client_builder); let client = client_builder.build(proxy_connector.clone()); let app_name = crate::get_app_name(); @@ -108,6 +360,7 @@ where client, user_agent, proxy_connector, + auth_extension, }) } @@ -121,11 +374,26 @@ where default_request_headers(&mut request, &self.user_agent); self.maybe_add_proxy_headers(&mut request); - emit!(http_client::AboutToSendHttpRequest { request: &request }); - - let response = self.client.request(request); + let client = self.client.clone(); + let auth_extension = self.auth_extension.clone(); let fut = async move { + if let Some(auth_extension) = auth_extension { + let auth_span = tracing::info_span!("auth_extension"); + auth_extension + .modify_request(&mut request) + .instrument(auth_span.clone().or_current()) + .await + .inspect_err(|error| { + // Emit the error into the internal events system. + emit!(http_client::AuthExtensionError { error }); + }) + .context(AuthenticationExtensionSnafu)?; + } + + emit!(http_client::AboutToSendHttpRequest { request: &request }); + let response: client::ResponseFuture = client.request(request); + // Capture the time right before we issue the request. // Request doesn't start the processing until we start polling it. let before = std::time::Instant::now(); @@ -169,6 +437,50 @@ where } } +fn build_auth_extension( + authorization_config: Option, + proxy_config: &ProxyConfig, + client_builder: &mut client::Builder, +) -> Option>> +where + B: fmt::Debug + HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into + Send, +{ + if let Some(authorization_config) = authorization_config { + match authorization_config.strategy { + HttpClientAuthorizationStrategy::Basic { user, password } => { + let basic_auth_extension = BasicAuthExtension { user, password }; + return Some(Arc::new(basic_auth_extension)); + } + HttpClientAuthorizationStrategy::OAuth2 { + token_endpoint, + client_id, + client_secret, + grace_period, + } => { + let tls_for_auth = authorization_config.tls.clone(); + let tls_for_auth: TlsSettings = TlsSettings::from_options(&tls_for_auth).unwrap(); + + let auth_proxy_connector = + build_proxy_connector(tls_for_auth.into(), proxy_config).unwrap(); + let auth_client = client_builder.build(auth_proxy_connector.clone()); + + let oauth2_extension = OAuth2Extension::new( + token_endpoint, + client_id, + client_secret, + grace_period, + auth_client, + ); + return Some(Arc::new(oauth2_extension)); + } + } + } + + None +} + pub fn build_proxy_connector( tls_settings: MaybeTlsSettings, proxy_config: &ProxyConfig, @@ -249,6 +561,7 @@ impl Clone for HttpClient { client: self.client.clone(), user_agent: self.user_agent.clone(), proxy_connector: self.proxy_connector.clone(), + auth_extension: self.auth_extension.clone(), } } } @@ -262,6 +575,105 @@ impl fmt::Debug for HttpClient { } } +/// Configuration for HTTP client providing an authentication mechanism. +#[configurable_component] +#[configurable(metadata(docs::advanced))] +#[derive(Clone, Debug)] +#[serde(deny_unknown_fields)] +pub struct AuthorizationConfig { + /// Define how to authorize against an upstream. + #[configurable] + strategy: HttpClientAuthorizationStrategy, + + /// The TLS settings for the http client's connection. + /// + /// Optional, constrains TLS settings for this http client. + #[configurable(derived)] + tls: Option, +} + +/// Configuration of the authentication strategy for HTTP requests. +/// +/// HTTP authentication should be used with HTTPS only, as the authentication credentials are passed as an +/// HTTP header without any additional encryption beyond what is provided by the transport itself. +#[configurable_component] +#[derive(Clone, Debug)] +#[serde(deny_unknown_fields, rename_all = "snake_case", tag = "strategy")] +#[configurable(metadata(docs::enum_tag_description = "The authentication strategy to use."))] +pub enum HttpClientAuthorizationStrategy { + /// Basic authentication. + /// + /// The username and password are concatenated and encoded via [base64][base64]. + /// + /// [base64]: https://en.wikipedia.org/wiki/Base64 + Basic { + /// The basic authentication username. + #[configurable(metadata(docs::examples = "username"))] + user: String, + + /// The basic authentication password. + #[configurable(metadata(docs::examples = "password"))] + password: SensitiveString, + }, + + /// Authentication based on OAuth 2.0 protocol. + /// + /// This strategy allows to dynamically acquire and use token based on provided parameters. + /// Both standard client_credentials and mTLS extension is supported, for standard client_credentials just provide both + /// client_id and client_secret parameters: + /// + /// # Example + /// + /// ```yaml + /// strategy: + /// strategy: "o_auth2" + /// client_id: "client.id" + /// client_secret: "secret-value" + /// token_endpoint: "https://yourendpoint.com/oauth/token" + /// ``` + /// In case you want to use mTLS extension [rfc8705](https://datatracker.ietf.org/doc/html/rfc8705), provide desired key and certificate, + /// together with client_id (with no client_secret parameter). + /// + /// # Example + /// + /// ```yaml + /// strategy: + /// strategy: "o_auth2" + /// client_id: "client.id" + /// token_endpoint: "https://yourendpoint.com/oauth/token" + /// tls: + /// crt_path: cert.pem + /// key_file: key.pem + /// ``` + OAuth2 { + /// Token endpoint location, required for token acquisition. + #[configurable(metadata(docs::examples = "https://auth.provider/oauth/token"))] + token_endpoint: String, + + /// The client id. + #[configurable(metadata(docs::examples = "client_id"))] + client_id: String, + + /// The sensitive client secret. + #[configurable(metadata(docs::examples = "client_secret"))] + client_secret: Option, + + /// The grace period configuration for a bearer token. + /// To avoid random authorization failures caused by expired token exception, + /// we will acquire new token, some time (grace period) before current token will be expired, + /// because of that, we will always execute request with fresh enough token. + #[serde(default = "default_oauth2_token_grace_period")] + #[configurable(metadata(docs::examples = 300))] + #[configurable(metadata(docs::type_unit = "seconds"))] + #[configurable(metadata(docs::human_name = "Grace period for bearer token."))] + grace_period: u32, + }, +} + +const fn default_oauth2_token_grace_period() -> u32 { + 300 // 5 minutes +} + /// Configuration of the authentication strategy for HTTP requests. /// /// HTTP authentication should be used with HTTPS only, as the authentication credentials are passed as an @@ -556,10 +968,17 @@ where #[cfg(test)] mod tests { - use std::convert::Infallible; + use std::{convert::Infallible, fs::File, io::BufReader}; - use hyper::{server::conn::AddrStream, service::make_service_fn, Server}; + use hyper::{ + server::conn::AddrStream, + service::{make_service_fn, service_fn}, + Server, + }; use proptest::prelude::*; + use rustls::{Certificate, PrivateKey, RootCertStore, ServerConfig}; + use tokio::net::TcpListener; + use tokio_rustls::TlsAcceptor; use tower::ServiceBuilder; use crate::test_util::next_addr; @@ -811,4 +1230,372 @@ mod tests { let response = client.send(req).await.unwrap(); assert_eq!(response.headers().get("Connection"), None); } + + #[tokio::test] + async fn test_oauth2extension_handle_errors_gently_with_hyper_server() { + let addr: SocketAddr = next_addr(); + // Simplest possible configuration for oauth's client connector. + let tls: vector_lib::tls::MaybeTls<(), TlsSettings> = + MaybeTlsSettings::from_config(&None, false).unwrap(); + let proxy_connector = build_proxy_connector(tls, &ProxyConfig::default()).unwrap(); + let auth_client = Client::builder().build(proxy_connector); + + let token_endpoint = format!("http://{}", addr); + let client_id = String::from("some_client_secret"); + let client_secret = Some(SensitiveString::from(String::from("some_secret"))); + let two_seconds_grace_period: u32 = 2; + + // Setup an OAuth2Extension. + let extension = OAuth2Extension::new( + token_endpoint, + client_id, + client_secret, + two_seconds_grace_period, + auth_client, + ); + + // First token is acquired because cache is empty. + let failed_acquisition = extension.get_token().await; + assert!(failed_acquisition.is_err()); + + let make_svc = make_service_fn(move |_: &AddrStream| { + let svc = ServiceBuilder::new().service(tower::service_fn( + |_req: Request| async move { + let not_a_valid_token = r#" + { + "definetly" : "not a vald response" + } + "#; + + Ok::, hyper::Error>(Response::new(Body::from(not_a_valid_token))) + }, + )); + futures_util::future::ok::<_, Infallible>(svc) + }); + + tokio::spawn(async move { + Server::bind(&addr).serve(make_svc).await.unwrap(); + }); + + // Wait for the server to start. + tokio::time::sleep(Duration::from_millis(10)).await; + + let failed_acquisition = extension.get_token().await; + assert!(failed_acquisition.is_err()); + } + + #[tokio::test] + async fn test_oauth2_strategy_with_hyper_server() { + let oauth_addr: SocketAddr = next_addr(); + let oauth_make_svc = make_service_fn(move |_: &AddrStream| { + let svc = ServiceBuilder::new() + .service(tower::service_fn(|req: Request| async move { + assert_eq!( + req.headers().get("Content-Type"), + Some(&HeaderValue::from_static("application/x-www-form-urlencoded")), + ); + + let body_bytes = hyper::body::to_bytes(req.into_body()).await.unwrap(); + let request_body = String::from_utf8(body_bytes.to_vec()).unwrap(); + + assert_eq!( + // Based on the (later) OAuth2Extension configuration. + "grant_type=client_credentials&client_id=some_client_secret&client_secret=some_secret", + request_body, + ); + + let token = r#" + { + "access_token": "some.jwt.token", + "token_type": "bearer", + "expires_in": 60, + "scope": "some-scope" + } + "#; + + Ok::, hyper::Error>(Response::new(Body::from(token))) + })); + futures_util::future::ok::<_, Infallible>(svc) + }); + + // Server a Http client will request together with acquired bearer token. + let addr: SocketAddr = next_addr(); + let make_svc = make_service_fn(move |_conn: &AddrStream| { + let svc = + ServiceBuilder::new().service(tower::service_fn(|req: Request| async move { + assert_eq!( + req.headers().get("authorization"), + Some(&HeaderValue::from_static("Bearer some.jwt.token")), + ); + + Ok::, hyper::Error>(Response::new(Body::empty())) + })); + futures_util::future::ok::<_, Infallible>(svc) + }); + + tokio::spawn(async move { + Server::bind(&oauth_addr) + .serve(oauth_make_svc) + .await + .unwrap(); + }); + + tokio::spawn(async move { + Server::bind(&addr).serve(make_svc).await.unwrap(); + }); + + // Wait for the server to start. + tokio::time::sleep(Duration::from_millis(10)).await; + + // Http client to test + let token_endpoint = format!("http://{}", oauth_addr); + let client_id: String = String::from("some_client_secret"); + let client_secret = Some(SensitiveString::from(String::from("some_secret"))); + let grace_period = 5; + + let oauth2_strategy = HttpClientAuthorizationStrategy::OAuth2 { + token_endpoint, + client_id, + client_secret, + grace_period, + }; + + let auth_config = AuthorizationConfig { + strategy: oauth2_strategy, + tls: None, + }; + + let client = + HttpClient::new_with_auth_extension(None, &ProxyConfig::default(), Some(auth_config)) + .unwrap(); + + let req = Request::get(format!("http://{}/", addr)) + .body(Body::empty()) + .unwrap(); + + let response = client.send(req).await.unwrap(); + assert!(response.status().is_success()); + } + + #[tokio::test] + async fn test_oauth2_with_mtls_strategy_with_hyper_server() { + let oauth_addr: SocketAddr = next_addr(); + let addr: SocketAddr = next_addr(); + let make_svc = make_service_fn(move |_conn: &AddrStream| { + let svc = + ServiceBuilder::new().service(tower::service_fn(|req: Request| async move { + assert_eq!( + req.headers().get("authorization"), + Some(&HeaderValue::from_static("Bearer some.jwt.token")), + ); + + Ok::, hyper::Error>(Response::new(Body::empty())) + })); + futures_util::future::ok::<_, Infallible>(svc) + }); + + // Load a certificates. + fn load_certs(path: &str) -> Vec { + let certfile = File::open(path).unwrap(); + let mut reader = BufReader::new(certfile); + rustls_pemfile::certs(&mut reader) + .unwrap() + .into_iter() + .map(Certificate) + .collect() + } + + // Load a private key. + fn load_private_key(path: &str) -> PrivateKey { + let keyfile = File::open(path).unwrap(); + let mut reader = BufReader::new(keyfile); + let keys = rustls_pemfile::rsa_private_keys(&mut reader).unwrap(); + PrivateKey(keys[0].clone()) + } + + // Load a server tls context to validate client. + let certs = load_certs("tests/data/ca/certs/ca.cert.pem"); + let key = load_private_key("tests/data/ca/private/ca.key.pem"); + let client_certs = load_certs("tests/data/ca/intermediate_client/certs/ca-chain.cert.pem"); + let mut root_store = RootCertStore::empty(); + for cert in client_certs { + root_store.add(&cert).unwrap(); + } + + tokio::spawn(async move { + let tls_config = ServerConfig::builder() + .with_safe_defaults() + .with_client_cert_verifier(rustls::server::AllowAnyAuthenticatedClient::new( + root_store, + )) + .with_single_cert(certs, key) + .unwrap(); + + let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config)); + let acceptor = Arc::new(tls_acceptor); + let http = hyper::server::conn::Http::new(); + let listener: TcpListener = TcpListener::bind(&oauth_addr).await.unwrap(); + + loop { + let (conn, _) = listener.accept().await.unwrap(); + let acceptor = Arc::::clone(&acceptor); + let http = http.clone(); + let fut = async move { + let stream = acceptor.accept(conn).await.unwrap(); + let service = service_fn(|req: Request| async { + assert_eq!( + req.headers().get("Content-Type"), + Some(&HeaderValue::from_static( + "application/x-www-form-urlencoded" + )), + ); + + let body_bytes = hyper::body::to_bytes(req.into_body()).await.unwrap(); + let request_body = String::from_utf8(body_bytes.to_vec()).unwrap(); + + assert_eq!( + // Based on the (later) OAuth2Extension configuration. + "grant_type=client_credentials&client_id=some_client_secret", + request_body, + ); + + let token = r#" + { + "access_token": "some.jwt.token", + "token_type": "bearer", + "expires_in": 60, + "scope": "some-scope" + } + "#; + + Ok::<_, hyper::Error>(Response::new(Body::from(token))) + }); + + http.serve_connection(stream, service).await.unwrap(); + }; + tokio::spawn(fut); + } + }); + + tokio::spawn(async move { + Server::bind(&addr).serve(make_svc).await.unwrap(); + }); + + // Wait for the server to start. + tokio::time::sleep(Duration::from_millis(10)).await; + + // Http client to test + let token_endpoint = format!("https://{}", oauth_addr); + let client_id: String = String::from("some_client_secret"); + let grace_period = 5; + + let oauth2_strategy = HttpClientAuthorizationStrategy::OAuth2 { + token_endpoint, + client_id, + client_secret: None, + grace_period, + }; + + let auth_config = AuthorizationConfig { + strategy: oauth2_strategy, + tls: Some(TlsConfig { + verify_hostname: Some(false), + ca_file: Some("tests/data/ca/certs/ca.cert.pem".into()), + crt_file: Some("tests/data/ca/intermediate_client/certs/localhost.cert.pem".into()), + key_file: Some( + "tests/data/ca/intermediate_client/private/localhost.key.pem".into(), + ), + ..Default::default() + }), + }; + + let client = + HttpClient::new_with_auth_extension(None, &ProxyConfig::default(), Some(auth_config)) + .unwrap(); + + let req = Request::get(format!("http://{}/", addr)) + .body(Body::empty()) + .unwrap(); + + let response = client.send(req).await.unwrap(); + assert!(response.status().is_success()); + } + + #[tokio::test] + async fn test_basic_auth_strategy_with_hyper_server() { + // Server a Http client will request together with acquired bearer token. + let addr: SocketAddr = next_addr(); + let make_svc = make_service_fn(move |_conn: &AddrStream| { + let svc = + ServiceBuilder::new().service(tower::service_fn(|req: Request| async move { + assert_eq!( + req.headers().get("authorization"), + Some(&HeaderValue::from_static("Basic dXNlcjpwYXNzd29yZA==")), + ); + + Ok::, hyper::Error>(Response::new(Body::empty())) + })); + futures_util::future::ok::<_, Infallible>(svc) + }); + + tokio::spawn(async move { + Server::bind(&addr).serve(make_svc).await.unwrap(); + }); + + // Wait for the server to start. + tokio::time::sleep(Duration::from_millis(10)).await; + + // Http client to test + let user = String::from("user"); + let password = SensitiveString::from(String::from("password")); + + let basic_strategy = HttpClientAuthorizationStrategy::Basic { user, password }; + + let auth_config = AuthorizationConfig { + strategy: basic_strategy, + tls: None, + }; + + let client = + HttpClient::new_with_auth_extension(None, &ProxyConfig::default(), Some(auth_config)) + .unwrap(); + + let req = Request::get(format!("http://{}/", addr)) + .body(Body::empty()) + .unwrap(); + + let response = client.send(req).await.unwrap(); + assert!(response.status().is_success()); + } + + #[tokio::test] + async fn test_grace_period_calculation() { + let now = Duration::from_secs(100); + let grace_period_seconds = 5; + let fake_token = Token { + access_token: String::from("some-jwt"), + expires_in: 20, + }; + + let expires_after_ms = + OAuth2Extension::calculate_valid_until(now, grace_period_seconds, &fake_token); + + assert_eq!(115000, expires_after_ms); + } + + #[tokio::test] + async fn test_grace_period_calculation_with_overflow() { + let now = Duration::from_secs(100); + let grace_period_seconds = 30; + let fake_token = Token { + access_token: String::from("some-jwt"), + expires_in: 20, + }; + + let expires_after_ms = + OAuth2Extension::calculate_valid_until(now, grace_period_seconds, &fake_token); + + // When overflow, we expect grace_period be 0 (so, now + grace = now) + assert_eq!(100000, expires_after_ms); + } } diff --git a/src/internal_events/http_client.rs b/src/internal_events/http_client.rs index 2584ef9c1254d..d69752c11df06 100644 --- a/src/internal_events/http_client.rs +++ b/src/internal_events/http_client.rs @@ -95,6 +95,29 @@ impl<'a> InternalEvent for GotHttpWarning<'a> { } } +#[derive(Debug)] +pub struct AuthExtensionError<'a> { + pub error: &'a vector_lib::Error, +} + +impl<'a> InternalEvent for AuthExtensionError<'a> { + fn emit(self) { + error!( + message = "HTTP Auth extension error.", + error = %self.error, + error_type = error_type::REQUEST_FAILED, + stage = error_stage::PROCESSING, + internal_log_rate_limit = true, + ); + counter!( + "component_errors_total", + "error_type" => error_type::CONFIGURATION_FAILED, + "stage" => error_stage::SENDING, + ) + .increment(1); + } +} + /// Newtype placeholder to provide a formatter for the request and response body. struct FormatBody<'a, B>(&'a B); diff --git a/src/sinks/axiom.rs b/src/sinks/axiom.rs index dfc4ab124dcc1..979fe1241bd82 100644 --- a/src/sinks/axiom.rs +++ b/src/sinks/axiom.rs @@ -119,6 +119,7 @@ impl SinkConfig for AxiomConfig { }), method: HttpMethod::Post, tls: self.tls.clone(), + authorization_config: None, request, acknowledgements: self.acknowledgements, batch: self.batch, diff --git a/src/sinks/http/config.rs b/src/sinks/http/config.rs index ccc9ace780046..55648f35b3bd1 100644 --- a/src/sinks/http/config.rs +++ b/src/sinks/http/config.rs @@ -10,7 +10,7 @@ use vector_lib::codecs::{ use crate::{ codecs::{EncodingConfigWithFraming, SinkType}, - http::{Auth, HttpClient, MaybeAuth}, + http::{Auth, AuthorizationConfig, HttpClient, MaybeAuth}, sinks::{ prelude::*, util::{ @@ -90,6 +90,9 @@ pub struct HttpSinkConfig { #[configurable(derived)] pub tls: Option, + #[configurable(derived)] + pub authorization_config: Option, + #[configurable(derived)] #[serde( default, @@ -153,7 +156,12 @@ impl From for Method { impl HttpSinkConfig { fn build_http_client(&self, cx: &SinkContext) -> crate::Result { let tls = TlsSettings::from_options(&self.tls)?; - Ok(HttpClient::new(tls, cx.proxy())?) + let auth_strategy = self.authorization_config.clone(); + Ok(HttpClient::new_with_auth_extension( + tls, + cx.proxy(), + auth_strategy, + )?) } pub(super) fn build_encoder(&self) -> crate::Result> { @@ -338,6 +346,7 @@ mod tests { batch: BatchConfig::default(), request: RequestConfig::default(), tls: None, + authorization_config: None, acknowledgements: AcknowledgementsConfig::default(), payload_prefix: String::new(), payload_suffix: String::new(), diff --git a/src/sinks/http/tests.rs b/src/sinks/http/tests.rs index 363877380c308..65beca27574ec 100644 --- a/src/sinks/http/tests.rs +++ b/src/sinks/http/tests.rs @@ -59,6 +59,7 @@ fn default_cfg(encoding: EncodingConfigWithFraming) -> HttpSinkConfig { batch: Default::default(), request: Default::default(), tls: Default::default(), + authorization_config: None, acknowledgements: Default::default(), } } diff --git a/website/cue/reference/components/sinks/base/http.cue b/website/cue/reference/components/sinks/base/http.cue index ae8aaa42bf94b..2377e240cb8a9 100644 --- a/website/cue/reference/components/sinks/base/http.cue +++ b/website/cue/reference/components/sinks/base/http.cue @@ -74,6 +74,209 @@ base: components: sinks: http: configuration: { } } } + authorization_config: { + description: "Configuration for HTTP client providing an authentication mechanism." + required: false + type: object: options: { + strategy: { + description: """ + Configuration of the authentication strategy for HTTP requests. + + Define how to authorize against an upstream. + """ + required: true + type: object: options: { + client_id: { + description: "The client id." + relevant_when: "strategy = \"o_auth2\"" + required: true + type: string: examples: ["client_id"] + } + client_secret: { + description: "The sensitive client secret." + relevant_when: "strategy = \"o_auth2\"" + required: false + type: string: examples: ["client_secret"] + } + grace_period: { + description: """ + The grace period configuration for a bearer token. + To avoid random authorization failures caused by expired token exception, + we will acquire new token, some time (grace period) before current token will be expired, + because of that, we will always execute request with fresh enough token. + """ + relevant_when: "strategy = \"o_auth2\"" + required: false + type: uint: { + default: 300 + examples: [300] + unit: "seconds" + } + } + password: { + description: "The basic authentication password." + relevant_when: "strategy = \"basic\"" + required: true + type: string: examples: ["password"] + } + strategy: { + description: "The authentication strategy to use." + required: true + type: string: enum: { + basic: """ + Basic authentication. + + The username and password are concatenated and encoded via [base64][base64]. + + [base64]: https://en.wikipedia.org/wiki/Base64 + """ + o_auth2: """ + Authentication based on OAuth 2.0 protocol. + + This strategy allows to dynamically acquire and use token based on provided parameters. + Both standard client_credentials and mTLS extension is supported, for standard client_credentials just provide both + client_id and client_secret parameters: + + # Example + + ```yaml + strategy: + strategy: "o_auth2" + client_id: "client.id" + client_secret: "secret-value" + token_endpoint: "https://yourendpoint.com/oauth/token" + ``` + In case you want to use mTLS extension [rfc8705](https://datatracker.ietf.org/doc/html/rfc8705), provide desired key and certificate, + together with client_id (with no client_secret parameter). + + # Example + + ```yaml + strategy: + strategy: "o_auth2" + client_id: "client.id" + token_endpoint: "https://yourendpoint.com/oauth/token" + tls: + crt_path: cert.pem + key_file: key.pem + ``` + """ + } + } + token_endpoint: { + description: "Token endpoint location, required for token acquisition." + relevant_when: "strategy = \"o_auth2\"" + required: true + type: string: examples: ["https://auth.provider/oauth/token"] + } + user: { + description: "The basic authentication username." + relevant_when: "strategy = \"basic\"" + required: true + type: string: examples: ["username"] + } + } + } + tls: { + description: """ + The TLS settings for the http client's connection. + + Optional, constrains TLS settings for this http client. + """ + required: false + type: object: options: { + alpn_protocols: { + description: """ + Sets the list of supported ALPN protocols. + + Declare the supported ALPN protocols, which are used during negotiation with peer. They are prioritized in the order + that they are defined. + """ + required: false + type: array: items: type: string: examples: ["h2"] + } + ca_file: { + description: """ + Absolute path to an additional CA certificate file. + + The certificate must be in the DER or PEM (X.509) format. Additionally, the certificate can be provided as an inline string in PEM format. + """ + required: false + type: string: examples: ["/path/to/certificate_authority.crt"] + } + crt_file: { + description: """ + Absolute path to a certificate file used to identify this server. + + The certificate must be in DER, PEM (X.509), or PKCS#12 format. Additionally, the certificate can be provided as + an inline string in PEM format. + + If this is set, and is not a PKCS#12 archive, `key_file` must also be set. + """ + required: false + type: string: examples: ["/path/to/host_certificate.crt"] + } + key_file: { + description: """ + Absolute path to a private key file used to identify this server. + + The key must be in DER or PEM (PKCS#8) format. Additionally, the key can be provided as an inline string in PEM format. + """ + required: false + type: string: examples: ["/path/to/host_certificate.key"] + } + key_pass: { + description: """ + Passphrase used to unlock the encrypted key file. + + This has no effect unless `key_file` is set. + """ + required: false + type: string: examples: ["${KEY_PASS_ENV_VAR}", "PassWord1"] + } + server_name: { + description: """ + Server name to use when using Server Name Indication (SNI). + + Only relevant for outgoing connections. + """ + required: false + type: string: examples: ["www.example.com"] + } + verify_certificate: { + description: """ + Enables certificate verification. For components that create a server, this requires that the + client connections have a valid client certificate. For components that initiate requests, + this validates that the upstream has a valid certificate. + + If enabled, certificates must not be expired and must be issued by a trusted + issuer. This verification operates in a hierarchical manner, checking that the leaf certificate (the + certificate presented by the client/server) is not only valid, but that the issuer of that certificate is also valid, and + so on until the verification process reaches a root certificate. + + Do NOT set this to `false` unless you understand the risks of not verifying the validity of certificates. + """ + required: false + type: bool: {} + } + verify_hostname: { + description: """ + Enables hostname verification. + + If enabled, the hostname used to connect to the remote host must be present in the TLS certificate presented by + the remote host, either as the Common Name or as an entry in the Subject Alternative Name extension. + + Only relevant for outgoing connections. + + Do NOT set this to `false` unless you understand the risks of not verifying the remote hostname. + """ + required: false + type: bool: {} + } + } + } + } + } batch: { description: "Event batching behavior." required: false