KNOX-2140 - RequestUpdateHandler.ForwardedRequest#getRequestURL needs to return a valid URL (#206)
Signed-off-by: Kevin Risden <krisden@apache.org>
diff --git a/gateway-server/src/main/java/org/apache/knox/gateway/filter/PortMappingHelperHandler.java b/gateway-server/src/main/java/org/apache/knox/gateway/filter/PortMappingHelperHandler.java
index f917f8f..31de490 100644
--- a/gateway-server/src/main/java/org/apache/knox/gateway/filter/PortMappingHelperHandler.java
+++ b/gateway-server/src/main/java/org/apache/knox/gateway/filter/PortMappingHelperHandler.java
@@ -29,9 +29,9 @@
import java.util.Map;
/**
- * This is a helper handler that adjusts the "target" patch of the request.
- * Used when Topology Port Mapping feature is used.
- * See KNOX-928
+ * This is a helper handler that adjusts the "target" patch
+ * of the request when Topology Port Mapping feature is
+ * enabled. See KNOX-928.
* <p>
* This class also handles the Default Topology Feature
* where, any one of the topologies can be set to "default"
@@ -39,31 +39,33 @@
* will not need /gateway/{topology} context.
* Basically Topology Port Mapping for standard port.
* Backwards compatible to Default Topology Feature.
- *
*/
public class PortMappingHelperHandler extends HandlerWrapper {
-
- private static final GatewayMessages LOG = MessagesFactory
- .get(GatewayMessages.class);
+ private static final GatewayMessages LOG = MessagesFactory.get(GatewayMessages.class);
private final GatewayConfig config;
- private final Map<String, Integer> topologyPortMap;
-
- private String defaultTopologyRedirectContext;
+ private final String defaultTopologyRedirectContext;
public PortMappingHelperHandler(final GatewayConfig config) {
-
this.config = config;
- this.topologyPortMap = config.getGatewayPortMappings();
+ this.defaultTopologyRedirectContext = getDefaultTopologyRedirectContext(config);
+ }
- //Set up context for default topology feature.
- String defaultTopologyName = config.getDefaultTopologyName();
+ /**
+ * Set up context for default topology feature.
+ * @param config GatewayConfig object to read from
+ * @return default topology redirect context as a string
+ */
+ private String getDefaultTopologyRedirectContext(final GatewayConfig config) {
+ final String defaultTopologyName = config.getDefaultTopologyName();
// default topology feature can also be enabled using port mapping feature
// config e.g. gateway.port.mapping.{defaultTopologyName}
- if(defaultTopologyName == null && config.getGatewayPortMappings().values().contains(config.getGatewayPort())) {
+ String defaultTopologyRedirectContext = null;
+ if(defaultTopologyName == null &&
+ config.getGatewayPortMappings().containsValue(config.getGatewayPort())) {
for(final Map.Entry<String, Integer> entry: config.getGatewayPortMappings().entrySet()) {
- if(entry.getValue().intValue() == config.getGatewayPort()) {
+ if(entry.getValue().equals(config.getGatewayPort())) {
defaultTopologyRedirectContext = "/" + config.getGatewayPath() + "/" + entry.getKey();
break;
}
@@ -78,66 +80,74 @@
}
}
if (defaultTopologyRedirectContext != null) {
- LOG.defaultTopologySetup(defaultTopologyName,
- defaultTopologyRedirectContext);
+ LOG.defaultTopologySetup(defaultTopologyName, defaultTopologyRedirectContext);
}
+ return defaultTopologyRedirectContext;
}
@Override
public void handle(final String target, final Request baseRequest,
final HttpServletRequest request, final HttpServletResponse response)
throws IOException, ServletException {
-
- String newTarget = target;
- String baseURI = baseRequest.getRequestURI();
+ final String baseURI = baseRequest.getRequestURI();
final int port = baseRequest.getLocalPort();
- RequestUpdateHandler.ForwardedRequest newRequest;
- // If Port Mapping feature enabled
- if (config.isGatewayPortMappingEnabled() && topologyPortMap.containsValue(port)) {
-
- final String topologyName = topologyPortMap.entrySet()
- .stream()
- .filter(e -> e.getValue().equals(port))
- .map(Map.Entry::getKey)
- .findFirst()
- .orElse(null);
- final String gatewayTopologyContext =
- "/" + config.getGatewayPath() + "/" + topologyName;
-
- if(!target.contains(gatewayTopologyContext)) {
- newTarget = gatewayTopologyContext + target;
- }
-
- // if the request does not contain /{gatewayName}/{topologyName}
- if(!baseRequest.getRequestURI().contains(gatewayTopologyContext)) {
- newRequest = new RequestUpdateHandler.ForwardedRequest(
- request, gatewayTopologyContext, newTarget);
-
- baseRequest.setPathInfo(gatewayTopologyContext + baseRequest.getPathInfo());
- baseRequest.setURIPathQuery(gatewayTopologyContext + baseRequest.getRequestURI());
-
- LOG.topologyPortMappingUpdateRequest(target, newTarget);
- super.handle(newTarget, baseRequest, newRequest, response);
- }
- else {
- super.handle(newTarget, baseRequest, request, response);
- }
- }
- //Backwards compatibility for default topology feature
- else if (defaultTopologyRedirectContext != null && !baseURI
- .startsWith("/" + config.getGatewayPath())) {
- newTarget = defaultTopologyRedirectContext + target;
-
- newRequest = new RequestUpdateHandler.ForwardedRequest(
- request, defaultTopologyRedirectContext, newTarget);
-
- LOG.defaultTopologyForward(target, newTarget);
- super.handle(newTarget, baseRequest, newRequest, response);
-
+ if (config.isGatewayPortMappingEnabled()
+ && config.getGatewayPortMappings().containsValue(port)) {
+ // If Port Mapping feature enabled
+ handlePortMapping(target, baseRequest, request, response, port);
+ } else if (defaultTopologyRedirectContext != null &&
+ !baseURI.startsWith("/" + config.getGatewayPath())) {
+ //Backwards compatibility for default topology feature
+ handleDefaultTopologyMapping(target, baseRequest, request, response);
} else {
- /* case where topology port mapping is not enabled (or improperly configured) and no default topology is configured */
+ // case where topology port mapping is not enabled (or improperly configured)
+ // and no default topology is configured
+ super.handle(target, baseRequest, request, response);
+ }
+ }
+
+ private void handlePortMapping(final String target, final Request baseRequest,
+ final HttpServletRequest request,
+ final HttpServletResponse response, final int port)
+ throws IOException, ServletException {
+ final String topologyName = config.getGatewayPortMappings().entrySet()
+ .stream()
+ .filter(e -> e.getValue().equals(port))
+ .map(Map.Entry::getKey)
+ .findFirst()
+ .orElse(null);
+ final String gatewayTopologyContext = "/" + config.getGatewayPath() + "/" + topologyName;
+ String newTarget = target;
+
+ if(!target.contains(gatewayTopologyContext)) {
+ newTarget = gatewayTopologyContext + target;
+ }
+
+ // if the request does not contain /{gatewayName}/{topologyName}
+ if(!baseRequest.getRequestURI().contains(gatewayTopologyContext)) {
+ RequestUpdateHandler.ForwardedRequest newRequest = new RequestUpdateHandler.ForwardedRequest(
+ request, gatewayTopologyContext);
+
+ baseRequest.setPathInfo(gatewayTopologyContext + baseRequest.getPathInfo());
+ baseRequest.setURIPathQuery(gatewayTopologyContext + baseRequest.getRequestURI());
+
+ LOG.topologyPortMappingUpdateRequest(target, newTarget);
+ super.handle(newTarget, baseRequest, newRequest, response);
+ } else {
super.handle(newTarget, baseRequest, request, response);
}
}
+
+ private void handleDefaultTopologyMapping(final String target, final Request baseRequest,
+ final HttpServletRequest request,
+ final HttpServletResponse response)
+ throws IOException, ServletException {
+ RequestUpdateHandler.ForwardedRequest newRequest = new RequestUpdateHandler.ForwardedRequest(
+ request, defaultTopologyRedirectContext);
+
+ final String newTarget = defaultTopologyRedirectContext + target;
+ LOG.defaultTopologyForward(target, newTarget);
+ super.handle(newTarget, baseRequest, newRequest, response);
+ }
}
diff --git a/gateway-server/src/main/java/org/apache/knox/gateway/filter/RequestUpdateHandler.java b/gateway-server/src/main/java/org/apache/knox/gateway/filter/RequestUpdateHandler.java
index 7691422..fc0e631 100644
--- a/gateway-server/src/main/java/org/apache/knox/gateway/filter/RequestUpdateHandler.java
+++ b/gateway-server/src/main/java/org/apache/knox/gateway/filter/RequestUpdateHandler.java
@@ -29,6 +29,7 @@
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
+import java.util.Locale;
/**
* This handler will be ONLY registered with a specific connector listening on a
@@ -42,9 +43,7 @@
*
*/
public class RequestUpdateHandler extends ScopedHandler {
-
- private static final GatewayMessages LOG = MessagesFactory
- .get(GatewayMessages.class);
+ private static final GatewayMessages LOG = MessagesFactory.get(GatewayMessages.class);
private String redirectContext;
@@ -63,7 +62,6 @@
}
redirectContext = "/" + config.getGatewayPath() + "/" + topologyName;
-
}
@Override
@@ -78,22 +76,20 @@
final HttpServletRequest request, final HttpServletResponse response)
throws IOException, ServletException {
- final String newTarget = redirectContext + target;
-
RequestUpdateHandler.ForwardedRequest newRequest = new RequestUpdateHandler.ForwardedRequest(
- request, redirectContext, newTarget);
+ request, redirectContext);
// if the request already has the /{gatewaypath}/{topology} part then skip
if (!StringUtils.startsWithIgnoreCase(target, redirectContext)) {
baseRequest.setPathInfo(redirectContext + baseRequest.getPathInfo());
baseRequest.setURIPathQuery(redirectContext + baseRequest.getRequestURI());
+ final String newTarget = redirectContext + target;
LOG.topologyPortMappingUpdateRequest(target, newTarget);
nextHandle(newTarget, baseRequest, newRequest, response);
} else {
nextHandle(target, baseRequest, newRequest, response);
}
-
}
/**
@@ -101,32 +97,47 @@
* needed.
*/
static class ForwardedRequest extends HttpServletRequestWrapper {
+ private final String contextPath;
+ private final String requestURL;
- private String newURL;
- private String contextpath;
-
- ForwardedRequest(final HttpServletRequest request,
- final String contextpath, final String newURL) {
+ ForwardedRequest(final HttpServletRequest request, final String contextPath) {
super(request);
- this.newURL = newURL;
- this.contextpath = contextpath;
+ this.contextPath = contextPath;
+ this.requestURL = generateRequestURL();
+ }
+
+ /**
+ * Handle the case where getServerPort returns -1
+ * @return requestURL
+ */
+ private String generateRequestURL() {
+ if (getRequest().getServerPort() != -1) {
+ return String.format(Locale.ROOT, "%s://%s:%s%s",
+ getRequest().getScheme(),
+ getRequest().getServerName(),
+ getRequest().getServerPort(),
+ getRequestURI());
+ } else {
+ return String.format(Locale.ROOT, "%s://%s%s",
+ getRequest().getScheme(),
+ getRequest().getServerName(),
+ getRequestURI());
+ }
}
@Override
public StringBuffer getRequestURL() {
- return new StringBuffer(newURL);
+ return new StringBuffer(this.requestURL);
}
@Override
public String getRequestURI() {
- return contextpath + super.getRequestURI();
+ return this.contextPath + super.getRequestURI();
}
@Override
public String getContextPath() {
- return this.contextpath;
+ return this.contextPath;
}
-
}
-
}
diff --git a/gateway-server/src/test/java/org/apache/knox/gateway/filter/ForwardedRequestTest.java b/gateway-server/src/test/java/org/apache/knox/gateway/filter/ForwardedRequestTest.java
new file mode 100644
index 0000000..f72e86f
--- /dev/null
+++ b/gateway-server/src/test/java/org/apache/knox/gateway/filter/ForwardedRequestTest.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with this
+ * work for additional information regarding copyright ownership. The ASF
+ * licenses this file to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
+package org.apache.knox.gateway.filter;
+
+import org.easymock.EasyMock;
+import org.junit.Assert;
+import org.junit.Test;
+
+import javax.servlet.http.HttpServletRequest;
+import java.util.Locale;
+
+public class ForwardedRequestTest {
+ @Test
+ public void testForwardedRequestNoContextDefaultPort() {
+ String scheme = "http";
+ String host = "localhost";
+ String context = "/abc";
+ String requestURL = String.format(Locale.ROOT, "%s://%s%s", scheme, host, context);
+
+ HttpServletRequest httpServletRequest = EasyMock.mock(HttpServletRequest.class);
+ EasyMock.expect(httpServletRequest.getScheme()).andReturn(scheme);
+ EasyMock.expect(httpServletRequest.getServerName()).andReturn(host);
+ EasyMock.expect(httpServletRequest.getServerPort()).andReturn(-1);
+ EasyMock.expect(httpServletRequest.getRequestURI()).andReturn(context).anyTimes();
+ EasyMock.expect(httpServletRequest.getRequestURL()).andReturn(new StringBuffer(requestURL));
+ EasyMock.replay(httpServletRequest);
+
+ String contextPath = "";
+ RequestUpdateHandler.ForwardedRequest forwardedRequest =
+ new RequestUpdateHandler.ForwardedRequest(httpServletRequest, contextPath);
+ Assert.assertEquals(httpServletRequest.getRequestURL().toString(),
+ forwardedRequest.getRequestURL().toString());
+ Assert.assertEquals(httpServletRequest.getRequestURI(), forwardedRequest.getRequestURI());
+ Assert.assertEquals(contextPath, forwardedRequest.getContextPath());
+ }
+
+ @Test
+ public void testForwardedRequestNoContext() {
+ String scheme = "http";
+ String host = "localhost";
+ int port = 8443;
+ String context = "/abc";
+ String requestURL = String.format(Locale.ROOT, "%s://%s:%s%s", scheme, host, port, context);
+
+ HttpServletRequest httpServletRequest = EasyMock.mock(HttpServletRequest.class);
+ EasyMock.expect(httpServletRequest.getScheme()).andReturn(scheme);
+ EasyMock.expect(httpServletRequest.getServerName()).andReturn(host);
+ EasyMock.expect(httpServletRequest.getServerPort()).andReturn(port).anyTimes();
+ EasyMock.expect(httpServletRequest.getRequestURI()).andReturn(context).anyTimes();
+ EasyMock.expect(httpServletRequest.getRequestURL()).andReturn(new StringBuffer(requestURL));
+ EasyMock.replay(httpServletRequest);
+
+ String contextPath = "";
+ RequestUpdateHandler.ForwardedRequest forwardedRequest =
+ new RequestUpdateHandler.ForwardedRequest(httpServletRequest, contextPath);
+ Assert.assertEquals(httpServletRequest.getRequestURL().toString(),
+ forwardedRequest.getRequestURL().toString());
+ Assert.assertEquals(contextPath, forwardedRequest.getContextPath());
+ }
+
+ @Test
+ public void testForwardedRequestWithContextDefaultPort() {
+ String scheme = "http";
+ String host = "localhost";
+ String context = "/abc";
+ String requestURL = String.format(Locale.ROOT, "%s://%s%s", scheme, host, context);
+
+ HttpServletRequest httpServletRequest = EasyMock.mock(HttpServletRequest.class);
+ EasyMock.expect(httpServletRequest.getScheme()).andReturn(scheme);
+ EasyMock.expect(httpServletRequest.getServerName()).andReturn(host);
+ EasyMock.expect(httpServletRequest.getServerPort()).andReturn(-1);
+ EasyMock.expect(httpServletRequest.getRequestURI()).andReturn(context).anyTimes();
+ EasyMock.expect(httpServletRequest.getRequestURL()).andReturn(new StringBuffer(requestURL));
+ EasyMock.replay(httpServletRequest);
+
+ String contextPath = "/mycontext";
+ RequestUpdateHandler.ForwardedRequest forwardedRequest =
+ new RequestUpdateHandler.ForwardedRequest(httpServletRequest, contextPath);
+ Assert.assertEquals(
+ String.format(Locale.ROOT, "%s://%s%s", scheme, host, contextPath + context),
+ forwardedRequest.getRequestURL().toString());
+ Assert.assertEquals(contextPath + context, forwardedRequest.getRequestURI());
+ Assert.assertEquals(contextPath, forwardedRequest.getContextPath());
+ }
+
+ @Test
+ public void testForwardedRequestWithContext() {
+ String scheme = "http";
+ String host = "localhost";
+ int port = 8443;
+ String context = "/abc";
+ String requestURL = String.format(Locale.ROOT, "%s://%s:%s%s", scheme, host, port, context);
+
+ HttpServletRequest httpServletRequest = EasyMock.mock(HttpServletRequest.class);
+ EasyMock.expect(httpServletRequest.getScheme()).andReturn(scheme);
+ EasyMock.expect(httpServletRequest.getServerName()).andReturn(host);
+ EasyMock.expect(httpServletRequest.getServerPort()).andReturn(port).anyTimes();
+ EasyMock.expect(httpServletRequest.getRequestURI()).andReturn(context).anyTimes();
+ EasyMock.expect(httpServletRequest.getRequestURL()).andReturn(new StringBuffer(requestURL));
+ EasyMock.replay(httpServletRequest);
+
+ String contextPath = "/mycontext";
+ RequestUpdateHandler.ForwardedRequest forwardedRequest =
+ new RequestUpdateHandler.ForwardedRequest(httpServletRequest, contextPath);
+ Assert.assertEquals(
+ String.format(Locale.ROOT, "%s://%s:%s%s", scheme, host, port, contextPath + context),
+ forwardedRequest.getRequestURL().toString());
+ Assert.assertEquals(contextPath + context, forwardedRequest.getRequestURI());
+ Assert.assertEquals(contextPath, forwardedRequest.getContextPath());
+ }
+}