blob: 0fb778bbec202042f1315197ca69c056e7ad9ca1 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.store;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import com.google.common.base.Stopwatch;
import com.google.common.collect.Lists;
/**
* Class used to allow parallel executions of tasks in a simplified way. Also maintains and reports timings of task completion.
* TODO: look at switching to fork join.
* @param <V> The time value that will be returned when the task is executed.
*/
public abstract class TimedRunnable<V> implements Runnable {
private volatile Exception e;
private volatile long timeNanos;
private volatile V value;
@Override
public final void run() {
long start = System.nanoTime();
try{
value = runInner();
}catch(Exception e){
this.e = e;
}finally{
timeNanos = System.nanoTime() - start;
}
}
protected abstract V runInner() throws Exception ;
protected abstract IOException convertToIOException(Exception e);
public long getTimeSpentNanos(){
return timeNanos;
}
public final V getValue() throws IOException {
if(e != null){
if(e instanceof IOException){
throw (IOException) e;
}else{
throw convertToIOException(e);
}
}
return value;
}
private static class LatchedRunnable implements Runnable {
final CountDownLatch latch;
final Runnable runnable;
public LatchedRunnable(CountDownLatch latch, Runnable runnable){
this.latch = latch;
this.runnable = runnable;
}
@Override
public void run() {
try{
runnable.run();
}finally{
latch.countDown();
}
}
}
/**
* Execute the list of runnables with the given parallelization. At end, return values and report completion time stats to provided logger.
* @param activity Name of activity for reporting in logger.
* @param logger The logger to use to report results.
* @param runnables List of runnables that should be executed and timed. If this list has one item, task will be completed in-thread.
* @param parallelism The number of threads that should be run to complete this task.
* @return The list of outcome objects.
* @throws IOException All exceptions are coerced to IOException since this was build for storage system tasks initially.
*/
public static <V> List<V> run(final String activity, final Logger logger, final List<TimedRunnable<V>> runnables, int parallelism) throws IOException {
Stopwatch watch = new Stopwatch().start();
if(runnables.size() == 1){
parallelism = 1;
runnables.get(0).run();
}else{
parallelism = Math.min(parallelism, runnables.size());
final CountDownLatch latch = new CountDownLatch(runnables.size());
final ExecutorService threadPool = Executors.newFixedThreadPool(parallelism);
try{
for(TimedRunnable<V> runnable : runnables){
threadPool.submit(new LatchedRunnable(latch, runnable));
}
}finally{
threadPool.shutdown();
}
try{
latch.await();
}catch(InterruptedException e){
// TODO interrupted exception.
throw new RuntimeException(e);
}
}
List<V> values = Lists.newArrayList();
long sum = 0;
long max = 0;
long count = 0;
IOException excep = null;
for(final TimedRunnable<V> reader : runnables){
try{
values.add(reader.getValue());
sum += reader.getTimeSpentNanos();
count++;
max = Math.max(max, reader.getTimeSpentNanos());
}catch(IOException e){
if(excep == null){
excep = e;
}else{
excep.addSuppressed(e);
}
}
}
if(logger.isInfoEnabled()){
double avg = (sum/1000.0/1000.0)/(count*1.0d);
logger.info(
String.format("%s: Executed %d out of %d using %d threads. "
+ "Time: %dms total, %fms avg, %dms max.",
activity, count, runnables.size(), parallelism, watch.elapsed(TimeUnit.MILLISECONDS), avg, max/1000/1000));
}
if(excep != null) {
throw excep;
}
return values;
}
}