KNOX-3023 - Include groups in a header in ConfigurableDispatch (#903)

diff --git a/gateway-spi/src/main/java/org/apache/knox/gateway/SpiGatewayMessages.java b/gateway-spi/src/main/java/org/apache/knox/gateway/SpiGatewayMessages.java
index cbd310a..771e616 100644
--- a/gateway-spi/src/main/java/org/apache/knox/gateway/SpiGatewayMessages.java
+++ b/gateway-spi/src/main/java/org/apache/knox/gateway/SpiGatewayMessages.java
@@ -120,4 +120,7 @@
 
   @Message( level = MessageLevel.DEBUG, text = "Malformed dispatch URL: {0}" )
   void malformedDispatchUrl(String url);
+
+  @Message( level = MessageLevel.ERROR, text = "No valid principal found" )
+  void noPrincipalFound();
 }
diff --git a/gateway-spi/src/main/java/org/apache/knox/gateway/dispatch/ConfigurableDispatch.java b/gateway-spi/src/main/java/org/apache/knox/gateway/dispatch/ConfigurableDispatch.java
index 9bbcde7..6f1d240 100644
--- a/gateway-spi/src/main/java/org/apache/knox/gateway/dispatch/ConfigurableDispatch.java
+++ b/gateway-spi/src/main/java/org/apache/knox/gateway/dispatch/ConfigurableDispatch.java
@@ -22,20 +22,29 @@
 import org.apache.knox.gateway.audit.api.ActionOutcome;
 import org.apache.knox.gateway.config.Configure;
 import org.apache.knox.gateway.config.Default;
+import org.apache.knox.gateway.security.SubjectUtils;
 import org.apache.knox.gateway.util.StringUtils;
 
+import javax.security.auth.Subject;
 import javax.servlet.http.HttpServletRequest;
 import java.io.UnsupportedEncodingException;
 import java.net.URI;
 import java.net.URLDecoder;
 import java.nio.charset.StandardCharsets;
+
 import java.util.Collections;
-import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
 import java.util.Arrays;
-import java.util.Optional;
+import java.util.HashSet;
 import java.util.HashMap;
