blob: aa8206227fd722e4b8945e9aa617adfac1208b52 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.knox.gateway.provider.federation.jwt.filter;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.knox.gateway.services.security.token.UnknownTokenException;
import org.apache.knox.gateway.services.security.token.impl.JWTToken;
import org.apache.knox.gateway.util.CertificateUtils;
import org.apache.knox.gateway.services.security.token.impl.JWT;
import javax.security.auth.Subject;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.text.ParseException;
import java.util.Base64;
import java.util.Locale;
public class JWTFederationFilter extends AbstractJWTFilter {
public enum TokenType {
JWT, Passcode;
}
public static final String KNOX_TOKEN_AUDIENCES = "knox.token.audiences";
public static final String TOKEN_VERIFICATION_PEM = "knox.token.verification.pem";
public static final String KNOX_TOKEN_QUERY_PARAM_NAME = "knox.token.query.param.name";
public static final String TOKEN_PRINCIPAL_CLAIM = "knox.token.principal.claim";
public static final String JWKS_URL = "knox.token.jwks.url";
public static final String BEARER = "Bearer ";
public static final String BASIC = "Basic";
public static final String TOKEN = "Token";
public static final String PASSCODE = "Passcode";
private String paramName;
@Override
public void init( FilterConfig filterConfig ) throws ServletException {
super.init(filterConfig);
// expected audiences or null
String expectedAudiences = filterConfig.getInitParameter(KNOX_TOKEN_AUDIENCES);
if (expectedAudiences != null) {
audiences = parseExpectedAudiences(expectedAudiences);
}
// query param name for finding the provided knoxtoken
String queryParamName = filterConfig.getInitParameter(KNOX_TOKEN_QUERY_PARAM_NAME);
if (queryParamName != null) {
paramName = queryParamName;
}
// JWKSUrl
String oidcjwksurl = filterConfig.getInitParameter(JWKS_URL);
if (oidcjwksurl != null) {
expectedJWKSUrl = oidcjwksurl;
}
// expected claim
String oidcPrincipalclaim = filterConfig.getInitParameter(TOKEN_PRINCIPAL_CLAIM);
if (oidcPrincipalclaim != null) {
expectedPrincipalClaim = oidcPrincipalclaim;
}
// token verification pem
String verificationPEM = filterConfig.getInitParameter(TOKEN_VERIFICATION_PEM);
// setup the public key of the token issuer for verification
if (verificationPEM != null) {
publicKey = CertificateUtils.parseRSAPublicKey(verificationPEM);
}
configureExpectedParameters(filterConfig);
}
@Override
public void destroy() {
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
final Pair<TokenType, String> wireToken = getWireToken(request);
if (wireToken != null) {
TokenType tokenType = wireToken.getLeft();
String tokenValue = wireToken.getRight();
if (TokenType.JWT.equals(tokenType)) {
try {
JWT token = new JWTToken(tokenValue);
if (validateToken((HttpServletRequest) request, (HttpServletResponse) response, chain, token)) {
Subject subject = createSubjectFromToken(token);
continueWithEstablishedSecurityContext(subject, (HttpServletRequest) request, (HttpServletResponse) response, chain);
}
} catch (ParseException | UnknownTokenException ex) {
((HttpServletResponse) response).sendError(HttpServletResponse.SC_UNAUTHORIZED);
}
} else if (TokenType.Passcode.equals(tokenType)) {
// Validate the token based on the server-managed metadata
if (validateToken((HttpServletRequest) request, (HttpServletResponse) response, chain, tokenValue)) {
try {
Subject subject = createSubjectFromTokenIdentifier(tokenValue);
continueWithEstablishedSecurityContext(subject, (HttpServletRequest) request, (HttpServletResponse) response, chain);
} catch (UnknownTokenException e) {
((HttpServletResponse) response).sendError(HttpServletResponse.SC_UNAUTHORIZED);
}
}
}
} else {
// no token provided in header
((HttpServletResponse) response).sendError(HttpServletResponse.SC_UNAUTHORIZED);
}
}
public Pair<TokenType, String> getWireToken(final ServletRequest request) {
Pair<TokenType, String> parsed = null;
String token = null;
final String header = ((HttpServletRequest)request).getHeader("Authorization");
if (header != null) {
if (header.startsWith(BEARER)) {
// what follows the bearer designator should be the JWT token being used
// to request or as an access token
token = header.substring(BEARER.length());
parsed = Pair.of(TokenType.JWT, token);
} else if (header.toLowerCase(Locale.ROOT).startsWith(BASIC.toLowerCase(Locale.ROOT))) {
// what follows the Basic designator should be the JWT token or the unique token ID being used
// to request or as an access token
parsed = parseFromHTTPBasicCredentials(header);
}
}
if (parsed == null) {
token = request.getParameter(this.paramName);
if (token != null) {
parsed = Pair.of(TokenType.JWT, token);
}
}
return parsed;
}
private Pair<TokenType, String> parseFromHTTPBasicCredentials(final String header) {
Pair<TokenType, String> parsed = null;
final String base64Credentials = header.substring(BASIC.length()).trim();
final byte[] credDecoded = Base64.getDecoder().decode(base64Credentials);
final String credentials = new String(credDecoded, StandardCharsets.UTF_8);
final String[] values = credentials.split(":", 2);
String username = values[0];
String passcode = values[1].isEmpty() ? null : values[1];
if (TOKEN.equalsIgnoreCase(username) || PASSCODE.equalsIgnoreCase(username)) {
parsed = Pair.of(TOKEN.equalsIgnoreCase(username) ? TokenType.JWT : TokenType.Passcode, passcode);
}
return parsed;
}
@Override
protected void handleValidationError(HttpServletRequest request, HttpServletResponse response, int status,
String error) throws IOException {
if (error != null) {
response.sendError(status, error);
}
else {
response.sendError(status);
}
}
}