blob: 6664dbf02b38b4c5dd5ca01cee1fea03298bb604 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.common.util;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.RejectedExecutionHandler;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.MDC;
public class ExecutorUtil {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private static volatile List<InheritableThreadLocalProvider> providers = new ArrayList<>();
public synchronized static void addThreadLocalProvider(InheritableThreadLocalProvider provider) {
for (InheritableThreadLocalProvider p : providers) {//this is to avoid accidental multiple addition of providers in tests
if (p.getClass().equals(provider.getClass())) return;
}
List<InheritableThreadLocalProvider> copy = new ArrayList<>(providers);
copy.add(provider);
providers = copy;
}
/** Any class which wants to carry forward the threadlocal values to the threads run
* by threadpools must implement this interface and the implementation should be
* registered here
*/
public interface InheritableThreadLocalProvider {
/**This is invoked in the parent thread which submitted a task.
* copy the necessary Objects to the ctx. The object that is passed is same
* across all three methods
*/
public void store(AtomicReference<?> ctx);
/**This is invoked in the Threadpool thread. set the appropriate values in the threadlocal
* of this thread. */
public void set(AtomicReference<?> ctx);
/**This method is invoked in the threadpool thread after the execution
* clean all the variables set in the set method
*/
public void clean(AtomicReference<?> ctx);
}
public static void shutdownAndAwaitTermination(ExecutorService pool) {
if (pool == null) return;
pool.shutdown(); // Disable new tasks from being submitted
awaitTermination(pool);
}
public static void shutdownNowAndAwaitTermination(ExecutorService pool) {
if (pool == null) return;
pool.shutdownNow(); // Disable new tasks from being submitted; interrupt existing tasks
awaitTermination(pool);
}
public static void awaitTermination(ExecutorService pool) {
boolean shutdown = false;
while (!shutdown) {
try {
// Wait a while for existing tasks to terminate
shutdown = pool.awaitTermination(60, TimeUnit.SECONDS);
} catch (InterruptedException ie) {
// Preserve interrupt status
Thread.currentThread().interrupt();
}
}
}
/**
* See {@link java.util.concurrent.Executors#newFixedThreadPool(int, ThreadFactory)}
*/
public static ExecutorService newMDCAwareFixedThreadPool(int nThreads, ThreadFactory threadFactory) {
return new MDCAwareThreadPoolExecutor(nThreads, nThreads,
0L, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<Runnable>(),
threadFactory);
}
/**
* See {@link java.util.concurrent.Executors#newSingleThreadExecutor(ThreadFactory)}
*/
public static ExecutorService newMDCAwareSingleThreadExecutor(ThreadFactory threadFactory) {
return new MDCAwareThreadPoolExecutor(1, 1,
0L, TimeUnit.MILLISECONDS,
new LinkedBlockingQueue<>(),
threadFactory);
}
/**
* Create a cached thread pool using a named thread factory
*/
public static ExecutorService newMDCAwareCachedThreadPool(String name) {
return newMDCAwareCachedThreadPool(new SolrNamedThreadFactory(name));
}
/**
* See {@link java.util.concurrent.Executors#newCachedThreadPool(ThreadFactory)}
*/
public static ExecutorService newMDCAwareCachedThreadPool(ThreadFactory threadFactory) {
return new MDCAwareThreadPoolExecutor(0, Integer.MAX_VALUE,
60L, TimeUnit.SECONDS,
new SynchronousQueue<>(),
threadFactory);
}
public static ExecutorService newMDCAwareCachedThreadPool(int maxThreads, ThreadFactory threadFactory) {
return new MDCAwareThreadPoolExecutor(0, maxThreads,
60L, TimeUnit.SECONDS,
new LinkedBlockingQueue<>(maxThreads),
threadFactory);
}
@SuppressForbidden(reason = "class customizes ThreadPoolExecutor so it can be used instead")
public static class MDCAwareThreadPoolExecutor extends ThreadPoolExecutor {
private static final int MAX_THREAD_NAME_LEN = 512;
private final boolean enableSubmitterStackTrace;
public MDCAwareThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, RejectedExecutionHandler handler) {
super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, handler);
this.enableSubmitterStackTrace = true;
}
public MDCAwareThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue) {
super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue);
this.enableSubmitterStackTrace = true;
}
public MDCAwareThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory) {
this(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory, true);
}
public MDCAwareThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, ThreadFactory threadFactory, boolean enableSubmitterStackTrace) {
super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, threadFactory);
this.enableSubmitterStackTrace = enableSubmitterStackTrace;
}
public MDCAwareThreadPoolExecutor(int corePoolSize, int maximumPoolSize, long keepAliveTime, TimeUnit unit, BlockingQueue<Runnable> workQueue, RejectedExecutionHandler handler) {
super(corePoolSize, maximumPoolSize, keepAliveTime, unit, workQueue, handler);
this.enableSubmitterStackTrace = true;
}
@Override
public void execute(final Runnable command) {
final Map<String, String> submitterContext = MDC.getCopyOfContextMap();
StringBuilder contextString = new StringBuilder();
if (submitterContext != null) {
Collection<String> values = submitterContext.values();
for (String value : values) {
contextString.append(value).append(' ');
}
if (contextString.length() > 1) {
contextString.setLength(contextString.length() - 1);
}
}
String ctxStr = contextString.toString().replace("/", "//");
final String submitterContextStr = ctxStr.length() <= MAX_THREAD_NAME_LEN ? ctxStr : ctxStr.substring(0, MAX_THREAD_NAME_LEN);
final Exception submitterStackTrace = enableSubmitterStackTrace ? new Exception("Submitter stack trace") : null;
final List<InheritableThreadLocalProvider> providersCopy = providers;
@SuppressWarnings({"rawtypes"})
final ArrayList<AtomicReference> ctx = providersCopy.isEmpty() ? null : new ArrayList<>(providersCopy.size());
if (ctx != null) {
for (int i = 0; i < providers.size(); i++) {
@SuppressWarnings({"rawtypes"})
AtomicReference reference = new AtomicReference();
ctx.add(reference);
providersCopy.get(i).store(reference);
}
}
super.execute(() -> {
isServerPool.set(Boolean.TRUE);
if (ctx != null) {
for (int i = 0; i < providersCopy.size(); i++) providersCopy.get(i).set(ctx.get(i));
}
Map<String, String> threadContext = MDC.getCopyOfContextMap();
final Thread currentThread = Thread.currentThread();
final String oldName = currentThread.getName();
if (submitterContext != null && !submitterContext.isEmpty()) {
MDC.setContextMap(submitterContext);
currentThread.setName(oldName + "-processing-" + submitterContextStr);
} else {
MDC.clear();
}
try {
command.run();
} catch (Throwable t) {
if (t instanceof OutOfMemoryError) {
throw t;
}
if (enableSubmitterStackTrace) {
log.error("Uncaught exception {} thrown by thread: {}", t, currentThread.getName(), submitterStackTrace);
} else {
log.error("Uncaught exception {} thrown by thread: {}", t, currentThread.getName());
}
throw t;
} finally {
isServerPool.remove();
if (threadContext != null && !threadContext.isEmpty()) {
MDC.setContextMap(threadContext);
} else {
MDC.clear();
}
if (ctx != null) {
for (int i = 0; i < providersCopy.size(); i++) providersCopy.get(i).clean(ctx.get(i));
}
currentThread.setName(oldName);
}
});
}
}
private static final ThreadLocal<Boolean> isServerPool = new ThreadLocal<>();
/// this tells whether a thread is owned/run by solr or not.
public static boolean isSolrServerThread() {
return Boolean.TRUE.equals(isServerPool.get());
}
public static void setServerThreadFlag(Boolean flag) {
if (flag == null) isServerPool.remove();
else isServerPool.set(flag);
}
}