Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yash.bansal committed Sep 9, 2024
1 parent 70d2e59 commit aeaedb3
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 20 deletions.
28 changes: 16 additions & 12 deletions pubsub/gochannel/pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func (g *GoChannel) sendMessage(topic string, message *message.Message) (<-chan

wg.Add(1)
go func() {
subscriber.sendMessageToSubscriber(message, logFields, g.config.PreserveContext)
subscriber.sendMessageToSubscriber(message, logFields)
wg.Done()
}()
}
Expand Down Expand Up @@ -196,11 +196,12 @@ func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *messag
subLock.(*sync.Mutex).Lock()

s := &subscriber{
ctx: ctx,
uuid: watermill.NewUUID(),
outputChannel: make(chan *message.Message, g.config.OutputChannelBuffer),
logger: g.logger,
closing: make(chan struct{}),
ctx: ctx,
uuid: watermill.NewUUID(),
outputChannel: make(chan *message.Message, g.config.OutputChannelBuffer),
logger: g.logger,
closing: make(chan struct{}),
preserveContext: g.config.PreserveContext,
}

go func(s *subscriber, g *GoChannel) {
Expand Down Expand Up @@ -246,7 +247,7 @@ func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *messag
msg := g.persistedMessages[topic][i]
logFields := watermill.LogFields{"message_uuid": msg.UUID, "topic": topic}

go s.sendMessageToSubscriber(msg, logFields, g.config.PreserveContext)
go s.sendMessageToSubscriber(msg, logFields)
}
}

Expand Down Expand Up @@ -329,6 +330,8 @@ type subscriber struct {
logger watermill.LoggerAdapter
closed bool
closing chan struct{}

preserveContext bool
}

func (s *subscriber) Close() {
Expand All @@ -349,19 +352,20 @@ func (s *subscriber) Close() {
close(s.outputChannel)
}

func (s *subscriber) sendMessageToSubscriber(msg *message.Message, logFields watermill.LogFields, preserveContext bool) {
func (s *subscriber) sendMessageToSubscriber(msg *message.Message, logFields watermill.LogFields) {
s.sending.Lock()
defer s.sending.Unlock()

var ctx context.Context
var cancelCtx context.CancelFunc

if preserveContext {
ctx, cancelCtx = context.WithCancel(msg.Context())
//This is getting the context from the message, not the subscriber
if s.preserveContext {
ctx = msg.Context()
} else {
var cancelCtx context.CancelFunc
ctx, cancelCtx = context.WithCancel(s.ctx)
defer cancelCtx()
}
defer cancelCtx()

SendToSubscriber:
for {
Expand Down
6 changes: 3 additions & 3 deletions pubsub/gochannel/pubsub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,16 @@ func TestPublishSubscribe_not_persistent_with_context(t *testing.T) {
msgs, err := pubSub.Subscribe(context.Background(), topicName)
require.NoError(t, err)

const contextKey = "foo"
sendMessages := tests.PublishSimpleMessagesWithContext(t, messagesCount, contextKey, pubSub, topicName)
const contextKeyString = "foo"
sendMessages := tests.PublishSimpleMessagesWithContext(t, messagesCount, contextKeyString, pubSub, topicName)
receivedMsgs, _ := subscriber.BulkRead(msgs, messagesCount, time.Second)

expectedContexts := make(map[string]context.Context)
for _, msg := range sendMessages {
expectedContexts[msg.UUID] = msg.Context()
}
tests.AssertAllMessagesReceived(t, sendMessages, receivedMsgs)
tests.AssertAllMessagesHaveSameContext(t, contextKey, expectedContexts, receivedMsgs)
tests.AssertAllMessagesHaveSameContext(t, contextKeyString, expectedContexts, receivedMsgs)

assert.NoError(t, pubSub.Close())
}
Expand Down
6 changes: 3 additions & 3 deletions pubsub/tests/test_asserts.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ func AssertMessagesMetadata(t *testing.T, key string, expectedValues map[string]
}

// AssertAllMessagesHaveSameContext checks if context of all received messages is the same as in expectedValues, if PreserveContext is enabled.
func AssertAllMessagesHaveSameContext(t *testing.T, contextKey string, expectedValues map[string]context.Context, received []*message.Message) bool {
func AssertAllMessagesHaveSameContext(t *testing.T, contextKeyString string, expectedValues map[string]context.Context, received []*message.Message) bool {
assert.Len(t, received, len(expectedValues))

ok := true
for _, msg := range received {
expectedValue := expectedValues[msg.UUID].Value(contextKey)
actualValue := msg.Context().Value(contextKey)
expectedValue := expectedValues[msg.UUID].Value(contextKey(contextKeyString)).(string)
actualValue := msg.Context().Value(contextKeyString)
if !assert.Equal(t, expectedValue, actualValue) {
ok = false
}
Expand Down
4 changes: 2 additions & 2 deletions pubsub/tests/test_pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -1291,14 +1291,14 @@ func PublishSimpleMessages(t *testing.T, messagesCount int, publisher message.Pu
}

// PublishSimpleMessagesWithContext publishes provided number of simple messages without a payload, but custom context
func PublishSimpleMessagesWithContext(t *testing.T, messagesCount int, contextKey string, publisher message.Publisher, topicName string) message.Messages {
func PublishSimpleMessagesWithContext(t *testing.T, messagesCount int, contextKeyString string, publisher message.Publisher, topicName string) message.Messages {
var messagesToPublish []*message.Message

for i := 0; i < messagesCount; i++ {
id := watermill.NewUUID()

msg := message.NewMessage(id, nil)
msg.SetContext(context.WithValue(context.Background(), contextKey, "bar"+strconv.Itoa(i)))
msg.SetContext(context.WithValue(context.Background(), contextKey(contextKeyString), "bar"+strconv.Itoa(i)))
messagesToPublish = append(messagesToPublish, msg)

err := publishWithRetry(publisher, topicName, msg)
Expand Down

0 comments on commit aeaedb3

Please sign in to comment.