blob: 1d55a8c14ae7ef006d9afcae0ee2df0fa7d185be [file] [log] [blame]
/**
* Licensed to jclouds, Inc. (jclouds) under one or more
* contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. jclouds licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.jclouds.sshj;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Predicates.instanceOf;
import static com.google.common.base.Predicates.or;
import static com.google.common.base.Throwables.getCausalChain;
import static com.google.common.collect.Iterables.any;
import static org.jclouds.crypto.CryptoStreams.hex;
import static org.jclouds.crypto.CryptoStreams.md5;
import static org.jclouds.crypto.SshKeys.fingerprintPrivateKey;
import static org.jclouds.crypto.SshKeys.sha1PrivateKey;
import java.io.IOException;
import java.io.InputStream;
import java.net.ConnectException;
import java.net.SocketTimeoutException;
import java.util.concurrent.TimeUnit;
import javax.annotation.PreDestroy;
import javax.annotation.Resource;
import javax.inject.Named;
import net.schmizz.sshj.SSHClient;
import net.schmizz.sshj.common.IOUtils;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.channel.direct.Session;
import net.schmizz.sshj.connection.channel.direct.Session.Command;
import net.schmizz.sshj.sftp.SFTPClient;
import net.schmizz.sshj.sftp.SFTPException;
import net.schmizz.sshj.transport.TransportException;
import net.schmizz.sshj.transport.verification.PromiscuousVerifier;
import net.schmizz.sshj.userauth.UserAuthException;
import net.schmizz.sshj.userauth.keyprovider.OpenSSHKeyFile;
import net.schmizz.sshj.xfer.InMemorySourceFile;
import org.apache.commons.io.input.ProxyInputStream;
import org.jclouds.compute.domain.ExecResponse;
import org.jclouds.http.handlers.BackoffLimitedRetryHandler;
import org.jclouds.io.Payload;
import org.jclouds.io.Payloads;
import org.jclouds.logging.Logger;
import org.jclouds.net.IPSocket;
import org.jclouds.rest.AuthorizationException;
import org.jclouds.ssh.SshClient;
import org.jclouds.ssh.SshException;
import org.jclouds.util.Throwables2;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.base.Splitter;
import com.google.common.base.Throwables;
import com.google.inject.Inject;
/**
* This class needs refactoring. It is not thread safe.
*
* @author Adrian Cole
*/
@SuppressWarnings("unchecked")
public class SshjSshClient implements SshClient {
private final class CloseFtpChannelOnCloseInputStream extends ProxyInputStream {
private final SFTPClient sftp;
private CloseFtpChannelOnCloseInputStream(InputStream proxy, SFTPClient sftp) {
super(proxy);
this.sftp = sftp;
}
@Override
public void close() throws IOException {
super.close();
if (sftp != null)
sftp.close();
}
}
private final String host;
private final int port;
private final String username;
private final String password;
private final String toString;
@Inject(optional = true)
@Named("jclouds.ssh.max-retries")
@VisibleForTesting
int sshRetries = 5;
@Inject(optional = true)
@Named("jclouds.ssh.retry-auth")
@VisibleForTesting
boolean retryAuth;
@Inject(optional = true)
@Named("jclouds.ssh.retryable-messages")
@VisibleForTesting
String retryableMessages = "";
@Inject(optional = true)
@Named("jclouds.ssh.retry-predicate")
// NOTE cannot retry io exceptions, as SSHException is a part of the chain
private Predicate<Throwable> retryPredicate = or(instanceOf(ConnectionException.class),
instanceOf(ConnectException.class), instanceOf(SocketTimeoutException.class),
instanceOf(TransportException.class),
// safe to retry sftp exceptions as they are idempotent
instanceOf(SFTPException.class));
@Resource
@Named("jclouds.ssh")
protected Logger logger = Logger.NULL;
@VisibleForTesting
SSHClient ssh;
private final byte[] privateKey;
final byte[] emptyPassPhrase = new byte[0];
private final int timeoutMillis;
private final BackoffLimitedRetryHandler backoffLimitedRetryHandler;
public SshjSshClient(BackoffLimitedRetryHandler backoffLimitedRetryHandler, IPSocket socket, int timeout,
String username, String password, byte[] privateKey) {
this.host = checkNotNull(socket, "socket").getAddress();
checkArgument(socket.getPort() > 0, "ssh port must be greater then zero" + socket.getPort());
checkArgument(password != null || privateKey != null, "you must specify a password or a key");
this.port = socket.getPort();
this.username = checkNotNull(username, "username");
this.backoffLimitedRetryHandler = checkNotNull(backoffLimitedRetryHandler, "backoffLimitedRetryHandler");
this.timeoutMillis = timeout;
this.password = password;
this.privateKey = privateKey;
if (privateKey == null) {
this.toString = String.format("%s:pw[%s]@%s:%d", username, hex(md5(password.getBytes())), host, port);
} else {
String fingerPrint = fingerprintPrivateKey(new String(privateKey));
String sha1 = sha1PrivateKey(new String(privateKey));
this.toString = String.format("%s:rsa[fingerprint(%s),sha1(%s)]@%s:%d", username, fingerPrint, sha1, host,
port);
}
}
@Override
public void put(String path, String contents) {
put(path, Payloads.newStringPayload(checkNotNull(contents, "contents")));
}
private void checkConnected() {
checkState(ssh != null && ssh.isConnected(), String.format("(%s) ssh not connected!", toString()));
}
public static interface Connection<T> {
void clear() throws Exception;
T create() throws Exception;
}
Connection<net.schmizz.sshj.SSHClient> sshConnection = new Connection<net.schmizz.sshj.SSHClient>() {
@Override
public void clear() {
if (ssh != null && ssh.isConnected()) {
try {
ssh.disconnect();
} catch (IOException e) {
logger.warn(e, "<< exception disconnecting from %s: %s", e, e.getMessage());
}
ssh = null;
}
}
@Override
public net.schmizz.sshj.SSHClient create() throws Exception {
net.schmizz.sshj.SSHClient ssh = new net.schmizz.sshj.SSHClient();
ssh.addHostKeyVerifier(new PromiscuousVerifier());
if (timeoutMillis != 0) {
ssh.setTimeout(timeoutMillis);
ssh.setConnectTimeout(timeoutMillis);
}
ssh.connect(host, port);
if (password != null) {
ssh.authPassword(username, password);
} else {
OpenSSHKeyFile key = new OpenSSHKeyFile();
key.init(new String(privateKey), null);
ssh.authPublickey(username, key);
}
return ssh;
}
@Override
public String toString() {
return String.format("SSHClient(timeout=%d)", timeoutMillis);
}
};
private void backoffForAttempt(int retryAttempt, String message) {
backoffLimitedRetryHandler.imposeBackoffExponentialDelay(200L, 2, retryAttempt, sshRetries, message);
}
protected <T, C extends Connection<T>> T acquire(C connection) {
String errorMessage = String.format("(%s) error acquiring %s", toString(), connection);
for (int i = 0; i < sshRetries; i++) {
try {
connection.clear();
logger.debug(">> (%s) acquiring %s", toString(), connection);
T returnVal = connection.create();
logger.debug("<< (%s) acquired %s", toString(), returnVal);
return returnVal;
} catch (Exception from) {
try {
disconnect();
} catch (Exception e1) {
logger.warn(from, "<< (%s) error closing connection", toString());
}
if (i + 1 == sshRetries) {
throw propagate(from, errorMessage+" (out of retries - max "+sshRetries+")");
} else if (shouldRetry(from) ||
(Throwables2.getFirstThrowableOfType(from, IllegalStateException.class) != null)) {
logger.info("<< " + errorMessage + " (attempt " + (i + 1) + " of " + sshRetries + "): " + from.getMessage());
backoffForAttempt(i + 1, errorMessage + ": " + from.getMessage());
if (connection != sshConnection)
connect();
continue;
} else {
throw propagate(from, errorMessage+" (not retryable)");
}
}
}
assert false : "should not reach here";
return null;
}
public void connect() {
try {
ssh = acquire(sshConnection);
} catch (Exception e) {
Throwables.propagate(e);
}
}
Connection<SFTPClient> sftpConnection = new Connection<SFTPClient>() {
private SFTPClient sftp;
@Override
public void clear() {
if (sftp != null)
try {
sftp.close();
} catch (IOException e) {
Throwables.propagate(e);
}
}
@Override
public SFTPClient create() throws IOException {
checkConnected();
sftp = ssh.newSFTPClient();
return sftp;
}
@Override
public String toString() {
return "SFTPClient()";
}
};
class GetConnection implements Connection<Payload> {
private final String path;
private SFTPClient sftp;
GetConnection(String path) {
this.path = checkNotNull(path, "path");
}
@Override
public void clear() throws IOException {
if (sftp != null)
sftp.close();
}
@Override
public Payload create() throws Exception {
sftp = acquire(sftpConnection);
return Payloads.newInputStreamPayload(new CloseFtpChannelOnCloseInputStream(sftp.getSFTPEngine().open(path)
.getInputStream(), sftp));
}
@Override
public String toString() {
return "Payload(path=[" + path + "])";
}
};
public Payload get(String path) {
return acquire(new GetConnection(path));
}
class PutConnection implements Connection<Void> {
private final String path;
private final Payload contents;
private SFTPClient sftp;
PutConnection(String path, Payload contents) {
this.path = checkNotNull(path, "path");
this.contents = checkNotNull(contents, "contents");
}
@Override
public void clear() {
if (sftp != null)
try {
sftp.close();
} catch (IOException e) {
Throwables.propagate(e);
}
}
@Override
public Void create() throws Exception {
sftp = acquire(sftpConnection);
try {
sftp.put(new InMemorySourceFile() {
@Override
public String getName() {
return path;
}
@Override
public long getLength() {
return contents.getContentMetadata().getContentLength();
}
@Override
public InputStream getInputStream() throws IOException {
return checkNotNull(contents.getInput(), "inputstream for path %s", path);
}
}, path);
} finally {
contents.release();
}
return null;
}
@Override
public String toString() {
return "Put(path=[" + path + "])";
}
};
@Override
public void put(String path, Payload contents) {
acquire(new PutConnection(path, contents));
}
@VisibleForTesting
boolean shouldRetry(Exception from) {
Predicate<Throwable> predicate = retryAuth ? Predicates.<Throwable> or(retryPredicate,
instanceOf(AuthorizationException.class), instanceOf(UserAuthException.class)) : retryPredicate;
if (any(getCausalChain(from), predicate))
return true;
if (!retryableMessages.equals(""))
return any(Splitter.on(",").split(retryableMessages), causalChainHasMessageContaining(from));
return false;
}
@VisibleForTesting
Predicate<String> causalChainHasMessageContaining(final Exception from) {
return new Predicate<String>() {
@Override
public boolean apply(final String input) {
return any(getCausalChain(from), new Predicate<Throwable>() {
@Override
public boolean apply(Throwable arg0) {
return (arg0.toString().indexOf(input) != -1)
|| (arg0.getMessage() != null && arg0.getMessage().indexOf(input) != -1);
}
});
}
};
}
@VisibleForTesting
SshException propagate(Exception e, String message) {
message += ": " + e.getMessage();
logger.error(e, "<< " + message);
if (e instanceof UserAuthException)
throw new AuthorizationException("(" + toString() + ") " + message, e);
throw e instanceof SshException ? SshException.class.cast(e) : new SshException(
"(" + toString() + ") " + message, e);
}
@Override
public String toString() {
return toString;
}
@PreDestroy
public void disconnect() {
try {
sshConnection.clear();
} catch (Exception e) {
Throwables.propagate(e);
}
}
protected Connection<Session> execConnection() {
return new Connection<Session>() {
private Session session = null;
@Override
public void clear() throws TransportException, ConnectionException {
if (session != null)
session.close();
}
@Override
public Session create() throws Exception {
checkConnected();
session = ssh.startSession();
session.allocateDefaultPTY();
return session;
}
@Override
public String toString() {
return "Session()";
}
};
}
class ExecConnection implements Connection<ExecResponse> {
private final String command;
private Session session;
ExecConnection(String command) {
this.command = checkNotNull(command, "command");
}
@Override
public void clear() throws TransportException, ConnectionException {
if (session != null)
session.close();
}
@Override
public ExecResponse create() throws Exception {
try {
session = acquire(execConnection());
Command output = session.exec(checkNotNull(command, "command"));
String outputString = IOUtils.readFully(output.getInputStream()).toString();
output.join(timeoutMillis, TimeUnit.SECONDS);
int errorStatus = output.getExitStatus();
String errorString = IOUtils.readFully(output.getErrorStream()).toString();
return new ExecResponse(outputString, errorString, errorStatus);
} finally {
clear();
}
}
@Override
public String toString() {
return "ExecResponse(command=[" + command + "])";
}
}
public ExecResponse exec(String command) {
return acquire(new ExecConnection(command));
}
@Override
public String getHostAddress() {
return this.host;
}
@Override
public String getUsername() {
return this.username;
}
}