blob: 99b11825bd9527d1f398cc64d3236f789b39e8b4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.drill.exec.store;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.drill.common.collections.Collectors;
import org.apache.drill.common.exceptions.UserException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Preconditions;
import com.google.common.base.Stopwatch;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
/**
* Allows parallel executions of tasks in a simplified way. Also maintains and
* reports timings of task completion.
* <p>
* TODO: look at switching to fork join.
*
* @param <V>
* The time value that will be returned when the task is executed.
*/
public abstract class TimedCallable<V> implements Callable<V> {
private static final Logger logger = LoggerFactory.getLogger(TimedCallable.class);
private static long TIMEOUT_PER_RUNNABLE_IN_MSECS = 15000;
private volatile long startTime = 0;
private volatile long executionTime = -1;
private static class FutureMapper<V> implements Function<Future<V>, V> {
int count;
Throwable throwable = null;
private void setThrowable(Throwable t) {
if (throwable == null) {
throwable = t;
} else {
throwable.addSuppressed(t);
}
}
@Override
public V apply(Future<V> future) {
Preconditions.checkState(future.isDone());
if (!future.isCancelled()) {
try {
count++;
return future.get();
} catch (InterruptedException e) {
// there is no wait as we are getting result from the completed/done future
logger.error("Unexpected exception", e);
throw UserException.internalError(e)
.message("Unexpected exception")
.build(logger);
} catch (ExecutionException e) {
setThrowable(e.getCause());
}
} else {
setThrowable(new CancellationException());
}
return null;
}
}
private static class Statistics<V> implements Consumer<TimedCallable<V>> {
final long start = System.nanoTime();
final Stopwatch watch = Stopwatch.createStarted();
long totalExecution;
long maxExecution;
int count;
int startedCount;
private int doneCount;
// measure thread creation times
long earliestStart;
long latestStart;
long totalStart;
@Override
public void accept(TimedCallable<V> task) {
count++;
long threadStart = task.getStartTime(TimeUnit.NANOSECONDS) - start;
if (threadStart >= 0) {
startedCount++;
earliestStart = Math.min(earliestStart, threadStart);
latestStart = Math.max(latestStart, threadStart);
totalStart += threadStart;
long executionTime = task.getExecutionTime(TimeUnit.NANOSECONDS);
if (executionTime != -1) {
doneCount++;
totalExecution += executionTime;
maxExecution = Math.max(maxExecution, executionTime);
} else {
logger.info("Task {} started at {} did not finish", task, threadStart);
}
} else {
logger.info("Task {} never commenced execution", task);
}
}
Statistics<V> collect(final List<TimedCallable<V>> tasks) {
totalExecution = maxExecution = 0;
count = startedCount = doneCount = 0;
earliestStart = Long.MAX_VALUE;
latestStart = totalStart = 0;
tasks.forEach(this);
return this;
}
void log(final String activity, final Logger logger, int parallelism) {
if (startedCount > 0) {
logger.debug("{}: started {} out of {} using {} threads. (start time: min {} ms, avg {} ms, max {} ms).",
activity, startedCount, count, parallelism,
TimeUnit.NANOSECONDS.toMillis(earliestStart),
TimeUnit.NANOSECONDS.toMillis(totalStart) / startedCount,
TimeUnit.NANOSECONDS.toMillis(latestStart));
} else {
logger.debug("{}: started {} out of {} using {} threads.", activity, startedCount, count, parallelism);
}
if (doneCount > 0) {
logger.debug("{}: completed {} out of {} using {} threads (execution time: total {} ms, avg {} ms, max {} ms).",
activity, doneCount, count, parallelism, watch.elapsed(TimeUnit.MILLISECONDS),
TimeUnit.NANOSECONDS.toMillis(totalExecution) / doneCount, TimeUnit.NANOSECONDS.toMillis(maxExecution));
} else {
logger.debug("{}: completed {} out of {} using {} threads", activity, doneCount, count, parallelism);
}
}
}
@Override
public final V call() throws Exception {
long start = System.nanoTime();
startTime = start;
try {
logger.debug("Started execution of '{}' task at {} ms", this, TimeUnit.MILLISECONDS.convert(start, TimeUnit.NANOSECONDS));
return runInner();
} catch (InterruptedException e) {
logger.warn("Task '{}' interrupted", this, e);
throw e;
} finally {
long time = System.nanoTime() - start;
if (logger.isWarnEnabled()) {
long timeMillis = TimeUnit.MILLISECONDS.convert(time, TimeUnit.NANOSECONDS);
if (timeMillis > TIMEOUT_PER_RUNNABLE_IN_MSECS) {
logger.warn("Task '{}' execution time {} ms exceeds timeout {} ms.", this, timeMillis, TIMEOUT_PER_RUNNABLE_IN_MSECS);
} else {
logger.debug("Task '{}' execution time is {} ms", this, timeMillis);
}
}
executionTime = time;
}
}
protected abstract V runInner() throws Exception;
private long getStartTime(TimeUnit unit) {
return unit.convert(startTime, TimeUnit.NANOSECONDS);
}
private long getExecutionTime(TimeUnit unit) {
return unit.convert(executionTime, TimeUnit.NANOSECONDS);
}
/**
* Execute the list of runnables with the given parallelization. At end, return values and report completion time
* stats to provided logger. Each runnable is allowed a certain timeout. If the timeout exceeds, existing/pending
* tasks will be cancelled and a {@link UserException} is thrown.
* @param activity Name of activity for reporting in logger.
* @param logger The logger to use to report results.
* @param tasks List of callable that should be executed and timed. If this list has one item, task will be
* completed in-thread. Each callable must handle {@link InterruptedException}s.
* @param parallelism The number of threads that should be run to complete this task.
* @return The list of outcome objects.
* @throws IOException All exceptions are coerced to IOException since this was build for storage system tasks initially.
*/
public static <V> List<V> run(final String activity, final Logger logger, final List<TimedCallable<V>> tasks, int parallelism) throws IOException {
Preconditions.checkArgument(!Preconditions.checkNotNull(tasks).isEmpty(), "list of tasks is empty");
Preconditions.checkArgument(parallelism > 0);
parallelism = Math.min(parallelism, tasks.size());
final ExecutorService threadPool = parallelism == 1 ? MoreExecutors.newDirectExecutorService()
: Executors.newFixedThreadPool(parallelism, new ThreadFactoryBuilder().setNameFormat(activity + "-%d").build());
final long timeout = TIMEOUT_PER_RUNNABLE_IN_MSECS * ((tasks.size() - 1)/parallelism + 1);
final FutureMapper<V> futureMapper = new FutureMapper<>();
final Statistics<V> statistics = logger.isDebugEnabled() ? new Statistics<>() : null;
try {
return Collectors.toList(threadPool.invokeAll(tasks, timeout, TimeUnit.MILLISECONDS), futureMapper);
} catch (InterruptedException e) {
final String errMsg = String.format("Interrupted while waiting for activity '%s' tasks to be done.", activity);
logger.error(errMsg, e);
throw UserException.resourceError(e)
.message(errMsg)
.build(logger);
} catch (RejectedExecutionException e) {
final String errMsg = String.format("Failure while submitting activity '%s' tasks for execution.", activity);
logger.error(errMsg, e);
throw UserException.internalError(e)
.message(errMsg)
.build(logger);
} finally {
List<Runnable> notStartedTasks = threadPool.shutdownNow();
if (!notStartedTasks.isEmpty()) {
logger.error("{} activity '{}' tasks never commenced execution.", notStartedTasks.size(), activity);
}
try {
// Wait for 5s for currently running threads to terminate. Above call (threadPool.shutdownNow()) interrupts
// any running threads. If the tasks are handling the interrupts properly they should be able to
// wrap up and terminate. If not waiting for 5s here gives a chance to identify and log any potential
// thread leaks.
if (!threadPool.awaitTermination(5000, TimeUnit.MILLISECONDS)) {
logger.error("Detected run away tasks in activity '{}'.", activity);
}
} catch (final InterruptedException e) {
logger.warn("Interrupted while waiting for pending threads in activity '{}' to terminate.", activity);
}
if (statistics != null) {
statistics.collect(tasks).log(activity, logger, parallelism);
}
if (futureMapper.count != tasks.size()) {
final String errMsg = String.format("Waited for %d ms, but only %d tasks for '%s' are complete." +
" Total number of tasks %d, parallelism %d.", timeout, futureMapper.count, activity, tasks.size(), parallelism);
logger.error(errMsg, futureMapper.throwable);
throw UserException.resourceError(futureMapper.throwable)
.message(errMsg)
.build(logger);
}
if (futureMapper.throwable != null) {
throw (futureMapper.throwable instanceof IOException) ?
(IOException)futureMapper.throwable : new IOException(futureMapper.throwable);
}
}
}
}