mirror of
https://github.com/argoproj/argo-cd.git
synced 2026-02-20 01:28:45 +01:00
Signed-off-by: Jagpreet Singh Tamber <jagpreetstamber@gmail.com>
This commit is contained in:
committed by
GitHub
parent
b39ca155dc
commit
c0c6abedc4
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/argoproj/argo-cd/v3/util/workloadidentity"
|
||||
"github.com/argoproj/argo-cd/v3/util/workloadidentity/mocks"
|
||||
)
|
||||
|
||||
@@ -861,7 +862,7 @@ func Test_nativeGitClient_CommitAndPush(t *testing.T) {
|
||||
|
||||
func Test_newAuth_AzureWorkloadIdentity(t *testing.T) {
|
||||
tokenprovider := new(mocks.TokenProvider)
|
||||
tokenprovider.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil)
|
||||
tokenprovider.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
|
||||
|
||||
creds := AzureWorkloadIdentityCreds{store: NoopCredsStore{}, tokenProvider: tokenprovider}
|
||||
|
||||
|
||||
@@ -735,7 +735,7 @@ func (creds AzureWorkloadIdentityCreds) getAccessToken(scope string) (string, er
|
||||
|
||||
t, found := azureTokenCache.Get(key)
|
||||
if found {
|
||||
return t.(string), nil
|
||||
return t.(*workloadidentity.Token).AccessToken, nil
|
||||
}
|
||||
|
||||
token, err := creds.tokenProvider.GetToken(scope)
|
||||
@@ -743,8 +743,11 @@ func (creds AzureWorkloadIdentityCreds) getAccessToken(scope string) (string, er
|
||||
return "", fmt.Errorf("failed to get Azure access token: %w", err)
|
||||
}
|
||||
|
||||
azureTokenCache.Set(key, token, 2*time.Hour)
|
||||
return token, nil
|
||||
cacheExpiry := workloadidentity.CalculateCacheExpiryBasedOnTokenExpiry(token.ExpiresOn)
|
||||
if cacheExpiry > 0 {
|
||||
azureTokenCache.Set(key, token, cacheExpiry)
|
||||
}
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
func (creds AzureWorkloadIdentityCreds) GetAzureDevOpsAccessToken() (string, error) {
|
||||
|
||||
@@ -8,8 +8,10 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
@@ -412,9 +414,10 @@ func TestGoogleCloudCreds_Environ_cleanup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAzureWorkloadIdentityCreds_Environ(t *testing.T) {
|
||||
resetAzureTokenCache()
|
||||
store := &memoryCredsStore{creds: make(map[string]cred)}
|
||||
workloadIdentityMock := new(mocks.TokenProvider)
|
||||
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil)
|
||||
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken", ExpiresOn: time.Now().Add(time.Minute)}, nil)
|
||||
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
|
||||
_, _, err := creds.Environ()
|
||||
require.NoError(t, err)
|
||||
@@ -427,9 +430,10 @@ func TestAzureWorkloadIdentityCreds_Environ(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAzureWorkloadIdentityCreds_Environ_cleanup(t *testing.T) {
|
||||
resetAzureTokenCache()
|
||||
store := &memoryCredsStore{creds: make(map[string]cred)}
|
||||
workloadIdentityMock := new(mocks.TokenProvider)
|
||||
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil)
|
||||
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken", ExpiresOn: time.Now().Add(time.Minute)}, nil)
|
||||
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
|
||||
closer, _, err := creds.Environ()
|
||||
require.NoError(t, err)
|
||||
@@ -439,9 +443,10 @@ func TestAzureWorkloadIdentityCreds_Environ_cleanup(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAzureWorkloadIdentityCreds_GetUserInfo(t *testing.T) {
|
||||
resetAzureTokenCache()
|
||||
store := &memoryCredsStore{creds: make(map[string]cred)}
|
||||
workloadIdentityMock := new(mocks.TokenProvider)
|
||||
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return("accessToken", nil)
|
||||
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(&workloadidentity.Token{AccessToken: "accessToken", ExpiresOn: time.Now().Add(time.Minute)}, nil)
|
||||
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
|
||||
|
||||
user, email, err := creds.GetUserInfo(t.Context())
|
||||
@@ -456,3 +461,45 @@ func TestGetHelmCredsShouldReturnHelmCredsIfAzureWorkloadIdentityNotSpecified(t
|
||||
_, ok := creds.(AzureWorkloadIdentityCreds)
|
||||
require.Truef(t, ok, "expected HelmCreds but got %T", creds)
|
||||
}
|
||||
|
||||
func TestAzureWorkloadIdentityCreds_FetchNewTokenIfExistingIsExpired(t *testing.T) {
|
||||
resetAzureTokenCache()
|
||||
store := &memoryCredsStore{creds: make(map[string]cred)}
|
||||
workloadIdentityMock := new(mocks.TokenProvider)
|
||||
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).
|
||||
Return(&workloadidentity.Token{AccessToken: "firstToken", ExpiresOn: time.Now().Add(time.Minute)}, nil).Once()
|
||||
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).
|
||||
Return(&workloadidentity.Token{AccessToken: "secondToken"}, nil).Once()
|
||||
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
|
||||
token, err := creds.GetAzureDevOpsAccessToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "firstToken", token)
|
||||
time.Sleep(5 * time.Second)
|
||||
token, err = creds.GetAzureDevOpsAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "secondToken", token)
|
||||
}
|
||||
|
||||
func TestAzureWorkloadIdentityCreds_ReuseTokenIfExistingIsNotExpired(t *testing.T) {
|
||||
resetAzureTokenCache()
|
||||
store := &memoryCredsStore{creds: make(map[string]cred)}
|
||||
workloadIdentityMock := new(mocks.TokenProvider)
|
||||
firstToken := &workloadidentity.Token{AccessToken: "firstToken", ExpiresOn: time.Now().Add(6 * time.Minute)}
|
||||
secondToken := &workloadidentity.Token{AccessToken: "secondToken"}
|
||||
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(firstToken, nil).Once()
|
||||
workloadIdentityMock.On("GetToken", azureDevopsEntraResourceId).Return(secondToken, nil).Once()
|
||||
creds := AzureWorkloadIdentityCreds{store, workloadIdentityMock}
|
||||
token, err := creds.GetAzureDevOpsAccessToken()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "firstToken", token)
|
||||
time.Sleep(5 * time.Second)
|
||||
token, err = creds.GetAzureDevOpsAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "firstToken", token)
|
||||
}
|
||||
|
||||
func resetAzureTokenCache() {
|
||||
azureTokenCache = gocache.New(gocache.NoExpiration, 0)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
utilio "github.com/argoproj/argo-cd/v3/util/io"
|
||||
"github.com/argoproj/argo-cd/v3/util/workloadidentity"
|
||||
"github.com/argoproj/argo-cd/v3/util/workloadidentity/mocks"
|
||||
)
|
||||
|
||||
@@ -308,7 +309,7 @@ func TestGetTagsFromURLPrivateRepoWithAzureWorkloadIdentityAuthentication(t *tes
|
||||
}
|
||||
|
||||
workloadIdentityMock := new(mocks.TokenProvider)
|
||||
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
|
||||
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
|
||||
|
||||
mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Logf("called %s", r.URL.Path)
|
||||
|
||||
@@ -11,7 +11,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
argoutils "github.com/argoproj/argo-cd/v3/util"
|
||||
"github.com/argoproj/argo-cd/v3/util/env"
|
||||
@@ -146,11 +148,33 @@ func (creds AzureWorkloadIdentityCreds) GetAccessToken() (string, error) {
|
||||
return "", fmt.Errorf("failed to get Azure access token after challenge: %w", err)
|
||||
}
|
||||
|
||||
// Access token has a lifetime of 3 hours
|
||||
storeAzureToken(key, token, 2*time.Hour)
|
||||
tokenExpiry, err := getJWTExpiry(token)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get token expiry from JWT: %v, using current time as fallback", err)
|
||||
tokenExpiry = time.Now()
|
||||
}
|
||||
|
||||
cacheExpiry := workloadidentity.CalculateCacheExpiryBasedOnTokenExpiry(tokenExpiry)
|
||||
if cacheExpiry > 0 {
|
||||
storeAzureToken(key, token, cacheExpiry)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func getJWTExpiry(token string) (time.Time, error) {
|
||||
parser := jwt.NewParser()
|
||||
claims := jwt.MapClaims{}
|
||||
_, _, err := parser.ParseUnverified(token, claims)
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("failed to parse JWT: %w", err)
|
||||
}
|
||||
exp, err := claims.GetExpirationTime()
|
||||
if err != nil {
|
||||
return time.Time{}, fmt.Errorf("'exp' claim not found or invalid in token: %w", err)
|
||||
}
|
||||
return time.UnixMilli(exp.UnixMilli()), nil
|
||||
}
|
||||
|
||||
func (creds AzureWorkloadIdentityCreds) getAccessTokenAfterChallenge(tokenParams map[string]string) (string, error) {
|
||||
realm := tokenParams["realm"]
|
||||
service := tokenParams["service"]
|
||||
@@ -177,7 +201,7 @@ func (creds AzureWorkloadIdentityCreds) getAccessTokenAfterChallenge(tokenParams
|
||||
formValues := url.Values{}
|
||||
formValues.Add("grant_type", "access_token")
|
||||
formValues.Add("service", service)
|
||||
formValues.Add("access_token", armAccessToken)
|
||||
formValues.Add("access_token", armAccessToken.AccessToken)
|
||||
|
||||
resp, err := client.PostForm(refreshTokenURL, formValues)
|
||||
if err != nil {
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -79,7 +81,7 @@ func TestGetPasswordShouldGenerateTokenIfNotPresentInCache(t *testing.T) {
|
||||
defer mockServer.Close()
|
||||
|
||||
workloadIdentityMock := new(mocks.TokenProvider)
|
||||
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
|
||||
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
|
||||
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
|
||||
|
||||
// Retrieve the token from the cache
|
||||
@@ -191,7 +193,7 @@ func TestGetAccessTokenAfterChallenge_Success(t *testing.T) {
|
||||
defer mockServer.Close()
|
||||
|
||||
workloadIdentityMock := new(mocks.TokenProvider)
|
||||
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
|
||||
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
|
||||
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
|
||||
|
||||
tokenParams := map[string]string{
|
||||
@@ -216,7 +218,7 @@ func TestGetAccessTokenAfterChallenge_Failure(t *testing.T) {
|
||||
|
||||
// Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper
|
||||
workloadIdentityMock := new(mocks.TokenProvider)
|
||||
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
|
||||
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
|
||||
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
|
||||
|
||||
tokenParams := map[string]string{
|
||||
@@ -241,7 +243,7 @@ func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) {
|
||||
|
||||
// Create an instance of AzureWorkloadIdentityCreds with the mock credential wrapper
|
||||
workloadIdentityMock := new(mocks.TokenProvider)
|
||||
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return("accessToken", nil)
|
||||
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
|
||||
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
|
||||
|
||||
tokenParams := map[string]string{
|
||||
@@ -253,3 +255,125 @@ func TestGetAccessTokenAfterChallenge_MalformedResponse(t *testing.T) {
|
||||
require.ErrorContains(t, err, "failed to unmarshal response body")
|
||||
assert.Empty(t, refreshToken)
|
||||
}
|
||||
|
||||
// Helper to generate a mock JWT token with a given expiry time
|
||||
func generateMockJWT(expiry time.Time) (string, error) {
|
||||
claims := jwt.MapClaims{
|
||||
"exp": expiry.Unix(),
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
// Use a dummy secret for signing
|
||||
return token.SignedString([]byte("dummy-secret"))
|
||||
}
|
||||
|
||||
func TestGetAccessToken_FetchNewTokenIfExistingIsExpired(t *testing.T) {
|
||||
resetAzureTokenCache()
|
||||
accessToken1, _ := generateMockJWT(time.Now().Add(1 * time.Minute))
|
||||
accessToken2, _ := generateMockJWT(time.Now().Add(1 * time.Minute))
|
||||
|
||||
mockServerURL := ""
|
||||
mockedServerURL := func() string {
|
||||
return mockServerURL
|
||||
}
|
||||
|
||||
callCount := 0
|
||||
mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/v2/":
|
||||
assert.Equal(t, "/v2/", r.URL.Path)
|
||||
w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm="%s",service="%s"`, mockedServerURL(), mockedServerURL()[8:]))
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
case "/oauth2/exchange":
|
||||
assert.Equal(t, "/oauth2/exchange", r.URL.Path)
|
||||
var response string
|
||||
switch callCount {
|
||||
case 0:
|
||||
response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken1)
|
||||
case 1:
|
||||
response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken2)
|
||||
default:
|
||||
response = `{"refresh_token": "defaultToken"}`
|
||||
}
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := w.Write([]byte(response))
|
||||
require.NoError(t, err)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
mockServerURL = mockServer.URL
|
||||
|
||||
workloadIdentityMock := new(mocks.TokenProvider)
|
||||
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
|
||||
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
|
||||
|
||||
refreshToken, err := creds.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accessToken1, refreshToken)
|
||||
|
||||
time.Sleep(5 * time.Second) // Wait for the token to expire
|
||||
|
||||
refreshToken, err = creds.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accessToken2, refreshToken)
|
||||
}
|
||||
|
||||
func TestGetAccessToken_ReuseTokenIfExistingIsNotExpired(t *testing.T) {
|
||||
resetAzureTokenCache()
|
||||
accessToken1, _ := generateMockJWT(time.Now().Add(6 * time.Minute))
|
||||
accessToken2, _ := generateMockJWT(time.Now().Add(1 * time.Minute))
|
||||
|
||||
mockServerURL := ""
|
||||
mockedServerURL := func() string {
|
||||
return mockServerURL
|
||||
}
|
||||
|
||||
callCount := 0
|
||||
mockServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/v2/":
|
||||
assert.Equal(t, "/v2/", r.URL.Path)
|
||||
w.Header().Set("Www-Authenticate", fmt.Sprintf(`Bearer realm="%s",service="%s"`, mockedServerURL(), mockedServerURL()[8:]))
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
case "/oauth2/exchange":
|
||||
assert.Equal(t, "/oauth2/exchange", r.URL.Path)
|
||||
var response string
|
||||
switch callCount {
|
||||
case 0:
|
||||
response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken1)
|
||||
case 1:
|
||||
response = fmt.Sprintf(`{"refresh_token": "%s"}`, accessToken2)
|
||||
default:
|
||||
response = `{"refresh_token": "defaultToken"}`
|
||||
}
|
||||
callCount++
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, err := w.Write([]byte(response))
|
||||
require.NoError(t, err)
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer mockServer.Close()
|
||||
mockServerURL = mockServer.URL
|
||||
|
||||
workloadIdentityMock := new(mocks.TokenProvider)
|
||||
workloadIdentityMock.On("GetToken", "https://management.core.windows.net/.default").Return(&workloadidentity.Token{AccessToken: "accessToken"}, nil)
|
||||
creds := NewAzureWorkloadIdentityCreds(mockServer.URL[8:], "", nil, nil, true, workloadIdentityMock)
|
||||
|
||||
refreshToken, err := creds.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accessToken1, refreshToken)
|
||||
|
||||
time.Sleep(5 * time.Second) // Wait for the token to expire
|
||||
|
||||
refreshToken, err = creds.GetAccessToken()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, accessToken1, refreshToken)
|
||||
}
|
||||
|
||||
func resetAzureTokenCache() {
|
||||
azureTokenCache = gocache.New(gocache.NoExpiration, 0)
|
||||
}
|
||||
|
||||
19
util/workloadidentity/mocks/TokenProvider.go
generated
19
util/workloadidentity/mocks/TokenProvider.go
generated
@@ -5,6 +5,7 @@
|
||||
package mocks
|
||||
|
||||
import (
|
||||
"github.com/argoproj/argo-cd/v3/util/workloadidentity"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
@@ -36,22 +37,24 @@ func (_m *TokenProvider) EXPECT() *TokenProvider_Expecter {
|
||||
}
|
||||
|
||||
// GetToken provides a mock function for the type TokenProvider
|
||||
func (_mock *TokenProvider) GetToken(scope string) (string, error) {
|
||||
func (_mock *TokenProvider) GetToken(scope string) (*workloadidentity.Token, error) {
|
||||
ret := _mock.Called(scope)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetToken")
|
||||
}
|
||||
|
||||
var r0 string
|
||||
var r0 *workloadidentity.Token
|
||||
var r1 error
|
||||
if returnFunc, ok := ret.Get(0).(func(string) (string, error)); ok {
|
||||
if returnFunc, ok := ret.Get(0).(func(string) (*workloadidentity.Token, error)); ok {
|
||||
return returnFunc(scope)
|
||||
}
|
||||
if returnFunc, ok := ret.Get(0).(func(string) string); ok {
|
||||
if returnFunc, ok := ret.Get(0).(func(string) *workloadidentity.Token); ok {
|
||||
r0 = returnFunc(scope)
|
||||
} else {
|
||||
r0 = ret.Get(0).(string)
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*workloadidentity.Token)
|
||||
}
|
||||
}
|
||||
if returnFunc, ok := ret.Get(1).(func(string) error); ok {
|
||||
r1 = returnFunc(scope)
|
||||
@@ -79,12 +82,12 @@ func (_c *TokenProvider_GetToken_Call) Run(run func(scope string)) *TokenProvide
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *TokenProvider_GetToken_Call) Return(s string, err error) *TokenProvider_GetToken_Call {
|
||||
_c.Call.Return(s, err)
|
||||
func (_c *TokenProvider_GetToken_Call) Return(token *workloadidentity.Token, err error) *TokenProvider_GetToken_Call {
|
||||
_c.Call.Return(token, err)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *TokenProvider_GetToken_Call) RunAndReturn(run func(scope string) (string, error)) *TokenProvider_GetToken_Call {
|
||||
func (_c *TokenProvider_GetToken_Call) RunAndReturn(run func(scope string) (*workloadidentity.Token, error)) *TokenProvider_GetToken_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package workloadidentity
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
@@ -12,8 +13,13 @@ const (
|
||||
EmptyGuid = "00000000-0000-0000-0000-000000000000" //nolint:revive //FIXME(var-naming)
|
||||
)
|
||||
|
||||
type Token struct {
|
||||
AccessToken string
|
||||
ExpiresOn time.Time
|
||||
}
|
||||
|
||||
type TokenProvider interface {
|
||||
GetToken(scope string) (string, error)
|
||||
GetToken(scope string) (*Token, error)
|
||||
}
|
||||
|
||||
type WorkloadIdentityTokenProvider struct {
|
||||
@@ -29,17 +35,23 @@ func NewWorkloadIdentityTokenProvider() TokenProvider {
|
||||
return WorkloadIdentityTokenProvider{tokenCredential: cred}
|
||||
}
|
||||
|
||||
func (c WorkloadIdentityTokenProvider) GetToken(scope string) (string, error) {
|
||||
func (c WorkloadIdentityTokenProvider) GetToken(scope string) (*Token, error) {
|
||||
if initError != nil {
|
||||
return "", initError
|
||||
return nil, initError
|
||||
}
|
||||
|
||||
token, err := c.tokenCredential.GetToken(context.Background(), policy.TokenRequestOptions{
|
||||
Scopes: []string{scope},
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token.Token, nil
|
||||
return &Token{AccessToken: token.Token, ExpiresOn: token.ExpiresOn}, nil
|
||||
}
|
||||
|
||||
func CalculateCacheExpiryBasedOnTokenExpiry(tokenExpiry time.Time) time.Duration {
|
||||
// Calculate the cache expiry as 5 minutes before the token expires
|
||||
cacheExpiry := time.Until(tokenExpiry) - time.Minute*5
|
||||
return cacheExpiry
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
@@ -37,7 +38,7 @@ func TestGetToken_Success(t *testing.T) {
|
||||
|
||||
token, err := provider.GetToken(scope)
|
||||
require.NoError(t, err, "Expected no error from GetToken")
|
||||
assert.Equal(t, "mocked_token", token, "Expected token to match")
|
||||
assert.Equal(t, "mocked_token", token.AccessToken, "Expected token to match")
|
||||
}
|
||||
|
||||
func TestGetToken_Failure(t *testing.T) {
|
||||
@@ -47,7 +48,7 @@ func TestGetToken_Failure(t *testing.T) {
|
||||
|
||||
token, err := provider.GetToken(scope)
|
||||
require.Error(t, err, "Expected error from GetToken")
|
||||
assert.Empty(t, token, "Expected token to be empty on error")
|
||||
assert.Nil(t, token, "Expected token to be empty on error")
|
||||
}
|
||||
|
||||
func TestGetToken_InitError(t *testing.T) {
|
||||
@@ -56,5 +57,58 @@ func TestGetToken_InitError(t *testing.T) {
|
||||
|
||||
token, err := provider.GetToken("https://management.core.windows.net/.default")
|
||||
require.Error(t, err, "Expected error from GetToken due to initialization error")
|
||||
assert.Empty(t, token, "Expected token to be empty on initialization error")
|
||||
assert.Nil(t, token, "Expected token to be empty on initialization error")
|
||||
}
|
||||
|
||||
func TestCalculateCacheExpiryBasedOnTokenExpiry(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
expiry time.Time
|
||||
expected time.Duration
|
||||
delta float64
|
||||
}{
|
||||
{
|
||||
name: "Future expiry (10min ahead)",
|
||||
expiry: now.Add(10 * time.Minute),
|
||||
expected: 5 * time.Minute,
|
||||
delta: 10, // allow 10s difference
|
||||
},
|
||||
{
|
||||
name: "Expiring in 5 minutes",
|
||||
expiry: now.Add(5 * time.Second),
|
||||
expected: now.Sub(now.Add(5 * time.Minute)),
|
||||
delta: 10, // allow 10s difference
|
||||
},
|
||||
{
|
||||
name: "Expires soon (4min ahead)",
|
||||
expiry: now.Add(4 * time.Minute),
|
||||
expected: now.Sub(now.Add(1 * time.Minute)),
|
||||
delta: 10, // allow 10s difference
|
||||
},
|
||||
{
|
||||
name: "Just expired (1s ago)",
|
||||
expiry: now.Add(-1 * time.Second),
|
||||
expected: now.Sub(now.Add(5 * time.Minute)),
|
||||
delta: 10, // allow 10s difference
|
||||
},
|
||||
{
|
||||
name: "Already expired (1m ago)",
|
||||
expiry: now.Add(-1 * time.Minute),
|
||||
expected: now.Sub(now.Add(6 * time.Minute)),
|
||||
delta: 10, // allow 10s difference
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
actual := CalculateCacheExpiryBasedOnTokenExpiry(tt.expiry)
|
||||
if tt.delta > 0 {
|
||||
assert.InDelta(t, tt.expected.Seconds(), actual.Seconds(), tt.delta)
|
||||
} else {
|
||||
assert.Equal(t, tt.expected, actual)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user