blob: 7d7781ef2aeacd896cd344404762afc117cd7b06 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.util
import org.apache.http.conn.ssl.DefaultHostnameVerifier
import org.glassfish.jersey.client.ClientConfig
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.mockito.Mockito
import sun.security.tools.keytool.CertAndKeyGen
import sun.security.x509.X500Name
import javax.net.ssl.HostnameVerifier
import javax.net.ssl.SSLContext
import javax.net.ssl.SSLPeerUnverifiedException
import javax.servlet.http.HttpServletRequest
import javax.ws.rs.client.Client
import javax.ws.rs.core.UriBuilderException
import java.security.cert.X509Certificate
@RunWith(JUnit4.class)
class WebUtilsGroovyTest extends GroovyTestCase {
static final String PCP_HEADER = "X-ProxyContextPath"
static final String FC_HEADER = "X-Forwarded-Context"
static final String FP_HEADER = "X-Forwarded-Prefix"
static final String ALLOWED_PATH = "/some/context/path"
HttpServletRequest mockRequest(Map keys) {
HttpServletRequest mockRequest = [
getContextPath: { ->
"default/path"
},
getHeader : { String k ->
switch (k) {
case PCP_HEADER:
return keys["proxy"]
break
case FC_HEADER:
return keys["forward"]
break
case FP_HEADER:
return keys["prefix"]
break
default:
return ""
}
}] as HttpServletRequest
mockRequest
}
@Test
void testShouldDetermineCorrectContextPathWhenPresent() throws Exception {
// Arrange
final String CORRECT_CONTEXT_PATH = ALLOWED_PATH
final String WRONG_CONTEXT_PATH = "this/is/a/bad/path"
// Variety of requests with different ordering of context paths (the correct one is always "some/context/path"
HttpServletRequest proxyRequest = mockRequest([proxy: CORRECT_CONTEXT_PATH])
HttpServletRequest forwardedRequest = mockRequest([forward: CORRECT_CONTEXT_PATH])
HttpServletRequest prefixRequest = mockRequest([prefix: CORRECT_CONTEXT_PATH])
HttpServletRequest proxyBeforeForwardedRequest = mockRequest([proxy: CORRECT_CONTEXT_PATH, forward: WRONG_CONTEXT_PATH])
HttpServletRequest proxyBeforePrefixRequest = mockRequest([proxy: CORRECT_CONTEXT_PATH, prefix: WRONG_CONTEXT_PATH])
HttpServletRequest forwardBeforePrefixRequest = mockRequest([forward: CORRECT_CONTEXT_PATH, prefix: WRONG_CONTEXT_PATH])
List<HttpServletRequest> requests = [proxyRequest, forwardedRequest, prefixRequest, proxyBeforeForwardedRequest,
proxyBeforePrefixRequest, forwardBeforePrefixRequest]
// Act
requests.each { HttpServletRequest request ->
String determinedContextPath = WebUtils.determineContextPath(request)
// Assert
assert determinedContextPath == CORRECT_CONTEXT_PATH
}
}
@Test
void testShouldDetermineCorrectContextPathWhenAbsent() throws Exception {
// Arrange
final String CORRECT_CONTEXT_PATH = ""
// Variety of requests with different ordering of non-existent context paths (the correct one is always ""
HttpServletRequest proxyRequest = mockRequest([proxy: ""])
HttpServletRequest proxySpacesRequest = mockRequest([proxy: " "])
HttpServletRequest forwardedRequest = mockRequest([forward: ""])
HttpServletRequest forwardedSpacesRequest = mockRequest([forward: " "])
HttpServletRequest prefixRequest = mockRequest([prefix: ""])
HttpServletRequest prefixSpacesRequest = mockRequest([prefix: " "])
HttpServletRequest proxyBeforeForwardedOrPrefixRequest = mockRequest([proxy: "", forward: "", prefix: ""])
HttpServletRequest proxyBeforeForwardedOrPrefixSpacesRequest = mockRequest([proxy: " ", forward: " ", prefix: " "])
List<HttpServletRequest> requests = [proxyRequest, proxySpacesRequest, forwardedRequest, forwardedSpacesRequest, prefixRequest, prefixSpacesRequest,
proxyBeforeForwardedOrPrefixRequest, proxyBeforeForwardedOrPrefixSpacesRequest]
// Act
requests.each { HttpServletRequest request ->
String determinedContextPath = WebUtils.determineContextPath(request)
// Assert
assert determinedContextPath == CORRECT_CONTEXT_PATH
}
}
@Test
void testShouldNormalizeContextPath() throws Exception {
// Arrange
final String CORRECT_CONTEXT_PATH = ALLOWED_PATH
final String TRIMMED_PATH = ALLOWED_PATH[1..-1] // Trims leading /
// Variety of different context paths (the correct one is always "/some/context/path")
List<String> contextPaths = ["/$TRIMMED_PATH", "/" + TRIMMED_PATH, TRIMMED_PATH, TRIMMED_PATH + "/"]
// Act
contextPaths.each { String contextPath ->
String normalizedContextPath = WebUtils.normalizeContextPath(contextPath)
// Assert
assert normalizedContextPath == CORRECT_CONTEXT_PATH
}
}
@Test
void testVerifyContextPathShouldAllowContextPathHeaderIfInAllowList() throws Exception {
WebUtils.verifyContextPath(Arrays.asList(ALLOWED_PATH), ALLOWED_PATH)
}
@Test
void testVerifyContextPathShouldAllowContextPathHeaderIfInMultipleAllowLists() throws Exception {
WebUtils.verifyContextPath(Arrays.asList(ALLOWED_PATH, ALLOWED_PATH.reverse()), ALLOWED_PATH)
}
@Test
void testVerifyContextPathShouldAllowContextPathHeaderIfBlank() throws Exception {
def emptyContextPaths = ["", " ", "\t", null]
emptyContextPaths.each { String contextPath ->
WebUtils.verifyContextPath(Arrays.asList(ALLOWED_PATH), contextPath)
}
}
@Test
void testVerifyContextPathShouldBlockContextPathHeaderIfNotAllowed() throws Exception {
def invalidContextPaths = ["/other/path", "localhost", "/../trying/to/escape"]
invalidContextPaths.each { String contextPath ->
shouldFail(UriBuilderException) {
WebUtils.verifyContextPath(Arrays.asList(ALLOWED_PATH), contextPath)
}
}
}
@Test
void testHostnameVerifierType() {
// Arrange
SSLContext sslContext = Mockito.mock(SSLContext.class)
final ClientConfig clientConfig = new ClientConfig()
// Act
Client client = WebUtils.createClient(clientConfig, sslContext)
HostnameVerifier hostnameVerifier = client.getHostnameVerifier()
// Assert
assertTrue(hostnameVerifier instanceof DefaultHostnameVerifier)
}
@Test
void testHostnameVerifierWildcard() {
// Arrange
final String EXPECTED_DN = "CN=*.apache.com,OU=Security,O=Apache,ST=CA,C=US"
final String hostname = "nifi.apache.com"
X509Certificate cert = generateCertificate(EXPECTED_DN)
SSLContext sslContext = Mockito.mock(SSLContext.class)
final ClientConfig clientConfig = new ClientConfig()
// Act
Client client = WebUtils.createClient(clientConfig, sslContext)
DefaultHostnameVerifier hostnameVerifier = (DefaultHostnameVerifier) client.getHostnameVerifier()
// Verify
hostnameVerifier.verify(hostname, cert)
}
@Test
void testHostnameVerifierDNWildcardFourthLevelDomain() {
// Arrange
final String EXPECTED_DN = "CN=*.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
final String clientHostname = "client.nifi.apache.org"
final String serverHostname = "server.nifi.apache.org"
X509Certificate cert = generateCertificate(EXPECTED_DN)
SSLContext sslContext = Mockito.mock(SSLContext.class)
final ClientConfig clientConfig = new ClientConfig()
// Act
Client client = WebUtils.createClient(clientConfig, sslContext)
DefaultHostnameVerifier hostnameVerifier = client.getHostnameVerifier()
// Verify
hostnameVerifier.verify(clientHostname, cert)
hostnameVerifier.verify(serverHostname, cert)
}
@Test(expected = SSLPeerUnverifiedException)
void testHostnameVerifierDomainLevelMismatch() {
// Arrange
final String EXPECTED_DN = "CN=*.nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
final String hostname = "nifi.apache.org"
X509Certificate cert = generateCertificate(EXPECTED_DN)
SSLContext sslContext = Mockito.mock(SSLContext.class)
final ClientConfig clientConfig = new ClientConfig()
// Act
Client client = WebUtils.createClient(clientConfig, sslContext)
DefaultHostnameVerifier hostnameVerifier = client.getHostnameVerifier()
// Verify
hostnameVerifier.verify(hostname, cert)
}
@Test(expected = SSLPeerUnverifiedException)
void testHostnameVerifierEmptyHostname() {
// Arrange
final String EXPECTED_DN = "CN=nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
final String hostname = ""
X509Certificate cert = generateCertificate(EXPECTED_DN)
SSLContext sslContext = Mockito.mock(SSLContext.class)
final ClientConfig clientConfig = new ClientConfig()
// Act
Client client = WebUtils.createClient(clientConfig, sslContext)
DefaultHostnameVerifier hostnameVerifier = client.getHostnameVerifier()
// Verify
hostnameVerifier.verify(hostname, cert)
}
@Test(expected = SSLPeerUnverifiedException)
void testHostnameVerifierDifferentSubdomain() {
// Arrange
final String EXPECTED_DN = "CN=nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
final String hostname = "egg.apache.org"
X509Certificate cert = generateCertificate(EXPECTED_DN)
SSLContext sslContext = Mockito.mock(SSLContext.class)
final ClientConfig clientConfig = new ClientConfig()
// Act
Client client = WebUtils.createClient(clientConfig, sslContext)
DefaultHostnameVerifier hostnameVerifier = client.getHostnameVerifier()
// Verify
hostnameVerifier.verify(hostname, cert)
}
@Test(expected = SSLPeerUnverifiedException)
void testHostnameVerifierDifferentTLD() {
// Arrange
final String EXPECTED_DN = "CN=nifi.apache.org,OU=Security,O=Apache,ST=CA,C=US"
final String hostname = "nifi.apache.com"
X509Certificate cert = generateCertificate(EXPECTED_DN)
SSLContext sslContext = Mockito.mock(SSLContext.class)
final ClientConfig clientConfig = new ClientConfig()
// Act
Client client = WebUtils.createClient(clientConfig, sslContext)
DefaultHostnameVerifier hostnameVerifier = client.getHostnameVerifier()
// Verify
hostnameVerifier.verify(hostname, cert)
}
@Test
void testHostnameVerifierWildcardTLD() {
// Arrange
final String EXPECTED_DN = "CN=nifi.apache.*,OU=Security,O=Apache,ST=CA,C=US"
final String comTLDhostname = "nifi.apache.com"
final String orgTLDHostname = "nifi.apache.org"
X509Certificate cert = generateCertificate(EXPECTED_DN)
SSLContext sslContext = Mockito.mock(SSLContext.class)
final ClientConfig clientConfig = new ClientConfig()
// Act
Client client = WebUtils.createClient(clientConfig, sslContext)
DefaultHostnameVerifier hostnameVerifier = client.getHostnameVerifier()
// Verify
hostnameVerifier.verify(comTLDhostname, cert)
hostnameVerifier.verify(orgTLDHostname, cert)
}
@Test
void testHostnameVerifierWildcardDomain() {
// Arrange
final String EXPECTED_DN = "CN=nifi.*.com,OU=Security,O=Apache,ST=CA,C=US"
final String hostname = "nifi.apache.com"
X509Certificate cert = generateCertificate(EXPECTED_DN)
SSLContext sslContext = Mockito.mock(SSLContext.class)
final ClientConfig clientConfig = new ClientConfig()
// Act
Client client = WebUtils.createClient(clientConfig, sslContext)
DefaultHostnameVerifier hostnameVerifier = client.getHostnameVerifier()
// Verify
hostnameVerifier.verify(hostname, cert)
}
X509Certificate generateCertificate(String DN) {
CertAndKeyGen certGenerator = new CertAndKeyGen("RSA", "SHA256WithRSA", null)
certGenerator.generate(2048)
long validityPeriod = (long) 365 * 24 * 60 * 60 // 1 YEAR
X509Certificate cert = certGenerator.getSelfCertificate(new X500Name(DN), validityPeriod)
return cert
}
}