Skip to content

Commit 17dae9d

Browse files
committed
s
1 parent cc9b2e9 commit 17dae9d

File tree

4 files changed

+110
-56
lines changed

4 files changed

+110
-56
lines changed

go.mod

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module github.com/hashicorp/boundary-plugin-aws
22

3-
go 1.21
3+
go 1.23.3
44

55
require (
66
github.com/aws/aws-sdk-go-v2 v1.32.5
@@ -13,7 +13,7 @@ require (
1313
github.com/google/uuid v1.4.0
1414
github.com/hashicorp/boundary/sdk v0.0.43-0.20240717182311-a20aae98794a
1515
github.com/hashicorp/go-multierror v1.1.1
16-
github.com/hashicorp/go-secure-stdlib/awsutil/v2 v2.0.0
16+
github.com/hashicorp/go-secure-stdlib/awsutil/v2 v2.0.1-0.20241123041515-c99d3d7bba9a
1717
github.com/hashicorp/terraform-json v0.14.0
1818
github.com/mitchellh/mapstructure v1.5.0
1919
github.com/stretchr/testify v1.8.4
@@ -43,7 +43,7 @@ require (
4343
github.com/hashicorp/eventlogger v0.2.6-0.20231025104552-802587e608f0 // indirect
4444
github.com/hashicorp/eventlogger/filters/encrypt v0.1.8-0.20231025104552-802587e608f0 // indirect
4545
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
46-
github.com/hashicorp/go-hclog v1.5.0 // indirect
46+
github.com/hashicorp/go-hclog v1.6.3 // indirect
4747
github.com/hashicorp/go-kms-wrapping/v2 v2.0.14 // indirect
4848
github.com/hashicorp/go-uuid v1.0.3 // indirect
4949
github.com/hashicorp/go-version v1.5.0 // indirect

go.sum

+4-4
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ github.com/hashicorp/eventlogger/filters/encrypt v0.1.8-0.20231025104552-802587e
8484
github.com/hashicorp/eventlogger/filters/encrypt v0.1.8-0.20231025104552-802587e608f0/go.mod h1:tMywUTIvdB/FXhwm6HMTt61C8/eODY6gitCHhXtyojg=
8585
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
8686
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
87-
github.com/hashicorp/go-hclog v1.5.0 h1:bI2ocEMgcVlz55Oj1xZNBsVi900c7II+fWDyV9o+13c=
88-
github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
87+
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k=
88+
github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
8989
github.com/hashicorp/go-kms-wrapping/plugin/v2 v2.0.5 h1:jrnDfQm2hCQ0/hEselgqzV4fK16gpZoY0OWGZpVPNHM=
9090
github.com/hashicorp/go-kms-wrapping/plugin/v2 v2.0.5/go.mod h1:psh1qKep5ukvuNobFY/hCybuudlkkACpmazOsCgX5Rg=
9191
github.com/hashicorp/go-kms-wrapping/v2 v2.0.14 h1:1ZuhfnZgRnLK8S0KovJkoTCRIQId5pv3sDR7pG5VQBw=
@@ -94,8 +94,8 @@ github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+l
9494
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
9595
github.com/hashicorp/go-plugin v1.5.2 h1:aWv8eimFqWlsEiMrYZdPYl+FdHaBJSN4AWwGWfT1G2Y=
9696
github.com/hashicorp/go-plugin v1.5.2/go.mod h1:w1sAEES3g3PuV/RzUrgow20W2uErMly84hhD3um1WL4=
97-
github.com/hashicorp/go-secure-stdlib/awsutil/v2 v2.0.0 h1:ca5TSI4AgaOncPpyzLDtCGjVEtKukONpeM95vFxXCOQ=
98-
github.com/hashicorp/go-secure-stdlib/awsutil/v2 v2.0.0/go.mod h1:7CUvZtfTp2U0CYQCLzMtS2ngckjAZePSfwrE2aeDP1M=
97+
github.com/hashicorp/go-secure-stdlib/awsutil/v2 v2.0.1-0.20241123041515-c99d3d7bba9a h1:/CCTVDc9Q8GGIB/cCImf+x3XtL7qySfogGAPY8XVseA=
98+
github.com/hashicorp/go-secure-stdlib/awsutil/v2 v2.0.1-0.20241123041515-c99d3d7bba9a/go.mod h1:OeRwM2eWNW62L1Z+8GvoZM5nQJMRWBewHSoo77qmb4Y=
9999
github.com/hashicorp/go-secure-stdlib/base62 v0.1.1/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=
100100
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 h1:ET4pqyjiGmY09R5y+rSd70J2w45CtbWDNvGqWp/R3Ng=
101101
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2/go.mod h1:EdWO6czbmthiwZ3/PUsDV+UD1D5IRU4ActiaWGwt0Yw=

plugin/service/storage/plugin.go

+79-36
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,17 @@ func New() *StoragePlugin {
6464
}
6565
}
6666

67+
func (p *StoragePlugin) deleteClient(storageBucketId string) {
68+
p.clients.Lock()
69+
defer p.clients.Unlock()
70+
delete(p.clients.cache, storageBucketId)
71+
}
72+
6773
// GetClient returns an S3API client for the given storage bucket id.
68-
func (p *StoragePlugin) GetClient(ctx context.Context, storageBucketId string, storageState *awsStoragePersistedState, opts ...s3Option) (S3API, error) {
69-
if storageBucketId == "" {
70-
// No storage bucket ID to key cache on, create and return client
74+
func (p *StoragePlugin) getClient(ctx context.Context, storageBucketId string, storageState *awsStoragePersistedState, opts ...s3Option) (S3API, error) {
75+
roleArn := storageState.CredentialsConfig.RoleARN
76+
if storageBucketId == "" || roleArn == "" {
77+
// No storage bucket ID to key cache on or not a roleArn cred type
7178
client, err := storageState.S3Client(ctx, opts...)
7279
if err != nil {
7380
return nil, errors.BadRequestStatus("error creating S3 client: %s", err)
@@ -78,11 +85,13 @@ func (p *StoragePlugin) GetClient(ctx context.Context, storageBucketId string, s
7885
p.clients.RLock()
7986
client, ok := p.clients.cache[storageBucketId]
8087
p.clients.RUnlock()
81-
if !ok {
88+
89+
if !ok || client.Credentials().Expired() || client.RoleArn() != roleArn {
8290
p.clients.Lock()
83-
// Check again in case another caller created it since we got lock
91+
92+
// Check again in case another caller updated it since we got the lock
8493
client, ok = p.clients.cache[storageBucketId]
85-
if ok {
94+
if ok && !client.Credentials().Expired() && client.RoleArn() == roleArn {
8695
p.clients.Unlock()
8796
return client, nil
8897
}
@@ -98,6 +107,39 @@ func (p *StoragePlugin) GetClient(ctx context.Context, storageBucketId string, s
98107
return client, nil
99108
}
100109

110+
type s3Caller func(client S3API) (any, error)
111+
112+
func (p *StoragePlugin) call(
113+
ctx context.Context,
114+
fn s3Caller,
115+
storageBucketId string,
116+
storageState *awsStoragePersistedState,
117+
opts ...s3Option) (any, error) {
118+
119+
var attempt int
120+
for {
121+
s3Client, err := p.getClient(ctx, storageBucketId, storageState, opts...)
122+
if err != nil {
123+
return nil, errors.BadRequestStatus("error getting S3 client: %s", err)
124+
}
125+
resp, err := fn(s3Client)
126+
if err != nil {
127+
st, _ := errors.ParseAWSError("", err)
128+
if st.Code() == codes.PermissionDenied {
129+
// We got a permission error, if this is the first time re-create client otherwise return error
130+
if attempt != 0 || storageState.CredentialsConfig.RoleARN == "" {
131+
return nil, err
132+
}
133+
attempt++
134+
p.deleteClient(storageBucketId)
135+
continue
136+
}
137+
return nil, err
138+
}
139+
return resp, nil
140+
}
141+
}
142+
101143
// OnCreateStorageBucket is called when a storage bucket is created.
102144
func (p *StoragePlugin) OnCreateStorageBucket(ctx context.Context, req *pb.OnCreateStorageBucketRequest) (*pb.OnCreateStorageBucketResponse, error) {
103145
bucket := req.GetBucket()
@@ -400,19 +442,20 @@ func (p *StoragePlugin) HeadObject(ctx context.Context, req *pb.HeadObjectReques
400442
if storageAttributes.EndpointUrl != "" {
401443
opts = append(opts, WithEndpoint(storageAttributes.EndpointUrl))
402444
}
403-
s3Client, err := storageState.S3Client(ctx, opts...)
404-
if err != nil {
405-
return nil, errors.BadRequestStatus("error getting S3 client: %s", err)
406-
}
407445

408446
objectKey := path.Join(bucket.GetBucketPrefix(), req.GetKey())
409-
resp, err := s3Client.HeadObject(ctx, &s3.HeadObjectInput{
410-
Bucket: aws.String(bucket.GetBucketName()),
411-
Key: aws.String(objectKey),
412-
})
447+
headCall := func(s3Client S3API) (any, error) {
448+
return s3Client.HeadObject(ctx, &s3.HeadObjectInput{
449+
Bucket: aws.String(bucket.GetBucketName()),
450+
Key: aws.String(objectKey),
451+
})
452+
}
453+
headResp, err := p.call(ctx, headCall, bucket.GetId(), storageState, opts...)
413454
if err != nil {
414455
return nil, parseS3Error("head object", err, req).Err()
415456
}
457+
resp := headResp.(*s3.HeadObjectOutput)
458+
416459
return &pb.HeadObjectResponse{
417460
ContentLength: aws.ToInt64(resp.ContentLength),
418461
LastModified: timestamppb.New(*resp.LastModified),
@@ -526,19 +569,19 @@ func (p *StoragePlugin) GetObject(req *pb.GetObjectRequest, stream pb.StoragePlu
526569
if storageAttributes.EndpointUrl != "" {
527570
opts = append(opts, WithEndpoint(storageAttributes.EndpointUrl))
528571
}
529-
s3Client, err := storageState.S3Client(stream.Context(), opts...)
530-
if err != nil {
531-
return errors.BadRequestStatus("error getting S3 client: %s", err)
532-
}
533572

534573
objectKey := path.Join(bucket.GetBucketPrefix(), req.GetKey())
535-
resp, err := s3Client.GetObject(stream.Context(), &s3.GetObjectInput{
536-
Bucket: aws.String(bucket.GetBucketName()),
537-
Key: aws.String(objectKey),
538-
})
574+
getCall := func(s3Client S3API) (any, error) {
575+
return s3Client.GetObject(stream.Context(), &s3.GetObjectInput{
576+
Bucket: aws.String(bucket.GetBucketName()),
577+
Key: aws.String(objectKey),
578+
})
579+
}
580+
getResp, err := p.call(stream.Context(), getCall, bucket.GetId(), storageState, opts...)
539581
if err != nil {
540582
return parseS3Error("get object", err, req).Err()
541583
}
584+
resp := getResp.(*s3.GetObjectOutput)
542585

543586
defer resp.Body.Close()
544587
reader := bufio.NewReader(resp.Body)
@@ -638,11 +681,6 @@ func (p *StoragePlugin) PutObject(ctx context.Context, req *pb.PutObjectRequest)
638681
opts = append(opts, WithEndpoint(storageAttributes.EndpointUrl))
639682
}
640683

641-
s3Client, err := p.GetClient(ctx, bucket.GetId(), storageState, opts...)
642-
if err != nil {
643-
return nil, errors.BadRequestStatus("error getting S3 client: %s", err)
644-
}
645-
646684
hash := sha256.New()
647685
if _, err := io.Copy(hash, file); err != nil {
648686
return nil, errors.UnknownStatus("failed to calcualte hash")
@@ -653,18 +691,23 @@ func (p *StoragePlugin) PutObject(ctx context.Context, req *pb.PutObjectRequest)
653691
if err != nil {
654692
return nil, errors.UnknownStatus("failed to rewind file pointer")
655693
}
656-
657694
objectKey := path.Join(bucket.GetBucketPrefix(), req.GetKey())
658-
resp, err := s3Client.PutObject(ctx, &s3.PutObjectInput{
659-
Bucket: aws.String(bucket.GetBucketName()),
660-
Key: aws.String(objectKey),
661-
Body: file,
662-
ChecksumAlgorithm: types.ChecksumAlgorithmSha256,
663-
ChecksumSHA256: aws.String(checksum),
664-
})
695+
696+
putCall := func(s3Client S3API) (any, error) {
697+
return s3Client.PutObject(ctx, &s3.PutObjectInput{
698+
Bucket: aws.String(bucket.GetBucketName()),
699+
Key: aws.String(objectKey),
700+
Body: file,
701+
ChecksumAlgorithm: types.ChecksumAlgorithmSha256,
702+
ChecksumSHA256: aws.String(checksum),
703+
})
704+
}
705+
putResp, err := p.call(ctx, putCall, bucket.GetId(), storageState, opts...)
665706
if err != nil {
666707
return nil, parseS3Error("put object", err, req).Err()
667708
}
709+
resp := putResp.(*s3.PutObjectOutput)
710+
668711
if resp.ChecksumSHA256 == nil {
669712
return nil, errors.UnknownStatus("missing checksum response from aws")
670713
}
@@ -733,7 +776,7 @@ func (p *StoragePlugin) DeleteObjects(ctx context.Context, req *pb.DeleteObjects
733776
if storageAttributes.EndpointUrl != "" {
734777
opts = append(opts, WithEndpoint(storageAttributes.EndpointUrl))
735778
}
736-
client, err := storageState.S3Client(ctx, opts...)
779+
client, err := p.getClient(ctx, bucket.GetId(), storageState, opts...)
737780
if err != nil {
738781
return nil, errors.BadRequestStatus("error getting S3 client: %s", err)
739782
}

plugin/service/storage/state.go

+24-13
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@ import (
77
"context"
88
"errors"
99
"fmt"
10-
"net"
11-
"net/http"
1210
"net/url"
13-
"time"
1411

1512
"github.com/aws/aws-sdk-go-v2/aws"
1613
"github.com/aws/aws-sdk-go-v2/service/s3"
@@ -27,17 +24,22 @@ type S3API interface {
2724
PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error)
2825
DeleteObjects(ctx context.Context, params *s3.DeleteObjectsInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectsOutput, error)
2926
ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error)
27+
Credentials() aws.Credentials
28+
RoleArn() string
3029
}
3130

32-
var customClient = &http.Client{
33-
Transport: &http.Transport{
34-
Proxy: http.ProxyFromEnvironment,
35-
DialContext: (&net.Dialer{
36-
Timeout: 30 * time.Second,
37-
KeepAlive: 30 * time.Second,
38-
DualStack: true,
39-
}).DialContext,
40-
},
31+
type s3Client struct {
32+
*s3.Client
33+
creds aws.Credentials
34+
arn string
35+
}
36+
37+
func (c *s3Client) Credentials() aws.Credentials {
38+
return c.creds
39+
}
40+
41+
func (c *s3Client) RoleArn() string {
42+
return c.arn
4143
}
4244

4345
type awsStoragePersistedState struct {
@@ -114,7 +116,16 @@ func (s *awsStoragePersistedState) S3Client(ctx context.Context, opt ...s3Option
114116
if s.testS3APIFunc != nil {
115117
return s.testS3APIFunc(*awsCfg)
116118
}
117-
return s3.NewFromConfig(*awsCfg, s3Opts...), nil
119+
120+
// Retrieve the credentials provider from the client
121+
credsProvider := awsCfg.Credentials
122+
creds, err := credsProvider.Retrieve(ctx)
123+
if err != nil {
124+
return nil, fmt.Errorf("failed to retrieve credentials: %v\n", err)
125+
}
126+
127+
c := s3.NewFromConfig(*awsCfg, s3Opts...)
128+
return &s3Client{Client: c, creds: creds, arn: s.CredentialsConfig.RoleARN}, nil
118129
}
119130

120131
// getOpts iterates the inbound s3Options and returns a struct

0 commit comments

Comments
 (0)