blob: 600f31a8b2972b6a6f25d8a9f8b92ef1892c04f0 [file] [log] [blame]
package middleware
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import (
"bytes"
"compress/gzip"
"context"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/apache/trafficcontrol/lib/go-rfc"
"github.com/apache/trafficcontrol/lib/go-tc"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/api"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/auth"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/config"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/tocookie"
"github.com/jmoiron/sqlx"
"github.com/lestrrat-go/jwx/jwa"
"github.com/lestrrat-go/jwx/jwt"
sqlmock "gopkg.in/DATA-DOG/go-sqlmock.v1"
)
var debugLogging = flag.Bool("debug", false, "enable debug logging in test")
// TestWrapHeaders checks that appropriate default headers are added to a request
func TestWrapHeaders(t *testing.T) {
body := "We are here!!"
f := WrapHeaders(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(body))
})
w := httptest.NewRecorder()
r, err := http.NewRequest("", ".", nil)
if err != nil {
t.Error("Error creating new request")
}
// Call to add the headers
f(w, r)
if w.Body.String() != body {
t.Error("Expected body", body, "got", w.Body.String())
}
expected := map[string][]string{
"Access-Control-Allow-Credentials": nil,
"Access-Control-Allow-Headers": nil,
"Access-Control-Allow-Methods": nil,
"Access-Control-Allow-Origin": nil,
rfc.Vary: {rfc.AcceptEncoding},
"Content-Type": nil,
"Whole-Content-Sha512": nil,
"X-Server-Name": nil,
rfc.PermissionsPolicy: {"interest-cohort=()"},
}
if len(expected) != len(w.HeaderMap) {
t.Error("Expected", len(expected), "header, got", len(w.HeaderMap))
}
m := w.Header()
for k := range expected {
if _, ok := m[k]; !ok {
t.Error("Expected header", k, "not found")
} else if len(expected[k]) > 0 && !reflect.DeepEqual(expected[k], m[k]) {
t.Errorf("expected: %v, actual: %v", expected[k], m[k])
}
}
}
// TestWrapPanicRecover checks that a recovered panic returns a 500
func TestWrapPanicRecover(t *testing.T) {
f := WrapPanicRecover(func(w http.ResponseWriter, r *http.Request) {
var foo *string
bar := *foo // will throw nil dereference panic
w.Write([]byte(bar))
})
f = WrapHeaders(f)
w := httptest.NewRecorder()
r, err := http.NewRequest("", "/", nil)
if err != nil {
t.Error("Error creating new request")
}
// Call to wrap the panic recovery
f(w, r)
if w.Code != http.StatusInternalServerError {
t.Error("expected panic recovery to return a 500, got", w.Code)
}
}
// TestGzip checks that if Accept-Encoding contains "gzip" that the body is indeed gzip'd
func TestGzip(t *testing.T) {
body := "am I gzip'd?"
gz := bytes.Buffer{}
zw := gzip.NewWriter(&gz)
if _, err := zw.Write([]byte(body)); err != nil {
t.Error("Error gzipping", err)
}
if err := zw.Close(); err != nil {
t.Error("Error closing gzipper", err)
}
f := WrapHeaders(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(body))
})
w := httptest.NewRecorder()
r, err := http.NewRequest("", "/", nil)
if err != nil {
t.Error("Error creating new request")
}
f(w, r)
// body should not be gzip'd
if !bytes.Equal(w.Body.Bytes(), []byte(body)) {
t.Error("Expected body to be NOT gzip'd!")
}
// Call with gzip
w = httptest.NewRecorder()
r.Header.Add("Accept-Encoding", "gzip")
f(w, r)
if !bytes.Equal(w.Body.Bytes(), gz.Bytes()) {
t.Error("Expected body to be gzip'd!")
}
}
func newRWPair(t *testing.T, cookie *http.Cookie) (*httptest.ResponseRecorder, *http.Request) {
w := httptest.NewRecorder()
r, err := http.NewRequest("", "/api/4.0/blah", nil)
if err != nil {
t.Fatalf("Failed to create new request: %v", err)
}
if cookie != nil {
r.Header.Add("Cookie", tocookie.Name+"="+cookie.Value)
}
return w, r
}
func TestWrapAuth(t *testing.T) {
mockDB, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer mockDB.Close()
db := sqlx.NewDb(mockDB, "sqlmock")
defer db.Close()
userName := "user1"
id := 1
secret := "secret"
rows := sqlmock.NewRows([]string{"priv_level", "username", "id", "tenant_id"})
rows.AddRow(30, "user1", 1, 1)
mock.ExpectQuery("SELECT").WithArgs(userName).WillReturnRows(rows)
authBase := AuthBase{secret, nil}
cookie := tocookie.GetCookie(userName, time.Minute, secret)
handler := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
user, err := auth.GetCurrentUser(ctx)
if err != nil {
t.Fatalf("unable to get privLevel: %v", err)
return
}
respBts, err := json.Marshal(user)
if err != nil {
t.Fatalf("unable to marshal: %v", err)
return
}
w.Header().Set(rfc.ContentType, rfc.ApplicationJSON)
fmt.Fprintf(w, "%s", respBts)
}
authWrapper := authBase.GetWrapper(15)
f := authWrapper(handler)
w, r := newRWPair(t, cookie)
expected := auth.CurrentUser{UserName: userName, ID: id, PrivLevel: 30, TenantID: 1}
expectedBody, err := json.Marshal(expected)
if err != nil {
t.Fatalf("unable to marshal: %v", err)
}
r = r.WithContext(context.WithValue(context.Background(), api.DBContextKey, db))
r = r.WithContext(context.WithValue(r.Context(), api.ConfigContextKey, &config.Config{ConfigTrafficOpsGolang: config.ConfigTrafficOpsGolang{DBQueryTimeoutSeconds: 20}}))
f(w, r)
if !bytes.Equal(w.Body.Bytes(), expectedBody) {
t.Errorf("received: %s\n expected: %s\n", w.Body.Bytes(), expectedBody)
}
w, r = newRWPair(t, nil)
f(w, r)
expectedError := `{"alerts":[{"text":"unauthorized, please log in.","level":"error"}]}` + "\n"
if *debugLogging {
fmt.Printf("received: %s\n expected: %s\n", w.Body.Bytes(), expectedError)
}
if !bytes.Equal(w.Body.Bytes(), []byte(expectedError)) {
t.Errorf("received: %s\n expected: %s\n", w.Body.Bytes(), expectedError)
}
}
func TestRequiredPermissionsMiddleware(t *testing.T) {
mockDB, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer mockDB.Close()
db := sqlx.NewDb(mockDB, "sqlmock")
defer db.Close()
userName := "user1"
secret := "secret"
rows := sqlmock.NewRows([]string{"priv_level", "username", "id", "tenant_id", "capabilities"})
rows.AddRow(30, userName, 1, 1, "{foo}")
mock.ExpectQuery("SELECT").WithArgs(userName).WillReturnRows(rows)
authBase := AuthBase{secret, nil}
cookie := tocookie.GetCookie(userName, time.Minute, secret)
handler := func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("success\n"))
}
authWrapper := authBase.GetWrapper(0)
f := authWrapper(WrapHeaders(RequiredPermissionsMiddleware([]string{"foo"})(handler)))
w, r := newRWPair(t, cookie)
dbctx := context.WithValue(context.Background(), api.DBContextKey, db)
r = r.WithContext(dbctx)
conf := config.Config{
ConfigTrafficOpsGolang: config.ConfigTrafficOpsGolang{
DBQueryTimeoutSeconds: 20,
},
RoleBasedPermissions: true,
}
r = r.WithContext(context.WithValue(r.Context(), api.ConfigContextKey, &conf))
f(w, r)
if w.Code != http.StatusOK {
t.Errorf("Expected a 200 OK response when the user had all the required Permissions, got: %d", w.Code)
}
w, r = newRWPair(t, cookie)
r = r.WithContext(dbctx)
r = r.WithContext(context.WithValue(r.Context(), api.ConfigContextKey, &conf))
rows = sqlmock.NewRows([]string{"priv_level", "username", "id", "tenant_id", "capabilities"})
rows.AddRow(30, "user1", 1, 1, "{}")
mock.ExpectQuery("SELECT").WithArgs(userName).WillReturnRows(rows)
f(w, r)
result := w.Result()
if result.StatusCode != http.StatusForbidden {
t.Errorf("Expected a 403 Forbidden response when the user was missing the required Permissions, got: %d", result.StatusCode)
}
rawResp, err := ioutil.ReadAll(result.Body)
if err != nil {
t.Fatalf("Failed to read response body: %v", err)
}
var alerts tc.Alerts
if err := json.Unmarshal(rawResp, &alerts); err != nil {
t.Errorf("Failed to read response recorder body: %v", err)
}
if !strings.Contains(alerts.ErrorString(), "foo") {
t.Errorf("Expected an error-level alert mentioning the missing Permission, got: %s", alerts.ErrorString())
}
}
func TestConfigRoleBasedPermissionsHandling(t *testing.T) {
mockDB, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("an error '%s' was not expected when opening a stub database connection", err)
}
defer mockDB.Close()
db := sqlx.NewDb(mockDB, "sqlmock")
defer db.Close()
userName := "user1"
secret := "secret"
var rows *sqlmock.Rows
resetRows := func(privLevel int, caps ...string) {
rows = sqlmock.NewRows([]string{"priv_level", "username", "id", "tenant_id", "capabilities"})
rows.AddRow(privLevel, userName, 1, 1, fmt.Sprintf("{%s}", strings.Join(caps, ",")))
mock.ExpectQuery("SELECT").WithArgs(userName).WillReturnRows(rows)
}
resetRows(3, "foo")
cookie := tocookie.GetCookie(userName, time.Minute, secret)
conf := config.Config{
ConfigTrafficOpsGolang: config.ConfigTrafficOpsGolang{
DBQueryTimeoutSeconds: 20,
},
RoleBasedPermissions: false,
}
ctx := context.WithValue(context.Background(), api.DBContextKey, db)
ctx = context.WithValue(ctx, api.ConfigContextKey, &conf)
handler := func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("successs\n"))
}
f := WrapHeaders(AuthBase{secret, nil}.GetWrapper(5)(RequiredPermissionsMiddleware([]string{"foo"})(handler)))
w, r := newRWPair(t, cookie)
r = r.WithContext(ctx)
resetRows(3, "foo")
f(w, r)
result := w.Result()
if result.StatusCode != http.StatusForbidden {
t.Errorf("Expected a 403 Forbidden response when the user has insufficient PrivLevel and RoleBasedPermissions is configured to false; got: %d", result.StatusCode)
}
conf.RoleBasedPermissions = true
w, r = newRWPair(t, cookie)
r = r.WithContext(ctx)
f(w, r)
result = w.Result()
if result.StatusCode != http.StatusOK {
t.Errorf("Expected a user with the right Permissions for an endpoint to get a 200 OK response regardless of PrivLevel when RoleBasedPermissions is configured to true; got: %d", result.StatusCode)
}
resetRows(30)
w, r = newRWPair(t, cookie)
r = r.WithContext(ctx)
f(w, r)
result = w.Result()
if result.StatusCode != http.StatusForbidden {
t.Errorf("Expected a user with the wrong Permissions for an endpoint to get a 403 Forbidden response regardless of PrivLevel, got: %d", result.StatusCode)
}
var alerts tc.Alerts
if err := json.NewDecoder(result.Body).Decode(&alerts); err != nil {
t.Errorf("Failed to read and decode response body: %v", err)
} else {
errStr := alerts.ErrorString()
if !strings.Contains(errStr, "foo") {
t.Errorf("Expected the reason the user was denied access to be a missing Permission, actual: %s", errStr)
}
}
conf.RoleBasedPermissions = false
w, r = newRWPair(t, cookie)
r = r.WithContext(ctx)
resetRows(30)
f(w, r)
result = w.Result()
if result.StatusCode != http.StatusOK {
t.Errorf("Expected a user with the right PrivLevel for an endpoint to get a 200 OK response regardless of Permissions when RoleBasedPermissions is configured to false; got: %d", result.StatusCode)
}
}
func TestNoOpWhenNoPermissionsRequired(t *testing.T) {
respBts := "success"
handler := func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(respBts))
}
f := RequiredPermissionsMiddleware([]string{})(handler)
w, r := newRWPair(t, nil)
f(w, r)
result := w.Result()
if result.StatusCode != http.StatusOK {
t.Errorf("Expected checking for required Permissions to be a no-op when no Permissions are required, but response had status code %d", result.StatusCode)
}
body, err := ioutil.ReadAll(result.Body)
if err != nil {
t.Errorf("Failed to read response body: %v", err)
} else if string(body) != respBts {
t.Errorf("Expected normal response '%s' from endpoint, but got: %s", respBts, string(body))
}
f = RequiredPermissionsMiddleware(nil)(handler)
w, r = newRWPair(t, nil)
f(w, r)
result = w.Result()
if result.StatusCode != http.StatusOK {
t.Errorf("Expected checking for required Permissions to be a no-op when nil Permissions are required, but response had status code %d", result.StatusCode)
}
body, err = ioutil.ReadAll(result.Body)
if err != nil {
t.Errorf("Failed to read response body: %v", err)
} else if string(body) != respBts {
t.Errorf("Expected normal response '%s' from endpoint, but got: %s", respBts, string(body))
}
}
func TestGetCookieToken(t *testing.T) {
var cookies []http.Cookie
var jwtToken jwt.Token
var jwtSigned []byte
authUser := "foobar"
httpCookie := tocookie.GetCookie(authUser, 0, "fOObAR.")
jwtToken, _ = jwt.NewBuilder().Claim(api.MojoCookie, httpCookie.Value).Build()
jwtSigned, _ = jwt.Sign(jwtToken, jwa.HS256, []byte("fOObAR."))
mojoCookie := http.Cookie{Name: httpCookie.Name, Value: httpCookie.Value}
accessToken := http.Cookie{Name: "access_token", Value: string(jwtSigned)}
bearerToken := "Bearer " + string(jwtSigned)
cookies = append(cookies, mojoCookie, accessToken, http.Cookie{})
getUserFromCookie := func(cookieToken string) {
secret := "fOObAR."
user := ""
cookie, userErr, sysErr := tocookie.Parse(secret, cookieToken)
if userErr == nil && sysErr == nil {
user = cookie.AuthData
}
if user != "foobar" {
t.Errorf("Error: Unable to parse user from cookie. Expected: %v Got: %v", authUser, user)
}
}
r, err := http.NewRequest("GET", "https://localhost:8888", nil)
if err == nil && r != nil {
for i := range cookies {
if cookies[i].Name != "" {
r.AddCookie(&cookies[i])
cookieToken := getCookieToken(r)
getUserFromCookie(cookieToken)
} else {
r.Header.Add("Authorization", bearerToken)
cookieToken := getCookieToken(r)
getUserFromCookie(cookieToken)
}
}
}
}