diff --git a/.gitignore b/.gitignore index 33698c9c5a..02607ac674 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ node_modules/ .kube/ ./test/cmp/*.sock .envrc.remote +.mirrord/ .*.swp rerunreport.txt diff --git a/server/server.go b/server/server.go index 28bd794d74..b506dab6ef 100644 --- a/server/server.go +++ b/server/server.go @@ -1221,9 +1221,12 @@ func (server *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWeb terminalOpts := application.TerminalOptions{DisableAuth: server.DisableAuth, Enf: server.enf} + // SSO ClientApp + server.ssoClientApp, _ = oidc.NewClientApp(server.settings, server.DexServerAddr, server.DexTLSConfig, server.BaseHRef, cacheutil.NewRedisCache(server.RedisClient, server.settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone)) + terminal := application.NewHandler(server.appLister, server.Namespace, server.ApplicationNamespaces, server.db, appResourceTreeFn, server.settings.ExecShells, server.sessionMgr, &terminalOpts). WithFeatureFlagMiddleware(server.settingsMgr.GetSettings) - th := util_session.WithAuthMiddleware(server.DisableAuth, server.sessionMgr, terminal) + th := util_session.WithAuthMiddleware(server.DisableAuth, server.settings.IsSSOConfigured(), server.ssoClientApp, server.sessionMgr, terminal) mux.Handle("/terminal", th) // Proxy extension is currently an alpha feature and is disabled @@ -1253,7 +1256,7 @@ func (server *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWeb swagger.ServeSwaggerUI(mux, assets.SwaggerJSON, "/swagger-ui", server.RootPath) healthz.ServeHealthCheck(mux, server.healthCheck) - // Dex reverse proxy and client app and OAuth2 login/callback + // Dex reverse proxy and OAuth2 login/callback server.registerDexHandlers(mux) // Webhook handler for git events (Note: cache timeouts are hardcoded because API server does not write to cache and not really using them) @@ -1305,7 +1308,7 @@ func enforceContentTypes(handler http.Handler, types []string) http.Handler { func registerExtensions(mux *http.ServeMux, a *ArgoCDServer, metricsReg HTTPMetricsRegistry) { a.log.Info("Registering extensions...") extHandler := http.HandlerFunc(a.extensionManager.CallExtension()) - authMiddleware := a.sessionMgr.AuthMiddlewareFunc(a.DisableAuth) + authMiddleware := a.sessionMgr.AuthMiddlewareFunc(a.DisableAuth, a.settings.IsSSOConfigured(), a.ssoClientApp) // auth middleware ensures that requests to all extensions are authenticated first mux.Handle(extension.URLPrefix+"/", authMiddleware(extHandler)) @@ -1359,7 +1362,7 @@ func (server *ArgoCDServer) serveExtensions(extensionsSharedPath string, w http. } } -// registerDexHandlers will register dex HTTP handlers, creating the OAuth client app +// registerDexHandlers will register dex HTTP handlers func (server *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) { if !server.settings.IsSSOConfigured() { return @@ -1367,7 +1370,6 @@ func (server *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) { // Run dex OpenID Connect Identity Provider behind a reverse proxy (served at /api/dex) var err error mux.HandleFunc(common.DexAPIEndpoint+"/", dexutil.NewDexHTTPReverseProxy(server.DexServerAddr, server.BaseHRef, server.DexTLSConfig)) - server.ssoClientApp, err = oidc.NewClientApp(server.settings, server.DexServerAddr, server.DexTLSConfig, server.BaseHRef, cacheutil.NewRedisCache(server.RedisClient, server.settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone)) errorsutil.CheckError(err) mux.HandleFunc(common.LoginEndpoint, server.ssoClientApp.HandleLogin) mux.HandleFunc(common.CallbackEndpoint, server.ssoClientApp.HandleCallback) @@ -1578,34 +1580,15 @@ func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err) } - // Some SSO implementations (Okta) require a call to - // the OIDC user info path to get attributes like groups - // we assume that everywhere in argocd jwt.MapClaims is used as type for interface jwt.Claims - // otherwise this would cause a panic - var groupClaims jwt.MapClaims - if groupClaims, ok = claims.(jwt.MapClaims); !ok { - if tmpClaims, ok := claims.(*jwt.MapClaims); ok { - groupClaims = *tmpClaims - } - } - iss := jwtutil.StringField(groupClaims, "iss") - if iss != util_session.SessionManagerClaimsIssuer && server.settings.UserInfoGroupsEnabled() && server.settings.UserInfoPath() != "" { - userInfo, unauthorized, err := server.ssoClientApp.GetUserInfo(groupClaims, server.settings.IssuerURL(), server.settings.UserInfoPath()) - if unauthorized { - log.Errorf("error while quering userinfo endpoint: %v", err) - return claims, "", status.Errorf(codes.Unauthenticated, "invalid session") - } + finalClaims := claims + if server.settings.IsSSOConfigured() { + finalClaims, err = server.ssoClientApp.SetGroupsFromUserInfo(claims, util_session.SessionManagerClaimsIssuer) if err != nil { - log.Errorf("error fetching user info endpoint: %v", err) - return claims, "", status.Errorf(codes.Internal, "invalid userinfo response") + return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err) } - if groupClaims["sub"] != userInfo["sub"] { - return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo") - } - groupClaims["groups"] = userInfo["groups"] } - return groupClaims, newToken, nil + return finalClaims, newToken, nil } // getToken extracts the token from gRPC metadata or cookie headers diff --git a/util/oidc/oidc.go b/util/oidc/oidc.go index 55d84d5e3c..76c3dfd10f 100644 --- a/util/oidc/oidc.go +++ b/util/oidc/oidc.go @@ -493,7 +493,7 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) { } sub := jwtutil.StringField(claims, "sub") err = a.clientCache.Set(&cache.Item{ - Key: formatAccessTokenCacheKey(sub), + Key: FormatAccessTokenCacheKey(sub), Object: encToken, CacheActionOpts: cache.CacheActionOpts{ Expiration: getTokenExpiration(claims), @@ -640,6 +640,39 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc return oauth2.SetAuthURLParam("claims", string(claimsRequestRAW)), nil } +// SetGroupsFromUserInfo takes a claims object and adds groups claim from userinfo endpoint if available +// This is required by some SSO implementations as they don't provide the groups claim in the ID token +// If querying the UserInfo endpoint fails, we return an error to indicate the session is invalid +// we assume that everywhere in argocd jwt.MapClaims is used as type for interface jwt.Claims +// otherwise this would cause a panic +func (a *ClientApp) SetGroupsFromUserInfo(claims jwt.Claims, sessionManagerClaimsIssuer string) (jwt.MapClaims, error) { + var groupClaims jwt.MapClaims + var ok bool + if groupClaims, ok = claims.(jwt.MapClaims); !ok { + if tmpClaims, ok := claims.(*jwt.MapClaims); ok { + if tmpClaims != nil { + groupClaims = *tmpClaims + } + } + } + iss := jwtutil.StringField(groupClaims, "iss") + if iss != sessionManagerClaimsIssuer && a.settings.UserInfoGroupsEnabled() && a.settings.UserInfoPath() != "" { + userInfo, unauthorized, err := a.GetUserInfo(groupClaims, a.settings.IssuerURL(), a.settings.UserInfoPath()) + if unauthorized { + return groupClaims, fmt.Errorf("error while quering userinfo endpoint: %w", err) + } + if err != nil { + return groupClaims, fmt.Errorf("error fetching user info endpoint: %w", err) + } + if groupClaims["sub"] != userInfo["sub"] { + return groupClaims, errors.New("subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo") + } + groupClaims["groups"] = userInfo["groups"] + } + + return groupClaims, nil +} + // GetUserInfo queries the IDP userinfo endpoint for claims func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) { sub := jwtutil.StringField(actualClaims, "sub") @@ -647,7 +680,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP var encClaims []byte // in case we got it in the cache, we just return the item - clientCacheKey := formatUserInfoResponseCacheKey(sub) + clientCacheKey := FormatUserInfoResponseCacheKey(sub) if err := a.clientCache.Get(clientCacheKey, &encClaims); err == nil { claimsRaw, err := crypto.Decrypt(encClaims, a.encryptionKey) if err != nil { @@ -664,7 +697,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP // check if the accessToken for the user is still present var encAccessToken []byte - err := a.clientCache.Get(formatAccessTokenCacheKey(sub), &encAccessToken) + err := a.clientCache.Get(FormatAccessTokenCacheKey(sub), &encAccessToken) // without an accessToken we can't query the user info endpoint // thus the user needs to reauthenticate for argocd to get a new accessToken if errors.Is(err, cache.ErrCacheMiss) { @@ -774,11 +807,11 @@ func getTokenExpiration(claims jwt.MapClaims) time.Duration { } // formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache -func formatUserInfoResponseCacheKey(sub string) string { +func FormatUserInfoResponseCacheKey(sub string) string { return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub) } // formatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache -func formatAccessTokenCacheKey(sub string) string { +func FormatAccessTokenCacheKey(sub string) string { return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub) } diff --git a/util/oidc/oidc_test.go b/util/oidc/oidc_test.go index 4db70d2cc0..e2ca575e40 100644 --- a/util/oidc/oidc_test.go +++ b/util/oidc/oidc_test.go @@ -943,7 +943,7 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: FormatUserInfoResponseCacheKey("randomUser"), expectError: true, }, }, @@ -958,7 +958,7 @@ func TestGetUserInfo(t *testing.T) { encrypt bool }{ { - key: formatAccessTokenCacheKey("randomUser"), + key: FormatAccessTokenCacheKey("randomUser"), value: "FakeAccessToken", encrypt: true, }, @@ -977,7 +977,7 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: FormatUserInfoResponseCacheKey("randomUser"), expectError: true, }, }, @@ -992,7 +992,7 @@ func TestGetUserInfo(t *testing.T) { encrypt bool }{ { - key: formatAccessTokenCacheKey("randomUser"), + key: FormatAccessTokenCacheKey("randomUser"), value: "FakeAccessToken", encrypt: true, }, @@ -1011,7 +1011,7 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: FormatUserInfoResponseCacheKey("randomUser"), expectError: true, }, }, @@ -1034,7 +1034,7 @@ func TestGetUserInfo(t *testing.T) { encrypt bool }{ { - key: formatAccessTokenCacheKey("randomUser"), + key: FormatAccessTokenCacheKey("randomUser"), value: "FakeAccessToken", encrypt: true, }, @@ -1053,7 +1053,7 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: FormatUserInfoResponseCacheKey("randomUser"), expectError: true, }, }, @@ -1086,7 +1086,7 @@ func TestGetUserInfo(t *testing.T) { expectError bool }{ { - key: formatUserInfoResponseCacheKey("randomUser"), + key: FormatUserInfoResponseCacheKey("randomUser"), value: "{\"groups\":[\"githubOrg:engineers\"]}", expectEncrypted: true, expectError: false, @@ -1113,7 +1113,7 @@ func TestGetUserInfo(t *testing.T) { encrypt bool }{ { - key: formatAccessTokenCacheKey("randomUser"), + key: FormatAccessTokenCacheKey("randomUser"), value: "FakeAccessToken", encrypt: true, }, @@ -1172,3 +1172,94 @@ func TestGetUserInfo(t *testing.T) { }) } } + +func TestSetGroupsFromUserInfo(t *testing.T) { + tests := []struct { + name string + inputClaims jwt.MapClaims // function input + cacheClaims jwt.MapClaims // userinfo response + expectedClaims jwt.MapClaims // function output + expectError bool + }{ + { + name: "set correct groups from userinfo endpoint", // enriches the JWT claims with information from the userinfo endpoint, default case + inputClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + cacheClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"githubOrg:example"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + expectedClaims: jwt.MapClaims{"sub": "randomUser", "groups": []any{"githubOrg:example"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, // the groups must be of type any since the response we get was parsed by GetUserInfo and we don't yet know the type of the groups claim + expectError: false, + }, + { + name: "return error for wrong userinfo claims returned", // when there's an error in this feature, the claims should be untouched for the rest to still proceed + inputClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + cacheClaims: jwt.MapClaims{"sub": "wrongUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + expectedClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + expectError: true, + }, + { + name: "override groups already defined in input claims", // this is expected behavior since input claims might have been truncated (HTTP header 4K limit) + inputClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"groupfromjwt"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + cacheClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"superusers", "usergroup", "support-group"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + expectedClaims: jwt.MapClaims{"sub": "randomUser", "groups": []any{"superusers", "usergroup", "support-group"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + expectError: false, + }, + { + name: "empty cache and non-rechable userinfo endpoint", // this will try to reach the userinfo endpoint defined in the test and fail + inputClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"groupfromjwt"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + cacheClaims: nil, // the test doesn't set the cache for an empty object + expectedClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"groupfromjwt"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // create the ClientApp + userInfoCache := cache.NewInMemoryCache(24 * time.Hour) + signature, err := util.MakeSignature(32) + require.NoError(t, err, "failed creating signature for settings object") + cdSettings := &settings.ArgoCDSettings{ + ServerSignature: signature, + OIDCConfigRAW: ` +issuer: http://localhost:63231 +enableUserInfoGroups: true +userInfoPath: /`, + } + a, err := NewClientApp(cdSettings, "", nil, "/argo-cd", userInfoCache) + require.NoError(t, err, "failed creating clientapp") + + // prepoluate cache to predict what the GetUserInfo function will return to the SetGroupsFromUserInfo function (without having to mock the userinfo response) + encryptionKey, err := cdSettings.GetServerEncryptionKey() + require.NoError(t, err, "failed obtaining encryption key from settings") + + // set fake accessToken for function to not return early + encAccessToken, err := crypto.Encrypt([]byte("123456"), encryptionKey) + require.NoError(t, err, "failed encrypting dummy access token") + err = a.clientCache.Set(&cache.Item{ + Key: FormatAccessTokenCacheKey("randomUser"), + Object: encAccessToken, + }) + require.NoError(t, err, "failed setting item to in-memory cache") + + // set cacheClaims to in-memory cache to let GetUserInfo return early with this information (GetUserInfo has a separate test, here we focus on SetUserInfoGroups) + if tt.cacheClaims != nil { + cacheClaims, err := json.Marshal(tt.cacheClaims) + require.NoError(t, err) + encCacheClaims, err := crypto.Encrypt([]byte(cacheClaims), encryptionKey) + require.NoError(t, err, "failed encrypting dummy access token") + err = a.clientCache.Set(&cache.Item{ + Key: FormatUserInfoResponseCacheKey("randomUser"), + Object: encCacheClaims, + }) + require.NoError(t, err, "failed setting item to in-memory cache") + } + + receivedClaims, err := a.SetGroupsFromUserInfo(tt.inputClaims, "argocd") + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + assert.Equal(t, tt.expectedClaims, receivedClaims) // check that the claims were successfully enriched with what we expect + }) + } +} diff --git a/util/session/sessionmanager.go b/util/session/sessionmanager.go index 39285ee7f5..a15d0e9c0b 100644 --- a/util/session/sessionmanager.go +++ b/util/session/sessionmanager.go @@ -480,9 +480,9 @@ func (mgr *SessionManager) VerifyUsernamePassword(username string, password stri // AuthMiddlewareFunc returns a function that can be used as an // authentication middleware for HTTP requests. -func (mgr *SessionManager) AuthMiddlewareFunc(disabled bool) func(http.Handler) http.Handler { +func (mgr *SessionManager) AuthMiddlewareFunc(disabled bool, isSSOConfigured bool, ssoClientApp *oidcutil.ClientApp) func(http.Handler) http.Handler { return func(h http.Handler) http.Handler { - return WithAuthMiddleware(disabled, mgr, h) + return WithAuthMiddleware(disabled, isSSOConfigured, ssoClientApp, mgr, h) } } @@ -495,26 +495,41 @@ type TokenVerifier interface { // WithAuthMiddleware is an HTTP middleware used to ensure incoming // requests are authenticated before invoking the target handler. If // disabled is true, it will just invoke the next handler in the chain. -func WithAuthMiddleware(disabled bool, authn TokenVerifier, next http.Handler) http.Handler { +func WithAuthMiddleware(disabled bool, isSSOConfigured bool, ssoClientApp *oidcutil.ClientApp, authn TokenVerifier, next http.Handler) http.Handler { + if disabled { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(w, r) + }) + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if !disabled { - cookies := r.Cookies() - tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies) - if err != nil { - http.Error(w, "Auth cookie not found", http.StatusBadRequest) - return - } - claims, _, err := authn.VerifyToken(tokenString) - if err != nil { - http.Error(w, "Invalid token", http.StatusUnauthorized) - return - } - ctx := r.Context() - // Add claims to the context to inspect for RBAC - //nolint:staticcheck - ctx = context.WithValue(ctx, "claims", claims) - r = r.WithContext(ctx) + cookies := r.Cookies() + tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies) + if err != nil { + http.Error(w, "Auth cookie not found", http.StatusBadRequest) + return } + claims, _, err := authn.VerifyToken(tokenString) + if err != nil { + http.Error(w, "Invalid token", http.StatusUnauthorized) + return + } + + finalClaims := claims + if isSSOConfigured { + finalClaims, err = ssoClientApp.SetGroupsFromUserInfo(claims, SessionManagerClaimsIssuer) + if err != nil { + http.Error(w, "Invalid session", http.StatusUnauthorized) + return + } + } + + ctx := r.Context() + // Add claims to the context to inspect for RBAC + //nolint:staticcheck + ctx = context.WithValue(ctx, "claims", finalClaims) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) }) } diff --git a/util/session/sessionmanager_test.go b/util/session/sessionmanager_test.go index f6a5bcc8a4..e2ae1bd080 100644 --- a/util/session/sessionmanager_test.go +++ b/util/session/sessionmanager_test.go @@ -2,6 +2,7 @@ package session import ( "context" + "encoding/json" "encoding/pem" stderrors "errors" "fmt" @@ -29,7 +30,11 @@ import ( apps "github.com/argoproj/argo-cd/v3/pkg/client/clientset/versioned/fake" "github.com/argoproj/argo-cd/v3/pkg/client/listers/application/v1alpha1" "github.com/argoproj/argo-cd/v3/test" + "github.com/argoproj/argo-cd/v3/util" + "github.com/argoproj/argo-cd/v3/util/cache" + "github.com/argoproj/argo-cd/v3/util/crypto" jwtutil "github.com/argoproj/argo-cd/v3/util/jwt" + "github.com/argoproj/argo-cd/v3/util/oidc" "github.com/argoproj/argo-cd/v3/util/password" "github.com/argoproj/argo-cd/v3/util/settings" utiltest "github.com/argoproj/argo-cd/v3/util/test" @@ -236,20 +241,39 @@ func strPointer(str string) *string { func TestSessionManager_WithAuthMiddleware(t *testing.T) { handlerFunc := func() func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, _ *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { t.Helper() w.WriteHeader(http.StatusOK) - w.Header().Set("Content-Type", "application/text") - _, err := w.Write([]byte("Ok")) - require.NoError(t, err, "error writing response: %s", err) + + contextClaims := r.Context().Value("claims") + if contextClaims != nil { + var gotClaims jwt.MapClaims + var ok bool + if gotClaims, ok = contextClaims.(jwt.MapClaims); !ok { + if tmpClaims, ok := contextClaims.(*jwt.MapClaims); ok && tmpClaims != nil { + gotClaims = *tmpClaims + } + } + jsonClaims, err := json.Marshal(gotClaims) + require.NoError(t, err, "erorr marshalling claims set by AuthMiddleware") + w.Header().Set("Content-Type", "application/json") + _, err = w.Write(jsonClaims) + require.NoError(t, err, "error writing response: %s", err) + } else { + w.Header().Set("Content-Type", "application/text") + _, err := w.Write([]byte("Ok")) + require.NoError(t, err, "error writing response: %s", err) + } } } type testCase struct { name string authDisabled bool + ssoEnabled bool cookieHeader bool - verifiedClaims *jwt.RegisteredClaims + verifiedClaims *jwt.MapClaims verifyTokenErr error + userInfoCacheClaims *jwt.MapClaims expectedStatusCode int expectedResponseBody *string } @@ -258,47 +282,79 @@ func TestSessionManager_WithAuthMiddleware(t *testing.T) { { name: "will authenticate successfully", authDisabled: false, + ssoEnabled: false, cookieHeader: true, - verifiedClaims: &jwt.RegisteredClaims{}, + verifiedClaims: &jwt.MapClaims{}, verifyTokenErr: nil, + userInfoCacheClaims: nil, expectedStatusCode: http.StatusOK, - expectedResponseBody: strPointer("Ok"), + expectedResponseBody: strPointer("{}"), }, { name: "will be noop if auth is disabled", authDisabled: true, + ssoEnabled: false, cookieHeader: false, verifiedClaims: nil, verifyTokenErr: nil, + userInfoCacheClaims: nil, expectedStatusCode: http.StatusOK, expectedResponseBody: strPointer("Ok"), }, { name: "will return 400 if no cookie header", authDisabled: false, + ssoEnabled: false, cookieHeader: false, - verifiedClaims: &jwt.RegisteredClaims{}, + verifiedClaims: &jwt.MapClaims{}, verifyTokenErr: nil, + userInfoCacheClaims: nil, expectedStatusCode: http.StatusBadRequest, expectedResponseBody: nil, }, { name: "will return 401 verify token fails", authDisabled: false, + ssoEnabled: false, cookieHeader: true, - verifiedClaims: &jwt.RegisteredClaims{}, + verifiedClaims: &jwt.MapClaims{}, verifyTokenErr: stderrors.New("token error"), + userInfoCacheClaims: nil, expectedStatusCode: http.StatusUnauthorized, expectedResponseBody: nil, }, { name: "will return 200 if claims are nil", authDisabled: false, + ssoEnabled: false, cookieHeader: true, verifiedClaims: nil, verifyTokenErr: nil, + userInfoCacheClaims: nil, expectedStatusCode: http.StatusOK, - expectedResponseBody: strPointer("Ok"), + expectedResponseBody: strPointer("null"), + }, + { + name: "will return 401 if sso is enabled but userinfo response not working", + authDisabled: false, + ssoEnabled: true, + cookieHeader: true, + verifiedClaims: nil, + verifyTokenErr: nil, + userInfoCacheClaims: nil, // indicates that the userinfo response will not work since cache is empty and userinfo endpoint not rechable + expectedStatusCode: http.StatusUnauthorized, + expectedResponseBody: strPointer("Invalid session"), + }, + { + name: "will return 200 if sso is enabled and userinfo response from cache is valid", + authDisabled: false, + ssoEnabled: true, + cookieHeader: true, + verifiedClaims: &jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + verifyTokenErr: nil, + userInfoCacheClaims: &jwt.MapClaims{"sub": "randomUser", "groups": []string{"superusers"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, + expectedStatusCode: http.StatusOK, + expectedResponseBody: strPointer("\"groups\":[\"superusers\"]"), }, } for _, tc := range cases { @@ -311,7 +367,47 @@ func TestSessionManager_WithAuthMiddleware(t *testing.T) { claims: tc.verifiedClaims, err: tc.verifyTokenErr, } - ts := httptest.NewServer(WithAuthMiddleware(tc.authDisabled, tm, mux)) + clientApp := &oidc.ClientApp{} // all testcases need at least the empty struct for the function to work + if tc.ssoEnabled { + userInfoCache := cache.NewInMemoryCache(24 * time.Hour) + signature, err := util.MakeSignature(32) + require.NoError(t, err, "failed creating signature for settings object") + cdSettings := &settings.ArgoCDSettings{ + ServerSignature: signature, + OIDCConfigRAW: ` +issuer: http://localhost:63231 +enableUserInfoGroups: true +userInfoPath: /`, + } + clientApp, err = oidc.NewClientApp(cdSettings, "", nil, "/argo-cd", userInfoCache) + require.NoError(t, err, "failed creating clientapp") + + // prepopulate the cache with claims to return for a userinfo call + encryptionKey, err := cdSettings.GetServerEncryptionKey() + require.NoError(t, err, "failed obtaining encryption key from settings") + // set fake accessToken for GetUserInfo to not return early (can be the same for all cases) + encAccessToken, err := crypto.Encrypt([]byte("123456"), encryptionKey) + require.NoError(t, err, "failed encrypting dummy access token") + err = userInfoCache.Set(&cache.Item{ + Key: oidc.FormatAccessTokenCacheKey("randomUser"), + Object: encAccessToken, + }) + require.NoError(t, err, "failed setting item to in-memory cache") + + // set cacheClaims to in-memory cache to let GetUserInfo return early with this information + if tc.userInfoCacheClaims != nil { + cacheClaims, err := json.Marshal(tc.userInfoCacheClaims) + require.NoError(t, err) + encCacheClaims, err := crypto.Encrypt([]byte(cacheClaims), encryptionKey) + require.NoError(t, err, "failed encrypting cache Claims") + err = userInfoCache.Set(&cache.Item{ + Key: oidc.FormatUserInfoResponseCacheKey("randomUser"), + Object: encCacheClaims, + }) + require.NoError(t, err, "failed setting item to in-memory cache") + } + } + ts := httptest.NewServer(WithAuthMiddleware(tc.authDisabled, tc.ssoEnabled, clientApp, tm, mux)) defer ts.Close() req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, ts.URL, http.NoBody) require.NoErrorf(t, err, "error creating request: %s", err)