mirror of
https://github.com/argoproj/argo-cd.git
synced 2026-02-20 01:28:45 +01:00
feat: add OTEL instrumentation for authentication and handlers (#25296)
Signed-off-by: Mike Cutsail <mcutsail15@apple.com> Signed-off-by: Alexandre Gaudreault <alexandre_gaudreault@intuit.com> Co-authored-by: Alexandre Gaudreault <alexandre_gaudreault@intuit.com>
This commit is contained in:
4
go.mod
4
go.mod
@@ -90,9 +90,11 @@ require (
|
||||
github.com/yuin/gopher-lua v1.1.1
|
||||
gitlab.com/gitlab-org/api/client-go v1.29.0
|
||||
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.64.0
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0
|
||||
go.opentelemetry.io/otel v1.39.0
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0
|
||||
go.opentelemetry.io/otel/sdk v1.39.0
|
||||
go.opentelemetry.io/otel/trace v1.39.0
|
||||
golang.org/x/crypto v0.48.0
|
||||
golang.org/x/net v0.50.0
|
||||
golang.org/x/oauth2 v0.34.0
|
||||
@@ -273,10 +275,8 @@ require (
|
||||
github.com/xlab/treeprint v1.2.0 // indirect
|
||||
go.mongodb.org/mongo-driver v1.17.6 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 // indirect
|
||||
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.39.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.39.0 // indirect
|
||||
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.2 // indirect
|
||||
go.yaml.in/yaml/v3 v3.0.4 // indirect
|
||||
|
||||
@@ -44,8 +44,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/soheilhy/cmux"
|
||||
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
|
||||
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
|
||||
"go.opentelemetry.io/otel"
|
||||
otel_codes "go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/propagation"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
@@ -166,6 +169,9 @@ var (
|
||||
enableGRPCTimeHistogram = true
|
||||
)
|
||||
|
||||
// OpenTelemetry tracer for this package
|
||||
var tracer trace.Tracer
|
||||
|
||||
func init() {
|
||||
maxConcurrentLoginRequestsCount = env.ParseNumFromEnv(maxConcurrentLoginRequestsCountEnv, maxConcurrentLoginRequestsCount, 0, math.MaxInt32)
|
||||
replicasCount = env.ParseNumFromEnv(replicasCountEnv, replicasCount, 0, math.MaxInt32)
|
||||
@@ -173,6 +179,7 @@ func init() {
|
||||
maxConcurrentLoginRequestsCount = maxConcurrentLoginRequestsCount / replicasCount
|
||||
}
|
||||
enableGRPCTimeHistogram = env.ParseBoolFromEnv(common.EnvEnableGRPCTimeHistogramEnv, false)
|
||||
tracer = otel.Tracer("github.com/argoproj/argo-cd/v3/server")
|
||||
}
|
||||
|
||||
// ArgoCDServer is the API server for Argo CD
|
||||
@@ -1164,8 +1171,8 @@ func (server *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWeb
|
||||
Handler: &handlerSwitcher{
|
||||
handler: mux,
|
||||
urlToHandler: map[string]http.Handler{
|
||||
"/api/badge": badge.NewHandler(server.AppClientset, server.settingsMgr, server.Namespace, server.ApplicationNamespaces),
|
||||
common.LogoutEndpoint: logout.NewHandler(server.settingsMgr, server.sessionMgr, server.RootPath, server.BaseHRef),
|
||||
"/api/badge": otelhttp.NewHandler(badge.NewHandler(server.AppClientset, server.settingsMgr, server.Namespace, server.ApplicationNamespaces), "server.ArgoCDServer/badge"),
|
||||
common.LogoutEndpoint: otelhttp.NewHandler(logout.NewHandler(server.settingsMgr, server.sessionMgr, server.RootPath, server.BaseHRef), "server.ArgoCDServer/logout"),
|
||||
},
|
||||
contentTypeToHandler: map[string]http.Handler{
|
||||
"application/grpc-web+proto": grpcWebHandler,
|
||||
@@ -1293,7 +1300,7 @@ func registerExtensions(mux *http.ServeMux, a *ArgoCDServer, metricsReg HTTPMetr
|
||||
extHandler := http.HandlerFunc(a.extensionManager.CallExtension())
|
||||
authMiddleware := a.sessionMgr.AuthMiddlewareFunc(a.DisableAuth, a.settings.IsSSOConfigured(), a.ssoClientApp)
|
||||
// auth middleware ensures that requests to all extensions are authenticated first
|
||||
mux.Handle(extension.URLPrefix+"/", authMiddleware(extHandler))
|
||||
mux.Handle(extension.URLPrefix+"/", otelhttp.NewHandler(authMiddleware(extHandler), "server.ArgoCDServer/extensions"))
|
||||
|
||||
a.extensionManager.AddMetricsRegistry(metricsReg)
|
||||
|
||||
@@ -1351,9 +1358,10 @@ func (server *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) {
|
||||
return
|
||||
}
|
||||
// Run dex OpenID Connect Identity Provider behind a reverse proxy (served at /api/dex)
|
||||
mux.HandleFunc(common.DexAPIEndpoint+"/", dexutil.NewDexHTTPReverseProxy(server.DexServerAddr, server.BaseHRef, server.DexTLSConfig))
|
||||
mux.HandleFunc(common.LoginEndpoint, server.ssoClientApp.HandleLogin)
|
||||
mux.HandleFunc(common.CallbackEndpoint, server.ssoClientApp.HandleCallback)
|
||||
mux.Handle(common.DexAPIEndpoint+"/", otelhttp.NewHandler(http.HandlerFunc(dexutil.NewDexHTTPReverseProxy(server.DexServerAddr, server.BaseHRef, server.DexTLSConfig)), "server.dex/Proxy"))
|
||||
|
||||
mux.Handle(common.LoginEndpoint, otelhttp.NewHandler(http.HandlerFunc(server.ssoClientApp.HandleLogin), "server.ClientApp/HandleLogin"))
|
||||
mux.Handle(common.CallbackEndpoint, otelhttp.NewHandler(http.HandlerFunc(server.ssoClientApp.HandleCallback), "server.ClientApp/HandleCallback"))
|
||||
}
|
||||
|
||||
// newRedirectServer returns an HTTP server which does a 307 redirect to the HTTPS server
|
||||
@@ -1510,6 +1518,9 @@ func replaceBaseHRef(data string, replaceWith string) string {
|
||||
|
||||
// Authenticate checks for the presence of a valid token when accessing server-side resources.
|
||||
func (server *ArgoCDServer) Authenticate(ctx context.Context) (context.Context, error) {
|
||||
var span trace.Span
|
||||
ctx, span = tracer.Start(ctx, "server.ArgoCDServer.Authenticate")
|
||||
defer span.End()
|
||||
if server.DisableAuth {
|
||||
return ctx, nil
|
||||
}
|
||||
@@ -1549,18 +1560,24 @@ func (server *ArgoCDServer) Authenticate(ctx context.Context) (context.Context,
|
||||
|
||||
// getClaims extracts, validates and refreshes a JWT token from an incoming request context.
|
||||
func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error) {
|
||||
var span trace.Span
|
||||
ctx, span = tracer.Start(ctx, "server.ArgoCDServer.getClaims")
|
||||
defer span.End()
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
span.SetStatus(otel_codes.Error, ErrNoSession.Error())
|
||||
return nil, "", ErrNoSession
|
||||
}
|
||||
tokenString := getToken(md)
|
||||
if tokenString == "" {
|
||||
span.SetStatus(otel_codes.Error, ErrNoSession.Error())
|
||||
return nil, "", ErrNoSession
|
||||
}
|
||||
// A valid argocd-issued token is automatically refreshed here prior to expiration.
|
||||
// OIDC tokens will be verified but will not be refreshed here.
|
||||
claims, newToken, err := server.sessionMgr.VerifyToken(ctx, tokenString)
|
||||
if err != nil {
|
||||
span.SetStatus(otel_codes.Error, err.Error())
|
||||
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
gooidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -38,6 +44,13 @@ import (
|
||||
|
||||
var ErrInvalidRedirectURL = errors.New("invalid return URL")
|
||||
|
||||
// OpenTelemetry tracer for this package
|
||||
var tracer trace.Tracer
|
||||
|
||||
func init() {
|
||||
tracer = otel.Tracer("github.com/argoproj/argo-cd/v3/util/oidc")
|
||||
}
|
||||
|
||||
const (
|
||||
GrantTypeAuthorizationCode = "authorization_code"
|
||||
GrantTypeImplicit = "implicit"
|
||||
@@ -147,6 +160,9 @@ func GetOidcTokenCacheFromJSON(jsonBytes []byte) (*OidcTokenCache, error) {
|
||||
// GetTokenSourceFromCache creates an oauth2 TokenSource from a cached oidc token. The TokenSource will be configured
|
||||
// with an early expiration based on the refreshTokenThreshold.
|
||||
func (a *ClientApp) GetTokenSourceFromCache(ctx context.Context, oidcTokenCache *OidcTokenCache) (oauth2.TokenSource, error) {
|
||||
var span trace.Span
|
||||
ctx, span = tracer.Start(ctx, "oidc.ClientApp.GetTokenSourceFromCache")
|
||||
defer span.End()
|
||||
if oidcTokenCache == nil {
|
||||
return nil, errors.New("oidcTokenCache is required")
|
||||
}
|
||||
@@ -198,10 +214,18 @@ func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTL
|
||||
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
_, span := tracer.Start(ctx, "oidc.ClientApp.client")
|
||||
defer span.End()
|
||||
span.SetAttributes(
|
||||
attribute.String("network", network),
|
||||
attribute.String("addr", addr),
|
||||
)
|
||||
return (&net.Dialer{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial(network, addr)
|
||||
},
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
@@ -541,7 +565,7 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
// save the accessToken in memory for later use
|
||||
sub := jwtutil.StringField(claims, "sub")
|
||||
err = a.SetValueInEncryptedCache(FormatAccessTokenCacheKey(sub), []byte(token.AccessToken), GetTokenExpiration(claims))
|
||||
err = a.SetValueInEncryptedCache(ctx, FormatAccessTokenCacheKey(sub), []byte(token.AccessToken), GetTokenExpiration(claims))
|
||||
if err != nil {
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
log.Errorf("cannot cache encrypted accessToken: %v (claims=%s)", err, claimsJSON)
|
||||
@@ -557,7 +581,7 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
sid := jwtutil.StringField(claims, "sid")
|
||||
err = a.SetValueInEncryptedCache(formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, GetTokenExpiration(claims))
|
||||
err = a.SetValueInEncryptedCache(ctx, formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, GetTokenExpiration(claims))
|
||||
if err != nil {
|
||||
claimsJSON, _ := json.Marshal(claims)
|
||||
log.Errorf("cannot cache encrypted oidc token: %v (claims=%s)", err, claimsJSON)
|
||||
@@ -587,30 +611,48 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
// GetValueFromEncryptedCache is a convenience method for retreiving a value from cache and decrypting it. If the cache
|
||||
// does not contain a value for the given key, a nil value is returned. Return handling should check for error and then
|
||||
// check for nil.
|
||||
func (a *ClientApp) GetValueFromEncryptedCache(key string) (value []byte, err error) {
|
||||
func (a *ClientApp) GetValueFromEncryptedCache(ctx context.Context, key string) (value []byte, err error) {
|
||||
_, span := tracer.Start(ctx, "oidc.ClientApp.GetValueFromEncryptedCache")
|
||||
defer span.End()
|
||||
var encryptedValue []byte
|
||||
span.AddEvent("start cache read")
|
||||
err = a.clientCache.Get(key, &encryptedValue)
|
||||
span.AddEvent("end cache read")
|
||||
if err != nil {
|
||||
if errors.Is(err, cache.ErrCacheMiss) {
|
||||
span.SetAttributes(attribute.Bool("cache_hit", false))
|
||||
// Return nil to signify a cache miss
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get encrypted value from cache: %w", err)
|
||||
err = fmt.Errorf("failed to get encrypted value from cache: %w", err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
span.SetAttributes(attribute.Bool("cache_hit", true))
|
||||
value, err = crypto.Decrypt(encryptedValue, a.encryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt value from cache: %w", err)
|
||||
err = fmt.Errorf("failed to decrypt value from cache: %w", err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
return value, err
|
||||
}
|
||||
|
||||
// SetValueFromEncyrptedCache is a convenience method for encrypting a value and storing it in the cache at a given key.
|
||||
// Cache expiration is set based on input.
|
||||
func (a *ClientApp) SetValueInEncryptedCache(key string, value []byte, expiration time.Duration) error {
|
||||
func (a *ClientApp) SetValueInEncryptedCache(ctx context.Context, key string, value []byte, expiration time.Duration) error {
|
||||
_, span := tracer.Start(ctx, "oidc.ClientApp.SetValueInEncryptedCache")
|
||||
defer span.End()
|
||||
encryptedValue, err := crypto.Encrypt(value, a.encryptionKey)
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return err
|
||||
}
|
||||
span.SetAttributes(
|
||||
attribute.String("key", key),
|
||||
attribute.Int("value_length", len(value)),
|
||||
)
|
||||
span.AddEvent("start cache write")
|
||||
err = a.clientCache.Set(&cache.Item{
|
||||
Key: key,
|
||||
Object: encryptedValue,
|
||||
@@ -618,25 +660,38 @@ func (a *ClientApp) SetValueInEncryptedCache(key string, value []byte, expiratio
|
||||
Expiration: expiration,
|
||||
},
|
||||
})
|
||||
span.AddEvent("end cache write")
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *ClientApp) CheckAndRefreshToken(ctx context.Context, groupClaims jwt.MapClaims, refreshTokenThreshold time.Duration) (string, error) {
|
||||
var span trace.Span
|
||||
ctx, span = tracer.Start(ctx, "oidc.ClientApp.CheckAndRefreshToken")
|
||||
defer span.End()
|
||||
iss := jwtutil.StringField(groupClaims, "iss")
|
||||
sub := jwtutil.StringField(groupClaims, "sub")
|
||||
sid := jwtutil.StringField(groupClaims, "sid")
|
||||
span.SetAttributes(
|
||||
attribute.String("iss", iss),
|
||||
attribute.String("sub", sub),
|
||||
attribute.String("sid", sid))
|
||||
if GetTokenExpiration(groupClaims) < refreshTokenThreshold {
|
||||
token, err := a.GetUpdatedOidcTokenFromCache(ctx, sub, sid)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get token from cache: %v", err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return "", err
|
||||
}
|
||||
if token != nil {
|
||||
idTokenRAW, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return "", errors.New("empty id_token")
|
||||
err = errors.New("empty id_token")
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return "", err
|
||||
}
|
||||
return idTokenRAW, nil
|
||||
}
|
||||
@@ -647,12 +702,21 @@ func (a *ClientApp) CheckAndRefreshToken(ctx context.Context, groupClaims jwt.Ma
|
||||
// GetUpdatedOidcTokenFromCache fetches a token from cache and refreshes it if under the threshold for expiration.
|
||||
// The cached token will also be updated if it is refreshed. Returns latest token or an error if the process fails.
|
||||
func (a *ClientApp) GetUpdatedOidcTokenFromCache(ctx context.Context, subject string, sessionId string) (*oauth2.Token, error) {
|
||||
var span trace.Span
|
||||
ctx, span = tracer.Start(ctx, "oidc.ClientApp.GetUpdatedOidcTokenFromCache")
|
||||
defer span.End()
|
||||
|
||||
ctx = gooidc.ClientContext(ctx, a.client)
|
||||
span.SetAttributes(
|
||||
attribute.String("subject", subject),
|
||||
attribute.String("sessionId", sessionId),
|
||||
)
|
||||
|
||||
// Get oauth2 config
|
||||
cacheKey := formatOidcTokenCacheKey(subject, sessionId)
|
||||
oidcTokenCacheJSON, err := a.GetValueFromEncryptedCache(cacheKey)
|
||||
oidcTokenCacheJSON, err := a.GetValueFromEncryptedCache(ctx, cacheKey)
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
if oidcTokenCacheJSON == nil {
|
||||
@@ -662,25 +726,35 @@ func (a *ClientApp) GetUpdatedOidcTokenFromCache(ctx context.Context, subject st
|
||||
oidcTokenCache, err := GetOidcTokenCacheFromJSON(oidcTokenCacheJSON)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to unmarshal cached oidc token: %w", err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
tokenSource, err := a.GetTokenSourceFromCache(ctx, oidcTokenCache)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get token source from cached oidc token: %w", err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
span.AddEvent("starting tokenSource.Token()")
|
||||
token, err := tokenSource.Token()
|
||||
span.AddEvent("finished tokenSource.Token()")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to refresh token from source: %w", err)
|
||||
err = fmt.Errorf("failed to refresh token from source: %w", err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
if token.AccessToken != oidcTokenCache.Token.AccessToken {
|
||||
span.AddEvent("updating cache with latest token")
|
||||
oidcTokenCache = NewOidcTokenCache(oidcTokenCache.RedirectURL, token)
|
||||
oidcTokenCacheJSON, err = json.Marshal(oidcTokenCache)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal oidc oidcTokenCache refresher: %w", err)
|
||||
err = fmt.Errorf("failed to marshal oidc oidcTokenCache refresher: %w", err)
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
err = a.SetValueInEncryptedCache(cacheKey, oidcTokenCacheJSON, time.Until(token.Expiry))
|
||||
err = a.SetValueInEncryptedCache(ctx, cacheKey, oidcTokenCacheJSON, time.Until(token.Expiry))
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -827,6 +901,9 @@ func (a *ClientApp) SetGroupsFromUserInfo(ctx context.Context, claims jwt.Claims
|
||||
|
||||
// GetUserInfo queries the IDP userinfo endpoint for claims
|
||||
func (a *ClientApp) GetUserInfo(ctx context.Context, actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
|
||||
var span trace.Span
|
||||
ctx, span = tracer.Start(ctx, "oidc.ClientApp.GetUserInfo")
|
||||
defer span.End()
|
||||
sub := jwtutil.StringField(actualClaims, "sub")
|
||||
var claims jwt.MapClaims
|
||||
var encClaims []byte
|
||||
@@ -848,7 +925,7 @@ func (a *ClientApp) GetUserInfo(ctx context.Context, actualClaims jwt.MapClaims,
|
||||
}
|
||||
|
||||
// check if the accessToken for the user is still present
|
||||
accessTokenBytes, err := a.GetValueFromEncryptedCache(FormatAccessTokenCacheKey(sub))
|
||||
accessTokenBytes, err := a.GetValueFromEncryptedCache(ctx, FormatAccessTokenCacheKey(sub))
|
||||
if err != nil {
|
||||
return claims, true, fmt.Errorf("could not read accessToken from cache for %s: %w", sub, err)
|
||||
}
|
||||
|
||||
@@ -1418,7 +1418,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
|
||||
if tt.insertIntoCache {
|
||||
oidcTokenCacheJSON, err := json.Marshal(tt.oidcTokenCache)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, app.SetValueInEncryptedCache(formatOidcTokenCacheKey(tt.subject, tt.session), oidcTokenCacheJSON, time.Minute))
|
||||
require.NoError(t, app.SetValueInEncryptedCache(t.Context(), formatOidcTokenCacheKey(tt.subject, tt.session), oidcTokenCacheJSON, time.Minute))
|
||||
}
|
||||
token, err := app.GetUpdatedOidcTokenFromCache(t.Context(), tt.subject, tt.session)
|
||||
if tt.expectErrorContains != "" {
|
||||
@@ -1509,7 +1509,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL, tt.refreshTokenThreshold),
|
||||
require.NotEmpty(t, sub)
|
||||
sid := jwtutil.StringField(tt.groupClaims, "sid")
|
||||
require.NotEmpty(t, sid)
|
||||
require.NoError(t, app.SetValueInEncryptedCache(formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, time.Minute))
|
||||
require.NoError(t, app.SetValueInEncryptedCache(t.Context(), formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, time.Minute))
|
||||
token, err := app.CheckAndRefreshToken(t.Context(), tt.groupClaims, cdSettings.RefreshTokenThreshold())
|
||||
if tt.expectErrorContains != "" {
|
||||
require.ErrorContains(t, err, tt.expectErrorContains)
|
||||
|
||||
@@ -7,6 +7,10 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
gooidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
@@ -103,8 +107,12 @@ func (p *providerImpl) Verify(ctx context.Context, tokenString string, argoSetti
|
||||
//
|
||||
// At this point, we have not verified that the token has not been altered. All code paths below MUST VERIFY
|
||||
// THE TOKEN SIGNATURE to confirm that an attacker did not maliciously remove the "aud" claim.
|
||||
var span trace.Span
|
||||
ctx, span = tracer.Start(ctx, "oidc.providerImpl.Verify")
|
||||
defer span.End()
|
||||
unverifiedHasAudClaim, err := security.UnverifiedHasAudClaim(tokenString)
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, fmt.Errorf("failed to determine whether the token has an aud claim: %w", err)
|
||||
}
|
||||
|
||||
@@ -113,7 +121,9 @@ func (p *providerImpl) Verify(ctx context.Context, tokenString string, argoSetti
|
||||
idToken, err = p.verify(ctx, "", tokenString, argoSettings.SkipAudienceCheckWhenTokenHasNoAudience())
|
||||
} else {
|
||||
allowedAudiences := argoSettings.OAuth2AllowedAudiences()
|
||||
span.SetAttributes(attribute.StringSlice("allowedAudiences", allowedAudiences))
|
||||
if len(allowedAudiences) == 0 {
|
||||
span.SetStatus(codes.Error, "token has an audience claim, but no allowed audiences are configured")
|
||||
return nil, errors.New("token has an audience claim, but no allowed audiences are configured")
|
||||
}
|
||||
tokenVerificationErrors := make(map[string]error)
|
||||
@@ -143,6 +153,7 @@ func (p *providerImpl) Verify(ctx context.Context, tokenString string, argoSetti
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
return nil, fmt.Errorf("failed to verify provider token: %w", err)
|
||||
}
|
||||
|
||||
@@ -150,8 +161,12 @@ func (p *providerImpl) Verify(ctx context.Context, tokenString string, argoSetti
|
||||
}
|
||||
|
||||
func (p *providerImpl) verify(ctx context.Context, clientID, tokenString string, skipClientIDCheck bool) (*gooidc.IDToken, error) {
|
||||
var span trace.Span
|
||||
ctx, span = tracer.Start(ctx, "oidc.providerImpl.verify")
|
||||
defer span.End()
|
||||
prov, err := p.provider()
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, fmt.Sprintf("failed to query provider: %v", err))
|
||||
return nil, err
|
||||
}
|
||||
config := &gooidc.Config{ClientID: clientID, SkipClientIDCheck: skipClientIDCheck}
|
||||
@@ -166,16 +181,19 @@ func (p *providerImpl) verify(ctx context.Context, clientID, tokenString string,
|
||||
// 3. re-attempting token verification
|
||||
// NOTE: the error message is sensitive to implementation of verifier.Verify()
|
||||
if !strings.Contains(err.Error(), "failed to verify signature") {
|
||||
span.SetStatus(codes.Error, fmt.Sprintf("error verifying token: %v", err))
|
||||
return nil, err
|
||||
}
|
||||
newProvider, retryErr := p.newGoOIDCProvider()
|
||||
if retryErr != nil {
|
||||
span.SetStatus(codes.Error, fmt.Sprintf("hack: error verifying token on retry: %v", err))
|
||||
// return original error if we fail to re-initialize OIDC
|
||||
return nil, err
|
||||
}
|
||||
verifier = newProvider.Verifier(config)
|
||||
idToken, err = verifier.Verify(ctx, tokenString)
|
||||
if err != nil {
|
||||
span.SetStatus(codes.Error, fmt.Sprintf("hack: error verifying token: %v", err))
|
||||
return nil, err
|
||||
}
|
||||
// If we get here, we successfully re-initialized OIDC and after re-initialization,
|
||||
|
||||
@@ -13,6 +13,11 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
otel_codes "go.opentelemetry.io/otel/codes"
|
||||
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
@@ -100,6 +105,13 @@ const (
|
||||
|
||||
var InvalidLoginErr = status.Errorf(codes.Unauthenticated, invalidLoginError)
|
||||
|
||||
// OpenTelemetry tracer for this package
|
||||
var tracer trace.Tracer
|
||||
|
||||
func init() {
|
||||
tracer = otel.Tracer("github.com/argoproj/argo-cd/v3/util/session")
|
||||
}
|
||||
|
||||
// Returns the maximum cache size as number of entries
|
||||
func getMaximumCacheSize() int {
|
||||
return env.ParseNumFromEnv(envLoginMaxCacheSize, defaultMaxCacheSize, 1, math.MaxInt32)
|
||||
@@ -536,6 +548,9 @@ func WithAuthMiddleware(disabled bool, isSSOConfigured bool, ssoClientApp *oidcu
|
||||
// VerifyToken verifies if a token is correct. Tokens can be issued either from us or by an IDP.
|
||||
// We choose how to verify based on the issuer.
|
||||
func (mgr *SessionManager) VerifyToken(ctx context.Context, tokenString string) (jwt.Claims, string, error) {
|
||||
var span trace.Span
|
||||
ctx, span = tracer.Start(ctx, "session.SessionManager.VerifyToken")
|
||||
defer span.End()
|
||||
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
|
||||
claims := jwt.MapClaims{}
|
||||
_, _, err := parser.ParseUnverified(tokenString, &claims)
|
||||
@@ -568,7 +583,9 @@ func (mgr *SessionManager) VerifyToken(ctx context.Context, tokenString string)
|
||||
// return a dummy claims only containing a value for the issuer, so the
|
||||
// UI can handle expired tokens appropriately.
|
||||
if err != nil {
|
||||
log.Warnf("Failed to verify session token: %s", err)
|
||||
errorMsg := "Failed to verify session token: " + err.Error()
|
||||
span.SetStatus(otel_codes.Error, errorMsg)
|
||||
log.Warn(errorMsg)
|
||||
tokenExpiredError := &oidc.TokenExpiredError{}
|
||||
if errors.As(err, &tokenExpiredError) {
|
||||
claims = jwt.MapClaims{
|
||||
|
||||
Reference in New Issue
Block a user