diff --git a/CHANGELOG.md b/CHANGELOG.md index cecedb6f5f..5316fa96f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ All notable changes to this project will be documented in this file. ### Added +- `aws_sqs` now has a `max_outstanding` field to prevent unbounded memory usage. (@rockwotj) - `avro` scanner now emits metadata for the Avro schema it used along with the schema fingerprint. (@rockwotj) - Field `content_type` added to the `amqp_1` output. (@timo102) - `kafka_franz`, `ockam_kafka`, `redpanda`, `redpanda_common`, `redpanda_migrator` now support `fetch_max_wait` configuration field. @@ -18,6 +19,7 @@ All notable changes to this project will be documented in this file. - The `code` and `file` fields on the `javascript` processor docs no longer erroneously mention interpolation support. (@mihaitodor) - The `postgres_cdc` now correctly handles `null` values. (@rockwotj) +- Fix an issue in `aws_sqs` with refreshing in-flight message leases which could prevent acks from processed. (@rockwotj) ## 4.44.0 - 2024-12-13 diff --git a/docs/modules/components/pages/inputs/aws_sqs.adoc b/docs/modules/components/pages/inputs/aws_sqs.adoc index c0f34f44af..f61e9f051a 100644 --- a/docs/modules/components/pages/inputs/aws_sqs.adoc +++ b/docs/modules/components/pages/inputs/aws_sqs.adoc @@ -38,6 +38,7 @@ input: label: "" aws_sqs: url: "" # No default (required) + max_outstanding_messages: 1000 ``` -- @@ -54,7 +55,9 @@ input: delete_message: true reset_visibility: true max_number_of_messages: 10 + max_outstanding_messages: 1000 wait_time_seconds: 0 + message_timeout: 30s region: "" endpoint: "" credentials: @@ -127,6 +130,15 @@ The maximum number of messages to return on one poll. Valid values: 1 to 10. *Default*: `10` +=== `max_outstanding_messages` + +The maximum number of outstanding pending messages to be consumed at a given time. + + +*Type*: `int` + +*Default*: `1000` + === `wait_time_seconds` Whether to set the wait time. Enabling this activates long-polling. Valid values: 0 to 20. @@ -136,6 +148,15 @@ Whether to set the wait time. Enabling this activates long-polling. Valid values *Default*: `0` +=== `message_timeout` + +The time to process messages before needing to refresh the receipt handle. Messages will be eligible for refresh when half of the timeout has elapsed. This sets MessageVisibility for each received message. + + +*Type*: `string` + +*Default*: `"30s"` + === `region` The AWS region to target. diff --git a/internal/impl/aws/input_sqs.go b/internal/impl/aws/input_sqs.go index d02ca9ab4d..aa5e486725 100644 --- a/internal/impl/aws/input_sqs.go +++ b/internal/impl/aws/input_sqs.go @@ -15,8 +15,12 @@ package aws import ( + "container/list" "context" - "strconv" + "errors" + "fmt" + "slices" + "strings" "sync" "time" @@ -39,8 +43,8 @@ const ( sqsiFieldDeleteMessage = "delete_message" sqsiFieldResetVisibility = "reset_visibility" sqsiFieldMaxNumberOfMessages = "max_number_of_messages" - - sqsiAttributeNameVisibilityTimeout = "VisibilityTimeout" + sqsiFieldMaxOutstanding = "max_outstanding_messages" + sqsiFieldMessageTimeout = "message_timeout" ) type sqsiConfig struct { @@ -49,6 +53,8 @@ type sqsiConfig struct { DeleteMessage bool ResetVisibility bool MaxNumberOfMessages int + MaxOutstanding int + MessageTimeout time.Duration } func sqsiConfigFromParsed(pConf *service.ParsedConfig) (conf sqsiConfig, err error) { @@ -67,6 +73,12 @@ func sqsiConfigFromParsed(pConf *service.ParsedConfig) (conf sqsiConfig, err err if conf.MaxNumberOfMessages, err = pConf.FieldInt(sqsiFieldMaxNumberOfMessages); err != nil { return } + if conf.MaxOutstanding, err = pConf.FieldInt(sqsiFieldMaxOutstanding); err != nil { + return + } + if conf.MessageTimeout, err = pConf.FieldDuration(sqsiFieldMessageTimeout); err != nil { + return + } return } @@ -110,10 +122,17 @@ xref:configuration:interpolation.adoc#bloblang-queries[function interpolation].` Description("The maximum number of messages to return on one poll. Valid values: 1 to 10."). Default(10). Advanced(), - service.NewIntField("wait_time_seconds"). + service.NewIntField(sqsiFieldMaxOutstanding). + Description("The maximum number of outstanding pending messages to be consumed at a given time."). + Default(1000), + service.NewIntField(sqsiFieldWaitTimeSeconds). Description("Whether to set the wait time. Enabling this activates long-polling. Valid values: 0 to 20."). Default(0). Advanced(), + service.NewDurationField(sqsiFieldMessageTimeout). + Description("The time to process messages before needing to refresh the receipt handle. Messages will be eligible for refresh when half of the timeout has elapsed. This sets MessageVisibility for each received message."). + Default("30s"). + Advanced(), ). Fields(config.SessionFields()...) } @@ -153,9 +172,9 @@ type awsSQSReader struct { aconf aws.Config sqs sqsAPI - messagesChan chan types.Message - ackMessagesChan chan sqsMessageHandle - nackMessagesChan chan sqsMessageHandle + messagesChan chan sqsMessage + ackMessagesChan chan *sqsMessageHandle + nackMessagesChan chan *sqsMessageHandle closeSignal *shutdown.Signaller log *service.Logger @@ -166,9 +185,9 @@ func newAWSSQSReader(conf sqsiConfig, aconf aws.Config, log *service.Logger) (*a conf: conf, aconf: aconf, log: log, - messagesChan: make(chan types.Message), - ackMessagesChan: make(chan sqsMessageHandle), - nackMessagesChan: make(chan sqsMessageHandle), + messagesChan: make(chan sqsMessage), + ackMessagesChan: make(chan *sqsMessageHandle), + nackMessagesChan: make(chan *sqsMessageHandle), closeSignal: shutdown.NewSignaller(), }, nil } @@ -181,13 +200,18 @@ func (a *awsSQSReader) Connect(ctx context.Context) error { } ift := &sqsInFlightTracker{ - handles: map[string]sqsInFlightHandle{}, + handles: map[string]*list.Element{}, + fifo: list.New(), + limit: a.conf.MaxOutstanding, + timeout: a.conf.MessageTimeout, } + ift.l = sync.NewCond(&ift.m) var wg sync.WaitGroup - wg.Add(2) + wg.Add(3) go a.readLoop(&wg, ift) go a.ackLoop(&wg, ift) + go a.refreshLoop(&wg, ift) go func() { wg.Wait() a.closeSignal.TriggerHasStopped() @@ -195,135 +219,160 @@ func (a *awsSQSReader) Connect(ctx context.Context) error { return nil } -type sqsInFlightHandle struct { - receiptHandle string - timeoutSeconds int - addedAt time.Time -} - type sqsInFlightTracker struct { - handles map[string]sqsInFlightHandle + handles map[string]*list.Element + fifo *list.List // contains *sqsMessageHandle + limit int + timeout time.Duration m sync.Mutex + l *sync.Cond } -func (t *sqsInFlightTracker) PullToRefresh() (handles []sqsMessageHandle, timeoutSeconds int) { +func (t *sqsInFlightTracker) PullToRefresh(limit int) []*sqsMessageHandle { t.m.Lock() defer t.m.Unlock() - handles = make([]sqsMessageHandle, 0, len(t.handles)) - for k, v := range t.handles { - if time.Since(v.addedAt) < time.Second { - continue - } - handles = append(handles, sqsMessageHandle{ - id: k, - receiptHandle: v.receiptHandle, - }) - if v.timeoutSeconds > timeoutSeconds { - timeoutSeconds = v.timeoutSeconds + handles := make([]*sqsMessageHandle, 0, limit) + now := time.Now() + // Pull the front of our fifo until we reach our limit or we reach elements that do not + // need to be refreshed + for e := t.fifo.Front(); e != nil && len(handles) < limit; e = t.fifo.Front() { + v := e.Value.(*sqsMessageHandle) + if v.deadline.Sub(now) > (t.timeout / 2) { + break } + handles = append(handles, v) + v.deadline = now.Add(t.timeout) + // Keep our fifo in deadline sorted order + t.fifo.MoveToBack(e) } - return + return handles +} + +func (t *sqsInFlightTracker) Size() int { + t.m.Lock() + defer t.m.Unlock() + return len(t.handles) } func (t *sqsInFlightTracker) Remove(id string) { t.m.Lock() defer t.m.Unlock() - delete(t.handles, id) + entry, ok := t.handles[id] + if ok { + t.fifo.Remove(entry) + delete(t.handles, id) + } + t.l.Signal() } -func (t *sqsInFlightTracker) AddNew(messages ...types.Message) { +func (t *sqsInFlightTracker) IsTracking(id string) bool { t.m.Lock() defer t.m.Unlock() + _, ok := t.handles[id] + return ok +} - for _, m := range messages { - if m.MessageId == nil || m.ReceiptHandle == nil { - continue - } +func (t *sqsInFlightTracker) Clear() { + t.m.Lock() + defer t.m.Unlock() + clear(t.handles) + t.fifo = list.New() + t.l.Signal() +} - handle := sqsInFlightHandle{ - timeoutSeconds: 30, - receiptHandle: *m.ReceiptHandle, - addedAt: time.Now(), - } - if timeoutStr, exists := m.Attributes[sqsiAttributeNameVisibilityTimeout]; exists { - // Might as well keep the queue timeout setting refreshed as we - // consume new data. - if tmpTimeoutSeconds, err := strconv.Atoi(timeoutStr); err == nil { - handle.timeoutSeconds = tmpTimeoutSeconds - } +func (t *sqsInFlightTracker) AddNew(ctx context.Context, messages ...sqsMessage) { + t.m.Lock() + defer t.m.Unlock() + + // Treat this as a soft limit, we can burst over, but we should be able to make progress. + for len(t.handles) >= t.limit { + if ctx.Err() != nil { + return } - t.handles[*m.MessageId] = handle + t.l.Wait() } -} -func flushMapToHandles(m map[string]string) (s []sqsMessageHandle) { - s = make([]sqsMessageHandle, 0, len(m)) - for k, v := range m { - s = append(s, sqsMessageHandle{id: k, receiptHandle: v}) - delete(m, k) + for _, m := range messages { + if m.handle == nil { + continue + } + // If this is a duplicate (a re-recieve of an inflight message due to timeout) + // we can just update the existing handle. + if e, ok := t.handles[m.handle.id]; ok { + e.Value = m.handle + t.fifo.MoveToBack(e) + } else { + e := t.fifo.PushBack(m.handle) + t.handles[m.handle.id] = e + } } - return } func (a *awsSQSReader) ackLoop(wg *sync.WaitGroup, inFlightTracker *sqsInFlightTracker) { defer wg.Done() + defer inFlightTracker.Clear() closeNowCtx, done := a.closeSignal.HardStopCtx(context.Background()) defer done() - flushFinishedHandles := func(m map[string]string, erase bool) { - handles := flushMapToHandles(m) + flushFinishedHandles := func(handles []*sqsMessageHandle, erase bool) { if len(handles) == 0 { return } + seen := make(map[string]bool, len(handles)) + // deduplicate handles, unlikely that there are duplicates, so this is defensive. + handles = slices.DeleteFunc(handles, func(h *sqsMessageHandle) bool { + if seen[h.id] { + return true + } + seen[h.id] = true + return false + }) if erase { if err := a.deleteMessages(closeNowCtx, handles...); err != nil { a.log.Errorf("Failed to delete messages: %v", err) } } else { if err := a.resetMessages(closeNowCtx, handles...); err != nil { - a.log.Errorf("Failed to reset the visibility timeout of messages: %v", err) + // Downgrade this to Info level - it's not really an error, it's just going to take longer + // to reset the visibility so the messages might be delayed is all. It's possible for delays + // if this succeeds anyways as it might be racing with the refresh loop. Fixing that + // would mean moving nacks to the refresh loop, but I don't think this will be a big deal in + // practice. + a.log.Infof("Failed to reset the visibility timeout of messages: %v", err) } } } - refreshCurrentHandles := func() { - currentHandles, timeoutSeconds := inFlightTracker.PullToRefresh() - if len(currentHandles) == 0 { - return - } - if err := a.updateVisibilityMessages(closeNowCtx, timeoutSeconds, currentHandles...); err != nil { - a.log.Debugf("Failed to update messages visibility timeout: %v", err) - } - } - flushTimer := time.NewTicker(time.Second) defer flushTimer.Stop() - // Both maps are of the message ID to the receipt handle - pendingAcks := map[string]string{} - pendingNacks := map[string]string{} + pendingAcks := []*sqsMessageHandle{} + pendingNacks := []*sqsMessageHandle{} ackLoop: for { select { case h := <-a.ackMessagesChan: - pendingAcks[h.id] = h.receiptHandle + pendingAcks = append(pendingAcks, h) inFlightTracker.Remove(h.id) if len(pendingAcks) >= a.conf.MaxNumberOfMessages { flushFinishedHandles(pendingAcks, true) + pendingAcks = pendingAcks[:0] } case h := <-a.nackMessagesChan: - pendingNacks[h.id] = h.receiptHandle + pendingNacks = append(pendingNacks, h) inFlightTracker.Remove(h.id) if len(pendingNacks) >= a.conf.MaxNumberOfMessages { flushFinishedHandles(pendingNacks, false) + pendingNacks = pendingNacks[:0] } case <-flushTimer.C: flushFinishedHandles(pendingAcks, true) + pendingAcks = pendingAcks[:0] flushFinishedHandles(pendingNacks, false) - refreshCurrentHandles() + pendingNacks = pendingNacks[:0] case <-a.closeSignal.SoftStopChan(): break ackLoop } @@ -333,21 +382,65 @@ ackLoop: flushFinishedHandles(pendingNacks, false) } +func (a *awsSQSReader) refreshLoop(wg *sync.WaitGroup, inFlightTracker *sqsInFlightTracker) { + defer wg.Done() + closeNowCtx, done := a.closeSignal.HardStopCtx(context.Background()) + defer done() + refreshCurrentHandles := func() { + for !a.closeSignal.IsSoftStopSignalled() { + // updateVisibilityMessages can only make an API request with 10 messages at most, so grab 10 then refresh to prevent + // an issue where we grab a ton of messages and they are acked before we actual make the API call. Note that this scenario + // can still happen because we refresh async with acking, but this makes it a lot less likely. + currentHandles := inFlightTracker.PullToRefresh(10) + if len(currentHandles) == 0 { + // There is nothing to refresh, return and sleep for a second + return + } + err := a.updateVisibilityMessages(closeNowCtx, int(a.conf.MessageTimeout.Seconds()), currentHandles...) + if err == nil { + continue + } + partialErr := &batchUpdateVisibilityError{} + if errors.As(err, &partialErr) { + for _, fail := range partialErr.entries { + // Mitigate erroneous log statements due to the race described above by making sure we're still tracking the message + if !inFlightTracker.IsTracking(*fail.Id) { + continue + } + msg := "(no message)" + if fail.Message != nil { + msg = *fail.Message + } + a.log.Debugf("Failed to update SQS message '%v', response code: %v, message: %q, sender fault: %v", *fail.Id, *fail.Code, msg, fail.SenderFault) + } + } else { + a.log.Debugf("Failed to update messages visibility timeout: %v", err) + } + } + } + + for { + select { + case <-time.After(time.Second): + refreshCurrentHandles() + case <-a.closeSignal.SoftStopChan(): + return + } + } +} + func (a *awsSQSReader) readLoop(wg *sync.WaitGroup, inFlightTracker *sqsInFlightTracker) { defer wg.Done() - var pendingMsgs []types.Message + var pendingMsgs []sqsMessage defer func() { if len(pendingMsgs) > 0 { - tmpNacks := make([]sqsMessageHandle, 0, len(pendingMsgs)) + tmpNacks := make([]*sqsMessageHandle, 0, len(pendingMsgs)) for _, m := range pendingMsgs { - if m.MessageId == nil || m.ReceiptHandle == nil { + if m.handle == nil { continue } - tmpNacks = append(tmpNacks, sqsMessageHandle{ - id: *m.MessageId, - receiptHandle: *m.ReceiptHandle, - }) + tmpNacks = append(tmpNacks, m.handle) } ctx, done := a.closeSignal.HardStopCtx(context.Background()) defer done() @@ -371,6 +464,7 @@ func (a *awsSQSReader) readLoop(wg *sync.WaitGroup, inFlightTracker *sqsInFlight MaxNumberOfMessages: int32(a.conf.MaxNumberOfMessages), WaitTimeSeconds: int32(a.conf.WaitTimeSeconds), AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameAll}, + VisibilityTimeout: int32(a.conf.MessageTimeout.Seconds()), MessageAttributeNames: []string{"All"}, }) if err != nil { @@ -380,8 +474,21 @@ func (a *awsSQSReader) readLoop(wg *sync.WaitGroup, inFlightTracker *sqsInFlight return } if len(res.Messages) > 0 { - inFlightTracker.AddNew(res.Messages...) - pendingMsgs = append(pendingMsgs, res.Messages...) + for _, msg := range res.Messages { + var handle *sqsMessageHandle + if msg.MessageId != nil && msg.ReceiptHandle != nil { + handle = &sqsMessageHandle{ + id: *msg.MessageId, + receiptHandle: *msg.ReceiptHandle, + deadline: time.Now().Add(a.conf.MessageTimeout), + } + } + pendingMsgs = append(pendingMsgs, sqsMessage{ + Message: msg, + handle: handle, + }) + } + inFlightTracker.AddNew(closeAtLeisureCtx, pendingMsgs[len(pendingMsgs)-len(res.Messages):]...) } if len(res.Messages) > 0 || a.conf.WaitTimeSeconds > 0 { // When long polling we want to reset our back off even if we didn't @@ -412,24 +519,35 @@ func (a *awsSQSReader) readLoop(wg *sync.WaitGroup, inFlightTracker *sqsInFlight } } +type sqsMessage struct { + types.Message + handle *sqsMessageHandle +} + type sqsMessageHandle struct { id, receiptHandle string + // The timestamp of when the message expires + deadline time.Time } -func (a *awsSQSReader) deleteMessages(ctx context.Context, msgs ...sqsMessageHandle) error { +func (a *awsSQSReader) deleteMessages(ctx context.Context, msgs ...*sqsMessageHandle) error { + if !a.conf.DeleteMessage { + return nil + } + const maxBatchSize = 10 for len(msgs) > 0 { input := sqs.DeleteMessageBatchInput{ QueueUrl: aws.String(a.conf.URL), Entries: []types.DeleteMessageBatchRequestEntry{}, } - for _, msg := range msgs { - msg := msg + for i := range msgs { + msg := msgs[i] input.Entries = append(input.Entries, types.DeleteMessageBatchRequestEntry{ Id: &msg.id, ReceiptHandle: &msg.receiptHandle, }) - if len(input.Entries) == a.conf.MaxNumberOfMessages { + if len(input.Entries) == maxBatchSize { break } } @@ -440,48 +558,79 @@ func (a *awsSQSReader) deleteMessages(ctx context.Context, msgs ...sqsMessageHan return err } for _, fail := range response.Failed { - a.log.Errorf("Failed to delete consumed SQS message '%v', response code: %v\n", *fail.Id, *fail.Code) + msg := "(no message)" + if fail.Message != nil { + msg = *fail.Message + } + a.log.Errorf("Failed to delete consumed SQS message '%v', response code: %v, message: %q, sender fault: %v", *fail.Id, *fail.Code, msg, fail.SenderFault) } } return nil } -func (a *awsSQSReader) resetMessages(ctx context.Context, msgs ...sqsMessageHandle) error { +func (a *awsSQSReader) resetMessages(ctx context.Context, msgs ...*sqsMessageHandle) error { if !a.conf.ResetVisibility { return nil } - return a.updateVisibilityMessages(ctx, 0, msgs...) } -func (a *awsSQSReader) updateVisibilityMessages(ctx context.Context, timeout int, msgs ...sqsMessageHandle) error { +type batchUpdateVisibilityError struct { + entries []types.BatchResultErrorEntry +} + +func (err *batchUpdateVisibilityError) Error() string { + if len(err.entries) == 0 { + return "(no failures)" + } + var msg strings.Builder + msg.WriteString("failed to update visibility for messages: [") + for i, fail := range err.entries { + if i > 0 { + msg.WriteByte(',') + } + msg.WriteString(fmt.Sprintf("%q", *fail.Id)) + } + msg.WriteByte(']') + return msg.String() +} + +func (a *awsSQSReader) updateVisibilityMessages(ctx context.Context, timeout int, msgs ...*sqsMessageHandle) error { + const maxBatchSize = 10 + batchError := &batchUpdateVisibilityError{} for len(msgs) > 0 { input := sqs.ChangeMessageVisibilityBatchInput{ QueueUrl: aws.String(a.conf.URL), Entries: []types.ChangeMessageVisibilityBatchRequestEntry{}, } - for _, msg := range msgs { - msg := msg + for i := range msgs { + msg := msgs[i] input.Entries = append(input.Entries, types.ChangeMessageVisibilityBatchRequestEntry{ Id: &msg.id, ReceiptHandle: &msg.receiptHandle, VisibilityTimeout: int32(timeout), }) - if len(input.Entries) == a.conf.MaxNumberOfMessages { + if len(input.Entries) == maxBatchSize { break } } msgs = msgs[len(input.Entries):] + if len(input.Entries) == 0 { + continue + } response, err := a.sqs.ChangeMessageVisibilityBatch(ctx, &input) if err != nil { return err } - for _, fail := range response.Failed { - a.log.Debugf("Failed to update consumed SQS message '%v' visibility, response code: %v\n", *fail.Id, *fail.Code) + if len(response.Failed) != 0 { + batchError.entries = append(batchError.entries, response.Failed...) } } + if len(batchError.entries) > 0 { + return batchError + } return nil } @@ -504,7 +653,7 @@ func (a *awsSQSReader) Read(ctx context.Context) (*service.Message, service.AckF return nil, nil, service.ErrNotConnected } - var next types.Message + var next sqsMessage var open bool select { case next, open = <-a.messagesChan: @@ -522,23 +671,13 @@ func (a *awsSQSReader) Read(ctx context.Context) (*service.Message, service.AckF } msg := service.NewMessage([]byte(*next.Body)) - addSQSMetadata(msg, next) - - mHandle := sqsMessageHandle{ - id: *next.MessageId, - } - if next.ReceiptHandle != nil { - mHandle.receiptHandle = *next.ReceiptHandle - } + addSQSMetadata(msg, next.Message) + mHandle := next.handle return msg, func(rctx context.Context, res error) error { - if mHandle.receiptHandle == "" { + if mHandle == nil { return nil } - if res == nil { - if !a.conf.DeleteMessage { - return nil - } select { case <-rctx.Done(): return rctx.Err() diff --git a/internal/impl/aws/input_sqs_test.go b/internal/impl/aws/input_sqs_test.go index 91e48292cd..84c5094a3e 100644 --- a/internal/impl/aws/input_sqs_test.go +++ b/internal/impl/aws/input_sqs_test.go @@ -17,7 +17,8 @@ package aws import ( "context" "fmt" - "strconv" + "slices" + "sync" "testing" "time" @@ -35,15 +36,15 @@ import ( type mockSqsInput struct { sqsAPI - mtx chan struct{} + mtx sync.Mutex queueTimeout int32 messages []types.Message mesTimeouts map[string]int32 } func (m *mockSqsInput) do(fn func()) { - <-m.mtx - defer func() { m.mtx <- struct{}{} }() + m.mtx.Lock() + defer m.mtx.Unlock() fn() } @@ -54,7 +55,7 @@ func (m *mockSqsInput) TimeoutLoop(ctx context.Context) { for { select { case <-t.C: - <-m.mtx + m.mtx.Lock() for mesID, timeout := range m.mesTimeouts { timeout = timeout - 1 @@ -65,7 +66,7 @@ func (m *mockSqsInput) TimeoutLoop(ctx context.Context) { } } - m.mtx <- struct{}{} + m.mtx.Unlock() case <-ctx.Done(): return } @@ -73,8 +74,8 @@ func (m *mockSqsInput) TimeoutLoop(ctx context.Context) { } func (m *mockSqsInput) ReceiveMessage(context.Context, *sqs.ReceiveMessageInput, ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { - <-m.mtx - defer func() { m.mtx <- struct{}{} }() + m.mtx.Lock() + defer m.mtx.Unlock() messages := make([]types.Message, 0, len(m.messages)) @@ -88,13 +89,9 @@ func (m *mockSqsInput) ReceiveMessage(context.Context, *sqs.ReceiveMessageInput, return &sqs.ReceiveMessageOutput{Messages: messages}, nil } -func (m *mockSqsInput) GetQueueAttributes(input *sqs.GetQueueAttributesInput) (*sqs.GetQueueAttributesOutput, error) { - return &sqs.GetQueueAttributesOutput{Attributes: map[string]string{sqsiAttributeNameVisibilityTimeout: strconv.Itoa(int(m.queueTimeout))}}, nil -} - func (m *mockSqsInput) ChangeMessageVisibilityBatch(ctx context.Context, input *sqs.ChangeMessageVisibilityBatchInput, opts ...func(*sqs.Options)) (*sqs.ChangeMessageVisibilityBatchOutput, error) { - <-m.mtx - defer func() { m.mtx <- struct{}{} }() + m.mtx.Lock() + defer m.mtx.Unlock() for _, entry := range input.Entries { if _, found := m.mesTimeouts[*entry.Id]; found { @@ -108,16 +105,14 @@ func (m *mockSqsInput) ChangeMessageVisibilityBatch(ctx context.Context, input * } func (m *mockSqsInput) DeleteMessageBatch(ctx context.Context, input *sqs.DeleteMessageBatchInput, opts ...func(*sqs.Options)) (*sqs.DeleteMessageBatchOutput, error) { - <-m.mtx - defer func() { m.mtx <- struct{}{} }() + m.mtx.Lock() + defer m.mtx.Unlock() for _, entry := range input.Entries { delete(m.mesTimeouts, *entry.Id) - for i, message := range m.messages { - if *entry.Id == *message.MessageId { - m.messages = append(m.messages[:i], m.messages[i+1:]...) - } - } + m.messages = slices.DeleteFunc(m.messages, func(msg types.Message) bool { + return *entry.Id == *msg.MessageId + }) } return &sqs.DeleteMessageBatchOutput{}, nil @@ -158,6 +153,8 @@ func TestSQSInput(t *testing.T) { DeleteMessage: true, ResetVisibility: true, MaxNumberOfMessages: 10, + MaxOutstanding: 100, + MessageTimeout: 10 * time.Second, }, conf, nil, @@ -165,12 +162,10 @@ func TestSQSInput(t *testing.T) { require.NoError(t, err) mockInput := &mockSqsInput{ - mtx: make(chan struct{}, 1), queueTimeout: 10, messages: messages, mesTimeouts: make(map[string]int32, expectedMessages), } - mockInput.mtx <- struct{}{} r.sqs = mockInput go mockInput.TimeoutLoop(tCtx) @@ -178,7 +173,7 @@ func TestSQSInput(t *testing.T) { err = r.Connect(tCtx) require.NoError(t, err) - receivedMessages := make([]types.Message, 0, expectedMessages) + receivedMessages := make([]sqsMessage, 0, expectedMessages) // Check that all messages are received from the reader require.Eventually(t, func() bool { @@ -192,7 +187,7 @@ func TestSQSInput(t *testing.T) { } } return len(receivedMessages) == expectedMessages - }, 30*time.Second, time.Second) + }, 30*time.Second, 100*time.Millisecond) // Wait over the defined queue timeout and check that messages have not been received again time.Sleep(time.Duration(mockInput.queueTimeout+5) * time.Second) @@ -209,7 +204,9 @@ func TestSQSInput(t *testing.T) { // Ack all messages and ensure that they are deleted from SQS for _, message := range receivedMessages { - r.ackMessagesChan <- sqsMessageHandle{id: *message.MessageId, receiptHandle: *message.ReceiptHandle} + if message.handle != nil { + r.ackMessagesChan <- message.handle + } } require.Eventually(t, func() bool { @@ -218,7 +215,7 @@ func TestSQSInput(t *testing.T) { msgsLen = len(mockInput.messages) }) return msgsLen == 0 - }, 5*time.Second, time.Second) + }, 5*time.Second, 100*time.Millisecond) } func TestSQSInputBatchAck(t *testing.T) { @@ -247,6 +244,8 @@ func TestSQSInputBatchAck(t *testing.T) { DeleteMessage: true, ResetVisibility: true, MaxNumberOfMessages: 10, + MaxOutstanding: 100, + MessageTimeout: 10 * time.Second, }, conf, nil, @@ -254,12 +253,10 @@ func TestSQSInputBatchAck(t *testing.T) { require.NoError(t, err) mockInput := &mockSqsInput{ - mtx: make(chan struct{}, 1), queueTimeout: 10, messages: messages, mesTimeouts: make(map[string]int32, expectedMessages), } - mockInput.mtx <- struct{}{} r.sqs = mockInput go mockInput.TimeoutLoop(tCtx)