diff --git a/pubsub/gochannel/pubsub.go b/pubsub/gochannel/pubsub.go index 03a9d548a..fdd22044d 100644 --- a/pubsub/gochannel/pubsub.go +++ b/pubsub/gochannel/pubsub.go @@ -163,7 +163,7 @@ func (g *GoChannel) sendMessage(topic string, message *message.Message) (<-chan wg.Add(1) go func() { - subscriber.sendMessageToSubscriber(message, logFields, g.config.PreserveContext) + subscriber.sendMessageToSubscriber(message, logFields) wg.Done() }() } @@ -196,11 +196,12 @@ func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *messag subLock.(*sync.Mutex).Lock() s := &subscriber{ - ctx: ctx, - uuid: watermill.NewUUID(), - outputChannel: make(chan *message.Message, g.config.OutputChannelBuffer), - logger: g.logger, - closing: make(chan struct{}), + ctx: ctx, + uuid: watermill.NewUUID(), + outputChannel: make(chan *message.Message, g.config.OutputChannelBuffer), + logger: g.logger, + closing: make(chan struct{}), + preserveContext: g.config.PreserveContext, } go func(s *subscriber, g *GoChannel) { @@ -246,7 +247,7 @@ func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *messag msg := g.persistedMessages[topic][i] logFields := watermill.LogFields{"message_uuid": msg.UUID, "topic": topic} - go s.sendMessageToSubscriber(msg, logFields, g.config.PreserveContext) + go s.sendMessageToSubscriber(msg, logFields) } } @@ -329,6 +330,8 @@ type subscriber struct { logger watermill.LoggerAdapter closed bool closing chan struct{} + + preserveContext bool } func (s *subscriber) Close() { @@ -349,19 +352,20 @@ func (s *subscriber) Close() { close(s.outputChannel) } -func (s *subscriber) sendMessageToSubscriber(msg *message.Message, logFields watermill.LogFields, preserveContext bool) { +func (s *subscriber) sendMessageToSubscriber(msg *message.Message, logFields watermill.LogFields) { s.sending.Lock() defer s.sending.Unlock() var ctx context.Context - var cancelCtx context.CancelFunc - if preserveContext { - ctx, cancelCtx = context.WithCancel(msg.Context()) + //This is getting the context from the message, not the subscriber + if s.preserveContext { + ctx = msg.Context() } else { + var cancelCtx context.CancelFunc ctx, cancelCtx = context.WithCancel(s.ctx) + defer cancelCtx() } - defer cancelCtx() SendToSubscriber: for { diff --git a/pubsub/gochannel/pubsub_test.go b/pubsub/gochannel/pubsub_test.go index 4f339e5a4..2478d080f 100644 --- a/pubsub/gochannel/pubsub_test.go +++ b/pubsub/gochannel/pubsub_test.go @@ -102,8 +102,8 @@ func TestPublishSubscribe_not_persistent_with_context(t *testing.T) { msgs, err := pubSub.Subscribe(context.Background(), topicName) require.NoError(t, err) - const contextKey = "foo" - sendMessages := tests.PublishSimpleMessagesWithContext(t, messagesCount, contextKey, pubSub, topicName) + const contextKeyString = "foo" + sendMessages := tests.PublishSimpleMessagesWithContext(t, messagesCount, contextKeyString, pubSub, topicName) receivedMsgs, _ := subscriber.BulkRead(msgs, messagesCount, time.Second) expectedContexts := make(map[string]context.Context) @@ -111,7 +111,7 @@ func TestPublishSubscribe_not_persistent_with_context(t *testing.T) { expectedContexts[msg.UUID] = msg.Context() } tests.AssertAllMessagesReceived(t, sendMessages, receivedMsgs) - tests.AssertAllMessagesHaveSameContext(t, contextKey, expectedContexts, receivedMsgs) + tests.AssertAllMessagesHaveSameContext(t, contextKeyString, expectedContexts, receivedMsgs) assert.NoError(t, pubSub.Close()) } diff --git a/pubsub/tests/test_asserts.go b/pubsub/tests/test_asserts.go index 7daa17d5f..4660ed931 100644 --- a/pubsub/tests/test_asserts.go +++ b/pubsub/tests/test_asserts.go @@ -95,13 +95,13 @@ func AssertMessagesMetadata(t *testing.T, key string, expectedValues map[string] } // AssertAllMessagesHaveSameContext checks if context of all received messages is the same as in expectedValues, if PreserveContext is enabled. -func AssertAllMessagesHaveSameContext(t *testing.T, contextKey string, expectedValues map[string]context.Context, received []*message.Message) bool { +func AssertAllMessagesHaveSameContext(t *testing.T, contextKeyString string, expectedValues map[string]context.Context, received []*message.Message) bool { assert.Len(t, received, len(expectedValues)) ok := true for _, msg := range received { - expectedValue := expectedValues[msg.UUID].Value(contextKey) - actualValue := msg.Context().Value(contextKey) + expectedValue := expectedValues[msg.UUID].Value(contextKey(contextKeyString)).(string) + actualValue := msg.Context().Value(contextKeyString) if !assert.Equal(t, expectedValue, actualValue) { ok = false } diff --git a/pubsub/tests/test_pubsub.go b/pubsub/tests/test_pubsub.go index 026fffa18..1a5566c01 100644 --- a/pubsub/tests/test_pubsub.go +++ b/pubsub/tests/test_pubsub.go @@ -1291,14 +1291,14 @@ func PublishSimpleMessages(t *testing.T, messagesCount int, publisher message.Pu } // PublishSimpleMessagesWithContext publishes provided number of simple messages without a payload, but custom context -func PublishSimpleMessagesWithContext(t *testing.T, messagesCount int, contextKey string, publisher message.Publisher, topicName string) message.Messages { +func PublishSimpleMessagesWithContext(t *testing.T, messagesCount int, contextKeyString string, publisher message.Publisher, topicName string) message.Messages { var messagesToPublish []*message.Message for i := 0; i < messagesCount; i++ { id := watermill.NewUUID() msg := message.NewMessage(id, nil) - msg.SetContext(context.WithValue(context.Background(), contextKey, "bar"+strconv.Itoa(i))) + msg.SetContext(context.WithValue(context.Background(), contextKey(contextKeyString), "bar"+strconv.Itoa(i))) messagesToPublish = append(messagesToPublish, msg) err := publishWithRetry(publisher, topicName, msg)