blob: 794b21a8095095cab9a902c0e2ffb8ccb9970b3c [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package docker_test
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
)
import (
"github.com/docker/docker/client"
"golang.org/x/crypto/ssh"
)
import (
"github.com/apache/dubbo-kubernetes/app/dubboctl/internal/docker"
)
func TestNewDockerClientWithSSH(t *testing.T) {
withCleanHome(t)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*1)
defer cancel()
sshConf := startSSH(t)
withKnowHosts(t, sshConf.address, sshConf.pubHostKey)
t.Setenv("DOCKER_HOST", fmt.Sprintf("ssh://user:pwd@%s", sshConf.address))
dockerClient, dockerHostInRemote, err := docker.NewClient(client.DefaultDockerHost)
if err != nil {
t.Fatal(err)
}
defer dockerClient.Close()
if dockerHostInRemote != `unix://`+sshDockerSocket {
t.Errorf("bad remote DOCKER_HOST: expected %q but got %q", `unix://`+sshDockerSocket, dockerHostInRemote)
}
_, err = dockerClient.Ping(ctx)
if err != nil {
t.Error(err)
}
}
const sshDockerSocket = "/some/path/docker.sock"
type sshConfig struct {
address string
pubHostKey ssh.PublicKey
}
// emulates remote machine with docker unix socket at "/some/path/docker.sock"
func startSSH(t *testing.T, authorizedKeys ...ssh.PublicKey) (settings sshConfig) {
var err error
ctx, cancel := context.WithCancel(context.Background())
httpServerErrChan := make(chan error, 1)
pollingLoopErr := make(chan error, 1)
config := &ssh.ServerConfig{
PasswordCallback: func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if string(password) != "pwd" {
return nil, errors.New("bad pwd")
}
return &ssh.Permissions{}, nil
},
PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
for _, authKey := range authorizedKeys {
if bytes.Equal(authKey.Marshal(), key.Marshal()) {
return &ssh.Permissions{}, nil
}
}
return nil, fmt.Errorf("unknown public key")
},
}
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Error(err)
}
hostKey, err := ssh.NewSignerFromKey(key)
if err != nil {
t.Error(err)
}
config.AddHostKey(hostKey)
settings.pubHostKey = hostKey.PublicKey()
sshTCPListener, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatal(err)
}
dockerDaemonServer := http.Server{}
t.Cleanup(func() {
var err error
cancel()
err = sshTCPListener.Close()
if err != nil {
t.Error(err)
}
err = <-pollingLoopErr
if err != nil && !errors.Is(err, net.ErrClosed) {
t.Error(err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err = dockerDaemonServer.Shutdown(ctx)
if err != nil {
t.Error(err)
}
err = <-httpServerErrChan
if err != nil && !strings.Contains(err.Error(), "Server closed") {
t.Error(err)
}
})
settings.address = sshTCPListener.Addr().String()
t.Logf("Listening on %s", sshTCPListener.Addr())
// mimics /_ping endpoint
dockerDaemonServer.Handler = http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.Header().Add("Content-Type", "text/plain")
writer.WriteHeader(200)
_, _ = writer.Write([]byte("OK"))
})
// listener that emulates unix socket in remote accessed via SSH
dockerDaemonListener := listener{make(chan io.ReadWriteCloser, 128)}
go func() {
httpServerErrChan <- dockerDaemonServer.Serve(dockerDaemonListener)
}()
handleChannel := func(newChannel ssh.NewChannel) {
switch newChannel.ChannelType() {
case "session":
handleSession(t, newChannel)
case "direct-streamlocal@openssh.com":
handleTunnel(t, newChannel, dockerDaemonListener)
default:
err = newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("type of channel %q is not supported", newChannel.ChannelType()))
if err != nil {
t.Error(err)
}
}
}
handleChannels := func(newChannels <-chan ssh.NewChannel) {
for newChannel := range newChannels {
go handleChannel(newChannel)
}
}
go func() {
for {
tcpConn, err := sshTCPListener.Accept()
if err != nil {
pollingLoopErr <- err
return
}
sshConn, newChannels, reqs, err := ssh.NewServerConn(tcpConn, config)
if err != nil {
pollingLoopErr <- err
return
}
go func() {
<-ctx.Done()
err = sshConn.Close()
if err != nil && !errors.Is(err, net.ErrClosed) {
t.Error(err)
}
}()
go ssh.DiscardRequests(reqs)
go handleChannels(newChannels)
}
}()
return
}
func handleSession(t *testing.T, newChannel ssh.NewChannel) {
ch, reqs, err := newChannel.Accept()
if err != nil {
t.Error(err)
}
go func() {
defer func() {
_ = ch.Close()
}()
for req := range reqs {
if req.Type == "exec" {
err = req.Reply(true, nil)
if err != nil {
t.Error(err)
}
data := struct {
Command string
}{}
err = ssh.Unmarshal(req.Payload, &data)
if err != nil {
t.Error(err)
}
var ret uint32
switch {
case data.Command == "set":
ret = 0
_, _ = fmt.Fprintf(ch, "DOCKER_HOST=unix://%s\n", sshDockerSocket)
default:
_, _ = fmt.Fprintf(ch.Stderr(), "unknown command: %q\n", data.Command)
ret = 127
}
msg := []byte{0, 0, 0, 0}
binary.BigEndian.PutUint32(msg, ret)
_, err = ch.SendRequest("exit-status", false, msg)
if err != nil {
t.Error(err)
}
return
}
}
}()
}
func handleTunnel(t *testing.T, newChannel ssh.NewChannel, dockerDaemonListener listener) {
var err error
extraData := newChannel.ExtraData()
data := struct {
SocketPath string
Reserved0 string
Reserved1 uint32
}{}
err = ssh.Unmarshal(extraData, &data)
if err != nil {
t.Error(err)
}
if data.SocketPath != sshDockerSocket {
err = newChannel.Reject(ssh.ConnectionFailed, fmt.Sprintf("bad socket: %q", data.SocketPath))
if err != nil {
t.Error(err)
}
return
}
ch, reqs, err := newChannel.Accept()
if err != nil {
t.Error(err)
}
select {
case dockerDaemonListener.connections <- ch:
default:
err = ch.Close()
if err != nil {
t.Error(err)
}
return
}
ssh.DiscardRequests(reqs)
}
type listener struct {
connections chan io.ReadWriteCloser
}
type channelConnection struct {
ch io.ReadWriteCloser
}
func (c channelConnection) Read(b []byte) (n int, err error) {
return c.ch.Read(b)
}
func (c channelConnection) Write(b []byte) (n int, err error) {
return c.ch.Write(b)
}
func (c channelConnection) Close() error {
return c.ch.Close()
}
func (c channelConnection) LocalAddr() net.Addr {
return &net.UnixAddr{Name: sshDockerSocket, Net: "unix"}
}
func (c channelConnection) RemoteAddr() net.Addr {
return &net.UnixAddr{Name: "@", Net: "unix"}
}
func (c channelConnection) SetDeadline(t time.Time) error { return nil }
func (c channelConnection) SetReadDeadline(t time.Time) error { return nil }
func (c channelConnection) SetWriteDeadline(t time.Time) error { return nil }
func (l listener) Accept() (net.Conn, error) {
rwc, ok := <-l.connections
if !ok {
return nil, errors.New("listener closed")
}
return channelConnection{rwc}, nil
}
func (l listener) Close() error {
close(l.connections)
return nil
}
func (l listener) Addr() net.Addr {
return &net.UnixAddr{Name: sshDockerSocket, Net: "unix"}
}
// sets clean temporary $HOME for test
// this prevents interaction with actual user home which may contain .ssh/
func withCleanHome(t *testing.T) {
t.Helper()
homeName := "HOME"
if runtime.GOOS == "windows" {
homeName = "USERPROFILE"
}
tmpDir, err := os.MkdirTemp("", "tmpHome")
if err != nil {
t.Fatal(err)
}
oldHome, hadHome := os.LookupEnv(homeName)
os.Setenv(homeName, tmpDir)
t.Cleanup(func() {
if hadHome {
os.Setenv(homeName, oldHome)
} else {
os.Unsetenv(homeName)
}
os.RemoveAll(tmpDir)
})
}
// withKnowHosts creates $HOME/.ssh/known_hosts that trust the host
func withKnowHosts(t *testing.T, host string, pubKey ssh.PublicKey) {
t.Helper()
var err error
var home string
home, err = os.UserHomeDir()
if err != nil {
t.Fatal(err)
}
knownHosts := filepath.Join(home, ".ssh", "known_hosts")
_, err = os.Stat(knownHosts)
if err == nil || !errors.Is(err, os.ErrNotExist) {
t.Fatal("known_hosts already exists")
}
err = os.MkdirAll(filepath.Join(home, ".ssh"), 0o700)
if err != nil {
t.Fatal(err)
}
knownHostFile, err := os.OpenFile(knownHosts, os.O_CREATE|os.O_WRONLY, 0o600)
if err != nil {
t.Fatal(err)
}
defer knownHostFile.Close()
fmt.Fprintf(knownHostFile, "%s %s\n", host, string(ssh.MarshalAuthorizedKey(pubKey)))
t.Cleanup(func() {
os.Remove(knownHosts)
})
}