SLING-4676 - Clean up threads or refresh threads when put back into the pool
git-svn-id: https://svn.apache.org/repos/asf/sling/trunk@1716601 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/pom.xml b/pom.xml
index e400284..95a65fc 100644
--- a/pom.xml
+++ b/pom.xml
@@ -92,5 +92,15 @@
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-simple</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</project>
diff --git a/src/main/java/org/apache/sling/commons/threads/ModifiableThreadPoolConfig.java b/src/main/java/org/apache/sling/commons/threads/ModifiableThreadPoolConfig.java
index 7f62afd..bfe789d 100644
--- a/src/main/java/org/apache/sling/commons/threads/ModifiableThreadPoolConfig.java
+++ b/src/main/java/org/apache/sling/commons/threads/ModifiableThreadPoolConfig.java
@@ -19,6 +19,7 @@
import aQute.bnd.annotation.ProviderType;
import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
/**
* This is a modifiable thread pool configuration that can be instantiated
@@ -45,6 +46,8 @@
public static final String PROPERTY_MAX_POOL_SIZE = "maxPoolSize";
/** Configuration property for the queue size. */
public static final String PROPERTY_QUEUE_SIZE = "queueSize";
+ /** Configuration property for the max thread age. */
+ public static final String PROPERTY_MAX_THREAD_AGE = "maxThreadAge";
/** Configuration property for the keep alive time. */
public static final String PROPERTY_KEEP_ALIVE_TIME = "keepAliveTime";
/** Configuration property for the block policy. */
@@ -69,6 +72,9 @@
/** The queue size */
private int queueSize = -1;
+ /** Max age of a thread in milliseconds */
+ private long maxThreadAge = TimeUnit.MINUTES.toMillis(5);
+
/** The keep alive time. */
private long keepAliveTime = 60000L;
@@ -85,7 +91,7 @@
private ThreadFactory factory;
/** Thread priority. */
- private ThreadPriority priority = ThreadPriority.NORM;
+ private ThreadPriority priority = ThreadPriority.NORM;
/** Create daemon threads? */
private boolean isDaemon = false;
@@ -106,6 +112,7 @@
this.minPoolSize = copy.getMinPoolSize();
this.maxPoolSize = copy.getMaxPoolSize();
this.queueSize = copy.getQueueSize();
+ this.maxThreadAge = copy.getMaxThreadAge();
this.keepAliveTime = copy.getKeepAliveTime();
this.blockPolicy = copy.getBlockPolicy();
this.shutdownGraceful = copy.isShutdownGraceful();
@@ -161,6 +168,22 @@
this.queueSize = queueSize;
}
+
+ /**
+ * @see org.apache.sling.commons.threads.ThreadPoolConfig#getMaxThreadAge()
+ */
+ public long getMaxThreadAge() {
+ return maxThreadAge;
+ }
+
+ /**
+ * Set the max thread age.
+ * @param maxThreadAge New max thread age in milliseconds.
+ */
+ public void setMaxThreadAge(final long maxThreadAge) {
+ this.maxThreadAge = maxThreadAge;
+ }
+
/**
* @see org.apache.sling.commons.threads.ThreadPoolConfig#getKeepAliveTime()
*/
@@ -282,6 +305,7 @@
return this.minPoolSize == o.minPoolSize
&& this.maxPoolSize == o.maxPoolSize
&& this.queueSize == o.queueSize
+ && this.maxThreadAge == o.maxThreadAge
&& this.keepAliveTime == o.keepAliveTime
&& this.blockPolicy.equals(o.blockPolicy)
&& this.shutdownGraceful == o.shutdownGraceful
diff --git a/src/main/java/org/apache/sling/commons/threads/ThreadPoolConfig.java b/src/main/java/org/apache/sling/commons/threads/ThreadPoolConfig.java
index 74cbf76..6bb641f 100644
--- a/src/main/java/org/apache/sling/commons/threads/ThreadPoolConfig.java
+++ b/src/main/java/org/apache/sling/commons/threads/ThreadPoolConfig.java
@@ -59,6 +59,12 @@
int getQueueSize();
/**
+ * Return the maximum age before a thread is retired.
+ * @return The maximum age of a thread in milliseconds.
+ */
+ long getMaxThreadAge();
+
+ /**
* Return the keep alive time.
* @return The keep alive time.
*/
diff --git a/src/main/java/org/apache/sling/commons/threads/impl/DefaultThreadPool.java b/src/main/java/org/apache/sling/commons/threads/impl/DefaultThreadPool.java
index 81b8740..4f5a1a2 100644
--- a/src/main/java/org/apache/sling/commons/threads/impl/DefaultThreadPool.java
+++ b/src/main/java/org/apache/sling/commons/threads/impl/DefaultThreadPool.java
@@ -126,8 +126,11 @@
handler = new ThreadPoolExecutor.CallerRunsPolicy();
break;
}
- this.executor = new ThreadPoolExecutor(this.configuration.getMinPoolSize(),
+
+ this.executor = new ThreadExpiringThreadPool(this.configuration.getMinPoolSize(),
this.configuration.getMaxPoolSize(),
+ this.configuration.getMaxThreadAge(),
+ TimeUnit.MILLISECONDS,
this.configuration.getKeepAliveTime(),
TimeUnit.MILLISECONDS,
queue,
@@ -204,7 +207,7 @@
logger.warn("Running commands have not terminated within "
+ this.configuration.getShutdownWaitTimeMs()
+ "ms. Will shut them down by interruption");
- this.executor.shutdownNow();
+ this.executor.shutdownNow(); // TODO: shouldn't this be outside the if statement?!
}
}
} catch (final InterruptedException ie) {
diff --git a/src/main/java/org/apache/sling/commons/threads/impl/ThreadExpiringThreadPool.java b/src/main/java/org/apache/sling/commons/threads/impl/ThreadExpiringThreadPool.java
new file mode 100644
index 0000000..ec5bab4
--- /dev/null
+++ b/src/main/java/org/apache/sling/commons/threads/impl/ThreadExpiringThreadPool.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sling.commons.threads.impl;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.RejectedExecutionHandler;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * An extension of ThreadPoolExecutor, which keeps track of the age
+ * of the worker threads and expires them when they get older than
+ * a specified max-age.
+ * <br/>
+ * To be precise, a thread is expired when it finishes processing
+ * a task and its max-age has been exceeded at that time. I.e. if a
+ * thread is idle past its expiry, it may still process a single
+ * task before it is expired.
+ */
+public class ThreadExpiringThreadPool extends ThreadPoolExecutor {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ThreadExpiringThreadPool.class);
+
+ /**
+ * Map from thread-id to the time (in milliseconds) when a thread was first used to
+ * process a task. This is used to look determine when a thread is to be expired.
+ */
+ private final ConcurrentHashMap<Long, Long> threadStartTimes;
+
+ /**
+ * Thread max-age in milliseconds.
+ */
+ private final long maxThreadAge;
+
+ /**
+ * Convenience flag indicating whether threads expire or not.
+ * This is equivalent to {@code maxThreadAge >= 0}.
+ */
+ private final boolean enableThreadExpiry;
+
+ /**
+ * Marker exception object thrown to terminate threads that have
+ * reached or exceeded their max-age. This exception is intentionally
+ * used for (minimal) control flow, i.e. the {@code ThreadPoolExecutor}
+ * will dispose of any thread that threw an exception and create a new
+ * one in its stead. This exception should never show up in any logs,
+ * otherwise it is a bug.
+ */
+ private final RuntimeException expiredThreadException;
+
+ public ThreadExpiringThreadPool(
+ final int corePoolSize,
+ final int maximumPoolSize,
+ final long maxThreadAge,
+ final TimeUnit maxThreadAgeUnit,
+ final long keepAliveTime,
+ final TimeUnit keepAliveTimeUnit,
+ final BlockingQueue<Runnable> workQueue,
+ final ThreadFactory threadFactory,
+ final RejectedExecutionHandler handler
+ ) {
+ super(corePoolSize, maximumPoolSize, keepAliveTime, keepAliveTimeUnit, workQueue, threadFactory, handler);
+ this.threadStartTimes = new ConcurrentHashMap<Long, Long>(maximumPoolSize);
+ this.maxThreadAge = TimeUnit.MILLISECONDS.convert(maxThreadAge, maxThreadAgeUnit);
+ this.enableThreadExpiry = maxThreadAge >= 0;
+ this.expiredThreadException = new RuntimeException("Kill old thread");
+ }
+
+ @Override
+ protected void beforeExecute(final Thread thread, final Runnable runnable) {
+ if (enableThreadExpiry) {
+ recordStartTime(thread);
+ }
+ super.beforeExecute(thread, runnable);
+ }
+
+ private void recordStartTime(final Thread thread) {
+ final long threadId = thread.getId();
+ if (threadStartTimes.putIfAbsent(threadId, System.currentTimeMillis()) == null) {
+ LOG.debug("{} used for the first time.", thread);
+
+ // The uncaught exception handler makes sure that the exception
+ // signalling the death of a thread is swallowed. All other
+ // Throwables are handed to the originalHandler.
+ final Thread.UncaughtExceptionHandler originalHandler = thread.getUncaughtExceptionHandler();
+ thread.setUncaughtExceptionHandler(new Thread.UncaughtExceptionHandler() {
+ @Override
+ public void uncaughtException(final Thread thread, final Throwable throwable) {
+ // first reset the original uncaught exception handler - just as a precaution
+ thread.setUncaughtExceptionHandler(originalHandler);
+
+ // ignore expected exception thrown to terminate the thread
+ if (throwable == expiredThreadException) {
+ return;
+ }
+
+ // delegate any other exceptions to the original uncaught exception handler
+ if (originalHandler != null) {
+ originalHandler.uncaughtException(thread, throwable);
+ }
+ }
+ });
+ }
+ }
+
+ @Override
+ protected void afterExecute(final Runnable runnable, final Throwable throwable) {
+ super.afterExecute(runnable, throwable);
+ if (throwable == null && enableThreadExpiry) {
+ checkMaxThreadAge(Thread.currentThread());
+ }
+ }
+
+ private void checkMaxThreadAge(final Thread thread) {
+ final long now = System.currentTimeMillis();
+ final long threadId = thread.getId();
+ final Long started = threadStartTimes.get(threadId);
+ if (started != null && now >= started + maxThreadAge) {
+ final long delta = now - (started + maxThreadAge);
+ LOG.debug("{} exceeded its max age by {}ms and will be replaced.", thread, delta);
+ threadStartTimes.remove(threadId);
+
+ // throw marker exception to kill this thread and thus trigger creation of a new one
+ throw expiredThreadException;
+ }
+ }
+}
diff --git a/src/main/java/org/apache/sling/commons/threads/impl/ThreadPoolMBeanImpl.java b/src/main/java/org/apache/sling/commons/threads/impl/ThreadPoolMBeanImpl.java
index fe96b97..bed13a1 100644
--- a/src/main/java/org/apache/sling/commons/threads/impl/ThreadPoolMBeanImpl.java
+++ b/src/main/java/org/apache/sling/commons/threads/impl/ThreadPoolMBeanImpl.java
@@ -100,6 +100,10 @@
}
}
+ public long getMaxThreadAge() {
+ return this.entry.getConfig().getMaxThreadAge();
+ }
+
public long getKeepAliveTime() {
return this.entry.getConfig().getKeepAliveTime();
}
diff --git a/src/main/java/org/apache/sling/commons/threads/jmx/ThreadPoolMBean.java b/src/main/java/org/apache/sling/commons/threads/jmx/ThreadPoolMBean.java
index 1674111..e933b91 100644
--- a/src/main/java/org/apache/sling/commons/threads/jmx/ThreadPoolMBean.java
+++ b/src/main/java/org/apache/sling/commons/threads/jmx/ThreadPoolMBean.java
@@ -83,6 +83,13 @@
long getExecutorTaskCount();
/**
+ * Return the configured max thread age.
+ *
+ * @return The configured max thread age.
+ */
+ long getMaxThreadAge();
+
+ /**
* Return the configured keep alive time.
*
* @return The configured keep alive time.
diff --git a/src/main/java/org/apache/sling/commons/threads/jmx/package-info.java b/src/main/java/org/apache/sling/commons/threads/jmx/package-info.java
index 2e9107f..53dcb2e 100644
--- a/src/main/java/org/apache/sling/commons/threads/jmx/package-info.java
+++ b/src/main/java/org/apache/sling/commons/threads/jmx/package-info.java
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-@Version("1.0.1")
+@Version("1.1.0")
package org.apache.sling.commons.threads.jmx;
import aQute.bnd.annotation.Version;
\ No newline at end of file
diff --git a/src/main/java/org/apache/sling/commons/threads/package-info.java b/src/main/java/org/apache/sling/commons/threads/package-info.java
index b7b9b97..3e2b736 100644
--- a/src/main/java/org/apache/sling/commons/threads/package-info.java
+++ b/src/main/java/org/apache/sling/commons/threads/package-info.java
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-@Version("3.2.1")
+@Version("3.3.0")
package org.apache.sling.commons.threads;
import aQute.bnd.annotation.Version;
\ No newline at end of file
diff --git a/src/main/resources/OSGI-INF/metatype/metatype.properties b/src/main/resources/OSGI-INF/metatype/metatype.properties
index 4108496..eccce0e 100644
--- a/src/main/resources/OSGI-INF/metatype/metatype.properties
+++ b/src/main/resources/OSGI-INF/metatype/metatype.properties
@@ -32,6 +32,10 @@
queueSize.name=Queue Size
queueSize.description=The queue size or -1 for an unlimited queue size.
+maxThreadAge.name=Max Thread Age
+maxThreadAge.description=Milliseconds before a pooled thread is replaced (-1 to disable expiry). \
+ Useful to avoid memory leaks by accumulation of ThreadLocals.
+
keepAliveTime.name=Keep Alive Time
keepAliveTime.description=The keep alive time.
diff --git a/src/main/resources/OSGI-INF/metatype/metatype.xml b/src/main/resources/OSGI-INF/metatype/metatype.xml
index a6c3f73..f1f7f4d 100644
--- a/src/main/resources/OSGI-INF/metatype/metatype.xml
+++ b/src/main/resources/OSGI-INF/metatype/metatype.xml
@@ -37,6 +37,9 @@
<metatype:AD id="queueSize"
type="Integer" default="-1" name="%queueSize.name"
description="%queueSize.description" />
+ <metatype:AD id="maxThreadAge"
+ type="Long" default="300000" name="%maxThreadAge.name"
+ description="%maxThreadAge.description" />
<metatype:AD id="keepAliveTime"
type="Long" default="60000" name="%keepAliveTime.name"
description="%keepAliveTime.description" />
diff --git a/src/test/java/org/apache/sling/commons/threads/impl/ThreadExpiringThreadPoolTest.java b/src/test/java/org/apache/sling/commons/threads/impl/ThreadExpiringThreadPoolTest.java
new file mode 100644
index 0000000..80c0e4f
--- /dev/null
+++ b/src/test/java/org/apache/sling/commons/threads/impl/ThreadExpiringThreadPoolTest.java
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sling.commons.threads.impl;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExternalResource;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.RejectedExecutionHandler;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static java.util.Arrays.asList;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.is;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+public class ThreadExpiringThreadPoolTest {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ThreadExpiringThreadPoolTest.class);
+
+ private static final int MAX_THREAD_AGE_MS = 15; // let threads expire after this many ms
+
+ @Rule
+ public ThreadPoolContext context = new ThreadPoolContext();
+
+ @Test
+ public void shouldCreateNewThreadAfterExpiry() throws InterruptedException, ExecutionException {
+ final TrackingThreadFactory threadFactory = context.getThreadFactory();
+ final ThreadExpiringThreadPool pool = context.getPool();
+
+ assertThat(threadFactory.getThreadCount(), is(0));
+
+ assertExecutionByThread(pool, "test-thread-0");
+ assertExecutionByThread(pool, "test-thread-0");
+ assertExecutionByThread(pool, "test-thread-0");
+ assertThat(threadFactory.getThreadCount(), is(1));
+
+ letThreadsDie();
+
+ // thread executes one more task after expiring
+ assertExecutionByThread(pool, "test-thread-0");
+ assertExecutionByThread(pool, "test-thread-1");
+ assertThat(threadFactory.getThreadCount(), is(2));
+
+ assertActiveThreads(threadFactory, "test-thread-1");
+ assertExpiredThreads(threadFactory, "test-thread-0");
+ }
+
+ @Test
+ public void shouldCreateNewThreadAfterExpiryForFailingTasks() throws InterruptedException, ExecutionException {
+ final TrackingThreadFactory threadFactory = context.getThreadFactory();
+ final ThreadExpiringThreadPool pool = context.getPool();
+
+ assertThat(threadFactory.getThreadCount(), is(0));
+
+ assertFailingSubmitThreadName(pool, "test-thread-0");
+ assertFailingSubmitThreadName(pool, "test-thread-0");
+ assertFailingSubmitThreadName(pool, "test-thread-0");
+ assertThat(threadFactory.getThreadCount(), is(1));
+
+ letThreadsDie();
+
+ // thread executes one more task after expiring
+ assertFailingSubmitThreadName(pool, "test-thread-0");
+ assertFailingSubmitThreadName(pool, "test-thread-1");
+ assertThat(threadFactory.getThreadCount(), is(2));
+
+ assertActiveThreads(threadFactory, "test-thread-1");
+ assertExpiredThreads(threadFactory, "test-thread-0");
+ }
+
+ @Test
+ public void shouldLetMultipleThreadsDieAfterExpiry()
+ throws ExecutionException, InterruptedException {
+
+ final TrackingThreadFactory threadFactory = context.getThreadFactory();
+ final ThreadExpiringThreadPool pool = context.getPool();
+ pool.setCorePoolSize(3);
+ pool.setMaximumPoolSize(3);
+
+ assertParallelExecutionsByThread(pool, "test-thread-0", "test-thread-1", "test-thread-2");
+ assertThat(threadFactory.getThreadCount(), is(3));
+
+ letThreadsDie();
+ // thread executes one more task after expiring
+ executeParallelTasks(pool, 3);
+
+ assertParallelExecutionsByThread(pool, "test-thread-3", "test-thread-4", "test-thread-5");
+ assertThat(threadFactory.getThreadCount(), is(6));
+
+ assertActiveThreads(threadFactory, "test-thread-3", "test-thread-4", "test-thread-5");
+ assertExpiredThreads(threadFactory, "test-thread-0", "test-thread-1", "test-thread-2");
+ }
+
+ private void assertActiveThreads(final TrackingThreadFactory factory, final String... names) {
+ assertThat("Active threads", factory.getActiveThreads(), equalTo(asSet(names)));
+ }
+
+ private void assertExpiredThreads(final TrackingThreadFactory factory, final String... names) {
+ assertThat("Expired threads", factory.getExpiredThreads(), equalTo(asSet(names)));
+ }
+
+ private Set<String> asSet(final String... items) {
+ return new HashSet<String>(asList(items));
+ }
+
+ private void assertParallelExecutionsByThread(final ExecutorService pool, final String... expectedThreads)
+ throws InterruptedException {
+
+ final Task[] tasks = executeParallelTasks(pool, 3);
+ final List<String> threadNames = new ArrayList<String>();
+ for (final Task task : tasks) {
+ threadNames.add(task.executedBy);
+ }
+ for (final String expectedThread : expectedThreads) {
+ assertTrue("No task was executed by " + expectedThread,
+ threadNames.remove(expectedThread));
+ assertFalse("Multiple tasks were executed by " + expectedThread,
+ threadNames.contains(expectedThread));
+ }
+ }
+
+ private Task[] executeParallelTasks(final ExecutorService pool, final int number)
+ throws InterruptedException {
+ final Task[] tasks = new Task[number];
+ final CountDownLatch latch = new CountDownLatch(number);
+ for (int i = 0; i < tasks.length; i++) {
+ tasks[i] = new Task(latch);
+ pool.execute(tasks[i]);
+ }
+ pool.awaitTermination(MAX_THREAD_AGE_MS, TimeUnit.MILLISECONDS);
+ return tasks;
+ }
+
+ private void assertExecutionByThread(final ExecutorService pool, final String expectedThread)
+ throws ExecutionException, InterruptedException {
+ final Task task = new Task();
+ pool.submit(task).get();
+ assertEquals("Thread name", expectedThread, task.executedBy);
+ }
+
+ private void assertFailingSubmitThreadName(final ExecutorService pool, final String expectedThread)
+ throws ExecutionException, InterruptedException {
+ final Task task = new ExceptionTask();
+ try {
+ pool.submit(task).get();
+ } catch (ExecutionException e) {
+ if (!e.getCause().getMessage().startsWith("ExceptionTask #")) {
+ LOG.error("Unexpected exception: ", e);
+ fail("Unexpected exception: " + e.getMessage());
+ }
+ }
+ assertEquals("Thread name", expectedThread, task.executedBy);
+ }
+
+ private void letThreadsDie() throws InterruptedException {
+ TimeUnit.MILLISECONDS.sleep(MAX_THREAD_AGE_MS * 2);
+ }
+
+ private static class Task implements Runnable {
+
+ private static int counter = 0;
+
+ protected final int count;
+
+ private final CountDownLatch mayFinish;
+
+ protected String executedBy;
+
+ Task() {
+ this(new CountDownLatch(0));
+ }
+
+ Task(final CountDownLatch latch) {
+ this.mayFinish = latch;
+ this.count = counter++;
+ }
+
+ @Override
+ public void run() {
+ mayFinish.countDown();
+ final Thread thread = Thread.currentThread();
+ try {
+ mayFinish.await();
+ } catch (InterruptedException e) {
+ thread.interrupt();
+ }
+ LOG.info("{} #{} running in thread {}",
+ new Object[] {getClass().getSimpleName(), count, thread});
+ executedBy = thread.getName();
+ }
+ }
+
+ private static class ExceptionTask extends Task {
+ @Override
+ public void run() {
+ super.run();
+ throw new RuntimeException("ExceptionTask #" + count);
+ }
+ }
+
+ private static class TrackingThreadFactory implements ThreadFactory {
+
+ private final ThreadGroup group;
+
+ private final AtomicInteger threadCount = new AtomicInteger(0);
+
+ private final List<Thread> threadHistory = new CopyOnWriteArrayList<Thread>();
+
+ public TrackingThreadFactory() {
+ group = Thread.currentThread().getThreadGroup();
+ }
+
+ public int getThreadCount() {
+ return threadHistory.size();
+ }
+
+ public Set<String> getActiveThreads() {
+ final HashSet<String> active = new HashSet<String>();
+ for (final Thread thread : threadHistory) {
+ if (thread.isAlive()) {
+ active.add(thread.getName());
+ }
+ }
+ return active;
+ }
+
+ public Set<String> getExpiredThreads() {
+ final HashSet<String> expired = new HashSet<String>();
+ for (final Thread thread : threadHistory) {
+ if (!thread.isAlive()) {
+ expired.add(thread.getName());
+ }
+ }
+ return expired;
+ }
+
+ @Override
+ public Thread newThread(final Runnable r) {
+ final Thread thread = new Thread(group, r, "test-thread-" + threadCount.getAndIncrement());
+ thread.setDaemon(false);
+ thread.setPriority(Thread.NORM_PRIORITY);
+ threadHistory.add(thread);
+ LOG.info("Created thread {}", thread.getName());
+ return thread;
+ }
+ }
+
+ public static class ThreadPoolContext extends ExternalResource {
+
+ public TrackingThreadFactory getThreadFactory() {
+ return threadFactory;
+ }
+
+ public ThreadExpiringThreadPool getPool() {
+ return pool;
+ }
+
+ private TrackingThreadFactory threadFactory;
+
+ private ThreadExpiringThreadPool pool;
+
+ @Override
+ protected void before() throws Throwable {
+ Task.counter = 0; // reset counter
+ final BlockingQueue<Runnable> queue = new ArrayBlockingQueue<Runnable>(20);
+ final RejectedExecutionHandler rejectionHandler = new ThreadPoolExecutor.AbortPolicy();
+ threadFactory = new TrackingThreadFactory();
+ pool = new ThreadExpiringThreadPool(
+ 1, 1,
+ MAX_THREAD_AGE_MS, TimeUnit.MILLISECONDS,
+ 1000, TimeUnit.MILLISECONDS,
+ queue, threadFactory, rejectionHandler);
+ }
+
+ @Override
+ protected void after() {
+ threadFactory = null;
+ pool = null;
+ }
+ }
+}
+