blob: 26f571d08af255e82c5cdd4f79e8975639c39ef2 [file] [log] [blame]
/*
Copyright 2016 The Kubernetes Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package filters
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"sync"
"testing"
"time"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/util/diff"
"k8s.io/apimachinery/pkg/util/runtime"
)
type recorder struct {
lock sync.Mutex
count int
}
func (r *recorder) Record() {
r.lock.Lock()
defer r.lock.Unlock()
r.count++
}
func (r *recorder) Count() int {
r.lock.Lock()
defer r.lock.Unlock()
return r.count
}
func newHandler(responseCh <-chan string, panicCh <-chan struct{}, writeErrCh chan<- error) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case resp := <-responseCh:
_, err := w.Write([]byte(resp))
writeErrCh <- err
case <-panicCh:
panic("inner handler panics")
}
})
}
func TestTimeout(t *testing.T) {
origReallyCrash := runtime.ReallyCrash
runtime.ReallyCrash = false
defer func() {
runtime.ReallyCrash = origReallyCrash
}()
sendResponse := make(chan string, 1)
doPanic := make(chan struct{}, 1)
writeErrors := make(chan error, 1)
gotPanic := make(chan interface{}, 1)
timeout := make(chan time.Time, 1)
resp := "test response"
timeoutErr := apierrors.NewServerTimeout(schema.GroupResource{Group: "foo", Resource: "bar"}, "get", 0)
record := &recorder{}
handler := newHandler(sendResponse, doPanic, writeErrors)
ts := httptest.NewServer(withPanicRecovery(
WithTimeout(handler, func(req *http.Request) (*http.Request, <-chan time.Time, func(), *apierrors.StatusError) {
return req, timeout, record.Record, timeoutErr
}), func(w http.ResponseWriter, req *http.Request, err interface{}) {
gotPanic <- err
http.Error(w, "This request caused apiserver to panic. Look in the logs for details.", http.StatusInternalServerError)
}),
)
defer ts.Close()
// No timeouts
sendResponse <- resp
res, err := http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusOK {
t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusOK)
}
body, _ := ioutil.ReadAll(res.Body)
if string(body) != resp {
t.Errorf("got body %q; expected %q", string(body), resp)
}
if err := <-writeErrors; err != nil {
t.Errorf("got unexpected Write error on first request: %v", err)
}
if record.Count() != 0 {
t.Errorf("invoked record method: %#v", record)
}
// Times out
timeout <- time.Time{}
res, err = http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusGatewayTimeout {
t.Errorf("got res.StatusCode %d; expected %d", res.StatusCode, http.StatusServiceUnavailable)
}
body, _ = ioutil.ReadAll(res.Body)
status := &metav1.Status{}
if err := json.Unmarshal(body, status); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(status, &timeoutErr.ErrStatus) {
t.Errorf("unexpected object: %s", diff.ObjectReflectDiff(&timeoutErr.ErrStatus, status))
}
if record.Count() != 1 {
t.Errorf("did not invoke record method: %#v", record)
}
// Now try to send a response
sendResponse <- resp
if err := <-writeErrors; err != http.ErrHandlerTimeout {
t.Errorf("got Write error of %v; expected %v", err, http.ErrHandlerTimeout)
}
// Panics
doPanic <- struct{}{}
res, err = http.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
if res.StatusCode != http.StatusInternalServerError {
t.Errorf("got res.StatusCode %d; expected %d due to panic", res.StatusCode, http.StatusInternalServerError)
}
select {
case err := <-gotPanic:
msg := fmt.Sprintf("%v", err)
if !strings.Contains(msg, "newHandler") {
t.Errorf("expected line with root cause panic in the stack trace, but didn't: %v", err)
}
case <-time.After(30 * time.Second):
t.Fatalf("expected to see a handler panic, but didn't")
}
}