KNOX-2149 - Added JWT OIDC Verification based on JWKS Urls and extract custom claim

Closes #216

Signed-off-by: Kevin Risden <krisden@apache.org>
diff --git a/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AbstractJWTFilter.java b/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AbstractJWTFilter.java
index 6e92241..190af6d 100644
--- a/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AbstractJWTFilter.java
+++ b/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/AbstractJWTFilter.java
@@ -27,6 +27,7 @@
 import java.util.Date;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Locale;
 import java.util.Set;
 
 import javax.security.auth.Subject;
@@ -52,8 +53,8 @@
 import org.apache.knox.gateway.i18n.messages.MessagesFactory;
 import org.apache.knox.gateway.provider.federation.jwt.JWTMessages;
 import org.apache.knox.gateway.security.PrimaryPrincipal;
-import org.apache.knox.gateway.services.ServiceType;
 import org.apache.knox.gateway.services.GatewayServices;
+import org.apache.knox.gateway.services.ServiceType;
 import org.apache.knox.gateway.services.security.token.JWTokenAuthority;
 import org.apache.knox.gateway.services.security.token.TokenServiceException;
 import org.apache.knox.gateway.services.security.token.TokenStateService;
@@ -87,6 +88,8 @@
   protected RSAPublicKey publicKey;
   private String expectedIssuer;
   private String expectedSigAlg;
+  protected String expectedPrincipalClaim;
+  protected String expectedJWKSUrl;
 
   private TokenStateService tokenStateService;
 
