| /* |
| Copyright 2016 The Kubernetes Authors. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| */ |
| |
| package streaming |
| |
| import ( |
| "crypto/tls" |
| "io" |
| "net/http" |
| "net/http/httptest" |
| "net/url" |
| "strconv" |
| "strings" |
| "sync" |
| "testing" |
| |
| "github.com/stretchr/testify/assert" |
| "github.com/stretchr/testify/require" |
| |
| restclient "k8s.io/client-go/rest" |
| "k8s.io/client-go/tools/remotecommand" |
| "k8s.io/client-go/transport/spdy" |
| api "k8s.io/kubernetes/pkg/apis/core" |
| runtimeapi "k8s.io/kubernetes/pkg/kubelet/apis/cri/runtime/v1alpha2" |
| kubeletportforward "k8s.io/kubernetes/pkg/kubelet/server/portforward" |
| ) |
| |
| const ( |
| testAddr = "localhost:12345" |
| testContainerID = "container789" |
| testPodSandboxID = "pod0987" |
| ) |
| |
| func TestGetExec(t *testing.T) { |
| serv, err := NewServer(Config{ |
| Addr: testAddr, |
| }, nil) |
| assert.NoError(t, err) |
| |
| tlsServer, err := NewServer(Config{ |
| Addr: testAddr, |
| TLSConfig: &tls.Config{}, |
| }, nil) |
| assert.NoError(t, err) |
| |
| const pathPrefix = "cri/shim" |
| prefixServer, err := NewServer(Config{ |
| Addr: testAddr, |
| BaseURL: &url.URL{ |
| Scheme: "http", |
| Host: testAddr, |
| Path: "/" + pathPrefix + "/", |
| }, |
| }, nil) |
| assert.NoError(t, err) |
| |
| assertRequestToken := func(expectedReq *runtimeapi.ExecRequest, cache *requestCache, token string) { |
| req, ok := cache.Consume(token) |
| require.True(t, ok, "token %s not found!", token) |
| assert.Equal(t, expectedReq, req) |
| } |
| request := &runtimeapi.ExecRequest{ |
| ContainerId: testContainerID, |
| Cmd: []string{"echo", "foo"}, |
| Tty: true, |
| Stdin: true, |
| } |
| { // Non-TLS |
| resp, err := serv.GetExec(request) |
| assert.NoError(t, err) |
| expectedURL := "http://" + testAddr + "/exec/" |
| assert.Contains(t, resp.Url, expectedURL) |
| token := strings.TrimPrefix(resp.Url, expectedURL) |
| assertRequestToken(request, serv.(*server).cache, token) |
| } |
| |
| { // TLS |
| resp, err := tlsServer.GetExec(request) |
| assert.NoError(t, err) |
| expectedURL := "https://" + testAddr + "/exec/" |
| assert.Contains(t, resp.Url, expectedURL) |
| token := strings.TrimPrefix(resp.Url, expectedURL) |
| assertRequestToken(request, tlsServer.(*server).cache, token) |
| } |
| |
| { // Path prefix |
| resp, err := prefixServer.GetExec(request) |
| assert.NoError(t, err) |
| expectedURL := "http://" + testAddr + "/" + pathPrefix + "/exec/" |
| assert.Contains(t, resp.Url, expectedURL) |
| token := strings.TrimPrefix(resp.Url, expectedURL) |
| assertRequestToken(request, prefixServer.(*server).cache, token) |
| } |
| } |
| |
| func TestValidateExecAttachRequest(t *testing.T) { |
| type config struct { |
| tty bool |
| stdin bool |
| stdout bool |
| stderr bool |
| } |
| for _, tc := range []struct { |
| desc string |
| configs []config |
| expectErr bool |
| }{ |
| { |
| desc: "at least one stream must be true", |
| expectErr: true, |
| configs: []config{ |
| {false, false, false, false}, |
| {true, false, false, false}}, |
| }, |
| { |
| desc: "tty and stderr cannot both be true", |
| expectErr: true, |
| configs: []config{ |
| {true, false, false, true}, |
| {true, false, true, true}, |
| {true, true, false, true}, |
| {true, true, true, true}, |
| }, |
| }, |
| { |
| desc: "a valid config should pass", |
| expectErr: false, |
| configs: []config{ |
| {false, false, false, true}, |
| {false, false, true, false}, |
| {false, false, true, true}, |
| {false, true, false, false}, |
| {false, true, false, true}, |
| {false, true, true, false}, |
| {false, true, true, true}, |
| {true, false, true, false}, |
| {true, true, false, false}, |
| {true, true, true, false}, |
| }, |
| }, |
| } { |
| t.Run(tc.desc, func(t *testing.T) { |
| for _, c := range tc.configs { |
| // validate the exec request. |
| execReq := &runtimeapi.ExecRequest{ |
| ContainerId: testContainerID, |
| Cmd: []string{"date"}, |
| Tty: c.tty, |
| Stdin: c.stdin, |
| Stdout: c.stdout, |
| Stderr: c.stderr, |
| } |
| err := validateExecRequest(execReq) |
| assert.Equal(t, tc.expectErr, err != nil, "config: %v, err: %v", c, err) |
| |
| // validate the attach request. |
| attachReq := &runtimeapi.AttachRequest{ |
| ContainerId: testContainerID, |
| Tty: c.tty, |
| Stdin: c.stdin, |
| Stdout: c.stdout, |
| Stderr: c.stderr, |
| } |
| err = validateAttachRequest(attachReq) |
| assert.Equal(t, tc.expectErr, err != nil, "config: %v, err: %v", c, err) |
| } |
| }) |
| } |
| } |
| |
| func TestGetAttach(t *testing.T) { |
| serv, err := NewServer(Config{ |
| Addr: testAddr, |
| }, nil) |
| require.NoError(t, err) |
| |
| tlsServer, err := NewServer(Config{ |
| Addr: testAddr, |
| TLSConfig: &tls.Config{}, |
| }, nil) |
| require.NoError(t, err) |
| |
| assertRequestToken := func(expectedReq *runtimeapi.AttachRequest, cache *requestCache, token string) { |
| req, ok := cache.Consume(token) |
| require.True(t, ok, "token %s not found!", token) |
| assert.Equal(t, expectedReq, req) |
| } |
| |
| request := &runtimeapi.AttachRequest{ |
| ContainerId: testContainerID, |
| Stdin: true, |
| Tty: true, |
| } |
| { // Non-TLS |
| resp, err := serv.GetAttach(request) |
| assert.NoError(t, err) |
| expectedURL := "http://" + testAddr + "/attach/" |
| assert.Contains(t, resp.Url, expectedURL) |
| token := strings.TrimPrefix(resp.Url, expectedURL) |
| assertRequestToken(request, serv.(*server).cache, token) |
| } |
| |
| { // TLS |
| resp, err := tlsServer.GetAttach(request) |
| assert.NoError(t, err) |
| expectedURL := "https://" + testAddr + "/attach/" |
| assert.Contains(t, resp.Url, expectedURL) |
| token := strings.TrimPrefix(resp.Url, expectedURL) |
| assertRequestToken(request, tlsServer.(*server).cache, token) |
| } |
| } |
| |
| func TestGetPortForward(t *testing.T) { |
| podSandboxID := testPodSandboxID |
| request := &runtimeapi.PortForwardRequest{ |
| PodSandboxId: podSandboxID, |
| Port: []int32{1, 2, 3, 4}, |
| } |
| |
| { // Non-TLS |
| serv, err := NewServer(Config{ |
| Addr: testAddr, |
| }, nil) |
| assert.NoError(t, err) |
| resp, err := serv.GetPortForward(request) |
| assert.NoError(t, err) |
| expectedURL := "http://" + testAddr + "/portforward/" |
| assert.True(t, strings.HasPrefix(resp.Url, expectedURL)) |
| token := strings.TrimPrefix(resp.Url, expectedURL) |
| req, ok := serv.(*server).cache.Consume(token) |
| require.True(t, ok, "token %s not found!", token) |
| assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).PodSandboxId) |
| } |
| |
| { // TLS |
| tlsServer, err := NewServer(Config{ |
| Addr: testAddr, |
| TLSConfig: &tls.Config{}, |
| }, nil) |
| assert.NoError(t, err) |
| resp, err := tlsServer.GetPortForward(request) |
| assert.NoError(t, err) |
| expectedURL := "https://" + testAddr + "/portforward/" |
| assert.True(t, strings.HasPrefix(resp.Url, expectedURL)) |
| token := strings.TrimPrefix(resp.Url, expectedURL) |
| req, ok := tlsServer.(*server).cache.Consume(token) |
| require.True(t, ok, "token %s not found!", token) |
| assert.Equal(t, testPodSandboxID, req.(*runtimeapi.PortForwardRequest).PodSandboxId) |
| } |
| } |
| |
| func TestServeExec(t *testing.T) { |
| runRemoteCommandTest(t, "exec") |
| } |
| |
| func TestServeAttach(t *testing.T) { |
| runRemoteCommandTest(t, "attach") |
| } |
| |
| func TestServePortForward(t *testing.T) { |
| s, testServer := startTestServer(t) |
| defer testServer.Close() |
| |
| resp, err := s.GetPortForward(&runtimeapi.PortForwardRequest{ |
| PodSandboxId: testPodSandboxID, |
| }) |
| require.NoError(t, err) |
| reqURL, err := url.Parse(resp.Url) |
| require.NoError(t, err) |
| |
| transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{}) |
| require.NoError(t, err) |
| dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", reqURL) |
| streamConn, _, err := dialer.Dial(kubeletportforward.ProtocolV1Name) |
| require.NoError(t, err) |
| defer streamConn.Close() |
| |
| // Create the streams. |
| headers := http.Header{} |
| // Error stream is required, but unused in this test. |
| headers.Set(api.StreamType, api.StreamTypeError) |
| headers.Set(api.PortHeader, strconv.Itoa(testPort)) |
| _, err = streamConn.CreateStream(headers) |
| require.NoError(t, err) |
| // Setup the data stream. |
| headers.Set(api.StreamType, api.StreamTypeData) |
| headers.Set(api.PortHeader, strconv.Itoa(testPort)) |
| stream, err := streamConn.CreateStream(headers) |
| require.NoError(t, err) |
| |
| doClientStreams(t, "portforward", stream, stream, nil) |
| } |
| |
| // |
| // Run the remote command test. |
| // commandType is either "exec" or "attach". |
| func runRemoteCommandTest(t *testing.T, commandType string) { |
| s, testServer := startTestServer(t) |
| defer testServer.Close() |
| |
| var reqURL *url.URL |
| stdin, stdout, stderr := true, true, true |
| containerID := testContainerID |
| switch commandType { |
| case "exec": |
| resp, err := s.GetExec(&runtimeapi.ExecRequest{ |
| ContainerId: containerID, |
| Cmd: []string{"echo"}, |
| Stdin: stdin, |
| Stdout: stdout, |
| Stderr: stderr, |
| }) |
| require.NoError(t, err) |
| reqURL, err = url.Parse(resp.Url) |
| require.NoError(t, err) |
| case "attach": |
| resp, err := s.GetAttach(&runtimeapi.AttachRequest{ |
| ContainerId: containerID, |
| Stdin: stdin, |
| Stdout: stdout, |
| Stderr: stderr, |
| }) |
| require.NoError(t, err) |
| reqURL, err = url.Parse(resp.Url) |
| require.NoError(t, err) |
| } |
| |
| wg := sync.WaitGroup{} |
| wg.Add(2) |
| |
| stdinR, stdinW := io.Pipe() |
| stdoutR, stdoutW := io.Pipe() |
| stderrR, stderrW := io.Pipe() |
| |
| go func() { |
| defer wg.Done() |
| exec, err := remotecommand.NewSPDYExecutor(&restclient.Config{}, "POST", reqURL) |
| require.NoError(t, err) |
| |
| opts := remotecommand.StreamOptions{ |
| Stdin: stdinR, |
| Stdout: stdoutW, |
| Stderr: stderrW, |
| Tty: false, |
| } |
| require.NoError(t, exec.Stream(opts)) |
| }() |
| |
| go func() { |
| defer wg.Done() |
| doClientStreams(t, commandType, stdinW, stdoutR, stderrR) |
| }() |
| |
| wg.Wait() |
| |
| // Repeat request with the same URL should be a 404. |
| resp, err := http.Get(reqURL.String()) |
| require.NoError(t, err) |
| assert.Equal(t, http.StatusNotFound, resp.StatusCode) |
| } |
| |
| func startTestServer(t *testing.T) (Server, *httptest.Server) { |
| var s Server |
| testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| s.ServeHTTP(w, r) |
| })) |
| cleanup := true |
| defer func() { |
| if cleanup { |
| testServer.Close() |
| } |
| }() |
| |
| testURL, err := url.Parse(testServer.URL) |
| require.NoError(t, err) |
| |
| rt := newFakeRuntime(t) |
| config := DefaultConfig |
| config.BaseURL = testURL |
| s, err = NewServer(config, rt) |
| require.NoError(t, err) |
| |
| cleanup = false // Caller must close the test server. |
| return s, testServer |
| } |
| |
| const ( |
| testInput = "abcdefg" |
| testOutput = "fooBARbaz" |
| testErr = "ERROR!!!" |
| testPort = 12345 |
| ) |
| |
| func newFakeRuntime(t *testing.T) *fakeRuntime { |
| return &fakeRuntime{ |
| t: t, |
| } |
| } |
| |
| type fakeRuntime struct { |
| t *testing.T |
| } |
| |
| func (f *fakeRuntime) Exec(containerID string, cmd []string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { |
| assert.Equal(f.t, testContainerID, containerID) |
| doServerStreams(f.t, "exec", stdin, stdout, stderr) |
| return nil |
| } |
| |
| func (f *fakeRuntime) Attach(containerID string, stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan remotecommand.TerminalSize) error { |
| assert.Equal(f.t, testContainerID, containerID) |
| doServerStreams(f.t, "attach", stdin, stdout, stderr) |
| return nil |
| } |
| |
| func (f *fakeRuntime) PortForward(podSandboxID string, port int32, stream io.ReadWriteCloser) error { |
| assert.Equal(f.t, testPodSandboxID, podSandboxID) |
| assert.EqualValues(f.t, testPort, port) |
| doServerStreams(f.t, "portforward", stream, stream, nil) |
| return nil |
| } |
| |
| // Send & receive expected input/output. Must be the inverse of doClientStreams. |
| // Function will block until the expected i/o is finished. |
| func doServerStreams(t *testing.T, prefix string, stdin io.Reader, stdout, stderr io.Writer) { |
| if stderr != nil { |
| writeExpected(t, "server stderr", stderr, prefix+testErr) |
| } |
| readExpected(t, "server stdin", stdin, prefix+testInput) |
| writeExpected(t, "server stdout", stdout, prefix+testOutput) |
| } |
| |
| // Send & receive expected input/output. Must be the inverse of doServerStreams. |
| // Function will block until the expected i/o is finished. |
| func doClientStreams(t *testing.T, prefix string, stdin io.Writer, stdout, stderr io.Reader) { |
| if stderr != nil { |
| readExpected(t, "client stderr", stderr, prefix+testErr) |
| } |
| writeExpected(t, "client stdin", stdin, prefix+testInput) |
| readExpected(t, "client stdout", stdout, prefix+testOutput) |
| } |
| |
| // Read and verify the expected string from the stream. |
| func readExpected(t *testing.T, streamName string, r io.Reader, expected string) { |
| result := make([]byte, len(expected)) |
| _, err := io.ReadAtLeast(r, result, len(expected)) |
| assert.NoError(t, err, "stream %s", streamName) |
| assert.Equal(t, expected, string(result), "stream %s", streamName) |
| } |
| |
| // Write and verify success of the data over the stream. |
| func writeExpected(t *testing.T, streamName string, w io.Writer, data string) { |
| n, err := io.WriteString(w, data) |
| assert.NoError(t, err, "stream %s", streamName) |
| assert.Equal(t, len(data), n, "stream %s", streamName) |
| } |