KNOX-3023 - Include groups in a header in ConfigurableDispatch (#903)
diff --git a/gateway-spi/src/main/java/org/apache/knox/gateway/SpiGatewayMessages.java b/gateway-spi/src/main/java/org/apache/knox/gateway/SpiGatewayMessages.java
index cbd310a..771e616 100644
--- a/gateway-spi/src/main/java/org/apache/knox/gateway/SpiGatewayMessages.java
+++ b/gateway-spi/src/main/java/org/apache/knox/gateway/SpiGatewayMessages.java
@@ -120,4 +120,7 @@
@Message( level = MessageLevel.DEBUG, text = "Malformed dispatch URL: {0}" )
void malformedDispatchUrl(String url);
+
+ @Message( level = MessageLevel.ERROR, text = "No valid principal found" )
+ void noPrincipalFound();
}
diff --git a/gateway-spi/src/main/java/org/apache/knox/gateway/dispatch/ConfigurableDispatch.java b/gateway-spi/src/main/java/org/apache/knox/gateway/dispatch/ConfigurableDispatch.java
index 9bbcde7..6f1d240 100644
--- a/gateway-spi/src/main/java/org/apache/knox/gateway/dispatch/ConfigurableDispatch.java
+++ b/gateway-spi/src/main/java/org/apache/knox/gateway/dispatch/ConfigurableDispatch.java
@@ -22,20 +22,29 @@
import org.apache.knox.gateway.audit.api.ActionOutcome;
import org.apache.knox.gateway.config.Configure;
import org.apache.knox.gateway.config.Default;
+import org.apache.knox.gateway.security.SubjectUtils;
import org.apache.knox.gateway.util.StringUtils;
+import javax.security.auth.Subject;
import javax.servlet.http.HttpServletRequest;
import java.io.UnsupportedEncodingException;
import java.net.URI;
import java.net.URLDecoder;
import java.nio.charset.StandardCharsets;
+
import java.util.Collections;
-import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.Arrays;
-import java.util.Optional;
+import java.util.HashSet;
import java.util.HashMap;
+import java.util.Optional;
+import java.util.List;
+import java.util.Collection;
+import java.util.Locale;
+import java.util.ArrayList;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
@@ -50,6 +59,21 @@
private Set<String> responseExcludeSetCookieHeaderDirectives = super.getOutboundResponseExcludedSetCookieHeaderDirectives();
private Boolean removeUrlEncoding = false;
+ private boolean shouldIncludePrincipalAndGroups;
+ private String actorIdHeaderName = DEFAULT_AUTH_ACTOR_ID_HEADER_NAME;
+ private String actorGroupsHeaderPrefix = DEFAULT_AUTH_ACTOR_GROUPS_HEADER_PREFIX;
+ private String groupFilterPattern = DEFAULT_GROUP_FILTER_PATTERN;
+
+ static final String DEFAULT_AUTH_ACTOR_ID_HEADER_NAME = "X-Knox-Actor-ID";
+ static final String DEFAULT_AUTH_ACTOR_GROUPS_HEADER_PREFIX = "X-Knox-Actor-Groups";
+ static final String DEFAULT_GROUP_FILTER_PATTERN = ".*";
+ static final String DEFAULT_ARE_USERS_GROUPS_HEADER_INCLUDED = "false";
+
+ protected static final int MAX_HEADER_LENGTH = 1000;
+ protected static final String ACTOR_GROUPS_HEADER_FORMAT = "%s-%d";
+ protected Pattern groupPattern = Pattern.compile(DEFAULT_GROUP_FILTER_PATTERN);
+
+
private Set<String> convertCommaDelimitedHeadersToSet(String headers) {
return headers == null ? Collections.emptySet(): new HashSet<>(Arrays.asList(headers.split("\\s*,\\s*")));
}
@@ -123,6 +147,27 @@
this.removeUrlEncoding = Boolean.parseBoolean(removeUrlEncoding);
}
+ @Configure
+ public void setShouldIncludePrincipalAndGroups(@Default(DEFAULT_ARE_USERS_GROUPS_HEADER_INCLUDED) boolean shouldIncludePrincipalAndGroups) {
+ this.shouldIncludePrincipalAndGroups = shouldIncludePrincipalAndGroups;
+ }
+
+ @Configure
+ public void setActorIdHeaderName(@Default(DEFAULT_AUTH_ACTOR_ID_HEADER_NAME) String actorIdHeaderName) {
+ this.actorIdHeaderName = actorIdHeaderName;
+ }
+
+ @Configure
+ public void setActorGroupsHeaderPrefix(@Default(DEFAULT_AUTH_ACTOR_GROUPS_HEADER_PREFIX) String actorGroupsHeaderPrefix) {
+ this.actorGroupsHeaderPrefix = actorGroupsHeaderPrefix;
+ }
+
+ @Configure
+ public void setGroupFilterPattern(@Default(DEFAULT_GROUP_FILTER_PATTERN) String groupFilterPattern) {
+ this.groupFilterPattern = groupFilterPattern;
+ groupPattern = Pattern.compile(this.groupFilterPattern);
+ }
+
@Override
public void copyRequestHeaderFields(HttpUriRequest outboundRequest,
HttpServletRequest inboundRequest) {
@@ -133,6 +178,61 @@
if(MapUtils.isNotEmpty(extraHeaders)){
extraHeaders.forEach(outboundRequest::addHeader);
}
+
+ /* If we need to add user and groups to outbound request */
+ if(shouldIncludePrincipalAndGroups) {
+ Map<String, String> groups = addPrincipalAndGroups();
+ if(MapUtils.isNotEmpty(groups)){
+ groups.forEach(outboundRequest::addHeader);
+ }
+ }
+ }
+
+ private Map<String, String> addPrincipalAndGroups() {
+ final Map<String, String> headers = new ConcurrentHashMap();
+ final Subject subject = SubjectUtils.getCurrentSubject();
+
+ final String primaryPrincipalName = subject == null ? null : SubjectUtils.getPrimaryPrincipalName(subject);
+ if (primaryPrincipalName == null) {
+ LOG.noPrincipalFound();
+ headers.put(actorIdHeaderName, "");
+ } else {
+ headers.put(actorIdHeaderName, primaryPrincipalName);
+ }
+
+ // Populate actor groups headers
+ final Set<String> matchingGroupNames = subject == null ? Collections.emptySet()
+ : SubjectUtils.getGroupPrincipals(subject).stream().filter(group -> groupPattern.matcher(group.getName()).matches()).map(group -> group.getName())
+ .collect(Collectors.toSet());
+ if (!matchingGroupNames.isEmpty()) {
+ final List<String> groupStrings = getGroupStrings(matchingGroupNames);
+ for (int i = 0; i < groupStrings.size(); i++) {
+ headers.put(String.format(Locale.ROOT, ACTOR_GROUPS_HEADER_FORMAT, actorGroupsHeaderPrefix, i + 1), groupStrings.get(i));
+ }
+ }
+ return headers;
+ }
+
+ private List<String> getGroupStrings(final Collection<String> groupNames) {
+ if (groupNames.isEmpty()) {
+ return Collections.emptyList();
+ }
+ List<String> groupStrings = new ArrayList<>();
+ StringBuilder sb = new StringBuilder();
+ for (String groupName : groupNames) {
+ if (sb.length() + groupName.length() > MAX_HEADER_LENGTH) {
+ groupStrings.add(sb.toString());
+ sb = new StringBuilder();
+ }
+ if (sb.length() > 0) {
+ sb.append(',');
+ }
+ sb.append(groupName);
+ }
+ if (sb.length() > 0) {
+ groupStrings.add(sb.toString());
+ }
+ return groupStrings;
}
@Override
@@ -180,4 +280,5 @@
return super.getDispatchUrl(request);
}
+
}
diff --git a/gateway-spi/src/test/java/org/apache/knox/gateway/dispatch/ConfigurableDispatchTest.java b/gateway-spi/src/test/java/org/apache/knox/gateway/dispatch/ConfigurableDispatchTest.java
index 8f38fd5..0386ac9 100644
--- a/gateway-spi/src/test/java/org/apache/knox/gateway/dispatch/ConfigurableDispatchTest.java
+++ b/gateway-spi/src/test/java/org/apache/knox/gateway/dispatch/ConfigurableDispatchTest.java
@@ -18,6 +18,7 @@
package org.apache.knox.gateway.dispatch;
import static org.apache.knox.gateway.dispatch.AbstractGatewayDispatch.REQUEST_ID_HEADER_NAME;
+import static org.apache.knox.gateway.dispatch.ConfigurableDispatch.DEFAULT_AUTH_ACTOR_ID_HEADER_NAME;
import static org.apache.knox.gateway.dispatch.DefaultDispatch.SET_COOKIE;
import static org.apache.knox.gateway.dispatch.DefaultDispatch.WWW_AUTHENTICATE;
import static org.hamcrest.CoreMatchers.containsString;
@@ -26,12 +27,15 @@
import static org.junit.Assert.assertThat;
import java.net.URI;
+import java.security.PrivilegedActionException;
+import java.security.PrivilegedExceptionAction;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
+import javax.security.auth.Subject;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@@ -41,6 +45,8 @@
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.message.BasicHeader;
+import org.apache.knox.gateway.security.GroupPrincipal;
+import org.apache.knox.gateway.security.PrimaryPrincipal;
import org.apache.knox.test.TestUtils;
import org.apache.knox.test.mock.MockHttpServletResponse;
import org.apache.logging.log4j.CloseableThreadContext;
@@ -316,7 +322,7 @@
assertThat(outboundRequestHeaders[3].getName(), is("c"));
}
- @Test( timeout = TestUtils.SHORT_TIMEOUT )
+ @Test( timeout = TestUtils.LONG_TIMEOUT )
public void testRequestExcludeAndAppendHeadersConfig() {
ConfigurableDispatch dispatch = new ConfigurableDispatch();
dispatch.setRequestAppendHeaders("a : b ; c : d");
@@ -724,4 +730,47 @@
assertThat(outboundResponse.getHeader(REQUEST_ID_HEADER_NAME), nullValue());
}
+ /**
+ * Make sure X-Knox-Actor-ID and X-Knox-Actor-Groups-1 headers
+ * are added for authenticated users.
+ */
+ @Test
+ public void testGroupHeaders() throws PrivilegedActionException {
+ Subject subject = new Subject();
+ subject.getPrincipals().add(new PrimaryPrincipal("knoxui"));
+ subject.getPrincipals().add(new GroupPrincipal("knox"));
+ subject.getPrincipals().add(new GroupPrincipal("admin"));
+
+ ConfigurableDispatch dispatch = new ConfigurableDispatch();
+ final String headerReqID = "1234567890ABCD";
+ dispatch.setShouldIncludePrincipalAndGroups(true);
+
+ Map<String, String> headers = new HashMap<>();
+ headers.put(REQUEST_ID_HEADER_NAME, headerReqID);
+ headers.put(HttpHeaders.ACCEPT, "abc");
+ headers.put("TEST", "test");
+
+ HttpServletRequest inboundRequest = EasyMock.createNiceMock(HttpServletRequest.class);
+ EasyMock.expect(inboundRequest.getHeaderNames()).andReturn(Collections.enumeration(headers.keySet())).anyTimes();
+ Capture<String> capturedArgument = Capture.newInstance();
+ EasyMock.expect(inboundRequest.getHeader(EasyMock.capture(capturedArgument)))
+ .andAnswer(() -> headers.get(capturedArgument.getValue())).anyTimes();
+ EasyMock.replay(inboundRequest);
+
+ HttpUriRequest outboundRequest = new HttpGet();
+
+ Subject.doAs(subject, new PrivilegedExceptionAction<Object>() {
+
+ @Override
+ public Object run() throws Exception {
+ dispatch.copyRequestHeaderFields(outboundRequest, inboundRequest);
+ return null;
+ }
+ });
+
+ Header[] outboundRequestHeaders = outboundRequest.getAllHeaders();
+ assertThat(outboundRequestHeaders.length, is(5));
+ assertThat(outboundRequest.getHeaders(DEFAULT_AUTH_ACTOR_ID_HEADER_NAME)[0].getValue(), is("knoxui"));
+ }
+
}