diff --git a/internal/impl/sftp/input.go b/internal/impl/sftp/input.go index e4fb3bed94..e980616b0a 100644 --- a/internal/impl/sftp/input.go +++ b/internal/impl/sftp/input.go @@ -174,11 +174,61 @@ func newSFTPReaderFromParsed(conf *service.ParsedConfig, mgr *service.Resources) } func (s *sftpReader) Connect(ctx context.Context) (err error) { + file, nextPath, skip, err := s.seekNextPath(ctx) + if err != nil { + return err + } + if skip { + return nil + } + + details := service.NewScannerSourceDetails() + details.SetName(nextPath) + if s.scanner, err = s.scannerCtor.Create(file, func(ctx context.Context, aErr error) (outErr error) { + _ = s.pathProvider.Ack(ctx, nextPath, aErr) + if aErr != nil { + return nil + } + if s.deleteOnFinish { + s.scannerMut.Lock() + client := s.client + if client == nil { + if client, outErr = s.creds.GetClient(s.mgr.FS(), s.address); outErr != nil { + outErr = fmt.Errorf("obtain private client: %w", outErr) + } + defer func() { + _ = client.Close() + }() + } + if outErr == nil { + if outErr = client.Remove(nextPath); outErr != nil { + outErr = fmt.Errorf("remove %v: %w", nextPath, outErr) + } + } + s.scannerMut.Unlock() + } + return + }, details); err != nil { + _ = file.Close() + _ = s.pathProvider.Ack(ctx, nextPath, err) + return err + } + + s.scannerMut.Lock() + s.currentPath = nextPath + s.scannerMut.Unlock() + + s.log.Debugf("Consuming from file '%v'", nextPath) + return +} + +func (s *sftpReader) initState(ctx context.Context) (client *sftp.Client, pathProvider pathProvider, skip bool, err error) { s.scannerMut.Lock() defer s.scannerMut.Unlock() if s.scanner != nil { - return nil + skip = true + return } if s.client == nil { @@ -191,13 +241,22 @@ func (s *sftpReader) Connect(ctx context.Context) (err error) { s.pathProvider = s.getFilePathProvider(ctx) } - var nextPath string - var file *sftp.File + return s.client, s.pathProvider, false, nil +} + +func (s *sftpReader) seekNextPath(ctx context.Context) (file *sftp.File, nextPath string, skip bool, err error) { + client, pathProvider, skip, err := s.initState(ctx) + if err != nil || skip { + return + } + for { - if nextPath, err = s.pathProvider.Next(ctx, s.client); err != nil { + if nextPath, err = pathProvider.Next(ctx, client); err != nil { if errors.Is(err, sftp.ErrSshFxConnectionLost) { - _ = s.client.Close() + _ = client.Close() + s.scannerMut.Lock() s.client = nil + s.scannerMut.Unlock() return } if errors.Is(err, errEndOfPaths) { @@ -206,62 +265,28 @@ func (s *sftpReader) Connect(ctx context.Context) (err error) { return } - if file, err = s.client.Open(nextPath); err != nil { + if file, err = client.Open(nextPath); err != nil { if errors.Is(err, sftp.ErrSshFxConnectionLost) { - _ = s.client.Close() + _ = client.Close() + s.scannerMut.Lock() s.client = nil + s.scannerMut.Unlock() } s.log.With("path", nextPath, "err", err.Error()).Warn("Unable to open previously identified file") if os.IsNotExist(err) { // If we failed to open the file because it no longer exists // then we can "ack" the path as we're done with it. - _ = s.pathProvider.Ack(ctx, nextPath, nil) + _ = pathProvider.Ack(ctx, nextPath, nil) } else { // Otherwise we "nack" it with the error as we'll want to // reprocess it again later. - _ = s.pathProvider.Ack(ctx, nextPath, err) + _ = pathProvider.Ack(ctx, nextPath, err) } } else { - break - } - } - - details := service.NewScannerSourceDetails() - details.SetName(nextPath) - if s.scanner, err = s.scannerCtor.Create(file, func(ctx context.Context, aErr error) (outErr error) { - _ = s.pathProvider.Ack(ctx, nextPath, aErr) - if aErr != nil { - return nil - } - if s.deleteOnFinish { - s.scannerMut.Lock() - client := s.client - if client == nil { - if client, outErr = s.creds.GetClient(s.mgr.FS(), s.address); outErr != nil { - outErr = fmt.Errorf("obtain private client: %w", outErr) - } - defer func() { - _ = client.Close() - }() - } - if outErr == nil { - if outErr = client.Remove(nextPath); outErr != nil { - outErr = fmt.Errorf("remove %v: %w", nextPath, outErr) - } - } - s.scannerMut.Unlock() + return } - return - }, details); err != nil { - _ = file.Close() - _ = s.pathProvider.Ack(ctx, nextPath, err) - return err } - s.currentPath = nextPath - - s.log.Debugf("Consuming from file '%v'", nextPath) - return } func (s *sftpReader) ReadBatch(ctx context.Context) (service.MessageBatch, service.AckFunc, error) { @@ -297,9 +322,7 @@ func (s *sftpReader) ReadBatch(ctx context.Context) (service.MessageBatch, servi part.MetaSetMut("sftp_path", currentPath) } - return parts, func(ctx context.Context, res error) error { - return codecAckFn(ctx, res) - }, nil + return parts, codecAckFn, nil } func (s *sftpReader) Close(ctx context.Context) error { @@ -363,61 +386,62 @@ type watcherPathProvider struct { } func (w *watcherPathProvider) Next(ctx context.Context, client *sftp.Client) (string, error) { - if len(w.expandedPaths) > 0 { - nextPath := w.expandedPaths[0] - w.expandedPaths = w.expandedPaths[1:] - return nextPath, nil - } - - if waitFor := time.Until(w.nextPoll); waitFor > 0 { - w.nextPoll = time.Now().Add(w.pollInterval) - select { - case <-time.After(waitFor): - case <-ctx.Done(): - return "", ctx.Err() + for { + if len(w.expandedPaths) > 0 { + nextPath := w.expandedPaths[0] + w.expandedPaths = w.expandedPaths[1:] + return nextPath, nil } - } - if cerr := w.mgr.AccessCache(ctx, w.cacheName, func(cache service.Cache) { - for _, p := range w.targetPaths { - paths, err := client.Glob(p) - if err != nil { - w.mgr.Logger().With("error", err, "path", p).Warn("Failed to scan files from path") - continue + if waitFor := time.Until(w.nextPoll); w.nextPoll.IsZero() || waitFor > 0 { + w.nextPoll = time.Now().Add(w.pollInterval) + select { + case <-time.After(waitFor): + case <-ctx.Done(): + return "", ctx.Err() } + } - for _, path := range paths { - info, err := client.Stat(path) + if cerr := w.mgr.AccessCache(ctx, w.cacheName, func(cache service.Cache) { + for _, p := range w.targetPaths { + paths, err := client.Glob(p) if err != nil { - w.mgr.Logger().With("error", err, "path", path).Warn("Failed to stat path") - continue - } - if time.Since(info.ModTime()) < w.minAge { + w.mgr.Logger().With("error", err, "path", p).Warn("Failed to scan files from path") continue } - // We process it if the marker is a pending symbol (!) and we're - // polling for the first time, or if the path isn't found in the - // cache. - // - // If we got an unexpected error obtaining a marker for this - // path from the cache then we skip that path because the - // watcher will eventually poll again, and the cache.Get - // operation will re-run. - if v, err := cache.Get(ctx, path); errors.Is(err, service.ErrKeyNotFound) || (!w.followUpPoll && string(v) == "!") { - w.expandedPaths = append(w.expandedPaths, path) - if err = cache.Set(ctx, path, []byte("!"), nil); err != nil { - // Mark the file target as pending so that we do not reprocess it - w.mgr.Logger().With("error", err, "path", path).Warn("Failed to mark path as pending") + for _, path := range paths { + info, err := client.Stat(path) + if err != nil { + w.mgr.Logger().With("error", err, "path", path).Warn("Failed to stat path") + continue + } + if time.Since(info.ModTime()) < w.minAge { + continue + } + + // We process it if the marker is a pending symbol (!) and we're + // polling for the first time, or if the path isn't found in the + // cache. + // + // If we got an unexpected error obtaining a marker for this + // path from the cache then we skip that path because the + // watcher will eventually poll again, and the cache.Get + // operation will re-run. + if v, err := cache.Get(ctx, path); errors.Is(err, service.ErrKeyNotFound) || (!w.followUpPoll && string(v) == "!") { + w.expandedPaths = append(w.expandedPaths, path) + if err = cache.Set(ctx, path, []byte("!"), nil); err != nil { + // Mark the file target as pending so that we do not reprocess it + w.mgr.Logger().With("error", err, "path", path).Warn("Failed to mark path as pending") + } } } } + }); cerr != nil { + return "", fmt.Errorf("error obtaining cache: %v", cerr) } - }); cerr != nil { - return "", fmt.Errorf("error obtaining cache: %v", cerr) + w.followUpPoll = true } - w.followUpPoll = true - return w.Next(ctx, client) } func (w *watcherPathProvider) Ack(ctx context.Context, name string, err error) (outErr error) { diff --git a/internal/impl/sftp/integration_test.go b/internal/impl/sftp/integration_test.go index be86760960..887f5963e1 100644 --- a/internal/impl/sftp/integration_test.go +++ b/internal/impl/sftp/integration_test.go @@ -15,15 +15,22 @@ package sftp import ( + "context" + "errors" + "fmt" "io/fs" "os" + "strings" + "sync" "testing" "time" "github.com/ory/dockertest/v3" + "github.com/pkg/sftp" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/redpanda-data/benthos/v4/public/service" "github.com/redpanda-data/benthos/v4/public/service/integration" // Bring in memory cache. @@ -39,34 +46,7 @@ func TestIntegrationSFTP(t *testing.T) { integration.CheckSkip(t) t.Parallel() - pool, err := dockertest.NewPool("") - require.NoError(t, err) - - pool.MaxWait = time.Second * 30 - resource, err := pool.RunWithOptions(&dockertest.RunOptions{ - Repository: "atmoz/sftp", - Tag: "alpine", - Cmd: []string{ - // https://github.com/atmoz/sftp/issues/401 - "/bin/sh", "-c", "ulimit -n 65535 && exec /entrypoint " + sftpUsername + ":" + sftpPassword + ":1001:100:upload", - }, - }) - require.NoError(t, err) - t.Cleanup(func() { - assert.NoError(t, pool.Purge(resource)) - }) - - _ = resource.Expire(900) - - creds := credentials{ - Username: sftpUsername, - Password: sftpPassword, - } - - require.NoError(t, pool.Retry(func() error { - _, err = creds.GetClient(&osPT{}, "localhost:"+resource.GetPort("22/tcp")) - return err - })) + resource := setupDockerPool(t) t.Run("sftp", func(t *testing.T) { template := ` @@ -129,6 +109,133 @@ cache_resources: }) } +func TestIntegrationSFTPDeleteOnFinish(t *testing.T) { + integration.CheckSkip(t) + t.Parallel() + + resource := setupDockerPool(t) + + client, err := getClient(resource) + require.NoError(t, err) + + writeSFTPFile(t, client, "/upload/1.txt", "data-1") + writeSFTPFile(t, client, "/upload/2.txt", "data-2") + writeSFTPFile(t, client, "/upload/3.txt", "data-3") + + config := ` +output: + drop: {} + +input: + sftp: + address: localhost:$PORT + paths: + - /upload/*.txt + credentials: + username: foo + password: pass + delete_on_finish: true + watcher: + enabled: true + poll_interval: 100ms + cache: files_memory + +cache_resources: + - label: files_memory + memory: + default_ttl: 900s +` + config = strings.NewReplacer( + "$PORT", resource.GetPort("22/tcp"), + ).Replace(config) + + var receivedPathsMut sync.Mutex + var receivedPaths []string + + builder := service.NewStreamBuilder() + require.NoError(t, builder.SetYAML(config)) + require.NoError(t, builder.AddConsumerFunc(func(_ context.Context, msg *service.Message) error { + receivedPathsMut.Lock() + defer receivedPathsMut.Unlock() + path, ok := msg.MetaGet("sftp_path") + if !ok { + return errors.New("sftp_path metadata not found") + } + receivedPaths = append(receivedPaths, path) + return nil + })) + stream, err := builder.Build() + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + runErr := make(chan error) + go func() { runErr <- stream.Run(ctx) }() + defer func() { + cancel() + err := <-runErr + if err != context.Canceled { + require.NoError(t, err, "stream.Run() failed") + } + }() + + require.EventuallyWithT(t, func(c *assert.CollectT) { + receivedPathsMut.Lock() + defer receivedPathsMut.Unlock() + assert.Len(c, receivedPaths, 3) + + files, err := client.Glob("/upload/*.txt") + assert.NoError(c, err) + assert.Empty(c, files) + }, time.Second, time.Millisecond*100) +} + +func setupDockerPool(t *testing.T) *dockertest.Resource { + t.Helper() + + pool, err := dockertest.NewPool("") + require.NoError(t, err) + + pool.MaxWait = time.Second * 30 + resource, err := pool.RunWithOptions(&dockertest.RunOptions{ + Repository: "atmoz/sftp", + Tag: "alpine", + Cmd: []string{ + // https://github.com/atmoz/sftp/issues/401 + "/bin/sh", "-c", "ulimit -n 65535 && exec /entrypoint " + sftpUsername + ":" + sftpPassword + ":1001:100:upload", + }, + }) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, pool.Purge(resource)) + }) + + _ = resource.Expire(900) + + // wait for server to be ready to accept connections + require.NoError(t, pool.Retry(func() error { + _, err := getClient(resource) + return err + })) + + return resource +} +func getClient(resource *dockertest.Resource) (*sftp.Client, error) { + creds := credentials{ + Username: sftpUsername, + Password: sftpPassword, + } + return creds.GetClient(&osPT{}, "localhost:"+resource.GetPort("22/tcp")) +} + +func writeSFTPFile(t *testing.T, client *sftp.Client, path, data string) { + t.Helper() + file, err := client.Create(path) + require.NoError(t, err, "creating file") + defer file.Close() + _, err = fmt.Fprint(file, data, "writing file contents") + require.NoError(t, err) +} + type osPT struct{} func (o *osPT) Open(name string) (fs.File, error) {