+import java.util.Optional;
+import java.util.List;
+import java.util.Collection;
+import java.util.Locale;
+import java.util.ArrayList;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.regex.Pattern;
 import java.util.stream.Collectors;
 
 /**
@@ -50,6 +59,21 @@
   private Set<String> responseExcludeSetCookieHeaderDirectives = super.getOutboundResponseExcludedSetCookieHeaderDirectives();
   private Boolean removeUrlEncoding = false;
 
+  private boolean shouldIncludePrincipalAndGroups;
+  private String actorIdHeaderName = DEFAULT_AUTH_ACTOR_ID_HEADER_NAME;
+  private String actorGroupsHeaderPrefix = DEFAULT_AUTH_ACTOR_GROUPS_HEADER_PREFIX;
+  private String groupFilterPattern = DEFAULT_GROUP_FILTER_PATTERN;
+
+  static final String DEFAULT_AUTH_ACTOR_ID_HEADER_NAME = "X-Knox-Actor-ID";
+  static final String DEFAULT_AUTH_ACTOR_GROUPS_HEADER_PREFIX = "X-Knox-Actor-Groups";
+  static final String DEFAULT_GROUP_FILTER_PATTERN = ".*";
+  static final String DEFAULT_ARE_USERS_GROUPS_HEADER_INCLUDED = "false";
+
+  protected static final int MAX_HEADER_LENGTH = 1000;
+  protected static final String ACTOR_GROUPS_HEADER_FORMAT = "%s-%d";
+  protected Pattern groupPattern = Pattern.compile(DEFAULT_GROUP_FILTER_PATTERN);
+
+
   private Set<String> convertCommaDelimitedHeadersToSet(String headers) {
     return headers == null ?  Collections.emptySet(): new HashSet<>(Arrays.asList(headers.split("\\s*,\\s*")));
   }
@@ -123,6 +147,27 @@
     this.removeUrlEncoding = Boolean.parseBoolean(removeUrlEncoding);
   }
 
+  @Configure
+  public void setShouldIncludePrincipalAndGroups(@Default(DEFAULT_ARE_USERS_GROUPS_HEADER_INCLUDED) boolean shouldIncludePrincipalAndGroups) {
+    this.shouldIncludePrincipalAndGroups = shouldIncludePrincipalAndGroups;
+  }
+
+  @Configure
+  public void setActorIdHeaderName(@Default(DEFAULT_AUTH_ACTOR_ID_HEADER_NAME) String actorIdHeaderName) {
+    this.actorIdHeaderName = actorIdHeaderName;
+  }
+
+  @Configure
+  public void setActorGroupsHeaderPrefix(@Default(DEFAULT_AUTH_ACTOR_GROUPS_HEADER_PREFIX) String actorGroupsHeaderPrefix) {
+    this.actorGroupsHeaderPrefix = actorGroupsHeaderPrefix;
+  }
+
+  @Configure
+  public void setGroupFilterPattern(@Default(DEFAULT_GROUP_FILTER_PATTERN) String groupFilterPattern) {
+    this.groupFilterPattern = groupFilterPattern;
+    groupPattern = Pattern.compile(this.groupFilterPattern);
+  }
+
   @Override
   public void copyRequestHeaderFields(HttpUriRequest outboundRequest,
                                       HttpServletRequest inboundRequest) {
@@ -133,6 +178,61 @@
     if(MapUtils.isNotEmpty(extraHeaders)){
       extraHeaders.forEach(outboundRequest::addHeader);
     }
+
+    /* If we need to add user and groups to outbound request */
+    if(shouldIncludePrincipalAndGroups) {
+      Map<String, String> groups = addPrincipalAndGroups();
+      if(MapUtils.isNotEmpty(groups)){
+        groups.forEach(outboundRequest::addHeader);
+      }
+    }
+  }
+
+  private Map<String, String> addPrincipalAndGroups() {
+    final Map<String, String> headers = new ConcurrentHashMap();
+    final Subject subject = SubjectUtils.getCurrentSubject();
+
+    final String primaryPrincipalName = subject == null ? null : SubjectUtils.getPrimaryPrincipalName(subject);
+    if (primaryPrincipalName == null) {
+      LOG.noPrincipalFound();
+      headers.put(actorIdHeaderName, "");
+    } else {
+      headers.put(actorIdHeaderName, primaryPrincipalName);
+    }
+
+    // Populate actor groups headers
+    final Set<String> matchingGroupNames = subject == null ? Collections.emptySet()
+            : SubjectUtils.getGroupPrincipals(subject).stream().filter(group -> groupPattern.matcher(group.getName()).matches()).map(group -> group.getName())
+            .collect(Collectors.toSet());
+    if (!matchingGroupNames.isEmpty()) {
+      final List<String> groupStrings = getGroupStrings(matchingGroupNames);
+      for (int i = 0; i < groupStrings.size(); i++) {
+        headers.put(String.format(Locale.ROOT, ACTOR_GROUPS_HEADER_FORMAT, actorGroupsHeaderPrefix, i + 1), groupStrings.get(i));
+      }
+    }
+    return headers;
+  }
+
+  private List<String> getGroupStrings(final Collection<String> groupNames) {
+    if (groupNames.isEmpty()) {
+      return Collections.emptyList();
+    }
+    List<String> groupStrings = new ArrayList<>();
+    StringBuilder sb = new StringBuilder();
+    for (String groupName : groupNames) {
+      if (sb.length() + groupName.length() > MAX_HEADER_LENGTH) {
+        groupStrings.add(sb.toString());
+        sb = new StringBuilder();
+      }
+      if (sb.length() > 0) {
+        sb.append(',');
+      }
+      sb.append(groupName);
+    }
+    if (sb.length() > 0) {
+      groupStrings.add(sb.toString());
+    }
+    return groupStrings;
   }
 
   @Override
@@ -180,4 +280,5 @@
 
     return super.getDispatchUrl(request);
   }
+
 }
diff --git a/gateway-spi/src/test/java/org/apache/knox/gateway/dispatch/ConfigurableDispatchTest.java b/gateway-spi/src/test/java/org/apache/knox/gateway/dispatch/ConfigurableDispatchTest.java
index 8f38fd5..0386ac9 100644
--- a/gateway-spi/src/test/java/org/apache/knox/gateway/dispatch/ConfigurableDispatchTest.java
+++ b/gateway-spi/src/test/java/org/apache/knox/gateway/dispatch/ConfigurableDispatchTest.java
@@ -18,6 +18,7 @@
 package org.apache.knox.gateway.dispatch;
 
 import static org.apache.knox.gateway.dispatch.AbstractGatewayDispatch.REQUEST_ID_HEADER_NAME;
