Skip to content

Commit

Permalink
fix: race conditions in publisher
Browse files Browse the repository at this point in the history
Signed-off-by: Dusan Malusev <dusan@dusanmalusev.dev>
  • Loading branch information
CodeLieutenant committed Feb 28, 2024
1 parent 89b949a commit e347d86
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 84 deletions.
40 changes: 0 additions & 40 deletions Makefile

This file was deleted.

29 changes: 10 additions & 19 deletions publisher/publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@ type (
Publisher[T any] struct {
serializer serializer.Serializer[T]
conn *connection.Connection
ch *amqp091.Channel
ch atomic.Pointer[amqp091.Channel]
cancel context.CancelFunc
exchangeName string
routingKey string
wg sync.WaitGroup
ready sync.RWMutex
closing atomic.Bool
gettingCh atomic.Bool
}
Expand Down Expand Up @@ -69,16 +68,13 @@ func (e ExchangeDeclare) declare(ch *amqp091.Channel, logger logging.Logger) err

func (p *Publisher[T]) swapChannel(connection *amqp091.Connection, cfg Config[T]) (chan *amqp091.Error, error) {
p.gettingCh.Store(true)
p.ready.Lock()
defer p.ready.Unlock()

chOrigin, notifyClose, err := newChannel(connection, cfg.exchange, cfg.logger)

if err != nil {
return nil, err
}

p.ch = chOrigin
p.ch.Store(chOrigin)
p.gettingCh.Store(false)
return notifyClose, nil
}
Expand All @@ -101,10 +97,9 @@ func (p *Publisher[T]) onConnectionReady(cfg Config[T]) connection.OnConnectionR
select {
case <-ctx.Done():
p.closing.Store(true)
p.ready.Lock()
defer p.ready.Unlock()
if !p.ch.IsClosed() {
if err := p.ch.Close(); err != nil {
ch := p.ch.Load()
if !ch.IsClosed() {
if err := ch.Close(); err != nil {
cfg.logger.Error("Failed to close channel: %v", err)
}
}
Expand Down Expand Up @@ -185,7 +180,7 @@ func New[T any](exchangeName string, options ...Option[T]) (*Publisher[T], error
panic(err)
}

fmt.Fprintf(os.Stderr, "[ERROR]: An error has occurred! %v\n", err)
_, _ = fmt.Fprintf(os.Stderr, "[ERROR]: An error has occurred! %v\n", err)
},
}

Expand Down Expand Up @@ -231,16 +226,16 @@ func (p *Publisher[T]) Publish(ctx context.Context, msg T, config ...PublishConf
return ErrChannelNotReady
}

p.ready.RLock()

defer p.ready.RUnlock()
//p.ready.RLock()
//
//defer p.ready.RUnlock()

body, err := p.serializer.Marshal(msg)
if err != nil {
return err
}

return p.ch.PublishWithContext(
return (*p.ch.Load()).PublishWithContext(
ctx,
p.exchangeName,
p.routingKey,
Expand All @@ -257,11 +252,7 @@ func (p *Publisher[T]) Publish(ctx context.Context, msg T, config ...PublishConf

func (p *Publisher[T]) Close() error {
p.closing.Store(true)
p.ready.Lock()
defer p.ready.Unlock()

p.cancel()
p.wg.Wait()
p.ch = nil
return p.conn.Close()
}
50 changes: 25 additions & 25 deletions publisher/publisher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,39 +91,39 @@ func TestPublisherPublish(t *testing.T) {

t.Run("Basic", func(t *testing.T) {
mappings := amqp_testing.NewMappings(t).
AddMapping("test_exchange", "test_queue")
AddMapping("test_exchange_basic", "test_queue_basic")

pub, err := publisher.New[Msg](mappings.Exchange("test_exchange"))
pub, err := publisher.New[Msg](mappings.Exchange("test_exchange_basic"))
assert.NoError(err)
assert.NotNil(pub)

assert.NoError(pub.Publish(context.Background(), Msg{Name: "test"}))
assert.NoError(pub.Close())

messages := amqp_testing.ConsumeAMQPMessages[Msg](t, mappings.Queue("test_queue"), connection.DefaultConfig, 200*time.Millisecond)
messages := amqp_testing.ConsumeAMQPMessages[Msg](t, mappings.Queue("test_queue_basic"), connection.DefaultConfig, 200*time.Millisecond)

assert.Len(messages, 1)
assert.Equal("test", messages[0].Name)
})

t.Run("WithSerializer", func(t *testing.T) {
mappings := amqp_testing.NewMappings(t).
AddMapping("test_exchange", "test_queue")
AddMapping("test_exchange_serializer", "test_queue_serializer")

serializer := &MockSerializer{}
mockSerializer := &MockSerializer{}

pub, err := publisher.New[Msg](
mappings.Exchange("test_exchange"),
publisher.WithSerializer[Msg](serializer),
mappings.Exchange("test_exchange_serializer"),
publisher.WithSerializer[Msg](mockSerializer),
)
assert.NoError(err)
assert.NotNil(pub)

serializer.On("Marshal", Msg{Name: "test"}).
mockSerializer.On("Marshal", Msg{Name: "test"}).
Once().
Return([]byte("\"test\""), nil)

serializer.On("GetContentType").
mockSerializer.On("GetContentType").
Once().
Return("application/json")

Expand All @@ -132,32 +132,32 @@ func TestPublisherPublish(t *testing.T) {

messages := amqp_testing.ConsumeAMQPMessages[string](
t,
mappings.Queue("test_queue"),
mappings.Queue("test_queue_serializer"),
connection.DefaultConfig,
200*time.Millisecond,
)

assert.Len(messages, 1)
assert.Equal("test", messages[0])
serializer.AssertExpectations(t)
mockSerializer.AssertExpectations(t)
})

t.Run("WithSerializerFails", func(t *testing.T) {
mappings := amqp_testing.NewMappings(t).
AddMapping("test_exchange", "test_queue")
AddMapping("test_exchange_serializer_fails", "test_queue_serializer_fails")

serializer := &MockSerializer{}
mockSerializer := &MockSerializer{}

pub, err := publisher.New[Msg](
mappings.Exchange("test_exchange"),
publisher.WithSerializer[Msg](serializer),
mappings.Exchange("test_exchange_serializer_fails"),
publisher.WithSerializer[Msg](mockSerializer),
)
assert.NoError(err)
assert.NotNil(pub)

expectedErr := errors.New("failed to serialize")

serializer.On("Marshal", Msg{Name: "test"}).
mockSerializer.On("Marshal", Msg{Name: "test"}).
Once().
Return([]byte{}, expectedErr)

Expand All @@ -166,14 +166,14 @@ func TestPublisherPublish(t *testing.T) {

messages := amqp_testing.ConsumeAMQPMessages[string](
t,
mappings.Queue("test_queue"),
mappings.Queue("test_queue_serializer_fails"),
connection.DefaultConfig,
200*time.Millisecond,
)

assert.Len(messages, 0)
serializer.AssertNotCalled(t, "GetContentType")
serializer.AssertExpectations(t)
mockSerializer.AssertNotCalled(t, "GetContentType")
mockSerializer.AssertExpectations(t)
})
}

Expand All @@ -183,9 +183,9 @@ func TestPublisherClose(t *testing.T) {

t.Run("Basic", func(t *testing.T) {
mappings := amqp_testing.NewMappings(t).
AddMapping("test_exchange", "test_queue")
AddMapping("test_exchange_close", "test_queue_close")

pub, err := publisher.New[Msg](mappings.Exchange("test_exchange"))
pub, err := publisher.New[Msg](mappings.Exchange("test_exchange_close"))
assert.NoError(err)
assert.NotNil(pub)

Expand All @@ -194,9 +194,9 @@ func TestPublisherClose(t *testing.T) {

t.Run("Call_To_Publish_After_Close", func(t *testing.T) {
mappings := amqp_testing.NewMappings(t).
AddMapping("test_exchange", "test_queue")
AddMapping("test_exchange_after_close", "test_queue_after_close")

pub, err := publisher.New[Msg](mappings.Exchange("test_exchange"))
pub, err := publisher.New[Msg](mappings.Exchange("test_exchange_after_close"))
assert.NoError(err)
assert.NotNil(pub)

Expand All @@ -206,9 +206,9 @@ func TestPublisherClose(t *testing.T) {

t.Run("Multiple_Close_Calling", func(t *testing.T) {
mappings := amqp_testing.NewMappings(t).
AddMapping("test_exchange", "test_queue")
AddMapping("test_exchange_multiple_close_call", "test_queue_multiple_close_call")

pub, err := publisher.New[Msg](mappings.Exchange("test_exchange"))
pub, err := publisher.New[Msg](mappings.Exchange("test_exchange_multiple_close_call"))
assert.NoError(err)
assert.NotNil(pub)

Expand Down

0 comments on commit e347d86

Please sign in to comment.