blob: 240e474709cafd169438f16d23a1f395a1e8519b [file] [log] [blame]
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.datastax.oss.driver.internal.core.channel;
import com.datastax.oss.driver.api.core.DefaultProtocolVersion;
import com.datastax.oss.driver.api.core.InvalidKeyspaceException;
import com.datastax.oss.driver.api.core.ProtocolVersion;
import com.datastax.oss.driver.api.core.UnsupportedProtocolVersionException;
import com.datastax.oss.driver.api.core.auth.AuthenticationException;
import com.datastax.oss.driver.api.core.auth.Authenticator;
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
import com.datastax.oss.driver.api.core.config.DriverExecutionProfile;
import com.datastax.oss.driver.api.core.connection.ConnectionInitException;
import com.datastax.oss.driver.api.core.type.codec.TypeCodecs;
import com.datastax.oss.driver.internal.core.context.InternalDriverContext;
import com.datastax.oss.driver.internal.core.util.ProtocolUtils;
import com.datastax.oss.driver.internal.core.util.concurrent.UncaughtExceptions;
import com.datastax.oss.protocol.internal.Message;
import com.datastax.oss.protocol.internal.ProtocolConstants;
import com.datastax.oss.protocol.internal.request.AuthResponse;
import com.datastax.oss.protocol.internal.request.Query;
import com.datastax.oss.protocol.internal.request.Register;
import com.datastax.oss.protocol.internal.request.Startup;
import com.datastax.oss.protocol.internal.response.AuthChallenge;
import com.datastax.oss.protocol.internal.response.AuthSuccess;
import com.datastax.oss.protocol.internal.response.Authenticate;
import com.datastax.oss.protocol.internal.response.Error;
import com.datastax.oss.protocol.internal.response.Ready;
import com.datastax.oss.protocol.internal.response.result.Rows;
import com.datastax.oss.protocol.internal.response.result.SetKeyspace;
import io.netty.channel.ChannelHandlerContext;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.util.List;
import net.jcip.annotations.NotThreadSafe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Handles the sequence of internal requests that we send on a channel before it's ready to accept
* user requests.
*/
@NotThreadSafe
class ProtocolInitHandler extends ConnectInitHandler {
private static final Logger LOG = LoggerFactory.getLogger(ProtocolInitHandler.class);
private static final Query CLUSTER_NAME_QUERY =
new Query("SELECT cluster_name FROM system.local");
private final InternalDriverContext context;
private final long timeoutMillis;
private final ProtocolVersion initialProtocolVersion;
private final DriverChannelOptions options;
// might be null if this is the first channel to this cluster
private final String expectedClusterName;
private final HeartbeatHandler heartbeatHandler;
private String logPrefix;
private ChannelHandlerContext ctx;
ProtocolInitHandler(
InternalDriverContext context,
ProtocolVersion protocolVersion,
String expectedClusterName,
DriverChannelOptions options,
HeartbeatHandler heartbeatHandler) {
this.context = context;
DriverExecutionProfile defaultConfig = context.getConfig().getDefaultProfile();
this.timeoutMillis =
defaultConfig.getDuration(DefaultDriverOption.CONNECTION_INIT_QUERY_TIMEOUT).toMillis();
this.initialProtocolVersion = protocolVersion;
this.expectedClusterName = expectedClusterName;
this.options = options;
this.heartbeatHandler = heartbeatHandler;
this.logPrefix = options.ownerLogPrefix + "|connecting...";
}
@Override
public void channelActive(ChannelHandlerContext ctx) throws Exception {
super.channelActive(ctx);
String channelId = ctx.channel().toString();
this.logPrefix = options.ownerLogPrefix + "|" + channelId.substring(1, channelId.length() - 1);
}
@Override
protected void onRealConnect(ChannelHandlerContext ctx) {
LOG.debug("[{}] Starting channel initialization", logPrefix);
this.ctx = ctx;
new InitRequest(ctx).send();
}
@Override
protected boolean setConnectSuccess() {
boolean result = super.setConnectSuccess();
if (result) {
// add heartbeat to pipeline now that protocol is initialized.
ctx.pipeline().addBefore("inflight", "heartbeat", heartbeatHandler);
}
return result;
}
private enum Step {
STARTUP,
GET_CLUSTER_NAME,
SET_KEYSPACE,
AUTH_RESPONSE,
REGISTER,
}
private class InitRequest extends ChannelHandlerRequest {
// This class is a finite-state automaton, that sends a different query depending on the step
// in the initialization sequence.
private Step step;
private Authenticator authenticator;
private ByteBuffer authReponseToken;
InitRequest(ChannelHandlerContext ctx) {
super(ctx, timeoutMillis);
this.step = Step.STARTUP;
}
@Override
String describe() {
return "[" + logPrefix + "] init query " + step;
}
@Override
Message getRequest() {
switch (step) {
case STARTUP:
return new Startup(context.getStartupOptions());
case GET_CLUSTER_NAME:
return CLUSTER_NAME_QUERY;
case SET_KEYSPACE:
return new Query("USE " + options.keyspace.asCql(false));
case AUTH_RESPONSE:
return new AuthResponse(authReponseToken);
case REGISTER:
return new Register(options.eventTypes);
default:
throw new AssertionError("unhandled step: " + step);
}
}
@Override
void onResponse(Message response) {
LOG.debug(
"[{}] step {} received response opcode={}",
logPrefix,
step,
ProtocolUtils.opcodeString(response.opcode));
try {
if (step == Step.STARTUP && response instanceof Ready) {
context
.getAuthProvider()
.ifPresent(provider -> provider.onMissingChallenge(channel.remoteAddress()));
step = Step.GET_CLUSTER_NAME;
send();
} else if (step == Step.STARTUP && response instanceof Authenticate) {
Authenticate authenticate = (Authenticate) response;
authenticator = buildAuthenticator(channel.remoteAddress(), authenticate.authenticator);
authenticator
.initialResponse()
.whenCompleteAsync(
(token, error) -> {
if (error != null) {
fail(
new AuthenticationException(
channel.remoteAddress(), "authenticator threw an exception", error));
} else {
step = Step.AUTH_RESPONSE;
authReponseToken = token;
send();
}
},
channel.eventLoop())
.exceptionally(UncaughtExceptions::log);
} else if (step == Step.AUTH_RESPONSE && response instanceof AuthChallenge) {
ByteBuffer challenge = ((AuthChallenge) response).token;
authenticator
.evaluateChallenge(challenge)
.whenCompleteAsync(
(token, error) -> {
if (error != null) {
fail(
new AuthenticationException(
channel.remoteAddress(), "authenticator threw an exception", error));
} else {
step = Step.AUTH_RESPONSE;
authReponseToken = token;
send();
}
},
channel.eventLoop())
.exceptionally(UncaughtExceptions::log);
} else if (step == Step.AUTH_RESPONSE && response instanceof AuthSuccess) {
ByteBuffer token = ((AuthSuccess) response).token;
authenticator
.onAuthenticationSuccess(token)
.whenCompleteAsync(
(ignored, error) -> {
if (error != null) {
fail(
new AuthenticationException(
channel.remoteAddress(), "authenticator threw an exception", error));
} else {
step = Step.GET_CLUSTER_NAME;
send();
}
},
channel.eventLoop())
.exceptionally(UncaughtExceptions::log);
} else if (step == Step.AUTH_RESPONSE
&& response instanceof Error
&& ((Error) response).code == ProtocolConstants.ErrorCode.AUTH_ERROR) {
fail(
new AuthenticationException(
channel.remoteAddress(),
String.format("server replied '%s'", ((Error) response).message)));
} else if (step == Step.GET_CLUSTER_NAME && response instanceof Rows) {
Rows rows = (Rows) response;
List<ByteBuffer> row = rows.getData().poll();
String actualClusterName = getString(row, 0);
if (expectedClusterName != null && !expectedClusterName.equals(actualClusterName)) {
fail(
new ClusterNameMismatchException(
channel.remoteAddress(), actualClusterName, expectedClusterName));
} else {
if (expectedClusterName == null) {
// Store the actual name so that it can be retrieved from the factory
channel.attr(DriverChannel.CLUSTER_NAME_KEY).set(actualClusterName);
}
if (options.keyspace != null) {
step = Step.SET_KEYSPACE;
send();
} else if (!options.eventTypes.isEmpty()) {
step = Step.REGISTER;
send();
} else {
setConnectSuccess();
}
}
} else if (step == Step.SET_KEYSPACE && response instanceof SetKeyspace) {
if (!options.eventTypes.isEmpty()) {
step = Step.REGISTER;
send();
} else {
setConnectSuccess();
}
} else if (step == Step.REGISTER && response instanceof Ready) {
setConnectSuccess();
} else if (response instanceof Error) {
Error error = (Error) response;
// Testing for a specific string is a tad fragile but Cassandra doesn't give us a more
// precise error
// code.
// C* 2.1 reports a server error instead of protocol error, see CASSANDRA-9451.
if (step == Step.STARTUP
&& (error.code == ProtocolConstants.ErrorCode.PROTOCOL_ERROR
|| error.code == ProtocolConstants.ErrorCode.SERVER_ERROR)
&& error.message.contains("Invalid or unsupported protocol version")) {
fail(
UnsupportedProtocolVersionException.forSingleAttempt(
channel.remoteAddress(), initialProtocolVersion));
} else if (step == Step.SET_KEYSPACE
&& error.code == ProtocolConstants.ErrorCode.INVALID) {
fail(new InvalidKeyspaceException(error.message));
} else {
failOnUnexpected(error);
}
} else {
failOnUnexpected(response);
}
} catch (AuthenticationException e) {
fail(e);
} catch (Throwable t) {
fail("Unexpected exception at step " + step, t);
}
}
@Override
void fail(String message, Throwable cause) {
Throwable finalException =
(message == null) ? cause : new ConnectionInitException(message, cause);
setConnectFailure(finalException);
}
private Authenticator buildAuthenticator(SocketAddress address, String authenticator) {
return context
.getAuthProvider()
.map(p -> p.newAuthenticator(address, authenticator))
.orElseThrow(
() ->
new AuthenticationException(
address,
String.format(
"Host %s requires authentication (%s), but no authenticator configured",
address, authenticator)));
}
@Override
public String toString() {
return "init query " + step;
}
}
private String getString(List<ByteBuffer> row, int i) {
return TypeCodecs.TEXT.decode(row.get(i), DefaultProtocolVersion.DEFAULT);
}
}