blob: 7e8320a21ca701d1f0fc646e46229bef1ff75772 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.cloud;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.invoke.MethodHandles;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.net.URI;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSocketFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Kindly borrowed the idea and base implementation from the ActiveMQ project;
* useful for blocking traffic on a specified port.
*/
public class SocketProxy {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
public static final int ACCEPT_TIMEOUT_MILLIS = 100;
// should be as large as the HttpShardHandlerFactory socket timeout ... or larger?
public static final int PUMP_SOCKET_TIMEOUT_MS = 100 * 1000;
private URI proxyUrl;
private URI target;
private Acceptor acceptor;
private ServerSocket serverSocket;
private CountDownLatch closed = new CountDownLatch(1);
public List<Bridge> connections = new LinkedList<Bridge>();
private final int listenPort;
private int receiveBufferSize = -1;
private boolean pauseAtStart = false;
private int acceptBacklog = 50;
private boolean usesSSL;
public SocketProxy() throws Exception {
this(0, false);
}
public SocketProxy( boolean useSSL) throws Exception {
this(0, useSSL);
}
public SocketProxy(int port, boolean useSSL) throws Exception {
int listenPort = port;
this.usesSSL = useSSL;
serverSocket = createServerSocket(useSSL);
serverSocket.setReuseAddress(true);
if (receiveBufferSize > 0) {
serverSocket.setReceiveBufferSize(receiveBufferSize);
}
serverSocket.bind(new InetSocketAddress(listenPort), acceptBacklog);
this.listenPort = serverSocket.getLocalPort();
}
public void open(URI uri) throws Exception {
target = uri;
proxyUrl = urlFromSocket(target, serverSocket);
doOpen();
}
public String toString() {
return "SocketyProxy: port="+listenPort+"; target="+target;
}
public void setReceiveBufferSize(int receiveBufferSize) {
this.receiveBufferSize = receiveBufferSize;
}
public void setTarget(URI tcpBrokerUri) {
target = tcpBrokerUri;
}
private void doOpen() throws Exception {
acceptor = new Acceptor(serverSocket, target);
if (pauseAtStart) {
acceptor.pause();
}
new Thread(null, acceptor, "SocketProxy-Acceptor-"
+ serverSocket.getLocalPort()).start();
closed = new CountDownLatch(1);
}
public int getListenPort() {
return listenPort;
}
private ServerSocket createServerSocket(boolean useSSL) throws Exception {
if (useSSL) {
return SSLServerSocketFactory.getDefault().createServerSocket();
}
return new ServerSocket();
}
private Socket createSocket(boolean useSSL) throws Exception {
if (useSSL) {
return SSLSocketFactory.getDefault().createSocket();
}
return new Socket();
}
public URI getUrl() {
return proxyUrl;
}
/*
* close all proxy connections and acceptor
*/
public void close() {
List<Bridge> connections;
synchronized (this.connections) {
connections = new ArrayList<Bridge>(this.connections);
}
log.warn("Closing {} connections to: {}, target: {}", connections.size(), getUrl(), target);
for (Bridge con : connections) {
closeConnection(con);
}
acceptor.close();
closed.countDown();
}
/*
* close all proxy receive connections, leaving acceptor open
*/
public void halfClose() {
List<Bridge> connections;
synchronized (this.connections) {
connections = new ArrayList<Bridge>(this.connections);
}
if (log.isInfoEnabled()) {
log.info("halfClose, numConnections= {}", connections.size());
}
for (Bridge con : connections) {
halfCloseConnection(con);
}
}
public boolean waitUntilClosed(long timeoutSeconds)
throws InterruptedException {
return closed.await(timeoutSeconds, TimeUnit.SECONDS);
}
/*
* called after a close to restart the acceptor on the same port
*/
public void reopen() {
if (log.isInfoEnabled()) {
log.info("Re-opening connectivity to {}", getUrl());
}
try {
if (proxyUrl == null) {
throw new IllegalStateException("Can not call open before open(URI uri).");
}
serverSocket = createServerSocket(usesSSL);
serverSocket.setReuseAddress(true);
if (receiveBufferSize > 0) {
serverSocket.setReceiveBufferSize(receiveBufferSize);
}
serverSocket.bind(new InetSocketAddress(proxyUrl.getPort()));
doOpen();
} catch (Exception e) {
if (log.isDebugEnabled()) {
log.debug("exception on reopen url:{} ", getUrl(), e);
}
}
}
/*
* pause accepting new connections and data transfer through existing proxy
* connections. All sockets remain open
*/
public void pause() {
synchronized (connections) {
if (log.isInfoEnabled()) {
log.info("pause, numConnections={}", connections.size());
}
acceptor.pause();
for (Bridge con : connections) {
con.pause();
}
}
}
/*
* continue after pause
*/
public void goOn() {
synchronized (connections) {
if (log.isInfoEnabled()) {
log.info("goOn, numConnections={}", connections.size());
}
for (Bridge con : connections) {
con.goOn();
}
}
acceptor.goOn();
}
private void closeConnection(Bridge c) {
try {
c.close();
} catch (Exception e) {
log.debug("exception on close of: {}", c, e);
}
}
private void halfCloseConnection(Bridge c) {
try {
c.halfClose();
} catch (Exception e) {
log.debug("exception on half close of: {}", c, e);
}
}
public boolean isPauseAtStart() {
return pauseAtStart;
}
public void setPauseAtStart(boolean pauseAtStart) {
this.pauseAtStart = pauseAtStart;
}
public int getAcceptBacklog() {
return acceptBacklog;
}
public void setAcceptBacklog(int acceptBacklog) {
this.acceptBacklog = acceptBacklog;
}
private URI urlFromSocket(URI uri, ServerSocket serverSocket)
throws Exception {
int listenPort = serverSocket.getLocalPort();
return new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(),
listenPort, uri.getPath(), uri.getQuery(), uri.getFragment());
}
public class Bridge {
private Socket receiveSocket;
private Socket sendSocket;
private Pump requestThread;
private Pump responseThread;
public Bridge(Socket socket, URI target) throws Exception {
receiveSocket = socket;
sendSocket = createSocket(usesSSL);
if (receiveBufferSize > 0) {
sendSocket.setReceiveBufferSize(receiveBufferSize);
}
sendSocket.connect(new InetSocketAddress(target.getHost(), target
.getPort()));
linkWithThreads(receiveSocket, sendSocket);
if (log.isInfoEnabled()) {
log.info("proxy connection {}, receiveBufferSize={}", sendSocket, sendSocket.getReceiveBufferSize());
}
}
public void goOn() {
responseThread.goOn();
requestThread.goOn();
}
public void pause() {
requestThread.pause();
responseThread.pause();
}
public void close() throws Exception {
synchronized (connections) {
connections.remove(this);
}
receiveSocket.close();
sendSocket.close();
}
public void halfClose() throws Exception {
receiveSocket.close();
}
private void linkWithThreads(Socket source, Socket dest) {
requestThread = new Pump("Request", source, dest);
requestThread.start();
responseThread = new Pump("Response", dest, source);
responseThread.start();
}
public class Pump extends Thread {
protected Socket src;
private Socket destination;
private AtomicReference<CountDownLatch> pause = new AtomicReference<CountDownLatch>();
public Pump(String kind, Socket source, Socket dest) {
super("SocketProxy-"+kind+"-" + source.getPort() + ":"
+ dest.getPort());
src = source;
destination = dest;
pause.set(new CountDownLatch(0));
}
public void pause() {
pause.set(new CountDownLatch(1));
}
public void goOn() {
pause.get().countDown();
}
public void run() {
byte[] buf = new byte[1024];
try {
src.setSoTimeout(PUMP_SOCKET_TIMEOUT_MS);
} catch (SocketException e) {
if (e.getMessage().equals("Socket is closed")) {
log.warn("Failed to set socket timeout on {} due to: ", src, e);
return;
}
log.error("Failed to set socket timeout on {} due to ",src, e);
throw new RuntimeException(e);
}
InputStream in = null;
OutputStream out = null;
try {
in = src.getInputStream();
out = destination.getOutputStream();
while (true) {
int len = -1;
try {
len = in.read(buf);
} catch (SocketTimeoutException ste) {
log.warn("Error when reading from {}", src, ste);
}
if (len == -1) {
log.debug("read eof from: {}", src);
break;
}
pause.get().await();
if (len > 0)
out.write(buf, 0, len);
}
} catch (Exception e) {
if (log.isDebugEnabled()) {
log.debug("read/write failed, reason: {}", e.getLocalizedMessage());
}
try {
if (!receiveSocket.isClosed()) {
// for halfClose, on read/write failure if we close the
// remote end will see a close at the same time.
close();
}
} catch (Exception ignore) {}
} finally {
if (in != null) {
try {
in.close();
} catch (Exception exc) {
log.debug("Error when closing InputStream on socket: {}", src, exc);
}
}
if (out != null) {
try {
out.close();
} catch (Exception exc) {
log.debug("{} when closing OutputStream on socket: {}", exc, destination);
}
}
}
}
}
}
public class Acceptor implements Runnable {
private ServerSocket socket;
private URI target;
private AtomicReference<CountDownLatch> pause = new AtomicReference<CountDownLatch>();
public Acceptor(ServerSocket serverSocket, URI uri) {
socket = serverSocket;
target = uri;
pause.set(new CountDownLatch(0));
try {
socket.setSoTimeout(ACCEPT_TIMEOUT_MILLIS);
} catch (SocketException e) {
e.printStackTrace();
}
}
public void pause() {
pause.set(new CountDownLatch(1));
}
public void goOn() {
pause.get().countDown();
}
public void run() {
try {
while (!socket.isClosed()) {
pause.get().await();
try {
Socket source = socket.accept();
pause.get().await();
if (receiveBufferSize > 0) {
source.setReceiveBufferSize(receiveBufferSize);
}
if (log.isInfoEnabled()) {
log.info("accepted {}, receiveBufferSize: {}", source, source.getReceiveBufferSize());
}
synchronized (connections) {
connections.add(new Bridge(source, target));
}
} catch (SocketTimeoutException expected) {}
}
} catch (Exception e) {
if (log.isDebugEnabled()) {
log.debug("acceptor: finished for reason: {}", e.getLocalizedMessage());
}
}
}
public void close() {
try {
socket.close();
closed.countDown();
goOn();
} catch (IOException ignored) {}
}
}
}