| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package org.apache.sshd.common.forward; |
| |
| import java.io.ByteArrayInputStream; |
| import java.io.ByteArrayOutputStream; |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.io.OutputStream; |
| import java.net.InetAddress; |
| import java.net.InetSocketAddress; |
| import java.net.ServerSocket; |
| import java.net.Socket; |
| import java.net.SocketTimeoutException; |
| import java.net.URL; |
| import java.nio.charset.StandardCharsets; |
| import java.util.List; |
| import java.util.concurrent.CopyOnWriteArrayList; |
| import java.util.concurrent.CountDownLatch; |
| import java.util.concurrent.Semaphore; |
| import java.util.concurrent.TimeUnit; |
| import java.util.concurrent.atomic.AtomicInteger; |
| |
| import com.jcraft.jsch.JSch; |
| import com.jcraft.jsch.JSchException; |
| import com.jcraft.jsch.Session; |
| import org.apache.commons.httpclient.HostConfiguration; |
| import org.apache.commons.httpclient.HttpClient; |
| import org.apache.commons.httpclient.HttpVersion; |
| import org.apache.commons.httpclient.MultiThreadedHttpConnectionManager; |
| import org.apache.commons.httpclient.methods.GetMethod; |
| import org.apache.mina.core.buffer.IoBuffer; |
| import org.apache.mina.core.service.IoAcceptor; |
| import org.apache.mina.core.service.IoHandlerAdapter; |
| import org.apache.mina.core.session.IoSession; |
| import org.apache.mina.transport.socket.nio.NioSocketAcceptor; |
| import org.apache.sshd.common.util.net.SshdSocketAddress; |
| import org.apache.sshd.common.util.security.SecurityUtils; |
| import org.apache.sshd.core.CoreModuleProperties; |
| import org.apache.sshd.server.SshServer; |
| import org.apache.sshd.server.forward.AcceptAllForwardingFilter; |
| import org.apache.sshd.util.test.BaseTestSupport; |
| import org.apache.sshd.util.test.CoreTestSupportUtils; |
| import org.apache.sshd.util.test.JSchLogger; |
| import org.apache.sshd.util.test.SimpleUserInfo; |
| import org.junit.After; |
| import org.junit.Assume; |
| import org.junit.Before; |
| import org.junit.BeforeClass; |
| import org.junit.FixMethodOrder; |
| import org.junit.Test; |
| import org.junit.runners.MethodSorters; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| /** |
| * Port forwarding tests |
| */ |
| @FixMethodOrder(MethodSorters.NAME_ASCENDING) |
| public class PortForwardingLoadTest extends BaseTestSupport { |
| private final Logger log; |
| |
| @SuppressWarnings({ "checkstyle:anoninnerlength", "synthetic-access" }) |
| private final PortForwardingEventListener serverSideListener = new PortForwardingEventListener() { |
| @Override |
| public void establishingExplicitTunnel( |
| org.apache.sshd.common.session.Session session, SshdSocketAddress local, SshdSocketAddress remote, |
| boolean localForwarding) |
| throws IOException { |
| log.info("establishingExplicitTunnel(session={}, local={}, remote={}, localForwarding={})", |
| session, local, remote, localForwarding); |
| } |
| |
| @Override |
| public void establishedExplicitTunnel( |
| org.apache.sshd.common.session.Session session, SshdSocketAddress local, |
| SshdSocketAddress remote, boolean localForwarding, SshdSocketAddress boundAddress, Throwable reason) |
| throws IOException { |
| log.info("establishedExplicitTunnel(session={}, local={}, remote={}, bound={}, localForwarding={}): {}", |
| session, local, remote, boundAddress, localForwarding, reason); |
| } |
| |
| @Override |
| public void tearingDownExplicitTunnel( |
| org.apache.sshd.common.session.Session session, SshdSocketAddress address, boolean localForwarding, |
| SshdSocketAddress remoteAddress) |
| throws IOException { |
| log.info("tearingDownExplicitTunnel(session={}, address={}, localForwarding={}, remote={})", |
| session, address, localForwarding, remoteAddress); |
| } |
| |
| @Override |
| public void tornDownExplicitTunnel( |
| org.apache.sshd.common.session.Session session, SshdSocketAddress address, boolean localForwarding, |
| SshdSocketAddress remoteAddress, Throwable reason) |
| throws IOException { |
| log.info("tornDownExplicitTunnel(session={}, address={}, localForwarding={}, remote={}, reason={})", |
| session, address, localForwarding, remoteAddress, reason); |
| } |
| |
| @Override |
| public void establishingDynamicTunnel( |
| org.apache.sshd.common.session.Session session, SshdSocketAddress local) |
| throws IOException { |
| log.info("establishingDynamicTunnel(session={}, local={})", session, local); |
| } |
| |
| @Override |
| public void establishedDynamicTunnel( |
| org.apache.sshd.common.session.Session session, SshdSocketAddress local, SshdSocketAddress boundAddress, |
| Throwable reason) |
| throws IOException { |
| log.info("establishedDynamicTunnel(session={}, local={}, bound={}, reason={})", session, local, boundAddress, |
| reason); |
| } |
| |
| @Override |
| public void tearingDownDynamicTunnel(org.apache.sshd.common.session.Session session, SshdSocketAddress address) |
| throws IOException { |
| log.info("tearingDownDynamicTunnel(session={}, address={})", session, address); |
| } |
| |
| @Override |
| public void tornDownDynamicTunnel( |
| org.apache.sshd.common.session.Session session, SshdSocketAddress address, Throwable reason) |
| throws IOException { |
| log.info("tornDownDynamicTunnel(session={}, address={}, reason={})", session, address, reason); |
| } |
| }; |
| |
| private SshServer sshd; |
| private int sshPort; |
| private IoAcceptor acceptor; |
| |
| public PortForwardingLoadTest() { |
| log = LoggerFactory.getLogger(getClass()); |
| } |
| |
| @BeforeClass |
| public static void jschInit() { |
| // FIXME inexplicably these tests fail without BC since SSHD-1004 |
| Assume.assumeTrue("Requires BC security provider", SecurityUtils.isBouncyCastleRegistered()); |
| JSchLogger.init(); |
| } |
| |
| @Before |
| public void setUp() throws Exception { |
| sshd = setupTestFullSupportServer(); |
| sshd.setForwardingFilter(AcceptAllForwardingFilter.INSTANCE); |
| sshd.addPortForwardingEventListener(serverSideListener); |
| sshd.start(); |
| sshPort = sshd.getPort(); |
| |
| NioSocketAcceptor acceptor = new NioSocketAcceptor(); |
| acceptor.setHandler(new IoHandlerAdapter() { |
| @Override |
| public void messageReceived(IoSession session, Object message) throws Exception { |
| IoBuffer recv = (IoBuffer) message; |
| IoBuffer sent = IoBuffer.allocate(recv.remaining()); |
| sent.put(recv); |
| sent.flip(); |
| session.write(sent); |
| } |
| }); |
| acceptor.setReuseAddress(true); |
| acceptor.bind(new InetSocketAddress(0)); |
| log.info("setUp() echo address = {}", acceptor.getLocalAddress()); |
| this.acceptor = acceptor; |
| } |
| |
| @After |
| public void tearDown() throws Exception { |
| if (sshd != null) { |
| sshd.stop(true); |
| } |
| if (acceptor != null) { |
| acceptor.dispose(true); |
| } |
| } |
| |
| @Test |
| @SuppressWarnings("checkstyle:nestedtrydepth") |
| public void testLocalForwardingPayload() throws Exception { |
| final int numIterations = 100; |
| final String payloadTmpData = "This is significantly longer Test Data. This is significantly " |
| + "longer Test Data. This is significantly longer Test Data. This is significantly " |
| + "longer Test Data. This is significantly longer Test Data. This is significantly " |
| + "longer Test Data. This is significantly longer Test Data. This is significantly " |
| + "longer Test Data. This is significantly longer Test Data. This is significantly " |
| + "longer Test Data. "; |
| StringBuilder sb = new StringBuilder(payloadTmpData.length() * 1000); |
| for (int i = 0; i < 1000; i++) { |
| sb.append(payloadTmpData); |
| } |
| String payload = sb.toString(); |
| |
| final byte[] dataBytes = payload.getBytes(StandardCharsets.UTF_8); |
| final int reportPhase = dataBytes.length / 10; |
| log.info("{} using payload size={}", getCurrentTestName(), dataBytes.length); |
| |
| AtomicInteger errors = new AtomicInteger(); |
| |
| Session session = createSession(); |
| try (ServerSocket ss = new ServerSocket()) { |
| ss.setReuseAddress(true); |
| ss.bind(new InetSocketAddress((InetAddress) null, 0)); |
| int forwardedPort = ss.getLocalPort(); |
| int sinkPort = session.setPortForwardingL(0, TEST_LOCALHOST, forwardedPort); |
| log.info("{} forwardedPort={}, sinkPort={}", getCurrentTestName(), forwardedPort, sinkPort); |
| |
| AtomicInteger conCount = new AtomicInteger(0); |
| Semaphore iterationsSignal = new Semaphore(0); |
| @SuppressWarnings("checkstyle:anoninnerlength") |
| Thread tAcceptor = new Thread(getCurrentTestName() + "Acceptor") { |
| @SuppressWarnings("synthetic-access") |
| @Override |
| public void run() { |
| try { |
| byte[] buf = new byte[8192]; |
| log.info("Started..."); |
| for (int i = 0; i < numIterations; ++i) { |
| try (Socket s = ss.accept()) { |
| int totalConns = conCount.incrementAndGet(); |
| log.info("Accepted connection #{} from {}", totalConns, s.getRemoteSocketAddress()); |
| |
| try (InputStream sockIn = s.getInputStream(); |
| ByteArrayOutputStream baos = new ByteArrayOutputStream()) { |
| |
| for (int readSize = 0, lastReport = 0; readSize < dataBytes.length;) { |
| int l = sockIn.read(buf); |
| if (l < 0) { |
| break; |
| } |
| |
| baos.write(buf, 0, l); |
| readSize += l; |
| |
| if ((readSize - lastReport) >= reportPhase) { |
| log.info("Read {}/{} bytes of iteration #{}", readSize, dataBytes.length, i); |
| lastReport = readSize; |
| } |
| } |
| |
| assertPayloadEquals("Mismatched received data at iteration #" + i, dataBytes, |
| baos.toByteArray()); |
| |
| byte[] outBytes = baos.toByteArray(); |
| try (InputStream inputCopy = new ByteArrayInputStream(outBytes); |
| OutputStream sockOut = s.getOutputStream()) { |
| |
| for (int writeSize = 0, lastReport = 0; writeSize < outBytes.length;) { |
| int l = inputCopy.read(buf); |
| if (l < 0) { |
| break; |
| } |
| sockOut.write(buf, 0, l); |
| writeSize += l; |
| if ((writeSize - lastReport) >= reportPhase) { |
| log.info("Written {}/{} bytes of iteration #{}", writeSize, dataBytes.length, |
| i); |
| lastReport = writeSize; |
| } |
| } |
| } |
| } |
| } |
| log.info("Finished iteration {}/{}", i, numIterations); |
| iterationsSignal.release(); |
| } |
| log.info("Done"); |
| } catch (Exception e) { |
| log.error("Failed to complete run loop", e); |
| } |
| } |
| }; |
| tAcceptor.start(); |
| Thread.sleep(TimeUnit.SECONDS.toMillis(1L)); |
| |
| byte[] buf = new byte[8192]; |
| for (int i = 0; i < numIterations; i++) { |
| log.debug("Iteration {}/{} started", i, numIterations); |
| try (Socket s = new Socket(TEST_LOCALHOST, sinkPort); |
| OutputStream sockOut = s.getOutputStream()) { |
| |
| log.debug("Iteration {} connected to {}", i, s.getRemoteSocketAddress()); |
| s.setSoTimeout((int) CoreModuleProperties.NIO2_MIN_WRITE_TIMEOUT.getRequiredDefault().toMillis()); |
| |
| sockOut.write(dataBytes); |
| sockOut.flush(); |
| |
| log.debug("Iteration {} awaiting echoed data", i); |
| try (InputStream sockIn = s.getInputStream(); |
| ByteArrayOutputStream baos = new ByteArrayOutputStream(dataBytes.length)) { |
| for (int readSize = 0, lastReport = 0; readSize < dataBytes.length;) { |
| try { |
| int l = sockIn.read(buf); |
| if (l < 0) { |
| break; |
| } |
| |
| baos.write(buf, 0, l); |
| readSize += l; |
| |
| if ((readSize - lastReport) >= reportPhase) { |
| log.debug("Read {}/{} bytes of iteration #{}", readSize, dataBytes.length, i); |
| lastReport = readSize; |
| } |
| } catch (SocketTimeoutException e) { |
| throw new IOException( |
| "Error reading data at index " + readSize + "/" + dataBytes.length + " of iteration #" |
| + i, |
| e); |
| } |
| } |
| assertPayloadEquals("Mismatched payload at iteration #" + i, dataBytes, baos.toByteArray()); |
| } |
| } catch (Exception e) { |
| log.error("Error in iteration #" + i, e); |
| errors.incrementAndGet(); |
| } |
| } |
| |
| try { |
| assertTrue("Failed to await pending iterations=" + numIterations, |
| iterationsSignal.tryAcquire(numIterations, numIterations, TimeUnit.SECONDS)); |
| } finally { |
| log.info("{} remove port forwarding for {}", getCurrentTestName(), sinkPort); |
| session.delPortForwardingL(sinkPort); |
| } |
| |
| ss.close(); |
| log.info("{} awaiting acceptor finish", getCurrentTestName()); |
| tAcceptor.join(TimeUnit.SECONDS.toMillis(11L)); |
| } finally { |
| session.disconnect(); |
| } |
| |
| assertEquals("Some errors occured", 0, errors.get()); |
| } |
| |
| private static void assertPayloadEquals(String message, byte[] expectedBytes, byte[] actualBytes) { |
| assertEquals(message + ": mismatched payload length", expectedBytes.length, actualBytes.length); |
| |
| for (int index = 0; index < expectedBytes.length; index++) { |
| if (expectedBytes[index] == actualBytes[index]) { |
| continue; |
| } |
| |
| int startPos = Math.max(0, index - Byte.SIZE); |
| int endPos = Math.min(startPos + Short.SIZE, expectedBytes.length); |
| if ((endPos - startPos) < Byte.SIZE) { |
| startPos = expectedBytes.length - Byte.SIZE; |
| endPos = expectedBytes.length; |
| } |
| |
| String expected = new String(expectedBytes, startPos, endPos - startPos, StandardCharsets.UTF_8); |
| String actual = new String(actualBytes, startPos, endPos - startPos, StandardCharsets.UTF_8); |
| fail("Mismatched data around offset " + index + ": expected='" + expected + "', actual='" + actual + "'"); |
| } |
| } |
| |
| @Test |
| public void testRemoteForwardingPayload() throws Exception { |
| final int numIterations = 100; |
| final String payload = "This is significantly longer Test Data. This is significantly " |
| + "longer Test Data. This is significantly longer Test Data. This is significantly " |
| + "longer Test Data. This is significantly longer Test Data. This is significantly " |
| + "longer Test Data. This is significantly longer Test Data. This is significantly " |
| + "longer Test Data. "; |
| Session session = createSession(); |
| try (ServerSocket ss = new ServerSocket()) { |
| ss.setReuseAddress(true); |
| ss.bind(new InetSocketAddress((InetAddress) null, 0)); |
| int forwardedPort = ss.getLocalPort(); |
| int sinkPort = CoreTestSupportUtils.getFreePort(); |
| session.setPortForwardingR(sinkPort, TEST_LOCALHOST, forwardedPort); |
| final boolean started[] = new boolean[1]; |
| started[0] = false; |
| final AtomicInteger conCount = new AtomicInteger(0); |
| |
| Thread tWriter = new Thread(getCurrentTestName() + "Writer") { |
| @SuppressWarnings("synthetic-access") |
| @Override |
| public void run() { |
| started[0] = true; |
| try { |
| byte[] bytes = payload.getBytes(StandardCharsets.UTF_8); |
| for (int i = 0; i < numIterations; ++i) { |
| try (Socket s = ss.accept()) { |
| conCount.incrementAndGet(); |
| |
| try (OutputStream sockOut = s.getOutputStream()) { |
| sockOut.write(bytes); |
| sockOut.flush(); |
| } |
| } |
| } |
| } catch (Exception e) { |
| log.error("Failed to complete run loop", e); |
| } |
| } |
| }; |
| tWriter.start(); |
| Thread.sleep(TimeUnit.SECONDS.toMillis(1L)); |
| assertTrue("Server not started", started[0]); |
| |
| final RuntimeException lenOK[] = new RuntimeException[numIterations]; |
| final RuntimeException dataOK[] = new RuntimeException[numIterations]; |
| byte b2[] = new byte[payload.length()]; |
| byte b1[] = new byte[b2.length / 2]; |
| |
| for (int i = 0; i < numIterations; i++) { |
| final int ii = i; |
| try (Socket s = new Socket(TEST_LOCALHOST, sinkPort); |
| InputStream sockIn = s.getInputStream()) { |
| s.setSoTimeout((int) TimeUnit.SECONDS.toMillis(10L)); |
| |
| int read1 = sockIn.read(b1); |
| String part1 = new String(b1, 0, read1, StandardCharsets.UTF_8); |
| Thread.sleep(50); |
| |
| int read2 = sockIn.read(b2); |
| String part2 = new String(b2, 0, read2, StandardCharsets.UTF_8); |
| int totalRead = read1 + read2; |
| lenOK[ii] = (payload.length() == totalRead) |
| ? null |
| : new IndexOutOfBoundsException( |
| "Mismatched length: expected=" + payload.length() + ", actual=" + totalRead); |
| |
| String readData = part1 + part2; |
| dataOK[ii] = payload.equals(readData) ? null : new IllegalStateException("Mismatched content"); |
| if (lenOK[ii] != null) { |
| throw lenOK[ii]; |
| } |
| |
| if (dataOK[ii] != null) { |
| throw dataOK[ii]; |
| } |
| } catch (Exception e) { |
| if (e instanceof IOException) { |
| log.warn("I/O exception in iteration #" + i, e); |
| } else { |
| log.error("Failed to complete iteration #" + i, e); |
| } |
| } |
| } |
| int ok = 0; |
| for (int i = 0; i < numIterations; i++) { |
| ok += (lenOK[i] == null) ? 1 : 0; |
| } |
| log.info("Successful iterations: " + ok + " out of " + numIterations); |
| Thread.sleep(TimeUnit.SECONDS.toMillis(1L)); |
| for (int i = 0; i < numIterations; i++) { |
| assertNull("Bad length at iteration " + i, lenOK[i]); |
| assertNull("Bad data at iteration " + i, dataOK[i]); |
| } |
| Thread.sleep(TimeUnit.SECONDS.toMillis(1L)); |
| session.delPortForwardingR(forwardedPort); |
| ss.close(); |
| tWriter.join(TimeUnit.SECONDS.toMillis(11L)); |
| } finally { |
| session.disconnect(); |
| } |
| } |
| |
| @Test |
| public void testForwardingOnLoad() throws Exception { |
| // final String path = "/history/recent/troubles/"; |
| // final String host = "www.bbc.co.uk"; |
| // final String path = ""; |
| // final String host = "www.bahn.de"; |
| final String path = ""; |
| final String host = TEST_LOCALHOST; |
| final int nbThread = 2; |
| final int nbDownloads = 2; |
| final int nbLoops = 2; |
| |
| StringBuilder resp = new StringBuilder(); |
| resp.append("<html><body>\n"); |
| for (int i = 0; i < 1000; i++) { |
| resp.append("0123456789\n"); |
| } |
| resp.append("</body></html>\n"); |
| |
| StringBuilder sb = new StringBuilder(); |
| sb.append("HTTP/1.1 200 OK").append('\n'); |
| sb.append("Content-Type: text/HTML").append('\n'); |
| sb.append("Content-Length: ").append(resp.length()).append('\n'); |
| sb.append('\n'); |
| sb.append(resp); |
| NioSocketAcceptor acceptor = new NioSocketAcceptor(); |
| acceptor.setHandler(new IoHandlerAdapter() { |
| @Override |
| public void messageReceived(IoSession session, Object message) throws Exception { |
| session.write(IoBuffer.wrap(sb.toString().getBytes(StandardCharsets.UTF_8))); |
| } |
| }); |
| acceptor.setReuseAddress(true); |
| acceptor.bind(new InetSocketAddress(0)); |
| int port = acceptor.getLocalAddress().getPort(); |
| |
| Session session = createSession(); |
| try { |
| int forwardedPort1 = session.setPortForwardingL(0, host, port); |
| int forwardedPort2 = CoreTestSupportUtils.getFreePort(); |
| session.setPortForwardingR(forwardedPort2, TEST_LOCALHOST, forwardedPort1); |
| outputDebugMessage("URL: http://localhost %s", forwardedPort2); |
| |
| CountDownLatch latch = new CountDownLatch(nbThread * nbDownloads * nbLoops); |
| Thread[] threads = new Thread[nbThread]; |
| List<Throwable> errors = new CopyOnWriteArrayList<>(); |
| for (int i = 0; i < threads.length; i++) { |
| threads[i] = new Thread(getCurrentTestName() + "[" + i + "]") { |
| @Override |
| @SuppressWarnings("synthetic-access") |
| public void run() { |
| for (int j = 0; j < nbLoops; j++) { |
| MultiThreadedHttpConnectionManager mgr = new MultiThreadedHttpConnectionManager(); |
| HttpClient client = new HttpClient(mgr); |
| client.getHttpConnectionManager().getParams().setDefaultMaxConnectionsPerHost(100); |
| client.getHttpConnectionManager().getParams().setMaxTotalConnections(1000); |
| for (int i = 0; i < nbDownloads; i++) { |
| try { |
| checkHtmlPage(client, new URL("http://localhost:" + forwardedPort2 + path)); |
| } catch (Throwable e) { |
| errors.add(e); |
| } finally { |
| latch.countDown(); |
| log.debug("Remaining: " + latch.getCount()); |
| } |
| } |
| mgr.shutdown(); |
| } |
| } |
| }; |
| } |
| for (Thread thread : threads) { |
| thread.start(); |
| } |
| latch.await(); |
| for (Throwable t : errors) { |
| log.warn("{}: {}", t.getClass().getSimpleName(), t.getMessage()); |
| } |
| assertEquals(0, errors.size()); |
| } finally { |
| session.disconnect(); |
| } |
| } |
| |
| protected Session createSession() throws JSchException { |
| JSch sch = new JSch(); |
| Session session = sch.getSession("sshd", TEST_LOCALHOST, sshPort); |
| session.setUserInfo(new SimpleUserInfo("sshd")); |
| session.connect(); |
| return session; |
| } |
| |
| protected void checkHtmlPage(HttpClient client, URL url) throws IOException { |
| client.setHostConfiguration(new HostConfiguration()); |
| client.getHostConfiguration().setHost(url.getHost(), url.getPort()); |
| GetMethod get = new GetMethod(""); |
| get.getParams().setVersion(HttpVersion.HTTP_1_1); |
| client.executeMethod(get); |
| String str = get.getResponseBodyAsString(); |
| if (str.indexOf("</html>") <= 0) { |
| System.err.println(str); |
| } |
| assertTrue("Missing HTML close tag", str.indexOf("</html>") > 0); |
| get.releaseConnection(); |
| // url.openConnection().setDefaultUseCaches(false); |
| // Reader reader = new BufferedReader(new InputStreamReader(url.openStream())); |
| // try { |
| // StringWriter sw = new StringWriter(); |
| // char[] buf = new char[8192]; |
| // while (true) { |
| // int len = reader.read(buf); |
| // if (len < 0) { |
| // break; |
| // } |
| // sw.write(buf, 0, len); |
| // } |
| // assertTrue(sw.toString().indexOf("</html>") > 0); |
| // } finally { |
| // reader.close(); |
| // } |
| } |
| } |