blob: 729ac8794d35e111827262eabc02826bccee9368 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.stream.io;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
public class LeakyBucketStreamThrottler implements StreamThrottler {
private final int maxBytesPerSecond;
private final BlockingQueue<Request> requestQueue = new LinkedBlockingQueue<Request>();
private final ScheduledExecutorService executorService;
private final AtomicBoolean shutdown = new AtomicBoolean(false);
public LeakyBucketStreamThrottler(final int maxBytesPerSecond) {
this.maxBytesPerSecond = maxBytesPerSecond;
executorService = Executors.newSingleThreadScheduledExecutor();
final Runnable task = new Drain();
executorService.scheduleAtFixedRate(task, 0, 1000, TimeUnit.MILLISECONDS);
}
@Override
public void close() {
this.shutdown.set(true);
executorService.shutdown();
try {
// Should not take more than 2 seconds because we run every second. If it takes more than
// 2 seconds, it is because the Runnable thread is blocking on a write; in this case,
// we will just ignore it and return
executorService.awaitTermination(2, TimeUnit.SECONDS);
} catch (InterruptedException e) {
}
}
@Override
public OutputStream newThrottledOutputStream(final OutputStream toWrap) {
return new OutputStream() {
@Override
public void write(final int b) throws IOException {
write(new byte[]{(byte) b}, 0, 1);
}
@Override
public void write(byte[] b) throws IOException {
write(b, 0, b.length);
}
@Override
public void write(byte[] b, int off, int len) throws IOException {
final InputStream in = new ByteArrayInputStream(b, off, len);
LeakyBucketStreamThrottler.this.copy(in, toWrap);
}
@Override
public void close() throws IOException {
toWrap.close();
}
@Override
public void flush() throws IOException {
toWrap.flush();
}
};
}
@Override
public InputStream newThrottledInputStream(final InputStream toWrap) {
return new InputStream() {
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
@Override
public int read() throws IOException {
final ByteArrayOutputStream baos = new ByteArrayOutputStream(1);
LeakyBucketStreamThrottler.this.copy(toWrap, baos, 1L);
if (baos.size() < 1) {
return -1;
}
return baos.toByteArray()[0] & 0xFF;
}
@Override
public int read(final byte[] b) throws IOException {
if (b.length == 0) {
return 0;
}
return read(b, 0, b.length);
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
if (len < 0) {
throw new IllegalArgumentException();
}
if (len == 0) {
return 0;
}
baos.reset();
final int copied = (int) LeakyBucketStreamThrottler.this.copy(toWrap, baos, len);
if (copied == 0) {
return -1;
}
System.arraycopy(baos.toByteArray(), 0, b, off, copied);
return copied;
}
@Override
public void close() throws IOException {
toWrap.close();
}
@Override
public int available() throws IOException {
return toWrap.available();
}
};
}
@Override
public long copy(final InputStream in, final OutputStream out) throws IOException {
return copy(in, out, -1);
}
@Override
public long copy(final InputStream in, final OutputStream out, final long maxBytes) throws IOException {
long totalBytesCopied = 0;
boolean finished = false;
while (!finished) {
final long requestMax = (maxBytes < 0) ? Long.MAX_VALUE : maxBytes - totalBytesCopied;
final Request request = new Request(in, out, requestMax);
boolean transferred = false;
while (!transferred) {
if (shutdown.get()) {
throw new IOException("Throttler shutdown");
}
try {
transferred = requestQueue.offer(request, 1000, TimeUnit.MILLISECONDS);
} catch (final InterruptedException e) {
throw new IOException("Interrupted", e);
}
}
final BlockingQueue<Response> responseQueue = request.getResponseQueue();
Response response = null;
while (response == null) {
try {
if (shutdown.get()) {
throw new IOException("Throttler shutdown");
}
response = responseQueue.poll(1000L, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
throw new IOException("Interrupted", e);
}
}
if (!response.isSuccess()) {
throw response.getError();
}
totalBytesCopied += response.getBytesCopied();
finished = (response.getBytesCopied() == 0) || (totalBytesCopied >= maxBytes && maxBytes > 0);
}
return totalBytesCopied;
}
/**
* This class is responsible for draining water from the leaky bucket. I.e., it actually moves the data
*/
private class Drain implements Runnable {
private final byte[] buffer;
public Drain() {
final int bufferSize = Math.min(4096, maxBytesPerSecond);
buffer = new byte[bufferSize];
}
@Override
public void run() {
final long start = System.currentTimeMillis();
int bytesTransferred = 0;
while (bytesTransferred < maxBytesPerSecond) {
final long maxMillisToWait = 1000 - (System.currentTimeMillis() - start);
if (maxMillisToWait < 1) {
return;
}
try {
final Request request = requestQueue.poll(maxMillisToWait, TimeUnit.MILLISECONDS);
if (request == null) {
return;
}
final BlockingQueue<Response> responseQueue = request.getResponseQueue();
final OutputStream out = request.getOutputStream();
final InputStream in = request.getInputStream();
try {
final long requestMax = request.getMaxBytesToCopy();
long maxBytesToTransfer;
if (requestMax < 0) {
maxBytesToTransfer = Math.min(buffer.length, maxBytesPerSecond - bytesTransferred);
} else {
maxBytesToTransfer = Math.min(requestMax,
Math.min(buffer.length, maxBytesPerSecond - bytesTransferred));
}
maxBytesToTransfer = Math.max(1L, maxBytesToTransfer);
final int bytesCopied = fillBuffer(in, maxBytesToTransfer);
out.write(buffer, 0, bytesCopied);
final Response response = new Response(true, bytesCopied);
responseQueue.put(response);
bytesTransferred += bytesCopied;
} catch (final IOException e) {
final Response response = new Response(e);
responseQueue.put(response);
}
} catch (InterruptedException e) {
}
}
}
private int fillBuffer(final InputStream in, final long maxBytes) throws IOException {
int bytesRead = 0;
int len;
while (bytesRead < maxBytes && (len = in.read(buffer, bytesRead, (int) Math.min(maxBytes - bytesRead, buffer.length - bytesRead))) > 0) {
bytesRead += len;
}
return bytesRead;
}
}
private static class Response {
private final boolean success;
private final IOException error;
private final int bytesCopied;
public Response(final boolean success, final int bytesCopied) {
this.success = success;
this.bytesCopied = bytesCopied;
this.error = null;
}
public Response(final IOException error) {
this.success = false;
this.error = error;
this.bytesCopied = -1;
}
public boolean isSuccess() {
return success;
}
public IOException getError() {
return error;
}
public int getBytesCopied() {
return bytesCopied;
}
}
private static class Request {
private final OutputStream out;
private final InputStream in;
private final long maxBytesToCopy;
private final BlockingQueue<Response> responseQueue;
public Request(final InputStream in, final OutputStream out, final long maxBytesToCopy) {
this.out = out;
this.in = in;
this.maxBytesToCopy = maxBytesToCopy;
this.responseQueue = new LinkedBlockingQueue<Response>(1);
}
public BlockingQueue<Response> getResponseQueue() {
return this.responseQueue;
}
public OutputStream getOutputStream() {
return out;
}
public InputStream getInputStream() {
return in;
}
public long getMaxBytesToCopy() {
return maxBytesToCopy;
}
@Override
public String toString() {
return "Request[maxBytes=" + maxBytesToCopy + "]";
}
}
}