| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package org.apache.sshd; |
| |
| import java.io.ByteArrayOutputStream; |
| import java.io.InputStream; |
| import java.io.OutputStream; |
| import java.io.PipedInputStream; |
| import java.io.PipedOutputStream; |
| import java.lang.reflect.InvocationHandler; |
| import java.lang.reflect.Method; |
| import java.lang.reflect.Proxy; |
| import java.nio.charset.StandardCharsets; |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.EnumSet; |
| import java.util.List; |
| import java.util.concurrent.Semaphore; |
| import java.util.concurrent.TimeUnit; |
| import java.util.concurrent.atomic.AtomicBoolean; |
| import java.util.concurrent.atomic.AtomicInteger; |
| |
| import com.jcraft.jsch.JSch; |
| |
| import org.apache.sshd.client.SshClient; |
| import org.apache.sshd.client.channel.ChannelShell; |
| import org.apache.sshd.client.channel.ClientChannel; |
| import org.apache.sshd.client.channel.ClientChannelEvent; |
| import org.apache.sshd.client.session.ClientSession; |
| import org.apache.sshd.common.FactoryManager; |
| import org.apache.sshd.common.NamedFactory; |
| import org.apache.sshd.common.PropertyResolverUtils; |
| import org.apache.sshd.common.channel.Channel; |
| import org.apache.sshd.common.cipher.BuiltinCiphers; |
| import org.apache.sshd.common.future.KeyExchangeFuture; |
| import org.apache.sshd.common.kex.BuiltinDHFactories; |
| import org.apache.sshd.common.kex.KeyExchange; |
| import org.apache.sshd.common.session.Session; |
| import org.apache.sshd.common.session.SessionListener; |
| import org.apache.sshd.common.subsystem.sftp.SftpConstants; |
| import org.apache.sshd.common.util.SecurityUtils; |
| import org.apache.sshd.common.util.io.NullOutputStream; |
| import org.apache.sshd.server.SshServer; |
| import org.apache.sshd.util.test.BaseTestSupport; |
| import org.apache.sshd.util.test.JSchLogger; |
| import org.apache.sshd.util.test.OutputCountTrackingOutputStream; |
| import org.apache.sshd.util.test.SimpleUserInfo; |
| import org.apache.sshd.util.test.TeeOutputStream; |
| import org.junit.After; |
| import org.junit.Assume; |
| import org.junit.BeforeClass; |
| import org.junit.FixMethodOrder; |
| import org.junit.Test; |
| import org.junit.runners.MethodSorters; |
| |
| /** |
| * Test key exchange algorithms. |
| * |
| * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a> |
| */ |
| @FixMethodOrder(MethodSorters.NAME_ASCENDING) |
| public class KeyReExchangeTest extends BaseTestSupport { |
| |
| private SshServer sshd; |
| private int port; |
| |
| public KeyReExchangeTest() { |
| super(); |
| } |
| |
| @BeforeClass |
| public static void jschInit() { |
| JSchLogger.init(); |
| } |
| |
| @After |
| public void tearDown() throws Exception { |
| if (sshd != null) { |
| sshd.stop(true); |
| } |
| } |
| |
| protected void setUp(long bytesLimit, long timeLimit, long packetsLimit) throws Exception { |
| sshd = setupTestServer(); |
| if (bytesLimit > 0L) { |
| PropertyResolverUtils.updateProperty(sshd, FactoryManager.REKEY_BYTES_LIMIT, bytesLimit); |
| } |
| if (timeLimit > 0L) { |
| PropertyResolverUtils.updateProperty(sshd, FactoryManager.REKEY_TIME_LIMIT, timeLimit); |
| } |
| if (packetsLimit > 0L) { |
| PropertyResolverUtils.updateProperty(sshd, FactoryManager.REKEY_PACKETS_LIMIT, packetsLimit); |
| } |
| |
| sshd.start(); |
| port = sshd.getPort(); |
| } |
| |
| @Test |
| public void testSwitchToNoneCipher() throws Exception { |
| setUp(0L, 0L, 0L); |
| |
| sshd.getCipherFactories().add(BuiltinCiphers.none); |
| try (SshClient client = setupTestClient()) { |
| client.getCipherFactories().add(BuiltinCiphers.none); |
| client.start(); |
| |
| try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port).verify(7L, TimeUnit.SECONDS).getSession()) { |
| session.addPasswordIdentity(getCurrentTestName()); |
| session.auth().verify(5L, TimeUnit.SECONDS); |
| |
| outputDebugMessage("Request switch to none cipher for %s", session); |
| KeyExchangeFuture switchFuture = session.switchToNoneCipher(); |
| switchFuture.verify(5L, TimeUnit.SECONDS); |
| try (ClientChannel channel = session.createSubsystemChannel(SftpConstants.SFTP_SUBSYSTEM_NAME)) { |
| channel.open().verify(5L, TimeUnit.SECONDS); |
| } |
| } finally { |
| client.stop(); |
| } |
| } |
| } |
| |
| @Test // see SSHD-558 |
| public void testKexFutureExceptionPropagation() throws Exception { |
| setUp(0L, 0L, 0L); |
| sshd.getCipherFactories().add(BuiltinCiphers.none); |
| |
| try (SshClient client = setupTestClient()) { |
| client.getCipherFactories().add(BuiltinCiphers.none); |
| // replace the original KEX factories with wrapped ones that we can fail intentionally |
| List<NamedFactory<KeyExchange>> kexFactories = new ArrayList<>(); |
| final AtomicBoolean successfulInit = new AtomicBoolean(true); |
| final AtomicBoolean successfulNext = new AtomicBoolean(true); |
| final ClassLoader loader = getClass().getClassLoader(); |
| final Class<?>[] interfaces = {KeyExchange.class}; |
| for (final NamedFactory<KeyExchange> factory : client.getKeyExchangeFactories()) { |
| kexFactories.add(new NamedFactory<KeyExchange>() { |
| @Override |
| public String getName() { |
| return factory.getName(); |
| } |
| |
| @Override |
| public KeyExchange create() { |
| final KeyExchange proxiedInstance = factory.create(); |
| return (KeyExchange) Proxy.newProxyInstance(loader, interfaces, new InvocationHandler() { |
| @Override |
| public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { |
| String name = method.getName(); |
| if ("init".equals(name) && (!successfulInit.get())) { |
| throw new UnsupportedOperationException("Intentionally failing 'init'"); |
| } else if ("next".equals(name) && (!successfulNext.get())) { |
| throw new UnsupportedOperationException("Intentionally failing 'next'"); |
| } else { |
| return method.invoke(proxiedInstance, args); |
| } |
| } |
| }); |
| } |
| }); |
| } |
| client.setKeyExchangeFactories(kexFactories); |
| client.start(); |
| |
| try { |
| try { |
| testKexFutureExceptionPropagation("init", successfulInit, client); |
| } finally { |
| successfulInit.set(true); |
| } |
| |
| try { |
| testKexFutureExceptionPropagation("next", successfulNext, client); |
| } finally { |
| successfulNext.set(true); |
| } |
| } finally { |
| client.stop(); |
| } |
| } |
| } |
| |
| private void testKexFutureExceptionPropagation(String failureType, AtomicBoolean successFlag, SshClient client) throws Exception { |
| try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port).verify(7L, TimeUnit.SECONDS).getSession()) { |
| session.addPasswordIdentity(getCurrentTestName()); |
| session.auth().verify(5L, TimeUnit.SECONDS); |
| |
| successFlag.set(false); |
| KeyExchangeFuture kexFuture = session.switchToNoneCipher(); |
| assertTrue(failureType + ": failed to complete KEX on time", kexFuture.await(7L, TimeUnit.SECONDS)); |
| assertNotNull(failureType + ": unexpected success", kexFuture.getException()); |
| } |
| } |
| |
| @Test |
| public void testReExchangeFromJschClient() throws Exception { |
| Assume.assumeTrue("DH Group Exchange not supported", SecurityUtils.isDHGroupExchangeSupported()); |
| setUp(0L, 0L, 0L); |
| |
| JSch.setConfig("kex", BuiltinDHFactories.Constants.DIFFIE_HELLMAN_GROUP_EXCHANGE_SHA1); |
| JSch sch = new JSch(); |
| com.jcraft.jsch.Session s = sch.getSession(getCurrentTestName(), TEST_LOCALHOST, port); |
| try { |
| s.setUserInfo(new SimpleUserInfo(getCurrentTestName())); |
| s.connect(); |
| |
| com.jcraft.jsch.Channel c = s.openChannel(Channel.CHANNEL_SHELL); |
| c.connect(); |
| try (OutputStream os = c.getOutputStream(); |
| InputStream is = c.getInputStream()) { |
| |
| String expected = "this is my command\n"; |
| byte[] bytes = expected.getBytes(StandardCharsets.UTF_8); |
| byte[] data = new byte[bytes.length + Long.SIZE]; |
| for (int i = 1; i <= 10; i++) { |
| os.write(bytes); |
| os.flush(); |
| |
| int len = is.read(data); |
| String str = new String(data, 0, len); |
| assertEquals("Mismatched data at iteration " + i, expected, str); |
| |
| outputDebugMessage("Request re-key #%d", i); |
| s.rekey(); |
| } |
| } finally { |
| c.disconnect(); |
| } |
| } finally { |
| s.disconnect(); |
| } |
| } |
| |
| @Test |
| public void testReExchangeFromSshdClient() throws Exception { |
| setUp(0L, 0L, 0L); |
| |
| try (SshClient client = setupTestClient()) { |
| client.start(); |
| |
| try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port).verify(7L, TimeUnit.SECONDS).getSession()) { |
| session.addPasswordIdentity(getCurrentTestName()); |
| session.auth().verify(5L, TimeUnit.SECONDS); |
| |
| final Semaphore pipedCount = new Semaphore(0, true); |
| try (ChannelShell channel = session.createShellChannel(); |
| ByteArrayOutputStream sent = new ByteArrayOutputStream(); |
| PipedOutputStream pipedIn = new PipedOutputStream(); |
| InputStream inPipe = new PipedInputStream(pipedIn); |
| OutputStream teeOut = new TeeOutputStream(sent, pipedIn); |
| ByteArrayOutputStream out = new ByteArrayOutputStream() { |
| private long writeCount; |
| |
| @Override |
| public void write(int b) { |
| super.write(b); |
| updateWriteCount(1L); |
| pipedCount.release(1); |
| } |
| |
| @Override |
| public void write(byte[] b, int off, int len) { |
| super.write(b, off, len); |
| updateWriteCount(len); |
| pipedCount.release(len); |
| } |
| |
| private void updateWriteCount(long delta) { |
| writeCount += delta; |
| outputDebugMessage("OUT write count=%d", writeCount); |
| } |
| }; |
| ByteArrayOutputStream err = new ByteArrayOutputStream()) { |
| |
| channel.setIn(inPipe); |
| channel.setOut(out); |
| channel.setErr(err); |
| channel.open(); |
| |
| teeOut.write("this is my command\n".getBytes(StandardCharsets.UTF_8)); |
| teeOut.flush(); |
| |
| StringBuilder sb = new StringBuilder(Byte.MAX_VALUE); |
| for (int i = 0; i < 10; i++) { |
| sb.append("0123456789"); |
| } |
| sb.append('\n'); |
| |
| byte[] data = sb.toString().getBytes(StandardCharsets.UTF_8); |
| for (int i = 1; i <= 10; i++) { |
| teeOut.write(data); |
| teeOut.flush(); |
| |
| KeyExchangeFuture kexFuture = session.reExchangeKeys(); |
| assertTrue("Failed to complete KEX on time at iteration " + i, kexFuture.await(5L, TimeUnit.SECONDS)); |
| assertNull("KEX exception signalled at iteration " + i, kexFuture.getException()); |
| } |
| teeOut.write("exit\n".getBytes(StandardCharsets.UTF_8)); |
| teeOut.flush(); |
| |
| Collection<ClientChannelEvent> result = |
| channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), TimeUnit.SECONDS.toMillis(15L)); |
| assertFalse("Timeout while waiting for channel closure", result.contains(ClientChannelEvent.TIMEOUT)); |
| |
| byte[] expected = sent.toByteArray(); |
| if (!pipedCount.tryAcquire(expected.length, 13L, TimeUnit.SECONDS)) { |
| fail("Failed to await sent data signal for len=" + expected.length + " (available=" + pipedCount.availablePermits() + ")"); |
| } |
| |
| assertArrayEquals("Mismatched sent data content", expected, out.toByteArray()); |
| } |
| } finally { |
| client.stop(); |
| } |
| } |
| } |
| |
| @Test |
| public void testReExchangeFromServerBySize() throws Exception { |
| final long bytesLImit = 10 * 1024L; |
| setUp(bytesLImit, 0L, 0L); |
| |
| try (SshClient client = setupTestClient()) { |
| client.start(); |
| |
| final Semaphore pipedCount = new Semaphore(0, true); |
| try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port).verify(7L, TimeUnit.SECONDS).getSession(); |
| ByteArrayOutputStream sent = new ByteArrayOutputStream(); |
| ByteArrayOutputStream out = new ByteArrayOutputStream() { |
| private long writeCount; |
| |
| @Override |
| public void write(int b) { |
| super.write(b); |
| updateWriteCount(1L); |
| pipedCount.release(1); |
| } |
| |
| @Override |
| public void write(byte[] b, int off, int len) { |
| super.write(b, off, len); |
| updateWriteCount(len); |
| pipedCount.release(len); |
| } |
| |
| private void updateWriteCount(long delta) { |
| writeCount += delta; |
| outputDebugMessage("OUT write count=%d", writeCount); |
| } |
| }) { |
| session.addPasswordIdentity(getCurrentTestName()); |
| session.auth().verify(5L, TimeUnit.SECONDS); |
| |
| byte[] sentData; |
| try (ChannelShell channel = session.createShellChannel(); |
| PipedOutputStream pipedIn = new PipedOutputStream(); |
| OutputStream teeOut = new TeeOutputStream(sent, pipedIn); |
| OutputStream err = new NullOutputStream(); |
| InputStream inPipe = new PipedInputStream(pipedIn)) { |
| |
| channel.setIn(inPipe); |
| channel.setOut(out); |
| channel.setErr(err); |
| channel.open(); |
| |
| teeOut.write("this is my command\n".getBytes(StandardCharsets.UTF_8)); |
| teeOut.flush(); |
| |
| StringBuilder sb = new StringBuilder(101 * 10); |
| for (int i = 0; i < 100; i++) { |
| sb.append("0123456789"); |
| } |
| sb.append('\n'); |
| |
| final AtomicInteger exchanges = new AtomicInteger(); |
| session.addSessionListener(new SessionListener() { |
| @Override |
| public void sessionCreated(Session session) { |
| // ignored |
| } |
| |
| @Override |
| public void sessionEvent(Session session, Event event) { |
| if (Event.KeyEstablished.equals(event)) { |
| int count = exchanges.incrementAndGet(); |
| outputDebugMessage("Key established for %s - count=%d", session, count); |
| } |
| } |
| |
| @Override |
| public void sessionException(Session session, Throwable t) { |
| // ignored |
| } |
| |
| @Override |
| public void sessionClosed(Session session) { |
| // ignored |
| } |
| }); |
| |
| byte[] data = sb.toString().getBytes(StandardCharsets.UTF_8); |
| for (long sentSize = 0L; sentSize < (bytesLImit + Byte.MAX_VALUE + data.length); sentSize += data.length) { |
| teeOut.write(data); |
| teeOut.flush(); |
| // no need to wait until the limit is reached if a re-key occurred |
| if (exchanges.get() > 0) { |
| outputDebugMessage("Stop sending after %d bytes - exchanges=%s", sentSize + data.length, exchanges); |
| break; |
| } |
| } |
| |
| teeOut.write("exit\n".getBytes(StandardCharsets.UTF_8)); |
| teeOut.flush(); |
| |
| Collection<ClientChannelEvent> result = |
| channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), TimeUnit.SECONDS.toMillis(15L)); |
| assertFalse("Timeout while waiting for channel closure", result.contains(ClientChannelEvent.TIMEOUT)); |
| |
| sentData = sent.toByteArray(); |
| if (!pipedCount.tryAcquire(sentData.length, 13L, TimeUnit.SECONDS)) { |
| fail("Failed to await sent data signal for len=" + sentData.length + " (available=" + pipedCount.availablePermits() + ")"); |
| } |
| assertTrue("Expected rekeying", exchanges.get() > 0); |
| } |
| |
| byte[] outData = out.toByteArray(); |
| assertEquals("Mismatched sent data length", sentData.length, outData.length); |
| assertArrayEquals("Mismatched sent data content", sentData, outData); |
| } finally { |
| client.stop(); |
| } |
| } |
| } |
| |
| @Test |
| public void testReExchangeFromServerByTime() throws Exception { |
| final long timeLimit = TimeUnit.SECONDS.toMillis(2L); |
| setUp(0L, timeLimit, 0L); |
| |
| try (SshClient client = setupTestClient()) { |
| client.start(); |
| |
| final Semaphore pipedCount = new Semaphore(0, true); |
| try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port).verify(7L, TimeUnit.SECONDS).getSession(); |
| ByteArrayOutputStream sent = new ByteArrayOutputStream(); |
| ByteArrayOutputStream out = new ByteArrayOutputStream() { |
| private long writeCount; |
| |
| @Override |
| public void write(int b) { |
| super.write(b); |
| updateWriteCount(1L); |
| pipedCount.release(1); |
| } |
| |
| @Override |
| public void write(byte[] b, int off, int len) { |
| super.write(b, off, len); |
| updateWriteCount(len); |
| pipedCount.release(len); |
| } |
| |
| private void updateWriteCount(long delta) { |
| writeCount += delta; |
| outputDebugMessage("OUT write count=%d", writeCount); |
| } |
| }) { |
| session.addPasswordIdentity(getCurrentTestName()); |
| session.auth().verify(5L, TimeUnit.SECONDS); |
| |
| byte[] sentData; |
| try (ChannelShell channel = session.createShellChannel(); |
| PipedOutputStream pipedIn = new PipedOutputStream(); |
| OutputStream teeOut = new TeeOutputStream(sent, pipedIn); |
| OutputStream err = new NullOutputStream(); |
| InputStream inPipe = new PipedInputStream(pipedIn)) { |
| |
| channel.setIn(inPipe); |
| channel.setOut(out); |
| channel.setErr(err); |
| channel.open(); |
| |
| teeOut.write("this is my command\n".getBytes(StandardCharsets.UTF_8)); |
| teeOut.flush(); |
| |
| StringBuilder sb = new StringBuilder(101 * 10); |
| for (int i = 0; i < 100; i++) { |
| sb.append("0123456789"); |
| } |
| sb.append('\n'); |
| |
| final AtomicInteger exchanges = new AtomicInteger(); |
| session.addSessionListener(new SessionListener() { |
| @Override |
| public void sessionCreated(Session session) { |
| // ignored |
| } |
| |
| @Override |
| public void sessionEvent(Session session, Event event) { |
| if (Event.KeyEstablished.equals(event)) { |
| int count = exchanges.incrementAndGet(); |
| outputDebugMessage("Key established for %s - count=%d", session, count); |
| } |
| } |
| |
| @Override |
| public void sessionException(Session session, Throwable t) { |
| // ignored |
| } |
| |
| @Override |
| public void sessionClosed(Session session) { |
| // ignored |
| } |
| }); |
| |
| byte[] data = getCurrentTestName().getBytes(StandardCharsets.UTF_8); |
| final long maxWaitNanos = TimeUnit.MILLISECONDS.toNanos(3L * timeLimit); |
| final long minWaitValue = 10L; |
| final long minWaitNanos = TimeUnit.MILLISECONDS.toNanos(minWaitValue); |
| for (long timePassed = 0L, sentSize = 0L; timePassed < maxWaitNanos; timePassed++) { |
| long nanoStart = System.nanoTime(); |
| teeOut.write(data); |
| teeOut.write('\n'); |
| teeOut.flush(); |
| |
| long nanoEnd = System.nanoTime(); |
| long nanoDuration = nanoEnd - nanoStart; |
| |
| timePassed += nanoDuration; |
| sentSize += data.length + 1; |
| |
| // no need to wait until the timeout expires if a re-key occurred |
| if (exchanges.get() > 0) { |
| outputDebugMessage("Stop sending after %d nanos and size=%d - exchanges=%s", |
| timePassed, sentSize, exchanges); |
| break; |
| } |
| |
| if ((timePassed < maxWaitNanos) && (nanoDuration < minWaitNanos)) { |
| Thread.sleep(minWaitValue); |
| } |
| } |
| |
| teeOut.write("exit\n".getBytes(StandardCharsets.UTF_8)); |
| teeOut.flush(); |
| |
| Collection<ClientChannelEvent> result = |
| channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), TimeUnit.SECONDS.toMillis(15L)); |
| assertFalse("Timeout while waiting for channel closure", result.contains(ClientChannelEvent.TIMEOUT)); |
| |
| sentData = sent.toByteArray(); |
| if (!pipedCount.tryAcquire(sentData.length, 13L, TimeUnit.SECONDS)) { |
| fail("Failed to await sent data signal for len=" + sentData.length + " (available=" + pipedCount.availablePermits() + ")"); |
| } |
| |
| assertTrue("Expected rekeying", exchanges.get() > 0); |
| } |
| |
| byte[] outData = out.toByteArray(); |
| assertEquals("Mismatched sent data length", sentData.length, outData.length); |
| assertArrayEquals("Mismatched sent data content", sentData, outData); |
| } finally { |
| client.stop(); |
| } |
| } |
| } |
| |
| @Test // see SSHD-601 |
| public void testReExchangeFromServerByPackets() throws Exception { |
| final int packetsLimit = 135; |
| setUp(0L, 0L, packetsLimit); |
| |
| try (SshClient client = setupTestClient()) { |
| client.start(); |
| |
| final Semaphore pipedCount = new Semaphore(0, true); |
| try (ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port).verify(7L, TimeUnit.SECONDS).getSession(); |
| ByteArrayOutputStream sent = new ByteArrayOutputStream(); |
| ByteArrayOutputStream out = new ByteArrayOutputStream() { |
| private long writeCount; |
| |
| @Override |
| public void write(int b) { |
| super.write(b); |
| updateWriteCount(1L); |
| pipedCount.release(1); |
| } |
| |
| @Override |
| public void write(byte[] b, int off, int len) { |
| super.write(b, off, len); |
| updateWriteCount(len); |
| pipedCount.release(len); |
| } |
| |
| private void updateWriteCount(long delta) { |
| writeCount += delta; |
| outputDebugMessage("OUT write count=%d", writeCount); |
| } |
| }) { |
| session.addPasswordIdentity(getCurrentTestName()); |
| session.auth().verify(5L, TimeUnit.SECONDS); |
| |
| byte[] sentData; |
| try (ChannelShell channel = session.createShellChannel(); |
| PipedOutputStream pipedIn = new PipedOutputStream(); |
| OutputStream sentTracker = new OutputCountTrackingOutputStream(sent) { |
| @Override |
| protected long updateWriteCount(long delta) { |
| long result = super.updateWriteCount(delta); |
| outputDebugMessage("SENT write count=%d", result); |
| return result; |
| } |
| }; |
| OutputStream teeOut = new TeeOutputStream(sentTracker, pipedIn); |
| OutputStream stderr = new NullOutputStream(); |
| OutputStream stdout = new OutputCountTrackingOutputStream(out) { |
| @Override |
| protected long updateWriteCount(long delta) { |
| long result = super.updateWriteCount(delta); |
| outputDebugMessage("OUT write count=%d", result); |
| return result; |
| } |
| }; |
| InputStream inPipe = new PipedInputStream(pipedIn)) { |
| |
| channel.setIn(inPipe); |
| channel.setOut(stdout); |
| channel.setErr(stderr); |
| channel.open(); |
| |
| teeOut.write("this is my command\n".getBytes(StandardCharsets.UTF_8)); |
| teeOut.flush(); |
| |
| final AtomicInteger exchanges = new AtomicInteger(); |
| session.addSessionListener(new SessionListener() { |
| @Override |
| public void sessionCreated(Session session) { |
| // ignored |
| } |
| |
| @Override |
| public void sessionEvent(Session session, Event event) { |
| if (Event.KeyEstablished.equals(event)) { |
| int count = exchanges.incrementAndGet(); |
| outputDebugMessage("Key established for %s - count=%d", session, count); |
| } |
| } |
| |
| @Override |
| public void sessionException(Session session, Throwable t) { |
| // ignored |
| } |
| |
| @Override |
| public void sessionClosed(Session session) { |
| // ignored |
| } |
| }); |
| |
| byte[] data = (getClass().getName() + "#" + getCurrentTestName() + "\n").getBytes(StandardCharsets.UTF_8); |
| for (int index = 0; index < (packetsLimit * 2); index++) { |
| teeOut.write(data); |
| teeOut.flush(); |
| |
| // no need to wait until the packets limit is reached if a re-key occurred |
| if (exchanges.get() > 0) { |
| outputDebugMessage("Stop sending after %d packets and %d bytes - exchanges=%s", |
| index + 11L, (index + 1L) * data.length, exchanges); |
| break; |
| } |
| } |
| |
| teeOut.write("exit\n".getBytes(StandardCharsets.UTF_8)); |
| teeOut.flush(); |
| |
| Collection<ClientChannelEvent> result = |
| channel.waitFor(EnumSet.of(ClientChannelEvent.CLOSED), TimeUnit.SECONDS.toMillis(15L)); |
| assertFalse("Timeout while waiting for channel closure", result.contains(ClientChannelEvent.TIMEOUT)); |
| |
| sentData = sent.toByteArray(); |
| if (!pipedCount.tryAcquire(sentData.length, 13L, TimeUnit.SECONDS)) { |
| fail("Failed to await sent data signal for len=" + sentData.length + " (available=" + pipedCount.availablePermits() + ")"); |
| } |
| |
| assertTrue("Expected rekeying", exchanges.get() > 0); |
| } |
| |
| byte[] outData = out.toByteArray(); |
| assertEquals("Mismatched sent data length", sentData.length, outData.length); |
| assertArrayEquals("Mismatched sent data content", sentData, outData); |
| } finally { |
| client.stop(); |
| } |
| } |
| } |
| } |