Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sftp): make sure to delete last file when watch and delete_on_finish are enabled #3037

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 115 additions & 91 deletions internal/impl/sftp/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,61 @@ func newSFTPReaderFromParsed(conf *service.ParsedConfig, mgr *service.Resources)
}

func (s *sftpReader) Connect(ctx context.Context) (err error) {
file, nextPath, skip, err := s.seekNextPath(ctx)
if err != nil {
return err
}
if skip {
return nil
}

details := service.NewScannerSourceDetails()
details.SetName(nextPath)
if s.scanner, err = s.scannerCtor.Create(file, func(ctx context.Context, aErr error) (outErr error) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the assignment of s.scanner needs to be under the mutex as well based on the ReadBatch function right?

_ = s.pathProvider.Ack(ctx, nextPath, aErr)
if aErr != nil {
return nil
}
if s.deleteOnFinish {
s.scannerMut.Lock()
client := s.client
if client == nil {
if client, outErr = s.creds.GetClient(s.mgr.FS(), s.address); outErr != nil {
outErr = fmt.Errorf("obtain private client: %w", outErr)
}
defer func() {
_ = client.Close()
}()
}
if outErr == nil {
if outErr = client.Remove(nextPath); outErr != nil {
outErr = fmt.Errorf("remove %v: %w", nextPath, outErr)
}
}
s.scannerMut.Unlock()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we always return after this block we could defer this right? I just get worried if the unlock is not right after.

}
return
}, details); err != nil {
_ = file.Close()
_ = s.pathProvider.Ack(ctx, nextPath, err)
return err
}

s.scannerMut.Lock()
s.currentPath = nextPath
s.scannerMut.Unlock()

s.log.Debugf("Consuming from file '%v'", nextPath)
return
}