@@ -225,8 +228,15 @@
   }
 
   protected Subject createSubjectFromToken(JWT token) {
-    final String principal = token.getSubject();
+    String principal = token.getSubject();
+    String claimvalue = null;
+    if (expectedPrincipalClaim != null) {
+      claimvalue = token.getClaim(expectedPrincipalClaim);
+    }
 
+    if (claimvalue != null) {
+      principal = claimvalue.toLowerCase(Locale.ROOT);
+    }
     @SuppressWarnings("rawtypes")
     HashSet emptySet = new HashSet();
     Set<Principal> principals = new HashSet<>();
@@ -248,11 +258,12 @@
       throws IOException, ServletException {
     boolean verified = false;
     try {
-      if (publicKey == null) {
-        verified = authority.verifyToken(token);
-      }
-      else {
+      if (publicKey != null) {
         verified = authority.verifyToken(token, publicKey);
+      } else if (expectedJWKSUrl != null) {
+        verified = authority.verifyToken(token, expectedJWKSUrl, expectedSigAlg);
+      } else {
+        verified = authority.verifyToken(token);
       }
     } catch (TokenServiceException e) {
       log.unableToVerifyToken(e);
diff --git a/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/JWTFederationFilter.java b/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/JWTFederationFilter.java
index 8d49f7f..e4a7a25 100644
--- a/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/JWTFederationFilter.java
+++ b/gateway-provider-security-jwt/src/main/java/org/apache/knox/gateway/provider/federation/jwt/filter/JWTFederationFilter.java
@@ -38,6 +38,8 @@
   public static final String KNOX_TOKEN_AUDIENCES = "knox.token.audiences";
   public static final String TOKEN_VERIFICATION_PEM = "knox.token.verification.pem";
   private static final String KNOX_TOKEN_QUERY_PARAM_NAME = "knox.token.query.param.name";
+  public static final String TOKEN_PRINCIPAL_CLAIM = "knox.token.principal.claim";
+  public static final String JWKS_URL = "knox.token.jwks.url";
   private static final String BEARER = "Bearer ";
   private String paramName = "knoxtoken";
 
@@ -56,7 +58,16 @@
     if (queryParamName != null) {
       paramName = queryParamName;
     }
-
+    //  JWKSUrl
+    String oidcjwksurl = filterConfig.getInitParameter(JWKS_URL);
+    if (oidcjwksurl != null) {
+      expectedJWKSUrl = oidcjwksurl;
+    }
+    // expected claim
+    String oidcPrincipalclaim = filterConfig.getInitParameter(TOKEN_PRINCIPAL_CLAIM);
+    if (oidcPrincipalclaim != null) {
+      expectedPrincipalClaim = oidcPrincipalclaim;
+    }
     // token verification pem
     String verificationPEM = filterConfig.getInitParameter(TOKEN_VERIFICATION_PEM);
     // setup the public key of the token issuer for verification
diff --git a/gateway-provider-security-jwt/src/test/java/org/apache/knox/gateway/provider/federation/AbstractJWTFilterTest.java b/gateway-provider-security-jwt/src/test/java/org/apache/knox/gateway/provider/federation/AbstractJWTFilterTest.java
index e46d2b9..eb23cd0 100644
--- a/gateway-provider-security-jwt/src/test/java/org/apache/knox/gateway/provider/federation/AbstractJWTFilterTest.java
+++ b/gateway-provider-security-jwt/src/test/java/org/apache/knox/gateway/provider/federation/AbstractJWTFilterTest.java
@@ -1070,6 +1070,11 @@
       JWSVerifier verifier = new RSASSAVerifier(publicKey);
       return token.verify(verifier);
     }
+
+    @Override
+    public boolean verifyToken(JWT token, String jwksurl, String algorithm) {
+     return false;
+    }
   }
 
   protected static class TestFilterChain implements FilterChain {
diff --git a/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/DefaultTokenAuthorityService.java b/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/DefaultTokenAuthorityService.java
index cb452c6..50ffb62 100644
--- a/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/DefaultTokenAuthorityService.java
+++ b/gateway-server/src/main/java/org/apache/knox/gateway/services/token/impl/DefaultTokenAuthorityService.java
@@ -17,6 +17,8 @@
  */
 package org.apache.knox.gateway.services.token.impl;
 
+import java.net.MalformedURLException;
+import java.net.URL;
 import java.security.Key;
 import java.security.KeyStore;
 import java.security.KeyStoreException;
@@ -27,11 +29,12 @@
 import java.security.cert.Certificate;
 import java.security.interfaces.RSAPrivateKey;
 import java.security.interfaces.RSAPublicKey;
-import java.util.Map;
-import java.util.Set;
-import java.util.List;
+import java.text.ParseException;
 import java.util.ArrayList;
 import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
 
 import javax.security.auth.Subject;
 
@@ -49,10 +52,22 @@
 import org.apache.knox.gateway.services.security.token.impl.JWT;
 import org.apache.knox.gateway.services.security.token.impl.JWTToken;
 
+import com.nimbusds.jose.JOSEException;
+import com.nimbusds.jose.JWSAlgorithm;
 import com.nimbusds.jose.JWSSigner;
 import com.nimbusds.jose.JWSVerifier;
 import com.nimbusds.jose.crypto.RSASSASigner;
 import com.nimbusds.jose.crypto.RSASSAVerifier;
+import com.nimbusds.jose.jwk.source.JWKSource;
+import com.nimbusds.jose.jwk.source.RemoteJWKSet;
+import com.nimbusds.jose.proc.BadJOSEException;
+import com.nimbusds.jose.proc.JWSKeySelector;
+import com.nimbusds.jose.proc.JWSVerificationKeySelector;
+import com.nimbusds.jose.proc.SecurityContext;
+import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
+import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
+import com.nimbusds.jwt.proc.DefaultJWTProcessor;
+import com.nimbusds.jwt.proc.JWTClaimsSetVerifier;
 
 public class DefaultTokenAuthorityService implements JWTokenAuthority, Service {
   private static final GatewayResources RESOURCES = ResourcesFactory.get(GatewayResources.class);
@@ -212,6 +227,32 @@
   }
 
   @Override
+  public boolean verifyToken(JWT token, String jwksurl, String algorithm) throws TokenServiceException {
+    boolean verified = false;
+    try {
+      if (algorithm != null && jwksurl != null) {
+        JWSAlgorithm expectedJWSAlg = JWSAlgorithm.parse(algorithm);
+        JWKSource<SecurityContext> keySource = new RemoteJWKSet<>(new URL(jwksurl));
+        JWSKeySelector<SecurityContext> keySelector = new JWSVerificationKeySelector<>(expectedJWSAlg, keySource);
+
+        // Create a JWT processor for the access tokens
+        ConfigurableJWTProcessor<SecurityContext> jwtProcessor = new DefaultJWTProcessor<>();
+        jwtProcessor.setJWSKeySelector(keySelector);
+        JWTClaimsSetVerifier<SecurityContext> claimsVerifier = new DefaultJWTClaimsVerifier<>();
+        jwtProcessor.setJWTClaimsSetVerifier(claimsVerifier);
+
+        // Process the token
+        SecurityContext ctx = null; // optional context parameter, not required here
+        jwtProcessor.process(token.toString(), ctx);
+        verified = true;
+      }
+    } catch (BadJOSEException | JOSEException | ParseException | MalformedURLException e) {
+      throw new TokenServiceException("Cannot verify token.", e);
+    }
+    return verified;
+  }
+
+  @Override
   public void init(GatewayConfig config, Map<String, String> options)
       throws ServiceLifecycleException {
     if (as == null || ks == null) {
diff --git a/gateway-service-knoxsso/src/test/java/org/apache/knox/gateway/service/knoxsso/WebSSOResourceTest.java b/gateway-service-knoxsso/src/test/java/org/apache/knox/gateway/service/knoxsso/WebSSOResourceTest.java
index b7f143e..c87aa58 100644
--- a/gateway-service-knoxsso/src/test/java/org/apache/knox/gateway/service/knoxsso/WebSSOResourceTest.java
+++ b/gateway-service-knoxsso/src/test/java/org/apache/knox/gateway/service/knoxsso/WebSSOResourceTest.java
@@ -1071,5 +1071,10 @@
       JWSVerifier verifier = new RSASSAVerifier(publicKey);
       return token.verify(verifier);
     }
+
+    @Override
+    public boolean verifyToken(JWT token, String jwksurl, String algorithm) {
+     return false;
+    }
   }
 }
diff --git a/gateway-service-knoxtoken/src/test/java/org/apache/knox/gateway/service/knoxtoken/TokenServiceResourceTest.java b/gateway-service-knoxtoken/src/test/java/org/apache/knox/gateway/service/knoxtoken/TokenServiceResourceTest.java
index e2eed03..bbe6fdd 100644
--- a/gateway-service-knoxtoken/src/test/java/org/apache/knox/gateway/service/knoxtoken/TokenServiceResourceTest.java
+++ b/gateway-service-knoxtoken/src/test/java/org/apache/knox/gateway/service/knoxtoken/TokenServiceResourceTest.java
@@ -1202,5 +1202,10 @@
       JWSVerifier verifier = new RSASSAVerifier(publicKey);
       return token.verify(verifier);
     }
+
+    @Override
+    public boolean verifyToken(JWT token, String jwksurl, String algorithm) {
+     return false;
+    }
   }
 }
diff --git a/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/JWTokenAuthority.java b/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/JWTokenAuthority.java
index 2f71c2b..5ba10d1 100644
--- a/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/JWTokenAuthority.java
+++ b/gateway-spi/src/main/java/org/apache/knox/gateway/services/security/token/JWTokenAuthority.java
@@ -38,8 +38,9 @@
 
   boolean verifyToken(JWT token) throws TokenServiceException;
 
-  boolean verifyToken(JWT token, RSAPublicKey publicKey)
-      throws TokenServiceException;
+  boolean verifyToken(JWT token, RSAPublicKey publicKey) throws TokenServiceException;
+
+  boolean verifyToken(JWT token, String jwksurl ,String algorithm ) throws TokenServiceException;
 
   JWT issueToken(Principal p, String algorithm, long expires) throws TokenServiceException;
 
@@ -52,4 +53,4 @@
   JWT issueToken(Principal p, List<String> audiences, String algorithm, long expires,
                  String signingKeystoreName, String signingKeystoreAlias, char[] signingKeystorePassphrase)
       throws TokenServiceException;
-}
\ No newline at end of file
+}