Skip to content

Commit

Permalink
chore: unit tests for the helper.go file
Browse files Browse the repository at this point in the history
Signed-off-by: Shahram Kalantari <shahramk@gmail.com>
  • Loading branch information
shahramk64 committed Oct 21, 2024
1 parent 2a0c8d8 commit 947f28c
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 86 deletions.
10 changes: 5 additions & 5 deletions pkg/common/oras/authprovider/azure/azureidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ type ManagedIdentityTokenGetter interface {
type DefaultManagedIdentityTokenGetterImpl struct{}

func (g *DefaultManagedIdentityTokenGetterImpl) GetManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) {
return getManagedIdentityToken(ctx, clientID)
return getManagedIdentityToken(ctx, clientID, azidentity.NewManagedIdentityCredential)
}

func getManagedIdentityToken(ctx context.Context, clientID string) (azcore.AccessToken, error) {
func getManagedIdentityToken(ctx context.Context, clientID string, newCredentialFunc func(opts *azidentity.ManagedIdentityCredentialOptions) (*azidentity.ManagedIdentityCredential, error)) (azcore.AccessToken, error) {
id := azidentity.ClientID(clientID)
opts := azidentity.ManagedIdentityCredentialOptions{ID: id}
cred, err := azidentity.NewManagedIdentityCredential(&opts)
cred, err := newCredentialFunc(&opts)
if err != nil {
return azcore.AccessToken{}, err
}
Expand Down Expand Up @@ -110,7 +110,7 @@ func (s *azureManagedIdentityProviderFactory) Create(authProviderConfig provider
return nil, err
}
// retrieve an AAD Access token
token, err := getManagedIdentityToken(context.Background(), client)
token, err := getManagedIdentityToken(context.Background(), client, azidentity.NewManagedIdentityCredential)
if err != nil {
return nil, re.ErrorCodeAuthDenied.NewError(re.AuthProvider, "", re.AzureManagedIdentityLink, err, "", re.HideStackTrace)
}
Expand Down Expand Up @@ -177,7 +177,7 @@ func (d *MIAuthProvider) Provide(ctx context.Context, artifact string) (provider

response, err := client.ExchangeAADAccessTokenForACRRefreshToken(
ctx,
"access_token",
azcontainerregistry.PostContentSchemaGrantType("access_token"),
artifactHostName,
&azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{
AccessToken: &d.identityToken.Token,
Expand Down
31 changes: 30 additions & 1 deletion pkg/common/oras/authprovider/azure/azureidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
azcontainerregistry "github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
ratifyerrors "github.com/ratify-project/ratify/errors"
"github.com/ratify-project/ratify/pkg/common/oras/authprovider"
Expand Down Expand Up @@ -155,7 +156,7 @@ func TestMIAuthProvider_Provide_TokenRefreshSuccess(t *testing.T) {
// Setup mock expectations
mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil)
mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil)
mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil)
mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil)
mockManagedIdentityTokenGetter.On("GetManagedIdentityToken", mock.Anything, "clientID").Return(newAADToken, nil)

// Initialize provider with expired token
Expand Down Expand Up @@ -241,3 +242,31 @@ func TestMIAuthProvider_Provide_InvalidHostName(t *testing.T) {
assert.Error(t, err)
assert.Contains(t, err.Error(), "HOST_NAME_INVALID")
}

// Unit tests
func TestGetManagedIdentityToken(t *testing.T) {
ctx := context.Background()
clientID := "test-client-id"
expectedToken := azcore.AccessToken{Token: "test-token", ExpiresOn: time.Now().Add(time.Hour)}

mockGetter := new(MockManagedIdentityTokenGetter)
mockGetter.On("GetManagedIdentityToken", ctx, clientID).Return(expectedToken, nil)

token, err := mockGetter.GetManagedIdentityToken(ctx, clientID)
assert.Nil(t, err)
assert.Equal(t, expectedToken, token)
}

func TestGetManagedIdentityToken_Error(t *testing.T) {
ctx := context.Background()
clientID := "test-client-id"

// Mock the newCredentialFunc to return an error
mockNewCredentialFunc := func(_ *azidentity.ManagedIdentityCredentialOptions) (*azidentity.ManagedIdentityCredential, error) {
return nil, assert.AnError
}

token, err := getManagedIdentityToken(ctx, clientID, mockNewCredentialFunc)
assert.NotNil(t, err)
assert.Equal(t, azcore.AccessToken{}, token)
}
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (d *WIAuthProvider) Provide(ctx context.Context, artifact string) (provider
startTime := time.Now()
response, err := client.ExchangeAADAccessTokenForACRRefreshToken(
ctx,
"access_token",
azcontainerregistry.PostContentSchemaGrantType("access_token"),
artifactHostName,
&azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{
AccessToken: &d.aadToken.AccessToken,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func TestWIAuthProvider_Provide_Success(t *testing.T) {
// Set expectations for mocked functions
mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil)
mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil)
mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil)
mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil)
mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(initialToken, nil)
mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return()

Expand Down Expand Up @@ -113,7 +113,7 @@ func TestWIAuthProvider_Provide_RefreshToken(t *testing.T) {
// Set expectations for mocked functions
mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil)
mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil)
mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil)
mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil)
mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(newToken, nil)
mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return()