func (s *sftpReader) initState(ctx context.Context) (client *sftp.Client, pathProvider pathProvider, skip bool, err error) {
s.scannerMut.Lock()
defer s.scannerMut.Unlock()

if s.scanner != nil {
return nil
skip = true
Copy link
Collaborator

@rockwotj rockwotj Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we skip if there is a scanner (and what are we skipping)? Can you add a comment?

return
}

if s.client == nil {
Expand All @@ -191,13 +241,22 @@ func (s *sftpReader) Connect(ctx context.Context) (err error) {
s.pathProvider = s.getFilePathProvider(ctx)
}

var nextPath string
var file *sftp.File
return s.client, s.pathProvider, false, nil
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a newcomer to this code this feels unsafe.

If the only important part of this code that all these variables are accessed/set together atomically, I do wonder if an atomic is better suited. You can use Swap to set the new value and destroy in Close, Store in Connect and Load in ReadBatch. I don't quite understand the higher level contract here of why it's only required that they are accessed concurrently and we don't have to worry about Close clobbering something ongoing in Connect or ReadBatch.

Copy link
Collaborator

@rockwotj rockwotj Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And when I talk about using atomics I mean using typed.AtomicValue in our typed package in internal/typed and wrapping all this state into a struct so it becomes typed.AtomicValue[*sftpReaderState]

}

func (s *sftpReader) seekNextPath(ctx context.Context) (file *sftp.File, nextPath string, skip bool, err error) {
client, pathProvider, skip, err := s.initState(ctx)
if err != nil || skip {
return
}

for {
if nextPath, err = s.pathProvider.Next(ctx, s.client); err != nil {
if nextPath, err = pathProvider.Next(ctx, client); err != nil {
if errors.Is(err, sftp.ErrSshFxConnectionLost) {
_ = s.client.Close()
_ = client.Close()
s.scannerMut.Lock()
s.client = nil
s.scannerMut.Unlock()
return
}
if errors.Is(err, errEndOfPaths) {
Expand All @@ -206,62 +265,28 @@ func (s *sftpReader) Connect(ctx context.Context) (err error) {
return
}

if file, err = s.client.Open(nextPath); err != nil {
if file, err = client.Open(nextPath); err != nil {
if errors.Is(err, sftp.ErrSshFxConnectionLost) {
_ = s.client.Close()
_ = client.Close()
s.scannerMut.Lock()
s.client = nil
s.scannerMut.Unlock()
}

s.log.With("path", nextPath, "err", err.Error()).Warn("Unable to open previously identified file")
if os.IsNotExist(err) {
// If we failed to open the file because it no longer exists
// then we can "ack" the path as we're done with it.
_ = s.pathProvider.Ack(ctx, nextPath, nil)
_ = pathProvider.Ack(ctx, nextPath, nil)
} else {
// Otherwise we "nack" it with the error as we'll want to
// reprocess it again later.
_ = s.pathProvider.Ack(ctx, nextPath, err)
_ = pathProvider.Ack(ctx, nextPath, err)
}
} else {
break
}
}

details := service.NewScannerSourceDetails()
details.SetName(nextPath)
if s.scanner, err = s.scannerCtor.Create(file, func(ctx context.Context, aErr error) (outErr error) {
_ = s.pathProvider.Ack(ctx, nextPath, aErr)
if aErr != nil {
return nil
}
if s.deleteOnFinish {
s.scannerMut.Lock()
client := s.client
if client == nil {
if client, outErr = s.creds.GetClient(s.mgr.FS(), s.address); outErr != nil {
outErr = fmt.Errorf("obtain private client: %w", outErr)
}
defer func() {
_ = client.Close()
}()
}
if outErr == nil {
if outErr = client.Remove(nextPath); outErr != nil {
outErr = fmt.Errorf("remove %v: %w", nextPath, outErr)
}
}
s.scannerMut.Unlock()
return
}
return
}, details); err != nil {
_ = file.Close()
_ = s.pathProvider.Ack(ctx, nextPath, err)
return err
}
s.currentPath = nextPath

s.log.Debugf("Consuming from file '%v'", nextPath)
return
}

func (s *sftpReader) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) {
Expand Down Expand Up @@ -297,9 +322,7 @@ func (s *sftpReader) ReadBatch(ctx context.Context) (service.MessageBatch, servi
part.MetaSetMut("sftp_path", currentPath)
}

return parts, func(ctx context.Context, res error) error {
return codecAckFn(ctx, res)
}, nil
return parts, codecAckFn, nil
}

func (s *sftpReader) Close(ctx context.Context) error {
Expand Down Expand Up @@ -363,61 +386,62 @@ type watcherPathProvider struct {
}

func (w *watcherPathProvider) Next(ctx context.Context, client *sftp.Client) (string, error) {
if len(w.expandedPaths) > 0 {
nextPath := w.expandedPaths[0]
w.expandedPaths = w.expandedPaths[1:]
return nextPath, nil
}

if waitFor := time.Until(w.nextPoll); waitFor > 0 {
w.nextPoll = time.Now().Add(w.pollInterval)
select {
case <-time.After(waitFor):
case <-ctx.Done():
return "", ctx.Err()
for {
if len(w.expandedPaths) > 0 {
nextPath := w.expandedPaths[0]
w.expandedPaths = w.expandedPaths[1:]
return nextPath, nil
}
}

if cerr := w.mgr.AccessCache(ctx, w.cacheName, func(cache service.Cache) {
for _, p := range w.targetPaths {
paths, err := client.Glob(p)
if err != nil {
w.mgr.Logger().With("error", err, "path", p).Warn("Failed to scan files from path")
continue
if waitFor := time.Until(w.nextPoll); w.nextPoll.IsZero() || waitFor > 0 {
w.nextPoll = time.Now().Add(w.pollInterval)
select {
case <-time.After(waitFor):
case <-ctx.Done():
return "", ctx.Err()
}
}

for _, path := range paths {
info, err := client.Stat(path)
if cerr := w.mgr.AccessCache(ctx, w.cacheName, func(cache service.Cache) {
for _, p := range w.targetPaths {
paths, err := client.Glob(p)
if err != nil {
w.mgr.Logger().With("error", err, "path", path).Warn("Failed to stat path")
continue
}
if time.Since(info.ModTime()) < w.minAge {
w.mgr.Logger().With("error", err, "path", p).Warn("Failed to scan files from path")
continue
}

// We process it if the marker is a pending symbol (!) and we're
// polling for the first time, or if the path isn't found in the
// cache.
//
// If we got an unexpected error obtaining a marker for this
// path from the cache then we skip that path because the
// watcher will eventually poll again, and the cache.Get
// operation will re-run.
if v, err := cache.Get(ctx, path); errors.Is(err, service.ErrKeyNotFound) || (!w.followUpPoll && string(v) == "!") {
w.expandedPaths = append(w.expandedPaths, path)
if err = cache.Set(ctx, path, []byte("!"), nil); err != nil {
// Mark the file target as pending so that we do not reprocess it
w.mgr.Logger().With("error", err, "path", path).Warn("Failed to mark path as pending")
for _, path := range paths {
info, err := client.Stat(path)
if err != nil {
w.mgr.Logger().With("error", err, "path", path).Warn("Failed to stat path")
continue
}
if time.Since(info.ModTime()) < w.minAge {
continue
}

// We process it if the marker is a pending symbol (!) and we're
// polling for the first time, or if the path isn't found in the
// cache.
//
// If we got an unexpected error obtaining a marker for this
// path from the cache then we skip that path because the
// watcher will eventually poll again, and the cache.Get
// operation will re-run.
if v, err := cache.Get(ctx, path); errors.Is(err, service.ErrKeyNotFound) || (!w.followUpPoll && string(v) == "!") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: separately it would be nice to use constants for the pending symbol and other cache values :)

w.expandedPaths = append(w.expandedPaths, path)
if err = cache.Set(ctx, path, []byte("!"), nil); err != nil {
// Mark the file target as pending so that we do not reprocess it
w.mgr.Logger().With("error", err, "path", path).Warn("Failed to mark path as pending")
}
}
}
}
}); cerr != nil {
return "", fmt.Errorf("error obtaining cache: %v", cerr)
}
}); cerr != nil {
return "", fmt.Errorf("error obtaining cache: %v", cerr)
w.followUpPoll = true
}
w.followUpPoll = true
return w.Next(ctx, client)
}

func (w *watcherPathProvider) Ack(ctx context.Context, name string, err error) (outErr error) {
Expand Down
Loading
Loading