blob: 57f7e7f939fdbf289741b902713a7b74b8cd36f2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sshd.common.forward;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import com.jcraft.jsch.JSch;
import com.jcraft.jsch.JSchException;
import com.jcraft.jsch.Session;
import org.apache.commons.httpclient.HostConfiguration;
import org.apache.commons.httpclient.HttpClient;
import org.apache.commons.httpclient.HttpVersion;
import org.apache.commons.httpclient.MultiThreadedHttpConnectionManager;
import org.apache.commons.httpclient.methods.GetMethod;
import org.apache.mina.core.buffer.IoBuffer;
import org.apache.mina.core.service.IoAcceptor;
import org.apache.mina.core.service.IoHandlerAdapter;
import org.apache.mina.core.session.IoSession;
import org.apache.mina.transport.socket.nio.NioSocketAcceptor;
import org.apache.sshd.common.util.net.SshdSocketAddress;
import org.apache.sshd.common.util.security.SecurityUtils;
import org.apache.sshd.core.CoreModuleProperties;
import org.apache.sshd.server.SshServer;
import org.apache.sshd.server.forward.AcceptAllForwardingFilter;
import org.apache.sshd.util.test.BaseTestSupport;
import org.apache.sshd.util.test.CoreTestSupportUtils;
import org.apache.sshd.util.test.JSchLogger;
import org.apache.sshd.util.test.SimpleUserInfo;
import org.junit.After;
import org.junit.Assume;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.FixMethodOrder;
import org.junit.Test;
import org.junit.runners.MethodSorters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Port forwarding tests
*/
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
public class PortForwardingLoadTest extends BaseTestSupport {
private final Logger log;
@SuppressWarnings({ "checkstyle:anoninnerlength", "synthetic-access" })
private final PortForwardingEventListener serverSideListener = new PortForwardingEventListener() {
@Override
public void establishingExplicitTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress local, SshdSocketAddress remote,
boolean localForwarding)
throws IOException {
log.info("establishingExplicitTunnel(session={}, local={}, remote={}, localForwarding={})",
session, local, remote, localForwarding);
}
@Override
public void establishedExplicitTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress local,
SshdSocketAddress remote, boolean localForwarding, SshdSocketAddress boundAddress, Throwable reason)
throws IOException {
log.info("establishedExplicitTunnel(session={}, local={}, remote={}, bound={}, localForwarding={}): {}",
session, local, remote, boundAddress, localForwarding, reason);
}
@Override
public void tearingDownExplicitTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress address, boolean localForwarding,
SshdSocketAddress remoteAddress)
throws IOException {
log.info("tearingDownExplicitTunnel(session={}, address={}, localForwarding={}, remote={})",
session, address, localForwarding, remoteAddress);
}
@Override
public void tornDownExplicitTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress address, boolean localForwarding,
SshdSocketAddress remoteAddress, Throwable reason)
throws IOException {
log.info("tornDownExplicitTunnel(session={}, address={}, localForwarding={}, remote={}, reason={})",
session, address, localForwarding, remoteAddress, reason);
}
@Override
public void establishingDynamicTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress local)
throws IOException {
log.info("establishingDynamicTunnel(session={}, local={})", session, local);
}
@Override
public void establishedDynamicTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress local, SshdSocketAddress boundAddress,
Throwable reason)
throws IOException {
log.info("establishedDynamicTunnel(session={}, local={}, bound={}, reason={})", session, local, boundAddress,
reason);
}
@Override
public void tearingDownDynamicTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress address)
throws IOException {
log.info("tearingDownDynamicTunnel(session={}, address={})", session, address);
}
@Override
public void tornDownDynamicTunnel(
org.apache.sshd.common.session.Session session, SshdSocketAddress address, Throwable reason)
throws IOException {
log.info("tornDownDynamicTunnel(session={}, address={}, reason={})", session, address, reason);
}
};
private SshServer sshd;
private int sshPort;
private IoAcceptor acceptor;
public PortForwardingLoadTest() {
log = LoggerFactory.getLogger(getClass());
}
@BeforeClass
public static void jschInit() {
// FIXME inexplicably these tests fail without BC since SSHD-1004
Assume.assumeTrue("Requires BC security provider", SecurityUtils.isBouncyCastleRegistered());
JSchLogger.init();
}
@Before
public void setUp() throws Exception {
sshd = setupTestFullSupportServer();
sshd.setForwardingFilter(AcceptAllForwardingFilter.INSTANCE);
sshd.addPortForwardingEventListener(serverSideListener);
sshd.start();
sshPort = sshd.getPort();
NioSocketAcceptor acceptor = new NioSocketAcceptor();
acceptor.setHandler(new IoHandlerAdapter() {
@Override
public void messageReceived(IoSession session, Object message) throws Exception {
IoBuffer recv = (IoBuffer) message;
IoBuffer sent = IoBuffer.allocate(recv.remaining());
sent.put(recv);
sent.flip();
session.write(sent);
}
});
acceptor.setReuseAddress(true);
acceptor.bind(new InetSocketAddress(0));
log.info("setUp() echo address = {}", acceptor.getLocalAddress());
this.acceptor = acceptor;
}
@After
public void tearDown() throws Exception {
if (sshd != null) {
sshd.stop(true);
}
if (acceptor != null) {
acceptor.dispose(true);
}
}
@Test
@SuppressWarnings("checkstyle:nestedtrydepth")
public void testLocalForwardingPayload() throws Exception {
final int numIterations = 100;
final String payloadTmpData = "This is significantly longer Test Data. This is significantly "
+ "longer Test Data. This is significantly longer Test Data. This is significantly "
+ "longer Test Data. This is significantly longer Test Data. This is significantly "
+ "longer Test Data. This is significantly longer Test Data. This is significantly "
+ "longer Test Data. This is significantly longer Test Data. This is significantly "
+ "longer Test Data. ";
StringBuilder sb = new StringBuilder(payloadTmpData.length() * 1000);
for (int i = 0; i < 1000; i++) {
sb.append(payloadTmpData);
}
String payload = sb.toString();
final byte[] dataBytes = payload.getBytes(StandardCharsets.UTF_8);
final int reportPhase = dataBytes.length / 10;
log.info("{} using payload size={}", getCurrentTestName(), dataBytes.length);
AtomicInteger errors = new AtomicInteger();
Session session = createSession();
try (ServerSocket ss = new ServerSocket()) {
ss.setReuseAddress(true);
ss.bind(new InetSocketAddress((InetAddress) null, 0));
int forwardedPort = ss.getLocalPort();
int sinkPort = session.setPortForwardingL(0, TEST_LOCALHOST, forwardedPort);
log.info("{} forwardedPort={}, sinkPort={}", getCurrentTestName(), forwardedPort, sinkPort);
AtomicInteger conCount = new AtomicInteger(0);
Semaphore iterationsSignal = new Semaphore(0);
@SuppressWarnings("checkstyle:anoninnerlength")
Thread tAcceptor = new Thread(getCurrentTestName() + "Acceptor") {
@SuppressWarnings("synthetic-access")
@Override
public void run() {
try {
byte[] buf = new byte[8192];
log.info("Started...");
for (int i = 0; i < numIterations; ++i) {
try (Socket s = ss.accept()) {
int totalConns = conCount.incrementAndGet();
log.info("Accepted connection #{} from {}", totalConns, s.getRemoteSocketAddress());
try (InputStream sockIn = s.getInputStream();
ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
for (int readSize = 0, lastReport = 0; readSize < dataBytes.length;) {
int l = sockIn.read(buf);
if (l < 0) {
break;
}
baos.write(buf, 0, l);
readSize += l;
if ((readSize - lastReport) >= reportPhase) {
log.info("Read {}/{} bytes of iteration #{}", readSize, dataBytes.length, i);
lastReport = readSize;
}
}
assertPayloadEquals("Mismatched received data at iteration #" + i, dataBytes,
baos.toByteArray());
byte[] outBytes = baos.toByteArray();
try (InputStream inputCopy = new ByteArrayInputStream(outBytes);
OutputStream sockOut = s.getOutputStream()) {
for (int writeSize = 0, lastReport = 0; writeSize < outBytes.length;) {
int l = inputCopy.read(buf);
if (l < 0) {
break;
}
sockOut.write(buf, 0, l);
writeSize += l;
if ((writeSize - lastReport) >= reportPhase) {
log.info("Written {}/{} bytes of iteration #{}", writeSize, dataBytes.length,
i);
lastReport = writeSize;
}
}
}
}
}
log.info("Finished iteration {}/{}", i, numIterations);
iterationsSignal.release();
}
log.info("Done");
} catch (Exception e) {
log.error("Failed to complete run loop", e);
}
}
};
tAcceptor.start();
Thread.sleep(TimeUnit.SECONDS.toMillis(1L));
byte[] buf = new byte[8192];
for (int i = 0; i < numIterations; i++) {
log.debug("Iteration {}/{} started", i, numIterations);
try (Socket s = new Socket(TEST_LOCALHOST, sinkPort);
OutputStream sockOut = s.getOutputStream()) {
log.debug("Iteration {} connected to {}", i, s.getRemoteSocketAddress());
s.setSoTimeout((int) CoreModuleProperties.NIO2_MIN_WRITE_TIMEOUT.getRequiredDefault().toMillis());
sockOut.write(dataBytes);
sockOut.flush();
log.debug("Iteration {} awaiting echoed data", i);
try (InputStream sockIn = s.getInputStream();
ByteArrayOutputStream baos = new ByteArrayOutputStream(dataBytes.length)) {
for (int readSize = 0, lastReport = 0; readSize < dataBytes.length;) {
try {
int l = sockIn.read(buf);
if (l < 0) {
break;
}
baos.write(buf, 0, l);
readSize += l;
if ((readSize - lastReport) >= reportPhase) {
log.debug("Read {}/{} bytes of iteration #{}", readSize, dataBytes.length, i);
lastReport = readSize;
}
} catch (SocketTimeoutException e) {
throw new IOException(
"Error reading data at index " + readSize + "/" + dataBytes.length + " of iteration #"
+ i,
e);
}
}
assertPayloadEquals("Mismatched payload at iteration #" + i, dataBytes, baos.toByteArray());
}
} catch (Exception e) {
log.error("Error in iteration #" + i, e);
errors.incrementAndGet();
}
}
try {
assertTrue("Failed to await pending iterations=" + numIterations,
iterationsSignal.tryAcquire(numIterations, numIterations, TimeUnit.SECONDS));
} finally {
log.info("{} remove port forwarding for {}", getCurrentTestName(), sinkPort);
session.delPortForwardingL(sinkPort);
}
ss.close();
log.info("{} awaiting acceptor finish", getCurrentTestName());
tAcceptor.join(TimeUnit.SECONDS.toMillis(11L));
} finally {
session.disconnect();
}
assertEquals("Some errors occured", 0, errors.get());
}
private static void assertPayloadEquals(String message, byte[] expectedBytes, byte[] actualBytes) {
assertEquals(message + ": mismatched payload length", expectedBytes.length, actualBytes.length);
for (int index = 0; index < expectedBytes.length; index++) {
if (expectedBytes[index] == actualBytes[index]) {
continue;
}
int startPos = Math.max(0, index - Byte.SIZE);
int endPos = Math.min(startPos + Short.SIZE, expectedBytes.length);
if ((endPos - startPos) < Byte.SIZE) {
startPos = expectedBytes.length - Byte.SIZE;
endPos = expectedBytes.length;
}
String expected = new String(expectedBytes, startPos, endPos - startPos, StandardCharsets.UTF_8);
String actual = new String(actualBytes, startPos, endPos - startPos, StandardCharsets.UTF_8);
fail("Mismatched data around offset " + index + ": expected='" + expected + "', actual='" + actual + "'");
}
}
@Test
public void testRemoteForwardingPayload() throws Exception {
final int numIterations = 100;
final String payload = "This is significantly longer Test Data. This is significantly "
+ "longer Test Data. This is significantly longer Test Data. This is significantly "
+ "longer Test Data. This is significantly longer Test Data. This is significantly "
+ "longer Test Data. This is significantly longer Test Data. This is significantly "
+ "longer Test Data. ";
Session session = createSession();
try (ServerSocket ss = new ServerSocket()) {
ss.setReuseAddress(true);
ss.bind(new InetSocketAddress((InetAddress) null, 0));
int forwardedPort = ss.getLocalPort();
int sinkPort = CoreTestSupportUtils.getFreePort();
session.setPortForwardingR(sinkPort, TEST_LOCALHOST, forwardedPort);
final boolean started[] = new boolean[1];
started[0] = false;
final AtomicInteger conCount = new AtomicInteger(0);
Thread tWriter = new Thread(getCurrentTestName() + "Writer") {
@SuppressWarnings("synthetic-access")
@Override
public void run() {
started[0] = true;
try {
byte[] bytes = payload.getBytes(StandardCharsets.UTF_8);
for (int i = 0; i < numIterations; ++i) {
try (Socket s = ss.accept()) {
conCount.incrementAndGet();
try (OutputStream sockOut = s.getOutputStream()) {
sockOut.write(bytes);
sockOut.flush();
}
}
}
} catch (Exception e) {
log.error("Failed to complete run loop", e);
}
}
};
tWriter.start();
Thread.sleep(TimeUnit.SECONDS.toMillis(1L));
assertTrue("Server not started", started[0]);
final RuntimeException lenOK[] = new RuntimeException[numIterations];
final RuntimeException dataOK[] = new RuntimeException[numIterations];
byte b2[] = new byte[payload.length()];
byte b1[] = new byte[b2.length / 2];
for (int i = 0; i < numIterations; i++) {
final int ii = i;
try (Socket s = new Socket(TEST_LOCALHOST, sinkPort);
InputStream sockIn = s.getInputStream()) {
s.setSoTimeout((int) TimeUnit.SECONDS.toMillis(10L));
int read1 = sockIn.read(b1);
String part1 = new String(b1, 0, read1, StandardCharsets.UTF_8);
Thread.sleep(50);
int read2 = sockIn.read(b2);
String part2 = new String(b2, 0, read2, StandardCharsets.UTF_8);
int totalRead = read1 + read2;
lenOK[ii] = (payload.length() == totalRead)
? null
: new IndexOutOfBoundsException(
"Mismatched length: expected=" + payload.length() + ", actual=" + totalRead);
String readData = part1 + part2;
dataOK[ii] = payload.equals(readData) ? null : new IllegalStateException("Mismatched content");
if (lenOK[ii] != null) {
throw lenOK[ii];
}
if (dataOK[ii] != null) {
throw dataOK[ii];
}
} catch (Exception e) {
if (e instanceof IOException) {
log.warn("I/O exception in iteration #" + i, e);
} else {
log.error("Failed to complete iteration #" + i, e);
}
}
}
int ok = 0;
for (int i = 0; i < numIterations; i++) {
ok += (lenOK[i] == null) ? 1 : 0;
}
log.info("Successful iterations: " + ok + " out of " + numIterations);
Thread.sleep(TimeUnit.SECONDS.toMillis(1L));
for (int i = 0; i < numIterations; i++) {
assertNull("Bad length at iteration " + i, lenOK[i]);
assertNull("Bad data at iteration " + i, dataOK[i]);
}
Thread.sleep(TimeUnit.SECONDS.toMillis(1L));
session.delPortForwardingR(forwardedPort);
ss.close();
tWriter.join(TimeUnit.SECONDS.toMillis(11L));
} finally {
session.disconnect();
}
}
@Test
public void testForwardingOnLoad() throws Exception {
// final String path = "/history/recent/troubles/";
// final String host = "www.bbc.co.uk";
// final String path = "";
// final String host = "www.bahn.de";
final String path = "";
final String host = TEST_LOCALHOST;
final int nbThread = 2;
final int nbDownloads = 2;
final int nbLoops = 2;
StringBuilder resp = new StringBuilder();
resp.append("<html><body>\n");
for (int i = 0; i < 1000; i++) {
resp.append("0123456789\n");
}
resp.append("</body></html>\n");
StringBuilder sb = new StringBuilder();
sb.append("HTTP/1.1 200 OK").append('\n');
sb.append("Content-Type: text/HTML").append('\n');
sb.append("Content-Length: ").append(resp.length()).append('\n');
sb.append('\n');
sb.append(resp);
NioSocketAcceptor acceptor = new NioSocketAcceptor();
acceptor.setHandler(new IoHandlerAdapter() {
@Override
public void messageReceived(IoSession session, Object message) throws Exception {
session.write(IoBuffer.wrap(sb.toString().getBytes(StandardCharsets.UTF_8)));
}
});
acceptor.setReuseAddress(true);
acceptor.bind(new InetSocketAddress(0));
int port = acceptor.getLocalAddress().getPort();
Session session = createSession();
try {
int forwardedPort1 = session.setPortForwardingL(0, host, port);
int forwardedPort2 = CoreTestSupportUtils.getFreePort();
session.setPortForwardingR(forwardedPort2, TEST_LOCALHOST, forwardedPort1);
outputDebugMessage("URL: http://localhost %s", forwardedPort2);
CountDownLatch latch = new CountDownLatch(nbThread * nbDownloads * nbLoops);
Thread[] threads = new Thread[nbThread];
List<Throwable> errors = new CopyOnWriteArrayList<>();
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(getCurrentTestName() + "[" + i + "]") {
@Override
@SuppressWarnings("synthetic-access")
public void run() {
for (int j = 0; j < nbLoops; j++) {
MultiThreadedHttpConnectionManager mgr = new MultiThreadedHttpConnectionManager();
HttpClient client = new HttpClient(mgr);
client.getHttpConnectionManager().getParams().setDefaultMaxConnectionsPerHost(100);
client.getHttpConnectionManager().getParams().setMaxTotalConnections(1000);
for (int i = 0; i < nbDownloads; i++) {
try {
checkHtmlPage(client, new URL("http://localhost:" + forwardedPort2 + path));
} catch (Throwable e) {
errors.add(e);
} finally {
latch.countDown();
log.debug("Remaining: " + latch.getCount());
}
}
mgr.shutdown();
}
}
};
}
for (Thread thread : threads) {
thread.start();
}
latch.await();
for (Throwable t : errors) {
log.warn("{}: {}", t.getClass().getSimpleName(), t.getMessage());
}
assertEquals(0, errors.size());
} finally {
session.disconnect();
}
}
protected Session createSession() throws JSchException {
JSch sch = new JSch();
Session session = sch.getSession("sshd", TEST_LOCALHOST, sshPort);
session.setUserInfo(new SimpleUserInfo("sshd"));
session.connect();
return session;
}
protected void checkHtmlPage(HttpClient client, URL url) throws IOException {
client.setHostConfiguration(new HostConfiguration());
client.getHostConfiguration().setHost(url.getHost(), url.getPort());
GetMethod get = new GetMethod("");
get.getParams().setVersion(HttpVersion.HTTP_1_1);
client.executeMethod(get);
String str = get.getResponseBodyAsString();
if (str.indexOf("</html>") <= 0) {
System.err.println(str);
}
assertTrue("Missing HTML close tag", str.indexOf("</html>") > 0);
get.releaseConnection();
// url.openConnection().setDefaultUseCaches(false);
// Reader reader = new BufferedReader(new InputStreamReader(url.openStream()));
// try {
// StringWriter sw = new StringWriter();
// char[] buf = new char[8192];
// while (true) {
// int len = reader.read(buf);
// if (len < 0) {
// break;
// }
// sw.write(buf, 0, len);
// }
// assertTrue(sw.toString().indexOf("</html>") > 0);
// } finally {
// reader.close();
// }
}
}