blob: 0d08497869cdb223788c6d49a0cefb9287e6ff3c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sshd.client;
import java.io.IOException;
import java.lang.reflect.Method;
import java.security.KeyPair;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.sshd.client.auth.AuthenticationIdentitiesProvider;
import org.apache.sshd.client.auth.BuiltinUserAuthFactories;
import org.apache.sshd.client.auth.UserAuth;
import org.apache.sshd.client.auth.keyboard.UserInteraction;
import org.apache.sshd.client.auth.password.PasswordIdentityProvider;
import org.apache.sshd.client.keyverifier.ServerKeyVerifier;
import org.apache.sshd.client.session.ClientSession;
import org.apache.sshd.client.session.ClientSessionImpl;
import org.apache.sshd.common.Factory;
import org.apache.sshd.common.NamedFactory;
import org.apache.sshd.common.NamedResource;
import org.apache.sshd.common.channel.Channel;
import org.apache.sshd.common.channel.ChannelListener;
import org.apache.sshd.common.forward.DefaultTcpipForwarderFactory;
import org.apache.sshd.common.io.IoSession;
import org.apache.sshd.common.io.IoWriteFuture;
import org.apache.sshd.common.keyprovider.KeyPairProvider;
import org.apache.sshd.common.random.JceRandomFactory;
import org.apache.sshd.common.random.Random;
import org.apache.sshd.common.random.SingletonRandomFactory;
import org.apache.sshd.common.session.Session;
import org.apache.sshd.common.session.SessionListener;
import org.apache.sshd.common.util.GenericUtils;
import org.apache.sshd.util.test.BaseTestSupport;
import org.junit.FixMethodOrder;
import org.junit.Test;
import org.junit.runners.MethodSorters;
import org.mockito.Mockito;
/**
* @author <a href="mailto:dev@mina.apache.org">Apache MINA SSHD Project</a>
*/
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
public class ClientAuthenticationManagerTest extends BaseTestSupport {
public ClientAuthenticationManagerTest() {
super();
}
@Test
public void testDefaultUserAuthFactoriesMethods() {
AtomicReference<List<NamedFactory<UserAuth>>> factoriesHolder = new AtomicReference<>();
@SuppressWarnings("checkstyle:anoninnerlength")
ClientAuthenticationManager manager = new ClientAuthenticationManager() {
@Override
public List<NamedFactory<UserAuth>> getUserAuthFactories() {
return factoriesHolder.get();
}
@Override
public void setUserAuthFactories(List<NamedFactory<UserAuth>> userAuthFactories) {
assertNull("Unexpected multiple invocation", factoriesHolder.getAndSet(userAuthFactories));
}
@Override
public KeyPairProvider getKeyPairProvider() {
return null;
}
@Override
public void setKeyPairProvider(KeyPairProvider keyPairProvider) {
throw new UnsupportedOperationException("setKeyPairProvider(" + keyPairProvider + ")");
}
@Override
public UserInteraction getUserInteraction() {
return null;
}
@Override
public void setUserInteraction(UserInteraction userInteraction) {
throw new UnsupportedOperationException("setUserInteraction(" + userInteraction + ")");
}
@Override
public ServerKeyVerifier getServerKeyVerifier() {
return null;
}
@Override
public void setServerKeyVerifier(ServerKeyVerifier serverKeyVerifier) {
throw new UnsupportedOperationException("setServerKeyVerifier(" + serverKeyVerifier + ")");
}
@Override
public PasswordIdentityProvider getPasswordIdentityProvider() {
return null;
}
@Override
public void setPasswordIdentityProvider(PasswordIdentityProvider provider) {
throw new UnsupportedOperationException("setPasswordIdentityProvider(" + provider + ")");
}
@Override
public AuthenticationIdentitiesProvider getRegisteredIdentities() {
return null;
}
@Override
public void addPublicKeyIdentity(KeyPair key) {
throw new UnsupportedOperationException("addPublicKeyIdentity(" + key + ")");
}
@Override
public KeyPair removePublicKeyIdentity(KeyPair kp) {
throw new UnsupportedOperationException("removePublicKeyIdentity(" + kp + ")");
}
@Override
public void addPasswordIdentity(String password) {
throw new UnsupportedOperationException("addPasswordIdentity(" + password + ")");
}
@Override
public String removePasswordIdentity(String password) {
throw new UnsupportedOperationException("removePasswordIdentity(" + password + ")");
}
};
assertEquals("Mismatched initial factories list", "", manager.getUserAuthFactoriesNameList());
String expected = NamedResource.Utils.getNames(BuiltinUserAuthFactories.VALUES);
manager.setUserAuthFactoriesNameList(expected);
assertEquals("Mismatched updated factories names", expected, manager.getUserAuthFactoriesNameList());
List<NamedFactory<UserAuth>> factories = factoriesHolder.get();
assertEquals("Mismatched factories count", BuiltinUserAuthFactories.VALUES.size(), GenericUtils.size(factories));
for (BuiltinUserAuthFactories f : BuiltinUserAuthFactories.VALUES) {
assertTrue("Missing factory=" + f.name(), factories.contains(f.create()));
}
}
@Test
public void testAddRemoveClientSessionIdentities() throws Exception {
try (ClientSession session = createMockClientSession()) {
testClientAuthenticationManager(session);
}
}
@Test
public void testAddRemoveSshClientIdentities() throws Exception {
try (SshClient client = SshClient.setUpDefaultClient()) {
testClientAuthenticationManager(client);
}
}
@Test
public void testClientProvidersPropagation() throws Exception {
try (SshClient client = SshClient.setUpDefaultClient()) {
client.setServiceFactories(SshClient.DEFAULT_SERVICE_FACTORIES);
client.setUserAuthFactories(SshClient.DEFAULT_USER_AUTH_FACTORIES);
try (ClientSession session = createMockClientSession(client)) {
for (Class<?> provider : new Class<?>[] {
PasswordIdentityProvider.class,
ServerKeyVerifier.class,
UserInteraction.class,
KeyPairProvider.class
}) {
testClientProvidersPropagation(provider, client, session);
}
}
}
}
private void testClientProvidersPropagation(Class<?> type, ClientAuthenticationManager client, ClientAuthenticationManager session) throws Exception {
String baseName = type.getSimpleName();
outputDebugMessage("testClientProvidersPropagation(%s)", baseName);
assertTrue(baseName + ": not an interface", type.isInterface());
Method getter = ClientAuthenticationManager.class.getMethod("get" + baseName);
Method setter = ClientAuthenticationManager.class.getMethod("set" + baseName, type);
Object clientProvider = Mockito.mock(type);
setter.invoke(client, clientProvider);
assertSame(baseName + ": mismatched client-only provider", clientProvider, getter.invoke(session));
Object sessionProvider = Mockito.mock(type);
setter.invoke(session, sessionProvider);
assertSame(baseName + ": mismatched session override provider", sessionProvider, getter.invoke(session));
setter.invoke(session, new Object[]{null});
assertSame(baseName + ": mismatched nullified session provider", clientProvider, getter.invoke(session));
}
private <M extends ClientAuthenticationManager> M testClientAuthenticationManager(M manager) {
if (manager != null) {
String expected = getCurrentTestName();
assertNull("Unexpected initial password identity", manager.removePasswordIdentity(expected));
manager.addPasswordIdentity(expected);
String actual = manager.removePasswordIdentity(expected);
assertSame("Mismatched removed password identity", expected, actual);
assertNull("Password identity not removed", manager.removePasswordIdentity(expected));
}
if (manager != null) {
KeyPair expected = new KeyPair(Mockito.mock(PublicKey.class), Mockito.mock(PrivateKey.class));
assertNull("Unexpected initial pubket identity", manager.removePublicKeyIdentity(expected));
manager.addPublicKeyIdentity(expected);
KeyPair actual = manager.removePublicKeyIdentity(expected);
assertSame("Mismatched removed pubkey identity", expected, actual);
assertNull("Pubkey identity not removed", manager.removePublicKeyIdentity(expected));
}
return manager;
}
private ClientSession createMockClientSession() throws Exception {
ClientFactoryManager client = Mockito.mock(ClientFactoryManager.class);
Mockito.when(client.getTcpipForwarderFactory()).thenReturn(DefaultTcpipForwarderFactory.INSTANCE);
Mockito.when(client.getSessionListenerProxy()).thenReturn(new SessionListener() {
@Override
public void sessionEvent(Session session, Event event) {
// ignored
}
@Override
public void sessionCreated(Session session) {
// ignored
}
@Override
public void sessionException(Session session, Throwable t) {
// ignored
}
@Override
public void sessionClosed(Session session) {
// ignored
}
});
Mockito.when(client.getChannelListenerProxy()).thenReturn(new ChannelListener() {
@Override
public void channelOpenSuccess(Channel channel) {
// ignored
}
@Override
public void channelOpenFailure(Channel channel, Throwable reason) {
// ignored
}
@Override
public void channelInitialized(Channel channel) {
// ignored
}
@Override
public void channelStateChanged(Channel channel, String hint) {
// ignored
}
@Override
public void channelClosed(Channel channel, Throwable reason) {
// ignored
}
});
Factory<Random> randomFactory = new SingletonRandomFactory(JceRandomFactory.INSTANCE);
Mockito.when(client.getRandomFactory()).thenReturn(randomFactory);
Mockito.when(client.getServiceFactories()).thenReturn(SshClient.DEFAULT_SERVICE_FACTORIES);
Mockito.when(client.getUserAuthFactories()).thenReturn(SshClient.DEFAULT_USER_AUTH_FACTORIES);
return createMockClientSession(client);
}
private ClientSession createMockClientSession(ClientFactoryManager client) throws Exception {
return new ClientSessionImpl(client, Mockito.mock(IoSession.class)) {
@Override
protected IoWriteFuture sendClientIdentification() {
return null;
}
@Override
protected byte[] sendKexInit() throws IOException {
return GenericUtils.EMPTY_BYTE_ARRAY;
}
@Override
public void close() throws IOException {
// ignored
}
};
}
}