blob: 42e3d15c02fccbf3146148cf65423e2ca5d2ecff [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.rpc.security;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.google.protobuf.ByteString;
import com.google.protobuf.Internal.EnumLite;
import com.google.protobuf.InvalidProtocolBufferException;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufInputStream;
import org.apache.drill.exec.proto.UserBitShared.SaslMessage;
import org.apache.drill.exec.proto.UserBitShared.SaslStatus;
import org.apache.drill.exec.rpc.RequestHandler;
import org.apache.drill.exec.rpc.Response;
import org.apache.drill.exec.rpc.ResponseSender;
import org.apache.drill.exec.rpc.RpcException;
import org.apache.drill.exec.rpc.ServerConnection;
import org.apache.hadoop.security.UserGroupInformation;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import java.io.IOException;
import java.lang.reflect.UndeclaredThrowableException;
import java.security.PrivilegedExceptionAction;
import java.util.EnumMap;
import java.util.Map;
import static com.google.common.base.Preconditions.checkNotNull;
/**
* Handles SASL exchange, on the server-side.
*
* @param <S> Server connection type
* @param <T> RPC type
*/
public class ServerAuthenticationHandler<S extends ServerConnection<S>, T extends EnumLite>
implements RequestHandler<S> {
private static final org.slf4j.Logger logger =
org.slf4j.LoggerFactory.getLogger(ServerAuthenticationHandler.class);
private static final ImmutableMap<SaslStatus, SaslResponseProcessor> RESPONSE_PROCESSORS;
static {
final Map<SaslStatus, SaslResponseProcessor> map = new EnumMap<>(SaslStatus.class);
map.put(SaslStatus.SASL_START, new SaslStartProcessor());
map.put(SaslStatus.SASL_IN_PROGRESS, new SaslInProgressProcessor());
map.put(SaslStatus.SASL_SUCCESS, new SaslSuccessProcessor());
map.put(SaslStatus.SASL_FAILED, new SaslFailedProcessor());
RESPONSE_PROCESSORS = Maps.immutableEnumMap(map);
}
private final RequestHandler<S> requestHandler;
private final int saslRequestTypeValue;
private final T saslResponseType;
public ServerAuthenticationHandler(final RequestHandler<S> requestHandler, final int saslRequestTypeValue,
final T saslResponseType) {
this.requestHandler = requestHandler;
this.saslRequestTypeValue = saslRequestTypeValue;
this.saslResponseType = saslResponseType;
}
@Override
public void handle(S connection, int rpcType, ByteBuf pBody, ByteBuf dBody, ResponseSender sender)
throws RpcException {
final String remoteAddress = connection.getRemoteAddress().toString();
// exchange involves server "challenges" and client "responses" (initiated by client)
if (saslRequestTypeValue == rpcType) {
final SaslMessage saslResponse;
try {
saslResponse = SaslMessage.PARSER.parseFrom(new ByteBufInputStream(pBody));
} catch (final InvalidProtocolBufferException e) {
handleAuthFailure(connection, sender, e, saslResponseType);
return;
}
logger.trace("Received SASL message {} from {}", saslResponse.getStatus(), remoteAddress);
final SaslResponseProcessor processor = RESPONSE_PROCESSORS.get(saslResponse.getStatus());
if (processor == null) {
logger.info("Unknown message type from client from {}. Will stop authentication.", remoteAddress);
handleAuthFailure(connection, sender, new SaslException("Received unexpected message"),
saslResponseType);
return;
}
final SaslResponseContext<S, T> context = new SaslResponseContext<>(saslResponse, connection, sender,
requestHandler, saslResponseType);
try {
processor.process(context);
} catch (final Exception e) {
handleAuthFailure(connection, sender, e, saslResponseType);
}
} else {
// this handler only handles messages of SASL_MESSAGE_VALUE type
// the response type for this request type is likely known from UserRpcConfig,
// but the client should not be making any requests before authenticating.
// drop connection
throw new RpcException(
String.format("Request of type %d is not allowed without authentication. Client on %s must authenticate " +
"before making requests. Connection dropped. [Details: %s]",
rpcType, remoteAddress, connection.getEncryptionCtxtString()));
}
}
private static class SaslResponseContext<S extends ServerConnection<S>, T extends EnumLite> {
final SaslMessage saslResponse;
final S connection;
final ResponseSender sender;
final RequestHandler<S> requestHandler;
final T saslResponseType;
SaslResponseContext(SaslMessage saslResponse, S connection, ResponseSender sender,
RequestHandler<S> requestHandler, T saslResponseType) {
this.saslResponse = checkNotNull(saslResponse);
this.connection = checkNotNull(connection);
this.sender = checkNotNull(sender);
this.requestHandler = checkNotNull(requestHandler);
this.saslResponseType = checkNotNull(saslResponseType);
}
}
private interface SaslResponseProcessor {
/**
* Process response from client, and if there are no exceptions, send response using
* {@link SaslResponseContext#sender}. Otherwise, throw the exception.
*
* @param context response context
*/
<S extends ServerConnection<S>, T extends EnumLite>
void process(SaslResponseContext<S, T> context) throws Exception;
}
private static class SaslStartProcessor implements SaslResponseProcessor {
@Override
public <S extends ServerConnection<S>, T extends EnumLite>
void process(SaslResponseContext<S, T> context) throws Exception {
context.connection.initSaslServer(context.saslResponse.getMechanism());
// assume #evaluateResponse must be called at least once
RESPONSE_PROCESSORS.get(SaslStatus.SASL_IN_PROGRESS).process(context);
}
}
private static class SaslInProgressProcessor implements SaslResponseProcessor {
@Override
public <S extends ServerConnection<S>, T extends EnumLite>
void process(SaslResponseContext<S, T> context) throws Exception {
final SaslMessage.Builder challenge = SaslMessage.newBuilder();
final SaslServer saslServer = context.connection.getSaslServer();
final byte[] challengeBytes = evaluateResponse(saslServer, context.saslResponse.getData().toByteArray());
if (saslServer.isComplete()) {
challenge.setStatus(SaslStatus.SASL_SUCCESS);
if (challengeBytes != null) {
challenge.setData(ByteString.copyFrom(challengeBytes));
}
handleSuccess(context, challenge, saslServer);
} else {
challenge.setStatus(SaslStatus.SASL_IN_PROGRESS)
.setData(ByteString.copyFrom(challengeBytes));
context.sender.send(new Response(context.saslResponseType, challenge.build()));
}
}
}
// only when client succeeds first
private static class SaslSuccessProcessor implements SaslResponseProcessor {
@Override
public <S extends ServerConnection<S>, T extends EnumLite>
void process(SaslResponseContext<S, T> context) throws Exception {
// at this point, #isComplete must be false; so try once, fail otherwise
final SaslServer saslServer = context.connection.getSaslServer();
evaluateResponse(saslServer, context.saslResponse.getData().toByteArray()); // discard challenge
if (saslServer.isComplete()) {
final SaslMessage.Builder challenge = SaslMessage.newBuilder();
challenge.setStatus(SaslStatus.SASL_SUCCESS);
handleSuccess(context, challenge, saslServer);
} else {
final S connection = context.connection;
logger.info("Failed to authenticate client from {} with encryption context:{}",
connection.getRemoteAddress().toString(), connection.getEncryptionCtxtString());
throw new SaslException(String.format("Client allegedly succeeded authentication but server did not. " +
"Suspicious? [Details: %s]", connection.getEncryptionCtxtString()));
}
}
}
private static class SaslFailedProcessor implements SaslResponseProcessor {
@Override
public <S extends ServerConnection<S>, T extends EnumLite>
void process(SaslResponseContext<S, T> context) throws Exception {
final S connection = context.connection;
logger.info("Client from {} failed authentication with encryption context:{} graciously, and does not want to " +
"continue.", connection.getRemoteAddress().toString(), connection.getEncryptionCtxtString());
throw new SaslException(String.format("Client graciously failed authentication. [Details: %s]",
connection.getEncryptionCtxtString()));
}
}
private static byte[] evaluateResponse(final SaslServer saslServer,
final byte[] responseBytes) throws SaslException {
try {
return UserGroupInformation.getLoginUser()
.doAs(new PrivilegedExceptionAction<byte[]>() {
@Override
public byte[] run() throws Exception {
return saslServer.evaluateResponse(responseBytes);
}
});
} catch (final UndeclaredThrowableException e) {
throw new SaslException(String.format("Unexpected failure trying to authenticate using %s",
saslServer.getMechanismName()), e.getCause());
} catch (final IOException | InterruptedException e) {
if (e instanceof SaslException) {
throw (SaslException) e;
} else {
throw new SaslException(String.format("Unexpected failure trying to authenticate using %s",
saslServer.getMechanismName()), e);
}
}
}
private static <S extends ServerConnection<S>, T extends EnumLite>
void handleSuccess(final SaslResponseContext<S, T> context, final SaslMessage.Builder challenge,
final SaslServer saslServer) throws IOException {
final S connection = context.connection;
connection.changeHandlerTo(context.requestHandler);
connection.finalizeSaslSession();
// Check the negotiated property before sending the response back to client
try {
final String negotiatedQOP = saslServer.getNegotiatedProperty(Sasl.QOP).toString();
final String expectedQOP = (connection.isEncryptionEnabled())
? SaslProperties.QualityOfProtection.PRIVACY.getSaslQop()
: SaslProperties.QualityOfProtection.AUTHENTICATION.getSaslQop();
if (!(negotiatedQOP.equals(expectedQOP))) {
throw new SaslException(String.format("Mismatch in negotiated QOP value: %s and Expected QOP value: %s",
negotiatedQOP, expectedQOP));
}
// Update the rawWrapSendSize with the negotiated rawSendSize since we cannot call encode with more than the
// negotiated size of buffer
if (connection.isEncryptionEnabled()) {
final int negotiatedRawSendSize = Integer.parseInt(
saslServer.getNegotiatedProperty(Sasl.RAW_SEND_SIZE).toString());
if (negotiatedRawSendSize <= 0) {
throw new SaslException(String.format("Negotiated rawSendSize: %d is invalid. Please check the configured " +
"value of encryption.sasl.max_wrapped_size. It might be configured to a very small value.",
negotiatedRawSendSize));
}
connection.setWrapSizeLimit(negotiatedRawSendSize);
}
} catch (IllegalStateException | NumberFormatException e) {
throw new SaslException(String.format("Unexpected failure while retrieving negotiated property values (%s)",
e.getMessage()), e);
}
if (logger.isTraceEnabled()) {
logger.trace("Authenticated {} successfully using {} from {} with encryption context {}",
saslServer.getAuthorizationID(), saslServer.getMechanismName(), connection.getRemoteAddress().toString(),
connection.getEncryptionCtxtString());
}
// All checks have passed let's send the response back to client before adding handlers.
context.sender.send(new Response(context.saslResponseType, challenge.build()));
if (connection.isEncryptionEnabled()) {
connection.addSecurityHandlers();
} else {
// Encryption is not required hence we don't need to hold on to saslServer object.
connection.disposeSaslServer();
}
}
private static final SaslMessage SASL_FAILED_MESSAGE =
SaslMessage.newBuilder().setStatus(SaslStatus.SASL_FAILED).build();
private static <S extends ServerConnection<S>, T extends EnumLite>
void handleAuthFailure(final S connection, final ResponseSender sender,
final Exception e, final T saslResponseType) throws RpcException {
final String remoteAddress = connection.getRemoteAddress().toString();
logger.debug("Authentication using mechanism {} with encryption context {} failed from client {} due to {}",
connection.getSaslServer().getMechanismName(), connection.getEncryptionCtxtString(), remoteAddress, e);
// inform the client that authentication failed, and no more
sender.send(new Response(saslResponseType, SASL_FAILED_MESSAGE));
// drop connection
throw new RpcException(e);
}
}