Skip to content

Commit

Permalink
fix(sftp): reduce criticl sections of mutexes
Browse files Browse the repository at this point in the history
ReadBatch was holding the state lock the while it polled for new files,
which blocked AckFns from cleaning up successfully processed files when
deleteOnFinish is set to true.
  • Loading branch information
ooesili committed Jan 7, 2025
1 parent 0142589 commit b22a741
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 68 deletions.
187 changes: 120 additions & 67 deletions internal/impl/sftp/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,12 @@ type sftpReader struct {

// State
stateLock sync.Mutex
client *sftp.Client
scanner codec.DeprecatedFallbackStream
currentPath string
closed bool

clientLock sync.Mutex
client *sftp.Client
clientClosedForever bool

pathProvider pathProvider
}
Expand Down Expand Up @@ -178,60 +180,80 @@ func (s *sftpReader) Connect(ctx context.Context) error {
s.stateLock.Lock()
defer s.stateLock.Unlock()

client, cleanup, err := s.initClient()
client, releaseClient, err := s.borrowClient()
if err != nil {
if errors.Is(err, sftp.ErrSSHFxConnectionLost) {
err = service.ErrNotConnected
}
return err
}
defer cleanup()
defer releaseClient()

if s.pathProvider == nil {
s.pathProvider = s.getFilePathProvider(client)
}
return nil
}

