blob: dd51a6839c37037e877baf0230a20552c062842b [file] [log] [blame]
/**
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/
package org.apache.bookkeeper.proto;
import java.io.IOException;
import java.net.Inet4Address;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.bookkeeper.auth.BookieAuthProvider;
import org.apache.bookkeeper.auth.ClientAuthProvider;
import org.apache.bookkeeper.client.BKException;
import org.apache.bookkeeper.proto.BookkeeperInternalCallbacks.GenericCallback;
import org.apache.bookkeeper.proto.BookkeeperProtocol.AuthMessage;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.ChannelStateEvent;
import org.jboss.netty.channel.DefaultExceptionEvent;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelHandler;
import org.jboss.netty.channel.local.LocalChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class AuthHandler {
static final Logger LOG = LoggerFactory.getLogger(AuthHandler.class);
static class ServerSideHandler extends SimpleChannelHandler {
volatile boolean authenticated = false;
final BookieAuthProvider.Factory authProviderFactory;
BookieAuthProvider authProvider;
ServerSideHandler(BookieAuthProvider.Factory authProviderFactory) {
this.authProviderFactory = authProviderFactory;
authProvider = null;
}
@Override
public void channelOpen(ChannelHandlerContext ctx,
ChannelStateEvent e) throws Exception {
LOG.info("Channel open {}", ctx.getChannel());
SocketAddress remote = ctx.getChannel().getRemoteAddress();
if (remote instanceof InetSocketAddress) {
authProvider = authProviderFactory.newProvider((InetSocketAddress)remote,
new AuthHandshakeCompleteCallback());
} else if (ctx.getChannel() instanceof LocalChannel) {
authProvider = authProviderFactory.newProvider(new InetSocketAddress(Inet4Address.getLocalHost(), 0),
new AuthHandshakeCompleteCallback());
} else {
LOG.error("Unknown channel ({}) or socket type {} for {}",
new Object[] { ctx.getChannel(), remote != null ? remote.getClass() : null, remote });
}
super.channelOpen(ctx, e);
}
@Override
public void messageReceived(ChannelHandlerContext ctx,
MessageEvent e)
throws Exception {
if (authProvider == null) {
// close the channel, authProvider should only be
// null if the other end of line is an InetSocketAddress
// anything else is strange, and we don't want to deal
// with it
ctx.getChannel().close();
return;
}
Object event = e.getMessage();
if (authenticated) {
super.messageReceived(ctx, e);
} else if (event instanceof BookieProtocol.AuthRequest) { // pre-PB-client
BookieProtocol.AuthRequest req = (BookieProtocol.AuthRequest)event;
assert (req.getOpCode() == BookieProtocol.AUTH);
if (checkAuthPlugin(req.getAuthMessage(), ctx.getChannel())) {
authProvider.process(req.getAuthMessage(),
new AuthResponseCallbackLegacy(req, ctx.getChannel()));
} else {
ctx.getChannel().close();
}
} else if (event instanceof BookieProtocol.Request) {
BookieProtocol.Request req = (BookieProtocol.Request)event;
if (req.getOpCode() == BookieProtocol.ADDENTRY) {
ctx.getChannel().write(
new BookieProtocol.AddResponse(
req.getProtocolVersion(), BookieProtocol.EUA,
req.getLedgerId(), req.getEntryId()));
} else if (req.getOpCode() == BookieProtocol.READENTRY) {
ctx.getChannel().write(
new BookieProtocol.ReadResponse(
req.getProtocolVersion(), BookieProtocol.EUA,
req.getLedgerId(), req.getEntryId()));
} else {
ctx.getChannel().close();
}
} else if (event instanceof BookkeeperProtocol.Request) { // post-PB-client
BookkeeperProtocol.Request req = (BookkeeperProtocol.Request)event;
if (req.getHeader().getOperation() == BookkeeperProtocol.OperationType.AUTH
&& req.hasAuthRequest()
&& checkAuthPlugin(req.getAuthRequest(), ctx.getChannel())) {
authProvider.process(req.getAuthRequest(),
new AuthResponseCallback(req, ctx.getChannel()));
} else {
BookkeeperProtocol.Response.Builder builder
= BookkeeperProtocol.Response.newBuilder()
.setHeader(req.getHeader())
.setStatus(BookkeeperProtocol.StatusCode.EUA);
ctx.getChannel().write(builder.build());
}
} else {
// close the channel, junk coming over it
ctx.getChannel().close();
}
}
private boolean checkAuthPlugin(AuthMessage am, final Channel src) {
if (!am.hasAuthPluginName()
|| !am.getAuthPluginName().equals(authProviderFactory.getPluginName())) {
LOG.error("Received message from incompatible auth plugin. Local = {},"
+ " Remote = {}, Channel = {}",
authProviderFactory.getPluginName(), am.getAuthPluginName());
return false;
}
return true;
}
static class AuthResponseCallbackLegacy implements GenericCallback<AuthMessage> {
final BookieProtocol.AuthRequest req;
final Channel channel;
AuthResponseCallbackLegacy(BookieProtocol.AuthRequest req, Channel channel) {
this.req = req;
this.channel = channel;
}
public void operationComplete(int rc, AuthMessage newam) {
if (rc != BKException.Code.OK) {
LOG.error("Error processing auth message, closing connection");
channel.close();
return;
}
channel.write(new BookieProtocol.AuthResponse(req.getProtocolVersion(),
newam));
}
}
static class AuthResponseCallback implements GenericCallback<AuthMessage> {
final BookkeeperProtocol.Request req;
final Channel channel;
AuthResponseCallback(BookkeeperProtocol.Request req, Channel channel) {
this.req = req;
this.channel = channel;
}
public void operationComplete(int rc, AuthMessage newam) {
BookkeeperProtocol.Response.Builder builder
= BookkeeperProtocol.Response.newBuilder()
.setHeader(req.getHeader());
if (rc != BKException.Code.OK) {
LOG.error("Error processing auth message, closing connection");
builder.setStatus(BookkeeperProtocol.StatusCode.EUA);
channel.write(builder.build());
channel.close();
return;
} else {
builder.setStatus(BookkeeperProtocol.StatusCode.EOK)
.setAuthResponse(newam);
channel.write(builder.build());
}
}
}
class AuthHandshakeCompleteCallback implements GenericCallback<Void> {
@Override
public void operationComplete(int rc, Void v) {
if (rc == BKException.Code.OK) {
authenticated = true;
} else {
LOG.debug("Authentication failed on server side");
}
}
}
}
static class ClientSideHandler extends SimpleChannelHandler {
volatile boolean authenticated = false;
final ClientAuthProvider.Factory authProviderFactory;
ClientAuthProvider authProvider;
AtomicLong transactionIdGenerator;
Queue<MessageEvent> waitingForAuth = new ConcurrentLinkedQueue<MessageEvent>();
ClientSideHandler(ClientAuthProvider.Factory authProviderFactory,
AtomicLong transactionIdGenerator) {
this.authProviderFactory = authProviderFactory;
this.transactionIdGenerator = transactionIdGenerator;
authProvider = null;
}
@Override
public void channelConnected(ChannelHandlerContext ctx,
ChannelStateEvent e)
throws Exception {
SocketAddress remote = ctx.getChannel().getRemoteAddress();
if (remote instanceof InetSocketAddress) {
authProvider = authProviderFactory.newProvider((InetSocketAddress)remote,
new AuthHandshakeCompleteCallback(ctx));
} else if (ctx.getChannel() instanceof LocalChannel) {
authProvider = authProviderFactory.newProvider(new InetSocketAddress(Inet4Address.getLocalHost(), 0),
new AuthHandshakeCompleteCallback(ctx));
} else {
LOG.error("Unknown channel ({}) or socket type {} for {}",
new Object[] { ctx.getChannel(), remote != null ? remote.getClass() : null, remote });
}
if (authProvider != null) {
authProvider.init(new AuthRequestCallback(ctx));
}
super.channelConnected(ctx, e);
}
@Override
public void messageReceived(ChannelHandlerContext ctx,
MessageEvent e)
throws Exception {
assert (authProvider != null);
Object event = e.getMessage();
if (authenticated) {
super.messageReceived(ctx, e);
} else if (event instanceof BookkeeperProtocol.Response) {
BookkeeperProtocol.Response resp = (BookkeeperProtocol.Response)event;
if (resp.getHeader().getOperation() == BookkeeperProtocol.OperationType.AUTH) {
if (resp.getStatus() != BookkeeperProtocol.StatusCode.EOK) {
authenticationError(ctx, resp.getStatus().getNumber());
} else {
assert (resp.hasAuthResponse());
BookkeeperProtocol.AuthMessage am = resp.getAuthResponse();
authProvider.process(am, new AuthRequestCallback(ctx));
}
} else {
// else just drop the message,
// we're not authenticated so nothing should be coming through
}
}
}
@Override
public void writeRequested(ChannelHandlerContext ctx,
MessageEvent e)
throws Exception {
synchronized (this) {
if (authenticated) {
super.writeRequested(ctx, e);
} else if (e.getMessage() instanceof BookkeeperProtocol.Request) {
// let auth messages through, queue the rest
BookkeeperProtocol.Request req = (BookkeeperProtocol.Request)e.getMessage();
if (req.getHeader().getOperation()
== BookkeeperProtocol.OperationType.AUTH) {
super.writeRequested(ctx, e);
} else {
waitingForAuth.add(e);
}
} // else just drop
}
}
long newTxnId() {
return transactionIdGenerator.incrementAndGet();
}
void authenticationError(ChannelHandlerContext ctx, int errorCode) {
LOG.error("Error processing auth message, erroring connection {}", errorCode);
ctx.sendUpstream(new DefaultExceptionEvent(ctx.getChannel(),
new AuthenticationException(
"Auth failed with error " + errorCode)));
}
class AuthRequestCallback implements GenericCallback<AuthMessage> {
Channel channel;
ChannelHandlerContext ctx;
AuthRequestCallback(ChannelHandlerContext ctx) {
this.channel = ctx.getChannel();
this.ctx = ctx;
}
public void operationComplete(int rc, AuthMessage newam) {
if (rc != BKException.Code.OK) {
authenticationError(ctx, rc);
return;
}
BookkeeperProtocol.BKPacketHeader header
= BookkeeperProtocol.BKPacketHeader.newBuilder()
.setVersion(BookkeeperProtocol.ProtocolVersion.VERSION_THREE)
.setOperation(BookkeeperProtocol.OperationType.AUTH)
.setTxnId(newTxnId()).build();
BookkeeperProtocol.Request.Builder builder
= BookkeeperProtocol.Request.newBuilder()
.setHeader(header)
.setAuthRequest(newam);
channel.write(builder.build());
}
}
class AuthHandshakeCompleteCallback implements GenericCallback<Void> {
ChannelHandlerContext ctx;
AuthHandshakeCompleteCallback(ChannelHandlerContext ctx) {
this.ctx = ctx;
}
@Override
public void operationComplete(int rc, Void v) {
if (rc == BKException.Code.OK) {
synchronized (this) {
authenticated = true;
MessageEvent e = waitingForAuth.poll();
while (e != null) {
ctx.sendDownstream(e);
e = waitingForAuth.poll();
}
}
} else {
authenticationError(ctx, rc);
LOG.debug("Authentication failed on server side");
}
}
}
}
static class AuthenticationException extends IOException {
AuthenticationException(String reason) {
super(reason);
}
}
}