blob: abcd8c1f92e90063dabfcabf3452fc867b6c8567 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package server
import (
"context"
"io/fs"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"testing"
"testing/fstest"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/spf13/viper"
"github.com/stretchr/testify/suite"
)
type ServerTestSuite struct {
suite.Suite
}
func (s *ServerTestSuite) TestContainsDotFile() {
tests := []struct {
input string
expected bool
}{
{"foo/bar/baz", false},
{".foo/bar", true},
{"foo/.bar/baz", true},
{"foo/bar/.baz", true},
{".", true},
{"./foo", true},
{"", false},
}
for _, tt := range tests {
s.Equal(tt.expected, containsDotFile(tt.input), "input: %s", tt.input)
}
}
func (s *ServerTestSuite) TestDotFileHidingFileReaddir_FiltersDotFiles() {
// Use MapFS for backing files
fsMap := fstest.MapFS{
"foo.txt": {Data: []byte("foo")},
".hidden.txt": {Data: []byte("hidden")},
"bar.log": {Data: []byte("bar")},
".dotfile": {Data: []byte("dot")},
}
root, err := fs.Sub(fsMap, ".")
s.Require().NoError(err)
// Open all files in root
files, err := fs.ReadDir(root, ".")
s.Require().NoError(err)
// Convert DirEntry to FileInfo
fileInfos := make([]fs.FileInfo, 0, len(files))
for _, f := range files {
info, err := f.Info()
s.Require().NoError(err)
fileInfos = append(fileInfos, info)
}
df := dotFileHidingFile{mockHTTPFile{files: fileInfos}}
filtered, err := df.Readdir(-1)
s.NoError(err)
names := []string{}
for _, f := range filtered {
names = append(names, f.Name())
}
s.Equal([]string{"bar.log", "foo.txt"}, names)
}
type mockHTTPFile struct {
http.File
files []fs.FileInfo
}
func (m mockHTTPFile) Readdir(n int) ([]fs.FileInfo, error) {
return m.files, nil
}
func (s *ServerTestSuite) TestDotFileHidingFileSystem_Open_RejectsDotFiles() {
fsMap := fstest.MapFS{
".hidden/file.txt": {Data: []byte("secret")},
}
fsys := dotFileHidingFileSystem{http.FS(fsMap)}
_, err := fsys.Open(".hidden/file.txt")
s.ErrorIs(err, fs.ErrPermission)
}
func (s *ServerTestSuite) TestDotFileHidingFileSystem_Open_DelegatesForNormalFile() {
fsMap := fstest.MapFS{
"normal.txt": {Data: []byte("normal")},
}
fsys := dotFileHidingFileSystem{http.FS(fsMap)}
f, err := fsys.Open("normal.txt")
s.NoError(err)
_, ok := f.(dotFileHidingFile)
s.True(ok)
}
func (s *ServerTestSuite) TestMakeJWTParser_ValidAlgorithm() {
cfg := DefaultConfig
cfg.SecretKey = "testkey"
parser, err := makeJWTParser(cfg)
s.NoError(err)
s.NotNil(parser)
}
func (s *ServerTestSuite) TestMakeJWTParser_InvalidAlgorithm() {
cfg := DefaultConfig
cfg.Algorithm = "invalid-algo"
_, err := makeJWTParser(cfg)
s.Error(err)
}
func (s *ServerTestSuite) TestNewLogServer_MissingSecretKey() {
cfg := DefaultConfig
cfg.SecretKey = ""
_, err := NewLogServer(nil, cfg)
s.Error(err)
}
func (s *ServerTestSuite) TestNewFromConfig_MissingSecretKey() {
v := viper.New()
v.Set("logging.base_log_folder", ".")
v.Set("logging.secret_key", "")
_, err := NewFromConfig(v)
s.Error(err)
}
func (s *ServerTestSuite) TestNewFromConfig_AllFields() {
v := viper.New()
v.Set("logging.base_log_folder", ".")
v.Set("logging.secret_key", "mysecret")
v.Set("logging.worker_log_server_port", 12345)
v.Set("logging.clock_leeway", "10s")
v.Set("logging.audiences", []string{"aud1", "aud2"})
v.Set("logging.algorithm", "HS256")
srv, err := NewFromConfig(v)
s.NoError(err)
s.NotNil(srv)
}
func (s *ServerTestSuite) TestLogServer_ServeLog_ForwardsToFileServer() {
fsMap := fstest.MapFS{
"foo.log": {Data: []byte("logdata")},
}
cfg := DefaultConfig
cfg.SecretKey = "testkey"
cfg.BaseLogFolder = "."
srv, err := NewLogServer(slog.Default(), cfg)
s.NoError(err)
// override fileServer and fs to use MapFS
srv.fs = dotFileHidingFileSystem{http.FS(fsMap)}
srv.fileServer = http.FileServer(srv.fs)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/foo.log", nil)
srv.ServeLog(rec, req)
s.Contains([]int{http.StatusOK, http.StatusNotFound}, rec.Code)
}
func (s *ServerTestSuite) TestValidateToken_MissingAuthHeader() {
cfg := DefaultConfig
cfg.SecretKey = "testkey"
srv, err := NewLogServer(slog.Default(), cfg)
s.NoError(err)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/foo.log", nil)
handler := srv.validateToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.Fail("should not call next")
}))
handler.ServeHTTP(rec, req)
s.Equal(http.StatusForbidden, rec.Code)
s.Contains(rec.Body.String(), "Authorization header missing")
}
func (s *ServerTestSuite) TestValidateToken_InvalidToken() {
cfg := DefaultConfig
cfg.SecretKey = "testkey"
srv, err := NewLogServer(slog.Default(), cfg)
s.NoError(err)
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/foo.log", nil)
req.Header.Set("Authorization", "Bearer invalidtoken")
handler := srv.validateToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.Fail("should not call next")
}))
handler.ServeHTTP(rec, req)
s.Equal(http.StatusForbidden, rec.Code)
s.Contains(rec.Body.String(), "Invalid Authorization header")
}
func (s *ServerTestSuite) TestValidateToken_ValidTokenWrongFilename() {
cfg := DefaultConfig
cfg.SecretKey = "testkey"
srv, err := NewLogServer(slog.Default(), cfg)
s.NoError(err)
claims := jwt.MapClaims{
"filename": "/other.log",
"aud": DefaultAudience,
"exp": time.Now().Add(1 * time.Minute).Unix(),
}
token := jwt.NewWithClaims(jwt.GetSigningMethod(cfg.Algorithm), claims)
signed, _ := token.SignedString([]byte(cfg.SecretKey))
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/foo.log", nil)
req.Header.Set("Authorization", "Bearer "+signed)
handler := srv.validateToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
s.Fail("should not call next")
}))
handler.ServeHTTP(rec, req)
s.Equal(http.StatusForbidden, rec.Code)
s.Contains(rec.Body.String(), "Invalid Authorization header")
}
func (s *ServerTestSuite) TestValidateToken_ValidToken_CallsNext() {
cfg := DefaultConfig
cfg.SecretKey = "testkey"
srv, err := NewLogServer(slog.Default(), cfg)
s.NoError(err)
claims := jwt.MapClaims{
"filename": "/foo.log",
"aud": DefaultAudience,
"exp": time.Now().Add(1 * time.Minute).Unix(),
}
token := jwt.NewWithClaims(jwt.GetSigningMethod(cfg.Algorithm), claims)
signed, _ := token.SignedString([]byte(cfg.SecretKey))
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/foo.log", nil)
req.Header.Set("Authorization", "Bearer "+signed)
called := false
handler := srv.validateToken(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(http.StatusOK)
}))
handler.ServeHTTP(rec, req)
s.True(called)
s.Equal(http.StatusOK, rec.Code)
}
// TestRun starts the server on an available port and verifies it can be shut down via context cancellation.
func (s *ServerTestSuite) TestLogServer_Run_StartsAndShutsDownCleanly() {
cfg := DefaultConfig
cfg.SecretKey = "testkey"
cfg.Port = 0 // 0 means choose any available port
srv, err := NewLogServer(slog.Default(), cfg)
s.Require().NoError(err)
// Override handler to prevent actual file serving and simplify test
srv.server.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Listen on a random port
ln, err := net.Listen("tcp", "127.0.0.1:0")
s.Require().NoError(err)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
runErrCh := make(chan error)
go func() {
runErrCh <- srv.Serve(ctx, 2*time.Second, ln)
}()
// Connect to the server to ensure it's up
client := &http.Client{Timeout: 1 * time.Second}
url := "http://" + ln.Addr().String()
var resp *http.Response
for i := 0; i < 10; i++ {
resp, err = client.Get(url)
if err == nil {
break
}
time.Sleep(100 * time.Millisecond)
}
s.Require().NoError(err)
s.Equal(http.StatusOK, resp.StatusCode)
// Now shut down
cancel()
// Wait for shutdown
select {
case err := <-runErrCh:
s.NoError(err)
case <-time.After(5 * time.Second):
s.Fail("server did not shut down in time")
}
}
func TestServerTestSuite(t *testing.T) {
suite.Run(t, new(ServerTestSuite))
}