@@ -64,10 +64,17 @@ func New() *StoragePlugin {
64
64
}
65
65
}
66
66
67
+ func (p * StoragePlugin ) deleteClient (storageBucketId string ) {
68
+ p .clients .Lock ()
69
+ defer p .clients .Unlock ()
70
+ delete (p .clients .cache , storageBucketId )
71
+ }
72
+
67
73
// GetClient returns an S3API client for the given storage bucket id.
68
- func (p * StoragePlugin ) GetClient (ctx context.Context , storageBucketId string , storageState * awsStoragePersistedState , opts ... s3Option ) (S3API , error ) {
69
- if storageBucketId == "" {
70
- // No storage bucket ID to key cache on, create and return client
74
+ func (p * StoragePlugin ) getClient (ctx context.Context , storageBucketId string , storageState * awsStoragePersistedState , opts ... s3Option ) (S3API , error ) {
75
+ roleArn := storageState .CredentialsConfig .RoleARN
76
+ if storageBucketId == "" || roleArn == "" {
77
+ // No storage bucket ID to key cache on or not a roleArn cred type
71
78
client , err := storageState .S3Client (ctx , opts ... )
72
79
if err != nil {
73
80
return nil , errors .BadRequestStatus ("error creating S3 client: %s" , err )
@@ -78,11 +85,13 @@ func (p *StoragePlugin) GetClient(ctx context.Context, storageBucketId string, s
78
85
p .clients .RLock ()
79
86
client , ok := p .clients .cache [storageBucketId ]
80
87
p .clients .RUnlock ()
81
- if ! ok {
88
+
89
+ if ! ok || client .Credentials ().Expired () || client .RoleArn () != roleArn {
82
90
p .clients .Lock ()
83
- // Check again in case another caller created it since we got lock
91
+
92
+ // Check again in case another caller updated it since we got the lock
84
93
client , ok = p .clients .cache [storageBucketId ]
85
- if ok {
94
+ if ok && ! client . Credentials (). Expired () && client . RoleArn () == roleArn {
86
95
p .clients .Unlock ()
87
96
return client , nil
88
97
}
@@ -98,6 +107,39 @@ func (p *StoragePlugin) GetClient(ctx context.Context, storageBucketId string, s
98
107
return client , nil
99
108
}
100
109
110
+ type s3Caller func (client S3API ) (any , error )
111
+
112
+ func (p * StoragePlugin ) call (
113
+ ctx context.Context ,
114
+ fn s3Caller ,
115
+ storageBucketId string ,
116
+ storageState * awsStoragePersistedState ,
117
+ opts ... s3Option ) (any , error ) {
118
+
119
+ var attempt int
120
+ for {
121
+ s3Client , err := p .getClient (ctx , storageBucketId , storageState , opts ... )
122
+ if err != nil {
123
+ return nil , errors .BadRequestStatus ("error getting S3 client: %s" , err )
124
+ }
125
+ resp , err := fn (s3Client )
126
+ if err != nil {
127
+ st , _ := errors .ParseAWSError ("" , err )
128
+ if st .Code () == codes .PermissionDenied {
129
+ // We got a permission error, if this is the first time re-create client otherwise return error
130
+ if attempt != 0 || storageState .CredentialsConfig .RoleARN == "" {
131
+ return nil , err
132
+ }
133
+ attempt ++
134
+ p .deleteClient (storageBucketId )
135
+ continue
136
+ }
137
+ return nil , err
138
+ }
139
+ return resp , nil
140
+ }
141
+ }
142
+
101
143
// OnCreateStorageBucket is called when a storage bucket is created.
102
144
func (p * StoragePlugin ) OnCreateStorageBucket (ctx context.Context , req * pb.OnCreateStorageBucketRequest ) (* pb.OnCreateStorageBucketResponse , error ) {
103
145
bucket := req .GetBucket ()
@@ -400,19 +442,20 @@ func (p *StoragePlugin) HeadObject(ctx context.Context, req *pb.HeadObjectReques
400
442
if storageAttributes .EndpointUrl != "" {
401
443
opts = append (opts , WithEndpoint (storageAttributes .EndpointUrl ))
402
444
}
403
- s3Client , err := storageState .S3Client (ctx , opts ... )
404
- if err != nil {
405
- return nil , errors .BadRequestStatus ("error getting S3 client: %s" , err )
406
- }
407
445
408
446
objectKey := path .Join (bucket .GetBucketPrefix (), req .GetKey ())
409
- resp , err := s3Client .HeadObject (ctx , & s3.HeadObjectInput {
410
- Bucket : aws .String (bucket .GetBucketName ()),
411
- Key : aws .String (objectKey ),
412
- })
447
+ headCall := func (s3Client S3API ) (any , error ) {
448
+ return s3Client .HeadObject (ctx , & s3.HeadObjectInput {
449
+ Bucket : aws .String (bucket .GetBucketName ()),
450
+ Key : aws .String (objectKey ),
451
+ })
452
+ }
453
+ headResp , err := p .call (ctx , headCall , bucket .GetId (), storageState , opts ... )
413
454
if err != nil {
414
455
return nil , parseS3Error ("head object" , err , req ).Err ()
415
456
}
457
+ resp := headResp .(* s3.HeadObjectOutput )
458
+
416
459
return & pb.HeadObjectResponse {
417
460
ContentLength : aws .ToInt64 (resp .ContentLength ),
418
461
LastModified : timestamppb .New (* resp .LastModified ),
@@ -526,19 +569,19 @@ func (p *StoragePlugin) GetObject(req *pb.GetObjectRequest, stream pb.StoragePlu
526
569
if storageAttributes .EndpointUrl != "" {
527
570
opts = append (opts , WithEndpoint (storageAttributes .EndpointUrl ))
528
571
}
529
- s3Client , err := storageState .S3Client (stream .Context (), opts ... )
530
- if err != nil {
531
- return errors .BadRequestStatus ("error getting S3 client: %s" , err )
532
- }
533
572
534
573
objectKey := path .Join (bucket .GetBucketPrefix (), req .GetKey ())
535
- resp , err := s3Client .GetObject (stream .Context (), & s3.GetObjectInput {
536
- Bucket : aws .String (bucket .GetBucketName ()),
537
- Key : aws .String (objectKey ),
538
- })
574
+ getCall := func (s3Client S3API ) (any , error ) {
575
+ return s3Client .GetObject (stream .Context (), & s3.GetObjectInput {
576
+ Bucket : aws .String (bucket .GetBucketName ()),
577
+ Key : aws .String (objectKey ),
578
+ })
579
+ }
580
+ getResp , err := p .call (stream .Context (), getCall , bucket .GetId (), storageState , opts ... )
539
581
if err != nil {
540
582
return parseS3Error ("get object" , err , req ).Err ()
541
583
}
584
+ resp := getResp .(* s3.GetObjectOutput )
542
585
543
586
defer resp .Body .Close ()
544
587
reader := bufio .NewReader (resp .Body )
@@ -638,11 +681,6 @@ func (p *StoragePlugin) PutObject(ctx context.Context, req *pb.PutObjectRequest)
638
681
opts = append (opts , WithEndpoint (storageAttributes .EndpointUrl ))
639
682
}
640
683
641
- s3Client , err := p .GetClient (ctx , bucket .GetId (), storageState , opts ... )
642
- if err != nil {
643
- return nil , errors .BadRequestStatus ("error getting S3 client: %s" , err )
644
- }
645
-
646
684
hash := sha256 .New ()
647
685
if _ , err := io .Copy (hash , file ); err != nil {
648
686
return nil , errors .UnknownStatus ("failed to calcualte hash" )
@@ -653,18 +691,23 @@ func (p *StoragePlugin) PutObject(ctx context.Context, req *pb.PutObjectRequest)
653
691
if err != nil {
654
692
return nil , errors .UnknownStatus ("failed to rewind file pointer" )
655
693
}
656
-
657
694
objectKey := path .Join (bucket .GetBucketPrefix (), req .GetKey ())
658
- resp , err := s3Client .PutObject (ctx , & s3.PutObjectInput {
659
- Bucket : aws .String (bucket .GetBucketName ()),
660
- Key : aws .String (objectKey ),
661
- Body : file ,
662
- ChecksumAlgorithm : types .ChecksumAlgorithmSha256 ,
663
- ChecksumSHA256 : aws .String (checksum ),
664
- })
695
+
696
+ putCall := func (s3Client S3API ) (any , error ) {
697
+ return s3Client .PutObject (ctx , & s3.PutObjectInput {
698
+ Bucket : aws .String (bucket .GetBucketName ()),
699
+ Key : aws .String (objectKey ),
700
+ Body : file ,
701
+ ChecksumAlgorithm : types .ChecksumAlgorithmSha256 ,
702
+ ChecksumSHA256 : aws .String (checksum ),
703
+ })
704
+ }
705
+ putResp , err := p .call (ctx , putCall , bucket .GetId (), storageState , opts ... )
665
706
if err != nil {
666
707
return nil , parseS3Error ("put object" , err , req ).Err ()
667
708
}
709
+ resp := putResp .(* s3.PutObjectOutput )
710
+
668
711
if resp .ChecksumSHA256 == nil {
669
712
return nil , errors .UnknownStatus ("missing checksum response from aws" )
670
713
}
@@ -733,7 +776,7 @@ func (p *StoragePlugin) DeleteObjects(ctx context.Context, req *pb.DeleteObjects
733
776
if storageAttributes .EndpointUrl != "" {
734
777
opts = append (opts , WithEndpoint (storageAttributes .EndpointUrl ))
735
778
}
736
- client , err := storageState . S3Client (ctx , opts ... )
779
+ client , err := p . getClient (ctx , bucket . GetId (), storageState , opts ... )
737
780
if err != nil {
738
781
return nil , errors .BadRequestStatus ("error getting S3 client: %s" , err )
739
782
}
0 commit comments