| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sshd.client; |
| |
| import java.io.IOException; |
| import java.lang.reflect.Method; |
| import java.security.KeyPair; |
| import java.security.PrivateKey; |
| import java.security.PublicKey; |
| import java.util.List; |
| import java.util.concurrent.atomic.AtomicReference; |
| |
| import org.apache.sshd.client.auth.AuthenticationIdentitiesProvider; |
| import org.apache.sshd.client.auth.BuiltinUserAuthFactories; |
| import org.apache.sshd.client.auth.UserAuth; |
| import org.apache.sshd.client.auth.keyboard.UserInteraction; |
| import org.apache.sshd.client.auth.password.PasswordIdentityProvider; |
| import org.apache.sshd.client.keyverifier.ServerKeyVerifier; |
| import org.apache.sshd.client.session.ClientSession; |
| import org.apache.sshd.client.session.ClientSessionImpl; |
| import org.apache.sshd.common.Factory; |
| import org.apache.sshd.common.NamedFactory; |
| import org.apache.sshd.common.NamedResource; |
| import org.apache.sshd.common.channel.Channel; |
| import org.apache.sshd.common.channel.ChannelListener; |
| import org.apache.sshd.common.forward.DefaultTcpipForwarderFactory; |
| import org.apache.sshd.common.io.IoSession; |
| import org.apache.sshd.common.io.IoWriteFuture; |
| import org.apache.sshd.common.keyprovider.KeyPairProvider; |
| import org.apache.sshd.common.random.JceRandomFactory; |
| import org.apache.sshd.common.random.Random; |
| import org.apache.sshd.common.random.SingletonRandomFactory; |
| import org.apache.sshd.common.session.Session; |
| import org.apache.sshd.common.session.SessionListener; |
| import org.apache.sshd.common.util.GenericUtils; |
| import org.apache.sshd.util.test.BaseTestSupport; |
| import org.junit.FixMethodOrder; |
| import org.junit.Test; |
| import org.junit.runners.MethodSorters; |
| import org.mockito.Mockito; |
| |
| /** |
| * @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a> |
| */ |
| @FixMethodOrder(MethodSorters.NAME_ASCENDING) |
| public class ClientAuthenticationManagerTest extends BaseTestSupport { |
| public ClientAuthenticationManagerTest() { |
| super(); |
| } |
| |
| @Test |
| public void testDefaultUserAuthFactoriesMethods() { |
| AtomicReference<List<NamedFactory<UserAuth>>> factoriesHolder = new AtomicReference<>(); |
| @SuppressWarnings("checkstyle:anoninnerlength") |
| ClientAuthenticationManager manager = new ClientAuthenticationManager() { |
| @Override |
| public List<NamedFactory<UserAuth>> getUserAuthFactories() { |
| return factoriesHolder.get(); |
| } |
| |
| @Override |
| public void setUserAuthFactories(List<NamedFactory<UserAuth>> userAuthFactories) { |
| assertNull("Unexpected multiple invocation", factoriesHolder.getAndSet(userAuthFactories)); |
| } |
| |
| @Override |
| public KeyPairProvider getKeyPairProvider() { |
| return null; |
| } |
| |
| @Override |
| public void setKeyPairProvider(KeyPairProvider keyPairProvider) { |
| throw new UnsupportedOperationException("setKeyPairProvider(" + keyPairProvider + ")"); |
| } |
| |
| @Override |
| public UserInteraction getUserInteraction() { |
| return null; |
| } |
| |
| @Override |
| public void setUserInteraction(UserInteraction userInteraction) { |
| throw new UnsupportedOperationException("setUserInteraction(" + userInteraction + ")"); |
| } |
| |
| @Override |
| public ServerKeyVerifier getServerKeyVerifier() { |
| return null; |
| } |
| |
| @Override |
| public void setServerKeyVerifier(ServerKeyVerifier serverKeyVerifier) { |
| throw new UnsupportedOperationException("setServerKeyVerifier(" + serverKeyVerifier + ")"); |
| } |
| |
| @Override |
| public PasswordIdentityProvider getPasswordIdentityProvider() { |
| return null; |
| } |
| |
| @Override |
| public void setPasswordIdentityProvider(PasswordIdentityProvider provider) { |
| throw new UnsupportedOperationException("setPasswordIdentityProvider(" + provider + ")"); |
| } |
| |
| @Override |
| public AuthenticationIdentitiesProvider getRegisteredIdentities() { |
| return null; |
| } |
| |
| @Override |
| public void addPublicKeyIdentity(KeyPair key) { |
| throw new UnsupportedOperationException("addPublicKeyIdentity(" + key + ")"); |
| } |
| |
| @Override |
| public KeyPair removePublicKeyIdentity(KeyPair kp) { |
| throw new UnsupportedOperationException("removePublicKeyIdentity(" + kp + ")"); |
| } |
| |
| @Override |
| public void addPasswordIdentity(String password) { |
| throw new UnsupportedOperationException("addPasswordIdentity(" + password + ")"); |
| } |
| |
| @Override |
| public String removePasswordIdentity(String password) { |
| throw new UnsupportedOperationException("removePasswordIdentity(" + password + ")"); |
| } |
| }; |
| assertEquals("Mismatched initial factories list", "", manager.getUserAuthFactoriesNameList()); |
| |
| String expected = NamedResource.Utils.getNames(BuiltinUserAuthFactories.VALUES); |
| manager.setUserAuthFactoriesNameList(expected); |
| assertEquals("Mismatched updated factories names", expected, manager.getUserAuthFactoriesNameList()); |
| |
| List<NamedFactory<UserAuth>> factories = factoriesHolder.get(); |
| assertEquals("Mismatched factories count", BuiltinUserAuthFactories.VALUES.size(), GenericUtils.size(factories)); |
| for (BuiltinUserAuthFactories f : BuiltinUserAuthFactories.VALUES) { |
| assertTrue("Missing factory=" + f.name(), factories.contains(f.create())); |
| } |
| } |
| |
| @Test |
| public void testAddRemoveClientSessionIdentities() throws Exception { |
| try (ClientSession session = createMockClientSession()) { |
| testClientAuthenticationManager(session); |
| } |
| } |
| |
| @Test |
| public void testAddRemoveSshClientIdentities() throws Exception { |
| try (SshClient client = SshClient.setUpDefaultClient()) { |
| testClientAuthenticationManager(client); |
| } |
| } |
| |
| @Test |
| public void testClientProvidersPropagation() throws Exception { |
| try (SshClient client = SshClient.setUpDefaultClient()) { |
| client.setServiceFactories(SshClient.DEFAULT_SERVICE_FACTORIES); |
| client.setUserAuthFactories(SshClient.DEFAULT_USER_AUTH_FACTORIES); |
| |
| try (ClientSession session = createMockClientSession(client)) { |
| for (Class<?> provider : new Class<?>[] { |
| PasswordIdentityProvider.class, |
| ServerKeyVerifier.class, |
| UserInteraction.class, |
| KeyPairProvider.class |
| }) { |
| testClientProvidersPropagation(provider, client, session); |
| } |
| } |
| } |
| } |
| |
| private void testClientProvidersPropagation(Class<?> type, ClientAuthenticationManager client, ClientAuthenticationManager session) throws Exception { |
| String baseName = type.getSimpleName(); |
| outputDebugMessage("testClientProvidersPropagation(%s)", baseName); |
| assertTrue(baseName + ": not an interface", type.isInterface()); |
| |
| Method getter = ClientAuthenticationManager.class.getMethod("get" + baseName); |
| Method setter = ClientAuthenticationManager.class.getMethod("set" + baseName, type); |
| Object clientProvider = Mockito.mock(type); |
| setter.invoke(client, clientProvider); |
| assertSame(baseName + ": mismatched client-only provider", clientProvider, getter.invoke(session)); |
| |
| Object sessionProvider = Mockito.mock(type); |
| setter.invoke(session, sessionProvider); |
| assertSame(baseName + ": mismatched session override provider", sessionProvider, getter.invoke(session)); |
| |
| setter.invoke(session, new Object[]{null}); |
| assertSame(baseName + ": mismatched nullified session provider", clientProvider, getter.invoke(session)); |
| } |
| |
| private <M extends ClientAuthenticationManager> M testClientAuthenticationManager(M manager) { |
| if (manager != null) { |
| String expected = getCurrentTestName(); |
| assertNull("Unexpected initial password identity", manager.removePasswordIdentity(expected)); |
| manager.addPasswordIdentity(expected); |
| |
| String actual = manager.removePasswordIdentity(expected); |
| assertSame("Mismatched removed password identity", expected, actual); |
| assertNull("Password identity not removed", manager.removePasswordIdentity(expected)); |
| } |
| |
| if (manager != null) { |
| KeyPair expected = new KeyPair(Mockito.mock(PublicKey.class), Mockito.mock(PrivateKey.class)); |
| assertNull("Unexpected initial pubket identity", manager.removePublicKeyIdentity(expected)); |
| manager.addPublicKeyIdentity(expected); |
| |
| KeyPair actual = manager.removePublicKeyIdentity(expected); |
| assertSame("Mismatched removed pubkey identity", expected, actual); |
| assertNull("Pubkey identity not removed", manager.removePublicKeyIdentity(expected)); |
| } |
| |
| return manager; |
| } |
| |
| private ClientSession createMockClientSession() throws Exception { |
| ClientFactoryManager client = Mockito.mock(ClientFactoryManager.class); |
| Mockito.when(client.getTcpipForwarderFactory()).thenReturn(DefaultTcpipForwarderFactory.INSTANCE); |
| Mockito.when(client.getSessionListenerProxy()).thenReturn(new SessionListener() { |
| @Override |
| public void sessionEvent(Session session, Event event) { |
| // ignored |
| } |
| |
| @Override |
| public void sessionCreated(Session session) { |
| // ignored |
| } |
| |
| @Override |
| public void sessionException(Session session, Throwable t) { |
| // ignored |
| } |
| |
| @Override |
| public void sessionClosed(Session session) { |
| // ignored |
| } |
| }); |
| Mockito.when(client.getChannelListenerProxy()).thenReturn(new ChannelListener() { |
| @Override |
| public void channelOpenSuccess(Channel channel) { |
| // ignored |
| } |
| |
| @Override |
| public void channelOpenFailure(Channel channel, Throwable reason) { |
| // ignored |
| } |
| |
| @Override |
| public void channelInitialized(Channel channel) { |
| // ignored |
| } |
| |
| @Override |
| public void channelStateChanged(Channel channel, String hint) { |
| // ignored |
| } |
| |
| @Override |
| public void channelClosed(Channel channel, Throwable reason) { |
| // ignored |
| } |
| }); |
| Factory<Random> randomFactory = new SingletonRandomFactory(JceRandomFactory.INSTANCE); |
| Mockito.when(client.getRandomFactory()).thenReturn(randomFactory); |
| |
| Mockito.when(client.getServiceFactories()).thenReturn(SshClient.DEFAULT_SERVICE_FACTORIES); |
| Mockito.when(client.getUserAuthFactories()).thenReturn(SshClient.DEFAULT_USER_AUTH_FACTORIES); |
| return createMockClientSession(client); |
| } |
| |
| private ClientSession createMockClientSession(ClientFactoryManager client) throws Exception { |
| return new ClientSessionImpl(client, Mockito.mock(IoSession.class)) { |
| @Override |
| protected IoWriteFuture sendClientIdentification() { |
| return null; |
| } |
| |
| @Override |
| protected byte[] sendKexInit() throws IOException { |
| return GenericUtils.EMPTY_BYTE_ARRAY; |
| } |
| |
| @Override |
| public void close() throws IOException { |
| // ignored |
| } |
| }; |
| } |
| } |