blob: 54670dee7ef4588161abead2968e2a8ea2bb68e3 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// NOTE: this code is based on "github.com/containers/podman/v3/pkg/bindings"
package ssh
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"net"
urlPkg "net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/docker/cli/cli/connhelper"
"github.com/docker/docker/pkg/homedir"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"golang.org/x/crypto/ssh/knownhosts"
)
type (
PasswordCallback func() (string, error)
PassPhraseCallback func() (string, error)
HostKeyCallback func(hostPort string, pubKey ssh.PublicKey) error
)
type Config struct {
Identity string
PassPhrase string
PasswordCallback PasswordCallback
PassPhraseCallback PassPhraseCallback
HostKeyCallback HostKeyCallback
}
type DialContextFn = func(ctx context.Context, network, addr string) (net.Conn, error)
// NewDialContext allows access to docker daemon in a remote machine using SSH.
//
// It creates a new ContextDialer which dials docker daemon in the remote
// and also returns Docker Host URI as seen by the remote.
//
// Knowing the Docker Host is useful when mounting docker socket into a container.
//
// Dialing the remote docker daemon can be done in two ways:
//
// - Use SSH to tunnel Unix/TCP socket.
//
// - Use SSH to execute the "docker system dial-stdio" command in the remote and forward its stdio.
//
// The tunnel method is used whenever possible.
// The "stdio" method is used as a fallback when tunneling is not possible:
// e.g. when remote uses Windows' named pipe.
//
// When tunneling is used all connection dialed
// by the returned ContextDialer are tunneled via single SSH connection.
// The connection should be disposed when dialer is no longer needed.
//
// For this reason returned ContextDialer may also implement io.Closer.
// Caller of this function should check if the returned ContextDialer
// is also an instance of io.Closer and call Close() on it if it is.
func NewDialContext(url *urlPkg.URL, config Config) (ContextDialer, string, error) {
sshConfig, err := NewSSHClientConfig(url, config)
if err != nil {
return nil, "", err
}
port := url.Port()
if port == "" {
port = "22"
}
host := url.Hostname()
sshClient, err := ssh.Dial("tcp", net.JoinHostPort(host, port), sshConfig)
if err != nil {
return nil, "", fmt.Errorf("failed to dial ssh: %w", err)
}
defer func() {
if sshClient != nil {
sshClient.Close()
}
}()
var remoteDockerHost string
if url.Path != "" {
remoteDockerHost = fmt.Sprintf(`unix://%s`, url.Path)
} else {
remoteDockerHost, err = getRemoteDockerHost(sshClient)
if err != nil {
return nil, "", err
}
}
network, addr, err := getNetworkAndAddress(remoteDockerHost)
if err != nil {
return nil, "", err
}
if network == "npipe" {
// ssh tunneling doesn't support tunneling of Windows' named pipes
dialContext, err := stdioDialContext(url, sshClient, config.Identity)
return contextDialerFn(dialContext), remoteDockerHost, err
}
d := dialer{sshClient: sshClient, addr: addr, network: network}
// moving ownership of sshClient from this function to the returned structure
sshClient = nil
return &d, remoteDockerHost, nil
}
type dialer struct {
sshClient *ssh.Client
network string
addr string
}
type ContextDialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
type contextDialerFn DialContextFn
func (n contextDialerFn) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return n(ctx, network, address)
}
func (d *dialer) DialContext(ctx context.Context, n, a string) (net.Conn, error) {
conn, err := d.Dial(d.network, d.addr)
if err != nil {
return nil, err
}
go func() {
if ctx != nil {
<-ctx.Done()
conn.Close()
}
}()
return conn, nil
}
func (d *dialer) Dial(n, a string) (net.Conn, error) {
return d.sshClient.Dial(d.network, d.addr)
}
func (d *dialer) Close() error {
return d.sshClient.Close()
}
func isWindowsMachine(sshClient *ssh.Client) (bool, error) {
session, err := sshClient.NewSession()
if err != nil {
return false, err
}
defer session.Close()
out, err := session.CombinedOutput("systeminfo")
if err == nil && strings.Contains(string(out), "Windows") {
return true, nil
}
return false, nil
}
func getRemoteDockerHost(sshClient *ssh.Client) (remoteDockerHost string, err error) {
session, err := sshClient.NewSession()
if err != nil {
return
}
defer session.Close()
out, err := session.CombinedOutput("set")
if err != nil {
return
}
remoteDockerHost = "unix:///var/run/docker.sock"
isWin, err := isWindowsMachine(sshClient)
if err != nil {
return
}
if isWin {
remoteDockerHost = "npipe:////./pipe/docker_engine"
}
scanner := bufio.NewScanner(bytes.NewBuffer(out))
for scanner.Scan() {
if strings.HasPrefix(scanner.Text(), "DOCKER_HOST=") {
parts := strings.SplitN(scanner.Text(), "=", 2)
remoteDockerHost = strings.Trim(parts[1], `"'`)
break
}
}
return remoteDockerHost, err
}
func getNetworkAndAddress(remoteDockerHost string) (network string, addr string, err error) {
remoteDockerHostURL, err := urlPkg.Parse(remoteDockerHost)
if err != nil {
return
}
switch remoteDockerHostURL.Scheme {
case "unix", "npipe":
addr = remoteDockerHostURL.Path
case "fd":
remoteDockerHostURL.Scheme = "tcp" // don't know why it works that way
fallthrough
case "tcp":
addr = remoteDockerHostURL.Host
default:
return "", "", errors.New("scheme is not supported")
}
network = remoteDockerHostURL.Scheme
return network, addr, err
}
func stdioDialContext(url *urlPkg.URL, sshClient *ssh.Client, identity string) (DialContextFn, error) {
session, err := sshClient.NewSession()
if err != nil {
return nil, err
}
defer session.Close()
out, err := session.CombinedOutput("docker system dial-stdio --help")
if err != nil {
return nil, fmt.Errorf("cannot use dial-stdio: %w (%q)", err, out)
}
var opts []string
if identity != "" {
opts = append(opts, "-i", identity)
}
connHelper, err := connhelper.GetConnectionHelperWithSSHOpts(url.String(), opts)
if err != nil {
return nil, err
}
return connHelper.Dialer, nil
}
// Default key names.
var knownKeyNames = []string{"id_rsa", "id_dsa", "id_ecdsa", "id_ecdsa_sk", "id_ed25519", "id_ed25519_sk"}
func NewSSHClientConfig(url *urlPkg.URL, credentialsConfig Config) (*ssh.ClientConfig, error) {
var (
authMethods []ssh.AuthMethod
signers []ssh.Signer
err error
)
if pw, found := url.User.Password(); found {
authMethods = append(authMethods, ssh.Password(pw))
}
// add signer from explicit identity parameter
if credentialsConfig.Identity != "" {
s, err := publicKey(credentialsConfig.Identity, []byte(credentialsConfig.Identity), credentialsConfig.PassPhraseCallback)
if err != nil {
return nil, fmt.Errorf("failed to parse identity file: %w", err)
}
signers = append(signers, s)
}
// add signers from ssh-agent
if sock, found := os.LookupEnv("SSH_AUTH_SOCK"); found && sock != "" {
var agentSigners []ssh.Signer
var agentConn net.Conn
agentConn, err = dialSSHAgentConnection(sock)
if err != nil {
return nil, fmt.Errorf("failed to connect to ssh-agent's socket: %w", err)
}
agentSigners, err = agent.NewClient(agentConn).Signers()
if err != nil {
return nil, fmt.Errorf("failed to get signers from ssh-agent: %w", err)
}
signers = append(signers, agentSigners...)
}
// if there is no explicit identity file nor keys from ssh-agent then
// add keys with standard name from ~/.ssh/
if len(signers) == 0 {
var defaultKeyPaths []string
if home, err := os.UserHomeDir(); err == nil {
for _, keyName := range knownKeyNames {
p := filepath.Join(home, ".ssh", keyName)
fi, err := os.Stat(p)
if err != nil {
continue
}
if fi.Mode().IsRegular() {
defaultKeyPaths = append(defaultKeyPaths, p)
}
}
}
if len(defaultKeyPaths) == 1 {
s, err := publicKey(defaultKeyPaths[0], []byte(credentialsConfig.PassPhrase), credentialsConfig.PassPhraseCallback)
if err != nil {
return nil, err
}
signers = append(signers, s)
}
}
if len(signers) > 0 {
dedup := make(map[string]ssh.Signer)
// Dedup signers based on fingerprint, ssh-agent keys override explicit identity
for _, s := range signers {
fp := ssh.FingerprintSHA256(s.PublicKey())
//if _, found := dedup[fp]; found {
// key updated
//}
dedup[fp] = s
}
var uniq []ssh.Signer
for _, s := range dedup {
uniq = append(uniq, s)
}
authMethods = append(authMethods, ssh.PublicKeysCallback(func() ([]ssh.Signer, error) {
return uniq, nil
}))
}
if len(authMethods) == 0 && credentialsConfig.PasswordCallback != nil {
authMethods = append(authMethods, ssh.PasswordCallback(credentialsConfig.PasswordCallback))
}
const sshTimeout = 5
clientConfig := &ssh.ClientConfig{
User: url.User.Username(),
Auth: authMethods,
HostKeyCallback: createHostKeyCallback(credentialsConfig.HostKeyCallback),
HostKeyAlgorithms: []string{
ssh.KeyAlgoECDSA256,
ssh.KeyAlgoECDSA384,
ssh.KeyAlgoECDSA521,
ssh.KeyAlgoED25519,
ssh.KeyAlgoRSASHA256,
ssh.KeyAlgoRSASHA512,
ssh.KeyAlgoRSA,
ssh.KeyAlgoDSA,
},
Timeout: sshTimeout * time.Second,
}
return clientConfig, nil
}
func publicKey(path string, passphrase []byte, passPhraseCallback PassPhraseCallback) (ssh.Signer, error) {
key, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read key file: %w", err)
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
var missingPhraseError *ssh.PassphraseMissingError
if ok := errors.As(err, &missingPhraseError); !ok {
return nil, fmt.Errorf("failed to parse private key: %w", err)
}
if len(passphrase) == 0 && passPhraseCallback != nil {
b, err := passPhraseCallback()
if err != nil {
return nil, err
}
passphrase = []byte(b)
}
return ssh.ParsePrivateKeyWithPassphrase(key, passphrase)
}
return signer, nil
}
func createHostKeyCallback(hostKeyCallback HostKeyCallback) func(hostPort string, remote net.Addr, key ssh.PublicKey) error {
return func(hostPort string, remote net.Addr, pubKey ssh.PublicKey) error {
host, port := hostPort, "22"
if _h, _p, err := net.SplitHostPort(host); err == nil {
host, port = _h, _p
}
knownHosts := filepath.Join(homedir.Get(), ".ssh", "known_hosts")
_, err := os.Stat(knownHosts)
if err != nil && errors.Is(err, os.ErrNotExist) {
if hostKeyCallback != nil && hostKeyCallback(hostPort, pubKey) == nil {
return nil
}
return errUnknownServerKey
}
f, err := os.Open(knownHosts)
if err != nil {
return fmt.Errorf("failed to open known_hosts: %w", err)
}
defer f.Close()
hashhost := knownhosts.HashHostname(host)
var errs []error
scanner := bufio.NewScanner(f)
for scanner.Scan() {
_, hostPorts, _key, _, _, err := ssh.ParseKnownHosts(scanner.Bytes())
if err != nil {
errs = append(errs, err)
continue
}
for _, hp := range hostPorts {
h, p := hp, "22"
if _h, _p, err := net.SplitHostPort(hp); err == nil {
h, p = _h, _p
}
if (h == host || h == hashhost) && port == p {
if pubKey.Type() != _key.Type() {
errs = append(errs, fmt.Errorf("missmatch in type of a key"))
continue
}
if bytes.Equal(_key.Marshal(), pubKey.Marshal()) {
return nil
}
return errBadServerKey
}
}
}
if hostKeyCallback != nil && hostKeyCallback(hostPort, pubKey) == nil {
return nil
}
if len(errs) > 0 {
return fmt.Errorf("server is not trusted (%v)", errs)
}
return errUnknownServerKey
}
}
var (
ErrBadServerKeyMsg = "server key for given host differs from key in known_host"
ErrUnknownServerKeyMsg = "server key not found in known_hosts"
)
// I would expose those but since ssh pkg doesn't do correct error wrapping it would be entirely futile
var (
errBadServerKey = errors.New(ErrBadServerKeyMsg)
errUnknownServerKey = errors.New(ErrUnknownServerKeyMsg)
)