DRILL-6187: Exception in RPC communication between DataClient/ControlClient and respective servers when bit-to-bit security is on
This closes #1145
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/client/DrillClient.java b/exec/java-exec/src/main/java/org/apache/drill/exec/client/DrillClient.java
index 71acfb1..ec01ff3 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/client/DrillClient.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/client/DrillClient.java
@@ -217,9 +217,14 @@
* @throws RpcException
*/
public void connect() throws RpcException {
- connect(null, null);
+ connect(null, new Properties());
}
+ /**
+ * Start's a connection from client to server
+ * @param props - not null {@link Properties} filled with connection url parameters
+ * @throws RpcException
+ */
public void connect(Properties props) throws RpcException {
connect(null, props);
}
@@ -308,12 +313,18 @@
return endpointList;
}
+ /**
+ * Start's a connection from client to server
+ * @param connect - Zookeeper connection string provided at connection URL
+ * @param props - not null {@link Properties} filled with connection url parameters
+ * @throws RpcException
+ */
public synchronized void connect(String connect, Properties props) throws RpcException {
if (connected) {
return;
}
- properties = DrillProperties.createFromProperties(props);
+ properties = DrillProperties.createFromProperties(props);
final List<DrillbitEndpoint> endpoints = new ArrayList<>();
if (isDirectConnection) {
@@ -371,6 +382,15 @@
while (triedEndpointIndex < connectTriesVal) {
endpoint = endpoints.get(triedEndpointIndex);
+
+ // Set in both props and properties since props is passed to UserClient
+ // TODO: Logically here it's doing putIfAbsent, please change to use that api once JDK 8 is minimum required
+ // version
+ if (!properties.containsKey(DrillProperties.SERVICE_HOST)) {
+ properties.setProperty(DrillProperties.SERVICE_HOST, endpoint.getAddress());
+ props.setProperty(DrillProperties.SERVICE_HOST, endpoint.getAddress());
+ }
+
// Note: the properties member is a DrillProperties instance which lower cases names of
// properties. That does not work too well with properties that are mixed case.
// For user client severla properties are mixed case so we do not use the properties member
@@ -378,10 +398,6 @@
client = new UserClient(clientName, config, props, supportComplexTypes, allocator, eventLoopGroup, executor, endpoint);
logger.debug("Connecting to server {}:{}", endpoint.getAddress(), endpoint.getUserPort());
- if (!properties.containsKey(DrillProperties.SERVICE_HOST)) {
- properties.setProperty(DrillProperties.SERVICE_HOST, endpoint.getAddress());
- }
-
try {
connect(endpoint);
connected = true;
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/BitRpcUtility.java b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/BitRpcUtility.java
new file mode 100644
index 0000000..c71363d
--- /dev/null
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/BitRpcUtility.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill.exec.rpc;
+
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.Internal.EnumLite;
+import com.google.protobuf.MessageLite;
+import org.apache.drill.exec.proto.CoordinationProtos.DrillbitEndpoint;
+import org.apache.drill.exec.rpc.security.AuthenticatorFactory;
+import org.apache.drill.exec.rpc.security.SaslProperties;
+import org.apache.hadoop.security.UserGroupInformation;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Utility class providing common methods shared between {@link org.apache.drill.exec.rpc.data.DataClient} and
+ * {@link org.apache.drill.exec.rpc.control.ControlClient}
+ */
+public final class BitRpcUtility {
+ private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(BitRpcUtility.class);
+
+ /**
+ * Method to do validation on the handshake message received from server side. Only used by BitClients NOT UserClient.
+ * Verify if rpc version of handshake message matches the supported RpcVersion and also validates the
+ * security configuration between client and server
+ * @param handshakeRpcVersion - rpc version received in handshake message
+ * @param remoteAuthMechs - authentication mechanisms supported by server
+ * @param rpcVersion - supported rpc version on client
+ * @param connection - client connection
+ * @param config - client connectin config
+ * @param client - data client or control client
+ * @return - Immutable list of authentication mechanisms supported by server or null
+ * @throws RpcException - exception is thrown if rpc version or authentication configuration mismatch is found
+ */
+ public static List<String> validateHandshake(int handshakeRpcVersion, List<String> remoteAuthMechs, int rpcVersion,
+ ClientConnection connection, BitConnectionConfig config,
+ BasicClient client) throws RpcException {
+
+ if (handshakeRpcVersion != rpcVersion) {
+ throw new RpcException(String.format("Invalid rpc version. Expected %d, actual %d.",
+ handshakeRpcVersion, rpcVersion));
+ }
+
+ if (remoteAuthMechs.size() != 0) { // remote requires authentication
+ client.setAuthComplete(false);
+ return ImmutableList.copyOf(remoteAuthMechs);
+ } else {
+ if (config.getAuthMechanismToUse() != null) { // local requires authentication
+ throw new RpcException(String.format("Remote Drillbit does not require auth, but auth is enabled in " +
+ "local Drillbit configuration. [Details: connection: (%s) and LocalAuthMechanism: (%s). Please check " +
+ "security configuration for bit-to-bit.", connection.getName(), config.getAuthMechanismToUse()));
+ }
+ }
+ return null;
+ }
+
+ /**
+ * Creates various instances needed to start the SASL handshake. This is called from
+ * {@link BasicClient#prepareSaslHandshake(RpcConnectionHandler, List)} only for
+ * {@link org.apache.drill.exec.rpc.data.DataClient} and {@link org.apache.drill.exec.rpc.control.ControlClient}
+ *
+ * @param connectionHandler - Connection handler used by client's to know about success/failure conditions.
+ * @param serverAuthMechanisms - List of auth mechanisms configured on server side
+ * @param connection - ClientConnection used for authentication
+ * @param config - ClientConnection config
+ * @param endpoint - Remote DrillbitEndpoint
+ * @param client - Either of DataClient/ControlClient instance
+ * @param saslRpcType - SASL_MESSAGE RpcType for Data and Control channel
+ */
+ public static <T extends EnumLite, CC extends ClientConnection, HS extends MessageLite, HR extends MessageLite>
+ void prepareSaslHandshake(final RpcConnectionHandler<CC> connectionHandler, List<String> serverAuthMechanisms,
+ CC connection, BitConnectionConfig config, DrillbitEndpoint endpoint,
+ final BasicClient<T, CC, HS, HR> client, T saslRpcType) {
+ try {
+ final Map<String, String> saslProperties = SaslProperties.getSaslProperties(connection.isEncryptionEnabled(),
+ connection.getMaxWrappedSize());
+ final UserGroupInformation ugi = UserGroupInformation.getLoginUser();
+ final AuthenticatorFactory factory = config.getAuthFactory(serverAuthMechanisms);
+ client.startSaslHandshake(connectionHandler, config.getSaslClientProperties(endpoint, saslProperties),
+ ugi, factory, saslRpcType);
+ } catch (final IOException e) {
+ logger.error("Failed while doing setup for starting sasl handshake for connection", connection.getName());
+ final Exception ex = new RpcException(String.format("Failed to initiate authentication to %s",
+ endpoint.getAddress()), e);
+ connectionHandler.connectionFailed(RpcConnectionHandler.FailureType.AUTHENTICATION, ex);
+ }
+ }
+
+ // Suppress default constructor
+ private BitRpcUtility() {
+ }
+}
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlClient.java b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlClient.java
index 1e0313a..1df5ff1 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlClient.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlClient.java
@@ -17,36 +17,26 @@
*/
package org.apache.drill.exec.rpc.control;
-import com.google.common.util.concurrent.SettableFuture;
+import com.google.common.collect.ImmutableList;
import com.google.protobuf.MessageLite;
-
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
import io.netty.channel.socket.SocketChannel;
import io.netty.util.concurrent.GenericFutureListener;
-
import org.apache.drill.exec.memory.BufferAllocator;
import org.apache.drill.exec.proto.BitControl.BitControlHandshake;
import org.apache.drill.exec.proto.BitControl.RpcType;
import org.apache.drill.exec.proto.CoordinationProtos.DrillbitEndpoint;
import org.apache.drill.exec.rpc.BasicClient;
-import org.apache.drill.exec.rpc.security.AuthenticationOutcomeListener;
+import org.apache.drill.exec.rpc.BitRpcUtility;
+import org.apache.drill.exec.rpc.FailingRequestHandler;
import org.apache.drill.exec.rpc.OutOfMemoryHandler;
import org.apache.drill.exec.rpc.ProtobufLengthDecoder;
import org.apache.drill.exec.rpc.ResponseSender;
-import org.apache.drill.exec.rpc.RpcCommand;
+import org.apache.drill.exec.rpc.RpcConnectionHandler;
import org.apache.drill.exec.rpc.RpcException;
-import org.apache.drill.exec.rpc.RpcOutcomeListener;
-import org.apache.drill.exec.rpc.FailingRequestHandler;
-import org.apache.drill.exec.rpc.security.SaslProperties;
-import org.apache.hadoop.security.UserGroupInformation;
-
-import javax.security.sasl.SaslClient;
-import javax.security.sasl.SaslException;
-import java.io.IOException;
-import java.util.Map;
-import java.util.concurrent.ExecutionException;
+import java.util.List;
public class ControlClient extends BasicClient<RpcType, ControlConnection, BitControlHandshake, BitControlHandshake> {
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ControlClient.class);
@@ -104,34 +94,16 @@
}
@Override
- protected void validateHandshake(BitControlHandshake handshake) throws RpcException {
- if (handshake.getRpcVersion() != ControlRpcConfig.RPC_VERSION) {
- throw new RpcException(String.format("Invalid rpc version. Expected %d, actual %d.",
- handshake.getRpcVersion(), ControlRpcConfig.RPC_VERSION));
- }
+ protected void prepareSaslHandshake(final RpcConnectionHandler<ControlConnection> connectionHandler,
+ List<String> serverAuthMechanisms) {
+ BitRpcUtility.prepareSaslHandshake(connectionHandler, serverAuthMechanisms, connection, config, remoteEndpoint,
+ this, RpcType.SASL_MESSAGE);
+ }
- if (handshake.getAuthenticationMechanismsCount() != 0) { // remote requires authentication
- final SaslClient saslClient;
- try {
- final Map<String, String> saslProperties = SaslProperties.getSaslProperties(connection.isEncryptionEnabled(),
- connection.getMaxWrappedSize());
-
- saslClient = config.getAuthFactory(handshake.getAuthenticationMechanismsList())
- .createSaslClient(UserGroupInformation.getLoginUser(),
- config.getSaslClientProperties(remoteEndpoint, saslProperties));
- } catch (final IOException e) {
- throw new RpcException(String.format("Failed to initiate authenticate to %s", remoteEndpoint.getAddress()), e);
- }
- if (saslClient == null) {
- throw new RpcException("Unexpected failure. Could not initiate SASL exchange.");
- }
- connection.setSaslClient(saslClient);
- } else {
- if (config.getAuthMechanismToUse() != null) { // local requires authentication
- throw new RpcException(String.format("Drillbit (%s) does not require auth, but auth is enabled.",
- remoteEndpoint.getAddress()));
- }
- }
+ @Override
+ protected List<String> validateHandshake(BitControlHandshake handshake) throws RpcException {
+ return BitRpcUtility.validateHandshake(handshake.getRpcVersion(), handshake.getAuthenticationMechanismsList(),
+ ControlRpcConfig.RPC_VERSION, connection, config, this);
}
@Override
@@ -140,86 +112,6 @@
}
@Override
- protected <M extends MessageLite> RpcCommand<M, ControlConnection>
- getInitialCommand(final RpcCommand<M, ControlConnection> command) {
- final RpcCommand<M, ControlConnection> initialCommand = super.getInitialCommand(command);
- if (config.getAuthMechanismToUse() == null) {
- return initialCommand;
- } else {
- return new AuthenticationCommand<>(initialCommand);
- }
- }
-
- private class AuthenticationCommand<M extends MessageLite> implements RpcCommand<M, ControlConnection> {
-
- private final RpcCommand<M, ControlConnection> command;
-
- AuthenticationCommand(RpcCommand<M, ControlConnection> command) {
- this.command = command;
- }
-
- @Override
- public void connectionAvailable(ControlConnection connection) {
- command.connectionFailed(FailureType.AUTHENTICATION, new SaslException("Should not reach here."));
- }
-
- @Override
- public void connectionSucceeded(final ControlConnection connection) {
- final UserGroupInformation loginUser;
- try {
- loginUser = UserGroupInformation.getLoginUser();
- } catch (final IOException e) {
- logger.debug("Unexpected failure trying to login.", e);
- command.connectionFailed(FailureType.AUTHENTICATION, e);
- return;
- }
-
- final SettableFuture<Void> future = SettableFuture.create();
- new AuthenticationOutcomeListener<>(ControlClient.this, connection, RpcType.SASL_MESSAGE,
- loginUser,
- new RpcOutcomeListener<Void>() {
- @Override
- public void failed(RpcException ex) {
- logger.debug("Authentication failed.", ex);
- future.setException(ex);
- }
-
- @Override
- public void success(Void value, ByteBuf buffer) {
- connection.changeHandlerTo(config.getMessageHandler());
- future.set(null);
- }
-
- @Override
- public void interrupted(InterruptedException e) {
- logger.debug("Authentication failed.", e);
- future.setException(e);
- }
- }).initiate(config.getAuthMechanismToUse());
-
-
- try {
- logger.trace("Waiting until authentication completes..");
- future.get();
- command.connectionSucceeded(connection);
- } catch (InterruptedException e) {
- command.connectionFailed(FailureType.AUTHENTICATION, e);
- // Preserve evidence that the interruption occurred so that code higher up on the call stack can learn of the
- // interruption and respond to it if it wants to.
- Thread.currentThread().interrupt();
- } catch (ExecutionException e) {
- command.connectionFailed(FailureType.AUTHENTICATION, e);
- }
- }
-
- @Override
- public void connectionFailed(FailureType type, Throwable t) {
- logger.debug("Authentication failed.", t);
- command.connectionFailed(FailureType.AUTHENTICATION, t);
- }
- }
-
- @Override
public ProtobufLengthDecoder getDecoder(BufferAllocator allocator) {
return new ControlProtobufLengthDecoder(allocator, OutOfMemoryHandler.DEFAULT_INSTANCE);
}
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlConnection.java b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlConnection.java
index 70189d7..c7d4d8e 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlConnection.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/control/ControlConnection.java
@@ -78,7 +78,7 @@
@Override
public boolean isActive() {
- return active;
+ return active && super.isActive();
}
@Override
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/data/DataClient.java b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/data/DataClient.java
index cba323e..267b483 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/data/DataClient.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/data/DataClient.java
@@ -17,7 +17,7 @@
*/
package org.apache.drill.exec.rpc.data;
-import com.google.common.util.concurrent.SettableFuture;
+import com.google.common.collect.ImmutableList;
import com.google.protobuf.MessageLite;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
@@ -29,21 +29,14 @@
import org.apache.drill.exec.proto.BitData.RpcType;
import org.apache.drill.exec.proto.CoordinationProtos.DrillbitEndpoint;
import org.apache.drill.exec.rpc.BasicClient;
+import org.apache.drill.exec.rpc.BitRpcUtility;
import org.apache.drill.exec.rpc.OutOfMemoryHandler;
import org.apache.drill.exec.rpc.ProtobufLengthDecoder;
import org.apache.drill.exec.rpc.ResponseSender;
-import org.apache.drill.exec.rpc.RpcCommand;
+import org.apache.drill.exec.rpc.RpcConnectionHandler;
import org.apache.drill.exec.rpc.RpcException;
-import org.apache.drill.exec.rpc.RpcOutcomeListener;
-import org.apache.drill.exec.rpc.security.AuthenticationOutcomeListener;
-import org.apache.drill.exec.rpc.security.SaslProperties;
-import org.apache.hadoop.security.UserGroupInformation;
-import javax.security.sasl.SaslClient;
-import javax.security.sasl.SaslException;
-import java.io.IOException;
-import java.util.Map;
-import java.util.concurrent.ExecutionException;
+import java.util.List;
public class DataClient extends BasicClient<RpcType, DataClientConnection, BitClientHandshake, BitServerHandshake> {
private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DataClient.class);
@@ -103,114 +96,17 @@
}
@Override
- protected void validateHandshake(BitServerHandshake handshake) throws RpcException {
- if (handshake.getRpcVersion() != DataRpcConfig.RPC_VERSION) {
- throw new RpcException(String.format("Invalid rpc version. Expected %d, actual %d.",
- handshake.getRpcVersion(), DataRpcConfig.RPC_VERSION));
- }
-
- if (handshake.getAuthenticationMechanismsCount() != 0) { // remote requires authentication
- final SaslClient saslClient;
- try {
-
- final Map<String, String> saslProperties = SaslProperties.getSaslProperties(connection.isEncryptionEnabled(),
- connection.getMaxWrappedSize());
-
- saslClient = config.getAuthFactory(handshake.getAuthenticationMechanismsList())
- .createSaslClient(UserGroupInformation.getLoginUser(),
- config.getSaslClientProperties(remoteEndpoint, saslProperties));
- } catch (final IOException e) {
- throw new RpcException(String.format("Failed to initiate authenticate to %s", remoteEndpoint.getAddress()), e);
- }
- if (saslClient == null) {
- throw new RpcException("Unexpected failure. Could not initiate SASL exchange.");
- }
- connection.setSaslClient(saslClient);
- } else {
- if (config.getAuthMechanismToUse() != null) {
- throw new RpcException(String.format("Drillbit (%s) does not require auth, but auth is enabled.",
- remoteEndpoint.getAddress()));
- }
- }
+ protected void prepareSaslHandshake(final RpcConnectionHandler<DataClientConnection> connectionHandler, List<String> serverAuthMechanisms) {
+ BitRpcUtility.prepareSaslHandshake(connectionHandler, serverAuthMechanisms, connection, config, remoteEndpoint,
+ this, RpcType.SASL_MESSAGE);
}
- protected <M extends MessageLite> RpcCommand<M, DataClientConnection>
- getInitialCommand(final RpcCommand<M, DataClientConnection> command) {
- final RpcCommand<M, DataClientConnection> initialCommand = super.getInitialCommand(command);
- if (config.getAuthMechanismToUse() == null) {
- return initialCommand;
- } else {
- return new AuthenticationCommand<>(initialCommand);
- }
- }
-
- private class AuthenticationCommand<M extends MessageLite> implements RpcCommand<M, DataClientConnection> {
-
- private final RpcCommand<M, DataClientConnection> command;
-
- AuthenticationCommand(RpcCommand<M, DataClientConnection> command) {
- this.command = command;
- }
-
@Override
- public void connectionAvailable(DataClientConnection connection) {
- command.connectionFailed(FailureType.AUTHENTICATION, new SaslException("Should not reach here."));
+ protected List<String> validateHandshake(BitServerHandshake handshake) throws RpcException {
+ return BitRpcUtility.validateHandshake(handshake.getRpcVersion(), handshake.getAuthenticationMechanismsList(),
+ DataRpcConfig.RPC_VERSION, connection, config, this);
}
- @Override
- public void connectionSucceeded(final DataClientConnection connection) {
- final UserGroupInformation loginUser;
- try {
- loginUser = UserGroupInformation.getLoginUser();
- } catch (final IOException e) {
- logger.debug("Unexpected failure trying to login.", e);
- command.connectionFailed(FailureType.AUTHENTICATION, e);
- return;
- }
-
- final SettableFuture<Void> future = SettableFuture.create();
- new AuthenticationOutcomeListener<>(DataClient.this, connection, RpcType.SASL_MESSAGE,
- loginUser,
- new RpcOutcomeListener<Void>() {
- @Override
- public void failed(RpcException ex) {
- logger.debug("Authentication failed.", ex);
- future.setException(ex);
- }
-
- @Override
- public void success(Void value, ByteBuf buffer) {
- future.set(null);
- }
-
- @Override
- public void interrupted(InterruptedException e) {
- logger.debug("Authentication failed.", e);
- future.setException(e);
- }
- }).initiate(config.getAuthMechanismToUse());
-
- try {
- logger.trace("Waiting until authentication completes..");
- future.get();
- command.connectionSucceeded(connection);
- } catch (InterruptedException e) {
- command.connectionFailed(FailureType.AUTHENTICATION, e);
- // Preserve evidence that the interruption occurred so that code higher up on the call stack can learn of the
- // interruption and respond to it if it wants to.
- Thread.currentThread().interrupt();
- } catch (ExecutionException e) {
- command.connectionFailed(FailureType.AUTHENTICATION, e);
- }
- }
-
- @Override
- public void connectionFailed(FailureType type, Throwable t) {
- logger.debug("Authentication failed.", t);
- command.connectionFailed(FailureType.AUTHENTICATION, t);
- }
- }
-
@Override
public ProtobufLengthDecoder getDecoder(BufferAllocator allocator) {
return new DataProtobufLengthDecoder.Client(allocator, OutOfMemoryHandler.DEFAULT_INSTANCE);
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/user/UserClient.java b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/user/UserClient.java
index 131febf..1504ce9 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/user/UserClient.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/user/UserClient.java
@@ -17,28 +17,24 @@
*/
package org.apache.drill.exec.rpc.user;
-import java.io.IOException;
-import java.util.List;
-import java.util.Map;
-import java.util.Properties;
-import java.util.Set;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Executor;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
-
-import javax.net.ssl.SSLEngine;
-import javax.security.sasl.SaslClient;
-import javax.security.sasl.SaslException;
-
+import com.google.common.base.Strings;
+import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Sets;
+import com.google.common.util.concurrent.AbstractCheckedFuture;
+import com.google.common.util.concurrent.CheckedFuture;
+import com.google.common.util.concurrent.SettableFuture;
+import com.google.protobuf.MessageLite;
+import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelPipeline;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslHandler;
import org.apache.drill.common.KerberosUtil;
import org.apache.drill.common.config.DrillConfig;
import org.apache.drill.common.config.DrillProperties;
import org.apache.drill.common.exceptions.DrillException;
import org.apache.drill.exec.client.InvalidConnectionInfoException;
-import org.apache.drill.exec.ssl.SSLConfig;
import org.apache.drill.exec.memory.BufferAllocator;
import org.apache.drill.exec.proto.CoordinationProtos.DrillbitEndpoint;
import org.apache.drill.exec.proto.GeneralRPCProtos.Ack;
@@ -76,28 +72,26 @@
import org.apache.drill.exec.rpc.RpcException;
import org.apache.drill.exec.rpc.RpcOutcomeListener;
import org.apache.drill.exec.rpc.security.AuthStringUtil;
-import org.apache.drill.exec.rpc.security.AuthenticationOutcomeListener;
import org.apache.drill.exec.rpc.security.AuthenticatorFactory;
import org.apache.drill.exec.rpc.security.ClientAuthenticatorProvider;
-import org.apache.drill.exec.rpc.security.plain.PlainFactory;
import org.apache.drill.exec.rpc.security.SaslProperties;
+import org.apache.drill.exec.rpc.security.plain.PlainFactory;
+import org.apache.drill.exec.ssl.SSLConfig;
import org.apache.drill.exec.ssl.SSLConfigBuilder;
import org.apache.hadoop.security.UserGroupInformation;
import org.slf4j.Logger;
-import com.google.common.base.Strings;
-import com.google.common.base.Throwables;
-import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Sets;
-import com.google.common.util.concurrent.AbstractCheckedFuture;
-import com.google.common.util.concurrent.CheckedFuture;
-import com.google.common.util.concurrent.SettableFuture;
-import com.google.protobuf.MessageLite;
-
-
-import io.netty.buffer.ByteBuf;
-import io.netty.channel.EventLoopGroup;
-import io.netty.channel.socket.SocketChannel;
+import javax.net.ssl.SSLEngine;
+import javax.security.sasl.SaslException;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Properties;
+import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executor;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
public class UserClient
extends BasicClient<RpcType, UserClient.UserToBitConnection, UserToBitHandshake, BitToUserHandshake> {
@@ -111,12 +105,11 @@
private RpcEndpointInfos serverInfos = null;
private Set<RpcType> supportedMethods = null;
- // these are used for authentication
- private volatile List<String> serverAuthMechanisms = null;
- private volatile boolean authComplete = true;
private SSLConfig sslConfig;
private DrillbitEndpoint endpoint;
+ private DrillProperties properties;
+
public UserClient(String clientName, DrillConfig config, Properties properties, boolean supportComplexTypes,
BufferAllocator allocator, EventLoopGroup eventLoopGroup, Executor eventExecutor,
DrillbitEndpoint endpoint) throws NonTransientRpcException {
@@ -133,6 +126,8 @@
throw new InvalidConnectionInfoException(e.getMessage());
}
+ // Keep a copy of properties in UserClient
+ this.properties = DrillProperties.createFromProperties(properties);
}
@Override protected void setupSSL(ChannelPipeline pipe,
@@ -195,30 +190,25 @@
SaslSupport.valueOf(Integer.parseInt(properties.getProperty(DrillProperties.TEST_SASL_LEVEL))));
}
- if (sslConfig.isUserSslEnabled()) {
- try {
- connect(hsBuilder.build(), endpoint)
- .checkedGet(sslConfig.getHandshakeTimeout(), TimeUnit.MILLISECONDS);
- } catch (TimeoutException e) {
- String msg = new StringBuilder().append(
- "Connecting to the server timed out. This is sometimes due to a mismatch in the SSL configuration between" +
- " client and server. [ Exception: ")
- .append(e.getMessage()).append("]").toString();
- throw new NonTransientRpcException(msg);
+ try {
+ if (sslConfig.isUserSslEnabled()) {
+ connect(hsBuilder.build(), endpoint).checkedGet(sslConfig.getHandshakeTimeout(), TimeUnit.MILLISECONDS);
+ } else {
+ connect(hsBuilder.build(), endpoint).checkedGet();
}
- } else {
- connect(hsBuilder.build(), endpoint).checkedGet();
- }
-
- // Validate if both client and server are compatible in their security requirements for the connection
- validateSaslCompatibility(properties);
-
- if (serverAuthMechanisms != null) {
- try {
- authenticate(properties).checkedGet();
- } catch (final SaslException e) {
- throw new NonTransientRpcException(e);
- }
+ } // Treat all authentication related exception as NonTransientException, since in those cases retry by client
+ // should not happen
+ catch(TimeoutException e) {
+ String msg = new StringBuilder().append("Connecting to the server timed out. This is sometimes due to a " +
+ "mismatch in the SSL configuration between" + " client and server. [ Exception: ").append(e.getMessage())
+ .append("]").toString();
+ throw new NonTransientRpcException(msg);
+ } catch (SaslException e) {
+ throw new NonTransientRpcException(e);
+ } catch (RpcException e) {
+ throw e;
+ } catch (Exception e) {
+ throw new RpcException(e);
}
}
@@ -226,14 +216,18 @@
* Validate that security requirements from client and Drillbit side is compatible. For example: It verifies if one
* side needs authentication / encryption then other side is also configured to support that security properties.
* @param properties - DrillClient connection parameters
+ * @param serverAuthMechs - list of auth mechanisms supported by server
* @throws NonTransientRpcException - When DrillClient security requirements doesn't match Drillbit side of security
* configurations.
*/
- private void validateSaslCompatibility(DrillProperties properties) throws NonTransientRpcException {
+ private void validateSaslCompatibility(DrillProperties properties, List<String> serverAuthMechs)
+ throws NonTransientRpcException {
final boolean clientNeedsEncryption = properties.containsKey(DrillProperties.SASL_ENCRYPT)
&& Boolean.parseBoolean(properties.getProperty(DrillProperties.SASL_ENCRYPT));
+ final boolean serverAuthConfigured = (serverAuthMechs != null);
+
// Check if client needs encryption and server is not configured for encryption.
if (clientNeedsEncryption && !connection.isEncryptionEnabled()) {
throw new NonTransientRpcException(
@@ -243,7 +237,7 @@
}
// Check if client needs encryption and server doesn't support any security mechanisms.
- if (clientNeedsEncryption && serverAuthMechanisms == null) {
+ if (clientNeedsEncryption && !serverAuthConfigured) {
throw new NonTransientRpcException(
"Client needs encrypted connection but server doesn't support any security mechanisms." +
" Please contact your administrator. [Warn: It may be due to wrong config or a security attack in" +
@@ -251,7 +245,7 @@
}
// Check if client needs authentication and server doesn't support any security mechanisms.
- if (clientNeedsAuthExceptPlain(properties) && serverAuthMechanisms == null) {
+ if (clientNeedsAuthExceptPlain(properties) && !serverAuthConfigured) {
throw new NonTransientRpcException(
"Client needs authentication but server doesn't support any security mechanisms." +
" Please contact your administrator. [Warn: It may be due to wrong config or a security attack in" +
@@ -280,15 +274,24 @@
return clientNeedsAuth;
}
- private CheckedFuture<Void, RpcException> connect(final UserToBitHandshake handshake,
+ private CheckedFuture<Void, IOException> connect(final UserToBitHandshake handshake,
final DrillbitEndpoint endpoint) {
final SettableFuture<Void> connectionSettable = SettableFuture.create();
- final CheckedFuture<Void, RpcException> connectionFuture =
- new AbstractCheckedFuture<Void, RpcException>(connectionSettable) {
- @Override protected RpcException mapException(Exception e) {
+ final CheckedFuture<Void, IOException> connectionFuture =
+ new AbstractCheckedFuture<Void, IOException>(connectionSettable) {
+ @Override protected IOException mapException(Exception e) {
+ if (e instanceof SaslException) {
+ return (SaslException)e;
+ } else if (e instanceof ExecutionException) {
+ final Throwable cause = Throwables.getRootCause(e);
+ if (cause instanceof SaslException) {
+ return (SaslException)cause;
+ }
+ }
return RpcException.mapException(e);
}
};
+
final RpcConnectionHandler<UserToBitConnection> connectionHandler =
new RpcConnectionHandler<UserToBitConnection>() {
@Override public void connectionSucceeded(UserToBitConnection connection) {
@@ -296,8 +299,21 @@
}
@Override public void connectionFailed(FailureType type, Throwable t) {
- connectionSettable
- .setException(new RpcException(String.format("%s : %s", type.name(), t.getMessage()), t));
+ // Don't wrap NonTransientRpcException inside RpcException, since called should not retry to connect in
+ // this case
+ if (t instanceof NonTransientRpcException || t instanceof SaslException) {
+ connectionSettable.setException(t);
+ } else if (t instanceof RpcException) {
+ final Throwable cause = t.getCause();
+ if (cause instanceof SaslException) {
+ connectionSettable.setException(cause);
+ return;
+ }
+ connectionSettable.setException(t);
+ } else {
+ connectionSettable.setException(
+ new RpcException(String.format("%s : %s", type.name(), t.getMessage()), t));
+ }
}
};
@@ -307,89 +323,16 @@
return connectionFuture;
}
- private CheckedFuture<Void, SaslException> authenticate(final DrillProperties properties) {
- final Map<String, String> propertiesMap = properties.stringPropertiesAsMap();
-
- // Set correct QOP property and Strength based on server needs encryption or not.
- // If ChunkMode is enabled then negotiate for buffer size equal to wrapChunkSize,
- // If ChunkMode is disabled then negotiate for MAX_WRAPPED_SIZE buffer size.
- propertiesMap.putAll(
- SaslProperties.getSaslProperties(connection.isEncryptionEnabled(), connection.getMaxWrappedSize()));
-
- final SettableFuture<Void> authSettable =
- SettableFuture.create(); // use handleAuthFailure to setException
- final CheckedFuture<Void, SaslException> authFuture =
- new AbstractCheckedFuture<Void, SaslException>(authSettable) {
-
- @Override protected SaslException mapException(Exception e) {
- if (e instanceof ExecutionException) {
- final Throwable cause = Throwables.getRootCause(e);
- if (cause instanceof SaslException) {
- return new SaslException(String.format("Authentication failed. [Details: %s, Error %s]",
- connection.getEncryptionCtxtString(), cause.getMessage()), cause);
- }
- }
- return new SaslException(String
- .format("Authentication failed unexpectedly. [Details: %s, Error %s]",
- connection.getEncryptionCtxtString(), e.getMessage()), e);
- }
- };
-
- final AuthenticatorFactory factory;
- final String mechanismName;
- final UserGroupInformation ugi;
- final SaslClient saslClient;
- try {
- factory = getAuthenticatorFactory(properties);
- mechanismName = factory.getSimpleName();
- logger.trace("Will try to authenticate to server using {} mechanism with encryption context {}",
- mechanismName, connection.getEncryptionCtxtString());
-
- // Update the thread context class loader to current class loader
- // See DRILL-6063 for detailed description
- final ClassLoader oldThreadCtxtCL = Thread.currentThread().getContextClassLoader();
- final ClassLoader newThreadCtxtCL = this.getClass().getClassLoader();
- Thread.currentThread().setContextClassLoader(newThreadCtxtCL);
-
- ugi = factory.createAndLoginUser(propertiesMap);
-
- // Reset the thread context class loader to original one
- Thread.currentThread().setContextClassLoader(oldThreadCtxtCL);
-
- saslClient = factory.createSaslClient(ugi, propertiesMap);
- if (saslClient == null) {
- throw new SaslException(String.format(
- "Cannot initiate authentication using %s mechanism. Insufficient "
- + "credentials or selected mechanism doesn't support configured security layers?",
- factory.getSimpleName()));
- }
- connection.setSaslClient(saslClient);
- } catch (final IOException e) {
- authSettable.setException(e);
- return authFuture;
- }
-
- logger.trace("Initiating SASL exchange.");
- new AuthenticationOutcomeListener<>(this, connection, RpcType.SASL_MESSAGE, ugi,
- new RpcOutcomeListener<Void>() {
- @Override public void failed(RpcException ex) {
- authSettable.setException(ex);
- }
-
- @Override public void success(Void value, ByteBuf buffer) {
- authComplete = true;
- authSettable.set(null);
- }
-
- @Override public void interrupted(InterruptedException e) {
- authSettable.setException(e);
- }
- }).initiate(mechanismName);
- return authFuture;
- }
-
- private AuthenticatorFactory getAuthenticatorFactory(final DrillProperties properties)
- throws SaslException {
+ /**
+ * Get's the authenticator factory for the mechanism required by client if it's supported on the server side too.
+ * Otherwise it throws {@link SaslException}
+ * @param properties - client connection properties
+ * @param serverAuthMechanisms - list of authentication mechanisms supported by server
+ * @return - {@link AuthenticatorFactory} for the mechanism required by client for authentication
+ * @throws SaslException - In case of failure
+ */
+ private AuthenticatorFactory getAuthenticatorFactory(final DrillProperties properties,
+ List<String> serverAuthMechanisms) throws SaslException {
final Set<String> mechanismSet = AuthStringUtil.asSet(serverAuthMechanisms);
// first, check if a certain mechanism must be used
@@ -421,7 +364,7 @@
throw new SaslException(String
.format("Server requires authentication using %s. Insufficient credentials?. " + "[Details: %s]. ",
- serverAuthMechanisms, connection.getEncryptionCtxtString()));
+ mechanismSet, connection.getEncryptionCtxtString()));
}
protected <SEND extends MessageLite, RECEIVE extends MessageLite> void send(
@@ -464,7 +407,7 @@
@Override protected void handle(UserToBitConnection connection, int rpcType, ByteBuf pBody, ByteBuf dBody,
ResponseSender sender) throws RpcException {
- if (!authComplete) {
+ if (!isAuthComplete()) {
// Remote should not be making any requests before authenticating, drop connection
throw new RpcException(String.format("Request of type %d is not allowed without authentication. "
+ "Remote on %s must authenticate before making requests. Connection dropped.", rpcType,
@@ -484,8 +427,45 @@
}
}
- @Override protected void validateHandshake(BitToUserHandshake inbound) throws RpcException {
+ @Override
+ protected void prepareSaslHandshake(final RpcConnectionHandler<UserToBitConnection> connectionHandler,
+ List<String> serverAuthMechanisms) {
+ try {
+ final Map<String, String> saslProperties = properties.stringPropertiesAsMap();
+
+ // Set correct QOP property and Strength based on server needs encryption or not.
+ // If ChunkMode is enabled then negotiate for buffer size equal to wrapChunkSize,
+ // If ChunkMode is disabled then negotiate for MAX_WRAPPED_SIZE buffer size.
+ saslProperties.putAll(
+ SaslProperties.getSaslProperties(connection.isEncryptionEnabled(), connection.getMaxWrappedSize()));
+
+ final AuthenticatorFactory factory = getAuthenticatorFactory(properties, serverAuthMechanisms);
+ final String mechanismName = factory.getSimpleName();
+ logger.trace("Will try to authenticate to server using {} mechanism with encryption context {}",
+ mechanismName, connection.getEncryptionCtxtString());
+
+ // Update the thread context class loader to current class loader
+ // See DRILL-6063 for detailed description
+ final ClassLoader oldThreadCtxtCL = Thread.currentThread().getContextClassLoader();
+ final ClassLoader newThreadCtxtCL = this.getClass().getClassLoader();
+ Thread.currentThread().setContextClassLoader(newThreadCtxtCL);
+ final UserGroupInformation ugi = factory.createAndLoginUser(saslProperties);
+ // Reset the thread context class loader to original one
+ Thread.currentThread().setContextClassLoader(oldThreadCtxtCL);
+
+ startSaslHandshake(connectionHandler, saslProperties, ugi, factory, RpcType.SASL_MESSAGE);
+ } catch (final IOException e) {
+ logger.error("Failed while doing setup for starting SASL handshake for connection", connection.getName());
+ final Exception ex = new RpcException(String.format("Failed to initiate authentication for connection %s",
+ connection.getName()), e);
+ connectionHandler.connectionFailed(RpcConnectionHandler.FailureType.AUTHENTICATION, ex);
+ }
+ }
+
+ @Override protected List<String> validateHandshake(BitToUserHandshake inbound) throws RpcException {
// logger.debug("Handling handshake from bit to user. {}", inbound);
+ List<String> serverAuthMechanisms = null;
+
if (inbound.hasServerInfos()) {
serverInfos = inbound.getServerInfos();
}
@@ -494,9 +474,9 @@
switch (inbound.getStatus()) {
case SUCCESS:
break;
- case AUTH_REQUIRED: {
- authComplete = false;
+ case AUTH_REQUIRED:
serverAuthMechanisms = ImmutableList.copyOf(inbound.getAuthenticationMechanismsList());
+ setAuthComplete(false);
connection.setEncryption(inbound.hasEncrypted() && inbound.getEncrypted());
if (inbound.hasMaxWrappedSize()) {
@@ -506,7 +486,6 @@
.format("Server requires authentication with encryption context %s before proceeding.",
connection.getEncryptionCtxtString()));
break;
- }
case AUTH_FAILED:
case RPC_VERSION_MISMATCH:
case UNKNOWN_FAILURE:
@@ -516,6 +495,11 @@
logger.error(errMsg);
throw new NonTransientRpcException(errMsg);
}
+
+ // Before starting SASL handshake validate if both client and server are compatible in their security
+ // requirements for the connection
+ validateSaslCompatibility(properties, serverAuthMechanisms);
+ return serverAuthMechanisms;
}
@Override protected UserToBitConnection initRemoteConnection(SocketChannel channel) {
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/client/ConnectTriesPropertyTestClusterBits.java b/exec/java-exec/src/test/java/org/apache/drill/exec/client/ConnectTriesPropertyTestClusterBits.java
index 5c28af1..f6336e6 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/client/ConnectTriesPropertyTestClusterBits.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/client/ConnectTriesPropertyTestClusterBits.java
@@ -185,7 +185,7 @@
ClusterCoordinator.RegistrationHandle fakeEndPoint2Handle = remoteServiceSet.getCoordinator()
.register(fakeEndPoint2);
- client.connect(null);
+ client.connect(new Properties());
client.close();
// Remove the fake drillbits so that other tests are not affected
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/data/TestBitBitKerberos.java b/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/data/TestBitBitKerberos.java
index 0b00824..b4b54c6 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/data/TestBitBitKerberos.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/data/TestBitBitKerberos.java
@@ -69,6 +69,7 @@
import org.junit.experimental.categories.Category;
import org.mockito.Mockito;
+import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.List;
@@ -87,8 +88,6 @@
private static KerberosHelper krbHelper;
private static DrillConfig newConfig;
- private static BootStrapContext c1;
- private static FragmentManager manager;
private int port = 1234;
@BeforeClass
@@ -126,16 +125,12 @@
defaultRealm.set(null, KerberosUtil.getDefaultRealm());
updateTestCluster(1, newConfig);
-
- ScanResult result = ClassPathScanner.fromPrescan(newConfig);
- c1 = new BootStrapContext(newConfig, SystemOptionManager.createDefaultOptionDefinitions(), result);
- setupFragmentContextAndManager();
}
- private static void setupFragmentContextAndManager() {
+ private FragmentManager setupFragmentContextAndManager(BufferAllocator allocator) {
final FragmentContextImpl fcontext = mock(FragmentContextImpl.class);
- when(fcontext.getAllocator()).thenReturn(c1.getAllocator());
- manager = new MockFragmentManager(fcontext);
+ when(fcontext.getAllocator()).thenReturn(allocator);
+ return new MockFragmentManager(fcontext);
}
private static WritableBatch getRandomBatch(BufferAllocator allocator, int records) {
@@ -194,6 +189,25 @@
final WorkerBee bee = mock(WorkerBee.class);
final WorkEventBus workBus = mock(WorkEventBus.class);
+ newConfig = new DrillConfig(DrillConfig.create(cloneDefaultTestConfigProperties())
+ .withValue(ExecConstants.AUTHENTICATION_MECHANISMS,
+ ConfigValueFactory.fromIterable(Lists.newArrayList("kerberos")))
+ .withValue(ExecConstants.BIT_AUTHENTICATION_ENABLED,
+ ConfigValueFactory.fromAnyRef(true))
+ .withValue(ExecConstants.BIT_AUTHENTICATION_MECHANISM,
+ ConfigValueFactory.fromAnyRef("kerberos"))
+ .withValue(ExecConstants.USE_LOGIN_PRINCIPAL,
+ ConfigValueFactory.fromAnyRef(true))
+ .withValue(BootStrapContext.SERVICE_PRINCIPAL,
+ ConfigValueFactory.fromAnyRef(krbHelper.SERVER_PRINCIPAL))
+ .withValue(BootStrapContext.SERVICE_KEYTAB_LOCATION,
+ ConfigValueFactory.fromAnyRef(krbHelper.serverKeytab.toString())));
+
+ final ScanResult result = ClassPathScanner.fromPrescan(newConfig);
+ final BootStrapContext c1 =
+ new BootStrapContext(newConfig, SystemOptionManager.createDefaultOptionDefinitions(), result);
+
+ final FragmentManager manager = setupFragmentContextAndManager(c1.getAllocator());
when(workBus.getFragmentManager(Mockito.<FragmentHandle>any())).thenReturn(manager);
DataConnectionConfig config = new DataConnectionConfig(c1.getAllocator(), c1,
@@ -205,60 +219,80 @@
DataConnectionManager connectionManager = new DataConnectionManager(ep, config);
DataTunnel tunnel = new DataTunnel(connectionManager);
AtomicLong max = new AtomicLong(0);
- for (int i = 0; i < 40; i++) {
- long t1 = System.currentTimeMillis();
- tunnel.sendRecordBatch(new TimingOutcome(max), new FragmentWritableBatch(false, QueryId.getDefaultInstance(), 1,
- 1, 1, 1, getRandomBatch(c1.getAllocator(), 5000)));
- System.out.println(System.currentTimeMillis() - t1);
+
+ try {
+ for (int i = 0; i < 40; i++) {
+ long t1 = System.currentTimeMillis();
+ tunnel.sendRecordBatch(new TimingOutcome(max),
+ new FragmentWritableBatch(false, QueryId.getDefaultInstance(), 1, 1, 1, 1,
+ getRandomBatch(c1.getAllocator(), 5000)));
+ System.out.println(System.currentTimeMillis() - t1);
+ }
+ System.out.println(String.format("Max time: %d", max.get()));
+ assertTrue(max.get() > 2700);
+ Thread.sleep(5000);
+ } catch (Exception | AssertionError e) {
+ fail();
+ } finally {
+ server.close();
+ connectionManager.close();
+ c1.close();
}
- System.out.println(String.format("Max time: %d", max.get()));
- assertTrue(max.get() > 2700);
- Thread.sleep(5000);
}
@Test
public void successEncryption() throws Exception {
+
final WorkerBee bee = mock(WorkerBee.class);
final WorkEventBus workBus = mock(WorkEventBus.class);
+ newConfig = new DrillConfig(DrillConfig.create(cloneDefaultTestConfigProperties())
+ .withValue(ExecConstants.AUTHENTICATION_MECHANISMS,
+ ConfigValueFactory.fromIterable(Lists.newArrayList("kerberos")))
+ .withValue(ExecConstants.BIT_AUTHENTICATION_ENABLED,
+ ConfigValueFactory.fromAnyRef(true))
+ .withValue(ExecConstants.BIT_AUTHENTICATION_MECHANISM,
+ ConfigValueFactory.fromAnyRef("kerberos"))
+ .withValue(ExecConstants.BIT_ENCRYPTION_SASL_ENABLED,
+ ConfigValueFactory.fromAnyRef(true))
+ .withValue(ExecConstants.USE_LOGIN_PRINCIPAL,
+ ConfigValueFactory.fromAnyRef(true))
+ .withValue(BootStrapContext.SERVICE_PRINCIPAL,
+ ConfigValueFactory.fromAnyRef(krbHelper.SERVER_PRINCIPAL))
+ .withValue(BootStrapContext.SERVICE_KEYTAB_LOCATION,
+ ConfigValueFactory.fromAnyRef(krbHelper.serverKeytab.toString())));
+ final ScanResult result = ClassPathScanner.fromPrescan(newConfig);
+ final BootStrapContext c2 =
+ new BootStrapContext(newConfig, SystemOptionManager.createDefaultOptionDefinitions(), result);
+
+ final FragmentManager manager = setupFragmentContextAndManager(c2.getAllocator());
when(workBus.getFragmentManager(Mockito.<FragmentHandle>any())).thenReturn(manager);
- newConfig = new DrillConfig(
- config.withValue(ExecConstants.AUTHENTICATION_MECHANISMS,
- ConfigValueFactory.fromIterable(Lists.newArrayList("kerberos")))
- .withValue(ExecConstants.BIT_AUTHENTICATION_ENABLED,
- ConfigValueFactory.fromAnyRef(true))
- .withValue(ExecConstants.BIT_AUTHENTICATION_MECHANISM,
- ConfigValueFactory.fromAnyRef("kerberos"))
- .withValue(ExecConstants.BIT_ENCRYPTION_SASL_ENABLED,
- ConfigValueFactory.fromAnyRef(true))
- .withValue(ExecConstants.USE_LOGIN_PRINCIPAL,
- ConfigValueFactory.fromAnyRef(true))
- .withValue(BootStrapContext.SERVICE_PRINCIPAL,
- ConfigValueFactory.fromAnyRef(krbHelper.SERVER_PRINCIPAL))
- .withValue(BootStrapContext.SERVICE_KEYTAB_LOCATION,
- ConfigValueFactory.fromAnyRef(krbHelper.serverKeytab.toString())));
-
- updateTestCluster(1, newConfig);
-
- DataConnectionConfig config = new DataConnectionConfig(c1.getAllocator(), c1,
- new DataServerRequestHandler(workBus, bee));
- DataServer server = new DataServer(config);
+ final DataConnectionConfig config =
+ new DataConnectionConfig(c2.getAllocator(), c2, new DataServerRequestHandler(workBus, bee));
+ final DataServer server = new DataServer(config);
port = server.bind(port, true);
DrillbitEndpoint ep = DrillbitEndpoint.newBuilder().setAddress("localhost").setDataPort(port).build();
- DataConnectionManager connectionManager = new DataConnectionManager(ep, config);
- DataTunnel tunnel = new DataTunnel(connectionManager);
+ final DataConnectionManager connectionManager = new DataConnectionManager(ep, config);
+ final DataTunnel tunnel = new DataTunnel(connectionManager);
AtomicLong max = new AtomicLong(0);
- for (int i = 0; i < 40; i++) {
- long t1 = System.currentTimeMillis();
- tunnel.sendRecordBatch(new TimingOutcome(max), new FragmentWritableBatch(false, QueryId.getDefaultInstance(), 1,
- 1, 1, 1, getRandomBatch(c1.getAllocator(), 5000)));
- System.out.println(System.currentTimeMillis() - t1);
+ try {
+ for (int i = 0; i < 40; i++) {
+ long t1 = System.currentTimeMillis();
+ tunnel.sendRecordBatch(new TimingOutcome(max),
+ new FragmentWritableBatch(false, QueryId.getDefaultInstance(), 1, 1, 1, 1,
+ getRandomBatch(c2.getAllocator(), 5000)));
+ System.out.println(System.currentTimeMillis() - t1);
+ }
+ System.out.println(String.format("Max time: %d", max.get()));
+ assertTrue(max.get() > 2700);
+ Thread.sleep(5000);
+ } finally {
+ server.close();
+ connectionManager.close();
+ c2.close();
}
- System.out.println(String.format("Max time: %d", max.get()));
- assertTrue(max.get() > 2700);
- Thread.sleep(5000);
}
@Test
@@ -268,10 +302,8 @@
final WorkerBee bee = mock(WorkerBee.class);
final WorkEventBus workBus = mock(WorkEventBus.class);
- when(workBus.getFragmentManager(Mockito.<FragmentHandle>any())).thenReturn(manager);
-
- newConfig = new DrillConfig(
- config.withValue(ExecConstants.AUTHENTICATION_MECHANISMS,
+ newConfig = new DrillConfig(DrillConfig.create(cloneDefaultTestConfigProperties())
+ .withValue(ExecConstants.AUTHENTICATION_MECHANISMS,
ConfigValueFactory.fromIterable(Lists.newArrayList("kerberos")))
.withValue(ExecConstants.BIT_AUTHENTICATION_ENABLED,
ConfigValueFactory.fromAnyRef(true))
@@ -288,33 +320,48 @@
.withValue(BootStrapContext.SERVICE_KEYTAB_LOCATION,
ConfigValueFactory.fromAnyRef(krbHelper.serverKeytab.toString())));
- updateTestCluster(1, newConfig);
+ final ScanResult result = ClassPathScanner.fromPrescan(newConfig);
+ final BootStrapContext c2 =
+ new BootStrapContext(newConfig, SystemOptionManager.createDefaultOptionDefinitions(), result);
- DataConnectionConfig config = new DataConnectionConfig(c1.getAllocator(), c1,
+ final FragmentManager manager = setupFragmentContextAndManager(c2.getAllocator());
+ when(workBus.getFragmentManager(Mockito.<FragmentHandle>any())).thenReturn(manager);
+
+ final DataConnectionConfig config = new DataConnectionConfig(c2.getAllocator(), c2,
new DataServerRequestHandler(workBus, bee));
- DataServer server = new DataServer(config);
+ final DataServer server = new DataServer(config);
port = server.bind(port, true);
- DrillbitEndpoint ep = DrillbitEndpoint.newBuilder().setAddress("localhost").setDataPort(port).build();
- DataConnectionManager connectionManager = new DataConnectionManager(ep, config);
- DataTunnel tunnel = new DataTunnel(connectionManager);
+ final DrillbitEndpoint ep = DrillbitEndpoint.newBuilder().setAddress("localhost").setDataPort(port).build();
+ final DataConnectionManager connectionManager = new DataConnectionManager(ep, config);
+ final DataTunnel tunnel = new DataTunnel(connectionManager);
AtomicLong max = new AtomicLong(0);
- for (int i = 0; i < 40; i++) {
- long t1 = System.currentTimeMillis();
- tunnel.sendRecordBatch(new TimingOutcome(max), new FragmentWritableBatch(false, QueryId.getDefaultInstance(), 1,
- 1, 1, 1, getRandomBatch(c1.getAllocator(), 5000)));
- System.out.println(System.currentTimeMillis() - t1);
+
+ try {
+ for (int i = 0; i < 40; i++) {
+ long t1 = System.currentTimeMillis();
+ tunnel.sendRecordBatch(new TimingOutcome(max),
+ new FragmentWritableBatch(false, QueryId.getDefaultInstance(), 1, 1, 1, 1,
+ getRandomBatch(c2.getAllocator(), 5000)));
+ System.out.println(System.currentTimeMillis() - t1);
+ }
+ System.out.println(String.format("Max time: %d", max.get()));
+ assertTrue(max.get() > 2700);
+ Thread.sleep(5000);
+ } catch (Exception | AssertionError ex) {
+ fail();
+ } finally {
+ server.close();
+ connectionManager.close();
+ c2.close();
}
- System.out.println(String.format("Max time: %d", max.get()));
- assertTrue(max.get() > 2700);
- Thread.sleep(5000);
}
@Test
public void failureEncryptionOnlyPlainMechanism() throws Exception {
try{
- newConfig = new DrillConfig(
- config.withValue(ExecConstants.AUTHENTICATION_MECHANISMS,
+ newConfig = new DrillConfig(DrillConfig.create(cloneDefaultTestConfigProperties())
+ .withValue(ExecConstants.AUTHENTICATION_MECHANISMS,
ConfigValueFactory.fromIterable(Lists.newArrayList("plain")))
.withValue(ExecConstants.BIT_AUTHENTICATION_ENABLED,
ConfigValueFactory.fromAnyRef(true))
@@ -452,7 +499,7 @@
} catch (InterruptedException e) {
}
- RawFragmentBatch rfb = batch.newRawFragmentBatch(c1.getAllocator());
+ RawFragmentBatch rfb = batch.newRawFragmentBatch(fragmentContext.getAllocator());
rfb.sendOk();
rfb.release();
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/security/KerberosHelper.java b/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/security/KerberosHelper.java
index 8ba4d18..79dbc36 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/security/KerberosHelper.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/security/KerberosHelper.java
@@ -131,6 +131,10 @@
kdc.exportPrincipal(principal, keytab);
}
+ /**
+ * Workspace is owned by test using this helper
+ * @throws Exception
+ */
public void stopKdc() throws Exception {
if (kdcStarted) {
logger.info("Stopping KDC on {}", kdcPort);
@@ -141,7 +145,6 @@
deleteIfExists(serverKeytab);
deleteIfExists(keytabDir);
deleteIfExists(kdcDir);
- deleteIfExists(workspace);
}
private void deleteIfExists(File file) throws IOException {
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/user/security/TestUserBitKerberos.java b/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/user/security/TestUserBitKerberos.java
index dbdbe3c..55f959c 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/user/security/TestUserBitKerberos.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/user/security/TestUserBitKerberos.java
@@ -175,7 +175,7 @@
// Check unencrypted counters value
assertTrue(1 == UserRpcMetrics.getInstance().getUnEncryptedConnectionCount());
- assertTrue(2 == ControlRpcMetrics.getInstance().getUnEncryptedConnectionCount());
+ assertTrue(0 == ControlRpcMetrics.getInstance().getUnEncryptedConnectionCount());
assertTrue(0 == DataRpcMetrics.getInstance().getUnEncryptedConnectionCount());
}
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/user/security/TestUserBitKerberosEncryption.java b/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/user/security/TestUserBitKerberosEncryption.java
index aa26fd6..640eb40 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/user/security/TestUserBitKerberosEncryption.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/rpc/user/security/TestUserBitKerberosEncryption.java
@@ -111,7 +111,22 @@
connectionProps.setProperty(DrillProperties.SERVICE_PRINCIPAL, krbHelper.SERVER_PRINCIPAL);
connectionProps.setProperty(DrillProperties.USER, krbHelper.CLIENT_PRINCIPAL);
connectionProps.setProperty(DrillProperties.KEYTAB, krbHelper.clientKeytab.getAbsolutePath());
- updateClient(connectionProps);
+
+ newConfig = new DrillConfig(DrillConfig.create(cloneDefaultTestConfigProperties())
+ .withValue(ExecConstants.USER_AUTHENTICATION_ENABLED,
+ ConfigValueFactory.fromAnyRef(true))
+ .withValue(ExecConstants.USER_AUTHENTICATOR_IMPL,
+ ConfigValueFactory.fromAnyRef(UserAuthenticatorTestImpl.TYPE))
+ .withValue(BootStrapContext.SERVICE_PRINCIPAL,
+ ConfigValueFactory.fromAnyRef(krbHelper.SERVER_PRINCIPAL))
+ .withValue(BootStrapContext.SERVICE_KEYTAB_LOCATION,
+ ConfigValueFactory.fromAnyRef(krbHelper.serverKeytab.toString()))
+ .withValue(ExecConstants.AUTHENTICATION_MECHANISMS,
+ ConfigValueFactory.fromIterable(Lists.newArrayList("plain", "kerberos")))
+ .withValue(ExecConstants.USER_ENCRYPTION_SASL_ENABLED,
+ ConfigValueFactory.fromAnyRef(true)));
+
+ updateTestCluster(1, newConfig, connectionProps);
// Run few queries using the new client
testBuilder()
@@ -145,7 +160,22 @@
connectionProps.setProperty(DrillProperties.SERVICE_PRINCIPAL, krbHelper.SERVER_PRINCIPAL);
connectionProps.setProperty(DrillProperties.USER, krbHelper.CLIENT_PRINCIPAL);
connectionProps.setProperty(DrillProperties.KEYTAB, krbHelper.clientKeytab.getAbsolutePath());
- updateClient(connectionProps);
+
+ newConfig = new DrillConfig(DrillConfig.create(cloneDefaultTestConfigProperties())
+ .withValue(ExecConstants.USER_AUTHENTICATION_ENABLED,
+ ConfigValueFactory.fromAnyRef(true))
+ .withValue(ExecConstants.USER_AUTHENTICATOR_IMPL,
+ ConfigValueFactory.fromAnyRef(UserAuthenticatorTestImpl.TYPE))
+ .withValue(BootStrapContext.SERVICE_PRINCIPAL,
+ ConfigValueFactory.fromAnyRef(krbHelper.SERVER_PRINCIPAL))
+ .withValue(BootStrapContext.SERVICE_KEYTAB_LOCATION,
+ ConfigValueFactory.fromAnyRef(krbHelper.serverKeytab.toString()))
+ .withValue(ExecConstants.AUTHENTICATION_MECHANISMS,
+ ConfigValueFactory.fromIterable(Lists.newArrayList("plain", "kerberos")))
+ .withValue(ExecConstants.USER_ENCRYPTION_SASL_ENABLED,
+ ConfigValueFactory.fromAnyRef(true)));
+
+ updateTestCluster(1, newConfig, connectionProps);
assertTrue(UserRpcMetrics.getInstance().getEncryptedConnectionCount() == 1);
assertTrue(UserRpcMetrics.getInstance().getUnEncryptedConnectionCount() == 0);
@@ -177,10 +207,24 @@
final Subject clientSubject = JaasKrbUtil.loginUsingKeytab(krbHelper.CLIENT_PRINCIPAL,
krbHelper.clientKeytab.getAbsoluteFile());
+ newConfig = new DrillConfig(DrillConfig.create(cloneDefaultTestConfigProperties())
+ .withValue(ExecConstants.USER_AUTHENTICATION_ENABLED,
+ ConfigValueFactory.fromAnyRef(true))
+ .withValue(ExecConstants.USER_AUTHENTICATOR_IMPL,
+ ConfigValueFactory.fromAnyRef(UserAuthenticatorTestImpl.TYPE))
+ .withValue(BootStrapContext.SERVICE_PRINCIPAL,
+ ConfigValueFactory.fromAnyRef(krbHelper.SERVER_PRINCIPAL))
+ .withValue(BootStrapContext.SERVICE_KEYTAB_LOCATION,
+ ConfigValueFactory.fromAnyRef(krbHelper.serverKeytab.toString()))
+ .withValue(ExecConstants.AUTHENTICATION_MECHANISMS,
+ ConfigValueFactory.fromIterable(Lists.newArrayList("plain", "kerberos")))
+ .withValue(ExecConstants.USER_ENCRYPTION_SASL_ENABLED,
+ ConfigValueFactory.fromAnyRef(true)));
+
Subject.doAs(clientSubject, new PrivilegedExceptionAction<Void>() {
@Override
public Void run() throws Exception {
- updateClient(connectionProps);
+ updateTestCluster(1, newConfig, connectionProps);
return null;
}
});
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/server/TestDrillbitResilience.java b/exec/java-exec/src/test/java/org/apache/drill/exec/server/TestDrillbitResilience.java
index 34f75f4..f86d698 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/server/TestDrillbitResilience.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/server/TestDrillbitResilience.java
@@ -29,6 +29,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Properties;
import org.apache.commons.math3.util.Pair;
import org.apache.drill.exec.work.foreman.FragmentsRunner;
@@ -205,7 +206,7 @@
// create a client
final DrillConfig drillConfig = zkHelper.getConfig();
- drillClient = QueryTestUtil.createClient(drillConfig, remoteServiceSet, 1, null);
+ drillClient = QueryTestUtil.createClient(drillConfig, remoteServiceSet, 1, new Properties());
clearAllInjections();
}
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/testing/TestResourceLeak.java b/exec/java-exec/src/test/java/org/apache/drill/exec/testing/TestResourceLeak.java
index 94c8ebf..1098dc4 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/testing/TestResourceLeak.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/testing/TestResourceLeak.java
@@ -87,7 +87,7 @@
bit = new Drillbit(config, serviceSet);
bit.run();
- client = QueryTestUtil.createClient(config, serviceSet, 2, null);
+ client = QueryTestUtil.createClient(config, serviceSet, 2, new Properties());
}
@Test
diff --git a/exec/rpc/pom.xml b/exec/rpc/pom.xml
index ea56574..5ed62ee 100644
--- a/exec/rpc/pom.xml
+++ b/exec/rpc/pom.xml
@@ -77,6 +77,10 @@
</exclusion>
</exclusions>
</dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-common</artifactId>
+ </dependency>
</dependencies>
diff --git a/exec/rpc/src/main/java/org/apache/drill/exec/rpc/BasicClient.java b/exec/rpc/src/main/java/org/apache/drill/exec/rpc/BasicClient.java
index 0f4ef1b..4395db3 100644
--- a/exec/rpc/src/main/java/org/apache/drill/exec/rpc/BasicClient.java
+++ b/exec/rpc/src/main/java/org/apache/drill/exec/rpc/BasicClient.java
@@ -17,6 +17,10 @@
*/
package org.apache.drill.exec.rpc;
+import com.google.common.base.Preconditions;
+import com.google.protobuf.Internal.EnumLite;
+import com.google.protobuf.MessageLite;
+import com.google.protobuf.Parser;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
@@ -32,16 +36,17 @@
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
-
-import java.util.concurrent.TimeUnit;
-
import org.apache.drill.exec.memory.BufferAllocator;
import org.apache.drill.exec.proto.GeneralRPCProtos.RpcMode;
+import org.apache.drill.exec.rpc.security.AuthenticationOutcomeListener;
+import org.apache.drill.exec.rpc.security.AuthenticatorFactory;
+import org.apache.hadoop.security.UserGroupInformation;
-import com.google.common.base.Preconditions;
-import com.google.protobuf.Internal.EnumLite;
-import com.google.protobuf.MessageLite;
-import com.google.protobuf.Parser;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
/**
*
@@ -69,6 +74,9 @@
private final IdlePingHandler pingHandler;
private ConnectionMultiListener.SSLHandshakeListener sslHandshakeListener = null;
+ // Determines if authentication is completed between client and server
+ private boolean authComplete = true;
+
public BasicClient(RpcConfig rpcMapping, ByteBufAllocator alloc, EventLoopGroup eventLoopGroup, T handshakeType,
Class<HR> responseClass, Parser<HR> handshakeParser) {
super(rpcMapping);
@@ -133,6 +141,19 @@
return false;
}
+ /**
+ * Set's the state for authentication complete.
+ * @param authComplete - state to set. True means authentication between client and server is completed, false
+ * means authentication is in progress.
+ */
+ protected void setAuthComplete(boolean authComplete) {
+ this.authComplete = authComplete;
+ }
+
+ protected boolean isAuthComplete() {
+ return authComplete;
+ }
+
// Save the SslChannel after the SSL handshake so it can be closed later
public void setSslChannel(Channel c) {
@@ -180,7 +201,67 @@
return (connection != null) && connection.isActive();
}
- protected abstract void validateHandshake(HR validateHandshake) throws RpcException;
+ protected abstract List<String> validateHandshake(HR validateHandshake) throws RpcException;
+
+ /**
+ * Creates various instances needed to start the SASL handshake. This is called from
+ * {@link BasicClient#validateHandshake(MessageLite)} if authentication is required from server side.
+ * @param connectionHandler - Connection handler used by client's to know about success/failure conditions.
+ * @param serverAuthMechanisms - List of auth mechanisms configured on server side
+ */
+ protected abstract void prepareSaslHandshake(final RpcConnectionHandler<CC> connectionHandler,
+ List<String> serverAuthMechanisms) throws RpcException;
+
+ /**
+ * Main method which starts the SASL handshake for all client channels (user/data/control) once it's determined
+ * after regular RPC handshake that authentication is required by server side. Once authentication is completed
+ * then only the underlying channel is made available to clients to send other RPC messages. Success and failure
+ * events are notified to the connection handler on which client waits.
+ * @param connectionHandler - Connection handler used by client's to know about success/failure conditions.
+ * @param saslProperties - SASL related properties needed to create SASL client.
+ * @param ugi - UserGroupInformation with logged in client side user
+ * @param authFactory - Authentication factory to use for this SASL handshake.
+ * @param rpcType - SASL_MESSAGE rpc type.
+ */
+ protected void startSaslHandshake(final RpcConnectionHandler<CC> connectionHandler,
+ Map<String, ?> saslProperties, UserGroupInformation ugi,
+ AuthenticatorFactory authFactory, T rpcType) {
+ final String mechanismName = authFactory.getSimpleName();
+ try {
+ final SaslClient saslClient = authFactory.createSaslClient(ugi, saslProperties);
+ if (saslClient == null) {
+ final Exception ex = new SaslException(String.format("Cannot initiate authentication using %s mechanism. " +
+ "Insufficient credentials or selected mechanism doesn't support configured security layers?", mechanismName));
+ connectionHandler.connectionFailed(RpcConnectionHandler.FailureType.AUTHENTICATION, ex);
+ return;
+ }
+ connection.setSaslClient(saslClient);
+ } catch (final SaslException e) {
+ logger.error("Failed while creating SASL client for SASL handshake for connection", connection.getName());
+ connectionHandler.connectionFailed(RpcConnectionHandler.FailureType.AUTHENTICATION, e);
+ return;
+ }
+
+ logger.debug("Initiating SASL exchange.");
+ new AuthenticationOutcomeListener<>(this, connection, rpcType, ugi,
+ new RpcOutcomeListener<Void>() {
+ @Override
+ public void failed(RpcException ex) {
+ connectionHandler.connectionFailed(RpcConnectionHandler.FailureType.AUTHENTICATION, ex);
+ }
+
+ @Override
+ public void success(Void value, ByteBuf buffer) {
+ authComplete = true;
+ connectionHandler.connectionSucceeded(connection);
+ }
+
+ @Override
+ public void interrupted(InterruptedException ex) {
+ connectionHandler.connectionFailed(RpcConnectionHandler.FailureType.AUTHENTICATION, ex);
+ }
+ }).initiate(mechanismName);
+ }
protected void finalizeConnection(HR handshake, CC connection) {
// no-op
@@ -204,12 +285,6 @@
allowInEventLoop, dataBodies);
}
- // the command itself must be "run" by the caller (to avoid calling inEventLoop)
- protected <M extends MessageLite> RpcCommand<M, CC>
- getInitialCommand(final RpcCommand<M, CC> command) {
- return command;
- }
-
protected void connectAsClient(RpcConnectionHandler<CC> connectionListener, HS handshakeValue,
String host, int port) {
ConnectionMultiListener<T, CC, HS, HR, BasicClient<T, CC, HS, HR>> cml;
diff --git a/exec/rpc/src/main/java/org/apache/drill/exec/rpc/ConnectionMultiListener.java b/exec/rpc/src/main/java/org/apache/drill/exec/rpc/ConnectionMultiListener.java
index 0cdca13..3fee5d7 100644
--- a/exec/rpc/src/main/java/org/apache/drill/exec/rpc/ConnectionMultiListener.java
+++ b/exec/rpc/src/main/java/org/apache/drill/exec/rpc/ConnectionMultiListener.java
@@ -28,6 +28,7 @@
import org.slf4j.Logger;
import java.net.SocketAddress;
+import java.util.List;
import java.util.concurrent.TimeUnit;
/**
@@ -151,12 +152,21 @@
public void success(HR value, ByteBuf buffer) {
// logger.debug("Handshake received. {}", value);
try {
- parent.validateHandshake(value);
+ final List<String> serverAuthMechanisms = parent.validateHandshake(value);
parent.finalizeConnection(value, parent.connection);
- connectionListener.connectionSucceeded(parent.connection);
- // logger.debug("Handshake completed succesfully.");
+
+ // If auth is required then start the SASL handshake
+ if (serverAuthMechanisms != null) {
+ parent.prepareSaslHandshake(connectionListener, serverAuthMechanisms);
+ } else {
+ connectionListener.connectionSucceeded(parent.connection);
+ logger.debug("Handshake completed successfully.");
+ }
+ } catch (NonTransientRpcException ex) {
+ logger.error("Failure while validating client and server sasl compatibility", ex);
+ connectionListener.connectionFailed(RpcConnectionHandler.FailureType.AUTHENTICATION, ex);
} catch (Exception ex) {
- logger.debug("Failure while validating handshake", ex);
+ logger.error("Failure while validating handshake", ex);
connectionListener.connectionFailed(RpcConnectionHandler.FailureType.HANDSHAKE_VALIDATION, ex);
}
}
diff --git a/exec/rpc/src/main/java/org/apache/drill/exec/rpc/ReconnectingConnection.java b/exec/rpc/src/main/java/org/apache/drill/exec/rpc/ReconnectingConnection.java
index a64a23b..3936170 100644
--- a/exec/rpc/src/main/java/org/apache/drill/exec/rpc/ReconnectingConnection.java
+++ b/exec/rpc/src/main/java/org/apache/drill/exec/rpc/ReconnectingConnection.java
@@ -78,7 +78,7 @@
} else {
// logger.debug("No connection active, opening client connection.");
BasicClient<?, C, HS, ?> client = getNewClient();
- ConnectionListeningFuture<T> future = new ConnectionListeningFuture<>(client.getInitialCommand(cmd));
+ ConnectionListeningFuture<T> future = new ConnectionListeningFuture<>(cmd);
client.connectAsClient(future, handshake, host, port);
future.waitAndRun();
// logger.debug("Connection available and active, command now being run inline.");
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/AuthenticationOutcomeListener.java b/exec/rpc/src/main/java/org/apache/drill/exec/rpc/security/AuthenticationOutcomeListener.java
similarity index 100%
rename from exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/AuthenticationOutcomeListener.java
rename to exec/rpc/src/main/java/org/apache/drill/exec/rpc/security/AuthenticationOutcomeListener.java
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/AuthenticatorFactory.java b/exec/rpc/src/main/java/org/apache/drill/exec/rpc/security/AuthenticatorFactory.java
similarity index 100%
rename from exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/AuthenticatorFactory.java
rename to exec/rpc/src/main/java/org/apache/drill/exec/rpc/security/AuthenticatorFactory.java
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/SaslProperties.java b/exec/rpc/src/main/java/org/apache/drill/exec/rpc/security/SaslProperties.java
similarity index 100%
rename from exec/java-exec/src/main/java/org/apache/drill/exec/rpc/security/SaslProperties.java
rename to exec/rpc/src/main/java/org/apache/drill/exec/rpc/security/SaslProperties.java