blob: bbcff0cd460e5b7faf66b32b600a4c96aa93ea9f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.web.security.saml.impl;
import org.apache.nifi.util.NiFiProperties;
import org.apache.nifi.util.StringUtils;
import org.apache.nifi.web.security.saml.NiFiSAMLContextProvider;
import org.apache.nifi.web.security.saml.SAMLConfiguration;
import org.apache.nifi.web.security.saml.SAMLConfigurationFactory;
import org.apache.nifi.web.security.saml.SAMLEndpoints;
import org.apache.nifi.web.security.saml.SAMLService;
import org.opensaml.common.SAMLException;
import org.opensaml.common.SAMLRuntimeException;
import org.opensaml.common.binding.decoding.URIComparator;
import org.opensaml.saml2.core.Attribute;
import org.opensaml.saml2.core.LogoutRequest;
import org.opensaml.saml2.core.LogoutResponse;
import org.opensaml.saml2.metadata.Endpoint;
import org.opensaml.saml2.metadata.EntityDescriptor;
import org.opensaml.saml2.metadata.provider.MetadataProvider;
import org.opensaml.saml2.metadata.provider.MetadataProviderException;
import org.opensaml.xml.XMLObject;
import org.opensaml.xml.encryption.DecryptionException;
import org.opensaml.xml.schema.XSString;
import org.opensaml.xml.schema.impl.XSAnyImpl;
import org.opensaml.xml.validation.ValidationException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.saml.SAMLConstants;
import org.springframework.security.saml.SAMLCredential;
import org.springframework.security.saml.SAMLLogoutProcessingFilter;
import org.springframework.security.saml.SAMLProcessingFilter;
import org.springframework.security.saml.context.SAMLMessageContext;
import org.springframework.security.saml.key.KeyManager;
import org.springframework.security.saml.log.SAMLLogger;
import org.springframework.security.saml.metadata.ExtendedMetadata;
import org.springframework.security.saml.metadata.ExtendedMetadataDelegate;
import org.springframework.security.saml.metadata.MetadataGenerator;
import org.springframework.security.saml.metadata.MetadataManager;
import org.springframework.security.saml.metadata.MetadataMemoryProvider;
import org.springframework.security.saml.processor.SAMLProcessor;
import org.springframework.security.saml.util.DefaultURLComparator;
import org.springframework.security.saml.util.SAMLUtil;
import org.springframework.security.saml.websso.SingleLogoutProfile;
import org.springframework.security.saml.websso.WebSSOProfile;
import org.springframework.security.saml.websso.WebSSOProfileConsumer;
import org.springframework.security.saml.websso.WebSSOProfileOptions;
import org.springframework.security.web.authentication.logout.LogoutHandler;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Timer;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
public class StandardSAMLService implements SAMLService {
private static final Logger LOGGER = LoggerFactory.getLogger(StandardSAMLService.class);
private final NiFiProperties properties;
private final SAMLConfigurationFactory samlConfigurationFactory;
private final AtomicBoolean initialized = new AtomicBoolean(false);
private final AtomicBoolean spMetadataInitialized = new AtomicBoolean(false);
private final AtomicReference<String> spBaseUrl = new AtomicReference<>(null);
private final URIComparator uriComparator = new DefaultURLComparator();
private SAMLConfiguration samlConfiguration;
public StandardSAMLService(final SAMLConfigurationFactory samlConfigurationFactory, final NiFiProperties properties) {
this.properties = properties;
this.samlConfigurationFactory = samlConfigurationFactory;
}
@Override
public synchronized void initialize() {
// this method will always be called so if SAML is not configured just return, don't throw an exception
if (!properties.isSamlEnabled()) {
return;
}
// already initialized so return
if (initialized.get()) {
return;
}
try {
LOGGER.info("Initializing SAML Service...");
samlConfiguration = samlConfigurationFactory.create(properties);
initialized.set(true);
LOGGER.info("Finished initializing SAML Service");
} catch (Exception e) {
throw new RuntimeException("Unable to initialize SAML configuration due to: " + e.getMessage(), e);
}
}
@Override
public void shutdown() {
// this method will always be called so if SAML is not configured just return, don't throw an exception
if (!properties.isSamlEnabled()) {
return;
}
LOGGER.info("Shutting down SAML Service...");
if (samlConfiguration != null) {
try {
final Timer backgroundTimer = samlConfiguration.getBackgroundTaskTimer();
backgroundTimer.purge();
backgroundTimer.cancel();
} catch (final Exception e) {
LOGGER.warn("Error shutting down background timer: " + e.getMessage(), e);
}
try {
final MetadataManager metadataManager = samlConfiguration.getMetadataManager();
metadataManager.destroy();
} catch (final Exception e) {
LOGGER.warn("Error shutting down metadata manager: " + e.getMessage(), e);
}
}
samlConfiguration = null;
initialized.set(false);
spMetadataInitialized.set(false);
spBaseUrl.set(null);
LOGGER.info("Finished shutting down SAML Service");
}
@Override
public boolean isSamlEnabled() {
return properties.isSamlEnabled();
}
@Override
public boolean isServiceProviderInitialized() {
return spMetadataInitialized.get();
}
@Override
public synchronized void initializeServiceProvider(final String baseUrl) {
if (!isSamlEnabled()) {
throw new IllegalStateException(SAML_SUPPORT_IS_NOT_CONFIGURED);
}
if (StringUtils.isBlank(baseUrl)) {
throw new IllegalArgumentException("baseUrl is required when initializing the service provider");
}
if (isServiceProviderInitialized()) {
final String existingBaseUrl = spBaseUrl.get();
LOGGER.info("Service provider already initialized with baseUrl = '{}'", new Object[]{existingBaseUrl});
return;
}
LOGGER.info("Initializing SAML service provider with baseUrl = '{}'", new Object[]{baseUrl});
try {
initializeServiceProviderMetadata(baseUrl);
spBaseUrl.set(baseUrl);
spMetadataInitialized.set(true);
} catch (Exception e) {
throw new RuntimeException("Unable to initialize SAML service provider: " + e.getMessage(), e);
}
LOGGER.info("Done initializing SAML service provider");
}
@Override
public String getServiceProviderMetadata() {
verifyReadyForSamlOperations();
try {
final KeyManager keyManager = samlConfiguration.getKeyManager();
final MetadataManager metadataManager = samlConfiguration.getMetadataManager();
final String spEntityId = samlConfiguration.getSpEntityId();
final EntityDescriptor descriptor = metadataManager.getEntityDescriptor(spEntityId);
final String metadataString = SAMLUtil.getMetadataAsString(metadataManager, keyManager, descriptor, null);
return metadataString;
} catch (Exception e) {
throw new RuntimeException("Unable to obtain SAML service provider metadata", e);
}
}
@Override
public long getAuthExpiration() {
verifyReadyForSamlOperations();
return samlConfiguration.getAuthExpiration();
}
@Override
public void initiateLogin(final HttpServletRequest request, final HttpServletResponse response, final String relayState) {
verifyReadyForSamlOperations();
final SAMLLogger samlLogger = samlConfiguration.getLogger();
final NiFiSAMLContextProvider contextProvider = samlConfiguration.getContextProvider();
final SAMLMessageContext context;
try {
context = contextProvider.getLocalAndPeerEntity(request, response, Collections.emptyMap());
} catch (final MetadataProviderException e) {
throw new IllegalStateException("Unable to create SAML Message Context: " + e.getMessage(), e);
}
// Generate options for the current SSO request
final WebSSOProfileOptions options = samlConfiguration.getWebSSOProfileOptions().clone();
options.setRelayState(relayState);
// Send WebSSO AuthN request
final WebSSOProfile webSSOProfile = samlConfiguration.getWebSSOProfile();
try {
webSSOProfile.sendAuthenticationRequest(context, options);
samlLogger.log(SAMLConstants.AUTH_N_REQUEST, SAMLConstants.SUCCESS, context);
} catch (Exception e) {
samlLogger.log(SAMLConstants.AUTH_N_REQUEST, SAMLConstants.FAILURE, context);
throw new RuntimeException("Unable to initiate SAML authentication request: " + e.getMessage(), e);
}
}
@Override
public SAMLCredential processLogin(final HttpServletRequest request, final HttpServletResponse response, final Map<String,String> parameters) {
verifyReadyForSamlOperations();
LOGGER.info("Attempting SAML2 authentication using profile {}", SAMLConstants.SAML2_WEBSSO_PROFILE_URI);
final SAMLMessageContext context;
try {
final NiFiSAMLContextProvider contextProvider = samlConfiguration.getContextProvider();
context = contextProvider.getLocalEntity(request, response, parameters);
} catch (MetadataProviderException e) {
throw new IllegalStateException("Unable to create SAML Message Context: " + e.getMessage(), e);
}
final SAMLProcessor samlProcessor = samlConfiguration.getProcessor();
try {
samlProcessor.retrieveMessage(context);
} catch (Exception e) {
throw new RuntimeException("Unable to load SAML message: " + e.getMessage(), e);
}
// Override set values
context.setCommunicationProfileId(SAMLConstants.SAML2_WEBSSO_PROFILE_URI);
try {
context.setLocalEntityEndpoint(getLocalEntityEndpoint(context));
} catch (SAMLException e) {
throw new RuntimeException(e.getMessage(), e);
}
if (!SAMLConstants.SAML2_WEBSSO_PROFILE_URI.equals(context.getCommunicationProfileId())) {
throw new IllegalStateException("Unsupported profile encountered in the context: " + context.getCommunicationProfileId());
}
final SAMLLogger samlLogger = samlConfiguration.getLogger();
final WebSSOProfileConsumer webSSOProfileConsumer = samlConfiguration.getWebSSOProfileConsumer();
try {
final SAMLCredential credential = webSSOProfileConsumer.processAuthenticationResponse(context);
LOGGER.debug("SAML Response contains successful authentication for NameID: " + credential.getNameID().getValue());
samlLogger.log(SAMLConstants.AUTH_N_RESPONSE, SAMLConstants.SUCCESS, context);
return credential;
} catch (SAMLException | SAMLRuntimeException e) {
LOGGER.error("Error validating SAML message", e);
samlLogger.log(SAMLConstants.AUTH_N_RESPONSE, SAMLConstants.FAILURE, context, e);
throw new RuntimeException("Error validating SAML message: " + e.getMessage(), e);
} catch (org.opensaml.xml.security.SecurityException | ValidationException e) {
LOGGER.error("Error validating signature", e);
samlLogger.log(SAMLConstants.AUTH_N_RESPONSE, SAMLConstants.FAILURE, context, e);
throw new RuntimeException("Error validating SAML message signature: " + e.getMessage(), e);
} catch (DecryptionException e) {
LOGGER.error("Error decrypting SAML message", e);
samlLogger.log(SAMLConstants.AUTH_N_RESPONSE, SAMLConstants.FAILURE, context, e);
throw new RuntimeException("Error decrypting SAML message: " + e.getMessage(), e);
}
}
@Override
public String getUserIdentity(final SAMLCredential credential) {
verifyReadyForSamlOperations();
if (credential == null) {
throw new IllegalArgumentException("SAML Credential is required");
}
String userIdentity = null;
final String identityAttributeName = samlConfiguration.getIdentityAttributeName();
if (StringUtils.isBlank(identityAttributeName)) {
userIdentity = credential.getNameID().getValue();
LOGGER.info("No identity attribute specified, using NameID for user identity: {}", userIdentity);
} else {
LOGGER.debug("Looking for SAML attribute {} ...", identityAttributeName);
final List<Attribute> attributes = credential.getAttributes();
if (attributes == null || attributes.isEmpty()) {
userIdentity = credential.getNameID().getValue();
LOGGER.warn("No attributes returned in SAML response, using NameID for user identity: {}", userIdentity);
} else {
for (final Attribute attribute : attributes) {
if (!identityAttributeName.equals(attribute.getName())) {
LOGGER.trace("Skipping SAML attribute {}", attribute.getName());
continue;
}
for (final XMLObject value : attribute.getAttributeValues()) {
if (value instanceof XSString) {
final XSString valueXSString = (XSString) value;
userIdentity = valueXSString.getValue();
break;
} else {
LOGGER.debug("Value was not XSString, but was " + value.getClass().getCanonicalName());
}
}
if (userIdentity != null) {
LOGGER.info("Found user identity {} in attribute {}", userIdentity, attribute.getName());
break;
}
}
}
if (userIdentity == null) {
userIdentity = credential.getNameID().getValue();
LOGGER.warn("No attribute found named {}, using NameID for user identity: {}", identityAttributeName, userIdentity);
}
}
return userIdentity;
}
@Override
public Set<String> getUserGroups(final SAMLCredential credential) {
verifyReadyForSamlOperations();
if (credential == null) {
throw new IllegalArgumentException("SAML Credential is required");
}
final String userIdentity = credential.getNameID().getValue();
final String groupAttributeName = samlConfiguration.getGroupAttributeName();
if (StringUtils.isBlank(groupAttributeName)) {
LOGGER.warn("Cannot obtain groups for {} because no group attribute name has been configured", userIdentity);
return Collections.emptySet();
}
final Set<String> groups = new HashSet<>();
if (credential.getAttributes() != null) {
for (final Attribute attribute : credential.getAttributes()) {
if (!groupAttributeName.equals(attribute.getName())) {
LOGGER.debug("Skipping SAML attribute {}", attribute.getName());
continue;
}
for (final XMLObject value : attribute.getAttributeValues()) {
if (value instanceof XSString) {
final XSString valueXSString = (XSString) value;
final String groupName = valueXSString.getValue();
LOGGER.debug("Found group {} for {}", groupName, userIdentity);
groups.add(groupName);
} else if (value instanceof XSAnyImpl) {
final XSAnyImpl valueXSAnyImpl = (XSAnyImpl) value;
final String groupName = valueXSAnyImpl.getTextContent();
LOGGER.debug("Found group {} for {}", groupName, userIdentity);
groups.add(groupName);
} else {
LOGGER.debug("Value was not XSString and XSAnyImpl, but was " + value.getClass().getCanonicalName());
}
}
}
}
return groups;
}
@Override
public void initiateLogout(final HttpServletRequest request, final HttpServletResponse response, final SAMLCredential credential) {
verifyReadyForSamlOperations();
final SAMLMessageContext context;
try {
final NiFiSAMLContextProvider contextProvider = samlConfiguration.getContextProvider();
context = contextProvider.getLocalAndPeerEntity(request, response, Collections.emptyMap());
} catch (MetadataProviderException e) {
throw new IllegalStateException("Unable to create SAML Message Context: " + e.getMessage(), e);
}
final SAMLLogger samlLogger = samlConfiguration.getLogger();
final SingleLogoutProfile singleLogoutProfile = samlConfiguration.getSingleLogoutProfile();
try {
singleLogoutProfile.sendLogoutRequest(context, credential);
samlLogger.log(SAMLConstants.LOGOUT_REQUEST, SAMLConstants.SUCCESS, context);
} catch (Exception e) {
samlLogger.log(SAMLConstants.LOGOUT_REQUEST, SAMLConstants.FAILURE, context);
throw new RuntimeException("Unable to initiate SAML logout request: " + e.getMessage(), e);
}
}
@Override
public void processLogout(final HttpServletRequest request, final HttpServletResponse response, final Map<String, String> parameters) {
verifyReadyForSamlOperations();
final SAMLMessageContext context;
try {
final NiFiSAMLContextProvider contextProvider = samlConfiguration.getContextProvider();
context = contextProvider.getLocalAndPeerEntity(request, response, parameters);
} catch (MetadataProviderException e) {
throw new IllegalStateException("Unable to create SAML Message Context: " + e.getMessage(), e);
}
final SAMLProcessor samlProcessor = samlConfiguration.getProcessor();
try {
samlProcessor.retrieveMessage(context);
} catch (Exception e) {
throw new RuntimeException("Unable to load SAML message: " + e.getMessage(), e);
}
// Override set values
context.setCommunicationProfileId(SAMLConstants.SAML2_SLO_PROFILE_URI);
try {
context.setLocalEntityEndpoint(getLocalEntityEndpoint(context));
} catch (SAMLException e) {
throw new RuntimeException(e.getMessage(), e);
}
// Determine if the incoming SAML messages is a response to a logout we initiated, or a request initiated by the IDP
if (context.getInboundSAMLMessage() instanceof LogoutResponse) {
processLogoutResponse(context);
} else if (context.getInboundSAMLMessage() instanceof LogoutRequest) {
processLogoutRequest(context);
}
}
private void processLogoutResponse(final SAMLMessageContext context) {
final SAMLLogger samlLogger = samlConfiguration.getLogger();
final SingleLogoutProfile logoutProfile = samlConfiguration.getSingleLogoutProfile();
try {
logoutProfile.processLogoutResponse(context);
samlLogger.log(SAMLConstants.LOGOUT_RESPONSE, SAMLConstants.SUCCESS, context);
} catch (Exception e) {
LOGGER.error("Received logout response is invalid", e);
samlLogger.log(SAMLConstants.LOGOUT_RESPONSE, SAMLConstants.FAILURE, context, e);
throw new RuntimeException("Received logout response is invalid: " + e.getMessage(), e);
}
}
private void processLogoutRequest(final SAMLMessageContext context) {
throw new UnsupportedOperationException("Apache NiFi currently does not support IDP initiated logout");
}
private Endpoint getLocalEntityEndpoint(final SAMLMessageContext context) throws SAMLException {
return SAMLUtil.getEndpoint(
context.getLocalEntityRoleMetadata().getEndpoints(),
context.getInboundSAMLBinding(),
context.getInboundMessageTransport(),
uriComparator);
}
private void initializeServiceProviderMetadata(final String baseUrl) throws MetadataProviderException {
// Create filters so MetadataGenerator can get URLs, but we don't actually use the filters, the filter
// paths are the URLs from AccessResource that match up with the corresponding SAML endpoint
final SAMLProcessingFilter ssoProcessingFilter = new SAMLProcessingFilter();
ssoProcessingFilter.setFilterProcessesUrl(SAMLEndpoints.LOGIN_CONSUMER);
final LogoutHandler noOpLogoutHandler = (request, response, authentication) -> {
return;
};
final SAMLLogoutProcessingFilter sloProcessingFilter = new SAMLLogoutProcessingFilter("/nifi", noOpLogoutHandler);
sloProcessingFilter.setFilterProcessesUrl(SAMLEndpoints.SINGLE_LOGOUT_CONSUMER);
// Create the MetadataGenerator...
final MetadataGenerator metadataGenerator = new MetadataGenerator();
metadataGenerator.setEntityId(samlConfiguration.getSpEntityId());
metadataGenerator.setEntityBaseURL(baseUrl);
metadataGenerator.setExtendedMetadata(samlConfiguration.getExtendedMetadata());
metadataGenerator.setIncludeDiscoveryExtension(false);
metadataGenerator.setKeyManager(samlConfiguration.getKeyManager());
metadataGenerator.setSamlWebSSOFilter(ssoProcessingFilter);
metadataGenerator.setSamlLogoutProcessingFilter(sloProcessingFilter);
metadataGenerator.setRequestSigned(samlConfiguration.isRequestSigningEnabled());
metadataGenerator.setWantAssertionSigned(samlConfiguration.isWantAssertionsSigned());
// Generate service provider metadata...
final EntityDescriptor descriptor = metadataGenerator.generateMetadata();
final ExtendedMetadata extendedMetadata = metadataGenerator.generateExtendedMetadata();
// Create the MetadataProvider to hold SP metadata
final MetadataMemoryProvider memoryProvider = new MetadataMemoryProvider(descriptor);
memoryProvider.initialize();
final MetadataProvider spMetadataProvider = new ExtendedMetadataDelegate(memoryProvider, extendedMetadata);
// Update the MetadataManager with the service provider MetadataProvider
final MetadataManager metadataManager = samlConfiguration.getMetadataManager();
metadataManager.addMetadataProvider(spMetadataProvider);
metadataManager.setHostedSPName(descriptor.getEntityID());
metadataManager.refreshMetadata();
}
private void verifyReadyForSamlOperations() {
if (!isSamlEnabled()) {
throw new IllegalStateException(SAML_SUPPORT_IS_NOT_CONFIGURED);
}
if (!initialized.get()) {
throw new IllegalStateException("StandardSAMLService has not been initialized");
}
if (!isServiceProviderInitialized()) {
throw new IllegalStateException("Service Provider is not initialized");
}
}
}