blob: 1cfebef9ab8dd779b8b357720cb743bea438a9d8 [file] [log] [blame]
package auth
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import (
"context"
"database/sql"
"errors"
"sync"
"time"
"github.com/apache/trafficcontrol/lib/go-log"
"github.com/lib/pq"
)
const (
getUsersQuery = `
SELECT
u.id,
u.local_passwd,
u.role,
u.tenant_id,
u.token,
u.ucdn,
u.username
FROM
tm_user AS u
`
getRolesQuery = `
SELECT
ARRAY(SELECT rc.cap_name FROM role_capability AS rc WHERE rc.role_id=r.id) AS capabilities,
r.id as role,
r.name as role_name,
r.priv_level
FROM role r
`
)
type user struct {
CurrentUser
LocalPasswd *string
Token *string
}
type role struct {
Capabilities pq.StringArray
ID int
Name string
PrivLevel int
}
type users struct {
userMap map[string]user
usernamesByToken map[string]string
*sync.RWMutex
initialized bool
enabled bool // note: enabled is only written to once at startup, before serving requests, so it doesn't need synchronized access
}
var usersCache = users{RWMutex: &sync.RWMutex{}}
func usersCacheIsEnabled() bool {
if usersCache.enabled {
usersCache.RLock()
defer usersCache.RUnlock()
return usersCache.initialized
}
return false
}
// getUserFromCache returns the user with the given username and a boolean indicating whether the user exists.
func getUserFromCache(username string) (user, bool) {
usersCache.RLock()
defer usersCache.RUnlock()
u, exists := usersCache.userMap[username]
return u, exists
}
// getUserNameFromCacheByToken returns the username with the given token and a boolean indicating whether a matching token was found.
func getUserNameFromCacheByToken(token string) (string, bool) {
usersCache.RLock()
defer usersCache.RUnlock()
t, exists := usersCache.usernamesByToken[token]
return t, exists
}
var once = sync.Once{}
// InitUsersCache attempts to initialize the in-memory users data (if enabled) then
// starts a goroutine to periodically refresh the in-memory data from the database.
func InitUsersCache(interval time.Duration, db *sql.DB, timeout time.Duration) {
once.Do(func() {
if interval <= 0 {
return
}
usersCache.enabled = true
refreshUsersCache(db, timeout)
startUsersCacheRefresher(interval, db, timeout)
})
}
func startUsersCacheRefresher(interval time.Duration, db *sql.DB, timeout time.Duration) {
go func() {
for {
time.Sleep(interval)
refreshUsersCache(db, timeout)
}
}()
}
func refreshUsersCache(db *sql.DB, timeout time.Duration) {
newUsers, err := getUsers(db, timeout)
if err != nil {
log.Errorf("refreshing users cache: %s", err.Error())
return
}
usersCache.Lock()
defer usersCache.Unlock()
usersCache.userMap = newUsers
usersCache.usernamesByToken = createTokenToUsernameMap(newUsers)
usersCache.initialized = true
log.Infof("refreshed users cache (len = %d)", len(usersCache.userMap))
}
func createTokenToUsernameMap(users map[string]user) map[string]string {
tokenToUserName := make(map[string]string)
for username, u := range users {
if u.Token == nil || u.RoleName == disallowed {
continue
}
tokenToUserName[*u.Token] = username
}
return tokenToUserName
}
func getUsers(db *sql.DB, timeout time.Duration) (map[string]user, error) {
dbCtx, dbClose := context.WithTimeout(context.Background(), timeout)
defer dbClose()
roles := make(map[int]role)
newUsers := make(map[string]user)
tx, err := db.BeginTx(dbCtx, nil)
if err != nil {
return nil, errors.New("beginning users transaction: " + err.Error())
}
defer func() {
if err := tx.Commit(); err != nil && err != sql.ErrTxDone {
log.Errorln("committing users transaction: " + err.Error())
}
}()
rolesRows, err := tx.QueryContext(dbCtx, getRolesQuery)
if err != nil {
return nil, errors.New("querying roles: " + err.Error())
}
defer log.Close(rolesRows, "closing role rows")
for rolesRows.Next() {
r := role{}
if err := rolesRows.Scan(&r.Capabilities, &r.ID, &r.Name, &r.PrivLevel); err != nil {
return nil, errors.New("scanning roles: " + err.Error())
}
roles[r.ID] = r
}
if err = rolesRows.Err(); err != nil {
return nil, errors.New("iterating over role rows: " + err.Error())
}
rows, err := tx.QueryContext(dbCtx, getUsersQuery)
if err != nil {
return nil, errors.New("querying users: " + err.Error())
}
defer log.Close(rows, "closing users rows")
for rows.Next() {
u := user{}
if err := rows.Scan(&u.ID, &u.LocalPasswd, &u.Role, &u.TenantID, &u.Token, &u.UCDN, &u.UserName); err != nil {
return nil, errors.New("scanning users: " + err.Error())
}
r := roles[u.Role]
u.RoleName = r.Name
u.PrivLevel = r.PrivLevel
u.Capabilities = r.Capabilities
u.perms = make(map[string]struct{}, len(u.Capabilities))
for _, perm := range u.Capabilities {
u.perms[perm] = struct{}{}
}
newUsers[u.UserName] = u
}
if err = rows.Err(); err != nil {
return nil, errors.New("iterating over user rows: " + err.Error())
}
return newUsers, nil
}