diff --git a/src/HeaderPropagationMessageHandler.cs b/src/HeaderPropagationMessageHandler.cs index f863998..3fa50fc 100644 --- a/src/HeaderPropagationMessageHandler.cs +++ b/src/HeaderPropagationMessageHandler.cs @@ -53,11 +53,31 @@ protected override Task SendAsync(HttpRequestMessage reques { var outputName = string.IsNullOrEmpty(entry?.OutboundHeaderName) ? headerName : entry.OutboundHeaderName; - if (!request.Headers.Contains(outputName) && - _values.Headers.TryGetValue(headerName, out var values) && - !StringValues.IsNullOrEmpty(values)) + var hasContent = request.Content != null; + + if (!request.Headers.TryGetValues(outputName, out var _) && + !(hasContent && request.Content.Headers.TryGetValues(outputName, out var _))) { - request.Headers.TryAddWithoutValidation(outputName, (string[])values); + if (_values.Headers.TryGetValue(headerName, out var stringValues) && + !StringValues.IsNullOrEmpty(stringValues)) + { + if (stringValues.Count == 1) + { + var value = (string)stringValues; + if (!request.Headers.TryAddWithoutValidation(outputName, value) && hasContent) + { + request.Content.Headers.TryAddWithoutValidation(outputName, value); + } + } + else + { + var values = (string[])stringValues; + if (!request.Headers.TryAddWithoutValidation(outputName, values) && hasContent) + { + request.Content.Headers.TryAddWithoutValidation(outputName, values); + } + } + } } } diff --git a/test/HeaderPropagationMessageHandlerTest.cs b/test/HeaderPropagationMessageHandlerTest.cs index ebdf82f..9813cf8 100644 --- a/test/HeaderPropagationMessageHandlerTest.cs +++ b/test/HeaderPropagationMessageHandlerTest.cs @@ -52,6 +52,63 @@ public async Task HeaderInState_AddCorrectValue() Assert.Equal(new[] { "test" }, Handler.Headers.GetValues("out")); } + [Fact] + public async Task HeaderInState_WithMultipleValues_AddAllValues() + { + // Arrange + Configuration.Headers.Add("in", new HeaderPropagationEntry { OutboundHeaderName = "out" }); + State.Headers.Add("in", new[] { "one", "two" }); + + // Act + await Client.SendAsync(new HttpRequestMessage()); + + // Assert + Assert.True(Handler.Headers.Contains("out")); + Assert.Equal(new[] { "one", "two" }, Handler.Headers.GetValues("out")); + } + + [Fact] + public async Task HeaderInState_RequestWithContent_ContentHeaderPresent_DoesNotAddIt() + { + Configuration.Headers.Add("in", new HeaderPropagationEntry() { OutboundHeaderName = "Content-Type" }); + State.Headers.Add("in", "test"); + + // Act + await Client.SendAsync(new HttpRequestMessage() { Content = new StringContent("test") }); + + // Assert + Assert.True(Handler.Content.Headers.Contains("Content-Type")); + Assert.Equal(new[] { "text/plain; charset=utf-8" }, Handler.Content.Headers.GetValues("Content-Type")); + } + + [Fact] + public async Task HeaderInState_RequestWithContent_ContentHeaderNotPresent_AddValue() + { + Configuration.Headers.Add("in", new HeaderPropagationEntry() { OutboundHeaderName = "Content-Language" }); + State.Headers.Add("in", "test"); + + // Act + await Client.SendAsync(new HttpRequestMessage() { Content = new StringContent("test") }); + + // Assert + Assert.True(Handler.Content.Headers.Contains("Content-Language")); + Assert.Equal(new[] { "test" }, Handler.Content.Headers.GetValues("Content-Language")); + } + + [Fact] + public async Task HeaderInState_WithMultipleValues_RequestWithContent_ContentHeaderNotPresent_AddAllValues() + { + Configuration.Headers.Add("in", new HeaderPropagationEntry() { OutboundHeaderName = "Content-Language" }); + State.Headers.Add("in", new[] { "one", "two" }); + + // Act + await Client.SendAsync(new HttpRequestMessage() { Content = new StringContent("test") }); + + // Assert + Assert.True(Handler.Content.Headers.Contains("Content-Language")); + Assert.Equal(new[] { "one", "two" }, Handler.Content.Headers.GetValues("Content-Language")); + } + [Fact] public async Task HeaderInState_NoOutputName_UseInputName() { @@ -168,11 +225,13 @@ public async Task NullEntryInConfiguration_AddCorrectValue() private class SimpleHandler : DelegatingHandler { public HttpHeaders Headers { get; private set; } + public HttpContent Content { get; private set; } protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { Headers = request.Headers; + Content = request.Content; return Task.FromResult(new HttpResponseMessage()); } }