blob: 04367be2e5d2c1fce1bbeaebd5d6a0d1ae42f294 [file] [log] [blame]
package main
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import (
"crypto/rand"
"crypto/sha512"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"strconv"
"strings"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
_ "github.com/lib/pq"
)
const TrafficOpsDomain = "localhost"
const CookieName = "access_token"
const DBTable = "traffic_ops_auth_users"
// AllowedCreateUserRoles returns the database roles which are allowed to create users.
// This should be treated as a constant, and only exists because Go doesn't allow constant map literals.
func AllowedCreateUserRoles() map[string]struct{} {
return map[string]struct{}{
"admin": struct{}{},
}
}
// Config holds the configuration of the server.
type Config struct {
DBServer string `json:"database_server"`
DBPort uint `json:"database_port"`
DBName string `json:"database_name"`
DBUser string `json:"database_user"`
DBPass string `json:"database_password"`
AdminUser string `json:"admin_user"`
AdminPass string `json:"admin_pass"`
Port uint `json:"port"`
TokenKey string `json:"token_key"`
}
func createConnectionStringPostgres(server, database, user, pass string, port uint) (string, error) {
connString := fmt.Sprintf("dbname=%s user=%s password=%s sslmode=disable", database, user, pass)
if server != "" {
connString += fmt.Sprintf(" host=%s", server)
}
if port > 0 && port < 65536 {
connString += fmt.Sprintf(" port=%d", port)
}
return connString, nil
}
// authRolesStr returns a string of roles allowed to create users, as expected by SQL `IN` clauses.
func authRolesStr() string {
roles := AllowedCreateUserRoles()
s := "("
for role, _ := range roles {
s += "'" + role + "'" + ","
}
s = s[:len(s)-1] // strip trailing ,
s += ")"
return s
}
const CreateUserPath = "/create_user/"
// handleCreateUser is the HTTP handler for the create_user endpoint.
// It checks that the current logged in user is authorized to create users, and if so, creates the user.
// hasAdmin is a mutable pointer, so we don't have to check every time. Once an admin exists, the pointer is set to true
// TODO change to POST
func HandleCreateUser(db *sql.DB, jwtSigningKey string, w http.ResponseWriter, r *http.Request) {
token, err := getTokenData(jwtSigningKey, r)
if err != nil {
log.Printf("%s ERROR unauthorized: %s\n", r.RemoteAddr, err)
w.WriteHeader(http.StatusUnauthorized)
return
}
if _, ok := AllowedCreateUserRoles()[token.Role]; !ok {
log.Printf("%s ERROR unauthorized: %s role %s is not allowed to create users\n", r.RemoteAddr, token.User, token.Role)
w.WriteHeader(http.StatusUnauthorized)
return
}
path := r.URL.String()[len(CreateUserPath):]
parts := strings.Split(path, "/")
if len(parts) < 3 {
log.Printf("%s ERROR invalid request: not enough parts: '%s'. Syntax is '/create_user/name/pass/role\n", r.RemoteAddr, path)
w.WriteHeader(http.StatusBadRequest)
return
}
user := parts[0]
pass := parts[1]
role := parts[2]
if err := createUser(db, user, pass, role); err != nil {
// TODO return 400 if the error is that the user exists
w.WriteHeader(http.StatusInternalServerError)
log.Printf("%s ERROR failed to insert into database: '%v'\n", r.RemoteAddr, err)
return
}
w.WriteHeader(http.StatusOK)
}
type TokenData struct {
User string
Role string
}
func getTokenData(jwtSigningKey string, r *http.Request) (*TokenData, error) {
encToken, err := r.Cookie(CookieName)
if err != nil {
return nil, err
}
token, err := jwt.Parse(
[]byte(encToken.Value),
jwt.WithVerify(jwa.HS256, []byte(jwtSigningKey)),
)
if err != nil {
return nil, err
}
userInterface, hasUser := token.Get("user")
roleInterface, hasRole := token.Get("role")
user, userIsStr := userInterface.(string)
role, roleIsStr := roleInterface.(string)
if !hasUser || !hasRole || !userIsStr || !roleIsStr {
// we signed it: this should never happen
return nil, fmt.Errorf("token missing claims")
}
return &TokenData{User: user, Role: role}, nil
}
const LoginPath = "/login/"
func HandleLogin(db *sql.DB, jwtSigningKey string, w http.ResponseWriter, r *http.Request) {
// w.Header().Set("Content-Type", "text/plain")
path := r.URL.String()[len(LoginPath):]
parts := strings.Split(path, "/")
if len(parts) < 2 {
w.WriteHeader(http.StatusBadRequest)
log.Printf("%s ERROR invalid request: not enough parts: '%s'. Syntax is '/get_user/name/pass'\n", r.RemoteAddr, path)
return
}
user := parts[0]
pass := parts[1]
role, err := getUser(db, user, pass)
if err != nil {
w.WriteHeader(http.StatusUnauthorized)
log.Printf("%s ERROR failed to get user '%s' with %v\n", r.RemoteAddr, path, err)
return
}
token, err := jwt.NewBuilder().
Claim(`user`, user).
Claim(`role`, role).
Build()
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Printf("ERROR creating token: %s", err)
return
}
signed, err := jwt.Sign(token, jwa.HS256, []byte(jwtSigningKey))
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
log.Printf("%s ERROR creating token for '%s': %v\n", r.RemoteAddr, user, err)
return
}
http.SetCookie(w, &http.Cookie{
Name: "access_token",
Value: string(signed),
Domain: TrafficOpsDomain,
Path: "/",
//Secure: true, // TODO uncomment when https is implemented
HttpOnly: true, // prevents the cookie being accessed by Javascript. DO NOT remove, security vulnerability
})
w.WriteHeader(http.StatusOK)
log.Printf("%s got user %s role %s\n", r.RemoteAddr, user, role)
}
type UserInfo struct {
User string `json:"user"`
Role string `json:"role"`
}
const UserInfoPath = "/user_info"
func HandleUserInfo(db *sql.DB, jwtSigningKey string, w http.ResponseWriter, r *http.Request) {
token, err := getTokenData(jwtSigningKey, r)
if err != nil {
log.Printf("%s ERROR unauthorized: %s\n", r.RemoteAddr, err)
w.WriteHeader(http.StatusUnauthorized)
return
}
userInfo := UserInfo{User: token.User, Role: token.Role}
jsonUserInfo, err := json.Marshal(userInfo)
if err != nil {
log.Printf("%s ERROR marshalling userinfo %v: %s\n", r.RemoteAddr, userInfo, err)
w.WriteHeader(http.StatusInternalServerError)
return
}
fmt.Fprintf(w, "%s", jsonUserInfo)
}
func getConfig(file string) (Config, error) {
configJson, err := ioutil.ReadFile(file)
if err != nil {
return Config{}, err
}
config := Config{}
err = json.Unmarshal(configJson, &config)
return config, err
}
func getDB(server string, port uint, name string, user string, pass string) (*sql.DB, error) {
connStr, err := createConnectionStringPostgres(server, name, user, pass, port)
if err != nil {
return nil, err
}
return sql.Open("postgres", connStr)
}
// Generates a base64-encoded 128-bit random number.
func randGUIDStr() string {
guid := make([]byte, 16, 16)
rand.Read(guid)
return base64.RawURLEncoding.EncodeToString(guid)
}
func adminExists(db *sql.DB) (bool, error) {
authRolesCount := 0
err := db.QueryRow(`SELECT COUNT(1) FROM "` + DBTable + `" WHERE role IN ` + authRolesStr() + `;`).Scan(&authRolesCount)
if err != nil {
return false, err
}
return authRolesCount > 0, nil
}
// TODO prepare query
func createUser(db *sql.DB, user, pass, role string) error {
salt := randGUIDStr()
hashedPassBytes := sha512.New().Sum([]byte(pass + salt))
hashedPass := base64.RawURLEncoding.EncodeToString(hashedPassBytes)
_, err := db.Exec(`insert into "`+DBTable+`" (username, hash, salt, role) VALUES ($1, $2, $3, $4);`, user, hashedPass, salt, role)
return err
}
// getUser returns the user's role, if authentication was successful. If authentication fails, err != nil
func getUser(db *sql.DB, user, pass string) (string, error) {
salt := ""
hashedPass := ""
role := ""
err := db.QueryRow(`select salt, hash, role from "`+DBTable+`" where username = $1;`, user).Scan(&salt, &hashedPass, &role)
if err != nil {
return "", err
}
hashedPassBytes := sha512.New().Sum([]byte(pass + salt))
sentHashedPass := base64.RawURLEncoding.EncodeToString(hashedPassBytes)
if hashedPass != sentHashedPass {
return "", fmt.Errorf("authentication failed")
}
return role, nil
}
func GetRoutes() map[string]func(db *sql.DB, jwtSigningKey string, w http.ResponseWriter, r *http.Request) {
return map[string]func(db *sql.DB, jwtSigningKey string, w http.ResponseWriter, r *http.Request){
LoginPath: HandleLogin,
CreateUserPath: HandleCreateUser,
UserInfoPath: HandleUserInfo,
}
}
func main() {
config, err := getConfig("config.json")
if err != nil {
log.Println(err)
return
}
if len(config.TokenKey) < 22 {
log.Printf("Token key is too short. Must be at least 128 bits. Suggested alternative: '%s'\n", randGUIDStr())
return
}
db, err := getDB(config.DBServer, config.DBPort, config.DBName, config.DBUser, config.DBPass)
if err != nil {
log.Println(err)
return
}
dbHasAdmin, err := adminExists(db)
if err != nil {
log.Printf("Error checking for admin: %v\n", err)
return
}
if !dbHasAdmin && (config.AdminUser == "" || config.AdminPass == "") {
log.Println("Error: no admin in config or database")
return
}
if !dbHasAdmin {
firstAllowedCreateRole := ""
for role, _ := range AllowedCreateUserRoles() {
firstAllowedCreateRole = role
break
}
if firstAllowedCreateRole == "" {
log.Printf("ERROR creating initial admin user: no roles are allowed to create users\n")
return
}
err := createUser(db, config.AdminUser, config.AdminPass, firstAllowedCreateRole)
if err != nil {
log.Printf("ERROR creating initial admin user: %v\n", err)
return
}
log.Printf("INFORMATION created initial admin user: %s\n", config.AdminUser)
}
if dbHasAdmin && (config.AdminUser != "" || config.AdminPass != "") {
log.Printf("WARNING admin exists in both database and config: ignoring config admin\n")
}
wrapHandleFunc := func(f func(db *sql.DB, jwtSigningKey string, w http.ResponseWriter, r *http.Request)) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
f(db, config.TokenKey, w, r)
}
}
routes := GetRoutes()
for path, handleFunc := range routes {
http.HandleFunc(path, wrapHandleFunc(handleFunc))
}
if err := http.ListenAndServe(":"+strconv.Itoa(int(config.Port)), nil); err != nil {
log.Println(err)
return
}
}