Expand Down Expand Up @@ -225,7 +225,7 @@ func TestWIAuthProvider_Provide_TokenRefresh_Success(t *testing.T) {
// Set expectations
mockRegistryHostGetter.On("GetRegistryHost", "artifact_name").Return("example.azurecr.io", nil)
mockAuthClientFactory.On("CreateAuthClient", "https://example.azurecr.io", mock.Anything).Return(mockAuthClient, nil)
mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, "access_token", "example.azurecr.io", mock.Anything).Return(refreshToken, nil)
mockAuthClient.On("ExchangeAADAccessTokenForACRRefreshToken", mock.Anything, azcontainerregistry.PostContentSchemaGrantType("access_token"), "example.azurecr.io", mock.Anything).Return(refreshToken, nil)
mockAADAccessTokenGetter.On("GetAADAccessToken", mock.Anything, "tenantID", "clientID", mock.Anything).Return(newToken, nil)
mockMetricsReporter.On("ReportMetrics", mock.Anything, mock.Anything, "example.azurecr.io").Return()

Expand Down
13 changes: 9 additions & 4 deletions pkg/common/oras/authprovider/azure/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,21 @@ func DefaultAuthClientFactory(serverURL string, options *azcontainerregistry.Aut
return &AuthenticationClientWrapper{client: client}, nil
}

// Define the interface for azcontainerregistry.AuthenticationClient methods used
type AuthenticationClientInterface interface {
ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error)
}

type AuthenticationClientWrapper struct {
client *azcontainerregistry.AuthenticationClient
client AuthenticationClientInterface
}

func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, azcontainerregistry.PostContentSchemaGrantType(grantType), service, options)
func (w *AuthenticationClientWrapper) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
return w.client.ExchangeAADAccessTokenForACRRefreshToken(ctx, grantType, service, options)
}

type AuthClient interface {
ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error)
ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error)
}

// RegistryHostGetter defines an interface for getting the registry host.
Expand Down
108 changes: 36 additions & 72 deletions pkg/common/oras/authprovider/azure/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ package azure

