mirror of
https://github.com/argoproj/argo-cd.git
synced 2026-02-20 01:28:45 +01:00
fix(oidc): check userinfo endpoint in AuthMiddleware (#23586)
Signed-off-by: Nathanael Liechti <technat@technat.ch>
This commit is contained in:
committed by
GitHub
parent
96038ba2a1
commit
5efb184c79
1
.gitignore
vendored
1
.gitignore
vendored
@@ -20,6 +20,7 @@ node_modules/
|
|||||||
.kube/
|
.kube/
|
||||||
./test/cmp/*.sock
|
./test/cmp/*.sock
|
||||||
.envrc.remote
|
.envrc.remote
|
||||||
|
.mirrord/
|
||||||
.*.swp
|
.*.swp
|
||||||
rerunreport.txt
|
rerunreport.txt
|
||||||
|
|
||||||
|
|||||||
@@ -1221,9 +1221,12 @@ func (server *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWeb
|
|||||||
|
|
||||||
terminalOpts := application.TerminalOptions{DisableAuth: server.DisableAuth, Enf: server.enf}
|
terminalOpts := application.TerminalOptions{DisableAuth: server.DisableAuth, Enf: server.enf}
|
||||||
|
|
||||||
|
// SSO ClientApp
|
||||||
|
server.ssoClientApp, _ = oidc.NewClientApp(server.settings, server.DexServerAddr, server.DexTLSConfig, server.BaseHRef, cacheutil.NewRedisCache(server.RedisClient, server.settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone))
|
||||||
|
|
||||||
terminal := application.NewHandler(server.appLister, server.Namespace, server.ApplicationNamespaces, server.db, appResourceTreeFn, server.settings.ExecShells, server.sessionMgr, &terminalOpts).
|
terminal := application.NewHandler(server.appLister, server.Namespace, server.ApplicationNamespaces, server.db, appResourceTreeFn, server.settings.ExecShells, server.sessionMgr, &terminalOpts).
|
||||||
WithFeatureFlagMiddleware(server.settingsMgr.GetSettings)
|
WithFeatureFlagMiddleware(server.settingsMgr.GetSettings)
|
||||||
th := util_session.WithAuthMiddleware(server.DisableAuth, server.sessionMgr, terminal)
|
th := util_session.WithAuthMiddleware(server.DisableAuth, server.settings.IsSSOConfigured(), server.ssoClientApp, server.sessionMgr, terminal)
|
||||||
mux.Handle("/terminal", th)
|
mux.Handle("/terminal", th)
|
||||||
|
|
||||||
// Proxy extension is currently an alpha feature and is disabled
|
// Proxy extension is currently an alpha feature and is disabled
|
||||||
@@ -1253,7 +1256,7 @@ func (server *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWeb
|
|||||||
swagger.ServeSwaggerUI(mux, assets.SwaggerJSON, "/swagger-ui", server.RootPath)
|
swagger.ServeSwaggerUI(mux, assets.SwaggerJSON, "/swagger-ui", server.RootPath)
|
||||||
healthz.ServeHealthCheck(mux, server.healthCheck)
|
healthz.ServeHealthCheck(mux, server.healthCheck)
|
||||||
|
|
||||||
// Dex reverse proxy and client app and OAuth2 login/callback
|
// Dex reverse proxy and OAuth2 login/callback
|
||||||
server.registerDexHandlers(mux)
|
server.registerDexHandlers(mux)
|
||||||
|
|
||||||
// Webhook handler for git events (Note: cache timeouts are hardcoded because API server does not write to cache and not really using them)
|
// Webhook handler for git events (Note: cache timeouts are hardcoded because API server does not write to cache and not really using them)
|
||||||
@@ -1305,7 +1308,7 @@ func enforceContentTypes(handler http.Handler, types []string) http.Handler {
|
|||||||
func registerExtensions(mux *http.ServeMux, a *ArgoCDServer, metricsReg HTTPMetricsRegistry) {
|
func registerExtensions(mux *http.ServeMux, a *ArgoCDServer, metricsReg HTTPMetricsRegistry) {
|
||||||
a.log.Info("Registering extensions...")
|
a.log.Info("Registering extensions...")
|
||||||
extHandler := http.HandlerFunc(a.extensionManager.CallExtension())
|
extHandler := http.HandlerFunc(a.extensionManager.CallExtension())
|
||||||
authMiddleware := a.sessionMgr.AuthMiddlewareFunc(a.DisableAuth)
|
authMiddleware := a.sessionMgr.AuthMiddlewareFunc(a.DisableAuth, a.settings.IsSSOConfigured(), a.ssoClientApp)
|
||||||
// auth middleware ensures that requests to all extensions are authenticated first
|
// auth middleware ensures that requests to all extensions are authenticated first
|
||||||
mux.Handle(extension.URLPrefix+"/", authMiddleware(extHandler))
|
mux.Handle(extension.URLPrefix+"/", authMiddleware(extHandler))
|
||||||
|
|
||||||
@@ -1359,7 +1362,7 @@ func (server *ArgoCDServer) serveExtensions(extensionsSharedPath string, w http.
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerDexHandlers will register dex HTTP handlers, creating the OAuth client app
|
// registerDexHandlers will register dex HTTP handlers
|
||||||
func (server *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) {
|
func (server *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) {
|
||||||
if !server.settings.IsSSOConfigured() {
|
if !server.settings.IsSSOConfigured() {
|
||||||
return
|
return
|
||||||
@@ -1367,7 +1370,6 @@ func (server *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) {
|
|||||||
// Run dex OpenID Connect Identity Provider behind a reverse proxy (served at /api/dex)
|
// Run dex OpenID Connect Identity Provider behind a reverse proxy (served at /api/dex)
|
||||||
var err error
|
var err error
|
||||||
mux.HandleFunc(common.DexAPIEndpoint+"/", dexutil.NewDexHTTPReverseProxy(server.DexServerAddr, server.BaseHRef, server.DexTLSConfig))
|
mux.HandleFunc(common.DexAPIEndpoint+"/", dexutil.NewDexHTTPReverseProxy(server.DexServerAddr, server.BaseHRef, server.DexTLSConfig))
|
||||||
server.ssoClientApp, err = oidc.NewClientApp(server.settings, server.DexServerAddr, server.DexTLSConfig, server.BaseHRef, cacheutil.NewRedisCache(server.RedisClient, server.settings.UserInfoCacheExpiration(), cacheutil.RedisCompressionNone))
|
|
||||||
errorsutil.CheckError(err)
|
errorsutil.CheckError(err)
|
||||||
mux.HandleFunc(common.LoginEndpoint, server.ssoClientApp.HandleLogin)
|
mux.HandleFunc(common.LoginEndpoint, server.ssoClientApp.HandleLogin)
|
||||||
mux.HandleFunc(common.CallbackEndpoint, server.ssoClientApp.HandleCallback)
|
mux.HandleFunc(common.CallbackEndpoint, server.ssoClientApp.HandleCallback)
|
||||||
@@ -1578,34 +1580,15 @@ func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string,
|
|||||||
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err)
|
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Some SSO implementations (Okta) require a call to
|
finalClaims := claims
|
||||||
// the OIDC user info path to get attributes like groups
|
if server.settings.IsSSOConfigured() {
|
||||||
// we assume that everywhere in argocd jwt.MapClaims is used as type for interface jwt.Claims
|
finalClaims, err = server.ssoClientApp.SetGroupsFromUserInfo(claims, util_session.SessionManagerClaimsIssuer)
|
||||||
// otherwise this would cause a panic
|
|
||||||
var groupClaims jwt.MapClaims
|
|
||||||
if groupClaims, ok = claims.(jwt.MapClaims); !ok {
|
|
||||||
if tmpClaims, ok := claims.(*jwt.MapClaims); ok {
|
|
||||||
groupClaims = *tmpClaims
|
|
||||||
}
|
|
||||||
}
|
|
||||||
iss := jwtutil.StringField(groupClaims, "iss")
|
|
||||||
if iss != util_session.SessionManagerClaimsIssuer && server.settings.UserInfoGroupsEnabled() && server.settings.UserInfoPath() != "" {
|
|
||||||
userInfo, unauthorized, err := server.ssoClientApp.GetUserInfo(groupClaims, server.settings.IssuerURL(), server.settings.UserInfoPath())
|
|
||||||
if unauthorized {
|
|
||||||
log.Errorf("error while quering userinfo endpoint: %v", err)
|
|
||||||
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session")
|
|
||||||
}
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error fetching user info endpoint: %v", err)
|
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err)
|
||||||
return claims, "", status.Errorf(codes.Internal, "invalid userinfo response")
|
|
||||||
}
|
}
|
||||||
if groupClaims["sub"] != userInfo["sub"] {
|
|
||||||
return claims, "", status.Error(codes.Unknown, "subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo")
|
|
||||||
}
|
|
||||||
groupClaims["groups"] = userInfo["groups"]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return groupClaims, newToken, nil
|
return finalClaims, newToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getToken extracts the token from gRPC metadata or cookie headers
|
// getToken extracts the token from gRPC metadata or cookie headers
|
||||||
|
|||||||
@@ -493,7 +493,7 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
sub := jwtutil.StringField(claims, "sub")
|
sub := jwtutil.StringField(claims, "sub")
|
||||||
err = a.clientCache.Set(&cache.Item{
|
err = a.clientCache.Set(&cache.Item{
|
||||||
Key: formatAccessTokenCacheKey(sub),
|
Key: FormatAccessTokenCacheKey(sub),
|
||||||
Object: encToken,
|
Object: encToken,
|
||||||
CacheActionOpts: cache.CacheActionOpts{
|
CacheActionOpts: cache.CacheActionOpts{
|
||||||
Expiration: getTokenExpiration(claims),
|
Expiration: getTokenExpiration(claims),
|
||||||
@@ -640,6 +640,39 @@ func createClaimsAuthenticationRequestParameter(requestedClaims map[string]*oidc
|
|||||||
return oauth2.SetAuthURLParam("claims", string(claimsRequestRAW)), nil
|
return oauth2.SetAuthURLParam("claims", string(claimsRequestRAW)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetGroupsFromUserInfo takes a claims object and adds groups claim from userinfo endpoint if available
|
||||||
|
// This is required by some SSO implementations as they don't provide the groups claim in the ID token
|
||||||
|
// If querying the UserInfo endpoint fails, we return an error to indicate the session is invalid
|
||||||
|
// we assume that everywhere in argocd jwt.MapClaims is used as type for interface jwt.Claims
|
||||||
|
// otherwise this would cause a panic
|
||||||
|
func (a *ClientApp) SetGroupsFromUserInfo(claims jwt.Claims, sessionManagerClaimsIssuer string) (jwt.MapClaims, error) {
|
||||||
|
var groupClaims jwt.MapClaims
|
||||||
|
var ok bool
|
||||||
|
if groupClaims, ok = claims.(jwt.MapClaims); !ok {
|
||||||
|
if tmpClaims, ok := claims.(*jwt.MapClaims); ok {
|
||||||
|
if tmpClaims != nil {
|
||||||
|
groupClaims = *tmpClaims
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
iss := jwtutil.StringField(groupClaims, "iss")
|
||||||
|
if iss != sessionManagerClaimsIssuer && a.settings.UserInfoGroupsEnabled() && a.settings.UserInfoPath() != "" {
|
||||||
|
userInfo, unauthorized, err := a.GetUserInfo(groupClaims, a.settings.IssuerURL(), a.settings.UserInfoPath())
|
||||||
|
if unauthorized {
|
||||||
|
return groupClaims, fmt.Errorf("error while quering userinfo endpoint: %w", err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return groupClaims, fmt.Errorf("error fetching user info endpoint: %w", err)
|
||||||
|
}
|
||||||
|
if groupClaims["sub"] != userInfo["sub"] {
|
||||||
|
return groupClaims, errors.New("subject of claims from user info endpoint didn't match subject of idToken, see https://openid.net/specs/openid-connect-core-1_0.html#UserInfo")
|
||||||
|
}
|
||||||
|
groupClaims["groups"] = userInfo["groups"]
|
||||||
|
}
|
||||||
|
|
||||||
|
return groupClaims, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetUserInfo queries the IDP userinfo endpoint for claims
|
// GetUserInfo queries the IDP userinfo endpoint for claims
|
||||||
func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
|
func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
|
||||||
sub := jwtutil.StringField(actualClaims, "sub")
|
sub := jwtutil.StringField(actualClaims, "sub")
|
||||||
@@ -647,7 +680,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
|
|||||||
var encClaims []byte
|
var encClaims []byte
|
||||||
|
|
||||||
// in case we got it in the cache, we just return the item
|
// in case we got it in the cache, we just return the item
|
||||||
clientCacheKey := formatUserInfoResponseCacheKey(sub)
|
clientCacheKey := FormatUserInfoResponseCacheKey(sub)
|
||||||
if err := a.clientCache.Get(clientCacheKey, &encClaims); err == nil {
|
if err := a.clientCache.Get(clientCacheKey, &encClaims); err == nil {
|
||||||
claimsRaw, err := crypto.Decrypt(encClaims, a.encryptionKey)
|
claimsRaw, err := crypto.Decrypt(encClaims, a.encryptionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -664,7 +697,7 @@ func (a *ClientApp) GetUserInfo(actualClaims jwt.MapClaims, issuerURL, userInfoP
|
|||||||
|
|
||||||
// check if the accessToken for the user is still present
|
// check if the accessToken for the user is still present
|
||||||
var encAccessToken []byte
|
var encAccessToken []byte
|
||||||
err := a.clientCache.Get(formatAccessTokenCacheKey(sub), &encAccessToken)
|
err := a.clientCache.Get(FormatAccessTokenCacheKey(sub), &encAccessToken)
|
||||||
// without an accessToken we can't query the user info endpoint
|
// without an accessToken we can't query the user info endpoint
|
||||||
// thus the user needs to reauthenticate for argocd to get a new accessToken
|
// thus the user needs to reauthenticate for argocd to get a new accessToken
|
||||||
if errors.Is(err, cache.ErrCacheMiss) {
|
if errors.Is(err, cache.ErrCacheMiss) {
|
||||||
@@ -774,11 +807,11 @@ func getTokenExpiration(claims jwt.MapClaims) time.Duration {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache
|
// formatUserInfoResponseCacheKey returns the key which is used to store userinfo of user in cache
|
||||||
func formatUserInfoResponseCacheKey(sub string) string {
|
func FormatUserInfoResponseCacheKey(sub string) string {
|
||||||
return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub)
|
return fmt.Sprintf("%s_%s", UserInfoResponseCachePrefix, sub)
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache
|
// formatAccessTokenCacheKey returns the key which is used to store the accessToken of a user in cache
|
||||||
func formatAccessTokenCacheKey(sub string) string {
|
func FormatAccessTokenCacheKey(sub string) string {
|
||||||
return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub)
|
return fmt.Sprintf("%s_%s", AccessTokenCachePrefix, sub)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -943,7 +943,7 @@ func TestGetUserInfo(t *testing.T) {
|
|||||||
expectError bool
|
expectError bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
key: formatUserInfoResponseCacheKey("randomUser"),
|
key: FormatUserInfoResponseCacheKey("randomUser"),
|
||||||
expectError: true,
|
expectError: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -958,7 +958,7 @@ func TestGetUserInfo(t *testing.T) {
|
|||||||
encrypt bool
|
encrypt bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
key: formatAccessTokenCacheKey("randomUser"),
|
key: FormatAccessTokenCacheKey("randomUser"),
|
||||||
value: "FakeAccessToken",
|
value: "FakeAccessToken",
|
||||||
encrypt: true,
|
encrypt: true,
|
||||||
},
|
},
|
||||||
@@ -977,7 +977,7 @@ func TestGetUserInfo(t *testing.T) {
|
|||||||
expectError bool
|
expectError bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
key: formatUserInfoResponseCacheKey("randomUser"),
|
key: FormatUserInfoResponseCacheKey("randomUser"),
|
||||||
expectError: true,
|
expectError: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -992,7 +992,7 @@ func TestGetUserInfo(t *testing.T) {
|
|||||||
encrypt bool
|
encrypt bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
key: formatAccessTokenCacheKey("randomUser"),
|
key: FormatAccessTokenCacheKey("randomUser"),
|
||||||
value: "FakeAccessToken",
|
value: "FakeAccessToken",
|
||||||
encrypt: true,
|
encrypt: true,
|
||||||
},
|
},
|
||||||
@@ -1011,7 +1011,7 @@ func TestGetUserInfo(t *testing.T) {
|
|||||||
expectError bool
|
expectError bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
key: formatUserInfoResponseCacheKey("randomUser"),
|
key: FormatUserInfoResponseCacheKey("randomUser"),
|
||||||
expectError: true,
|
expectError: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1034,7 +1034,7 @@ func TestGetUserInfo(t *testing.T) {
|
|||||||
encrypt bool
|
encrypt bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
key: formatAccessTokenCacheKey("randomUser"),
|
key: FormatAccessTokenCacheKey("randomUser"),
|
||||||
value: "FakeAccessToken",
|
value: "FakeAccessToken",
|
||||||
encrypt: true,
|
encrypt: true,
|
||||||
},
|
},
|
||||||
@@ -1053,7 +1053,7 @@ func TestGetUserInfo(t *testing.T) {
|
|||||||
expectError bool
|
expectError bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
key: formatUserInfoResponseCacheKey("randomUser"),
|
key: FormatUserInfoResponseCacheKey("randomUser"),
|
||||||
expectError: true,
|
expectError: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -1086,7 +1086,7 @@ func TestGetUserInfo(t *testing.T) {
|
|||||||
expectError bool
|
expectError bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
key: formatUserInfoResponseCacheKey("randomUser"),
|
key: FormatUserInfoResponseCacheKey("randomUser"),
|
||||||
value: "{\"groups\":[\"githubOrg:engineers\"]}",
|
value: "{\"groups\":[\"githubOrg:engineers\"]}",
|
||||||
expectEncrypted: true,
|
expectEncrypted: true,
|
||||||
expectError: false,
|
expectError: false,
|
||||||
@@ -1113,7 +1113,7 @@ func TestGetUserInfo(t *testing.T) {
|
|||||||
encrypt bool
|
encrypt bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
key: formatAccessTokenCacheKey("randomUser"),
|
key: FormatAccessTokenCacheKey("randomUser"),
|
||||||
value: "FakeAccessToken",
|
value: "FakeAccessToken",
|
||||||
encrypt: true,
|
encrypt: true,
|
||||||
},
|
},
|
||||||
@@ -1172,3 +1172,94 @@ func TestGetUserInfo(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetGroupsFromUserInfo(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
inputClaims jwt.MapClaims // function input
|
||||||
|
cacheClaims jwt.MapClaims // userinfo response
|
||||||
|
expectedClaims jwt.MapClaims // function output
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "set correct groups from userinfo endpoint", // enriches the JWT claims with information from the userinfo endpoint, default case
|
||||||
|
inputClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||||
|
cacheClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"githubOrg:example"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||||
|
expectedClaims: jwt.MapClaims{"sub": "randomUser", "groups": []any{"githubOrg:example"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())}, // the groups must be of type any since the response we get was parsed by GetUserInfo and we don't yet know the type of the groups claim
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "return error for wrong userinfo claims returned", // when there's an error in this feature, the claims should be untouched for the rest to still proceed
|
||||||
|
inputClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||||
|
cacheClaims: jwt.MapClaims{"sub": "wrongUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||||
|
expectedClaims: jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "override groups already defined in input claims", // this is expected behavior since input claims might have been truncated (HTTP header 4K limit)
|
||||||
|
inputClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"groupfromjwt"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||||
|
cacheClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"superusers", "usergroup", "support-group"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||||
|
expectedClaims: jwt.MapClaims{"sub": "randomUser", "groups": []any{"superusers", "usergroup", "support-group"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty cache and non-rechable userinfo endpoint", // this will try to reach the userinfo endpoint defined in the test and fail
|
||||||
|
inputClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"groupfromjwt"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||||
|
cacheClaims: nil, // the test doesn't set the cache for an empty object
|
||||||
|
expectedClaims: jwt.MapClaims{"sub": "randomUser", "groups": []string{"groupfromjwt"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// create the ClientApp
|
||||||
|
userInfoCache := cache.NewInMemoryCache(24 * time.Hour)
|
||||||
|
signature, err := util.MakeSignature(32)
|
||||||
|
require.NoError(t, err, "failed creating signature for settings object")
|
||||||
|
cdSettings := &settings.ArgoCDSettings{
|
||||||
|
ServerSignature: signature,
|
||||||
|
OIDCConfigRAW: `
|
||||||
|
issuer: http://localhost:63231
|
||||||
|
enableUserInfoGroups: true
|
||||||
|
userInfoPath: /`,
|
||||||
|
}
|
||||||
|
a, err := NewClientApp(cdSettings, "", nil, "/argo-cd", userInfoCache)
|
||||||
|
require.NoError(t, err, "failed creating clientapp")
|
||||||
|
|
||||||
|
// prepoluate cache to predict what the GetUserInfo function will return to the SetGroupsFromUserInfo function (without having to mock the userinfo response)
|
||||||
|
encryptionKey, err := cdSettings.GetServerEncryptionKey()
|
||||||
|
require.NoError(t, err, "failed obtaining encryption key from settings")
|
||||||
|
|
||||||
|
// set fake accessToken for function to not return early
|
||||||
|
encAccessToken, err := crypto.Encrypt([]byte("123456"), encryptionKey)
|
||||||
|
require.NoError(t, err, "failed encrypting dummy access token")
|
||||||
|
err = a.clientCache.Set(&cache.Item{
|
||||||
|
Key: FormatAccessTokenCacheKey("randomUser"),
|
||||||
|
Object: encAccessToken,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "failed setting item to in-memory cache")
|
||||||
|
|
||||||
|
// set cacheClaims to in-memory cache to let GetUserInfo return early with this information (GetUserInfo has a separate test, here we focus on SetUserInfoGroups)
|
||||||
|
if tt.cacheClaims != nil {
|
||||||
|
cacheClaims, err := json.Marshal(tt.cacheClaims)
|
||||||
|
require.NoError(t, err)
|
||||||
|
encCacheClaims, err := crypto.Encrypt([]byte(cacheClaims), encryptionKey)
|
||||||
|
require.NoError(t, err, "failed encrypting dummy access token")
|
||||||
|
err = a.clientCache.Set(&cache.Item{
|
||||||
|
Key: FormatUserInfoResponseCacheKey("randomUser"),
|
||||||
|
Object: encCacheClaims,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "failed setting item to in-memory cache")
|
||||||
|
}
|
||||||
|
|
||||||
|
receivedClaims, err := a.SetGroupsFromUserInfo(tt.inputClaims, "argocd")
|
||||||
|
if tt.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tt.expectedClaims, receivedClaims) // check that the claims were successfully enriched with what we expect
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -480,9 +480,9 @@ func (mgr *SessionManager) VerifyUsernamePassword(username string, password stri
|
|||||||
|
|
||||||
// AuthMiddlewareFunc returns a function that can be used as an
|
// AuthMiddlewareFunc returns a function that can be used as an
|
||||||
// authentication middleware for HTTP requests.
|
// authentication middleware for HTTP requests.
|
||||||
func (mgr *SessionManager) AuthMiddlewareFunc(disabled bool) func(http.Handler) http.Handler {
|
func (mgr *SessionManager) AuthMiddlewareFunc(disabled bool, isSSOConfigured bool, ssoClientApp *oidcutil.ClientApp) func(http.Handler) http.Handler {
|
||||||
return func(h http.Handler) http.Handler {
|
return func(h http.Handler) http.Handler {
|
||||||
return WithAuthMiddleware(disabled, mgr, h)
|
return WithAuthMiddleware(disabled, isSSOConfigured, ssoClientApp, mgr, h)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -495,26 +495,41 @@ type TokenVerifier interface {
|
|||||||
// WithAuthMiddleware is an HTTP middleware used to ensure incoming
|
// WithAuthMiddleware is an HTTP middleware used to ensure incoming
|
||||||
// requests are authenticated before invoking the target handler. If
|
// requests are authenticated before invoking the target handler. If
|
||||||
// disabled is true, it will just invoke the next handler in the chain.
|
// disabled is true, it will just invoke the next handler in the chain.
|
||||||
func WithAuthMiddleware(disabled bool, authn TokenVerifier, next http.Handler) http.Handler {
|
func WithAuthMiddleware(disabled bool, isSSOConfigured bool, ssoClientApp *oidcutil.ClientApp, authn TokenVerifier, next http.Handler) http.Handler {
|
||||||
|
if disabled {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if !disabled {
|
cookies := r.Cookies()
|
||||||
cookies := r.Cookies()
|
tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies)
|
||||||
tokenString, err := httputil.JoinCookies(common.AuthCookieName, cookies)
|
if err != nil {
|
||||||
if err != nil {
|
http.Error(w, "Auth cookie not found", http.StatusBadRequest)
|
||||||
http.Error(w, "Auth cookie not found", http.StatusBadRequest)
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
claims, _, err := authn.VerifyToken(tokenString)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ctx := r.Context()
|
|
||||||
// Add claims to the context to inspect for RBAC
|
|
||||||
//nolint:staticcheck
|
|
||||||
ctx = context.WithValue(ctx, "claims", claims)
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
}
|
}
|
||||||
|
claims, _, err := authn.VerifyToken(tokenString)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid token", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
finalClaims := claims
|
||||||
|
if isSSOConfigured {
|
||||||
|
finalClaims, err = ssoClientApp.SetGroupsFromUserInfo(claims, SessionManagerClaimsIssuer)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "Invalid session", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
// Add claims to the context to inspect for RBAC
|
||||||
|
//nolint:staticcheck
|
||||||
|
ctx = context.WithValue(ctx, "claims", finalClaims)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package session
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
stderrors "errors"
|
stderrors "errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -29,7 +30,11 @@ import (
|
|||||||
apps "github.com/argoproj/argo-cd/v3/pkg/client/clientset/versioned/fake"
|
apps "github.com/argoproj/argo-cd/v3/pkg/client/clientset/versioned/fake"
|
||||||
"github.com/argoproj/argo-cd/v3/pkg/client/listers/application/v1alpha1"
|
"github.com/argoproj/argo-cd/v3/pkg/client/listers/application/v1alpha1"
|
||||||
"github.com/argoproj/argo-cd/v3/test"
|
"github.com/argoproj/argo-cd/v3/test"
|
||||||
|
"github.com/argoproj/argo-cd/v3/util"
|
||||||
|
"github.com/argoproj/argo-cd/v3/util/cache"
|
||||||
|
"github.com/argoproj/argo-cd/v3/util/crypto"
|
||||||
jwtutil "github.com/argoproj/argo-cd/v3/util/jwt"
|
jwtutil "github.com/argoproj/argo-cd/v3/util/jwt"
|
||||||
|
"github.com/argoproj/argo-cd/v3/util/oidc"
|
||||||
"github.com/argoproj/argo-cd/v3/util/password"
|
"github.com/argoproj/argo-cd/v3/util/password"
|
||||||
"github.com/argoproj/argo-cd/v3/util/settings"
|
"github.com/argoproj/argo-cd/v3/util/settings"
|
||||||
utiltest "github.com/argoproj/argo-cd/v3/util/test"
|
utiltest "github.com/argoproj/argo-cd/v3/util/test"
|
||||||
@@ -236,20 +241,39 @@ func strPointer(str string) *string {
|
|||||||
|
|
||||||
func TestSessionManager_WithAuthMiddleware(t *testing.T) {
|
func TestSessionManager_WithAuthMiddleware(t *testing.T) {
|
||||||
handlerFunc := func() func(http.ResponseWriter, *http.Request) {
|
handlerFunc := func() func(http.ResponseWriter, *http.Request) {
|
||||||
return func(w http.ResponseWriter, _ *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Header().Set("Content-Type", "application/text")
|
|
||||||
_, err := w.Write([]byte("Ok"))
|
contextClaims := r.Context().Value("claims")
|
||||||
require.NoError(t, err, "error writing response: %s", err)
|
if contextClaims != nil {
|
||||||
|
var gotClaims jwt.MapClaims
|
||||||
|
var ok bool
|
||||||
|
if gotClaims, ok = contextClaims.(jwt.MapClaims); !ok {
|
||||||
|
if tmpClaims, ok := contextClaims.(*jwt.MapClaims); ok && tmpClaims != nil {
|
||||||
|
gotClaims = *tmpClaims
|
||||||
|
}
|
||||||
|
}
|
||||||
|
jsonClaims, err := json.Marshal(gotClaims)
|
||||||
|
require.NoError(t, err, "erorr marshalling claims set by AuthMiddleware")
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, err = w.Write(jsonClaims)
|
||||||
|
require.NoError(t, err, "error writing response: %s", err)
|
||||||
|
} else {
|
||||||
|
w.Header().Set("Content-Type", "application/text")
|
||||||
|
_, err := w.Write([]byte("Ok"))
|
||||||
|
require.NoError(t, err, "error writing response: %s", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
name string
|
name string
|
||||||
authDisabled bool
|
authDisabled bool
|
||||||
|
ssoEnabled bool
|
||||||
cookieHeader bool
|
cookieHeader bool
|
||||||
verifiedClaims *jwt.RegisteredClaims
|
verifiedClaims *jwt.MapClaims
|
||||||
verifyTokenErr error
|
verifyTokenErr error
|
||||||
|
userInfoCacheClaims *jwt.MapClaims
|
||||||
expectedStatusCode int
|
expectedStatusCode int
|
||||||
expectedResponseBody *string
|
expectedResponseBody *string
|
||||||
}
|
}
|
||||||
@@ -258,47 +282,79 @@ func TestSessionManager_WithAuthMiddleware(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "will authenticate successfully",
|
name: "will authenticate successfully",
|
||||||
authDisabled: false,
|
authDisabled: false,
|
||||||
|
ssoEnabled: false,
|
||||||
cookieHeader: true,
|
cookieHeader: true,
|
||||||
verifiedClaims: &jwt.RegisteredClaims{},
|
verifiedClaims: &jwt.MapClaims{},
|
||||||
verifyTokenErr: nil,
|
verifyTokenErr: nil,
|
||||||
|
userInfoCacheClaims: nil,
|
||||||
expectedStatusCode: http.StatusOK,
|
expectedStatusCode: http.StatusOK,
|
||||||
expectedResponseBody: strPointer("Ok"),
|
expectedResponseBody: strPointer("{}"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "will be noop if auth is disabled",
|
name: "will be noop if auth is disabled",
|
||||||
authDisabled: true,
|
authDisabled: true,
|
||||||
|
ssoEnabled: false,
|
||||||
cookieHeader: false,
|
cookieHeader: false,
|
||||||
verifiedClaims: nil,
|
verifiedClaims: nil,
|
||||||
verifyTokenErr: nil,
|
verifyTokenErr: nil,
|
||||||
|
userInfoCacheClaims: nil,
|
||||||
expectedStatusCode: http.StatusOK,
|
expectedStatusCode: http.StatusOK,
|
||||||
expectedResponseBody: strPointer("Ok"),
|
expectedResponseBody: strPointer("Ok"),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "will return 400 if no cookie header",
|
name: "will return 400 if no cookie header",
|
||||||
authDisabled: false,
|
authDisabled: false,
|
||||||
|
ssoEnabled: false,
|
||||||
cookieHeader: false,
|
cookieHeader: false,
|
||||||
verifiedClaims: &jwt.RegisteredClaims{},
|
verifiedClaims: &jwt.MapClaims{},
|
||||||
verifyTokenErr: nil,
|
verifyTokenErr: nil,
|
||||||
|
userInfoCacheClaims: nil,
|
||||||
expectedStatusCode: http.StatusBadRequest,
|
expectedStatusCode: http.StatusBadRequest,
|
||||||
expectedResponseBody: nil,
|
expectedResponseBody: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "will return 401 verify token fails",
|
name: "will return 401 verify token fails",
|
||||||
authDisabled: false,
|
authDisabled: false,
|
||||||
|
ssoEnabled: false,
|
||||||
cookieHeader: true,
|
cookieHeader: true,
|
||||||
verifiedClaims: &jwt.RegisteredClaims{},
|
verifiedClaims: &jwt.MapClaims{},
|
||||||
verifyTokenErr: stderrors.New("token error"),
|
verifyTokenErr: stderrors.New("token error"),
|
||||||
|
userInfoCacheClaims: nil,
|
||||||
expectedStatusCode: http.StatusUnauthorized,
|
expectedStatusCode: http.StatusUnauthorized,
|
||||||
expectedResponseBody: nil,
|
expectedResponseBody: nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "will return 200 if claims are nil",
|
name: "will return 200 if claims are nil",
|
||||||
authDisabled: false,
|
authDisabled: false,
|
||||||
|
ssoEnabled: false,
|
||||||
cookieHeader: true,
|
cookieHeader: true,
|
||||||
verifiedClaims: nil,
|
verifiedClaims: nil,
|
||||||
verifyTokenErr: nil,
|
verifyTokenErr: nil,
|
||||||
|
userInfoCacheClaims: nil,
|
||||||
expectedStatusCode: http.StatusOK,
|
expectedStatusCode: http.StatusOK,
|
||||||
expectedResponseBody: strPointer("Ok"),
|
expectedResponseBody: strPointer("null"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "will return 401 if sso is enabled but userinfo response not working",
|
||||||
|
authDisabled: false,
|
||||||
|
ssoEnabled: true,
|
||||||
|
cookieHeader: true,
|
||||||
|
verifiedClaims: nil,
|
||||||
|
verifyTokenErr: nil,
|
||||||
|
userInfoCacheClaims: nil, // indicates that the userinfo response will not work since cache is empty and userinfo endpoint not rechable
|
||||||
|
expectedStatusCode: http.StatusUnauthorized,
|
||||||
|
expectedResponseBody: strPointer("Invalid session"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "will return 200 if sso is enabled and userinfo response from cache is valid",
|
||||||
|
authDisabled: false,
|
||||||
|
ssoEnabled: true,
|
||||||
|
cookieHeader: true,
|
||||||
|
verifiedClaims: &jwt.MapClaims{"sub": "randomUser", "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||||
|
verifyTokenErr: nil,
|
||||||
|
userInfoCacheClaims: &jwt.MapClaims{"sub": "randomUser", "groups": []string{"superusers"}, "exp": float64(time.Now().Add(5 * time.Minute).Unix())},
|
||||||
|
expectedStatusCode: http.StatusOK,
|
||||||
|
expectedResponseBody: strPointer("\"groups\":[\"superusers\"]"),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
@@ -311,7 +367,47 @@ func TestSessionManager_WithAuthMiddleware(t *testing.T) {
|
|||||||
claims: tc.verifiedClaims,
|
claims: tc.verifiedClaims,
|
||||||
err: tc.verifyTokenErr,
|
err: tc.verifyTokenErr,
|
||||||
}
|
}
|
||||||
ts := httptest.NewServer(WithAuthMiddleware(tc.authDisabled, tm, mux))
|
clientApp := &oidc.ClientApp{} // all testcases need at least the empty struct for the function to work
|
||||||
|
if tc.ssoEnabled {
|
||||||
|
userInfoCache := cache.NewInMemoryCache(24 * time.Hour)
|
||||||
|
signature, err := util.MakeSignature(32)
|
||||||
|
require.NoError(t, err, "failed creating signature for settings object")
|
||||||
|
cdSettings := &settings.ArgoCDSettings{
|
||||||
|
ServerSignature: signature,
|
||||||
|
OIDCConfigRAW: `
|
||||||
|
issuer: http://localhost:63231
|
||||||
|
enableUserInfoGroups: true
|
||||||
|
userInfoPath: /`,
|
||||||
|
}
|
||||||
|
clientApp, err = oidc.NewClientApp(cdSettings, "", nil, "/argo-cd", userInfoCache)
|
||||||
|
require.NoError(t, err, "failed creating clientapp")
|
||||||
|
|
||||||
|
// prepopulate the cache with claims to return for a userinfo call
|
||||||
|
encryptionKey, err := cdSettings.GetServerEncryptionKey()
|
||||||
|
require.NoError(t, err, "failed obtaining encryption key from settings")
|
||||||
|
// set fake accessToken for GetUserInfo to not return early (can be the same for all cases)
|
||||||
|
encAccessToken, err := crypto.Encrypt([]byte("123456"), encryptionKey)
|
||||||
|
require.NoError(t, err, "failed encrypting dummy access token")
|
||||||
|
err = userInfoCache.Set(&cache.Item{
|
||||||
|
Key: oidc.FormatAccessTokenCacheKey("randomUser"),
|
||||||
|
Object: encAccessToken,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "failed setting item to in-memory cache")
|
||||||
|
|
||||||
|
// set cacheClaims to in-memory cache to let GetUserInfo return early with this information
|
||||||
|
if tc.userInfoCacheClaims != nil {
|
||||||
|
cacheClaims, err := json.Marshal(tc.userInfoCacheClaims)
|
||||||
|
require.NoError(t, err)
|
||||||
|
encCacheClaims, err := crypto.Encrypt([]byte(cacheClaims), encryptionKey)
|
||||||
|
require.NoError(t, err, "failed encrypting cache Claims")
|
||||||
|
err = userInfoCache.Set(&cache.Item{
|
||||||
|
Key: oidc.FormatUserInfoResponseCacheKey("randomUser"),
|
||||||
|
Object: encCacheClaims,
|
||||||
|
})
|
||||||
|
require.NoError(t, err, "failed setting item to in-memory cache")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ts := httptest.NewServer(WithAuthMiddleware(tc.authDisabled, tc.ssoEnabled, clientApp, tm, mux))
|
||||||
defer ts.Close()
|
defer ts.Close()
|
||||||
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, ts.URL, http.NoBody)
|
req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, ts.URL, http.NoBody)
|
||||||
require.NoErrorf(t, err, "error creating request: %s", err)
|
require.NoErrorf(t, err, "error creating request: %s", err)
|
||||||
|
|||||||
Reference in New Issue
Block a user