mirror of
https://github.com/argoproj/argo-cd.git
synced 2026-02-20 01:28:45 +01:00
fix(oidc): check userinfo endpoint in AuthMiddleware (#23586)
Signed-off-by: Nathanael Liechti <technat@technat.ch>
This commit is contained in:
committed by
GitHub
parent
96038ba2a1
commit
5efb184c79
1
.gitignore
vendored
1
.gitignore
vendored
@@ -20,6 +20,7 @@ node_modules/
|
||||
.kube/
|
||||
./test/cmp/*.sock
|
||||
.envrc.remote
|
||||
.mirrord/
|
||||
.*.swp
|
||||
rerunreport.txt
|
||||
|
||||
|
||||
@@ -1221,9 +1221,12 @@ func (server *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWeb
|
||||
|
||||
terminalOpts := application.TerminalOptions{DisableAuth: server.DisableAuth, Enf: server.enf}
|
||||
|
||||
// SSO ClientApp
|
||||
server.ssoClientApp, _ = oidc.NewClientApp(server.settings, server.DexServerAddr, server.DexTLSConfig, server.BaseHRef, cacheutil.NewRedisCache(server.RedisClient, server.settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone))
|
||||
|
||||
terminal := application.NewHandler(server.appLister, server.Namespace, server.ApplicationNamespaces, server.db, appResourceTreeFn, server.settings.ExecShells, server.sessionMgr, &terminalOpts).
|
||||
WithFeatureFlagMiddleware(server.settingsMgr.GetSettings)
|
||||
th := util_session.WithAuthMiddleware(server.DisableAuth, server.sessionMgr, terminal)
|
||||
th := util_session.WithAuthMiddleware(server.DisableAuth, server.settings.IsSSOConfigured(), server.ssoClientApp, server.sessionMgr, terminal)
|
||||
mux.Handle("/terminal", th)
|
||||
|
||||
// Proxy extension is currently an alpha feature and is disabled
|
||||
@@ -1253,7 +1256,7 @@ func (server *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWeb
|
||||
swagger.ServeSwaggerUI(mux, assets.SwaggerJSON, "/swagger-ui", server.RootPath)
|
||||
healthz.ServeHealthCheck(mux, server.healthCheck)
|
||||
|
||||
// Dex reverse proxy and client app and OAuth2 login/callback
|
||||
// Dex reverse proxy and OAuth2 login/callback
|
||||
server.registerDexHandlers(mux)
|
||||
|
||||
// Webhook handler for git events (Note: cache timeouts are hardcoded because API server does not write to cache and not really using them)
|
||||
@@ -1305,7 +1308,7 @@ func enforceContentTypes(handler http.Handler, types []string) http.Handler {
|
||||
func registerExtensions(mux *http.ServeMux, a *ArgoCDServer, metricsReg HTTPMetricsRegistry) {
|
||||
a.log.Info("Registering extensions...")
|
||||
extHandler := http.HandlerFunc(a.extensionManager.CallExtension())
|
||||
authMiddleware := a.sessionMgr.AuthMiddlewareFunc(a.DisableAuth)
|
||||
authMiddleware := a.sessionMgr.AuthMiddlewareFunc(a.DisableAuth, a.settings.IsSSOConfigured(), a.ssoClientApp)
|
||||
// auth middleware ensures that requests to all extensions are authenticated first
|
||||
mux.Handle(extension.URLPrefix+"/", authMiddleware(extHandler))
|
||||
|
||||
@@ -1359,7 +1362,7 @@ func (server *ArgoCDServer) serveExtensions(extensionsSharedPath string, w http.
|
||||
}
|
||||
}
|
||||
|
||||
// registerDexHandlers will register dex HTTP handlers, creating the OAuth client app
|
||||
// registerDexHandlers will register dex HTTP handlers
|
||||
func (server *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) {
|
||||
if !server.settings.IsSSOConfigured() {
|
||||
return
|
||||
@@ -1367,7 +1370,6 @@ func (server *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) {
|
||||
// Run dex OpenID Connect Identity Provider behind a reverse proxy (served at /api/dex)
|
||||
var err error
|
||||
mux.HandleFunc(common.DexAPIEndpoint+"/", dexutil.NewDexHTTPReverseProxy(server.DexServerAddr, server.BaseHRef, server.DexTLSConfig))
|
||||
server.ssoClientApp, err = oidc.NewClientApp(server.settings, server.DexServerAddr, server.DexTLSConfig, server.BaseHRef, cacheutil.NewRedisCache(server.RedisClient, server.settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone))
|
||||
errorsutil.CheckError(err)
|
||||
mux.HandleFunc(common.LoginEndpoint, server.ssoClientApp.HandleLogin)
|
||||
mux.HandleFunc(common.CallbackEndpoint, server.ssoClientApp.HandleCallback)
|
||||
@@ -1578,34 +1580,15 @@ func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string,
|
||||
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err)
|
||||
}
|
||||
|
||||
// Some SSO implementations (Okta) require a call to
|
||||
// the OIDC user info path to get attributes like groups
|
||||
// we assume that everywhere in argocd jwt.MapClaims is used as type for interface jwt.Claims
|
||||
// otherwise this would cause a panic
|
||||
var groupClaims jwt.MapClaims
|
||||
if groupClaims, ok = claims.(jwt.MapClaims); !ok {
|
||||
if tmpClaims, ok := claims.(*jwt.MapClaims); ok {
|
||||
groupClaims = *tmpClaims
|
||||
}
|
||||
}
|
||||
iss := jwtutil.StringField(groupClaims, "iss")
|
||||
if iss != util_session.SessionManagerClaimsIssuer && server.settings.UserInfoGroupsEnabled() && server.settings.UserInfoPath() != "" {
|
||||
userInfo, unauthorized, err := server.ssoClientApp.GetUserInfo(groupClaims, server.settings.IssuerURL(), server.settings.UserInfoPath())
|
||||
if unauthorized {
|
||||
log.Errorf("error while quering userinfo endpoint: %v", err)
|
||||
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session")
|
||||
}
|
||||
finalClaims := claims
|
||||
if server.settings.IsSSOConfigured() {
|
||||
finalClaims, err = server.ssoClientApp.SetGroupsFromUserInfo(claims, util_session.SessionManagerClaimsIssuer)
|
||||
if err != nil {
|
||||
log.Errorf("error fetching user info endpoint: %v", err)
|
||||
return claims, "", status.Errorf(codes.Internal, "invalid userinfo response")
|
||||
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err)
|
||||
}
|
||||
if groupClaims["sub"] != userInfo["sub"] {
|
||||
return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo")
|
||||
}
|
||||
groupClaims["groups"] = userInfo["groups"]
|
||||
}
|
||||
|
||||
return groupClaims, newToken, nil
|
||||
return finalClaims, newToken, nil
|
||||
}
|
||||
|
||||
// getToken extracts the token from gRPC metadata or cookie headers
|
||||
|
||||
@@ -493,7 +493,7 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
sub := jwtutil.StringField(claims, "sub")
|
||||
err = a.clientCache.Set(&cache.Item{
|
||||
Key: formatAccessTokenCacheKey(sub),
|
||||
Key: FormatAccessTokenCacheKey(sub),
|
||||
Object: encToken,
|
||||
CacheActionOpts: cache.CacheActionOpts{
|
||||
Expiration: getTokenExpiration(claims),
|
||||
@@ -640,6 +640,39 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc
|
||||
return oauth2.SetAuthURLParam("claims", string(claimsRequestRAW)), nil
|
||||
}
|
||||
|
||||
// SetGroupsFromUserInfo takes a claims object and adds groups claim from userinfo endpoint if available
|
||||
// This is required by some SSO implementations as they don't provide the groups claim in the ID token
|
||||
// If querying the UserInfo endpoint fails, we return an error to indicate the session is invalid
|
||||
// we assume that everywhere in argocd jwt.MapClaims is used as type for interface jwt.Claims
|
||||
// otherwise this would cause a panic
|
||||
func (a *ClientApp) SetGroupsFromUserInfo(claims jwt.Claims, sessionManagerClaimsIssuer string) (jwt.MapClaims, error) {
|
||||
var groupClaims jwt.MapClaims
|
||||
var ok bool
|
||||
if groupClaims, ok = claims.(jwt.MapClaims); !ok {
|
||||
if tmpClaims, ok := claims.(*jwt.MapClaims); ok {
|
||||
if tmpClaims != nil {
|
||||
groupClaims = *tmpClaims
|
||||
}
|
||||
}
|
||||
}
|
||||
iss := jwtutil.StringField(groupClaims, "iss")
|
||||
if iss != sessionManagerClaimsIssuer && a.settings.UserInfoGroupsEnabled() && a.settings.UserInfoPath() != "" {
|
||||
userInfo, unauthorized, err := a.GetUserInfo(groupClaims, a.settings.IssuerURL(), a.settings.UserInfoPath())
|
||||
if unauthorized {
|
||||
return groupClaims, fmt.Errorf("error while quering userinfo endpoint: %w", err)
|
||||
}
|
||||
if err != nil {
|
||||
return groupClaims, fmt.Errorf("error fetching user info endpoint: %w", err)
|
||||
}
|
||||
if groupClaims["sub"] != userInfo["sub"] {
|
||||
return groupClaims, errors.New("subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo")
|
||||
}
|
||||
groupClaims["groups"] = userInfo["groups"]
|
||||
}
|
||||
|
||||
return groupClaims, nil
|
||||
}
|
||||
|
||||
// GetUserInfo queries the IDP userinfo endpoint for claims
|
||||
func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
|
||||
sub := jwtutil.StringField(actualClaims, "sub")
|
||||
@@ -647,7 +680,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
|
||||
var encClaims []byte
|
||||
|
||||
// in case we got it in the cache, we just return the item
|
||||
clientCacheKey := formatUserInfoResponseCacheKey(sub)
|
||||
clientCacheKey := FormatUserInfoResponseCacheKey(sub)
|
||||
if err := a.clientCache.Get(clientCacheKey, &encClaims); err == nil {
|
||||
claimsRaw, err := crypto.Decrypt(encClaims, a.encryptionKey)
|
||||
if err != nil {
|
||||
@@ -664,7 +697,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
|
||||
|
||||
// check if the accessToken for the user is still present
|
||||
var encAccessToken []byte
|
||||
err := a.clientCache.Get(formatAccessTokenCacheKey(sub), &encAccessToken)
|
||||
err := a.clientCache.Get(FormatAccessTokenCacheKey(sub), &encAccessToken)
|
||||
// without an accessToken we can't query the user info endpoint
|
||||
// thus the user needs to reauthenticate for argocd to get a new accessToken
|
||||
if errors.Is(err, cache.ErrCacheMiss) {
|
||||
@@ -774,11 +807,11 @@ func getTokenExpiration(claims jwt.MapClaims) time.Duration {
|
||||
}
|
||||
|
||||
// formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache
|
||||
func formatUserInfoResponseCacheKey(sub string) string {
|
||||
func FormatUserInfoResponseCacheKey(sub string) string {
|
||||
return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub)
|
||||
}
|
||||
|
||||
// formatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache
|
||||
func formatAccessTokenCacheKey(sub string) string {
|
||||
func FormatAccessTokenCacheKey(sub string) string {
|
||||
return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub)
|
||||
}
|
||||
|
||||
@@ -943,7 +943,7 @@ func TestGetUserInfo(t *testing.T) {
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
key: formatUserInfoResponseCacheKey("randomUser"),
|
||||
key: FormatUserInfoResponseCacheKey("randomUser"),
|
||||
expectError: true,
|
||||
},
|
||||
},
|
||||
@@ -958,7 +958,7 @@ func TestGetUserInfo(t *testing.T) {
|
||||
encrypt bool
|
||||
}{
|
||||
{
|
||||
key: formatAccessTokenCacheKey("randomUser"),
|
||||
key: FormatAccessTokenCacheKey("randomUser"),
|
||||
value: "FakeAccessToken",
|
||||
encrypt: true,
|
||||
},
|
||||
@@ -977,7 +977,7 @@ func TestGetUserInfo(t *testing.T) {
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
key: formatUserInfoResponseCacheKey("randomUser"),
|
||||
key: FormatUserInfoResponseCacheKey("randomUser"),
|
||||
expectError: true,
|
||||
},
|
||||
},
|
||||
@@ -992,7 +992,7 @@ func TestGetUserInfo(t *testing.T) {
|
||||
encrypt bool
|
||||
}{
|
||||
{
|
||||
key: formatAccessTokenCacheKey("randomUser"),
|
||||
key: FormatAccessTokenCacheKey("randomUser"),
|
||||
value: "FakeAccessToken",
|
||||
encrypt: true,
|
||||
},
|
||||
@@ -1011,7 +1011,7 @@ func TestGetUserInfo(t *testing.T) {
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
key: formatUserInfoResponseCacheKey("randomUser"),
|
||||
key: FormatUserInfoResponseCacheKey("randomUser"),
|
||||
expectError: true,
|
||||
},
|
||||
},
|
||||
@@ -1034,7 +1034,7 @@ func TestGetUserInfo(t *testing.T) {
|
||||
encrypt bool
|
||||
}{
|
||||
{
|
||||
key: formatAccessTokenCacheKey("randomUser"),
|
||||
key: FormatAccessTokenCacheKey("randomUser"),
|
||||
value: "FakeAccessToken",
|
||||
encrypt: true,
|
||||
},
|
||||
@@ -1053,7 +1053,7 @@ func TestGetUserInfo(t *testing.T) {
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
key: formatUserInfoResponseCacheKey("randomUser"),
|
||||
key: FormatUserInfoResponseCacheKey("randomUser"),
|
||||
expectError: true,
|
||||
},
|
||||
},
|
||||
@@ -1086,7 +1086,7 @@ func TestGetUserInfo(t *testing.T) {
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
key: formatUserInfoResponseCacheKey("randomUser"),
|
||||
key: FormatUserInfoResponseCacheKey("randomUser"),
|
||||
value: "{\"groups\":[\"githubOrg:engineers\"]}",
|
||||
expectEncrypted: true,
|
||||
expectError: false,
|
||||
@@ -1113,7 +1113,7 @@ func TestGetUserInfo(t *testing.T) {
|
||||
encrypt bool
|
||||
}{
|
||||
{
|
||||
key: formatAccessTokenCacheKey("randomUser"),
|
||||
key: FormatAccessTokenCacheKey("randomUser"),
|
||||
value: "FakeAccessToken",
|
||||
encrypt: true,
|
||||
},
|
||||
@@ -1172,3 +1172,94 @@ func TestGetUserInfo(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetGroupsFromUserInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputClaims jwt.MapClaims // function input
|
||||
cacheClaims jwt.MapClaims // userinfo response
|
||||
expectedClaims jwt.MapClaims // function output
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "set correct groups from userinfo endpoint", // enriches the JWT claims with information from the userinfo endpoint, default case
|
||||
inputClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||
cacheClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"githubOrg:example"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||
expectedClaims: jwt.MapClaims{"sub": "randomUser", "groups": []any{"githubOrg:example"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, // the groups must be of type any since the response we get was parsed by GetUserInfo and we don't yet know the type of the groups claim
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "return error for wrong userinfo claims returned", // when there's an error in this feature, the claims should be untouched for the rest to still proceed
|
||||
inputClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||
cacheClaims: jwt.MapClaims{"sub": "wrongUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||
expectedClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "override groups already defined in input claims", // this is expected behavior since input claims might have been truncated (HTTP header 4K limit)
|
||||
inputClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"groupfromjwt"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||
cacheClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"superusers", "usergroup", "support-group"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||
expectedClaims: jwt.MapClaims{"sub": "randomUser", "groups": []any{"superusers", "usergroup", "support-group"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "empty cache and non-rechable userinfo endpoint", // this will try to reach the userinfo endpoint defined in the test and fail
|
||||
inputClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"groupfromjwt"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||
cacheClaims: nil, // the test doesn't set the cache for an empty object
|
||||
expectedClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"groupfromjwt"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// create the ClientApp
|
||||
userInfoCache := cache.NewInMemoryCache(24 * time.Hour)
|
||||
signature, err := util.MakeSignature(32)
|
||||
require.NoError(t, err, "failed creating signature for settings object")
|
||||
cdSettings := &settings.ArgoCDSettings{
|
||||
ServerSignature: signature,
|
||||
OIDCConfigRAW: `
|
||||
issuer: http://localhost:63231
|
||||
enableUserInfoGroups: true
|
||||
userInfoPath: /`,
|
||||
}
|
||||
a, err := NewClientApp(cdSettings, "", nil, "/argo-cd", userInfoCache)
|
||||
require.NoError(t, err, "failed creating clientapp")
|
||||
|
||||
// prepoluate cache to predict what the GetUserInfo function will return to the SetGroupsFromUserInfo function (without having to mock the userinfo response)
|
||||
encryptionKey, err := cdSettings.GetServerEncryptionKey()
|
||||
require.NoError(t, err, "failed obtaining encryption key from settings")
|
||||
|
||||
// set fake accessToken for function to not return early
|
||||
encAccessToken, err := crypto.Encrypt([]byte("123456"), encryptionKey)
|
||||
require.NoError(t, err, "failed encrypting dummy access token")
|
||||
err = a.clientCache.Set(&cache.Item{
|
||||
Key: FormatAccessTokenCacheKey("randomUser"),
|
||||
Object: encAccessToken,
|
||||
})
|
||||
require.NoError(t, err, "failed setting item to in-memory cache")
|
||||
|
||||
// set cacheClaims to in-memory cache to let GetUserInfo return early with this information (GetUserInfo has a separate test, here we focus on SetUserInfoGroups)
|
||||
if tt.cacheClaims != nil {
|
||||
cacheClaims, err := json.Marshal(tt.cacheClaims)
|
||||
require.NoError(t, err)
|
||||
encCacheClaims, err := crypto.Encrypt([]byte(cacheClaims), encryptionKey)
|
||||
require.NoError(t, err, "failed encrypting dummy access token")
|
||||
err = a.clientCache.Set(&cache.Item{
|
||||
Key: FormatUserInfoResponseCacheKey("randomUser"),
|
||||
Object: encCacheClaims,
|
||||
})
|
||||
require.NoError(t, err, "failed setting item to in-memory cache")
|
||||
}
|
||||
|
||||
receivedClaims, err := a.SetGroupsFromUserInfo(tt.inputClaims, "argocd")
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
assert.Equal(t, tt.expectedClaims, receivedClaims) // check that the claims were successfully enriched with what we expect
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -480,9 +480,9 @@ func (mgr *SessionManager) VerifyUsernamePassword(username string, password stri
|
||||
|
||||
// AuthMiddlewareFunc returns a function that can be used as an
|
||||
// authentication middleware for HTTP requests.
|
||||
func (mgr *SessionManager) AuthMiddlewareFunc(disabled bool) func(http.Handler) http.Handler {
|
||||
func (mgr *SessionManager) AuthMiddlewareFunc(disabled bool, isSSOConfigured bool, ssoClientApp *oidcutil.ClientApp) func(http.Handler) http.Handler {
|
||||
return func(h http.Handler) http.Handler {
|
||||
return WithAuthMiddleware(disabled, mgr, h)
|
||||
return WithAuthMiddleware(disabled, isSSOConfigured, ssoClientApp, mgr, h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -495,26 +495,41 @@ type TokenVerifier interface {
|
||||
// WithAuthMiddleware is an HTTP middleware used to ensure incoming
|
||||
// requests are authenticated before invoking the target handler. If
|
||||
// disabled is true, it will just invoke the next handler in the chain.
|
||||
func WithAuthMiddleware(disabled bool, authn TokenVerifier, next http.Handler) http.Handler {
|
||||
func WithAuthMiddleware(disabled bool, isSSOConfigured bool, ssoClientApp *oidcutil.ClientApp, authn TokenVerifier, next http.Handler) http.Handler {
|
||||
if disabled {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !disabled {
|
||||
cookies := r.Cookies()
|
||||
tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies)
|
||||
if err != nil {
|
||||
http.Error(w, "Auth cookie not found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
claims, _, err := authn.VerifyToken(tokenString)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
// Add claims to the context to inspect for RBAC
|
||||
//nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, "claims", claims)
|
||||
r = r.WithContext(ctx)
|
||||
cookies := r.Cookies()
|
||||
tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies)
|
||||
if err != nil {
|
||||
http.Error(w, "Auth cookie not found", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
claims, _, err := authn.VerifyToken(tokenString)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
finalClaims := claims
|
||||
if isSSOConfigured {
|
||||
finalClaims, err = ssoClientApp.SetGroupsFromUserInfo(claims, SessionManagerClaimsIssuer)
|
||||
if err != nil {
|
||||
http.Error(w, "Invalid session", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
// Add claims to the context to inspect for RBAC
|
||||
//nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, "claims", finalClaims)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
@@ -29,7 +30,11 @@ import (
|
||||
apps "github.com/argoproj/argo-cd/v3/pkg/client/clientset/versioned/fake"
|
||||
"github.com/argoproj/argo-cd/v3/pkg/client/listers/application/v1alpha1"
|
||||
"github.com/argoproj/argo-cd/v3/test"
|
||||
"github.com/argoproj/argo-cd/v3/util"
|
||||
"github.com/argoproj/argo-cd/v3/util/cache"
|
||||
"github.com/argoproj/argo-cd/v3/util/crypto"
|
||||
jwtutil "github.com/argoproj/argo-cd/v3/util/jwt"
|
||||
"github.com/argoproj/argo-cd/v3/util/oidc"
|
||||
"github.com/argoproj/argo-cd/v3/util/password"
|
||||
"github.com/argoproj/argo-cd/v3/util/settings"
|
||||
utiltest "github.com/argoproj/argo-cd/v3/util/test"
|
||||
@@ -236,20 +241,39 @@ func strPointer(str string) *string {
|
||||
|
||||
func TestSessionManager_WithAuthMiddleware(t *testing.T) {
|
||||
handlerFunc := func() func(http.ResponseWriter, *http.Request) {
|
||||
return func(w http.ResponseWriter, _ *http.Request) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Helper()
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "application/text")
|
||||
_, err := w.Write([]byte("Ok"))
|
||||
require.NoError(t, err, "error writing response: %s", err)
|
||||
|
||||
contextClaims := r.Context().Value("claims")
|
||||
if contextClaims != nil {
|
||||
var gotClaims jwt.MapClaims
|
||||
var ok bool
|
||||
if gotClaims, ok = contextClaims.(jwt.MapClaims); !ok {
|
||||
if tmpClaims, ok := contextClaims.(*jwt.MapClaims); ok && tmpClaims != nil {
|
||||
gotClaims = *tmpClaims
|
||||
}
|
||||
}
|
||||
jsonClaims, err := json.Marshal(gotClaims)
|
||||
require.NoError(t, err, "erorr marshalling claims set by AuthMiddleware")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, err = w.Write(jsonClaims)
|
||||
require.NoError(t, err, "error writing response: %s", err)
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "application/text")
|
||||
_, err := w.Write([]byte("Ok"))
|
||||
require.NoError(t, err, "error writing response: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
authDisabled bool
|
||||
ssoEnabled bool
|
||||
cookieHeader bool
|
||||
verifiedClaims *jwt.RegisteredClaims
|
||||
verifiedClaims *jwt.MapClaims
|
||||
verifyTokenErr error
|
||||
userInfoCacheClaims *jwt.MapClaims
|
||||
expectedStatusCode int
|
||||
expectedResponseBody *string
|
||||
}
|
||||
@@ -258,47 +282,79 @@ func TestSessionManager_WithAuthMiddleware(t *testing.T) {
|
||||
{
|
||||
name: "will authenticate successfully",
|
||||
authDisabled: false,
|
||||
ssoEnabled: false,
|
||||
cookieHeader: true,
|
||||
verifiedClaims: &jwt.RegisteredClaims{},
|
||||
verifiedClaims: &jwt.MapClaims{},
|
||||
verifyTokenErr: nil,
|
||||
userInfoCacheClaims: nil,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedResponseBody: strPointer("Ok"),
|
||||
expectedResponseBody: strPointer("{}"),
|
||||
},
|
||||
{
|
||||
name: "will be noop if auth is disabled",
|
||||
authDisabled: true,
|
||||
ssoEnabled: false,
|
||||
cookieHeader: false,
|
||||
verifiedClaims: nil,
|
||||
verifyTokenErr: nil,
|
||||
userInfoCacheClaims: nil,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedResponseBody: strPointer("Ok"),
|
||||
},
|
||||
{
|
||||
name: "will return 400 if no cookie header",
|
||||
authDisabled: false,
|
||||
ssoEnabled: false,
|
||||
cookieHeader: false,
|
||||
verifiedClaims: &jwt.RegisteredClaims{},
|
||||
verifiedClaims: &jwt.MapClaims{},
|
||||
verifyTokenErr: nil,
|
||||
userInfoCacheClaims: nil,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
expectedResponseBody: nil,
|
||||
},
|
||||
{
|
||||
name: "will return 401 verify token fails",
|
||||
authDisabled: false,
|
||||
ssoEnabled: false,
|
||||
cookieHeader: true,
|
||||
verifiedClaims: &jwt.RegisteredClaims{},
|
||||
verifiedClaims: &jwt.MapClaims{},
|
||||
verifyTokenErr: stderrors.New("token error"),
|
||||
userInfoCacheClaims: nil,
|
||||
expectedStatusCode: http.StatusUnauthorized,
|
||||
expectedResponseBody: nil,
|
||||
},
|
||||
{
|
||||
name: "will return 200 if claims are nil",
|
||||
authDisabled: false,
|
||||
ssoEnabled: false,
|
||||
cookieHeader: true,
|
||||
verifiedClaims: nil,
|
||||
verifyTokenErr: nil,
|
||||
userInfoCacheClaims: nil,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedResponseBody: strPointer("Ok"),
|
||||
expectedResponseBody: strPointer("null"),
|
||||
},
|
||||
{
|
||||
name: "will return 401 if sso is enabled but userinfo response not working",
|
||||
authDisabled: false,
|
||||
ssoEnabled: true,
|
||||
cookieHeader: true,
|
||||
verifiedClaims: nil,
|
||||
verifyTokenErr: nil,
|
||||
userInfoCacheClaims: nil, // indicates that the userinfo response will not work since cache is empty and userinfo endpoint not rechable
|
||||
expectedStatusCode: http.StatusUnauthorized,
|
||||
expectedResponseBody: strPointer("Invalid session"),
|
||||
},
|
||||
{
|
||||
name: "will return 200 if sso is enabled and userinfo response from cache is valid",
|
||||
authDisabled: false,
|
||||
ssoEnabled: true,
|
||||
cookieHeader: true,
|
||||
verifiedClaims: &jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||
verifyTokenErr: nil,
|
||||
userInfoCacheClaims: &jwt.MapClaims{"sub": "randomUser", "groups": []string{"superusers"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedResponseBody: strPointer("\"groups\":[\"superusers\"]"),
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
@@ -311,7 +367,47 @@ func TestSessionManager_WithAuthMiddleware(t *testing.T) {
|
||||
claims: tc.verifiedClaims,
|
||||
err: tc.verifyTokenErr,
|
||||
}
|
||||
ts := httptest.NewServer(WithAuthMiddleware(tc.authDisabled, tm, mux))
|
||||
clientApp := &oidc.ClientApp{} // all testcases need at least the empty struct for the function to work
|
||||
if tc.ssoEnabled {
|
||||
userInfoCache := cache.NewInMemoryCache(24 * time.Hour)
|
||||
signature, err := util.MakeSignature(32)
|
||||
require.NoError(t, err, "failed creating signature for settings object")
|
||||
cdSettings := &settings.ArgoCDSettings{
|
||||
ServerSignature: signature,
|
||||
OIDCConfigRAW: `
|
||||
issuer: http://localhost:63231
|
||||
enableUserInfoGroups: true
|
||||
userInfoPath: /`,
|
||||
}
|
||||
clientApp, err = oidc.NewClientApp(cdSettings, "", nil, "/argo-cd", userInfoCache)
|
||||
require.NoError(t, err, "failed creating clientapp")
|
||||
|
||||
// prepopulate the cache with claims to return for a userinfo call
|
||||
encryptionKey, err := cdSettings.GetServerEncryptionKey()
|
||||
require.NoError(t, err, "failed obtaining encryption key from settings")
|
||||
// set fake accessToken for GetUserInfo to not return early (can be the same for all cases)
|
||||
encAccessToken, err := crypto.Encrypt([]byte("123456"), encryptionKey)
|
||||
require.NoError(t, err, "failed encrypting dummy access token")
|
||||
err = userInfoCache.Set(&cache.Item{
|
||||
Key: oidc.FormatAccessTokenCacheKey("randomUser"),
|
||||
Object: encAccessToken,
|
||||
})
|
||||
require.NoError(t, err, "failed setting item to in-memory cache")
|
||||
|
||||
// set cacheClaims to in-memory cache to let GetUserInfo return early with this information
|
||||
if tc.userInfoCacheClaims != nil {
|
||||
cacheClaims, err := json.Marshal(tc.userInfoCacheClaims)
|
||||
require.NoError(t, err)
|
||||
encCacheClaims, err := crypto.Encrypt([]byte(cacheClaims), encryptionKey)
|
||||
require.NoError(t, err, "failed encrypting cache Claims")
|
||||
err = userInfoCache.Set(&cache.Item{
|
||||
Key: oidc.FormatUserInfoResponseCacheKey("randomUser"),
|
||||
Object: encCacheClaims,
|
||||
})
|
||||
require.NoError(t, err, "failed setting item to in-memory cache")
|
||||
}
|
||||
}
|
||||
ts := httptest.NewServer(WithAuthMiddleware(tc.authDisabled, tc.ssoEnabled, clientApp, tm, mux))
|
||||
defer ts.Close()
|
||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, ts.URL, http.NoBody)
|
||||
require.NoErrorf(t, err, "error creating request: %s", err)
|
||||
|
||||
Reference in New Issue
Block a user