blob: bb43718f78be3802e6343c4e877dbeef1c3374f3 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sshd;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collection;
import java.util.Deque;
import java.util.EnumSet;
import java.util.LinkedList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.sshd.client.SshClient;
import org.apache.sshd.client.channel.ClientChannel;
import org.apache.sshd.client.channel.ClientChannelEvent;
import org.apache.sshd.client.session.ClientSession;
import org.apache.sshd.common.io.IoInputStream;
import org.apache.sshd.common.io.IoOutputStream;
import org.apache.sshd.common.io.WritePendingException;
import org.apache.sshd.common.util.buffer.Buffer;
import org.apache.sshd.common.util.buffer.ByteArrayBuffer;
import org.apache.sshd.common.util.io.IoUtils;
import org.apache.sshd.common.util.io.NoCloseOutputStream;
import org.apache.sshd.common.util.logging.AbstractLoggingBean;
import org.apache.sshd.common.util.threads.ThreadUtils;
import org.apache.sshd.server.Environment;
import org.apache.sshd.server.ExitCallback;
import org.apache.sshd.server.SshServer;
import org.apache.sshd.server.channel.ChannelSession;
import org.apache.sshd.server.command.AsyncCommand;
import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider;
import org.apache.sshd.util.test.BaseTestSupport;
import org.junit.After;
import org.junit.Before;
import org.junit.FixMethodOrder;
import org.junit.Test;
import org.junit.runners.MethodSorters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* <p>
* This test simulates heavy traffic coming from the server towards the client making sure the traffic does not get
* stuck. Especially if the server receives window adjust message while it tries to transfer all the data.
* </p>
* {@link AsyncInPendingWrapper} in this test serves as a handler for {@link WritePendingException}, which can occur
* when sending too many messages one after another.
*/
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
public class WindowAdjustTest extends BaseTestSupport {
public static final byte END_FILE = '#';
public static final int BIG_MSG_SEND_COUNT = 10000;
private SshServer sshServer;
private int port;
public WindowAdjustTest() {
super();
}
@Before
public void setUp() throws Exception {
sshServer = setupTestServer();
byte[] msg = IoUtils.toByteArray(getClass().getResourceAsStream("/big-msg.txt"));
sshServer.setShellFactory(
channel -> new FloodingAsyncCommand(msg, BIG_MSG_SEND_COUNT, END_FILE));
sshServer.setKeyPairProvider(new SimpleGeneratorHostKeyProvider());
sshServer.start();
port = sshServer.getPort();
}
@After
public void tearDown() throws Exception {
if (sshServer != null) {
sshServer.stop();
sshServer.close(true);
}
}
@Test(timeout = 6L * 60L * 1000L)
public void testTrafficHeavyLoad() throws Exception {
try (SshClient client = setupTestClient()) {
client.start();
try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port)
.verify(CONNECT_TIMEOUT).getSession()) {
session.addPasswordIdentity(getCurrentTestName());
session.auth().verify(AUTH_TIMEOUT);
try (ClientChannel channel = session.createShellChannel()) {
channel.setOut(new VerifyingOutputStream(channel, END_FILE));
channel.setErr(new NoCloseOutputStream(System.err));
channel.open().verify(OPEN_TIMEOUT);
Collection<ClientChannelEvent> result
= channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), TimeUnit.MINUTES.toMillis(2L));
assertFalse("Timeout while waiting for channel closure", result.contains(ClientChannelEvent.TIMEOUT));
}
} finally {
client.stop();
}
}
}
/**
* Read all incoming data and if END_FILE symbol is detected, kill client session to end test
*/
private static class VerifyingOutputStream extends OutputStream {
private final Logger log;
private final ClientChannel channel;
private final byte eofSignal;
VerifyingOutputStream(ClientChannel channel, final byte eofSignal) {
this.log = LoggerFactory.getLogger(getClass());
this.channel = channel;
this.eofSignal = eofSignal;
}
@Override
public void write(int b) throws IOException {
if (channel.isClosed() || channel.isClosing()) {
throw new IOException("Channel (" + channel + ") is closing / closed on write single byte");
}
if (b == (eofSignal & 0xff)) {
log.info("Closing channel (" + channel + ") due to single byte EOF");
channel.close(true);
}
}
@Override
public void write(byte[] b, int off, int len) throws IOException {
if (channel.isClosed() || channel.isClosing()) {
throw new IOException("Channel (" + channel + ") is closing / closed on write " + len + " bytes");
}
if (len <= 0) {
return;
}
int lastPos = off + len - 1;
if ((b[lastPos] & 0xFF) == (eofSignal & 0xFF)) {
log.info("Closing channel (" + channel + ") due to last byte EOF");
channel.close(true);
}
}
}
public static final class FloodingAsyncCommand extends AbstractLoggingBean implements AsyncCommand {
private static final AtomicInteger POOL_COUNT = new AtomicInteger(0);
private final AtomicReference<ExecutorService> executorHolder = new AtomicReference<>();
private final AtomicReference<Future<?>> futureHolder = new AtomicReference<>();
private AsyncInPendingWrapper pendingWrapper;
private byte[] msg;
private int sendCount;
private byte eofSignal;
public FloodingAsyncCommand(final byte[] msg, final int sendCount, final byte eofSignal) {
this.msg = msg;
this.sendCount = sendCount;
this.eofSignal = eofSignal;
}
@Override
public void setIoInputStream(IoInputStream in) {
// ignored
}
@Override
public void setIoOutputStream(IoOutputStream out) {
pendingWrapper = new AsyncInPendingWrapper(out);
}
@Override
public void setIoErrorStream(IoOutputStream err) {
// ignored
}
@Override
public void setInputStream(InputStream in) {
// ignored
}
@Override
public void setOutputStream(OutputStream out) {
// ignored
}
@Override
public void setErrorStream(OutputStream err) {
// ignored
}
@Override
public void setExitCallback(ExitCallback callback) {
// ignored
}
@Override
public void start(ChannelSession channel, Environment env) throws IOException {
log.info("Starting");
ExecutorService service
= ThreadUtils.newSingleThreadExecutor(getClass().getSimpleName() + "-" + POOL_COUNT.incrementAndGet());
executorHolder.set(service);
futureHolder.set(service.submit((Runnable) () -> {
log.info("Start heavy load sending " + sendCount + " messages of " + msg.length + " bytes");
for (int i = 0; i < sendCount; i++) {
try {
pendingWrapper.write(new ByteArrayBuffer(msg));
} catch (IOException e) {
log.error("Failed ({}) to send message #{}/{}: {}",
e.getClass().getSimpleName(), i + 1, sendCount, e.getMessage());
throw new RuntimeException(e);
}
}
log.info("Sending EOF signal");
try {
pendingWrapper.write(new ByteArrayBuffer(new byte[] { eofSignal }));
} catch (IOException e) {
log.error("Failed ({}) to send EOF message after {} messages: {}",
e.getClass().getSimpleName(), sendCount, e.getMessage());
throw new RuntimeException(e);
}
}));
log.info("Started");
}
@Override
public void destroy(ChannelSession channel) {
log.info("Destroying");
Future<?> future = futureHolder.getAndSet(null);
if ((future != null) && (!future.isDone())) {
log.info("Cancelling");
future.cancel(true);
}
ExecutorService service = executorHolder.getAndSet(null);
if ((service != null) && (!service.isShutdown())) {
log.info("Shutdown");
service.shutdownNow();
}
}
}
/**
* Wrapper for asyncIn stream that catches Pending exception and queues the pending messages for later retry (send
* after previous messages were fully transfered)
*/
private static class AsyncInPendingWrapper extends AbstractLoggingBean {
private IoOutputStream asyncIn;
// Order has to be preserved for queued writes
private final Deque<Buffer> pending = new LinkedList<>();
AsyncInPendingWrapper(IoOutputStream out) {
this.asyncIn = out;
}
public synchronized void write(Object msg) throws IOException {
if ((asyncIn != null) && (!asyncIn.isClosed()) && (!asyncIn.isClosing())) {
Buffer byteBufferMsg = (Buffer) msg;
if (!pending.isEmpty()) {
queueRequest(byteBufferMsg);
return;
}
writeWithPendingDetection(byteBufferMsg, false);
}
}
private void writeWithPendingDetection(Buffer msg, boolean wasPending) throws IOException {
try {
asyncIn.writeBuffer(msg).addListener(future -> {
if (future.isWritten()) {
if (wasPending) {
pending.remove();
}
try {
writePendingIfAny();
} catch (IOException e) {
log.error("Failed ({}) to re-write pending: {}", e.getClass().getSimpleName(), e.getMessage());
}
} else {
Throwable t = future.getException();
log.warn("Failed to write message", t);
}
});
} catch (final WritePendingException e) {
if (!wasPending) {
queueRequest(msg);
}
}
}
private synchronized void writePendingIfAny() throws IOException {
if (pending.peek() == null) {
return;
}
Buffer msg = pending.peek();
writeWithPendingDetection(msg, true);
}
private void queueRequest(Buffer msg) {
msg.rpos(0);
pending.add(msg);
}
}
}