+import static org.apache.knox.gateway.dispatch.ConfigurableDispatch.DEFAULT_AUTH_ACTOR_ID_HEADER_NAME;
 import static org.apache.knox.gateway.dispatch.DefaultDispatch.SET_COOKIE;
 import static org.apache.knox.gateway.dispatch.DefaultDispatch.WWW_AUTHENTICATE;
 import static org.hamcrest.CoreMatchers.containsString;
@@ -26,12 +27,15 @@
 import static org.junit.Assert.assertThat;
 
 import java.net.URI;
+import java.security.PrivilegedActionException;
+import java.security.PrivilegedExceptionAction;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.UUID;
 
+import javax.security.auth.Subject;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
@@ -41,6 +45,8 @@
 import org.apache.http.client.methods.HttpGet;
 import org.apache.http.client.methods.HttpUriRequest;
 import org.apache.http.message.BasicHeader;
+import org.apache.knox.gateway.security.GroupPrincipal;
+import org.apache.knox.gateway.security.PrimaryPrincipal;
 import org.apache.knox.test.TestUtils;
 import org.apache.knox.test.mock.MockHttpServletResponse;
 import org.apache.logging.log4j.CloseableThreadContext;
@@ -316,7 +322,7 @@
     assertThat(outboundRequestHeaders[3].getName(), is("c"));
   }
 
-  @Test( timeout = TestUtils.SHORT_TIMEOUT )
+  @Test( timeout = TestUtils.LONG_TIMEOUT )
   public void testRequestExcludeAndAppendHeadersConfig() {
     ConfigurableDispatch dispatch = new ConfigurableDispatch();
     dispatch.setRequestAppendHeaders("a : b ; c : d");
@@ -724,4 +730,47 @@
     assertThat(outboundResponse.getHeader(REQUEST_ID_HEADER_NAME), nullValue());
   }
 
+  /**
+   * Make sure X-Knox-Actor-ID and X-Knox-Actor-Groups-1 headers
+   * are added for authenticated users.
+   */
+  @Test
+  public void testGroupHeaders() throws PrivilegedActionException {
+    Subject subject = new Subject();
+    subject.getPrincipals().add(new PrimaryPrincipal("knoxui"));
+    subject.getPrincipals().add(new GroupPrincipal("knox"));
+    subject.getPrincipals().add(new GroupPrincipal("admin"));
+
+    ConfigurableDispatch dispatch = new ConfigurableDispatch();
+    final String headerReqID = "1234567890ABCD";
+    dispatch.setShouldIncludePrincipalAndGroups(true);
+
+    Map<String, String> headers = new HashMap<>();
+    headers.put(REQUEST_ID_HEADER_NAME, headerReqID);
+    headers.put(HttpHeaders.ACCEPT, "abc");
+    headers.put("TEST", "test");
+
+    HttpServletRequest inboundRequest = EasyMock.createNiceMock(HttpServletRequest.class);
+    EasyMock.expect(inboundRequest.getHeaderNames()).andReturn(Collections.enumeration(headers.keySet())).anyTimes();
+    Capture<String> capturedArgument = Capture.newInstance();
+    EasyMock.expect(inboundRequest.getHeader(EasyMock.capture(capturedArgument)))
+            .andAnswer(() -> headers.get(capturedArgument.getValue())).anyTimes();
+    EasyMock.replay(inboundRequest);
+
+    HttpUriRequest outboundRequest = new HttpGet();
+
+    Subject.doAs(subject, new PrivilegedExceptionAction<Object>() {
+
+      @Override
+      public Object run() throws Exception {
+        dispatch.copyRequestHeaderFields(outboundRequest, inboundRequest);
+        return null;
+      }
+    });
+
+    Header[] outboundRequestHeaders = outboundRequest.getAllHeaders();
+    assertThat(outboundRequestHeaders.length, is(5));
+    assertThat(outboundRequest.getHeaders(DEFAULT_AUTH_ACTOR_ID_HEADER_NAME)[0].getValue(), is("knoxui"));
+  }
+
 }