blob: 14694fbc90d2776714055691838e0c8f815e73aa [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.livy.rsc.rpc;
import java.io.Closeable;
import java.io.IOException;
import java.net.BindException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.SocketException;
import java.security.SecureRandom;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.concurrent.ScheduledFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.livy.rsc.RSCConf;
import org.apache.livy.rsc.Utils;
import static org.apache.livy.rsc.RSCConf.Entry.*;
/**
* An RPC server. The server matches remote clients based on a secret that is generated on
* the server - the secret needs to be given to the client through some other mechanism for
* this to work.
*/
public class RpcServer implements Closeable {
private static final Logger LOG = LoggerFactory.getLogger(RpcServer.class);
private static final SecureRandom RND = new SecureRandom();
private final String address;
private Channel channel;
private final EventLoopGroup group;
private final int port;
private final ConcurrentMap<String, ClientInfo> pendingClients;
private final RSCConf config;
private final String portRange;
private static enum PortRangeSchema{START_PORT, END_PORT, MAX};
private final String PORT_DELIMITER = "~";
/**
* Creating RPC Server
* @param lconf The default RSC configs
* @throws IOException
* @throws InterruptedException
*/
public RpcServer(RSCConf lconf) throws IOException, InterruptedException {
this.config = lconf;
this.portRange = config.get(LAUNCHER_PORT_RANGE);
this.group = new NioEventLoopGroup(
this.config.getInt(RPC_MAX_THREADS),
Utils.newDaemonThreadFactory("RPC-Handler-%d"));
int [] portData = getPortNumberAndRange();
int startingPortNumber = portData[PortRangeSchema.START_PORT.ordinal()];
int endPort = portData[PortRangeSchema.END_PORT.ordinal()];
boolean isContected = false;
for(int tries = startingPortNumber ; tries<=endPort ; tries++){
try {
this.channel = getChannel(tries);
isContected = true;
break;
} catch(SocketException e){
LOG.debug("RPC not able to connect port " + tries + " " + e.getMessage());
}
}
if(!isContected) {
throw new IOException("Unable to connect to provided ports " + this.portRange);
}
this.port = ((InetSocketAddress) channel.localAddress()).getPort();
this.pendingClients = new ConcurrentHashMap<>();
LOG.info("Connected to the port " + this.port);
String address = config.get(RPC_SERVER_ADDRESS);
if (address == null) {
address = config.findLocalAddress();
}
this.address = address;
}
/**
* Get Port Numbers
*/
private int[] getPortNumberAndRange() throws ArrayIndexOutOfBoundsException,
NumberFormatException {
String[] split = this.portRange.split(PORT_DELIMITER);
int [] portRange = new int [PortRangeSchema.MAX.ordinal()];
try {
portRange[PortRangeSchema.START_PORT.ordinal()] =
Integer.parseInt(split[PortRangeSchema.START_PORT.ordinal()]);
portRange[PortRangeSchema.END_PORT.ordinal()] =
Integer.parseInt(split[PortRangeSchema.END_PORT.ordinal()]);
} catch(ArrayIndexOutOfBoundsException e) {
LOG.error("Port Range format is not correct " + this.portRange);
throw e;
} catch(NumberFormatException e) {
LOG.error("Port are not in numeric format " + this.portRange);
throw e;
}
return portRange;
}
/**
* @throws InterruptedException
**/
private Channel getChannel(int portNumber) throws BindException, InterruptedException {
Channel channel = new ServerBootstrap()
.group(group)
.channel(NioServerSocketChannel.class)
.childHandler(new ChannelInitializer<SocketChannel>() {
@Override
public void initChannel(SocketChannel ch) throws Exception {
SaslServerHandler saslHandler = new SaslServerHandler(config);
final Rpc newRpc = Rpc.createServer(saslHandler, config, ch, group);
saslHandler.rpc = newRpc;
Runnable cancelTask = new Runnable() {
@Override
public void run() {
LOG.warn("Timed out waiting for hello from client.");
newRpc.close();
}
};
saslHandler.cancelTask = group.schedule(cancelTask,
config.getTimeAsMs(RPC_CLIENT_HANDSHAKE_TIMEOUT),
TimeUnit.MILLISECONDS);
}
})
.option(ChannelOption.SO_BACKLOG, 1)
.option(ChannelOption.SO_REUSEADDR, true)
.childOption(ChannelOption.SO_KEEPALIVE, true)
.bind(portNumber)
.sync()
.channel();
return channel;
}
/**
* Tells the RPC server to expect connections from clients.
*
* @param clientId An identifier for the client. Must be unique.
* @param secret The secret the client will send to the server to identify itself.
* @param callback The callback for when a new client successfully connects with the given
* credentials.
*/
public void registerClient(String clientId, String secret, ClientCallback callback) {
final ClientInfo client = new ClientInfo(clientId, secret, callback);
if (pendingClients.putIfAbsent(clientId, client) != null) {
throw new IllegalStateException(
String.format("Client '%s' already registered.", clientId));
}
}
/**
* Stop waiting for connections for a given client ID.
*
* @param clientId The client ID to forget.
*/
public void unregisterClient(String clientId) {
pendingClients.remove(clientId);
}
/**
* Creates a secret for identifying a client connection.
*
* @return the secret
*/
public String createSecret() {
byte[] secret = new byte[config.getInt(RPC_SECRET_RANDOM_BITS) / 8];
RND.nextBytes(secret);
StringBuilder sb = new StringBuilder();
for (byte b : secret) {
if (b < 10) {
sb.append("0");
}
sb.append(Integer.toHexString(b));
}
return sb.toString();
}
public String getAddress() {
return address;
}
public int getPort() {
return port;
}
public EventLoopGroup getEventLoopGroup() {
return group;
}
@Override
public void close() {
try {
channel.close();
pendingClients.clear();
} finally {
group.shutdownGracefully();
}
}
/**
* A callback that can be registered to be notified when new clients are created and
* successfully authenticate against the server.
*/
public interface ClientCallback {
/**
* Called when a new client successfully connects.
*
* @param client The RPC instance for the new client.
* @return The RpcDispatcher to be used for the client.
*/
RpcDispatcher onNewClient(Rpc client);
/**
* Called when a new client successfully completed SASL authentication.
*
* @param client The RPC instance for the new client.
*/
void onSaslComplete(Rpc client);
}
private class SaslServerHandler extends SaslHandler implements CallbackHandler {
private final SaslServer server;
private Rpc rpc;
private ScheduledFuture<?> cancelTask;
private String clientId;
private ClientInfo client;
SaslServerHandler(RSCConf config) throws IOException {
super(config);
this.server = Sasl.createSaslServer(config.get(SASL_MECHANISMS), Rpc.SASL_PROTOCOL,
Rpc.SASL_REALM, config.getSaslOptions(), this);
}
@Override
protected boolean isComplete() {
return server.isComplete();
}
@Override
protected String getNegotiatedProperty(String name) {
return (String) server.getNegotiatedProperty(name);
}
@Override
protected Rpc.SaslMessage update(Rpc.SaslMessage challenge) throws IOException {
if (clientId == null) {
Utils.checkArgument(challenge.clientId != null,
"Missing client ID in SASL handshake.");
clientId = challenge.clientId;
client = pendingClients.get(clientId);
Utils.checkArgument(client != null,
"Unexpected client ID '%s' in SASL handshake.", clientId);
}
return new Rpc.SaslMessage(server.evaluateResponse(challenge.payload));
}
@Override
public byte[] wrap(byte[] data, int offset, int len) throws IOException {
return server.wrap(data, offset, len);
}
@Override
public byte[] unwrap(byte[] data, int offset, int len) throws IOException {
return server.unwrap(data, offset, len);
}
@Override
public void dispose() throws IOException {
if (!server.isComplete()) {
onError(new SaslException("Server closed before SASL negotiation finished."));
}
server.dispose();
}
@Override
protected void onComplete() throws Exception {
cancelTask.cancel(true);
RpcDispatcher dispatcher = null;
try {
dispatcher = client.callback.onNewClient(rpc);
} catch (Exception e) {
LOG.warn("Client callback threw an exception.", e);
}
if (dispatcher != null) {
rpc.setDispatcher(dispatcher);
}
client.callback.onSaslComplete(rpc);
}
@Override
protected void onError(Throwable error) {
cancelTask.cancel(true);
}
@Override
public void handle(Callback[] callbacks) {
Utils.checkState(client != null, "Handshake not initialized yet.");
for (Callback cb : callbacks) {
if (cb instanceof NameCallback) {
((NameCallback)cb).setName(clientId);
} else if (cb instanceof PasswordCallback) {
((PasswordCallback)cb).setPassword(client.secret.toCharArray());
} else if (cb instanceof AuthorizeCallback) {
((AuthorizeCallback) cb).setAuthorized(true);
} else if (cb instanceof RealmCallback) {
RealmCallback rb = (RealmCallback) cb;
rb.setText(rb.getDefaultText());
}
}
}
}
private static class ClientInfo {
final String id;
final String secret;
final ClientCallback callback;
private ClientInfo(String id, String secret, ClientCallback callback) {
this.id = id;
this.secret = secret;
this.callback = callback;
}
}
}