diff --git a/internal/resources/providers/awslib/ec2/ebs_snapshot.go b/internal/resources/providers/awslib/ec2/ebs_snapshot.go index ac5ee32642..e0749f58a4 100644 --- a/internal/resources/providers/awslib/ec2/ebs_snapshot.go +++ b/internal/resources/providers/awslib/ec2/ebs_snapshot.go @@ -20,6 +20,7 @@ package ec2 import ( "fmt" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/elastic/cloudbeat/internal/resources/fetching" @@ -56,8 +57,8 @@ func FromSnapshotInfo(snapshot types.SnapshotInfo, region string, awsAccount str State: snapshot.State, Region: region, awsAccount: awsAccount, - VolumeSize: int(*snapshot.VolumeSize), - IsEncrypted: *snapshot.Encrypted, + VolumeSize: int(aws.ToInt32(snapshot.VolumeSize)), + IsEncrypted: aws.ToBool(snapshot.Encrypted), } } @@ -67,8 +68,8 @@ func FromSnapshot(snapshot types.Snapshot, region string, awsAccount string, ins State: snapshot.State, Region: region, awsAccount: awsAccount, - VolumeSize: int(*snapshot.VolumeSize), + VolumeSize: int(aws.ToInt32(snapshot.VolumeSize)), Instance: ins, - IsEncrypted: *snapshot.Encrypted, + IsEncrypted: aws.ToBool(snapshot.Encrypted), } } diff --git a/internal/resources/providers/awslib/ec2/provider.go b/internal/resources/providers/awslib/ec2/provider.go index 14abc5bf13..31d0bb7d05 100644 --- a/internal/resources/providers/awslib/ec2/provider.go +++ b/internal/resources/providers/awslib/ec2/provider.go @@ -20,6 +20,9 @@ package ec2 import ( "context" "fmt" + "iter" + "strings" + "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/retry" @@ -37,12 +40,24 @@ var ( subnetMainAssociationFilterName = "association.main" ) +const ( + snapshotPrefix = "elastic-vulnerability" +) + type Provider struct { log *clog.Logger clients map[string]Client awsAccountID string } +func NewProviderFromClients(log *clog.Logger, awsAccountID string, clients map[string]Client) *Provider { + return &Provider{ + log: log, + clients: clients, + awsAccountID: awsAccountID, + } +} + type Client interface { CreateSnapshots(ctx context.Context, params *ec2.CreateSnapshotsInput, optFns ...func(*ec2.Options)) (*ec2.CreateSnapshotsOutput, error) DeleteSnapshot(ctx context.Context, params *ec2.DeleteSnapshotInput, optFns ...func(*ec2.Options)) (*ec2.DeleteSnapshotOutput, error) @@ -78,7 +93,7 @@ func (p *Provider) CreateSnapshots(ctx context.Context, ins *Ec2Instance) ([]EBS { ResourceType: "snapshot", Tags: []types.Tag{ - {Key: aws.String("Name"), Value: aws.String(fmt.Sprintf("elastic-vulnerability-%s", *ins.InstanceId))}, + {Key: aws.String("Name"), Value: aws.String(fmt.Sprintf("%s-%s", snapshotPrefix, *ins.InstanceId))}, {Key: aws.String("Workload"), Value: aws.String("Cloudbeat Vulnerability Snapshot")}, }, }, @@ -311,6 +326,60 @@ func (p *Provider) DescribeSnapshots(ctx context.Context, snapshot EBSSnapshot) return result, nil } +// IterOwnedSnapshots will iterate over the snapshots owned by cloudbeat (snapshotPrefix) that are older than the +// specified before time. A snapshot will be yielded if: +// - It has a tag with key "Name" and value starting with snapshotPrefix +// - It is older than the specified before time +// - It is "owned" by the current account (owner ID is "self") +func (p *Provider) IterOwnedSnapshots(ctx context.Context, before time.Time) iter.Seq[EBSSnapshot] { + return func(yield func(EBSSnapshot) bool) { + _, err := awslib.MultiRegionFetch(ctx, p.clients, func(ctx context.Context, region string, c Client) ([]awslib.AwsResource, error) { + input := &ec2.DescribeSnapshotsInput{ + Filters: []types.Filter{ + { + Name: aws.String("tag:Name"), + Values: []string{fmt.Sprintf("%s-*", snapshotPrefix)}, + }, + }, + OwnerIds: []string{"self"}, + } + paginator := ec2.NewDescribeSnapshotsPaginator(c, input) + for paginator.HasMorePages() { + output, err := paginator.NextPage(ctx) + if err != nil { + return nil, err + } + for _, snap := range output.Snapshots { + if filterSnap(snap, before) { + p.log.Infof("Found old snapshot %s", *snap.SnapshotId) + ebsSnap := FromSnapshot(snap, region, p.awsAccountID, Ec2Instance{}) + if !yield(ebsSnap) { + return nil, nil + } + } + } + } + return nil, nil + }) + if err != nil { + p.log.Errorf("Error listing owned snapshots: %v", err) + } + } +} + +func filterSnap(snap types.Snapshot, before time.Time) bool { + if aws.ToTime(snap.StartTime).After(before) { + return false + } + + for _, tag := range snap.Tags { + if aws.ToString(tag.Key) == "Name" { + return strings.HasPrefix(aws.ToString(tag.Value), snapshotPrefix) + } + } + return false +} + func (p *Provider) DescribeSubnets(ctx context.Context) ([]awslib.AwsResource, error) { subnets, err := awslib.MultiRegionFetch(ctx, p.clients, func(ctx context.Context, region string, c Client) ([]awslib.AwsResource, error) { input := &ec2.DescribeSubnetsInput{} diff --git a/internal/vulnerability/mock_snapshot_creator_deleter.go b/internal/vulnerability/mock_snapshot_creator_deleter.go index d1102222db..dd0ed69744 100644 --- a/internal/vulnerability/mock_snapshot_creator_deleter.go +++ b/internal/vulnerability/mock_snapshot_creator_deleter.go @@ -21,9 +21,13 @@ package vulnerability import ( context "context" + iter "iter" ec2 "github.com/elastic/cloudbeat/internal/resources/providers/awslib/ec2" + mock "github.com/stretchr/testify/mock" + + time "time" ) // mockSnapshotCreatorDeleter is an autogenerated mock type for the snapshotCreatorDeleter type @@ -145,6 +149,55 @@ func (_c *mockSnapshotCreatorDeleter_DeleteSnapshot_Call) RunAndReturn(run func( return _c } +// IterOwnedSnapshots provides a mock function with given fields: ctx, before +func (_m *mockSnapshotCreatorDeleter) IterOwnedSnapshots(ctx context.Context, before time.Time) iter.Seq[ec2.EBSSnapshot] { + ret := _m.Called(ctx, before) + + if len(ret) == 0 { + panic("no return value specified for IterOwnedSnapshots") + } + + var r0 iter.Seq[ec2.EBSSnapshot] + if rf, ok := ret.Get(0).(func(context.Context, time.Time) iter.Seq[ec2.EBSSnapshot]); ok { + r0 = rf(ctx, before) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(iter.Seq[ec2.EBSSnapshot]) + } + } + + return r0 +} + +// mockSnapshotCreatorDeleter_IterOwnedSnapshots_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IterOwnedSnapshots' +type mockSnapshotCreatorDeleter_IterOwnedSnapshots_Call struct { + *mock.Call +} + +// IterOwnedSnapshots is a helper method to define mock.On call +// - ctx context.Context +// - before time.Time +func (_e *mockSnapshotCreatorDeleter_Expecter) IterOwnedSnapshots(ctx interface{}, before interface{}) *mockSnapshotCreatorDeleter_IterOwnedSnapshots_Call { + return &mockSnapshotCreatorDeleter_IterOwnedSnapshots_Call{Call: _e.mock.On("IterOwnedSnapshots", ctx, before)} +} + +func (_c *mockSnapshotCreatorDeleter_IterOwnedSnapshots_Call) Run(run func(ctx context.Context, before time.Time)) *mockSnapshotCreatorDeleter_IterOwnedSnapshots_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(time.Time)) + }) + return _c +} + +func (_c *mockSnapshotCreatorDeleter_IterOwnedSnapshots_Call) Return(_a0 iter.Seq[ec2.EBSSnapshot]) *mockSnapshotCreatorDeleter_IterOwnedSnapshots_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *mockSnapshotCreatorDeleter_IterOwnedSnapshots_Call) RunAndReturn(run func(context.Context, time.Time) iter.Seq[ec2.EBSSnapshot]) *mockSnapshotCreatorDeleter_IterOwnedSnapshots_Call { + _c.Call.Return(run) + return _c +} + // newMockSnapshotCreatorDeleter creates a new instance of mockSnapshotCreatorDeleter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func newMockSnapshotCreatorDeleter(t interface { diff --git a/internal/vulnerability/snapshot.go b/internal/vulnerability/snapshot.go index cf2650af98..01c78ec114 100644 --- a/internal/vulnerability/snapshot.go +++ b/internal/vulnerability/snapshot.go @@ -19,6 +19,7 @@ package vulnerability import ( "context" + "iter" "sync" "time" @@ -26,9 +27,15 @@ import ( "github.com/elastic/cloudbeat/internal/resources/providers/awslib/ec2" ) +const ( + backgroundDeleteWorkers = 3 + backgroundDeleteTimeout = 2 * 24 * time.Hour +) + type snapshotCreatorDeleter interface { CreateSnapshots(ctx context.Context, ins *ec2.Ec2Instance) ([]ec2.EBSSnapshot, error) DeleteSnapshot(ctx context.Context, snapshot ec2.EBSSnapshot) error + IterOwnedSnapshots(ctx context.Context, before time.Time) iter.Seq[ec2.EBSSnapshot] } type SnapshotManager struct { @@ -64,7 +71,7 @@ func (s *SnapshotManager) CreateSnapshots(ctx context.Context, ins *ec2.Ec2Insta func (s *SnapshotManager) DeleteSnapshot(ctx context.Context, snapshot ec2.EBSSnapshot) { runWithGrace(ctx, shutdownGracePeriod, func(ctx context.Context) { - s.delete(ctx, snapshot) + s.delete(ctx, snapshot, "DeleteSnapshot") }) s.lock.Lock() @@ -82,17 +89,45 @@ func (s *SnapshotManager) Cleanup(ctx context.Context) { wg.Add(1) go func() { defer wg.Done() - s.delete(ctx, snap) + s.delete(ctx, snap, "Cleanup") }() } }) + clear(s.snapshots) } -func (s *SnapshotManager) delete(ctx context.Context, snapshot ec2.EBSSnapshot) { - s.logger.Infof("VulnerabilityScanner.manager.DeleteSnapshot %s", snapshot.SnapshotId) +func (s *SnapshotManager) DeleteOldSnapshots(ctx context.Context) { + var wg sync.WaitGroup + defer wg.Wait() + + ch := newContextualChan[ec2.EBSSnapshot]() + defer ch.Close() + + wg.Add(backgroundDeleteWorkers) + for range backgroundDeleteWorkers { + go func() { + defer wg.Done() + for { + snap, ok := ch.Read(ctx) + if !ok { + return + } + s.delete(ctx, snap, "DeleteOldSnapshots") + } + }() + } + for snapshot := range s.provider.IterOwnedSnapshots(ctx, time.Now().Add(-backgroundDeleteTimeout)) { + if !ch.Write(ctx, snapshot) { + return + } + } +} + +func (s *SnapshotManager) delete(ctx context.Context, snapshot ec2.EBSSnapshot, message string) { + s.logger.Infof("VulnerabilityScanner.manager.%s %s", message, snapshot.SnapshotId) err := s.provider.DeleteSnapshot(ctx, snapshot) if err != nil { - s.logger.Errorf("VulnerabilityScanner.manager.DeleteSnapshot %s error: %s", snapshot.SnapshotId, err) + s.logger.Errorf("VulnerabilityScanner.manager.%s %s error: %s", message, snapshot.SnapshotId, err) } } @@ -109,3 +144,33 @@ func runWithGrace(ctx context.Context, grace time.Duration, f func(ctx context.C defer stop() // if the callback finishes in time, stop the AfterFunc f(newCtx) // finally, call the actual callback! } + +type contextualChan[T any] struct { + ch chan T +} + +func newContextualChan[T any]() contextualChan[T] { + return contextualChan[T]{ch: make(chan T)} +} + +func (s contextualChan[T]) Write(ctx context.Context, t T) bool { + select { + case <-ctx.Done(): + return false + case s.ch <- t: + return true + } +} + +func (s contextualChan[T]) Read(ctx context.Context) (T, bool) { + select { + case t, ok := <-s.ch: + return t, ok + case <-ctx.Done(): + return *new(T), false + } +} + +func (s contextualChan[T]) Close() { + close(s.ch) +} diff --git a/internal/vulnerability/snapshot_test.go b/internal/vulnerability/snapshot_test.go index 466b8ef572..fed63888f0 100644 --- a/internal/vulnerability/snapshot_test.go +++ b/internal/vulnerability/snapshot_test.go @@ -19,11 +19,13 @@ package vulnerability import ( "context" + "errors" "fmt" "testing" "time" "github.com/aws/aws-sdk-go-v2/aws" + awsec2 "github.com/aws/aws-sdk-go-v2/service/ec2" "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -75,6 +77,14 @@ func Test_SnapshotManager(t *testing.T) { assert.Contains(t, manager.snapshots, "snapshot-4") manager.Cleanup(ctx) + assert.Empty(t, manager.snapshots) + + provider.EXPECT().CreateSnapshots(mock.Anything, mock.Anything).Return(generateSnapshots(5), nil).Times(1) + provider.EXPECT().DeleteSnapshot(mock.Anything, mock.Anything).Return(errors.New("some error")).Times(1) + _, err = manager.CreateSnapshots(ctx, &ec2.Ec2Instance{}) + require.NoError(t, err) + manager.Cleanup(ctx) + assert.Empty(t, manager.snapshots) } func generateSnapshots(ids ...int) []ec2.EBSSnapshot { @@ -146,3 +156,89 @@ func Test_runWithGrace(t *testing.T) { }) } } + +func TestSnapshotManager_DeleteOldSnapshots(t *testing.T) { + tests := []struct { + name string + snapshots []types.Snapshot + describeErr error + expectedDeletedIDs []string + }{ + { + name: "nothing happens", + }, + { + name: "only old snapshots with tags are deleted", + snapshots: []types.Snapshot{ + mockSnapshot("snapshot-1", 1*time.Hour, "Name", "elastic-vulnerability-1"), // too new + mockSnapshot("snapshot-2", 49*time.Hour), // doesn't have tag + mockSnapshot("snapshot-3", 49*time.Hour, "Name", "some-vulnerability-3"), // doesn't match tag value + mockSnapshot("snapshot-4", 49*time.Hour, "Name", "elastic-vulnerability-4"), // matches + mockSnapshot("snapshot-5", 100*time.Hour, "Name", "elastic-vulnerability-5"), // matches + mockSnapshot("snapshot-6", 100*time.Hour, "SomeKey", "elastic-vulnerability-5"), // doesn't match tag key + }, + expectedDeletedIDs: []string{"snapshot-4", "snapshot-5"}, + }, + { + name: "error describing", + describeErr: errors.New("some error"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + ctx := context.Background() + log := testhelper.NewLogger(t) + + ec2p := ec2.NewMockClient(t) + ec2p.EXPECT().DescribeSnapshots(mock.Anything, mock.Anything, mock.Anything).Return(&awsec2.DescribeSnapshotsOutput{ + NextToken: nil, + Snapshots: tt.snapshots, + }, tt.describeErr).Times(1) + for _, deletedSnapshotID := range tt.expectedDeletedIDs { + ec2p.EXPECT().DeleteSnapshot(mock.Anything, mock.MatchedBy(func(input *awsec2.DeleteSnapshotInput) bool { + return aws.ToString(input.SnapshotId) == deletedSnapshotID + }), mock.Anything).Return(nil, nil).Times(1) + } + + p := ec2.NewProviderFromClients(log, "account-id", map[string]ec2.Client{"region": ec2p}) + + s := NewSnapshotManager(log, p) + s.DeleteOldSnapshots(ctx) + }) + } + + t.Run("test context done", func(t *testing.T) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + ctx, cancel := context.WithCancel(context.Background()) + m := newMockSnapshotCreatorDeleter(t) + m.EXPECT().IterOwnedSnapshots(mock.Anything, mock.Anything).Return(func(yield func(ec2.EBSSnapshot) bool) { + <-ctx.Done() + assert.False(t, yield(ec2.EBSSnapshot{})) + }).Times(1) + + s := NewSnapshotManager(testhelper.NewLogger(t), m) + cancel() + s.DeleteOldSnapshots(ctx) + }) +} + +func mockSnapshot(id string, before time.Duration, tagKeyAndValues ...string) types.Snapshot { + if len(tagKeyAndValues)%2 != 0 { + panic("tags must be key-value pairs") + } + lenTags := len(tagKeyAndValues) / 2 + tags := make([]types.Tag, lenTags) + for i := 0; i < lenTags; i++ { + tags[i] = types.Tag{ + Key: aws.String(tagKeyAndValues[i*2]), + Value: aws.String(tagKeyAndValues[i*2+1]), + } + } + return types.Snapshot{ + SnapshotId: &id, + StartTime: aws.Time(time.Now().Add(-before)), + Tags: tags, + } +} diff --git a/internal/vulnerability/worker.go b/internal/vulnerability/worker.go index cf4189dfc8..443ee6998a 100644 --- a/internal/vulnerability/worker.go +++ b/internal/vulnerability/worker.go @@ -98,6 +98,13 @@ func (f *VulnerabilityWorker) Run(ctx context.Context) { name string fn func(ctx context.Context) error }{ + { + name: "DeleteOldSnapshots", + fn: func(ctx context.Context) error { + f.manager.DeleteOldSnapshots(ctx) + return nil + }, + }, { name: "FetchInstances", fn: f.fetcher.FetchInstances, diff --git a/internal/vulnerability/worker_test.go b/internal/vulnerability/worker_test.go index 41a4f91b39..31e3a866b1 100644 --- a/internal/vulnerability/worker_test.go +++ b/internal/vulnerability/worker_test.go @@ -21,6 +21,7 @@ import ( "context" "encoding/json" "errors" + "iter" "os" "sync" "testing" @@ -127,7 +128,24 @@ func TestVulnerabilityWorker_Run(t *testing.T) { Region: "region", }, }, nil).Times(1) - mockSnapshotCreatorDeleterProvider.EXPECT().DeleteSnapshot(mock.Anything, mock.Anything).Return(nil).Times(1) + mockSnapshotCreatorDeleterProvider.EXPECT().DeleteSnapshot(mock.Anything, matchBySnapshotID("snapshot-1")).Return(nil).Times(1) + mockSnapshotCreatorDeleterProvider.EXPECT().IterOwnedSnapshots(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, before time.Time) iter.Seq[ec2.EBSSnapshot] { + return func(yield func(ec2.EBSSnapshot) bool) { + require.NoError(t, ctx.Err()) + assert.Greater(t, time.Now(), before) + snap := ec2.EBSSnapshot{ + Instance: ec2.Ec2Instance{ + Instance: types.Instance{InstanceId: aws.String("instance-1")}, + Region: "region", + RootVolume: nil, + }, + SnapshotId: "snapshot-100", + Region: "region", + } + assert.True(t, yield(snap)) + } + }).Times(1) + mockSnapshotCreatorDeleterProvider.EXPECT().DeleteSnapshot(mock.Anything, matchBySnapshotID("snapshot-100")).Return(nil).Times(1) fetcherProvider := newMockInstancesProvider(t) fetcherProvider.EXPECT().DescribeInstances(mock.Anything).Return([]*ec2.Ec2Instance{ @@ -236,3 +254,9 @@ func TestVulnerabilityWorker_Run(t *testing.T) { assert.Equal(t, ex.VulnerabilityID, gotVulnerability.(Vulnerability).ID) } } + +func matchBySnapshotID(snapshotID string) any { + return mock.MatchedBy(func(snapshot ec2.EBSSnapshot) bool { + return snapshot.SnapshotId == snapshotID + }) +}