blob: eb9ccf82683bd08c43432f666432f647f5d9122e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.aries.rsa.provider.fastbin.tcp;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.aries.rsa.provider.fastbin.io.ProtocolCodec;
import org.apache.aries.rsa.provider.fastbin.io.Service;
import org.apache.aries.rsa.provider.fastbin.io.Transport;
import org.apache.aries.rsa.provider.fastbin.io.TransportListener;
import org.fusesource.hawtdispatch.DispatchQueue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public abstract class TransportPool implements Service {
protected static final Logger LOGGER = LoggerFactory.getLogger(TransportPool.class);
public static final int DEFAULT_POOL_SIZE = 2;
public static final long DEFAULT_EVICTION_DELAY = TimeUnit.MINUTES.toMillis(5);
protected final String uri;
protected final DispatchQueue queue;
protected final LinkedList<Pair> pending = new LinkedList<>();
protected final Map<Transport, TransportState> transports = new HashMap<>();
protected AtomicBoolean running = new AtomicBoolean(false);
protected int poolSize;
protected long evictionDelay;
public TransportPool(String uri, DispatchQueue queue) {
this(uri, queue, DEFAULT_POOL_SIZE, DEFAULT_EVICTION_DELAY);
}
public TransportPool(String uri, DispatchQueue queue, int poolSize, long evictionDelay) {
this.uri = uri;
this.queue = queue;
this.poolSize = poolSize;
this.evictionDelay = evictionDelay;
}
protected abstract Transport createTransport(String uri) throws Exception;
protected abstract ProtocolCodec createCodec();
protected abstract void onCommand(Object command);
protected abstract void onFailure(Object id, Throwable throwable);
protected void onDone(Object id) {
for (TransportState state : transports.values()) {
if (state.inflight.remove(id)) {
break;
}
}
}
public void offer(final Object data, final Object id) {
if (!running.get()) {
throw new IllegalStateException("Transport pool stopped");
}
queue.execute(new Runnable() {
public void run() {
Transport transport = getIdleTransport();
if (transport != null) {
doOffer(transport, data, id);
if( transport.full() ) {
transports.get(transport).time = 0L;
}
} else {
pending.add(new Pair(data, id));
}
}
});
}
protected boolean doOffer(Transport transport, Object command, Object id) {
transports.get(transport).inflight.add(id);
return transport.offer(command);
}
protected Transport getIdleTransport() {
for (Map.Entry<Transport, TransportState> entry : transports.entrySet()) {
if (entry.getValue().time > 0) {
return entry.getKey();
}
}
if (transports.size() < poolSize) {
try {
startNewTransport();
} catch (Exception e) {
LOGGER.info("Unable to start new transport", e);
}
}
return null;
}
public void start() throws Exception {
start(null);
}
public void start(Runnable onComplete) throws Exception {
running.set(true);
}
public void stop() {
stop(null);
}
public void stop(final Runnable onComplete) {
if (running.compareAndSet(true, false)) {
queue.execute(new Runnable() {
public void run() {
final AtomicInteger latch = new AtomicInteger(transports.size());
final Runnable countDown = new Runnable() {
public void run() {
if (latch.decrementAndGet() == 0) {
while (!pending.isEmpty()) {
Pair p = pending.removeFirst();
onFailure(p.id, new IOException("Transport stopped"));
}
onComplete.run();
}
}
};
while (!transports.isEmpty()) {
Transport transport = transports.keySet().iterator().next();
TransportState state = transports.remove(transport);
if (state != null) {
for (Object id : state.inflight) {
onFailure(id, new IOException("Transport stopped"));
}
}
transport.stop(countDown);
}
}
});
} else {
onComplete.run();
}
}
protected void startNewTransport() throws Exception {
LOGGER.debug("Creating new transport for: {}", this.uri);
Transport transport = createTransport(this.uri);
transport.setDispatchQueue(queue);
transport.setProtocolCodec(createCodec());
transport.setTransportListener(new Listener());
transports.put(transport, new TransportState());
transport.start();
}
protected static class Pair {
Object command;
Object id;
public Pair(Object command, Object id) {
this.command = command;
this.id = id;
}
}
protected static class TransportState {
long time;
final Set<Object> inflight;
public TransportState() {
time = 0;
inflight = new HashSet<>();
}
}
protected class Listener implements TransportListener {
public void onTransportCommand(Transport transport, Object command) {
TransportPool.this.onCommand(command);
}
public void onRefill(final Transport transport) {
while (pending.size() > 0 && !transport.full()) {
Pair pair = pending.removeFirst();
boolean accepted = doOffer(transport, pair.command, pair.id);
assert accepted: "Should have been accepted since the transport was not full";
}
if( transport.full() ) {
transports.get(transport).time = 0L;
} else {
final long time = System.currentTimeMillis();
transports.get(transport).time = time;
if (evictionDelay > 0) {
queue.executeAfter(evictionDelay, TimeUnit.MILLISECONDS, new Runnable() {
public void run() {
TransportState state = transports.get(transport);
if (state != null && state.time == time) {
transports.remove(transport);
transport.stop();
}
}
});
}
}
}
public void onTransportFailure(Transport transport, IOException error) {
if (!transport.isDisposed()) {
LOGGER.info("Transport failure", error);
TransportState state = transports.remove(transport);
if (state != null) {
for (Object id : state.inflight) {
onFailure(id, error);
}
}
transport.stop();
if (transports.isEmpty()) {
while (!pending.isEmpty()) {
Pair p = pending.removeFirst();
onFailure(p.id, error);
}
}
}
}
public void onTransportConnected(Transport transport) {
transport.resumeRead();
onRefill(transport);
}
public void onTransportDisconnected(Transport transport) {
onTransportFailure(transport, new IOException("Transport disconnected"));
}
}
}