| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.zookeeper.server; |
| |
| import static org.jboss.netty.buffer.ChannelBuffers.dynamicBuffer; |
| |
| import java.io.IOException; |
| import java.net.InetAddress; |
| import java.net.InetSocketAddress; |
| import java.security.KeyManagementException; |
| import java.security.NoSuchAlgorithmException; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.Map; |
| import java.util.Set; |
| import java.util.concurrent.Executors; |
| |
| import javax.net.ssl.SSLContext; |
| import javax.net.ssl.SSLEngine; |
| import javax.net.ssl.SSLPeerUnverifiedException; |
| import javax.net.ssl.SSLSession; |
| import javax.net.ssl.X509KeyManager; |
| import javax.net.ssl.X509TrustManager; |
| |
| import org.apache.zookeeper.KeeperException; |
| import org.apache.zookeeper.common.ZKConfig; |
| import org.apache.zookeeper.common.X509Exception; |
| import org.apache.zookeeper.common.X509Exception.SSLContextException; |
| import org.apache.zookeeper.common.X509Util; |
| import org.apache.zookeeper.server.auth.ProviderRegistry; |
| import org.apache.zookeeper.server.auth.X509AuthenticationProvider; |
| import org.jboss.netty.bootstrap.ServerBootstrap; |
| import org.jboss.netty.buffer.ChannelBuffer; |
| import org.jboss.netty.buffer.ChannelBuffers; |
| import org.jboss.netty.channel.Channel; |
| import org.jboss.netty.channel.ChannelFuture; |
| import org.jboss.netty.channel.ChannelFutureListener; |
| import org.jboss.netty.channel.ChannelHandler.Sharable; |
| import org.jboss.netty.channel.ChannelHandlerContext; |
| import org.jboss.netty.channel.ChannelPipeline; |
| import org.jboss.netty.channel.ChannelPipelineFactory; |
| import org.jboss.netty.channel.ChannelStateEvent; |
| import org.jboss.netty.channel.Channels; |
| import org.jboss.netty.channel.ExceptionEvent; |
| import org.jboss.netty.channel.MessageEvent; |
| import org.jboss.netty.channel.SimpleChannelHandler; |
| import org.jboss.netty.channel.WriteCompletionEvent; |
| import org.jboss.netty.channel.group.ChannelGroup; |
| import org.jboss.netty.channel.group.DefaultChannelGroup; |
| import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory; |
| import org.jboss.netty.handler.ssl.SslHandler; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| public class NettyServerCnxnFactory extends ServerCnxnFactory { |
| private static final Logger LOG = LoggerFactory.getLogger(NettyServerCnxnFactory.class); |
| |
| ServerBootstrap bootstrap; |
| Channel parentChannel; |
| ChannelGroup allChannels = new DefaultChannelGroup("zkServerCnxns"); |
| HashMap<InetAddress, Set<NettyServerCnxn>> ipMap = |
| new HashMap<InetAddress, Set<NettyServerCnxn>>( ); |
| InetSocketAddress localAddress; |
| int maxClientCnxns = 60; |
| |
| /** |
| * This is an inner class since we need to extend SimpleChannelHandler, but |
| * NettyServerCnxnFactory already extends ServerCnxnFactory. By making it inner |
| * this class gets access to the member variables and methods. |
| */ |
| @Sharable |
| class CnxnChannelHandler extends SimpleChannelHandler { |
| |
| @Override |
| public void channelClosed(ChannelHandlerContext ctx, ChannelStateEvent e) |
| throws Exception |
| { |
| if (LOG.isTraceEnabled()) { |
| LOG.trace("Channel closed " + e); |
| } |
| allChannels.remove(ctx.getChannel()); |
| } |
| |
| @Override |
| public void channelConnected(ChannelHandlerContext ctx, |
| ChannelStateEvent e) throws Exception |
| { |
| if (LOG.isTraceEnabled()) { |
| LOG.trace("Channel connected " + e); |
| } |
| |
| NettyServerCnxn cnxn = new NettyServerCnxn(ctx.getChannel(), |
| zkServer, NettyServerCnxnFactory.this); |
| ctx.setAttachment(cnxn); |
| |
| if (secure) { |
| SslHandler sslHandler = ctx.getPipeline().get(SslHandler.class); |
| ChannelFuture handshakeFuture = sslHandler.handshake(); |
| handshakeFuture.addListener(new CertificateVerifier(sslHandler, cnxn)); |
| } else { |
| allChannels.add(ctx.getChannel()); |
| addCnxn(cnxn); |
| } |
| } |
| |
| @Override |
| public void channelDisconnected(ChannelHandlerContext ctx, |
| ChannelStateEvent e) throws Exception |
| { |
| if (LOG.isTraceEnabled()) { |
| LOG.trace("Channel disconnected " + e); |
| } |
| NettyServerCnxn cnxn = (NettyServerCnxn) ctx.getAttachment(); |
| if (cnxn != null) { |
| if (LOG.isTraceEnabled()) { |
| LOG.trace("Channel disconnect caused close " + e); |
| } |
| cnxn.close(); |
| } |
| } |
| |
| @Override |
| public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) |
| throws Exception |
| { |
| LOG.warn("Exception caught " + e, e.getCause()); |
| NettyServerCnxn cnxn = (NettyServerCnxn) ctx.getAttachment(); |
| if (cnxn != null) { |
| if (LOG.isDebugEnabled()) { |
| LOG.debug("Closing " + cnxn); |
| cnxn.close(); |
| } |
| } |
| } |
| |
| @Override |
| public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) |
| throws Exception |
| { |
| if (LOG.isTraceEnabled()) { |
| LOG.trace("message received called " + e.getMessage()); |
| } |
| try { |
| if (LOG.isDebugEnabled()) { |
| LOG.debug("New message " + e.toString() |
| + " from " + ctx.getChannel()); |
| } |
| NettyServerCnxn cnxn = (NettyServerCnxn)ctx.getAttachment(); |
| synchronized(cnxn) { |
| processMessage(e, cnxn); |
| } |
| } catch(Exception ex) { |
| LOG.error("Unexpected exception in receive", ex); |
| throw ex; |
| } |
| } |
| |
| private void processMessage(MessageEvent e, NettyServerCnxn cnxn) { |
| if (LOG.isDebugEnabled()) { |
| LOG.debug(Long.toHexString(cnxn.sessionId) + " queuedBuffer: " |
| + cnxn.queuedBuffer); |
| } |
| |
| if (e instanceof NettyServerCnxn.ResumeMessageEvent) { |
| LOG.debug("Received ResumeMessageEvent"); |
| if (cnxn.queuedBuffer != null) { |
| if (LOG.isTraceEnabled()) { |
| LOG.trace("processing queue " |
| + Long.toHexString(cnxn.sessionId) |
| + " queuedBuffer 0x" |
| + ChannelBuffers.hexDump(cnxn.queuedBuffer)); |
| } |
| cnxn.receiveMessage(cnxn.queuedBuffer); |
| if (!cnxn.queuedBuffer.readable()) { |
| LOG.debug("Processed queue - no bytes remaining"); |
| cnxn.queuedBuffer = null; |
| } else { |
| LOG.debug("Processed queue - bytes remaining"); |
| } |
| } else { |
| LOG.debug("queue empty"); |
| } |
| cnxn.channel.setReadable(true); |
| } else { |
| ChannelBuffer buf = (ChannelBuffer)e.getMessage(); |
| if (LOG.isTraceEnabled()) { |
| LOG.trace(Long.toHexString(cnxn.sessionId) |
| + " buf 0x" |
| + ChannelBuffers.hexDump(buf)); |
| } |
| |
| if (cnxn.throttled) { |
| LOG.debug("Received message while throttled"); |
| // we are throttled, so we need to queue |
| if (cnxn.queuedBuffer == null) { |
| LOG.debug("allocating queue"); |
| cnxn.queuedBuffer = dynamicBuffer(buf.readableBytes()); |
| } |
| cnxn.queuedBuffer.writeBytes(buf); |
| if (LOG.isTraceEnabled()) { |
| LOG.trace(Long.toHexString(cnxn.sessionId) |
| + " queuedBuffer 0x" |
| + ChannelBuffers.hexDump(cnxn.queuedBuffer)); |
| } |
| } else { |
| LOG.debug("not throttled"); |
| if (cnxn.queuedBuffer != null) { |
| if (LOG.isTraceEnabled()) { |
| LOG.trace(Long.toHexString(cnxn.sessionId) |
| + " queuedBuffer 0x" |
| + ChannelBuffers.hexDump(cnxn.queuedBuffer)); |
| } |
| cnxn.queuedBuffer.writeBytes(buf); |
| if (LOG.isTraceEnabled()) { |
| LOG.trace(Long.toHexString(cnxn.sessionId) |
| + " queuedBuffer 0x" |
| + ChannelBuffers.hexDump(cnxn.queuedBuffer)); |
| } |
| |
| cnxn.receiveMessage(cnxn.queuedBuffer); |
| if (!cnxn.queuedBuffer.readable()) { |
| LOG.debug("Processed queue - no bytes remaining"); |
| cnxn.queuedBuffer = null; |
| } else { |
| LOG.debug("Processed queue - bytes remaining"); |
| } |
| } else { |
| cnxn.receiveMessage(buf); |
| if (buf.readable()) { |
| if (LOG.isTraceEnabled()) { |
| LOG.trace("Before copy " + buf); |
| } |
| cnxn.queuedBuffer = dynamicBuffer(buf.readableBytes()); |
| cnxn.queuedBuffer.writeBytes(buf); |
| if (LOG.isTraceEnabled()) { |
| LOG.trace("Copy is " + cnxn.queuedBuffer); |
| LOG.trace(Long.toHexString(cnxn.sessionId) |
| + " queuedBuffer 0x" |
| + ChannelBuffers.hexDump(cnxn.queuedBuffer)); |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| @Override |
| public void writeComplete(ChannelHandlerContext ctx, |
| WriteCompletionEvent e) throws Exception |
| { |
| if (LOG.isTraceEnabled()) { |
| LOG.trace("write complete " + e); |
| } |
| } |
| |
| private final class CertificateVerifier |
| implements ChannelFutureListener { |
| private final SslHandler sslHandler; |
| private final NettyServerCnxn cnxn; |
| |
| CertificateVerifier(SslHandler sslHandler, NettyServerCnxn cnxn) { |
| this.sslHandler = sslHandler; |
| this.cnxn = cnxn; |
| } |
| |
| /** |
| * Only allow the connection to stay open if certificate passes auth |
| */ |
| public void operationComplete(ChannelFuture future) |
| throws SSLPeerUnverifiedException { |
| if (future.isSuccess()) { |
| LOG.debug("Successful handshake with session 0x{}", |
| Long.toHexString(cnxn.sessionId)); |
| SSLEngine eng = sslHandler.getEngine(); |
| SSLSession session = eng.getSession(); |
| cnxn.setClientCertificateChain(session.getPeerCertificates()); |
| |
| String authProviderProp |
| = System.getProperty(ZKConfig.SSL_AUTHPROVIDER, "x509"); |
| |
| X509AuthenticationProvider authProvider = |
| (X509AuthenticationProvider) |
| ProviderRegistry.getProvider(authProviderProp); |
| |
| if (authProvider == null) { |
| LOG.error("Auth provider not found: {}", authProviderProp); |
| cnxn.close(); |
| return; |
| } |
| |
| if (KeeperException.Code.OK != |
| authProvider.handleAuthentication(cnxn, null)) { |
| LOG.error("Authentication failed for session 0x{}", |
| Long.toHexString(cnxn.sessionId)); |
| cnxn.close(); |
| return; |
| } |
| |
| allChannels.add(future.getChannel()); |
| addCnxn(cnxn); |
| } else { |
| LOG.error("Unsuccessful handshake with session 0x{}", |
| Long.toHexString(cnxn.sessionId)); |
| cnxn.close(); |
| } |
| } |
| } |
| } |
| |
| CnxnChannelHandler channelHandler = new CnxnChannelHandler(); |
| |
| NettyServerCnxnFactory() { |
| bootstrap = new ServerBootstrap( |
| new NioServerSocketChannelFactory( |
| Executors.newCachedThreadPool(), |
| Executors.newCachedThreadPool())); |
| // parent channel |
| bootstrap.setOption("reuseAddress", true); |
| // child channels |
| bootstrap.setOption("child.tcpNoDelay", true); |
| /* set socket linger to off, so that socket close does not block */ |
| bootstrap.setOption("child.soLinger", -1); |
| bootstrap.setPipelineFactory(new ChannelPipelineFactory() { |
| @Override |
| public ChannelPipeline getPipeline() throws Exception { |
| ChannelPipeline p = Channels.pipeline(); |
| if (secure) { |
| initSSL(p); |
| } |
| p.addLast("servercnxnfactory", channelHandler); |
| |
| return p; |
| } |
| }); |
| } |
| |
| private synchronized void initSSL(ChannelPipeline p) |
| throws X509Exception, KeyManagementException, NoSuchAlgorithmException { |
| String authProviderProp = System.getProperty(ZKConfig.SSL_AUTHPROVIDER); |
| SSLContext sslContext; |
| if (authProviderProp == null) { |
| sslContext = X509Util.createSSLContext(); |
| } else { |
| sslContext = SSLContext.getInstance("TLSv1"); |
| X509AuthenticationProvider authProvider = |
| (X509AuthenticationProvider)ProviderRegistry.getProvider( |
| System.getProperty(ZKConfig.SSL_AUTHPROVIDER, |
| "x509")); |
| |
| if (authProvider == null) |
| { |
| LOG.error("Auth provider not found: {}", authProviderProp); |
| throw new SSLContextException( |
| "Could not create SSLContext with specified auth provider: " + |
| authProviderProp); |
| } |
| |
| sslContext.init(new X509KeyManager[] { authProvider.getKeyManager() }, |
| new X509TrustManager[] { authProvider.getTrustManager() }, |
| null); |
| } |
| |
| SSLEngine sslEngine = sslContext.createSSLEngine(); |
| sslEngine.setUseClientMode(false); |
| sslEngine.setNeedClientAuth(true); |
| |
| p.addLast("ssl", new SslHandler(sslEngine)); |
| LOG.info("SSL handler added for channel: {}", p.getChannel()); |
| } |
| |
| @Override |
| public void closeAll() { |
| if (LOG.isDebugEnabled()) { |
| LOG.debug("closeAll()"); |
| } |
| // clear all the connections on which we are selecting |
| int length = cnxns.size(); |
| for (ServerCnxn cnxn : cnxns) { |
| try { |
| // This will remove the cnxn from cnxns |
| cnxn.close(); |
| } catch (Exception e) { |
| LOG.warn("Ignoring exception closing cnxn sessionid 0x" |
| + Long.toHexString(cnxn.getSessionId()), e); |
| } |
| } |
| if (LOG.isDebugEnabled()) { |
| LOG.debug("allChannels size:" + allChannels.size() + " cnxns size:" |
| + length); |
| } |
| } |
| |
| @Override |
| public boolean closeSession(long sessionId) { |
| if (LOG.isDebugEnabled()) { |
| LOG.debug("closeSession sessionid:0x" + sessionId); |
| } |
| for (ServerCnxn cnxn : cnxns) { |
| if (cnxn.getSessionId() == sessionId) { |
| try { |
| cnxn.close(); |
| } catch (Exception e) { |
| LOG.warn("exception during session close", e); |
| } |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| @Override |
| public void configure(InetSocketAddress addr, int maxClientCnxns, boolean secure) |
| throws IOException |
| { |
| configureSaslLogin(); |
| localAddress = addr; |
| this.maxClientCnxns = maxClientCnxns; |
| this.secure = secure; |
| } |
| |
| /** {@inheritDoc} */ |
| public int getMaxClientCnxnsPerHost() { |
| return maxClientCnxns; |
| } |
| |
| /** {@inheritDoc} */ |
| public void setMaxClientCnxnsPerHost(int max) { |
| maxClientCnxns = max; |
| } |
| |
| @Override |
| public int getLocalPort() { |
| return localAddress.getPort(); |
| } |
| |
| boolean killed; |
| @Override |
| public void join() throws InterruptedException { |
| synchronized(this) { |
| while(!killed) { |
| wait(); |
| } |
| } |
| } |
| |
| @Override |
| public void shutdown() { |
| LOG.info("shutdown called " + localAddress); |
| if (login != null) { |
| login.shutdown(); |
| } |
| // null if factory never started |
| if (parentChannel != null) { |
| parentChannel.close().awaitUninterruptibly(); |
| closeAll(); |
| allChannels.close().awaitUninterruptibly(); |
| bootstrap.releaseExternalResources(); |
| } |
| |
| if (zkServer != null) { |
| zkServer.shutdown(); |
| } |
| synchronized(this) { |
| killed = true; |
| notifyAll(); |
| } |
| } |
| |
| @Override |
| public void start() { |
| LOG.info("binding to port " + localAddress); |
| parentChannel = bootstrap.bind(localAddress); |
| } |
| |
| public void reconfigure(InetSocketAddress addr) |
| { |
| Channel oldChannel = parentChannel; |
| LOG.info("binding to port " + addr); |
| parentChannel = bootstrap.bind(addr); |
| localAddress = addr; |
| oldChannel.close(); |
| } |
| |
| @Override |
| public void startup(ZooKeeperServer zks, boolean startServer) |
| throws IOException, InterruptedException { |
| start(); |
| setZooKeeperServer(zks); |
| if (startServer) { |
| zks.startdata(); |
| zks.startup(); |
| } |
| } |
| |
| @Override |
| public Iterable<ServerCnxn> getConnections() { |
| return cnxns; |
| } |
| |
| @Override |
| public InetSocketAddress getLocalAddress() { |
| return localAddress; |
| } |
| |
| private void addCnxn(NettyServerCnxn cnxn) { |
| cnxns.add(cnxn); |
| synchronized (ipMap){ |
| InetAddress addr = |
| ((InetSocketAddress)cnxn.channel.getRemoteAddress()) |
| .getAddress(); |
| Set<NettyServerCnxn> s = ipMap.get(addr); |
| if (s == null) { |
| s = new HashSet<NettyServerCnxn>(); |
| } |
| s.add(cnxn); |
| ipMap.put(addr,s); |
| } |
| } |
| |
| @Override |
| public void resetAllConnectionStats() { |
| // No need to synchronize since cnxns is backed by a ConcurrentHashMap |
| for(ServerCnxn c : cnxns){ |
| c.resetStats(); |
| } |
| } |
| |
| @Override |
| public Iterable<Map<String, Object>> getAllConnectionInfo(boolean brief) { |
| HashSet<Map<String,Object>> info = new HashSet<Map<String,Object>>(); |
| // No need to synchronize since cnxns is backed by a ConcurrentHashMap |
| for (ServerCnxn c : cnxns) { |
| info.add(c.getConnectionInfo(brief)); |
| } |
| return info; |
| } |
| |
| } |