[SSHD-639] Review code and reduce buffer re-use
diff --git a/sshd-core/src/main/java/org/apache/sshd/agent/common/AbstractAgentClient.java b/sshd-core/src/main/java/org/apache/sshd/agent/common/AbstractAgentClient.java
index 770f967..e28e947 100644
--- a/sshd-core/src/main/java/org/apache/sshd/agent/common/AbstractAgentClient.java
+++ b/sshd-core/src/main/java/org/apache/sshd/agent/common/AbstractAgentClient.java
@@ -66,8 +66,7 @@
return;
}
- // we can re-use the incoming message buffer since its data has been copied to the request buffer
- Buffer rep = BufferUtils.clear(message);
+ Buffer rep = new ByteArrayBuffer();
rep.putInt(0);
rep.rpos(rep.wpos());
diff --git a/sshd-core/src/main/java/org/apache/sshd/client/auth/keyboard/UserAuthKeyboardInteractive.java b/sshd-core/src/main/java/org/apache/sshd/client/auth/keyboard/UserAuthKeyboardInteractive.java
index 608f451..73c545c 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/auth/keyboard/UserAuthKeyboardInteractive.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/auth/keyboard/UserAuthKeyboardInteractive.java
@@ -32,7 +32,6 @@
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
-import org.apache.sshd.common.util.buffer.BufferUtils;
/**
* Manages a "keyboard-interactive" exchange according to
@@ -188,7 +187,7 @@
session, service, num, rep.length);
}
- buffer = session.prepareBuffer(SshConstants.SSH_MSG_USERAUTH_INFO_RESPONSE, BufferUtils.clear(buffer));
+ buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_INFO_RESPONSE, rep.length * Long.SIZE + Byte.SIZE);
buffer.putInt(rep.length);
for (int index = 0; index < rep.length; index++) {
String r = rep[index];
diff --git a/sshd-core/src/main/java/org/apache/sshd/client/auth/password/UserAuthPassword.java b/sshd-core/src/main/java/org/apache/sshd/client/auth/password/UserAuthPassword.java
index 3db61be..5107465 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/auth/password/UserAuthPassword.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/auth/password/UserAuthPassword.java
@@ -28,7 +28,6 @@
import org.apache.sshd.common.SshConstants;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.buffer.Buffer;
-import org.apache.sshd.common.util.buffer.BufferUtils;
/**
* Implements the "password" authentication mechanism
@@ -122,7 +121,10 @@
session, service, name, modified);
}
- buffer = session.prepareBuffer(SshConstants.SSH_MSG_USERAUTH_REQUEST, BufferUtils.clear(buffer));
+ buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_REQUEST,
+ GenericUtils.length(username) + GenericUtils.length(service)
+ + GenericUtils.length(name) + GenericUtils.length(oldPassword)
+ + (modified ? GenericUtils.length(newPassword) : 0) + Long.SIZE);
buffer.putString(username);
buffer.putString(service);
buffer.putString(name);
diff --git a/sshd-core/src/main/java/org/apache/sshd/client/auth/pubkey/UserAuthPublicKey.java b/sshd-core/src/main/java/org/apache/sshd/client/auth/pubkey/UserAuthPublicKey.java
index 95bec60..8b283e2 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/auth/pubkey/UserAuthPublicKey.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/auth/pubkey/UserAuthPublicKey.java
@@ -31,8 +31,8 @@
import org.apache.sshd.common.config.keys.KeyUtils;
import org.apache.sshd.common.signature.Signature;
import org.apache.sshd.common.signature.SignatureFactoriesManager;
+import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.buffer.Buffer;
-import org.apache.sshd.common.util.buffer.BufferUtils;
import org.apache.sshd.common.util.buffer.ByteArrayBuffer;
/**
@@ -122,7 +122,10 @@
}
String username = session.getUsername();
- buffer = session.prepareBuffer(SshConstants.SSH_MSG_USERAUTH_REQUEST, BufferUtils.clear(buffer));
+ buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_REQUEST,
+ GenericUtils.length(username) + GenericUtils.length(service)
+ + GenericUtils.length(name) + GenericUtils.length(algo)
+ + ByteArrayBuffer.DEFAULT_SIZE + Long.SIZE);
buffer.putString(username);
buffer.putString(service);
buffer.putString(name);
diff --git a/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGEXClient.java b/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGEXClient.java
index 4f7748d..644ab48 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGEXClient.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGEXClient.java
@@ -35,7 +35,6 @@
import org.apache.sshd.common.util.SecurityUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
-import org.apache.sshd.common.util.buffer.BufferUtils;
import org.apache.sshd.common.util.buffer.ByteArrayBuffer;
@@ -125,7 +124,7 @@
if (log.isDebugEnabled()) {
log.debug("next({})[{}] Send SSH_MSG_KEX_DH_GEX_INIT", this, session);
}
- buffer = session.prepareBuffer(SshConstants.SSH_MSG_KEX_DH_GEX_INIT, BufferUtils.clear(buffer));
+ buffer = session.createBuffer(SshConstants.SSH_MSG_KEX_DH_GEX_INIT, e.length + Byte.SIZE);
buffer.putMPInt(e);
session.writePacket(buffer);
expected = SshConstants.SSH_MSG_KEX_DH_GEX_REPLY;
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/SshConstants.java b/sshd-core/src/main/java/org/apache/sshd/common/SshConstants.java
index 3c7b336..af46ab7 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/SshConstants.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/SshConstants.java
@@ -121,6 +121,7 @@
// Some more constants
public static final int SSH_EXTENDED_DATA_STDERR = 1; // see RFC4254 section 5.2
+ public static final int SSH_PACKET_HEADER_LEN = 5; // 32-bit length + 8-bit pad length
private SshConstants() {
throw new UnsupportedOperationException("No instance allowed");
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/AbstractChannel.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/AbstractChannel.java
index 51eae39..141707a 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/channel/AbstractChannel.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/AbstractChannel.java
@@ -303,14 +303,10 @@
byte cmd = RequestHandler.Result.ReplySuccess.equals(result)
? SshConstants.SSH_MSG_CHANNEL_SUCCESS
: SshConstants.SSH_MSG_CHANNEL_FAILURE;
- buffer.clear();
- // leave room for the SSH header
- buffer.ensureCapacity(5 + 1 + (Integer.SIZE / Byte.SIZE), RESPONSE_BUFFER_GROWTH_FACTOR);
- buffer.rpos(5);
- buffer.wpos(5);
- buffer.putByte(cmd);
- buffer.putInt(recipient);
- session.writePacket(buffer);
+ Session session = getSession();
+ Buffer rsp = session.createBuffer(cmd, Integer.SIZE / Byte.SIZE);
+ rsp.putInt(recipient);
+ session.writePacket(rsp);
}
@Override
@@ -543,9 +539,9 @@
log.debug("handleExtendedData({}) send SSH_MSG_CHANNEL_FAILURE - non STDERR type: {}", this, ex);
}
Session s = getSession();
- buffer = s.prepareBuffer(SshConstants.SSH_MSG_CHANNEL_FAILURE, BufferUtils.clear(buffer));
- buffer.putInt(getRecipient());
- writePacket(buffer);
+ Buffer rsp = s.createBuffer(SshConstants.SSH_MSG_CHANNEL_FAILURE, Integer.SIZE / Byte.SIZE);
+ rsp.putInt(getRecipient());
+ writePacket(rsp);
return;
}
int len = buffer.getInt();
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractConnectionService.java b/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractConnectionService.java
index 34f726a..707e38b 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractConnectionService.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractConnectionService.java
@@ -45,7 +45,6 @@
import org.apache.sshd.common.util.Int2IntFunction;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
-import org.apache.sshd.common.util.buffer.BufferUtils;
import org.apache.sshd.common.util.closeable.AbstractInnerCloseable;
import org.apache.sshd.server.channel.OpenChannelException;
import org.apache.sshd.server.x11.X11ForwardSupport;
@@ -459,8 +458,9 @@
this, sender, SshConstants.getOpenErrorCodeName(reasonCode), lang, message);
}
- final Session session = getSession();
- Buffer buf = session.prepareBuffer(SshConstants.SSH_MSG_CHANNEL_OPEN_FAILURE, BufferUtils.clear(buffer));
+ Session session = getSession();
+ Buffer buf = session.createBuffer(SshConstants.SSH_MSG_CHANNEL_OPEN_FAILURE,
+ Long.SIZE + GenericUtils.length(message) + GenericUtils.length(lang));
buf.putInt(sender);
buf.putInt(reasonCode);
buf.putString(message);
@@ -534,9 +534,8 @@
? SshConstants.SSH_MSG_REQUEST_SUCCESS
: SshConstants.SSH_MSG_REQUEST_FAILURE;
Session session = getSession();
- buffer = session.prepareBuffer(cmd, BufferUtils.clear(buffer));
- buffer.putByte(cmd);
- session.writePacket(buffer);
+ Buffer rsp = session.createBuffer(cmd, 2);
+ session.writePacket(rsp);
}
protected void requestSuccess(Buffer buffer) throws Exception {
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractSession.java b/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractSession.java
index e5eaa54..d2516df 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractSession.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/session/AbstractSession.java
@@ -46,6 +46,7 @@
import org.apache.sshd.common.NamedResource;
import org.apache.sshd.common.PropertyResolver;
import org.apache.sshd.common.PropertyResolverUtils;
+import org.apache.sshd.common.RuntimeSshException;
import org.apache.sshd.common.Service;
import org.apache.sshd.common.SshConstants;
import org.apache.sshd.common.SshException;
@@ -165,8 +166,8 @@
protected Compression inCompression;
protected long seqi;
protected long seqo;
- protected Buffer decoderBuffer = new ByteArrayBuffer();
- protected Buffer uncompressBuffer;
+ protected SessionWorkBuffer uncompressBuffer;
+ protected final SessionWorkBuffer decoderBuffer;
protected int decoderState;
protected int decoderLength;
protected final Object encodeLock = new Object();
@@ -224,6 +225,7 @@
this.isServer = isServer;
this.factoryManager = factoryManager;
this.ioSession = ioSession;
+ this.decoderBuffer = new SessionWorkBuffer(this);
Factory<Random> factory = ValidateUtils.checkNotNull(factoryManager.getRandomFactory(), "No random factory for %s", ioSession);
random = ValidateUtils.checkNotNull(factory.create(), "No randomizer instance for %s", ioSession);
@@ -277,7 +279,7 @@
* @param session the session to attach
*/
public static void attachSession(IoSession ioSession, AbstractSession session) {
- ioSession.setAttribute(SESSION, session);
+ ValidateUtils.checkNotNull(ioSession, "No I/O session").setAttribute(SESSION, ValidateUtils.checkNotNull(session, "No SSH session"));
}
@Override
@@ -412,11 +414,9 @@
* Abstract method for processing incoming decoded packets.
* The given buffer will hold the decoded packet, starting from
* the command byte at the read position.
- * Packets must be processed within this call or be copied because
- * the given buffer is meant to be changed and updated when this
- * method returns.
*
- * @param buffer the buffer containing the packet
+ * @param buffer The {@link Buffer} containing the packet - it may be
+ * re-used to generate the response once request has been decoded
* @throws Exception if an exception occurs while handling this packet.
* @see #doHandleMessage(Buffer)
*/
@@ -425,7 +425,7 @@
synchronized (lock) {
doHandleMessage(buffer);
}
- } catch (Exception e) {
+ } catch (Throwable e) {
DefaultKeyExchangeFuture kexFuture = kexFutureHolder.get();
// if have any ongoing KEX notify it about the failure
if (kexFuture != null) {
@@ -437,7 +437,11 @@
}
}
- throw e;
+ if (e instanceof Exception) {
+ throw (Exception) e;
+ } else {
+ throw new RuntimeSshException(e);
+ }
}
}
@@ -622,7 +626,7 @@
log.debug("handleServiceRequest({}) Accepted service {}", this, serviceName);
}
- Buffer response = prepareBuffer(SshConstants.SSH_MSG_SERVICE_ACCEPT, BufferUtils.clear(buffer));
+ Buffer response = createBuffer(SshConstants.SSH_MSG_SERVICE_ACCEPT, Byte.SIZE + GenericUtils.length(serviceName));
response.putString(serviceName);
writePacket(response);
}
@@ -782,12 +786,21 @@
SessionListener listener = getSessionListenerProxy();
try {
listener.sessionClosed(this);
- } catch (RuntimeException t) {
+ } catch (Throwable t) {
Throwable e = GenericUtils.peelException(t);
log.warn("preClose({}) {} while signal session closed: {}", this, e.getClass().getSimpleName(), e.getMessage());
if (log.isDebugEnabled()) {
log.debug("preClose(" + this + ") signal session closed exception details", e);
}
+
+ if (log.isTraceEnabled()) {
+ Throwable[] suppressed = e.getSuppressed();
+ if (GenericUtils.length(suppressed) > 0) {
+ for (Throwable s : suppressed) {
+ log.trace("preClose(" + this + ") suppressed session closed signalling", s);
+ }
+ }
+ }
} finally {
// clear the listeners since we are closing the session (quicker GC)
this.sessionListeners.clear();
@@ -893,6 +906,11 @@
}
}
+ int curPos = buffer.rpos();
+ byte[] data = buffer.array();
+ int cmd = data[curPos] & 0xFF; // usually the 1st byte is the command
+ buffer = validateTargetBuffer(cmd, buffer);
+
// Synchronize all write requests as needed by the encoding algorithm
// and also queue the write request in this synchronized block to ensure
// packets are sent in the correct order
@@ -1007,7 +1025,7 @@
// they actually send exactly this amount.
//
int bsize = outCipherSize;
- len += 5;
+ len += SshConstants.SSH_PACKET_HEADER_LEN;
int pad = (-len) & (bsize - 1);
if (pad < bsize) {
pad += bsize;
@@ -1022,14 +1040,30 @@
@Override
public Buffer prepareBuffer(byte cmd, Buffer buffer) {
- ValidateUtils.checkNotNull(buffer, "No buffer to prepare");
- buffer.rpos(5);
- buffer.wpos(5);
+ buffer = validateTargetBuffer(cmd & 0xFF, buffer);
+ buffer.rpos(SshConstants.SSH_PACKET_HEADER_LEN);
+ buffer.wpos(SshConstants.SSH_PACKET_HEADER_LEN);
buffer.putByte(cmd);
return buffer;
}
/**
+ * Makes sure that the buffer used for output is not {@code null} or one
+ * of the session's internal ones used for decoding and uncompressing
+ *
+ * @param cmd The most likely command this buffer refers to (not guaranteed to be correct)
+ * @param buffer The buffer to be examined
+ * @return The validated target instance - default same as input
+ * @throws IllegalArgumentException if any of the conditions is violated
+ */
+ protected <B extends Buffer> B validateTargetBuffer(int cmd, B buffer) {
+ ValidateUtils.checkNotNull(buffer, "No target buffer to examine for command=%d", cmd);
+ ValidateUtils.checkTrue(buffer != decoderBuffer, "Not allowed to use the internal decoder buffer for command=%d", cmd);
+ ValidateUtils.checkTrue(buffer != uncompressBuffer, "Not allowed to use the internal uncompress buffer for command=%d", cmd);
+ return buffer;
+ }
+
+ /**
* Encode a buffer into the SSH protocol.
* This method need to be called into a synchronized block around encodeLock
*
@@ -1039,21 +1073,27 @@
protected void encode(Buffer buffer) throws IOException {
try {
// Check that the packet has some free space for the header
- if (buffer.rpos() < 5) {
- log.warn("Performance cost: when sending a packet, ensure that "
- + "5 bytes are available in front of the buffer");
+ int curPos = buffer.rpos();
+ if (curPos < SshConstants.SSH_PACKET_HEADER_LEN) {
+ byte[] data = buffer.array();
+ int cmd = data[curPos] & 0xFF; // usually the 1st byte is an SSH opcode
+ log.warn("encode({}) command={} performance cost: available buffer packet header length ({}) below min. required ({})",
+ this, SshConstants.getCommandMessageName(cmd), curPos, SshConstants.SSH_PACKET_HEADER_LEN);
Buffer nb = new ByteArrayBuffer(buffer.available() + Long.SIZE, false);
- nb.wpos(5);
+ nb.wpos(SshConstants.SSH_PACKET_HEADER_LEN);
nb.putBuffer(buffer);
buffer = nb;
+ curPos = buffer.rpos();
}
+
// Grab the length of the packet (excluding the 5 header bytes)
int len = buffer.available();
- int off = buffer.rpos() - 5;
+ int off = curPos - SshConstants.SSH_PACKET_HEADER_LEN;
// Debug log the packet
if (log.isTraceEnabled()) {
- log.trace("encode({}) Sending packet #{}: {}", this, Long.valueOf(seqo), buffer.printHex());
+ log.trace("encode({}) packet #{}: {}", this, seqo, buffer.printHex());
}
+
// Compress the packet if needed
if ((outCompression != null) && outCompression.isCompressionExecuted() && (authed || (!outCompression.isDelayed()))) {
outCompression.compress(buffer);
@@ -1063,7 +1103,7 @@
// Compute padding length
int bsize = outCipherSize;
int oldLen = len;
- len += 5;
+ len += SshConstants.SSH_PACKET_HEADER_LEN;
int pad = (-len) & (bsize - 1);
if (pad < bsize) {
pad += bsize;
@@ -1074,7 +1114,7 @@
buffer.putInt(len);
buffer.putByte((byte) pad);
// Fill padding
- buffer.wpos(off + oldLen + 5 + pad);
+ buffer.wpos(off + oldLen + SshConstants.SSH_PACKET_HEADER_LEN + pad);
random.fill(buffer.array(), buffer.wpos() - pad, pad);
// Compute mac
if (outMac != null) {
@@ -1130,8 +1170,8 @@
// Read packet length
decoderLength = decoderBuffer.getInt();
// Check packet length validity
- if ((decoderLength < 5) || (decoderLength > (256 * 1024))) {
- log.warn("decode({}) Error decoding packet(invalid length) {}", this, decoderBuffer.printHex());
+ if ((decoderLength < SshConstants.SSH_PACKET_HEADER_LEN) || (decoderLength > (256 * 1024))) {
+ log.warn("decode({}) Error decoding packet(invalid length): {}", this, decoderLength);
throw new SshException(SshConstants.SSH2_DISCONNECT_PROTOCOL_ERROR,
"Invalid packet length: " + decoderLength);
}
@@ -1174,33 +1214,33 @@
seqi = (seqi + 1) & 0xffffffffL;
// Get padding
int pad = decoderBuffer.getUByte();
- Buffer buf;
+ Buffer packet;
int wpos = decoderBuffer.wpos();
// Decompress if needed
if ((inCompression != null) && inCompression.isCompressionExecuted() && (authed || (!inCompression.isDelayed()))) {
if (uncompressBuffer == null) {
- uncompressBuffer = new ByteArrayBuffer();
+ uncompressBuffer = new SessionWorkBuffer(this);
} else {
- uncompressBuffer.clear();
+ uncompressBuffer.forceClear();
}
decoderBuffer.wpos(decoderBuffer.rpos() + decoderLength - 1 - pad);
inCompression.uncompress(decoderBuffer, uncompressBuffer);
- buf = uncompressBuffer;
+ packet = uncompressBuffer;
} else {
decoderBuffer.wpos(decoderLength + 4 - pad);
- buf = decoderBuffer;
+ packet = decoderBuffer;
}
if (log.isTraceEnabled()) {
- log.trace("decode({}) Received packet #{}: {}", this, seqi, buf.printHex());
+ log.trace("decode({}) packet #{}: {}", this, seqi, packet.printHex());
}
// Update stats
inPacketsCount.incrementAndGet();
- inBytesCount.addAndGet(buf.available());
+ inBytesCount.addAndGet(packet.available());
// Process decoded packet
- handleMessage(buf);
+ handleMessage(packet);
// Set ready to handle next packet
decoderBuffer.rpos(decoderLength + 4 + macSize);
decoderBuffer.wpos(wpos);
@@ -1714,8 +1754,10 @@
}
protected void requestSuccess(Buffer buffer) throws Exception {
+ // use a copy of the original data in case it is re-used on return
+ Buffer resultBuf = ByteArrayBuffer.getCompactClone(buffer.array(), buffer.rpos(), buffer.available());
synchronized (requestResult) {
- requestResult.set(new ByteArrayBuffer(buffer.getCompactData()));
+ requestResult.set(resultBuf);
resetIdleTimeout();
requestResult.notify();
}
@@ -1830,7 +1872,18 @@
protected void sendSessionEvent(SessionListener.Event event) throws IOException {
SessionListener listener = getSessionListenerProxy();
- listener.sessionEvent(this, event);
+ try {
+ listener.sessionEvent(this, event);
+ } catch (Throwable e) {
+ Throwable t = GenericUtils.peelException(e);
+ if (t instanceof IOException) {
+ throw (IOException) t;
+ } else if (t instanceof RuntimeException) {
+ throw (RuntimeException) t;
+ } else {
+ throw new IOException("Failed (" + t.getClass().getSimpleName() + ") to send session event: " + t.getMessage(), t);
+ }
+ }
}
@Override
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/SessionWorkBuffer.java b/sshd-core/src/main/java/org/apache/sshd/common/session/SessionWorkBuffer.java
new file mode 100644
index 0000000..1d4ce28
--- /dev/null
+++ b/sshd-core/src/main/java/org/apache/sshd/common/session/SessionWorkBuffer.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sshd.common.session;
+
+import org.apache.sshd.common.util.ValidateUtils;
+import org.apache.sshd.common.util.buffer.ByteArrayBuffer;
+
+/**
+ * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
+ */
+public class SessionWorkBuffer extends ByteArrayBuffer implements SessionHolder<Session> {
+ private final Session session;
+
+ public SessionWorkBuffer(Session session) {
+ this.session = ValidateUtils.checkNotNull(session, "No session");
+ }
+
+ @Override
+ public Session getSession() {
+ return session;
+ }
+
+ @Override
+ public void clear() {
+ throw new UnsupportedOperationException("Not allowed to clear session work buffer of " + getSession());
+ }
+
+ public void forceClear() {
+ super.clear();
+ }
+}
diff --git a/sshd-core/src/main/java/org/apache/sshd/common/util/buffer/ByteArrayBuffer.java b/sshd-core/src/main/java/org/apache/sshd/common/util/buffer/ByteArrayBuffer.java
index 650af18..8ffa954 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/util/buffer/ByteArrayBuffer.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/util/buffer/ByteArrayBuffer.java
@@ -21,6 +21,7 @@
import java.nio.charset.Charset;
+import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.Int2IntFunction;
import org.apache.sshd.common.util.Readable;
import org.apache.sshd.common.util.ValidateUtils;
@@ -30,7 +31,7 @@
*
* @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
*/
-public final class ByteArrayBuffer extends Buffer {
+public class ByteArrayBuffer extends Buffer {
public static final int DEFAULT_SIZE = 256;
public static final int MAX_LEN = 65536;
@@ -195,4 +196,20 @@
protected int size() {
return data.length;
}
+
+ /**
+ * @param data The original data buffer
+ * @param off The valid data offset
+ * @param len The valid data length
+ * @return A buffer with a <U>copy</U> of the original data, positioned to
+ * start at offset zero, regardless of the original offset
+ */
+ public static ByteArrayBuffer getCompactClone(byte[] data, int off, int len) {
+ byte[] cloned = (len > 0) ? new byte[len] : GenericUtils.EMPTY_BYTE_ARRAY;
+ if (len > 0) {
+ System.arraycopy(data, off, cloned, 0, len);
+ }
+
+ return new ByteArrayBuffer(cloned, true);
+ }
}
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/auth/keyboard/UserAuthKeyboardInteractive.java b/sshd-core/src/main/java/org/apache/sshd/server/auth/keyboard/UserAuthKeyboardInteractive.java
index 009ddb4..fbfe4a6 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/auth/keyboard/UserAuthKeyboardInteractive.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/auth/keyboard/UserAuthKeyboardInteractive.java
@@ -26,7 +26,6 @@
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.buffer.Buffer;
-import org.apache.sshd.common.util.buffer.BufferUtils;
import org.apache.sshd.server.auth.AbstractUserAuth;
import org.apache.sshd.server.session.ServerSession;
@@ -75,7 +74,7 @@
}
// Prompt for password
- buffer = session.prepareBuffer(SshConstants.SSH_MSG_USERAUTH_INFO_REQUEST, BufferUtils.clear(buffer));
+ buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_INFO_REQUEST);
challenge.append(buffer);
session.writePacket(buffer);
return null;
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/auth/password/UserAuthPassword.java b/sshd-core/src/main/java/org/apache/sshd/server/auth/password/UserAuthPassword.java
index 35704bc..0793c39 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/auth/password/UserAuthPassword.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/auth/password/UserAuthPassword.java
@@ -19,6 +19,7 @@
package org.apache.sshd.server.auth.password;
import org.apache.sshd.common.SshConstants;
+import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
import org.apache.sshd.server.auth.AbstractUserAuth;
@@ -132,7 +133,8 @@
session, prompt, lang);
}
- buffer = session.prepareBuffer(SshConstants.SSH_MSG_USERAUTH_PASSWD_CHANGEREQ, buffer);
+ buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_PASSWD_CHANGEREQ,
+ GenericUtils.length(prompt) + GenericUtils.length(lang) + Integer.SIZE);
buffer.putString(prompt);
buffer.putString(lang);
session.writePacket(buffer);
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java b/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java
index 7540747..48c1bb7 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/auth/pubkey/UserAuthPublicKey.java
@@ -19,15 +19,18 @@
package org.apache.sshd.server.auth.pubkey;
import java.security.PublicKey;
+import java.security.SignatureException;
import java.util.Collection;
import java.util.List;
import org.apache.sshd.common.NamedFactory;
import org.apache.sshd.common.NamedResource;
+import org.apache.sshd.common.RuntimeSshException;
import org.apache.sshd.common.SshConstants;
import org.apache.sshd.common.config.keys.KeyUtils;
import org.apache.sshd.common.signature.Signature;
import org.apache.sshd.common.signature.SignatureFactoriesManager;
+import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
import org.apache.sshd.common.util.buffer.BufferUtils;
@@ -105,32 +108,50 @@
return Boolean.FALSE;
}
- boolean authed = authenticator.authenticate(username, key, session);
+ boolean authed;
+ try {
+ authed = authenticator.authenticate(username, key, session);
+ } catch (Error e) {
+ log.warn("doAuth({}@{}) failed ({}) to consult delegate for {} key={}: {}",
+ username, session, e.getClass().getSimpleName(), alg, KeyUtils.getFingerPrint(key), e.getMessage());
+ if (log.isDebugEnabled()) {
+ log.debug("doAuth(" + username + "@" + session + ") delegate failure details", e);
+ }
+
+ throw new RuntimeSshException(e);
+ }
+
if (log.isDebugEnabled()) {
log.debug("doAuth({}@{}) key type={}, fingerprint={} - authentication result: {}",
username, session, alg, KeyUtils.getFingerPrint(key), authed);
}
+
if (!authed) {
return Boolean.FALSE;
}
if (!hasSig) {
- if (log.isDebugEnabled()) {
- log.debug("doAuth({}@{}) send SSH_MSG_USERAUTH_PK_OK for key type={}, fingerprint={}",
- username, session, alg, KeyUtils.getFingerPrint(key));
- }
-
- Buffer buf = session.prepareBuffer(SshConstants.SSH_MSG_USERAUTH_PK_OK, BufferUtils.clear(buffer));
- buf.putString(alg);
- buf.putRawBytes(buffer.array(), oldPos, 4 + len);
- session.writePacket(buf);
+ sendPublicKeyResponse(session, username, alg, key, buffer.array(), oldPos, 4 + len, buffer);
return null;
}
- // verify signature
+ buffer.rpos(oldPos);
+ buffer.wpos(oldPos + 4 + len);
+ if (!verifySignature(session, getService(), getName(), username, alg, key, buffer, verifier, sig)) {
+ throw new SignatureException("Key verification failed");
+ }
+
+ if (log.isDebugEnabled()) {
+ log.debug("doAuth({}@{}) key type={}, fingerprint={} - verified",
+ username, session, alg, KeyUtils.getFingerPrint(key));
+ }
+
+ return Boolean.TRUE;
+ }
+
+ protected boolean verifySignature(ServerSession session, String service, String name, String username,
+ String alg, PublicKey key, Buffer buffer, Signature verifier, byte[] sig) throws Exception {
byte[] id = session.getSessionId();
- String service = getService();
- String name = getName();
Buffer buf = new ByteArrayBuffer(id.length + username.length() + service.length() + name.length()
+ alg.length() + ByteArrayBuffer.DEFAULT_SIZE + Long.SIZE, false);
buf.putBytes(id);
@@ -140,27 +161,30 @@
buf.putString(name);
buf.putBoolean(true);
buf.putString(alg);
- buffer.rpos(oldPos);
- buffer.wpos(oldPos + 4 + len);
buf.putBuffer(buffer);
if (log.isTraceEnabled()) {
- log.trace("doAuth({}@{}) key type={}, fingerprint={} - verification data={}",
- username, session, alg, KeyUtils.getFingerPrint(key), buf.printHex());
- log.trace("doAuth({}@{}) key type={}, fingerprint={} - expected signature={}",
- username, session, alg, KeyUtils.getFingerPrint(key), BufferUtils.printHex(sig));
+ log.trace("verifySignature({}@{})[{}][{}] key type={}, fingerprint={} - verification data={}",
+ username, session, service, name, alg, KeyUtils.getFingerPrint(key), buf.printHex());
+ log.trace("verifySignature({}@{})[{}][{}] key type={}, fingerprint={} - expected signature={}",
+ username, session, service, name, alg, KeyUtils.getFingerPrint(key), BufferUtils.printHex(sig));
}
verifier.update(buf.array(), buf.rpos(), buf.available());
- if (!verifier.verify(sig)) {
- throw new Exception("Key verification failed");
- }
+ return verifier.verify(sig);
+ }
+ protected void sendPublicKeyResponse(ServerSession session, String username, String alg, PublicKey key,
+ byte[] keyBlob, int offset, int blobLen, Buffer buffer) throws Exception {
if (log.isDebugEnabled()) {
- log.debug("doAuth({}@{}) key type={}, fingerprint={} - verified",
- username, session, alg, KeyUtils.getFingerPrint(key));
+ log.debug("doAuth({}@{}) send SSH_MSG_USERAUTH_PK_OK for key type={}, fingerprint={}",
+ username, session, alg, KeyUtils.getFingerPrint(key));
}
- return Boolean.TRUE;
+ Buffer buf = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_PK_OK,
+ GenericUtils.length(alg) + blobLen + Integer.SIZE);
+ buf.putString(alg);
+ buf.putRawBytes(keyBlob, offset, blobLen);
+ session.writePacket(buf);
}
}
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/channel/ChannelSession.java b/sshd-core/src/main/java/org/apache/sshd/server/channel/ChannelSession.java
index ca8ba28..3b53b84 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/channel/ChannelSession.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/channel/ChannelSession.java
@@ -22,9 +22,11 @@
import java.io.OutputStream;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.TimerTask;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.sshd.agent.SshAgent;
import org.apache.sshd.agent.SshAgentFactory;
@@ -33,12 +35,15 @@
import org.apache.sshd.common.FactoryManager;
import org.apache.sshd.common.NamedFactory;
import org.apache.sshd.common.PropertyResolverUtils;
+import org.apache.sshd.common.RuntimeSshException;
import org.apache.sshd.common.SshConstants;
import org.apache.sshd.common.channel.Channel;
import org.apache.sshd.common.channel.ChannelAsyncOutputStream;
import org.apache.sshd.common.channel.ChannelOutputStream;
import org.apache.sshd.common.channel.PtyMode;
import org.apache.sshd.common.channel.RequestHandler;
+import org.apache.sshd.common.channel.RequestHandler.Result;
+import org.apache.sshd.common.channel.Window;
import org.apache.sshd.common.file.FileSystemAware;
import org.apache.sshd.common.file.FileSystemFactory;
import org.apache.sshd.common.future.CloseFuture;
@@ -78,9 +83,10 @@
protected ChannelAsyncOutputStream asyncErr;
protected OutputStream out;
protected OutputStream err;
- protected Command command;
+ protected Command commandInstance;
protected ChannelDataReceiver receiver;
protected Buffer tempBuffer;
+ protected final AtomicBoolean commandStarted = new AtomicBoolean(false);
protected final StandardEnvironment env = new StandardEnvironment();
protected final CloseFuture commandExitFuture = new DefaultCloseFuture(lock);
@@ -117,7 +123,7 @@
@Override
public CloseFuture close(boolean immediately) {
- if (immediately || command == null) {
+ if (immediately || (commandInstance == null)) {
commandExitFuture.setClosed();
} else if (!commandExitFuture.isClosed()) {
IOException e = IoUtils.closeQuietly(receiver);
@@ -159,21 +165,21 @@
@Override
protected void doCloseImmediately() {
- if (command != null) {
+ if (commandInstance != null) {
try {
- command.destroy();
- } catch (Exception e) {
+ commandInstance.destroy();
+ } catch (Throwable e) {
log.warn("doCloseImmediately({}) failed ({}) to destroy command: {}",
this, e.getClass().getSimpleName(), e.getMessage());
if (log.isDebugEnabled()) {
log.debug("doCloseImmediately(" + this + ") command destruction failure details", e);
}
} finally {
- command = null;
+ commandInstance = null;
}
}
- IOException e = IoUtils.closeQuietly(remoteWindow, out, err, receiver);
+ IOException e = IoUtils.closeQuietly(getRemoteWindow(), out, err, receiver);
if (e != null) {
if (log.isDebugEnabled()) {
log.debug("doCloseImmediately({}) failed ({}) to close resources: {}",
@@ -203,6 +209,10 @@
log.debug("handleEof({}) failed ({}) to close receiver: {}",
this, e.getClass().getSimpleName(), e.getMessage());
}
+
+ if (log.isTraceEnabled()) {
+ log.trace("handleEof(" + this + ") receiver close failure details", e);
+ }
}
}
@@ -215,7 +225,8 @@
if (receiver != null) {
int r = receiver.data(this, data, off, len);
if (r > 0) {
- localWindow.consumeAndCheck(r);
+ Window wLocal = getLocalWindow();
+ wLocal.consumeAndCheck(r);
}
} else {
if (tempBuffer == null) {
@@ -245,7 +256,7 @@
return handleBreak(buffer, wantReply);
case Channel.CHANNEL_SHELL:
if (this.type == null) {
- RequestHandler.Result r = handleShell(buffer, wantReply);
+ RequestHandler.Result r = handleShell(requestType, buffer, wantReply);
if (RequestHandler.Result.ReplySuccess.equals(r) || RequestHandler.Result.Replied.equals(r)) {
this.type = requestType;
}
@@ -259,7 +270,7 @@
}
case Channel.CHANNEL_EXEC:
if (this.type == null) {
- RequestHandler.Result r = handleExec(buffer, wantReply);
+ RequestHandler.Result r = handleExec(requestType, buffer, wantReply);
if (RequestHandler.Result.ReplySuccess.equals(r) || RequestHandler.Result.Replied.equals(r)) {
this.type = requestType;
}
@@ -273,7 +284,7 @@
}
case Channel.CHANNEL_SUBSYSTEM:
if (this.type == null) {
- RequestHandler.Result r = handleSubsystem(buffer, wantReply);
+ RequestHandler.Result r = handleSubsystem(requestType, buffer, wantReply);
if (RequestHandler.Result.ReplySuccess.equals(r) || RequestHandler.Result.Replied.equals(r)) {
this.type = requestType;
}
@@ -294,6 +305,42 @@
}
}
+ @Override
+ protected void sendResponse(Buffer buffer, String req, Result result, boolean wantReply) throws IOException {
+ super.sendResponse(buffer, req, result, wantReply);
+
+ if (!RequestHandler.Result.ReplySuccess.equals(result)) {
+ return;
+ }
+
+ if (commandInstance == null) {
+ if (log.isDebugEnabled()) {
+ log.debug("sendResponse({}) request={} no pending command", this, req);
+ }
+ return; // no pending command to activate
+ }
+
+ if (!Objects.equals(this.type, req)) {
+ if (log.isDebugEnabled()) {
+ log.debug("sendResponse({}) request={} mismatched channel type: {}", this, req, this.type);
+ }
+ return; // request does not match the current channel type
+ }
+
+ if (commandStarted.getAndSet(true)) {
+ if (log.isDebugEnabled()) {
+ log.debug("sendResponse({}) request={} pending command already started", this, req);
+ }
+ return;
+ }
+
+ // TODO - consider if (Channel.CHANNEL_SHELL.equals(req) || Channel.CHANNEL_EXEC.equals(req) || Channel.CHANNEL_SUBSYSTEM.equals(req)) {
+ if (log.isDebugEnabled()) {
+ log.debug("sendResponse({}) request={} activate command", this, req);
+ }
+ commandInstance.start(getEnvironment());
+ }
+
protected RequestHandler.Result handleEnv(Buffer buffer, boolean wantReply) throws IOException {
String name = buffer.getString();
String value = buffer.getString();
@@ -388,7 +435,7 @@
return RequestHandler.Result.ReplySuccess;
}
- protected RequestHandler.Result handleShell(Buffer buffer, boolean wantReply) throws IOException {
+ protected RequestHandler.Result handleShell(String request, Buffer buffer, boolean wantReply) throws IOException {
// If we're already closing, ignore incoming data
if (isClosing()) {
if (log.isDebugEnabled()) {
@@ -406,20 +453,28 @@
return RequestHandler.Result.ReplyFailure;
}
- command = factory.create();
- if (command == null) {
+ try {
+ commandInstance = factory.create();
+ } catch (RuntimeException | Error e) {
+ log.warn("handleShell({}) Failed ({}) to create shell: {}",
+ this, e.getClass().getSimpleName(), e.getMessage());
+ if (log.isDebugEnabled()) {
+ log.debug("handleShell(" + this + ") shell creation failure details", e);
+ }
+ return RequestHandler.Result.ReplyFailure;
+ }
+
+ if (commandInstance == null) {
if (log.isDebugEnabled()) {
log.debug("handleShell({}) - no shell command", this);
}
return RequestHandler.Result.ReplyFailure;
}
- prepareCommand();
- command.start(getEnvironment());
- return RequestHandler.Result.ReplySuccess;
+ return prepareChannelCommand(request, commandInstance);
}
- protected RequestHandler.Result handleExec(Buffer buffer, boolean wantReply) throws IOException {
+ protected RequestHandler.Result handleExec(String request, Buffer buffer, boolean wantReply) throws IOException {
// If we're already closing, ignore incoming data
if (isClosing()) {
return RequestHandler.Result.ReplyFailure;
@@ -438,20 +493,26 @@
}
try {
- command = factory.createCommand(commandLine);
- } catch (RuntimeException e) {
+ commandInstance = factory.createCommand(commandLine);
+ } catch (RuntimeException | Error e) {
log.warn("handleExec({}) Failed ({}) to create command for {}: {}",
this, e.getClass().getSimpleName(), commandLine, e.getMessage());
+ if (log.isDebugEnabled()) {
+ log.debug("handleExec(" + this + ") command=" + commandLine + " creation failure details", e);
+ }
+
return RequestHandler.Result.ReplyFailure;
}
- prepareCommand();
- // Launch command
- command.start(getEnvironment());
- return RequestHandler.Result.ReplySuccess;
+ if (commandInstance == null) {
+ log.warn("handleExec({}) Unsupported command: {}", this, commandLine);
+ return RequestHandler.Result.ReplyFailure;
+ }
+
+ return prepareChannelCommand(request, commandInstance);
}
- protected RequestHandler.Result handleSubsystem(Buffer buffer, boolean wantReply) throws IOException {
+ protected RequestHandler.Result handleSubsystem(String request, Buffer buffer, boolean wantReply) throws IOException {
String subsystem = buffer.getString();
if (log.isDebugEnabled()) {
log.debug("handleSubsystem({})[want-reply={}] sybsystem={}",
@@ -465,15 +526,42 @@
return RequestHandler.Result.ReplyFailure;
}
- command = NamedFactory.Utils.create(factories, subsystem);
- if (command == null) {
+ try {
+ commandInstance = NamedFactory.Utils.create(factories, subsystem);
+ } catch (RuntimeException | Error e) {
+ log.warn("handleSubsystem({}) Failed ({}) to create command for subsystem={}: {}",
+ this, e.getClass().getSimpleName(), subsystem, e.getMessage());
+ if (log.isDebugEnabled()) {
+ log.debug("handleSubsystem(" + this + ") subsystem=" + subsystem + " creation failure details", e);
+ }
+ return RequestHandler.Result.ReplyFailure;
+ }
+
+ if (commandInstance == null) {
log.warn("handleSubsystem({}) Unsupported subsystem: {}", this, subsystem);
return RequestHandler.Result.ReplyFailure;
}
- prepareCommand();
- // Launch command
- command.start(getEnvironment());
+ return prepareChannelCommand(request, commandInstance);
+ }
+
+ protected RequestHandler.Result prepareChannelCommand(String request, Command cmd) throws IOException {
+ Command command = prepareCommand(request, cmd);
+ if (command == null) {
+ log.warn("prepareChannelCommand({})[{}] no command prepared", this, request);
+ return RequestHandler.Result.ReplyFailure;
+ }
+
+ if (command != cmd) {
+ if (log.isDebugEnabled()) {
+ log.debug("prepareChannelCommand({})[{}] replaced original command", this, request);
+ }
+ commandInstance = command;
+ }
+
+ if (log.isDebugEnabled()) {
+ log.debug("prepareChannelCommand({})[{}] prepared command", this, request);
+ }
return RequestHandler.Result.ReplySuccess;
}
@@ -491,7 +579,23 @@
this.receiver = receiver;
}
- protected void prepareCommand() throws IOException {
+ /**
+ * Called by {@link #prepareChannelCommand(String, Command)} in order to set
+ * up the command's streams, session, file-system, exit callback, etc..
+ *
+ * @param requestType The request that caused the command to be created
+ * @param command The created {@link Command} - may be {@code null}
+ * @return The updated command instance - if {@code null} then the request that
+ * initially caused the creation of the command is failed and the original command
+ * (if any) destroyed (eventually). <B>Note:</B> if a different command instance
+ * than the input one is returned, then it is up to the implementor to take care
+ * of the wrapping or destruction of the original command instance.
+ * @throws IOException If failed to prepare the command
+ */
+ protected Command prepareCommand(String requestType, Command command) throws IOException {
+ if (command == null) {
+ return null;
+ }
// Add the user
Session session = getSession();
addEnvVariable(Environment.ENV_USER, session.getUsername());
@@ -515,13 +619,13 @@
((AsyncCommand) command).setIoOutputStream(asyncOut);
((AsyncCommand) command).setIoErrorStream(asyncErr);
} else {
- out = new ChannelOutputStream(this, remoteWindow, log, SshConstants.SSH_MSG_CHANNEL_DATA);
- err = new ChannelOutputStream(this, remoteWindow, log, SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA);
+ Window wRemote = getRemoteWindow();
+ out = new ChannelOutputStream(this, wRemote, log, SshConstants.SSH_MSG_CHANNEL_DATA);
+ err = new ChannelOutputStream(this, wRemote, log, SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA);
if (log.isTraceEnabled()) {
// Wrap in logging filters
- String channelId = toString();
- out = new LoggingFilterOutputStream(out, "OUT(" + channelId + ")", log);
- err = new LoggingFilterOutputStream(err, "ERR(" + channelId + ")", log);
+ out = new LoggingFilterOutputStream(out, "OUT(" + this + ")", log);
+ err = new LoggingFilterOutputStream(err, "ERR(" + this + ")", log);
}
command.setOutputStream(out);
command.setErrorStream(err);
@@ -534,7 +638,7 @@
setDataReceiver(recv);
((AsyncCommand) command).setIoInputStream(recv.getIn());
} else {
- PipeDataReceiver recv = new PipeDataReceiver(this, localWindow);
+ PipeDataReceiver recv = new PipeDataReceiver(this, getLocalWindow());
setDataReceiver(recv);
command.setInputStream(recv.getIn());
}
@@ -564,6 +668,8 @@
}
}
});
+
+ return command;
}
protected int getPtyModeValue(PtyMode mode) {
@@ -576,11 +682,20 @@
FactoryManager manager = ValidateUtils.checkNotNull(session.getFactoryManager(), "No session factory manager");
ForwardingFilter filter = manager.getTcpipForwardingFilter();
SshAgentFactory factory = manager.getAgentFactory();
- if ((factory == null) || (filter == null) || (!filter.canForwardAgent(session))) {
- if (log.isDebugEnabled()) {
- log.debug("handleAgentForwarding(" + this + ")[haveFactory=" + (factory != null) + ",haveFilter=" + (filter != null) + "] filtered out");
+ try {
+ if ((factory == null) || (filter == null) || (!filter.canForwardAgent(session))) {
+ if (log.isDebugEnabled()) {
+ log.debug("handleAgentForwarding(" + this + ")[haveFactory=" + (factory != null) + ",haveFilter=" + (filter != null) + "] filtered out");
+ }
+ return RequestHandler.Result.ReplyFailure;
}
- return RequestHandler.Result.ReplyFailure;
+ } catch (Error e) {
+ log.warn("handleAgentForwarding({}) failed ({}) to consult forwarding filter: {}",
+ this, e.getClass().getSimpleName(), e.getMessage());
+ if (log.isDebugEnabled()) {
+ log.debug("handleAgentForwarding(" + this + ") filter consultation failure details", e);
+ }
+ throw new RuntimeSshException(e);
}
String authSocket = service.initAgentForward();
@@ -595,14 +710,23 @@
String authCookie = buffer.getString();
int screenId = buffer.getInt();
- FactoryManager manager = session.getFactoryManager();
+ FactoryManager manager = ValidateUtils.checkNotNull(session.getFactoryManager(), "No factory manager");
ForwardingFilter filter = manager.getTcpipForwardingFilter();
- if ((filter == null) || (!filter.canForwardX11(session))) {
- if (log.isDebugEnabled()) {
- log.debug("handleX11Forwarding({}) single={}, protocol={}, cookie={}, screen={}, filter={}: filtered",
- this, singleConnection, authProtocol, authCookie, screenId, filter);
+ try {
+ if ((filter == null) || (!filter.canForwardX11(session))) {
+ if (log.isDebugEnabled()) {
+ log.debug("handleX11Forwarding({}) single={}, protocol={}, cookie={}, screen={}, filter={}: filtered",
+ this, singleConnection, authProtocol, authCookie, screenId, filter);
+ }
+ return RequestHandler.Result.ReplyFailure;
}
- return RequestHandler.Result.ReplyFailure;
+ } catch (Error e) {
+ log.warn("handleX11Forwarding({}) failed ({}) to consult forwarding filter: {}",
+ this, e.getClass().getSimpleName(), e.getMessage());
+ if (log.isDebugEnabled()) {
+ log.debug("handleX11Forwarding(" + this + ") filter consultation failure details", e);
+ }
+ throw new RuntimeSshException(e);
}
String display = service.createX11Display(singleConnection, authProtocol, authCookie, screenId);
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/global/CancelTcpipForwardHandler.java b/sshd-core/src/main/java/org/apache/sshd/server/global/CancelTcpipForwardHandler.java
index 23724b5..fd552e8 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/global/CancelTcpipForwardHandler.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/global/CancelTcpipForwardHandler.java
@@ -26,7 +26,6 @@
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.util.Int2IntFunction;
import org.apache.sshd.common.util.buffer.Buffer;
-import org.apache.sshd.common.util.buffer.BufferUtils;
/**
* Handler for cancel-tcpip-forward global request.
@@ -65,7 +64,7 @@
if (wantReply) {
Session session = connectionService.getSession();
- buffer = session.prepareBuffer(SshConstants.SSH_MSG_REQUEST_SUCCESS, BufferUtils.clear(buffer));
+ buffer = session.createBuffer(SshConstants.SSH_MSG_REQUEST_SUCCESS, Integer.SIZE / Byte.SIZE);
buffer.putInt(port);
session.writePacket(buffer);
}
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/global/OpenSshHostKeysHandler.java b/sshd-core/src/main/java/org/apache/sshd/server/global/OpenSshHostKeysHandler.java
index 92c49b9..5c4e080 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/global/OpenSshHostKeysHandler.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/global/OpenSshHostKeysHandler.java
@@ -36,7 +36,6 @@
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
-import org.apache.sshd.common.util.buffer.BufferUtils;
import org.apache.sshd.common.util.buffer.ByteArrayBuffer;
import org.apache.sshd.common.util.buffer.keys.BufferPublicKeyParser;
@@ -96,7 +95,7 @@
}
// generate the required signatures
- buffer = session.prepareBuffer(SshConstants.SSH_MSG_REQUEST_SUCCESS, BufferUtils.clear(buffer));
+ buffer = session.createBuffer(SshConstants.SSH_MSG_REQUEST_SUCCESS);
Buffer buf = new ByteArrayBuffer();
byte[] sessionId = session.getSessionId();
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/global/TcpipForwardHandler.java b/sshd-core/src/main/java/org/apache/sshd/server/global/TcpipForwardHandler.java
index 3b4618b..868ecef 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/global/TcpipForwardHandler.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/global/TcpipForwardHandler.java
@@ -26,7 +26,6 @@
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.util.Int2IntFunction;
import org.apache.sshd.common.util.buffer.Buffer;
-import org.apache.sshd.common.util.buffer.BufferUtils;
/**
* Handler for tcpip-forward global request.
@@ -70,7 +69,7 @@
port = bound.getPort();
if (wantReply) {
Session session = connectionService.getSession();
- buffer = session.prepareBuffer(SshConstants.SSH_MSG_REQUEST_SUCCESS, BufferUtils.clear(buffer));
+ buffer = session.createBuffer(SshConstants.SSH_MSG_REQUEST_SUCCESS, Integer.SIZE / Byte.SIZE);
buffer.putInt(port);
session.writePacket(buffer);
}
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGEXServer.java b/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGEXServer.java
index b05a763..1beaf73 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGEXServer.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGEXServer.java
@@ -126,7 +126,7 @@
log.debug("next({})[{}] send SSH_MSG_KEX_DH_GEX_GROUP", this, session);
}
- buffer = session.prepareBuffer(SshConstants.SSH_MSG_KEX_DH_GEX_GROUP, BufferUtils.clear(buffer));
+ buffer = session.createBuffer(SshConstants.SSH_MSG_KEX_DH_GEX_GROUP);
buffer.putMPInt(dh.getP());
buffer.putMPInt(dh.getG());
session.writePacket(buffer);
@@ -151,7 +151,7 @@
if (log.isDebugEnabled()) {
log.debug("next({})[{}] Send SSH_MSG_KEX_DH_GEX_GROUP", this, session);
}
- buffer = session.prepareBuffer(SshConstants.SSH_MSG_KEX_DH_GEX_GROUP, BufferUtils.clear(buffer));
+ buffer = session.createBuffer(SshConstants.SSH_MSG_KEX_DH_GEX_GROUP);
buffer.putMPInt(dh.getP());
buffer.putMPInt(dh.getG());
session.writePacket(buffer);
@@ -222,10 +222,8 @@
if (log.isDebugEnabled()) {
log.debug("next({})[{}] Send SSH_MSG_KEX_DH_GEX_REPLY", this, session);
}
- buffer.clear();
- buffer.rpos(5);
- buffer.wpos(5);
- buffer.putByte(SshConstants.SSH_MSG_KEX_DH_GEX_REPLY);
+
+ buffer = session.prepareBuffer(SshConstants.SSH_MSG_KEX_DH_GEX_REPLY, BufferUtils.clear(buffer));
buffer.putBytes(k_s);
buffer.putBytes(f);
buffer.putBytes(sigH);
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGServer.java b/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGServer.java
index dbc6cc8..dc4a414 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGServer.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGServer.java
@@ -139,10 +139,8 @@
if (log.isDebugEnabled()) {
log.debug("next({})[{}] Send SSH_MSG_KEXDH_REPLY", this, session);
}
- buffer.clear();
- buffer.rpos(5);
- buffer.wpos(5);
- buffer.putByte(SshConstants.SSH_MSG_KEXDH_REPLY);
+
+ buffer = session.prepareBuffer(SshConstants.SSH_MSG_KEXDH_REPLY, BufferUtils.clear(buffer));
buffer.putBytes(k_s);
buffer.putBytes(f);
buffer.putBytes(sigH);
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/session/ServerUserAuthService.java b/sshd-core/src/main/java/org/apache/sshd/server/session/ServerUserAuthService.java
index 518b319..5880295 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/session/ServerUserAuthService.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/session/ServerUserAuthService.java
@@ -33,7 +33,6 @@
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.common.util.ValidateUtils;
import org.apache.sshd.common.util.buffer.Buffer;
-import org.apache.sshd.common.util.buffer.BufferUtils;
import org.apache.sshd.common.util.closeable.AbstractCloseable;
import org.apache.sshd.server.ServerAuthenticationManager;
import org.apache.sshd.server.ServerFactoryManager;
@@ -258,7 +257,8 @@
String lang = PropertyResolverUtils.getStringProperty(session,
ServerFactoryManager.WELCOME_BANNER_LANGUAGE,
ServerFactoryManager.DEFAULT_WELCOME_BANNER_LANGUAGE);
- buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_BANNER, welcomeBanner.length() + lang.length() + Long.SIZE);
+ buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_BANNER,
+ welcomeBanner.length() + GenericUtils.length(lang) + Long.SIZE);
buffer.putString(welcomeBanner);
buffer.putString(lang);
@@ -277,7 +277,6 @@
session.resetIdleTimeout();
log.info("Session {}@{} authenticated", username, session.getIoSession().getRemoteAddress());
} else {
- buffer = session.prepareBuffer(SshConstants.SSH_MSG_USERAUTH_FAILURE, BufferUtils.clear(buffer));
StringBuilder sb = new StringBuilder();
for (List<String> l : authMethods) {
if (GenericUtils.size(l) > 0) {
@@ -292,6 +291,8 @@
if (log.isDebugEnabled()) {
log.debug("handleAuthenticationSuccess({}@{}) remaining methods={}", username, session, remaining);
}
+
+ buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_FAILURE, remaining.length() + Byte.SIZE);
buffer.putString(remaining);
buffer.putBoolean(true); // partial success ...
session.writePacket(buffer);
@@ -312,7 +313,6 @@
username, session, SshConstants.getCommandMessageName(cmd));
}
- buffer = session.prepareBuffer(SshConstants.SSH_MSG_USERAUTH_FAILURE, BufferUtils.clear(buffer));
StringBuilder sb = new StringBuilder((authMethods.size() + 1) * Byte.SIZE);
for (List<String> l : authMethods) {
if (GenericUtils.size(l) > 0) {
@@ -330,6 +330,8 @@
if (log.isDebugEnabled()) {
log.debug("handleAuthenticationFailure({}@{}) remaining methods: {}", username, session, remaining);
}
+
+ buffer = session.createBuffer(SshConstants.SSH_MSG_USERAUTH_FAILURE, remaining.length() + Byte.SIZE);
buffer.putString(remaining);
buffer.putBoolean(false); // no partial success ...
session.writePacket(buffer);
diff --git a/sshd-core/src/test/java/org/apache/sshd/common/auth/AuthenticationTest.java b/sshd-core/src/test/java/org/apache/sshd/common/auth/AuthenticationTest.java
index afaa8a5..f98ced2 100644
--- a/sshd-core/src/test/java/org/apache/sshd/common/auth/AuthenticationTest.java
+++ b/sshd-core/src/test/java/org/apache/sshd/common/auth/AuthenticationTest.java
@@ -570,6 +570,9 @@
assertFalse("Unexpected authentication success", future.isSuccess());
Throwable actual = future.getException();
+ if (actual instanceof IOException) {
+ actual = actual.getCause();
+ }
assertSame("Mismatched authentication failure reason", expected, actual);
} finally {
client.stop();
@@ -657,7 +660,7 @@
}
}
- @Test
+ @Test // see SSHD-620
public void testHostBasedAuthentication() throws Exception {
final String CLIENT_USERNAME = getClass().getSimpleName();
final String CLIENT_HOSTNAME = SshdSocketAddress.toAddressString(SshdSocketAddress.getFirstExternalNetwork4Address());
@@ -696,6 +699,63 @@
}
}
+ @Test // see SSHD-625
+ public void testRuntimeErrorsInAuthenticators() throws Exception {
+ final Error thrown = new OutOfMemoryError(getCurrentTestName());
+ final PasswordAuthenticator authPassword = sshd.getPasswordAuthenticator();
+ final AtomicInteger passCounter = new AtomicInteger(0);
+ sshd.setPasswordAuthenticator(new PasswordAuthenticator() {
+ @Override
+ public boolean authenticate(String username, String password, ServerSession session)
+ throws PasswordChangeRequiredException {
+ int count = passCounter.incrementAndGet();
+ if (count == 1) {
+ throw thrown;
+ }
+ return authPassword.authenticate(username, password, session);
+ }
+ });
+
+ final PublickeyAuthenticator authPubkey = sshd.getPublickeyAuthenticator();
+ final AtomicInteger pubkeyCounter = new AtomicInteger(0);
+ sshd.setPublickeyAuthenticator(new PublickeyAuthenticator() {
+ @Override
+ public boolean authenticate(String username, PublicKey key, ServerSession session) {
+ int count = pubkeyCounter.incrementAndGet();
+ if (count == 1) {
+ throw thrown;
+ }
+ return authPubkey.authenticate(username, key, session);
+ }
+ });
+ sshd.setKeyboardInteractiveAuthenticator(KeyboardInteractiveAuthenticator.NONE);
+
+ try (SshClient client = setupTestClient()) {
+ KeyPair kp = Utils.generateKeyPair("RSA", 1024);
+ client.start();
+ try {
+ for (int index = 1; index < 3; index++) {
+ try (ClientSession s = client.connect(getCurrentTestName(), TEST_LOCALHOST, port).verify(7L, TimeUnit.SECONDS).getSession()) {
+ s.addPasswordIdentity(getCurrentTestName());
+ s.addPublicKeyIdentity(kp);
+
+ AuthFuture auth = s.auth();
+ assertTrue("Failed to complete authentication on time", auth.await(11L, TimeUnit.SECONDS));
+ if (auth.isSuccess()) {
+ assertTrue("Premature authentication success", index > 1);
+ break;
+ }
+
+ assertEquals("Password authenticator not consulted", 1, passCounter.get());
+ assertEquals("Pubkey authenticator not consulted", 1, pubkeyCounter.get());
+ }
+ }
+ } finally {
+ client.stop();
+ }
+ }
+ }
+
private static void assertAuthenticationResult(String message, AuthFuture future, boolean expected) throws IOException {
assertTrue(message + ": failed to get result on time", future.await(5L, TimeUnit.SECONDS));
assertEquals(message + ": mismatched authentication result", expected, future.isSuccess());