blob: 3f12a1e1d6d480eef85a64edba71a5276dd5e13f [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.rpc.user;
import org.apache.drill.common.concurrent.AbstractCheckedFuture;
import org.apache.drill.common.concurrent.CheckedFuture;
import com.google.common.base.Strings;
import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Sets;
import com.google.common.util.concurrent.SettableFuture;
import com.google.protobuf.MessageLite;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslHandler;
import org.apache.drill.common.KerberosUtil;
import org.apache.drill.common.config.DrillConfig;
import org.apache.drill.common.config.DrillProperties;
import org.apache.drill.common.exceptions.DrillException;
import org.apache.drill.exec.client.InvalidConnectionInfoException;
import org.apache.drill.exec.memory.BufferAllocator;
import org.apache.drill.exec.proto.CoordinationProtos.DrillbitEndpoint;
import org.apache.drill.exec.proto.GeneralRPCProtos.Ack;
import org.apache.drill.exec.proto.UserBitShared.QueryData;
import org.apache.drill.exec.proto.UserBitShared.QueryId;
import org.apache.drill.exec.proto.UserBitShared.QueryResult;
import org.apache.drill.exec.proto.UserBitShared.SaslMessage;
import org.apache.drill.exec.proto.UserBitShared.UserCredentials;
import org.apache.drill.exec.proto.UserProtos.BitToUserHandshake;
import org.apache.drill.exec.proto.UserProtos.CreatePreparedStatementResp;
import org.apache.drill.exec.proto.UserProtos.GetCatalogsResp;
import org.apache.drill.exec.proto.UserProtos.GetColumnsResp;
import org.apache.drill.exec.proto.UserProtos.GetQueryPlanFragments;
import org.apache.drill.exec.proto.UserProtos.GetSchemasResp;
import org.apache.drill.exec.proto.UserProtos.GetServerMetaResp;
import org.apache.drill.exec.proto.UserProtos.GetTablesResp;
import org.apache.drill.exec.proto.UserProtos.QueryPlanFragments;
import org.apache.drill.exec.proto.UserProtos.RpcEndpointInfos;
import org.apache.drill.exec.proto.UserProtos.RpcType;
import org.apache.drill.exec.proto.UserProtos.RunQuery;
import org.apache.drill.exec.proto.UserProtos.SaslSupport;
import org.apache.drill.exec.proto.UserProtos.UserToBitHandshake;
import org.apache.drill.exec.rpc.AbstractClientConnection;
import org.apache.drill.exec.rpc.Acks;
import org.apache.drill.exec.rpc.BasicClient;
import org.apache.drill.exec.rpc.ConnectionMultiListener;
import org.apache.drill.exec.rpc.DrillRpcFuture;
import org.apache.drill.exec.rpc.NonTransientRpcException;
import org.apache.drill.exec.rpc.OutOfMemoryHandler;
import org.apache.drill.exec.rpc.ProtobufLengthDecoder;
import org.apache.drill.exec.rpc.Response;
import org.apache.drill.exec.rpc.ResponseSender;
import org.apache.drill.exec.rpc.RpcConnectionHandler;
import org.apache.drill.exec.rpc.RpcConstants;
import org.apache.drill.exec.rpc.RpcException;
import org.apache.drill.exec.rpc.RpcOutcomeListener;
import org.apache.drill.exec.rpc.security.AuthStringUtil;
import org.apache.drill.exec.rpc.security.AuthenticatorFactory;
import org.apache.drill.exec.rpc.security.ClientAuthenticatorProvider;
import org.apache.drill.exec.rpc.security.SaslProperties;
import org.apache.drill.exec.rpc.security.plain.PlainFactory;
import org.apache.drill.exec.ssl.SSLConfig;
import org.apache.drill.exec.ssl.SSLConfigBuilder;
import org.apache.hadoop.security.UserGroupInformation;
import org.slf4j.Logger;
import javax.net.ssl.SSLEngine;
import javax.security.sasl.SaslException;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
public class UserClient
extends BasicClient<RpcType, UserClient.UserToBitConnection, UserToBitHandshake, BitToUserHandshake> {
private static final Logger logger = org.slf4j.LoggerFactory.getLogger(UserClient.class);
private final BufferAllocator allocator;
private final QueryResultHandler queryResultHandler = new QueryResultHandler();
private final String clientName;
private final boolean supportComplexTypes;
private RpcEndpointInfos serverInfos = null;
private Set<RpcType> supportedMethods = null;
private SSLConfig sslConfig;
private DrillbitEndpoint endpoint;
private DrillProperties properties;
public UserClient(String clientName, DrillConfig config, Properties properties, boolean supportComplexTypes,
BufferAllocator allocator, EventLoopGroup eventLoopGroup, Executor eventExecutor,
DrillbitEndpoint endpoint) throws NonTransientRpcException {
super(UserRpcConfig.getMapping(config, eventExecutor), allocator.getAsByteBufAllocator(),
eventLoopGroup, RpcType.HANDSHAKE, BitToUserHandshake.class, BitToUserHandshake.PARSER);
this.endpoint = endpoint; // save the endpoint; it might be needed by SSL init.
this.clientName = clientName;
this.allocator = allocator;
this.supportComplexTypes = supportComplexTypes;
try {
this.sslConfig = new SSLConfigBuilder().properties(properties).mode(SSLConfig.Mode.CLIENT)
.initializeSSLContext(true).validateKeyStore(false).build();
} catch (DrillException e) {
throw new InvalidConnectionInfoException(e.getMessage());
}
// Keep a copy of properties in UserClient
this.properties = DrillProperties.createFromProperties(properties);
}
@Override protected void setupSSL(ChannelPipeline pipe,
ConnectionMultiListener.SSLHandshakeListener sslHandshakeListener) {
String peerHost = endpoint.getAddress();
int peerPort = endpoint.getUserPort();
SSLEngine sslEngine = sslConfig.createSSLEngine(allocator, peerHost, peerPort);
// Add SSL handler into pipeline
SslHandler sslHandler = new SslHandler(sslEngine);
sslHandler.setHandshakeTimeoutMillis(sslConfig.getHandshakeTimeout());
// Add a listener for SSL Handshake complete. The Drill client handshake will be enabled only
// after this is done.
sslHandler.handshakeFuture().addListener(sslHandshakeListener);
pipe.addFirst(RpcConstants.SSL_HANDLER, sslHandler);
logger.debug(sslConfig.toString());
}
@Override protected boolean isSslEnabled() {
return sslConfig.isUserSslEnabled();
}
public RpcEndpointInfos getServerInfos() {
return serverInfos;
}
public Set<RpcType> getSupportedMethods() {
return supportedMethods;
}
public void submitQuery(UserResultsListener resultsListener, RunQuery query) {
send(queryResultHandler.getWrappedListener(resultsListener), RpcType.RUN_QUERY, query, QueryId.class);
}
/**
* Connects, and if required, authenticates. This method blocks until both operations are complete.
*
* @param endpoint endpoint to connect to
* @param properties properties
* @param credentials credentials
* @throws RpcException if either connection or authentication fails
*/
public void connect(final DrillbitEndpoint endpoint, final DrillProperties properties,
final UserCredentials credentials) throws RpcException {
final UserToBitHandshake.Builder hsBuilder =
UserToBitHandshake.newBuilder()
.setRpcVersion(UserRpcConfig.RPC_VERSION)
.setSupportListening(true)
.setSupportComplexTypes(supportComplexTypes)
.setSupportTimeout(true).setCredentials(credentials)
.setClientInfos(UserRpcUtils.getRpcEndpointInfos(clientName))
.setSaslSupport(SaslSupport.SASL_PRIVACY)
.setProperties(properties.serializeForServer());
// Only used for testing purpose
if (properties.containsKey(DrillProperties.TEST_SASL_LEVEL)) {
hsBuilder.setSaslSupport(
SaslSupport.valueOf(Integer.parseInt(properties.getProperty(DrillProperties.TEST_SASL_LEVEL))));
}
try {
if (sslConfig.isUserSslEnabled()) {
connect(hsBuilder.build(), endpoint).checkedGet(sslConfig.getHandshakeTimeout(), TimeUnit.MILLISECONDS);
} else {
connect(hsBuilder.build(), endpoint).checkedGet();
}
} // Treat all authentication related exception as NonTransientException, since in those cases retry by client
// should not happen
catch(TimeoutException e) {
String msg = new StringBuilder().append("Connecting to the server timed out. This is sometimes due to a " +
"mismatch in the SSL configuration between" + " client and server. [ Exception: ").append(e.getMessage())
.append("]").toString();
throw new NonTransientRpcException(msg);
} catch (SaslException e) {
throw new NonTransientRpcException(e);
} catch (RpcException e) {
throw e;
} catch (Exception e) {
throw new RpcException(e);
}
}
/**
* Validate that security requirements from client and Drillbit side is compatible. For example: It verifies if one
* side needs authentication / encryption then other side is also configured to support that security properties.
* @param properties - DrillClient connection parameters
* @param serverAuthMechs - list of auth mechanisms supported by server
* @throws NonTransientRpcException - When DrillClient security requirements doesn't match Drillbit side of security
* configurations.
*/
private void validateSaslCompatibility(DrillProperties properties, List<String> serverAuthMechs)
throws NonTransientRpcException {
final boolean clientNeedsEncryption = properties.containsKey(DrillProperties.SASL_ENCRYPT)
&& Boolean.parseBoolean(properties.getProperty(DrillProperties.SASL_ENCRYPT));
final boolean serverAuthConfigured = (serverAuthMechs != null);
// Check if client needs encryption and server is not configured for encryption.
if (clientNeedsEncryption && !connection.isEncryptionEnabled()) {
throw new NonTransientRpcException(
"Client needs encrypted connection but server is not configured for encryption." +
" Please contact your administrator. [Warn: It may be due to wrong config or a security attack in" +
" progress.]");
}
// Check if client needs encryption and server doesn't support any security mechanisms.
if (clientNeedsEncryption && !serverAuthConfigured) {
throw new NonTransientRpcException(
"Client needs encrypted connection but server doesn't support any security mechanisms." +
" Please contact your administrator. [Warn: It may be due to wrong config or a security attack in" +
" progress.]");
}
// Check if client needs authentication and server doesn't support any security mechanisms.
if (clientNeedsAuthExceptPlain(properties) && !serverAuthConfigured) {
throw new NonTransientRpcException(
"Client needs authentication but server doesn't support any security mechanisms." +
" Please contact your administrator. [Warn: It may be due to wrong config or a security attack in" +
" progress.]");
}
}
/**
* Determine if client needs authenticated connection or not. It checks for following as an indication of
* requiring authentication from client side:
* 1) Auth mechanism except PLAIN is provided in connection properties with DrillProperties.AUTH_MECHANISM parameter.
* 2) A service principal is provided in connection properties in which case we treat AUTH_MECHANISM as Kerberos
* @param props - DrillClient connection parameters
* @return - true - If any of above 2 checks is successful.
* - false - If all the above 2 checks fails.
*/
private boolean clientNeedsAuthExceptPlain(DrillProperties props) {
boolean clientNeedsAuth = false;
final String authMechanism = props.getProperty(DrillProperties.AUTH_MECHANISM);
if (!Strings.isNullOrEmpty(authMechanism) && !authMechanism.equalsIgnoreCase(PlainFactory.SIMPLE_NAME)) {
clientNeedsAuth = true;
}
clientNeedsAuth |= !Strings.isNullOrEmpty(props.getProperty(DrillProperties.SERVICE_PRINCIPAL));
return clientNeedsAuth;
}
private CheckedFuture<Void, IOException> connect(final UserToBitHandshake handshake,
final DrillbitEndpoint endpoint) {
final SettableFuture<Void> connectionSettable = SettableFuture.create();
final CheckedFuture<Void, IOException> connectionFuture =
new AbstractCheckedFuture<Void, IOException>(connectionSettable) {
@Override protected IOException mapException(Exception e) {
if (e instanceof SaslException) {
return (SaslException)e;
} else if (e instanceof ExecutionException) {
final Throwable cause = Throwables.getRootCause(e);
if (cause instanceof SaslException) {
return (SaslException)cause;
}
}
return RpcException.mapException(e);
}
};
final RpcConnectionHandler<UserToBitConnection> connectionHandler =
new RpcConnectionHandler<UserToBitConnection>() {
@Override public void connectionSucceeded(UserToBitConnection connection) {
connectionSettable.set(null);
}
@Override public void connectionFailed(FailureType type, Throwable t) {
// Don't wrap NonTransientRpcException inside RpcException, since called should not retry to connect in
// this case
if (t instanceof NonTransientRpcException || t instanceof SaslException) {
connectionSettable.setException(t);
} else if (t instanceof RpcException) {
final Throwable cause = t.getCause();
if (cause instanceof SaslException) {
connectionSettable.setException(cause);
return;
}
connectionSettable.setException(t);
} else {
connectionSettable.setException(
new RpcException(String.format("%s : %s", type.name(), t.getMessage()), t));
}
}
};
connectAsClient(queryResultHandler.getWrappedConnectionHandler(connectionHandler), handshake,
endpoint.getAddress(), endpoint.getUserPort());
return connectionFuture;
}
/**
* Get's the authenticator factory for the mechanism required by client if it's supported on the server side too.
* Otherwise it throws {@link SaslException}
* @param properties - client connection properties
* @param serverAuthMechanisms - list of authentication mechanisms supported by server
* @return - {@link AuthenticatorFactory} for the mechanism required by client for authentication
* @throws SaslException - In case of failure
*/
private AuthenticatorFactory getAuthenticatorFactory(final DrillProperties properties,
List<String> serverAuthMechanisms) throws SaslException {
final Set<String> mechanismSet = AuthStringUtil.asSet(serverAuthMechanisms);
// first, check if a certain mechanism must be used
String authMechanism = properties.getProperty(DrillProperties.AUTH_MECHANISM);
if (authMechanism != null) {
if (!ClientAuthenticatorProvider.getInstance().containsFactory(authMechanism)) {
throw new SaslException(String.format("Unknown mechanism: %s", authMechanism));
}
if (!mechanismSet.contains(authMechanism.toUpperCase())) {
throw new SaslException(String
.format("Server does not support authentication using: %s. [Details: %s]", authMechanism,
connection.getEncryptionCtxtString()));
}
return ClientAuthenticatorProvider.getInstance().getAuthenticatorFactory(authMechanism);
}
// check if Kerberos is supported, and the service principal is provided
if (mechanismSet.contains(KerberosUtil.KERBEROS_SIMPLE_NAME) && properties
.containsKey(DrillProperties.SERVICE_PRINCIPAL)) {
return ClientAuthenticatorProvider.getInstance()
.getAuthenticatorFactory(KerberosUtil.KERBEROS_SIMPLE_NAME);
}
// check if username/password is supported, and username/password are provided
if (mechanismSet.contains(PlainFactory.SIMPLE_NAME) && properties.containsKey(DrillProperties.USER)
&& !Strings.isNullOrEmpty(properties.getProperty(DrillProperties.PASSWORD))) {
return ClientAuthenticatorProvider.getInstance().getAuthenticatorFactory(PlainFactory.SIMPLE_NAME);
}
throw new SaslException(String
.format("Server requires authentication using %s. Insufficient credentials?. " + "[Details: %s]. ",
mechanismSet, connection.getEncryptionCtxtString()));
}
protected <SEND extends MessageLite, RECEIVE extends MessageLite> void send(
RpcOutcomeListener<RECEIVE> listener, RpcType rpcType, SEND protobufBody, Class<RECEIVE> clazz,
boolean allowInEventLoop, ByteBuf... dataBodies) {
super.send(listener, connection, rpcType, protobufBody, clazz, allowInEventLoop, dataBodies);
}
@Override protected MessageLite getResponseDefaultInstance(int rpcType) throws RpcException {
switch (rpcType) {
case RpcType.ACK_VALUE:
return Ack.getDefaultInstance();
case RpcType.HANDSHAKE_VALUE:
return BitToUserHandshake.getDefaultInstance();
case RpcType.QUERY_HANDLE_VALUE:
return QueryId.getDefaultInstance();
case RpcType.QUERY_RESULT_VALUE:
return QueryResult.getDefaultInstance();
case RpcType.QUERY_DATA_VALUE:
return QueryData.getDefaultInstance();
case RpcType.QUERY_PLAN_FRAGMENTS_VALUE:
return QueryPlanFragments.getDefaultInstance();
case RpcType.CATALOGS_VALUE:
return GetCatalogsResp.getDefaultInstance();
case RpcType.SCHEMAS_VALUE:
return GetSchemasResp.getDefaultInstance();
case RpcType.TABLES_VALUE:
return GetTablesResp.getDefaultInstance();
case RpcType.COLUMNS_VALUE:
return GetColumnsResp.getDefaultInstance();
case RpcType.PREPARED_STATEMENT_VALUE:
return CreatePreparedStatementResp.getDefaultInstance();
case RpcType.SASL_MESSAGE_VALUE:
return SaslMessage.getDefaultInstance();
case RpcType.SERVER_META_VALUE:
return GetServerMetaResp.getDefaultInstance();
}
throw new RpcException(String.format("Unable to deal with RpcType of %d", rpcType));
}
@Override protected void handle(UserToBitConnection connection, int rpcType, ByteBuf pBody, ByteBuf dBody,
ResponseSender sender) throws RpcException {
if (!isAuthComplete()) {
// Remote should not be making any requests before authenticating, drop connection
throw new RpcException(String.format("Request of type %d is not allowed without authentication. "
+ "Remote on %s must authenticate before making requests. Connection dropped.", rpcType,
connection.getRemoteAddress()));
}
switch (rpcType) {
case RpcType.QUERY_DATA_VALUE:
queryResultHandler.batchArrived(connection, pBody, dBody);
sender.send(new Response(RpcType.ACK, Acks.OK));
break;
case RpcType.QUERY_RESULT_VALUE:
queryResultHandler.resultArrived(pBody);
sender.send(new Response(RpcType.ACK, Acks.OK));
break;
default:
throw new RpcException(String.format("Unknown Rpc Type %d. ", rpcType));
}
}
@Override
protected void prepareSaslHandshake(final RpcConnectionHandler<UserToBitConnection> connectionHandler,
List<String> serverAuthMechanisms) {
try {
final Map<String, String> saslProperties = properties.stringPropertiesAsMap();
// Set correct QOP property and Strength based on server needs encryption or not.
// If ChunkMode is enabled then negotiate for buffer size equal to wrapChunkSize,
// If ChunkMode is disabled then negotiate for MAX_WRAPPED_SIZE buffer size.
saslProperties.putAll(
SaslProperties.getSaslProperties(connection.isEncryptionEnabled(), connection.getMaxWrappedSize()));
final AuthenticatorFactory factory = getAuthenticatorFactory(properties, serverAuthMechanisms);
final String mechanismName = factory.getSimpleName();
logger.trace("Will try to authenticate to server using {} mechanism with encryption context {}",
mechanismName, connection.getEncryptionCtxtString());
// Update the thread context class loader to current class loader
// See DRILL-6063 for detailed description
final ClassLoader oldThreadCtxtCL = Thread.currentThread().getContextClassLoader();
final ClassLoader newThreadCtxtCL = this.getClass().getClassLoader();
Thread.currentThread().setContextClassLoader(newThreadCtxtCL);
final UserGroupInformation ugi = factory.createAndLoginUser(saslProperties);
// Reset the thread context class loader to original one
Thread.currentThread().setContextClassLoader(oldThreadCtxtCL);
startSaslHandshake(connectionHandler, saslProperties, ugi, factory, RpcType.SASL_MESSAGE);
} catch (IOException e) {
logger.error("Failed while doing setup for starting SASL handshake for connection {}", connection.getName());
final Exception ex = new RpcException(String.format("Failed to initiate authentication for connection %s",
connection.getName()), e);
connectionHandler.connectionFailed(RpcConnectionHandler.FailureType.AUTHENTICATION, ex);
}
}
@Override protected List<String> validateHandshake(BitToUserHandshake inbound) throws RpcException {
// logger.debug("Handling handshake from bit to user. {}", inbound);
List<String> serverAuthMechanisms = null;
if (inbound.hasServerInfos()) {
serverInfos = inbound.getServerInfos();
}
supportedMethods = Sets.immutableEnumSet(inbound.getSupportedMethodsList());
switch (inbound.getStatus()) {
case SUCCESS:
break;
case AUTH_REQUIRED:
serverAuthMechanisms = ImmutableList.copyOf(inbound.getAuthenticationMechanismsList());
setAuthComplete(false);
connection.setEncryption(inbound.hasEncrypted() && inbound.getEncrypted());
if (inbound.hasMaxWrappedSize()) {
connection.setMaxWrappedSize(inbound.getMaxWrappedSize());
}
logger.trace(String
.format("Server requires authentication with encryption context %s before proceeding.",
connection.getEncryptionCtxtString()));
break;
case AUTH_FAILED:
case RPC_VERSION_MISMATCH:
case UNKNOWN_FAILURE:
final String errMsg = String
.format("Status: %s, Error Id: %s, Error message: %s", inbound.getStatus(),
inbound.getErrorId(), inbound.getErrorMessage());
logger.error(errMsg);
throw new NonTransientRpcException(errMsg);
}
// Before starting SASL handshake validate if both client and server are compatible in their security
// requirements for the connection
validateSaslCompatibility(properties, serverAuthMechanisms);
return serverAuthMechanisms;
}
@Override protected UserToBitConnection initRemoteConnection(SocketChannel channel) {
super.initRemoteConnection(channel);
return new UserToBitConnection(channel);
}
public class UserToBitConnection extends AbstractClientConnection {
UserToBitConnection(SocketChannel channel) {
// by default connection is not set for encryption. After receiving handshake msg from server we set the
// isEncryptionEnabled, useChunkMode and chunkModeSize correctly.
super(channel, "user client");
}
@Override public BufferAllocator getAllocator() {
return allocator;
}
@Override protected Logger getLogger() {
return logger;
}
@Override public void incConnectionCounter() {
// no-op
}
@Override public void decConnectionCounter() {
// no-op
}
}
@Override public ProtobufLengthDecoder getDecoder(BufferAllocator allocator) {
return new UserProtobufLengthDecoder(allocator, OutOfMemoryHandler.DEFAULT_INSTANCE);
}
/**
* planQuery is an API to plan a query without query execution
*
* @param req - data necessary to plan query
* @return list of PlanFragments that can later on be submitted for execution
*/
public DrillRpcFuture<QueryPlanFragments> planQuery(GetQueryPlanFragments req) {
return send(RpcType.GET_QUERY_PLAN_FRAGMENTS, req, QueryPlanFragments.class);
}
}