| // Licensed to the Apache Software Foundation (ASF) under one or more |
| // contributor license agreements. See the NOTICE file distributed with |
| // this work for additional information regarding copyright ownership. |
| // The ASF licenses this file to You under the Apache License, Version 2.0 |
| // (the "License"); you may not use this file except in compliance with |
| // the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package docker_test |
| |
| import ( |
| "bytes" |
| "context" |
| "crypto/ecdsa" |
| "crypto/elliptic" |
| "crypto/rand" |
| "encoding/binary" |
| "errors" |
| "fmt" |
| "io" |
| "net" |
| "net/http" |
| "os" |
| "path/filepath" |
| "runtime" |
| "strings" |
| "testing" |
| "time" |
| |
| "github.com/apache/dubbo-kubernetes/app/dubboctl/internal/docker" |
| |
| "github.com/docker/docker/client" |
| "golang.org/x/crypto/ssh" |
| ) |
| |
| func TestNewDockerClientWithSSH(t *testing.T) { |
| withCleanHome(t) |
| |
| ctx, cancel := context.WithTimeout(context.Background(), time.Minute*1) |
| defer cancel() |
| |
| sshConf := startSSH(t) |
| |
| withKnowHosts(t, sshConf.address, sshConf.pubHostKey) |
| |
| t.Setenv("DOCKER_HOST", fmt.Sprintf("ssh://user:pwd@%s", sshConf.address)) |
| |
| dockerClient, dockerHostInRemote, err := docker.NewClient(client.DefaultDockerHost) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer dockerClient.Close() |
| |
| if dockerHostInRemote != `unix://`+sshDockerSocket { |
| t.Errorf("bad remote DOCKER_HOST: expected %q but got %q", `unix://`+sshDockerSocket, dockerHostInRemote) |
| } |
| |
| _, err = dockerClient.Ping(ctx) |
| if err != nil { |
| t.Error(err) |
| } |
| } |
| |
| const sshDockerSocket = "/some/path/docker.sock" |
| |
| type sshConfig struct { |
| address string |
| pubHostKey ssh.PublicKey |
| } |
| |
| // emulates remote machine with docker unix socket at "/some/path/docker.sock" |
| func startSSH(t *testing.T, authorizedKeys ...ssh.PublicKey) (settings sshConfig) { |
| var err error |
| |
| ctx, cancel := context.WithCancel(context.Background()) |
| httpServerErrChan := make(chan error, 1) |
| pollingLoopErr := make(chan error, 1) |
| |
| config := &ssh.ServerConfig{ |
| PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { |
| if string(password) != "pwd" { |
| return nil, errors.New("bad pwd") |
| } |
| return &ssh.Permissions{}, nil |
| }, |
| PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { |
| for _, authKey := range authorizedKeys { |
| if bytes.Equal(authKey.Marshal(), key.Marshal()) { |
| return &ssh.Permissions{}, nil |
| } |
| } |
| return nil, fmt.Errorf("unknown public key") |
| }, |
| } |
| |
| key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) |
| if err != nil { |
| t.Error(err) |
| } |
| hostKey, err := ssh.NewSignerFromKey(key) |
| if err != nil { |
| t.Error(err) |
| } |
| config.AddHostKey(hostKey) |
| settings.pubHostKey = hostKey.PublicKey() |
| |
| sshTCPListener, err := net.Listen("tcp", "localhost:0") |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| dockerDaemonServer := http.Server{} |
| t.Cleanup(func() { |
| var err error |
| cancel() |
| |
| err = sshTCPListener.Close() |
| if err != nil { |
| t.Error(err) |
| } |
| err = <-pollingLoopErr |
| if err != nil && !errors.Is(err, net.ErrClosed) { |
| t.Error(err) |
| } |
| ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) |
| defer cancel() |
| err = dockerDaemonServer.Shutdown(ctx) |
| if err != nil { |
| t.Error(err) |
| } |
| err = <-httpServerErrChan |
| if err != nil && !strings.Contains(err.Error(), "Server closed") { |
| t.Error(err) |
| } |
| }) |
| |
| settings.address = sshTCPListener.Addr().String() |
| |
| t.Logf("Listening on %s", sshTCPListener.Addr()) |
| |
| // mimics /_ping endpoint |
| dockerDaemonServer.Handler = http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { |
| writer.Header().Add("Content-Type", "text/plain") |
| writer.WriteHeader(200) |
| _, _ = writer.Write([]byte("OK")) |
| }) |
| |
| // listener that emulates unix socket in remote accessed via SSH |
| dockerDaemonListener := listener{make(chan io.ReadWriteCloser, 128)} |
| |
| go func() { |
| httpServerErrChan <- dockerDaemonServer.Serve(dockerDaemonListener) |
| }() |
| |
| handleChannel := func(newChannel ssh.NewChannel) { |
| switch newChannel.ChannelType() { |
| case "session": |
| handleSession(t, newChannel) |
| case "direct-streamlocal@openssh.com": |
| handleTunnel(t, newChannel, dockerDaemonListener) |
| default: |
| err = newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("type of channel %q is not supported", newChannel.ChannelType())) |
| if err != nil { |
| t.Error(err) |
| } |
| } |
| } |
| |
| handleChannels := func(newChannels <-chan ssh.NewChannel) { |
| for newChannel := range newChannels { |
| go handleChannel(newChannel) |
| } |
| } |
| |
| go func() { |
| for { |
| tcpConn, err := sshTCPListener.Accept() |
| if err != nil { |
| pollingLoopErr <- err |
| return |
| } |
| |
| sshConn, newChannels, reqs, err := ssh.NewServerConn(tcpConn, config) |
| if err != nil { |
| pollingLoopErr <- err |
| return |
| } |
| go func() { |
| <-ctx.Done() |
| err = sshConn.Close() |
| if err != nil && !errors.Is(err, net.ErrClosed) { |
| t.Error(err) |
| } |
| }() |
| |
| go ssh.DiscardRequests(reqs) |
| |
| go handleChannels(newChannels) |
| } |
| }() |
| |
| return |
| } |
| |
| func handleSession(t *testing.T, newChannel ssh.NewChannel) { |
| ch, reqs, err := newChannel.Accept() |
| if err != nil { |
| t.Error(err) |
| } |
| go func() { |
| defer func() { |
| _ = ch.Close() |
| }() |
| for req := range reqs { |
| if req.Type == "exec" { |
| err = req.Reply(true, nil) |
| if err != nil { |
| t.Error(err) |
| } |
| data := struct { |
| Command string |
| }{} |
| err = ssh.Unmarshal(req.Payload, &data) |
| if err != nil { |
| t.Error(err) |
| } |
| var ret uint32 |
| switch { |
| case data.Command == "set": |
| ret = 0 |
| _, _ = fmt.Fprintf(ch, "DOCKER_HOST=unix://%s\n", sshDockerSocket) |
| default: |
| _, _ = fmt.Fprintf(ch.Stderr(), "unknown command: %q\n", data.Command) |
| ret = 127 |
| } |
| msg := []byte{0, 0, 0, 0} |
| binary.BigEndian.PutUint32(msg, ret) |
| _, err = ch.SendRequest("exit-status", false, msg) |
| if err != nil { |
| t.Error(err) |
| } |
| |
| return |
| } |
| } |
| }() |
| } |
| |
| func handleTunnel(t *testing.T, newChannel ssh.NewChannel, dockerDaemonListener listener) { |
| var err error |
| extraData := newChannel.ExtraData() |
| data := struct { |
| SocketPath string |
| Reserved0 string |
| Reserved1 uint32 |
| }{} |
| |
| err = ssh.Unmarshal(extraData, &data) |
| if err != nil { |
| t.Error(err) |
| } |
| |
| if data.SocketPath != sshDockerSocket { |
| err = newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("bad socket: %q", data.SocketPath)) |
| if err != nil { |
| t.Error(err) |
| } |
| return |
| } |
| |
| ch, reqs, err := newChannel.Accept() |
| if err != nil { |
| t.Error(err) |
| } |
| select { |
| case dockerDaemonListener.connections <- ch: |
| default: |
| err = ch.Close() |
| if err != nil { |
| t.Error(err) |
| } |
| return |
| } |
| |
| ssh.DiscardRequests(reqs) |
| } |
| |
| type listener struct { |
| connections chan io.ReadWriteCloser |
| } |
| |
| type channelConnection struct { |
| ch io.ReadWriteCloser |
| } |
| |
| func (c channelConnection) Read(b []byte) (n int, err error) { |
| return c.ch.Read(b) |
| } |
| |
| func (c channelConnection) Write(b []byte) (n int, err error) { |
| return c.ch.Write(b) |
| } |
| |
| func (c channelConnection) Close() error { |
| return c.ch.Close() |
| } |
| |
| func (c channelConnection) LocalAddr() net.Addr { |
| return &net.UnixAddr{Name: sshDockerSocket, Net: "unix"} |
| } |
| |
| func (c channelConnection) RemoteAddr() net.Addr { |
| return &net.UnixAddr{Name: "@", Net: "unix"} |
| } |
| |
| func (c channelConnection) SetDeadline(t time.Time) error { return nil } |
| |
| func (c channelConnection) SetReadDeadline(t time.Time) error { return nil } |
| |
| func (c channelConnection) SetWriteDeadline(t time.Time) error { return nil } |
| |
| func (l listener) Accept() (net.Conn, error) { |
| rwc, ok := <-l.connections |
| if !ok { |
| return nil, errors.New("listener closed") |
| } |
| return channelConnection{rwc}, nil |
| } |
| |
| func (l listener) Close() error { |
| close(l.connections) |
| return nil |
| } |
| |
| func (l listener) Addr() net.Addr { |
| return &net.UnixAddr{Name: sshDockerSocket, Net: "unix"} |
| } |
| |
| // sets clean temporary $HOME for test |
| // this prevents interaction with actual user home which may contain .ssh/ |
| func withCleanHome(t *testing.T) { |
| t.Helper() |
| homeName := "HOME" |
| if runtime.GOOS == "windows" { |
| homeName = "USERPROFILE" |
| } |
| tmpDir, err := os.MkdirTemp("", "tmpHome") |
| if err != nil { |
| t.Fatal(err) |
| } |
| oldHome, hadHome := os.LookupEnv(homeName) |
| os.Setenv(homeName, tmpDir) |
| |
| t.Cleanup(func() { |
| if hadHome { |
| os.Setenv(homeName, oldHome) |
| } else { |
| os.Unsetenv(homeName) |
| } |
| os.RemoveAll(tmpDir) |
| }) |
| } |
| |
| // withKnowHosts creates $HOME/.ssh/known_hosts that trust the host |
| func withKnowHosts(t *testing.T, host string, pubKey ssh.PublicKey) { |
| t.Helper() |
| |
| var err error |
| var home string |
| |
| home, err = os.UserHomeDir() |
| if err != nil { |
| t.Fatal(err) |
| } |
| knownHosts := filepath.Join(home, ".ssh", "known_hosts") |
| |
| _, err = os.Stat(knownHosts) |
| if err == nil || !errors.Is(err, os.ErrNotExist) { |
| t.Fatal("known_hosts already exists") |
| } |
| |
| err = os.MkdirAll(filepath.Join(home, ".ssh"), 0o700) |
| if err != nil { |
| t.Fatal(err) |
| } |
| |
| knownHostFile, err := os.OpenFile(knownHosts, os.O_CREATE|os.O_WRONLY, 0o600) |
| if err != nil { |
| t.Fatal(err) |
| } |
| defer knownHostFile.Close() |
| |
| fmt.Fprintf(knownHostFile, "%s %s\n", host, string(ssh.MarshalAuthorizedKey(pubKey))) |
| |
| t.Cleanup(func() { |
| os.Remove(knownHosts) |
| }) |
| } |