import (
"context"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/containers/azcontainerregistry"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)

Expand All @@ -28,7 +30,7 @@ type MockAuthClient struct {
}

// Mock method for ExchangeAADAccessTokenForACRRefreshToken
func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
func (m *MockAuthClient) ExchangeAADAccessTokenForACRRefreshToken(ctx context.Context, grantType azcontainerregistry.PostContentSchemaGrantType, service string, options *azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions) (azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse, error) {
args := m.Called(ctx, grantType, service, options)
return args.Get(0).(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse), args.Error(1)
}
Expand All @@ -55,82 +57,44 @@ func (m *MockRegistryHostGetter) GetRegistryHost(artifact string) (string, error
return args.String(0), args.Error(1)
}

// // TestDefaultAuthClientFactoryImpl tests the default factory implementation.
// func TestDefaultAuthClientFactoryImpl(t *testing.T) {
// mockFactory := new(MockAuthClientFactory)
// mockAuthClient := new(MockAuthClient)
func TestDefaultAuthClientFactoryImpl_CreateAuthClient(t *testing.T) {
factory := &DefaultAuthClientFactoryImpl{}
serverURL := "https://example.com"
options := &azcontainerregistry.AuthenticationClientOptions{}

// serverURL := "https://example.azurecr.io"
// options := &azcontainerregistry.AuthenticationClientOptions{}

// // Set up expectations
// mockFactory.On("CreateAuthClient", serverURL, options).Return(mockAuthClient, nil)

// factory := &DefaultAuthClientFactoryImpl{}
// client, err := factory.CreateAuthClient(serverURL, options)

// // Verify expectations
// mockFactory.AssertCalled(t, "CreateAuthClient", serverURL, options)
// assert.NoError(t, err)
// assert.NotNil(t, client)
// }

// // TestDefaultAuthClientFactory_Error tests error handling during client creation.
// func TestDefaultAuthClientFactory_Error(t *testing.T) {
// mockFactory := new(MockAuthClientFactory)

// serverURL := "https://example.azurecr.io"
// options := &azcontainerregistry.AuthenticationClientOptions{}
// expectedError := errors.New("failed to create client")

// // Set up expectations
// mockFactory.On("CreateAuthClient", serverURL, options).Return(nil, expectedError)

// factory := &DefaultAuthClientFactoryImpl{}
// client, err := factory.CreateAuthClient(serverURL, options)

// // Verify expectations
// mockFactory.AssertCalled(t, "CreateAuthClient", serverURL, options)
// assert.Error(t, err)
// assert.Nil(t, client)
// assert.Equal(t, expectedError, err)
// }

// // TestGetRegistryHost tests the GetRegistryHost function.
// func TestGetRegistryHost(t *testing.T) {
// mockGetter := new(MockRegistryHostGetter)

// artifact := "test/artifact"
// expectedHost := "example.azurecr.io"

// // Set up expectations
// mockGetter.On("GetRegistryHost", artifact).Return(expectedHost, nil)
client, err := factory.CreateAuthClient(serverURL, options)
assert.Nil(t, err)
assert.NotNil(t, client)
}

// getter := &DefaultRegistryHostGetterImpl{}
// host, err := getter.GetRegistryHost(artifact)
func TestDefaultAuthClientFactory(t *testing.T) {
serverURL := "https://example.com"
options := &azcontainerregistry.AuthenticationClientOptions{}

// // Verify expectations
// mockGetter.AssertCalled(t, "GetRegistryHost", artifact)
// assert.NoError(t, err)
// assert.Equal(t, expectedHost, host)
// }
client, err := DefaultAuthClientFactory(serverURL, options)
assert.Nil(t, err)
assert.NotNil(t, client)
}

// // TestGetRegistryHost_Error tests error handling in GetRegistryHost.
// func TestGetRegistryHost_Error(t *testing.T) {
// mockGetter := new(MockRegistryHostGetter)
func TestDefaultRegistryHostGetterImpl_GetRegistryHost(t *testing.T) {
getter := &DefaultRegistryHostGetterImpl{}
artifact := "example.azurecr.io/myArtifact"

// artifact := "test/artifact"
// expectedError := errors.New("failed to get registry host")
host, err := getter.GetRegistryHost(artifact)
assert.Nil(t, err)
assert.Equal(t, "example.azurecr.io", host)
}

// // Set up expectations
// mockGetter.On("GetRegistryHost", artifact).Return("", expectedError)
func TestAuthenticationClientWrapper_ExchangeAADAccessTokenForACRRefreshToken(t *testing.T) {
mockClient := new(MockAuthClient)
wrapper := &AuthenticationClientWrapper{client: mockClient}
ctx := context.Background()
grantType := azcontainerregistry.PostContentSchemaGrantType("grantType")
service := "service"
options := &azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenOptions{}

// getter := &DefaultRegistryHostGetterImpl{}
// host, err := getter.GetRegistryHost(artifact)
mockClient.On("ExchangeAADAccessTokenForACRRefreshToken", ctx, grantType, service, options).Return(azcontainerregistry.AuthenticationClientExchangeAADAccessTokenForACRRefreshTokenResponse{}, nil)

// // Verify expectations
// mockGetter.AssertCalled(t, "GetRegistryHost", artifact)
// assert.Error(t, err)
// assert.Empty(t, host)
// assert.Equal(t, expectedError, err)
// }
_, err := wrapper.ExchangeAADAccessTokenForACRRefreshToken(ctx, grantType, service, options)
assert.Nil(t, err)
}

0 comments on commit 947f28c

Please sign in to comment.