func (s *sftpReader) initClient() (*sftp.Client, func(), error) {
type clientBorrower func() (client *sftp.Client, release func(), err error)

func (s *sftpReader) borrowClient() (*sftp.Client, func(), error) {
s.clientLock.Lock()
defer s.clientLock.Unlock()

if s.client != nil {
return s.client, func() {}, nil
}

client, err := s.creds.GetClient(s.mgr.FS(), s.address)
if err != nil {
return nil, nil, fmt.Errorf("initializing SFTP client: %w", err)
}

if s.clientClosedForever {
release := func() { s.closeClient(false) }
return client, release, nil
}

if s.clientClosedForever {
client, err := s.creds.GetClient(s.mgr.FS(), s.address)
if err != nil {
return nil, nil, fmt.Errorf("initializing SFTP client: %w", err)
}
return client, func() { s.closeClient(true) }, nil
}

if s.client == nil {
var err error
if s.client, err = s.creds.GetClient(s.mgr.FS(), s.address); err != nil {
return nil, nil, fmt.Errorf("initializing SFTP client: %w", err)
}
}

cleanup := func() {}
if s.closed {
cleanup = s.closeClient
}

return s.client, cleanup, nil
return s.client, func() {}, nil
}

func (s *sftpReader) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) {
s.stateLock.Lock()
defer s.stateLock.Unlock()

parts, codecAckFn, err := s.tryReadBatch(ctx)
if errors.Is(err, sftp.ErrSSHFxConnectionLost) {
s.closeScanner(ctx)
s.closeClient()
return nil, nil, service.ErrNotConnected
if err != nil {
if errors.Is(err, sftp.ErrSSHFxConnectionLost) {
s.stateLock.Lock()
s.closeScanner(ctx)
s.closeClient(false)
s.stateLock.Unlock()
err = service.ErrNotConnected
}
return nil, nil, err
}
return parts, codecAckFn, nil
}

func (s *sftpReader) tryReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) {
client, cleanup, err := s.initClient()
if err != nil {
if err := s.initScanner(ctx); err != nil {
return nil, nil, err
}
defer cleanup()

if err := s.initScanner(ctx, client); err != nil {
return nil, nil, err
}
s.stateLock.Lock()
defer s.stateLock.Unlock()

parts, codecAckFn, err := s.scanner.NextBatch(ctx)
if err != nil {
Expand All @@ -254,8 +276,11 @@ func (s *sftpReader) tryReadBatch(ctx context.Context) (service.MessageBatch, se
return parts, codecAckFn, nil
}

func (s *sftpReader) initScanner(ctx context.Context, client *sftp.Client) error {
if s.scanner != nil {
func (s *sftpReader) initScanner(ctx context.Context) error {
s.stateLock.Lock()
scanner := s.scanner
s.stateLock.Unlock()
if scanner != nil {
return nil
}

Expand All @@ -264,15 +289,20 @@ func (s *sftpReader) initScanner(ctx context.Context, client *sftp.Client) error
for {
var ok bool
var err error
path, ok, err = s.pathProvider.Next(ctx, client)
path, ok, err = s.pathProvider.Next(ctx, s.borrowClient)
if err != nil {
return fmt.Errorf("finding next file path: %w", err)
}
if !ok {
return service.ErrEndOfInput
}

client, releaseClient, err := s.borrowClient()
if err != nil {
return err
}
file, err = client.Open(path)
releaseClient()
if err != nil {
s.log.With("path", path, "err", err.Error()).Warn("Unable to open previously identified file")
if os.IsNotExist(err) {
Expand All @@ -295,8 +325,10 @@ func (s *sftpReader) initScanner(ctx context.Context, client *sftp.Client) error
return fmt.Errorf("creating scanner: %w", err)
}

s.stateLock.Lock()
s.scanner = scanner
s.currentPath = path
s.stateLock.Unlock()
return nil
}
}
Expand All @@ -314,17 +346,19 @@ func (s *sftpReader) newCodecAckFn(path string) service.AckFunc {
}

if s.deleteOnFinish {
client, cleanup, err := s.initClient()
client, releaseClient, err := s.borrowClient()
if err != nil {
return err
}
defer cleanup()
defer releaseClient()

if err := client.Remove(path); err != nil {
return fmt.Errorf("remove %v: %w", path, err)
}
}

time.Sleep(time.Millisecond * 100)

return nil
}
}
Expand All @@ -334,8 +368,7 @@ func (s *sftpReader) Close(ctx context.Context) error {
defer s.stateLock.Unlock()

s.closeScanner(ctx)
s.closeClient()
s.closed = true
s.closeClient(true)
return nil
}

Expand All @@ -349,25 +382,31 @@ func (s *sftpReader) closeScanner(ctx context.Context) {
}
}

func (s *sftpReader) closeClient() {
func (s *sftpReader) closeClient(closedForever bool) {
s.clientLock.Lock()
defer s.clientLock.Unlock()

if s.client == nil {
if err := s.client.Close(); err != nil {
s.log.With("error", err).Error("Failed to close client")
}
s.client = nil
}
if closedForever {
s.clientClosedForever = true
}
}

type pathProvider interface {
Next(context.Context, *sftp.Client) (string, bool, error)
Next(context.Context, clientBorrower) (string, bool, error)
Ack(context.Context, string, error) error
}

type staticPathProvider struct {
expandedPaths []string
}

func (s *staticPathProvider) Next(context.Context, *sftp.Client) (string, bool, error) {
func (s *staticPathProvider) Next(context.Context, clientBorrower) (string, bool, error) {
if len(s.expandedPaths) == 0 {
return "", false, nil
}
Expand All @@ -392,7 +431,7 @@ type watcherPathProvider struct {
followUpPoll bool
}

func (w *watcherPathProvider) Next(ctx context.Context, client *sftp.Client) (string, bool, error) {
func (w *watcherPathProvider) Next(ctx context.Context, borrowClient clientBorrower) (string, bool, error) {
for {
if len(w.expandedPaths) > 0 {
nextPath := w.expandedPaths[0]
Expand All @@ -409,46 +448,60 @@ func (w *watcherPathProvider) Next(ctx context.Context, client *sftp.Client) (st
}
}

if cerr := w.mgr.AccessCache(ctx, w.cacheName, func(cache service.Cache) {
for _, p := range w.targetPaths {
paths, err := client.Glob(p)
if err := w.findNewPaths(ctx, borrowClient); err != nil {
return "", false, fmt.Errorf("expanding new paths: %w", err)
}
w.followUpPoll = true
}
}

func (w *watcherPathProvider) findNewPaths(ctx context.Context, borrowClient clientBorrower) error {
client, releaseClient, err := borrowClient()
if err != nil {
return fmt.Errorf("obtaining sftp client: %w", err)
}
defer releaseClient()

if cerr := w.mgr.AccessCache(ctx, w.cacheName, func(cache service.Cache) {
for _, p := range w.targetPaths {
paths, err := client.Glob(p)
if err != nil {
w.mgr.Logger().With("error", err, "path", p).Warn("Failed to scan files from path")
continue
}

for _, path := range paths {
info, err := client.Stat(path)
if err != nil {
w.mgr.Logger().With("error", err, "path", p).Warn("Failed to scan files from path")
w.mgr.Logger().With("error", err, "path", path).Warn("Failed to stat path")
continue
}
if time.Since(info.ModTime()) < w.minAge {
continue
}

for _, path := range paths {
info, err := client.Stat(path)
if err != nil {
w.mgr.Logger().With("error", err, "path", path).Warn("Failed to stat path")
continue
}
if time.Since(info.ModTime()) < w.minAge {
continue
}

// We process it if the marker is a pending symbol (!) and we're
// polling for the first time, or if the path isn't found in the
// cache.
//
// If we got an unexpected error obtaining a marker for this
// path from the cache then we skip that path because the
// watcher will eventually poll again, and the cache.Get
// operation will re-run.
if v, err := cache.Get(ctx, path); errors.Is(err, service.ErrKeyNotFound) || (!w.followUpPoll && string(v) == "!") {
w.expandedPaths = append(w.expandedPaths, path)
if err = cache.Set(ctx, path, []byte("!"), nil); err != nil {
// Mark the file target as pending so that we do not reprocess it
w.mgr.Logger().With("error", err, "path", path).Warn("Failed to mark path as pending")
}
// We process it if the marker is a pending symbol (!) and we're
// polling for the first time, or if the path isn't found in the
// cache.
//
// If we got an unexpected error obtaining a marker for this
// path from the cache then we skip that path because the
// watcher will eventually poll again, and the cache.Get
// operation will re-run.
if v, err := cache.Get(ctx, path); errors.Is(err, service.ErrKeyNotFound) || (!w.followUpPoll && string(v) == "!") {
w.expandedPaths = append(w.expandedPaths, path)
if err = cache.Set(ctx, path, []byte("!"), nil); err != nil {
// Mark the file target as pending so that we do not reprocess it
w.mgr.Logger().With("error", err, "path", path).Warn("Failed to mark path as pending")
}
}
}
}); cerr != nil {
return "", false, fmt.Errorf("error obtaining cache: %v", cerr)
}
w.followUpPoll = true
}); cerr != nil {
return fmt.Errorf("error obtaining cache: %v", cerr)
}

return nil
}

func (w *watcherPathProvider) Ack(ctx context.Context, name string, err error) (outErr error) {
Expand Down
2 changes: 1 addition & 1 deletion internal/impl/sftp/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ cache_resources:
files, err := client.Glob("/upload/*.txt")
assert.NoError(c, err)
assert.Empty(c, files)
}, time.Second, time.Millisecond*100)
}, time.Second*10, time.Millisecond*100)
}

func setupDockerPool(t *testing.T) *dockertest.Resource {
Expand Down

0 comments on commit b22a741

Please sign in to comment.