blob: 1fab5ce402f56e203050cd43abd0880f29896ce7 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.metron.rest.config;
import static junit.framework.TestCase.assertNull;
import static junit.framework.TestCase.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSObject;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.interfaces.RSAPublicKey;
import java.text.ParseException;
import java.util.Arrays;
import java.util.Date;
import javax.servlet.FilterChain;
import javax.servlet.ServletResponse;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import org.apache.commons.io.FileUtils;
import org.junit.jupiter.api.Test;
import org.springframework.ldap.core.AttributesMapper;
import org.springframework.ldap.core.LdapTemplate;
import org.springframework.ldap.query.LdapQuery;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContext;
public class KnoxSSOAuthenticationFilterTest {
@Test
public void shouldThrowExceptionOnMissingLdapTemplate() {
IllegalStateException e =
assertThrows(
IllegalStateException.class,
() ->
new KnoxSSOAuthenticationFilter(
"userSearchBase", mock(Path.class), "knoxKeyString", "knoxCookie", null));
assertEquals("KnoxSSO requires LDAP. You must add 'ldap' to the active profiles.", e.getMessage());
}
@Test
public void doFilterShouldProperlySetAuthentication() throws Exception {
KnoxSSOAuthenticationFilter knoxSSOAuthenticationFilter = spy(new KnoxSSOAuthenticationFilter("userSearchBase",
mock(Path.class),
"knoxKeyString",
"knoxCookie",
mock(LdapTemplate.class)
));
HttpServletRequest request = mock(HttpServletRequest.class);
ServletResponse response = mock(ServletResponse.class);
FilterChain chain = mock(FilterChain.class);
SignedJWT signedJWT = mock(SignedJWT.class);
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().subject("userName").build();
Authentication authentication = mock(Authentication.class);
SecurityContext securityContext = mock(SecurityContext.class);
when(request.getHeader("Authorization")).thenReturn(null);
doReturn("serializedJWT").when(knoxSSOAuthenticationFilter).getJWTFromCookie(request);
doReturn(signedJWT).when(knoxSSOAuthenticationFilter).parseJWT(any());
when(signedJWT.getJWTClaimsSet()).thenReturn(jwtClaimsSet);
doReturn(true).when(knoxSSOAuthenticationFilter).isValid(signedJWT, "userName");
doReturn(authentication).when(knoxSSOAuthenticationFilter).getAuthentication("userName", request);
doReturn(securityContext).when(knoxSSOAuthenticationFilter).getSecurityContext();
knoxSSOAuthenticationFilter.doFilter(request, response, chain);
verify(securityContext).setAuthentication(authentication);
verify(chain).doFilter(request, response);
verifyNoMoreInteractions(chain, securityContext);
}
@Test
public void doFilterShouldContinueOnBasicAuthenticationHeader() throws Exception {
KnoxSSOAuthenticationFilter knoxSSOAuthenticationFilter = spy(new KnoxSSOAuthenticationFilter("userSearchBase",
mock(Path.class),
"knoxKeyString",
"knoxCookie",
mock(LdapTemplate.class)
));
HttpServletRequest request = mock(HttpServletRequest.class);
ServletResponse response = mock(ServletResponse.class);
FilterChain chain = mock(FilterChain.class);
when(request.getHeader("Authorization")).thenReturn("Basic ");
knoxSSOAuthenticationFilter.doFilter(request, response, chain);
verify(knoxSSOAuthenticationFilter, times(0)).getJWTFromCookie(request);
verify(chain).doFilter(request, response);
verifyNoMoreInteractions(chain);
}
@Test
public void doFilterShouldContinueOnParseException() throws Exception {
KnoxSSOAuthenticationFilter knoxSSOAuthenticationFilter = spy(new KnoxSSOAuthenticationFilter("userSearchBase",
mock(Path.class),
"knoxKeyString",
"knoxCookie",
mock(LdapTemplate.class)
));
HttpServletRequest request = mock(HttpServletRequest.class);
ServletResponse response = mock(ServletResponse.class);
FilterChain chain = mock(FilterChain.class);
when(request.getHeader("Authorization")).thenReturn(null);
doReturn("serializedJWT").when(knoxSSOAuthenticationFilter).getJWTFromCookie(request);
doThrow(new ParseException("parse exception", 0)).when(knoxSSOAuthenticationFilter).parseJWT(any());
knoxSSOAuthenticationFilter.doFilter(request, response, chain);
verify(knoxSSOAuthenticationFilter, times(0)).getAuthentication("userName", request);
verify(chain).doFilter(request, response);
verifyNoMoreInteractions(chain);
}
@Test
public void doFilterShouldContinueOnInvalidToken() throws Exception {
KnoxSSOAuthenticationFilter knoxSSOAuthenticationFilter = spy(new KnoxSSOAuthenticationFilter("userSearchBase",
mock(Path.class),
"knoxKeyString",
"knoxCookie",
mock(LdapTemplate.class)
));
HttpServletRequest request = mock(HttpServletRequest.class);
ServletResponse response = mock(ServletResponse.class);
FilterChain chain = mock(FilterChain.class);
SignedJWT signedJWT = mock(SignedJWT.class);
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().subject("userName").build();
when(request.getHeader("Authorization")).thenReturn(null);
doReturn("serializedJWT").when(knoxSSOAuthenticationFilter).getJWTFromCookie(request);
doReturn(signedJWT).when(knoxSSOAuthenticationFilter).parseJWT(any());
when(signedJWT.getJWTClaimsSet()).thenReturn(jwtClaimsSet);
doReturn(false).when(knoxSSOAuthenticationFilter).isValid(signedJWT, "userName");
knoxSSOAuthenticationFilter.doFilter(request, response, chain);
verify(knoxSSOAuthenticationFilter, times(0)).getAuthentication("userName", request);
verify(chain).doFilter(request, response);
verifyNoMoreInteractions(chain);
}
@Test
public void isValidShouldProperlyValidateToken() throws Exception {
KnoxSSOAuthenticationFilter knoxSSOAuthenticationFilter = spy(new KnoxSSOAuthenticationFilter("userSearchBase",
mock(Path.class),
"knoxKeyString",
"knoxCookie",
mock(LdapTemplate.class)
));
SignedJWT jwtToken = mock(SignedJWT.class);
{
// Should be invalid on emtpy user name
assertFalse(knoxSSOAuthenticationFilter.isValid(jwtToken, null));
}
{
// Should be invalid on expired token
Date expiredDate = new Date(System.currentTimeMillis() - 10000);
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().expirationTime(expiredDate).build();
when(jwtToken.getJWTClaimsSet()).thenReturn(jwtClaimsSet);
assertFalse(knoxSSOAuthenticationFilter.isValid(jwtToken, "userName"));
}
{
// Should be invalid when date is before notBeforeTime
Date notBeforeDate = new Date(System.currentTimeMillis() + 10000);
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().notBeforeTime(notBeforeDate).build();
when(jwtToken.getJWTClaimsSet()).thenReturn(jwtClaimsSet);
assertFalse(knoxSSOAuthenticationFilter.isValid(jwtToken, "userName"));
}
{
// Should be valid if user name is present and token is within time constraints
Date expiredDate = new Date(System.currentTimeMillis() + 10000);
Date notBeforeDate = new Date(System.currentTimeMillis() - 10000);
JWTClaimsSet jwtClaimsSet = new JWTClaimsSet.Builder().expirationTime(expiredDate).notBeforeTime(notBeforeDate).build();
when(jwtToken.getJWTClaimsSet()).thenReturn(jwtClaimsSet);
doReturn(true).when(knoxSSOAuthenticationFilter).validateSignature(jwtToken);
assertTrue(knoxSSOAuthenticationFilter.isValid(jwtToken, "userName"));
}
}
@Test
public void validateSignatureShouldProperlyValidateToken() throws Exception {
KnoxSSOAuthenticationFilter knoxSSOAuthenticationFilter = spy(new KnoxSSOAuthenticationFilter("userSearchBase",
mock(Path.class),
"knoxKeyString",
"knoxCookie",
mock(LdapTemplate.class)
));
SignedJWT jwtToken = mock(SignedJWT.class);
{
// Should be invalid if algorithm is not ES256
JWSHeader jwsHeader = new JWSHeader(JWSAlgorithm.ES384);
when(jwtToken.getHeader()).thenReturn(jwsHeader);
assertFalse(knoxSSOAuthenticationFilter.validateSignature(jwtToken));
}
{
// Should be invalid if state is not SIGNED
JWSHeader jwsHeader = new JWSHeader(JWSAlgorithm.RS256);
when(jwtToken.getHeader()).thenReturn(jwsHeader);
when(jwtToken.getState()).thenReturn(JWSObject.State.UNSIGNED);
assertFalse(knoxSSOAuthenticationFilter.validateSignature(jwtToken));
}
{
// Should be invalid if signature is null
JWSHeader jwsHeader = new JWSHeader(JWSAlgorithm.RS256);
when(jwtToken.getHeader()).thenReturn(jwsHeader);
when(jwtToken.getState()).thenReturn(JWSObject.State.SIGNED);
assertFalse(knoxSSOAuthenticationFilter.validateSignature(jwtToken));
}
{
Base64URL signature = mock(Base64URL.class);
when(jwtToken.getSignature()).thenReturn(signature);
RSAPublicKey rsaPublicKey = mock(RSAPublicKey.class);
RSASSAVerifier rsaSSAVerifier = mock(RSASSAVerifier.class);
doReturn(rsaSSAVerifier).when(knoxSSOAuthenticationFilter).getRSASSAVerifier();
{
// Should be invalid if token verify throws an exception
when(jwtToken.verify(rsaSSAVerifier)).thenThrow(new JOSEException("verify exception"));
assertFalse(knoxSSOAuthenticationFilter.validateSignature(jwtToken));
}
{
// Should be invalid if RSA verification fails
doReturn(false).when(jwtToken).verify(rsaSSAVerifier);
assertFalse(knoxSSOAuthenticationFilter.validateSignature(jwtToken));
}
{
// Should be valid if RSA verification succeeds
doReturn(true).when(jwtToken).verify(rsaSSAVerifier);
assertTrue(knoxSSOAuthenticationFilter.validateSignature(jwtToken));
}
}
}
@Test
public void getJWTFromCookieShouldProperlyReturnToken() {
KnoxSSOAuthenticationFilter knoxSSOAuthenticationFilter = spy(new KnoxSSOAuthenticationFilter("userSearchBase",
mock(Path.class),
"knoxKeyString",
"knoxCookie",
mock(LdapTemplate.class)
));
HttpServletRequest request = mock(HttpServletRequest.class);
{
// Should be null if cookies are empty
assertNull(knoxSSOAuthenticationFilter.getJWTFromCookie(request));
}
{
// Should be null if Knox cookie is missing
Cookie cookie = new Cookie("someCookie", "someValue");
when(request.getCookies()).thenReturn(new Cookie[]{cookie});
assertNull(knoxSSOAuthenticationFilter.getJWTFromCookie(request));
}
{
// Should return token from knoxCookie
Cookie cookie = new Cookie("knoxCookie", "token");
when(request.getCookies()).thenReturn(new Cookie[]{cookie});
assertEquals("token", knoxSSOAuthenticationFilter.getJWTFromCookie(request));
}
}
@Test
public void getKnoxKeyShouldProperlyReturnKnoxKey() throws Exception {
{
KnoxSSOAuthenticationFilter knoxSSOAuthenticationFilter = spy(new KnoxSSOAuthenticationFilter("userSearchBase",
mock(Path.class),
"knoxKeyString",
"knoxCookie",
mock(LdapTemplate.class)
));
assertEquals("knoxKeyString", knoxSSOAuthenticationFilter.getKnoxKey());
}
{
FileUtils.writeStringToFile(new File("./target/knoxKeyFile"), "knoxKeyFileKeyString", StandardCharsets.UTF_8);
KnoxSSOAuthenticationFilter knoxSSOAuthenticationFilter = spy(new KnoxSSOAuthenticationFilter("userSearchBase",
Paths.get("./target/knoxKeyFile"),
null,
"knoxCookie",
mock(LdapTemplate.class)
));
assertEquals("knoxKeyFileKeyString", knoxSSOAuthenticationFilter.getKnoxKey());
}
}
@SuppressWarnings("unchecked")
@Test
public void getAuthenticationShouldProperlyPopulateAuthentication() {
LdapTemplate ldapTemplate = mock(LdapTemplate.class);
KnoxSSOAuthenticationFilter knoxSSOAuthenticationFilter = spy(new KnoxSSOAuthenticationFilter("ou=people,dc=hadoop,dc=apache,dc=org",
mock(Path.class),
"knoxKeyString",
"knoxCookie",
ldapTemplate
));
HttpServletRequest request = mock(HttpServletRequest.class);
when(ldapTemplate.search(any(LdapQuery.class), any(AttributesMapper.class))).thenReturn(Arrays.asList("USER", "ADMIN"));
Authentication authentication = knoxSSOAuthenticationFilter.getAuthentication("userName", request);
Object[] grantedAuthorities = authentication.getAuthorities().toArray();
assertEquals("ROLE_USER", grantedAuthorities[0].toString());
assertEquals("ROLE_ADMIN", grantedAuthorities[1].toString());
assertEquals("userName", authentication.getName());
}
}