blob: f60aca3a9aef9a971e901dca0ee13910a52fb445 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sshd.client.subsystem.sftp;
import java.io.ByteArrayOutputStream;
import java.io.EOFException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.io.OutputStream;
import java.io.StreamCorruptedException;
import java.net.SocketTimeoutException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.sshd.client.channel.ChannelSubsystem;
import org.apache.sshd.client.channel.ClientChannel;
import org.apache.sshd.client.session.ClientSession;
import org.apache.sshd.common.PropertyResolverUtils;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.subsystem.sftp.SftpConstants;
import org.apache.sshd.common.subsystem.sftp.extensions.ParserUtils;
import org.apache.sshd.common.subsystem.sftp.extensions.VersionsParser.Versions;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
import org.apache.sshd.common.util.buffer.BufferUtils;
import org.apache.sshd.common.util.buffer.ByteArrayBuffer;
/**
* @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
*/
public class DefaultSftpClient extends AbstractSftpClient {
private final ClientSession clientSession;
private final ChannelSubsystem channel;
private final Map<Integer, Buffer> messages = new HashMap<>();
private final AtomicInteger cmdId = new AtomicInteger(100);
private final Buffer receiveBuffer = new ByteArrayBuffer();
private final byte[] workBuf = new byte[Integer.SIZE / Byte.SIZE]; // TODO in JDK-8 use Integer.BYTES
private final AtomicInteger versionHolder = new AtomicInteger(0);
private final AtomicBoolean closing = new AtomicBoolean(false);
private final Map<String, byte[]> extensions = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);
private final Map<String, byte[]> exposedExtensions = Collections.unmodifiableMap(extensions);
public DefaultSftpClient(ClientSession clientSession) throws IOException {
this.clientSession = ValidateUtils.checkNotNull(clientSession, "No client session");
this.channel = clientSession.createSubsystemChannel(SftpConstants.SFTP_SUBSYSTEM_NAME);
this.channel.setOut(new OutputStream() {
private final byte[] singleByte = new byte[1];
@Override
public void write(int b) throws IOException {
synchronized (singleByte) {
singleByte[0] = (byte) b;
write(singleByte);
}
}
@Override
public void write(byte[] b, int off, int len) throws IOException {
data(b, off, len);
}
});
this.channel.setErr(new ByteArrayOutputStream(Byte.MAX_VALUE));
long initializationTimeout = PropertyResolverUtils.getLongProperty(clientSession, SFTP_CHANNEL_OPEN_TIMEOUT, DEFAULT_CHANNEL_OPEN_TIMEOUT);
this.channel.open().verify(initializationTimeout);
this.channel.onClose(new Runnable() {
@SuppressWarnings("synthetic-access")
@Override
public void run() {
synchronized (messages) {
closing.set(true);
messages.notifyAll();
}
if (versionHolder.get() <= 0) {
log.warn("onClose({}) closed before version negotiated", channel);
}
}
});
try {
init(initializationTimeout);
} catch (IOException | RuntimeException e) {
this.channel.close(true);
throw e;
}
}
@Override
public int getVersion() {
return versionHolder.get();
}
@Override
public ClientSession getClientSession() {
return clientSession;
}
@Override
public ClientChannel getClientChannel() {
return channel;
}
@Override
public Map<String, byte[]> getServerExtensions() {
return exposedExtensions;
}
@Override
public boolean isClosing() {
return closing.get();
}
@Override
public boolean isOpen() {
return this.channel.isOpen();
}
@Override
public void close() throws IOException {
if (isOpen()) {
this.channel.close(false);
}
}
/**
* Receive binary data
* @param buf The buffer for the incoming data
* @param start Offset in buffer to place the data
* @param len Available space in buffer for the data
* @return Actual size of received data
* @throws IOException If failed to receive incoming data
*/
protected int data(byte[] buf, int start, int len) throws IOException {
Buffer incoming = new ByteArrayBuffer(buf, start, len);
// If we already have partial data, we need to append it to the buffer and use it
if (receiveBuffer.available() > 0) {
receiveBuffer.putBuffer(incoming);
incoming = receiveBuffer;
}
// Process commands
int rpos = incoming.rpos();
for (int count = 1; receive(incoming); count++) {
if (log.isTraceEnabled()) {
log.trace("data({}) Processed {} data messages", getClientChannel(), count);
}
}
int read = incoming.rpos() - rpos;
// Compact and add remaining data
receiveBuffer.compact();
if (receiveBuffer != incoming && incoming.available() > 0) {
receiveBuffer.putBuffer(incoming);
}
return read;
}
/**
* Read SFTP packets from buffer
*
* @param incoming The received {@link Buffer}
* @return {@code true} if data from incoming buffer was processed
* @throws IOException if failed to process the buffer
* @see #process(Buffer)
*/
protected boolean receive(Buffer incoming) throws IOException {
int rpos = incoming.rpos();
int wpos = incoming.wpos();
ClientSession session = getClientSession();
session.resetIdleTimeout();
if ((wpos - rpos) > 4) {
int length = incoming.getInt();
if (length < 5) {
throw new IOException("Illegal sftp packet length: " + length);
}
if ((wpos - rpos) >= (length + 4)) {
incoming.rpos(rpos);
incoming.wpos(rpos + 4 + length);
process(incoming);
incoming.rpos(rpos + 4 + length);
incoming.wpos(wpos);
return true;
}
}
incoming.rpos(rpos);
return false;
}
/**
* Process an SFTP packet
*
* @param incoming The received {@link Buffer}
* @throws IOException if failed to process the buffer
*/
protected void process(Buffer incoming) throws IOException {
// create a copy of the buffer in case it is being re-used
Buffer buffer = new ByteArrayBuffer(incoming.available() + Long.SIZE, false);
buffer.putBuffer(incoming);
int rpos = buffer.rpos();
int length = buffer.getInt();
int type = buffer.getUByte();
Integer id = buffer.getInt();
buffer.rpos(rpos);
if (log.isTraceEnabled()) {
log.trace("process({}) id={}, type={}, len={}",
getClientChannel(), id, SftpConstants.getCommandMessageName(type), length);
}
synchronized (messages) {
messages.put(id, buffer);
messages.notifyAll();
}
}
@Override
public int send(int cmd, Buffer buffer) throws IOException {
int id = cmdId.incrementAndGet();
int len = buffer.available();
if (log.isTraceEnabled()) {
log.trace("send({}) cmd={}, len={}, id={}",
getClientChannel(), SftpConstants.getCommandMessageName(cmd), len, id);
}
OutputStream dos = channel.getInvertedIn();
BufferUtils.writeInt(dos, 1 /* cmd */ + (Integer.SIZE / Byte.SIZE) /* id */ + len, workBuf);
dos.write(cmd & 0xFF);
BufferUtils.writeInt(dos, id, workBuf);
dos.write(buffer.array(), buffer.rpos(), len);
dos.flush();
return id;
}
@Override
public Buffer receive(int id) throws IOException {
Integer reqId = id;
synchronized (messages) {
for (int count = 1;; count++) {
if (isClosing() || (!isOpen())) {
throw new SshException("Channel is being closed");
}
Buffer buffer = messages.remove(reqId);
if (buffer != null) {
return buffer;
}
try {
messages.wait();
} catch (InterruptedException e) {
throw (IOException) new InterruptedIOException("Interrupted while waiting for messages at iteration #" + count).initCause(e);
}
}
}
}
protected Buffer read() throws IOException {
InputStream dis = channel.getInvertedOut();
int length = BufferUtils.readInt(dis, workBuf);
// must have at least command + length
// TODO in jdk-8 use Integer.BYTES
if (length < (1 + (Integer.SIZE / Byte.SIZE))) {
throw new IllegalArgumentException("Bad length: " + length);
}
// TODO in jdk-8 use Integer.BYTES
Buffer buffer = new ByteArrayBuffer(length + (Integer.SIZE / Byte.SIZE), false);
buffer.putInt(length);
int nb = length;
while (nb > 0) {
int readLen = dis.read(buffer.array(), buffer.wpos(), nb);
if (readLen < 0) {
throw new IllegalArgumentException("Premature EOF while read " + length + " bytes - remaining=" + nb);
}
buffer.wpos(buffer.wpos() + readLen);
nb -= readLen;
}
return buffer;
}
protected void init(long initializationTimeout) throws IOException {
ValidateUtils.checkTrue(initializationTimeout > 0L, "Invalid initialization timeout: %d", initializationTimeout);
// Send init packet
OutputStream dos = channel.getInvertedIn();
BufferUtils.writeInt(dos, 5 /* total length */, workBuf);
dos.write(SftpConstants.SSH_FXP_INIT);
BufferUtils.writeInt(dos, SftpConstants.SFTP_V6, workBuf);
dos.flush();
Buffer buffer;
Integer reqId;
synchronized (messages) {
/*
* We need to use a timeout since if the remote server does not support
* SFTP, we will not know it immediately. This is due to the fact that the
* request for the subsystem does not contain a reply as to its success or
* failure. Thus, the SFTP channel is created by the client, but there is
* no one on the other side to reply - thus the need for the timeout
*/
for (long remainingTimeout = initializationTimeout; (remainingTimeout > 0L) && messages.isEmpty() && (!isClosing()) && isOpen();) {
try {
long sleepStart = System.nanoTime();
messages.wait(remainingTimeout);
long sleepEnd = System.nanoTime();
long sleepDuration = sleepEnd - sleepStart;
long sleepMillis = TimeUnit.NANOSECONDS.toMillis(sleepDuration);
if (sleepMillis < 1L) {
remainingTimeout--;
} else {
remainingTimeout -= sleepMillis;
}
} catch (InterruptedException e) {
throw (IOException) new InterruptedIOException("Interruppted init()").initCause(e);
}
}
if (isClosing() || (!isOpen())) {
throw new EOFException("Closing while await init message");
}
if (messages.isEmpty()) {
throw new SocketTimeoutException("No incoming initialization response received within " + initializationTimeout + " msec.");
}
Collection<Integer> ids = messages.keySet();
Iterator<Integer> iter = ids.iterator();
reqId = iter.next();
buffer = messages.remove(reqId);
}
int length = buffer.getInt();
int type = buffer.getUByte();
int id = buffer.getInt();
if (log.isTraceEnabled()) {
log.trace("init({}) id={} type={} len={}",
getClientChannel(), id, SftpConstants.getCommandMessageName(type), length);
}
if (type == SftpConstants.SSH_FXP_VERSION) {
if (id < SftpConstants.SFTP_V3) {
throw new SshException("Unsupported sftp version " + id);
}
versionHolder.set(id);
if (log.isTraceEnabled()) {
log.trace("init({}) version={}", getClientChannel(), versionHolder);
}
while (buffer.available() > 0) {
String name = buffer.getString();
byte[] data = buffer.getBytes();
if (log.isTraceEnabled()) {
log.trace("init({}) added extension=", getClientChannel(), name);
}
extensions.put(name, data);
}
} else if (type == SftpConstants.SSH_FXP_STATUS) {
int substatus = buffer.getInt();
String msg = buffer.getString();
String lang = buffer.getString();
if (log.isTraceEnabled()) {
log.trace("init({})[id={}] - status: {} [{}] {}",
getClientChannel(), id, SftpConstants.getStatusName(substatus), lang, msg);
}
throwStatusException(SftpConstants.SSH_FXP_INIT, id, substatus, msg, lang);
} else {
handleUnexpectedPacket(SftpConstants.SSH_FXP_INIT, SftpConstants.SSH_FXP_VERSION, id, type, length, buffer);
}
}
/**
* @param selector The {@link SftpVersionSelector} to use - ignored if {@code null}
* @return The selected version (may be same as current)
* @throws IOException If failed to negotiate
*/
public int negotiateVersion(SftpVersionSelector selector) throws IOException {
int current = getVersion();
if (selector == null) {
return current;
}
Set<Integer> available = GenericUtils.asSortedSet(Collections.singleton(current));
Map<String, ?> parsed = getParsedServerExtensions();
Collection<String> extensions = ParserUtils.supportedExtensions(parsed);
if ((GenericUtils.size(extensions) > 0) && extensions.contains(SftpConstants.EXT_VERSION_SELECT)) {
Versions vers = GenericUtils.isEmpty(parsed) ? null : (Versions) parsed.get(SftpConstants.EXT_VERSIONS);
Collection<String> reported = (vers == null) ? null : vers.getVersions();
if (GenericUtils.size(reported) > 0) {
for (String v : reported) {
if (!available.add(Integer.valueOf(v))) {
continue; // debug breakpoint
}
}
}
}
int selected = selector.selectVersion(current, new ArrayList<>(available));
if (log.isDebugEnabled()) {
log.debug("negotiateVersion({}) current={} {} -> {}", getClientChannel(), current, available, selected);
}
if (selected == current) {
return current;
}
if (!available.contains(Integer.valueOf(selected))) {
throw new StreamCorruptedException("Selected version (" + selected + ") not part of available: " + available);
}
String verVal = String.valueOf(selected);
Buffer buffer = new ByteArrayBuffer((Integer.SIZE / Byte.SIZE) + SftpConstants.EXT_VERSION_SELECT.length() // extension name
+ (Integer.SIZE / Byte.SIZE) + verVal.length() + Byte.SIZE, false);
buffer.putString(SftpConstants.EXT_VERSION_SELECT);
buffer.putString(verVal);
checkCommandStatus(SftpConstants.SSH_FXP_EXTENDED, buffer);
versionHolder.set(selected);
return selected;
}
}