blob: 58d50cf97211ad4c067efb384618c86d5ca3b31c [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.security.http;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.commons.lang.StringUtils;
import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class CrossOriginFilter implements Filter {
private static final Logger LOG =
LoggerFactory.getLogger(CrossOriginFilter.class);
// HTTP CORS Request Headers
static final String ORIGIN = "Origin";
static final String ACCESS_CONTROL_REQUEST_METHOD =
"Access-Control-Request-Method";
static final String ACCESS_CONTROL_REQUEST_HEADERS =
"Access-Control-Request-Headers";
// HTTP CORS Response Headers
static final String ACCESS_CONTROL_ALLOW_ORIGIN =
"Access-Control-Allow-Origin";
static final String ACCESS_CONTROL_ALLOW_CREDENTIALS =
"Access-Control-Allow-Credentials";
static final String ACCESS_CONTROL_ALLOW_METHODS =
"Access-Control-Allow-Methods";
static final String ACCESS_CONTROL_ALLOW_HEADERS =
"Access-Control-Allow-Headers";
static final String ACCESS_CONTROL_MAX_AGE = "Access-Control-Max-Age";
// Filter configuration
public static final String ALLOWED_ORIGINS = "allowed-origins";
public static final String ALLOWED_ORIGINS_DEFAULT = "*";
public static final String ALLOWED_METHODS = "allowed-methods";
public static final String ALLOWED_METHODS_DEFAULT = "GET,POST,HEAD";
public static final String ALLOWED_HEADERS = "allowed-headers";
public static final String ALLOWED_HEADERS_DEFAULT =
"X-Requested-With,Content-Type,Accept,Origin";
public static final String MAX_AGE = "max-age";
public static final String MAX_AGE_DEFAULT = "1800";
private List<String> allowedMethods = new ArrayList<String>();
private List<String> allowedHeaders = new ArrayList<String>();
private List<String> allowedOrigins = new ArrayList<String>();
private boolean allowAllOrigins = true;
private String maxAge;
@Override
public void init(FilterConfig filterConfig) throws ServletException {
initializeAllowedMethods(filterConfig);
initializeAllowedHeaders(filterConfig);
initializeAllowedOrigins(filterConfig);
initializeMaxAge(filterConfig);
}
@Override
public void doFilter(ServletRequest req, ServletResponse res,
FilterChain chain)
throws IOException, ServletException {
doCrossFilter((HttpServletRequest) req, (HttpServletResponse) res);
chain.doFilter(req, res);
}
@Override
public void destroy() {
allowedMethods.clear();
allowedHeaders.clear();
allowedOrigins.clear();
}
private void doCrossFilter(HttpServletRequest req, HttpServletResponse res) {
String originsList = encodeHeader(req.getHeader(ORIGIN));
if (!isCrossOrigin(originsList)) {
if(LOG.isDebugEnabled()) {
LOG.debug("Header origin is null. Returning");
}
return;
}
if (!areOriginsAllowed(originsList)) {
if(LOG.isDebugEnabled()) {
LOG.debug("Header origins '" + originsList + "' not allowed. Returning");
}
return;
}
String accessControlRequestMethod =
req.getHeader(ACCESS_CONTROL_REQUEST_METHOD);
if (!isMethodAllowed(accessControlRequestMethod)) {
if(LOG.isDebugEnabled()) {
LOG.debug("Access control method '" + accessControlRequestMethod +
"' not allowed. Returning");
}
return;
}
String accessControlRequestHeaders =
req.getHeader(ACCESS_CONTROL_REQUEST_HEADERS);
if (!areHeadersAllowed(accessControlRequestHeaders)) {
if(LOG.isDebugEnabled()) {
LOG.debug("Access control headers '" + accessControlRequestHeaders +
"' not allowed. Returning");
}
return;
}
if(LOG.isDebugEnabled()) {
LOG.debug("Completed cross origin filter checks. Populating " +
"HttpServletResponse");
}
res.setHeader(ACCESS_CONTROL_ALLOW_ORIGIN, originsList);
res.setHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS, Boolean.TRUE.toString());
res.setHeader(ACCESS_CONTROL_ALLOW_METHODS, getAllowedMethodsHeader());
res.setHeader(ACCESS_CONTROL_ALLOW_HEADERS, getAllowedHeadersHeader());
res.setHeader(ACCESS_CONTROL_MAX_AGE, maxAge);
}
@VisibleForTesting
String getAllowedHeadersHeader() {
return StringUtils.join(allowedHeaders, ',');
}
@VisibleForTesting
String getAllowedMethodsHeader() {
return StringUtils.join(allowedMethods, ',');
}
private void initializeAllowedMethods(FilterConfig filterConfig) {
String allowedMethodsConfig =
filterConfig.getInitParameter(ALLOWED_METHODS);
if (allowedMethodsConfig == null) {
allowedMethodsConfig = ALLOWED_METHODS_DEFAULT;
}
allowedMethods.addAll(
Arrays.asList(allowedMethodsConfig.trim().split("\\s*,\\s*")));
LOG.info("Allowed Methods: " + getAllowedMethodsHeader());
}
private void initializeAllowedHeaders(FilterConfig filterConfig) {
String allowedHeadersConfig =
filterConfig.getInitParameter(ALLOWED_HEADERS);
if (allowedHeadersConfig == null) {
allowedHeadersConfig = ALLOWED_HEADERS_DEFAULT;
}
allowedHeaders.addAll(
Arrays.asList(allowedHeadersConfig.trim().split("\\s*,\\s*")));
LOG.info("Allowed Headers: " + getAllowedHeadersHeader());
}
private void initializeAllowedOrigins(FilterConfig filterConfig) {
String allowedOriginsConfig =
filterConfig.getInitParameter(ALLOWED_ORIGINS);
if (allowedOriginsConfig == null) {
allowedOriginsConfig = ALLOWED_ORIGINS_DEFAULT;
}
allowedOrigins.addAll(
Arrays.asList(allowedOriginsConfig.trim().split("\\s*,\\s*")));
allowAllOrigins = allowedOrigins.contains("*");
LOG.info("Allowed Origins: " + StringUtils.join(allowedOrigins, ','));
LOG.info("Allow All Origins: " + allowAllOrigins);
}
private void initializeMaxAge(FilterConfig filterConfig) {
maxAge = filterConfig.getInitParameter(MAX_AGE);
if (maxAge == null) {
maxAge = MAX_AGE_DEFAULT;
}
LOG.info("Max Age: " + maxAge);
}
static String encodeHeader(final String header) {
if (header == null) {
return null;
}
// Protect against HTTP response splitting vulnerability
// since value is written as part of the response header
// Ensure this header only has one header by removing
// CRs and LFs
return header.split("\n|\r")[0].trim();
}
static boolean isCrossOrigin(String originsList) {
return originsList != null;
}
@VisibleForTesting
boolean areOriginsAllowed(String originsList) {
if (allowAllOrigins) {
return true;
}
String[] origins = originsList.trim().split("\\s+");
for (String origin : origins) {
for (String allowedOrigin : allowedOrigins) {
if (allowedOrigin.contains("*")) {
String regex = allowedOrigin.replace(".", "\\.").replace("*", ".*");
Pattern p = Pattern.compile(regex);
Matcher m = p.matcher(origin);
if (m.matches()) {
return true;
}
} else if (allowedOrigin.equals(origin)) {
return true;
}
}
}
return false;
}
private boolean areHeadersAllowed(String accessControlRequestHeaders) {
if (accessControlRequestHeaders == null) {
return true;
}
String[] headers = accessControlRequestHeaders.trim().split("\\s*,\\s*");
return allowedHeaders.containsAll(Arrays.asList(headers));
}
private boolean isMethodAllowed(String accessControlRequestMethod) {
if (accessControlRequestMethod == null) {
return true;
}
return allowedMethods.contains(accessControlRequestMethod);
}
}