Skip to content

Commit

Permalink
Add changing visibility timeout for each call (#242)
Browse files Browse the repository at this point in the history
* Add changing visibility timeout for each call
  • Loading branch information
danielle-tfh authored Dec 2, 2024
1 parent 3e23538 commit 0df1c8f
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 76 deletions.
34 changes: 12 additions & 22 deletions consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sqsclient

import (
"context"
"errors"
"fmt"
"sync"

Expand All @@ -12,11 +13,10 @@ import (
)

type Config struct {
QueueURL string
WorkersNum int
VisibilityTimeout int32
BatchSize int32
ExtendEnabled bool
QueueURL string
WorkersNum int
VisibilityTimeoutSeconds int32
BatchSize int32
}

type Consumer struct {
Expand All @@ -26,13 +26,17 @@ type Consumer struct {
cfg Config
}

func NewConsumer(awsCfg aws.Config, cfg Config, handler Handler) *Consumer {
func NewConsumer(awsCfg aws.Config, cfg Config, handler Handler) (*Consumer, error) {
if cfg.VisibilityTimeoutSeconds < 30 {
return nil, errors.New("VisibilityTimeoutSeconds must be greater or equal to 30")
}

return &Consumer{
sqs: sqs.NewFromConfig(awsCfg),
handler: handler,
wg: &sync.WaitGroup{},
cfg: cfg,
}
}, nil
}

func (c *Consumer) Consume(ctx context.Context) {
Expand All @@ -56,6 +60,7 @@ loop:
MaxNumberOfMessages: c.cfg.BatchSize,
WaitTimeSeconds: int32(5),
MessageAttributeNames: []string{"TraceID", "SpanID"},
VisibilityTimeout: c.cfg.VisibilityTimeoutSeconds,
})
if err != nil {
zap.S().With(zap.Error(err)).Error("could not receive messages from SQS")
Expand Down Expand Up @@ -83,9 +88,6 @@ func (c *Consumer) worker(ctx context.Context, messages <-chan *Message) {

func (c *Consumer) handleMsg(ctx context.Context, m *Message) error {
if c.handler != nil {
if c.cfg.ExtendEnabled {
c.extend(ctx, m)
}
if err := c.handler.Run(ctx, m); err != nil {
return m.ErrorResponse(err)
}
Expand All @@ -104,15 +106,3 @@ func (c *Consumer) delete(ctx context.Context, m *Message) error {
zap.S().Debug("message deleted")
return nil
}

func (c *Consumer) extend(ctx context.Context, m *Message) {
_, err := c.sqs.ChangeMessageVisibility(ctx, &sqs.ChangeMessageVisibilityInput{
QueueUrl: &c.cfg.QueueURL,
ReceiptHandle: m.ReceiptHandle,
VisibilityTimeout: c.cfg.VisibilityTimeout,
})
if err != nil {
zap.S().With(zap.Error(err)).Error("unable to extend message")
return
}
}
65 changes: 52 additions & 13 deletions consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
const (
awsRegion = "us-east-1"
localAwsEndpoint = "http://localhost:4566"
visibilityTimeout = 20
visibilityTimeout = 30
batchSize = 10
workersNum = 1
traceId = "traceid123"
Expand Down Expand Up @@ -63,13 +63,13 @@ func TestConsume(t *testing.T) {

msgHandler := handler(t, expectedMsg, expectedMsgAttributes)
config := Config{
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeout: visibilityTimeout,
BatchSize: batchSize,
ExtendEnabled: true,
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeoutSeconds: visibilityTimeout,
BatchSize: batchSize,
}
consumer := NewConsumer(awsCfg, config, msgHandler)
consumer, err := NewConsumer(awsCfg, config, msgHandler)
assert.NoError(t, err)
go consumer.Consume(ctx)

t.Cleanup(func() {
Expand Down Expand Up @@ -105,14 +105,14 @@ func TestConsume_GracefulShutdown(t *testing.T) {
queueUrl := createQueue(t, ctx, awsCfg, queueName)

config := Config{
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeout: visibilityTimeout,
BatchSize: batchSize,
ExtendEnabled: true,
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeoutSeconds: visibilityTimeout,
BatchSize: batchSize,
}
msgHandler := MsgHandler{}
consumer := NewConsumer(awsCfg, config, &msgHandler)
consumer, err := NewConsumer(awsCfg, config, &msgHandler)
assert.NoError(t, err)
var wg sync.WaitGroup
wg.Add(2)
go func() {
Expand Down Expand Up @@ -147,6 +147,45 @@ func TestConsume_GracefulShutdown(t *testing.T) {
}, time.Second*2, time.Millisecond*100)
}

func TestConsume_ErrorsIfConfigIssues(t *testing.T) {
ctx, _ := context.WithTimeout(context.Background(), time.Second*10)
awsCfg := loadAWSDefaultConfig(ctx)

queueName := strings.ToLower(t.Name())
queueUrl := createQueue(t, ctx, awsCfg, queueName)

msgHandler := MsgHandlerWithIdleTrigger{
t: t,
msgsReceivedCount: 0,
}
tests := []struct {
name string
visibilityTimeoutSeconds int32
}{
{
name: "VisibilityTimeoutSeconds is less than 30",
visibilityTimeoutSeconds: int32(29),
},
{
name: "VisibilityTimeoutSeconds is less than 0",
visibilityTimeoutSeconds: int32(-1),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := Config{
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeoutSeconds: tt.visibilityTimeoutSeconds,
BatchSize: batchSize,
}
consumer, err := NewConsumer(awsCfg, config, &msgHandler)
assert.Error(t, err)
assert.Nil(t, consumer)
})
}
}

func createQueue(t *testing.T, ctx context.Context, awsCfg aws.Config, queueName string) *string {
sqsSvc := sqs.NewFromConfig(awsCfg)

Expand Down
24 changes: 7 additions & 17 deletions consumer_with_idle_trigger.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sqsclient

import (
"context"
"errors"
"fmt"
"sync"
"time"
Expand All @@ -21,15 +22,18 @@ type ConsumerWithIdleTrigger struct {
sqsReceiveWaitTimeSeconds int32
}

func NewConsumerWithIdleTrigger(awsCfg aws.Config, cfg Config, handler HandlerWithIdleTrigger, idleDurationTimeout time.Duration, sqsReceiveWaitTimeSeconds int32) *ConsumerWithIdleTrigger {
func NewConsumerWithIdleTrigger(awsCfg aws.Config, cfg Config, handler HandlerWithIdleTrigger, idleDurationTimeout time.Duration, sqsReceiveWaitTimeSeconds int32) (*ConsumerWithIdleTrigger, error) {
if cfg.VisibilityTimeoutSeconds < 30 {
return nil, errors.New("VisibilityTimeoutSeconds must be greater or equal to 30")
}
return &ConsumerWithIdleTrigger{
sqs: sqs.NewFromConfig(awsCfg),
handler: handler,
wg: &sync.WaitGroup{},
cfg: cfg,
idleDurationTimeout: idleDurationTimeout,
sqsReceiveWaitTimeSeconds: sqsReceiveWaitTimeSeconds,
}
}, nil
}

func (c *ConsumerWithIdleTrigger) Consume(ctx context.Context) {
Expand Down Expand Up @@ -61,6 +65,7 @@ loop:
MaxNumberOfMessages: c.cfg.BatchSize,
WaitTimeSeconds: c.sqsReceiveWaitTimeSeconds,
MessageAttributeNames: []string{"TraceID", "SpanID"},
VisibilityTimeout: c.cfg.VisibilityTimeoutSeconds,
})
if err != nil {
zap.S().With(zap.Error(err)).Error("could not receive messages from SQS")
Expand Down Expand Up @@ -103,9 +108,6 @@ func (c *ConsumerWithIdleTrigger) handleMsg(ctx context.Context, m *Message) err
return m.ErrorResponse(err)
}
} else {
if c.cfg.ExtendEnabled {
c.extend(ctx, m)
}
if err := c.handler.Run(ctx, m); err != nil {
return m.ErrorResponse(err)
}
Expand All @@ -125,15 +127,3 @@ func (c *ConsumerWithIdleTrigger) delete(ctx context.Context, m *Message) error
zap.S().Debug("message deleted")
return nil
}

func (c *ConsumerWithIdleTrigger) extend(ctx context.Context, m *Message) {
_, err := c.sqs.ChangeMessageVisibility(ctx, &sqs.ChangeMessageVisibilityInput{
QueueUrl: &c.cfg.QueueURL,
ReceiptHandle: m.ReceiptHandle,
VisibilityTimeout: c.cfg.VisibilityTimeout,
})
if err != nil {
zap.S().With(zap.Error(err)).Error("unable to extend message")
return
}
}
88 changes: 64 additions & 24 deletions consumer_with_idle_trigger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ func TestConsumeWithIdleTrigger(t *testing.T) {

msgHandler := handlerWithIdleTrigger(t, expectedMsg, expectedMsgAttributes)
config := Config{
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeout: visibilityTimeout,
BatchSize: batchSize,
ExtendEnabled: true,
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeoutSeconds: visibilityTimeout,
BatchSize: batchSize,
}
consumer := NewConsumerWithIdleTrigger(awsCfg, config, msgHandler, IdleTimeout, SqsReceiveWaitTimeSeconds)
consumer, err := NewConsumerWithIdleTrigger(awsCfg, config, msgHandler, IdleTimeout, SqsReceiveWaitTimeSeconds)
assert.NoError(t, err)
go consumer.Consume(ctx)

t.Cleanup(func() {
Expand Down Expand Up @@ -95,17 +95,17 @@ func TestConsumeWithIdleTimeout_GracefulShutdown(t *testing.T) {
queueUrl := createQueue(t, ctx, awsCfg, queueName)

config := Config{
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeout: visibilityTimeout,
BatchSize: batchSize,
ExtendEnabled: true,
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeoutSeconds: visibilityTimeout,
BatchSize: batchSize,
}
msgHandler := MsgHandlerWithIdleTrigger{
t: t,
msgsReceivedCount: 0,
}
consumer := NewConsumerWithIdleTrigger(awsCfg, config, &msgHandler, IdleTimeout, SqsReceiveWaitTimeSeconds)
consumer, err := NewConsumerWithIdleTrigger(awsCfg, config, &msgHandler, IdleTimeout, SqsReceiveWaitTimeSeconds)
assert.NoError(t, err)
var wg sync.WaitGroup
wg.Add(2)

Expand Down Expand Up @@ -146,17 +146,17 @@ func TestConsumeWithIdleTimeout_TimesOut(t *testing.T) {
queueUrl := createQueue(t, ctx, awsCfg, queueName)

config := Config{
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeout: visibilityTimeout,
BatchSize: batchSize,
ExtendEnabled: true,
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeoutSeconds: visibilityTimeout,
BatchSize: batchSize,
}
msgHandler := MsgHandlerWithIdleTrigger{
t: t,
msgsReceivedCount: 0,
}
consumer := NewConsumerWithIdleTrigger(awsCfg, config, &msgHandler, IdleTimeout, SqsReceiveWaitTimeSeconds)
consumer, err := NewConsumerWithIdleTrigger(awsCfg, config, &msgHandler, IdleTimeout, SqsReceiveWaitTimeSeconds)
assert.NoError(t, err)
go consumer.Consume(ctx)

t.Cleanup(func() {
Expand All @@ -169,6 +169,46 @@ func TestConsumeWithIdleTimeout_TimesOut(t *testing.T) {
// ensure that it gets called multiple times
assert.GreaterOrEqual(t, msgHandler.idleTimeoutTriggeredCount, 2)
}

func TestConsumeWithIdleTimeout_ErrorsIfConfigIssues(t *testing.T) {
ctx, _ := context.WithTimeout(context.Background(), time.Second*10)
awsCfg := loadAWSDefaultConfig(ctx)

queueName := strings.ToLower(t.Name())
queueUrl := createQueue(t, ctx, awsCfg, queueName)

msgHandler := MsgHandlerWithIdleTrigger{
t: t,
msgsReceivedCount: 0,
}
tests := []struct {
name string
visibilityTimeoutSeconds int32
}{
{
name: "VisibilityTimeoutSeconds is less than 30",
visibilityTimeoutSeconds: int32(29),
},
{
name: "VisibilityTimeoutSeconds is less than 0",
visibilityTimeoutSeconds: int32(-1),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := Config{
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeoutSeconds: tt.visibilityTimeoutSeconds,
BatchSize: batchSize,
}
consumer, err := NewConsumerWithIdleTrigger(awsCfg, config, &msgHandler, IdleTimeout, SqsReceiveWaitTimeSeconds)
t.Logf("error: %v", err)
assert.Error(t, err)
assert.Nil(t, consumer)
})
}
}
func TestConsumeWithIdleTimeout_TimesOutAndConsumes(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
awsCfg := loadAWSDefaultConfig(ctx)
Expand All @@ -189,14 +229,14 @@ func TestConsumeWithIdleTimeout_TimesOutAndConsumes(t *testing.T) {
}

config := Config{
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeout: visibilityTimeout,
BatchSize: batchSize,
ExtendEnabled: true,
QueueURL: *queueUrl,
WorkersNum: workersNum,
VisibilityTimeoutSeconds: visibilityTimeout,
BatchSize: batchSize,
}
msgHandler := handlerWithIdleTrigger(t, expectedMsg, expectedMsgAttributes)
consumer := NewConsumerWithIdleTrigger(awsCfg, config, msgHandler, IdleTimeout, SqsReceiveWaitTimeSeconds)
consumer, err := NewConsumerWithIdleTrigger(awsCfg, config, msgHandler, IdleTimeout, SqsReceiveWaitTimeSeconds)
assert.NoError(t, err)
go consumer.Consume(ctx)

t.Cleanup(func() {
Expand Down

0 comments on commit 0df1c8f

Please sign in to comment.