feat: add OTEL instrumentation for authentication and handlers (#25296)

Signed-off-by: Mike Cutsail <mcutsail15@apple.com>
Signed-off-by: Alexandre Gaudreault <alexandre_gaudreault@intuit.com>
Co-authored-by: Alexandre Gaudreault <alexandre_gaudreault@intuit.com>
This commit is contained in:
Mike Cutsail
2026-02-10 05:31:55 -08:00
committed by GitHub
parent 2c5f7317a5
commit 2793097480
6 changed files with 156 additions and 27 deletions

View File

@@ -19,6 +19,12 @@ import (
"sync"
"time"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
gooidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
@@ -38,6 +44,13 @@ import (
var ErrInvalidRedirectURL = errors.New("invalid return URL")
// OpenTelemetry tracer for this package
var tracer trace.Tracer
func init() {
tracer = otel.Tracer("github.com/argoproj/argo-cd/v3/util/oidc")
}
const (
GrantTypeAuthorizationCode = "authorization_code"
GrantTypeImplicit = "implicit"
@@ -147,6 +160,9 @@ func GetOidcTokenCacheFromJSON(jsonBytes []byte) (*OidcTokenCache, error) {
// GetTokenSourceFromCache creates an oauth2 TokenSource from a cached oidc token. The TokenSource will be configured
// with an early expiration based on the refreshTokenThreshold.
func (a *ClientApp) GetTokenSourceFromCache(ctx context.Context, oidcTokenCache *OidcTokenCache) (oauth2.TokenSource, error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "oidc.ClientApp.GetTokenSourceFromCache")
defer span.End()
if oidcTokenCache == nil {
return nil, errors.New("oidcTokenCache is required")
}
@@ -198,10 +214,18 @@ func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTL
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
_, span := tracer.Start(ctx, "oidc.ClientApp.client")
defer span.End()
span.SetAttributes(
attribute.String("network", network),
attribute.String("addr", addr),
)
return (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial(network, addr)
},
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
@@ -541,7 +565,7 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
}
// save the accessToken in memory for later use
sub := jwtutil.StringField(claims, "sub")
err = a.SetValueInEncryptedCache(FormatAccessTokenCacheKey(sub), []byte(token.AccessToken), GetTokenExpiration(claims))
err = a.SetValueInEncryptedCache(ctx, FormatAccessTokenCacheKey(sub), []byte(token.AccessToken), GetTokenExpiration(claims))
if err != nil {
claimsJSON, _ := json.Marshal(claims)
log.Errorf("cannot cache encrypted accessToken: %v (claims=%s)", err, claimsJSON)
@@ -557,7 +581,7 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
return
}
sid := jwtutil.StringField(claims, "sid")
err = a.SetValueInEncryptedCache(formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, GetTokenExpiration(claims))
err = a.SetValueInEncryptedCache(ctx, formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, GetTokenExpiration(claims))
if err != nil {
claimsJSON, _ := json.Marshal(claims)
log.Errorf("cannot cache encrypted oidc token: %v (claims=%s)", err, claimsJSON)
@@ -587,30 +611,48 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
// GetValueFromEncryptedCache is a convenience method for retreiving a value from cache and decrypting it. If the cache
// does not contain a value for the given key, a nil value is returned. Return handling should check for error and then
// check for nil.
func (a *ClientApp) GetValueFromEncryptedCache(key string) (value []byte, err error) {
func (a *ClientApp) GetValueFromEncryptedCache(ctx context.Context, key string) (value []byte, err error) {
_, span := tracer.Start(ctx, "oidc.ClientApp.GetValueFromEncryptedCache")
defer span.End()
var encryptedValue []byte
span.AddEvent("start cache read")
err = a.clientCache.Get(key, &encryptedValue)
span.AddEvent("end cache read")
if err != nil {
if errors.Is(err, cache.ErrCacheMiss) {
span.SetAttributes(attribute.Bool("cache_hit", false))
// Return nil to signify a cache miss
return nil, nil
}
return nil, fmt.Errorf("failed to get encrypted value from cache: %w", err)
err = fmt.Errorf("failed to get encrypted value from cache: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
span.SetAttributes(attribute.Bool("cache_hit", true))
value, err = crypto.Decrypt(encryptedValue, a.encryptionKey)
if err != nil {
return nil, fmt.Errorf("failed to decrypt value from cache: %w", err)
err = fmt.Errorf("failed to decrypt value from cache: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
return value, err
}
// SetValueFromEncyrptedCache is a convenience method for encrypting a value and storing it in the cache at a given key.
// Cache expiration is set based on input.
func (a *ClientApp) SetValueInEncryptedCache(key string, value []byte, expiration time.Duration) error {
func (a *ClientApp) SetValueInEncryptedCache(ctx context.Context, key string, value []byte, expiration time.Duration) error {
_, span := tracer.Start(ctx, "oidc.ClientApp.SetValueInEncryptedCache")
defer span.End()
encryptedValue, err := crypto.Encrypt(value, a.encryptionKey)
if err != nil {
span.SetStatus(codes.Error, err.Error())
return err
}
span.SetAttributes(
attribute.String("key", key),
attribute.Int("value_length", len(value)),
)
span.AddEvent("start cache write")
err = a.clientCache.Set(&cache.Item{
Key: key,
Object: encryptedValue,
@@ -618,25 +660,38 @@ func (a *ClientApp) SetValueInEncryptedCache(key string, value []byte, expiratio
Expiration: expiration,
},
})
span.AddEvent("end cache write")
if err != nil {
span.SetStatus(codes.Error, err.Error())
return err
}
return nil
}
func (a *ClientApp) CheckAndRefreshToken(ctx context.Context, groupClaims jwt.MapClaims, refreshTokenThreshold time.Duration) (string, error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "oidc.ClientApp.CheckAndRefreshToken")
defer span.End()
iss := jwtutil.StringField(groupClaims, "iss")
sub := jwtutil.StringField(groupClaims, "sub")
sid := jwtutil.StringField(groupClaims, "sid")
span.SetAttributes(
attribute.String("iss", iss),
attribute.String("sub", sub),
attribute.String("sid", sid))
if GetTokenExpiration(groupClaims) < refreshTokenThreshold {
token, err := a.GetUpdatedOidcTokenFromCache(ctx, sub, sid)
if err != nil {
log.Errorf("Failed to get token from cache: %v", err)
span.SetStatus(codes.Error, err.Error())
return "", err
}
if token != nil {
idTokenRAW, ok := token.Extra("id_token").(string)
if !ok {
return "", errors.New("empty id_token")
err = errors.New("empty id_token")
span.SetStatus(codes.Error, err.Error())
return "", err
}
return idTokenRAW, nil
}
@@ -647,12 +702,21 @@ func (a *ClientApp) CheckAndRefreshToken(ctx context.Context, groupClaims jwt.Ma
// GetUpdatedOidcTokenFromCache fetches a token from cache and refreshes it if under the threshold for expiration.
// The cached token will also be updated if it is refreshed. Returns latest token or an error if the process fails.
func (a *ClientApp) GetUpdatedOidcTokenFromCache(ctx context.Context, subject string, sessionId string) (*oauth2.Token, error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "oidc.ClientApp.GetUpdatedOidcTokenFromCache")
defer span.End()
ctx = gooidc.ClientContext(ctx, a.client)
span.SetAttributes(
attribute.String("subject", subject),
attribute.String("sessionId", sessionId),
)
// Get oauth2 config
cacheKey := formatOidcTokenCacheKey(subject, sessionId)
oidcTokenCacheJSON, err := a.GetValueFromEncryptedCache(cacheKey)
oidcTokenCacheJSON, err := a.GetValueFromEncryptedCache(ctx, cacheKey)
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, err
}
if oidcTokenCacheJSON == nil {
@@ -662,25 +726,35 @@ func (a *ClientApp) GetUpdatedOidcTokenFromCache(ctx context.Context, subject st
oidcTokenCache, err := GetOidcTokenCacheFromJSON(oidcTokenCacheJSON)
if err != nil {
err = fmt.Errorf("failed to unmarshal cached oidc token: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
tokenSource, err := a.GetTokenSourceFromCache(ctx, oidcTokenCache)
if err != nil {
err = fmt.Errorf("failed to get token source from cached oidc token: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
span.AddEvent("starting tokenSource.Token()")
token, err := tokenSource.Token()
span.AddEvent("finished tokenSource.Token()")
if err != nil {
return nil, fmt.Errorf("failed to refresh token from source: %w", err)
err = fmt.Errorf("failed to refresh token from source: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
if token.AccessToken != oidcTokenCache.Token.AccessToken {
span.AddEvent("updating cache with latest token")
oidcTokenCache = NewOidcTokenCache(oidcTokenCache.RedirectURL, token)
oidcTokenCacheJSON, err = json.Marshal(oidcTokenCache)
if err != nil {
return nil, fmt.Errorf("failed to marshal oidc oidcTokenCache refresher: %w", err)
err = fmt.Errorf("failed to marshal oidc oidcTokenCache refresher: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
err = a.SetValueInEncryptedCache(cacheKey, oidcTokenCacheJSON, time.Until(token.Expiry))
err = a.SetValueInEncryptedCache(ctx, cacheKey, oidcTokenCacheJSON, time.Until(token.Expiry))
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, err
}
}
@@ -827,6 +901,9 @@ func (a *ClientApp) SetGroupsFromUserInfo(ctx context.Context, claims jwt.Claims
// GetUserInfo queries the IDP userinfo endpoint for claims
func (a *ClientApp) GetUserInfo(ctx context.Context, actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "oidc.ClientApp.GetUserInfo")
defer span.End()
sub := jwtutil.StringField(actualClaims, "sub")
var claims jwt.MapClaims
var encClaims []byte
@@ -848,7 +925,7 @@ func (a *ClientApp) GetUserInfo(ctx context.Context, actualClaims jwt.MapClaims,
}
// check if the accessToken for the user is still present
accessTokenBytes, err := a.GetValueFromEncryptedCache(FormatAccessTokenCacheKey(sub))
accessTokenBytes, err := a.GetValueFromEncryptedCache(ctx, FormatAccessTokenCacheKey(sub))
if err != nil {
return claims, true, fmt.Errorf("could not read accessToken from cache for %s: %w", sub, err)
}

View File

@@ -1418,7 +1418,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
if tt.insertIntoCache {
oidcTokenCacheJSON, err := json.Marshal(tt.oidcTokenCache)
require.NoError(t, err)
require.NoError(t, app.SetValueInEncryptedCache(formatOidcTokenCacheKey(tt.subject, tt.session), oidcTokenCacheJSON, time.Minute))
require.NoError(t, app.SetValueInEncryptedCache(t.Context(), formatOidcTokenCacheKey(tt.subject, tt.session), oidcTokenCacheJSON, time.Minute))
}
token, err := app.GetUpdatedOidcTokenFromCache(t.Context(), tt.subject, tt.session)
if tt.expectErrorContains != "" {
@@ -1509,7 +1509,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL, tt.refreshTokenThreshold),
require.NotEmpty(t, sub)
sid := jwtutil.StringField(tt.groupClaims, "sid")
require.NotEmpty(t, sid)
require.NoError(t, app.SetValueInEncryptedCache(formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, time.Minute))
require.NoError(t, app.SetValueInEncryptedCache(t.Context(), formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, time.Minute))
token, err := app.CheckAndRefreshToken(t.Context(), tt.groupClaims, cdSettings.RefreshTokenThreshold())
if tt.expectErrorContains != "" {
require.ErrorContains(t, err, tt.expectErrorContains)

View File

@@ -7,6 +7,10 @@ import (
"net/http"
"strings"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
gooidc "github.com/coreos/go-oidc/v3/oidc"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
@@ -103,8 +107,12 @@ func (p *providerImpl) Verify(ctx context.Context, tokenString string, argoSetti
//
// At this point, we have not verified that the token has not been altered. All code paths below MUST VERIFY
// THE TOKEN SIGNATURE to confirm that an attacker did not maliciously remove the "aud" claim.
var span trace.Span
ctx, span = tracer.Start(ctx, "oidc.providerImpl.Verify")
defer span.End()
unverifiedHasAudClaim, err := security.UnverifiedHasAudClaim(tokenString)
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, fmt.Errorf("failed to determine whether the token has an aud claim: %w", err)
}
@@ -113,7 +121,9 @@ func (p *providerImpl) Verify(ctx context.Context, tokenString string, argoSetti
idToken, err = p.verify(ctx, "", tokenString, argoSettings.SkipAudienceCheckWhenTokenHasNoAudience())
} else {
allowedAudiences := argoSettings.OAuth2AllowedAudiences()
span.SetAttributes(attribute.StringSlice("allowedAudiences", allowedAudiences))
if len(allowedAudiences) == 0 {
span.SetStatus(codes.Error, "token has an audience claim, but no allowed audiences are configured")
return nil, errors.New("token has an audience claim, but no allowed audiences are configured")
}
tokenVerificationErrors := make(map[string]error)
@@ -143,6 +153,7 @@ func (p *providerImpl) Verify(ctx context.Context, tokenString string, argoSetti
}
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, fmt.Errorf("failed to verify provider token: %w", err)
}
@@ -150,8 +161,12 @@ func (p *providerImpl) Verify(ctx context.Context, tokenString string, argoSetti
}
func (p *providerImpl) verify(ctx context.Context, clientID, tokenString string, skipClientIDCheck bool) (*gooidc.IDToken, error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "oidc.providerImpl.verify")
defer span.End()
prov, err := p.provider()
if err != nil {
span.SetStatus(codes.Error, fmt.Sprintf("failed to query provider: %v", err))
return nil, err
}
config := &gooidc.Config{ClientID: clientID, SkipClientIDCheck: skipClientIDCheck}
@@ -166,16 +181,19 @@ func (p *providerImpl) verify(ctx context.Context, clientID, tokenString string,
// 3. re-attempting token verification
// NOTE: the error message is sensitive to implementation of verifier.Verify()
if !strings.Contains(err.Error(), "failed to verify signature") {
span.SetStatus(codes.Error, fmt.Sprintf("error verifying token: %v", err))
return nil, err
}
newProvider, retryErr := p.newGoOIDCProvider()
if retryErr != nil {
span.SetStatus(codes.Error, fmt.Sprintf("hack: error verifying token on retry: %v", err))
// return original error if we fail to re-initialize OIDC
return nil, err
}
verifier = newProvider.Verifier(config)
idToken, err = verifier.Verify(ctx, tokenString)
if err != nil {
span.SetStatus(codes.Error, fmt.Sprintf("hack: error verifying token: %v", err))
return nil, err
}
// If we get here, we successfully re-initialized OIDC and after re-initialization,

View File

@@ -13,6 +13,11 @@ import (
"sync"
"time"
otel_codes "go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/trace"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
@@ -100,6 +105,13 @@ const (
var InvalidLoginErr = status.Errorf(codes.Unauthenticated, invalidLoginError)
// OpenTelemetry tracer for this package
var tracer trace.Tracer
func init() {
tracer = otel.Tracer("github.com/argoproj/argo-cd/v3/util/session")
}
// Returns the maximum cache size as number of entries
func getMaximumCacheSize() int {
return env.ParseNumFromEnv(envLoginMaxCacheSize, defaultMaxCacheSize, 1, math.MaxInt32)
@@ -536,6 +548,9 @@ func WithAuthMiddleware(disabled bool, isSSOConfigured bool, ssoClientApp *oidcu
// VerifyToken verifies if a token is correct. Tokens can be issued either from us or by an IDP.
// We choose how to verify based on the issuer.
func (mgr *SessionManager) VerifyToken(ctx context.Context, tokenString string) (jwt.Claims, string, error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "session.SessionManager.VerifyToken")
defer span.End()
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
claims := jwt.MapClaims{}
_, _, err := parser.ParseUnverified(tokenString, &claims)
@@ -568,7 +583,9 @@ func (mgr *SessionManager) VerifyToken(ctx context.Context, tokenString string)
// return a dummy claims only containing a value for the issuer, so the
// UI can handle expired tokens appropriately.
if err != nil {
log.Warnf("Failed to verify session token: %s", err)
errorMsg := "Failed to verify session token: " + err.Error()
span.SetStatus(otel_codes.Error, errorMsg)
log.Warn(errorMsg)
tokenExpiredError := &oidc.TokenExpiredError{}
if errors.As(err, &tokenExpiredError) {
claims = jwt.MapClaims{