Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 46 additions & 74 deletions backend/http/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,60 +20,51 @@ import (
"golang.org/x/oauth2"
)

// userInfo struct to hold user claims from either UserInfo or ID token
// userInfo holds all claims dynamically, plus pre-parsed Groups.
type userInfo struct {
Name string `json:"name"`
PreferredUsername string `json:"preferred_username"`
Username string `json:"username"`
Email string `json:"email"`
Sub string `json:"sub"`
Phone string `json:"phone_number"`
Groups []string `json:"-"` // Handled manually by userInfoUnmarshaller
Claims map[string]interface{}
Groups []string
}

// userInfoUnmarshaller is a custom unmarshaller that handles configurable groups claim
// userInfoUnmarshaller handles unmarshalling all claims dynamically,
// while optionally parsing a configurable groups claim.
type userInfoUnmarshaller struct {
userInfo *userInfo
groupsClaim string
}

// UnmarshalJSON implements the json.Unmarshaler interface
func (u *userInfoUnmarshaller) UnmarshalJSON(data []byte) error {
// First, unmarshal the basic userInfo fields
if err := json.Unmarshal(data, u.userInfo); err != nil {
var raw map[string]interface{}
if err := json.Unmarshal(data, &raw); err != nil {
return err
}

// Parse the JSON to access the groups claim field
var rawData map[string]interface{}
if err := json.Unmarshal(data, &rawData); err != nil {
return err
}

// Look for the groups claim using the configured field name
if groupsValue, exists := rawData[u.groupsClaim]; exists {
switch v := groupsValue.(type) {
case []interface{}:
// It's already an array, convert to []string
groups := make([]string, len(v))
for i, group := range v {
if str, ok := group.(string); ok {
groups[i] = str
// Extract groups if configured
if u.groupsClaim != "" {
if v, ok := raw[u.groupsClaim]; ok {
switch val := v.(type) {
case []interface{}:
groups := make([]string, len(val))
for i, g := range val {
if s, ok := g.(string); ok {
groups[i] = strings.TrimSpace(s)
}
}
}
u.userInfo.Groups = groups
case string:
// It's a string, split by commas
if v != "" {
u.userInfo.Groups = strings.Split(v, ",")
// Trim whitespace from each group
for i, group := range u.userInfo.Groups {
u.userInfo.Groups[i] = strings.TrimSpace(group)
u.userInfo.Groups = groups
case string:
if val != "" {
parts := strings.Split(val, ",")
for i := range parts {
parts[i] = strings.TrimSpace(parts[i])
}
u.userInfo.Groups = parts
}
}
}
}

u.userInfo.Claims = raw
return nil
}

Expand All @@ -100,7 +91,7 @@ func oidcLoginHandler(w http.ResponseWriter, r *http.Request, d *requestContext)
ClientSecret: oidcCfg.ClientSecret,
Endpoint: oidcCfg.Provider.Endpoint(),
RedirectURL: fmt.Sprintf("%s%sapi/auth/oidc/callback", origin, config.Server.BaseURL),
Scopes: strings.Split(oidcCfg.Scopes, " "),
Scopes: strings.Fields(oidcCfg.Scopes),
}

nonce := utils.InsecureRandomIdentifier(16)
Expand Down Expand Up @@ -157,7 +148,7 @@ func oidcCallbackHandler(w http.ResponseWriter, r *http.Request, d *requestConte
ClientSecret: oidcCfg.ClientSecret,
Endpoint: oidcCfg.Provider.Endpoint(), // Use endpoint from discovered provider
RedirectURL: redirectURL, // Use the dynamically determined redirect URL
Scopes: strings.Split(oidcCfg.Scopes, " "),
Scopes: strings.Fields(oidcCfg.Scopes),
}

// Exchange the authorization code for tokens
Expand Down Expand Up @@ -186,10 +177,10 @@ func oidcCallbackHandler(w http.ResponseWriter, r *http.Request, d *requestConte

// Verify the ID token
// This uses the verifier initialized with the provider's JWKS endpoint and client ID
idToken, err := oidcCfg.Verifier.Verify(ctx, rawIDToken)
if err != nil {
idToken, verify_err := oidcCfg.Verifier.Verify(ctx, rawIDToken)
if verify_err != nil {
// this might not be necessary for certain providers like authentik
logger.Debugf("failed to verify ID token: %v. This might be expected, falling back to UserInfo endpoint.", err)
logger.Debugf("failed to verify ID token: %v. This might be expected, falling back to UserInfo endpoint.", verify_err)
// Verification failed, claimsFromIDToken remains false
} else {
// Decode the ID token claims using custom unmarshaller
Expand All @@ -203,30 +194,13 @@ func oidcCallbackHandler(w http.ResponseWriter, r *http.Request, d *requestConte

// Decide if we rely on ID token claims or still need UserInfo
// Even if parsing succeeded, if essential claims are missing, use UserInfo
switch oidcCfg.UserIdentifier {
case "email":
if userdata.Email != "" {
claimsFromIDToken = true
loginUsername = userdata.Email
}
case "username":
if userdata.Username != "" {
claimsFromIDToken = true
loginUsername = userdata.Username
}
case "preferred_username":
if userdata.PreferredUsername != "" {
claimsFromIDToken = true
loginUsername = userdata.PreferredUsername
}
case "phone":
if userdata.Phone != "" {
claimsFromIDToken = true
loginUsername = userdata.Phone
}
if _, ok := userdata.Claims[oidcCfg.UserIdentifier]; ok {
claimsFromIDToken = true
}
}
logger.Debugf("Failed to verify ID token: %v", verify_err)
}

} else {
logger.Debug("No ID token found in token response or it was empty. Falling back to UserInfo endpoint.")
// claimsFromIDToken remains false
Expand All @@ -247,22 +221,20 @@ func oidcCallbackHandler(w http.ResponseWriter, r *http.Request, d *requestConte
logger.Errorf("failed to decode user info from endpoint: %v", err)
return http.StatusInternalServerError, fmt.Errorf("failed to decode user info from endpoint: %v", err)
}
// Decide if we rely on ID token claims or still need UserInfo
// Even if parsing succeeded, if essential claims are missing, use UserInfo
switch oidcCfg.UserIdentifier {
case "email":
loginUsername = userdata.Email
case "username":
loginUsername = userdata.Username
case "preferred_username":
loginUsername = userdata.PreferredUsername
case "phone":
loginUsername = userdata.Phone
}

// --- Determine login username dynamically ---
if val, ok := userdata.Claims[oidcCfg.UserIdentifier]; ok {
switch v := val.(type) {
case string:
loginUsername = v
default:
loginUsername = fmt.Sprintf("%v", v)
}
}
if loginUsername == "" {
logger.Errorf("No valid username found for identifier '%v' in ID token or UserInfo response.", oidcCfg.UserIdentifier)
return http.StatusInternalServerError, fmt.Errorf("no valid username found in ID token or UserInfo response from claims")
logger.Errorf("No valid username found for identifier '%v' in claims.", oidcCfg.UserIdentifier)
return http.StatusInternalServerError, fmt.Errorf("no valid username found for identifier '%v'", oidcCfg.UserIdentifier)
}

// Proceed to log the user in with the OIDC data
Expand Down
34 changes: 24 additions & 10 deletions backend/http/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ func TestUserInfoUnmarshaller(t *testing.T) {
jsonData: `{"name":"John","email":"john@example.com","groups":["admin","users"]}`,
groupsClaim: "groups",
expected: userInfo{
Name: "John",
Email: "john@example.com",
Claims: map[string]interface{}{
"name": "John",
"email": "john@example.com",
"groups": []interface{}{"admin","users"},
},
Groups: []string{"admin", "users"},
},
},
Expand All @@ -28,8 +31,11 @@ func TestUserInfoUnmarshaller(t *testing.T) {
jsonData: `{"name":"Jane","email":"jane@example.com","roles":["admin","users"]}`,
groupsClaim: "roles",
expected: userInfo{
Name: "Jane",
Email: "jane@example.com",
Claims: map[string]interface{}{
"name": "Jane",
"email": "jane@example.com",
"roles": []interface{}{"admin","users"},
},
Groups: []string{"admin", "users"},
},
},
Expand All @@ -38,8 +44,11 @@ func TestUserInfoUnmarshaller(t *testing.T) {
jsonData: `{"name":"Bob","email":"bob@example.com","groups":"admin, users, guests"}`,
groupsClaim: "groups",
expected: userInfo{
Name: "Bob",
Email: "bob@example.com",
Claims: map[string]interface{}{
"name": "Bob",
"email": "bob@example.com",
"groups": "admin, users, guests",
},
Groups: []string{"admin", "users", "guests"},
},
},
Expand All @@ -48,8 +57,10 @@ func TestUserInfoUnmarshaller(t *testing.T) {
jsonData: `{"name":"Alice","email":"alice@example.com"}`,
groupsClaim: "groups",
expected: userInfo{
Name: "Alice",
Email: "alice@example.com",
Claims: map[string]interface{}{
"name": "Alice",
"email": "alice@example.com",
},
Groups: nil,
},
},
Expand All @@ -58,8 +69,11 @@ func TestUserInfoUnmarshaller(t *testing.T) {
jsonData: `{"name":"Charlie","email":"charlie@example.com","groups":[]}`,
groupsClaim: "groups",
expected: userInfo{
Name: "Charlie",
Email: "charlie@example.com",
Claims: map[string]interface{}{
"name": "Charlie",
"email": "charlie@example.com",
"groups": []interface{}{},
},
Groups: []string{},
},
},
Expand Down
Loading