SLING-4676 - Clean up threads or refresh threads when put back into the pool

git-svn-id: https://svn.apache.org/repos/asf/sling/trunk@1716601 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/pom.xml b/pom.xml
index e400284..95a65fc 100644
--- a/pom.xml
+++ b/pom.xml
@@ -92,5 +92,15 @@
             <groupId>org.slf4j</groupId>
             <artifactId>slf4j-api</artifactId>
         </dependency>
+        <dependency>
+            <groupId>junit</groupId>
+            <artifactId>junit</artifactId>
+            <scope>test</scope>
+        </dependency>
+        <dependency>
+            <groupId>org.slf4j</groupId>
+            <artifactId>slf4j-simple</artifactId>
+            <scope>test</scope>
+        </dependency>
     </dependencies>
 </project>
diff --git a/src/main/java/org/apache/sling/commons/threads/ModifiableThreadPoolConfig.java b/src/main/java/org/apache/sling/commons/threads/ModifiableThreadPoolConfig.java
index 7f62afd..bfe789d 100644
--- a/src/main/java/org/apache/sling/commons/threads/ModifiableThreadPoolConfig.java
+++ b/src/main/java/org/apache/sling/commons/threads/ModifiableThreadPoolConfig.java
@@ -19,6 +19,7 @@
 import aQute.bnd.annotation.ProviderType;
 
 import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
 
 /**
  * This is a modifiable thread pool configuration that can be instantiated
@@ -45,6 +46,8 @@
     public static final String PROPERTY_MAX_POOL_SIZE = "maxPoolSize";
     /** Configuration property for the queue size. */
     public static final String PROPERTY_QUEUE_SIZE = "queueSize";
+    /** Configuration property for the max thread age. */
+    public static final String PROPERTY_MAX_THREAD_AGE = "maxThreadAge";
     /** Configuration property for the keep alive time. */
     public static final String PROPERTY_KEEP_ALIVE_TIME = "keepAliveTime";
     /** Configuration property for the block policy. */
@@ -69,6 +72,9 @@
     /** The queue size */
     private int queueSize = -1;
 
+    /** Max age of a thread in milliseconds */
+    private long maxThreadAge = TimeUnit.MINUTES.toMillis(5);
+
     /** The keep alive time. */
     private long  keepAliveTime = 60000L;
 
@@ -85,7 +91,7 @@
     private  ThreadFactory factory;
 
     /** Thread priority. */
-    private  ThreadPriority   priority = ThreadPriority.NORM;
+    private  ThreadPriority priority = ThreadPriority.NORM;
 
     /** Create daemon threads? */
     private  boolean isDaemon = false;
@@ -106,6 +112,7 @@
             this.minPoolSize = copy.getMinPoolSize();
             this.maxPoolSize = copy.getMaxPoolSize();
             this.queueSize = copy.getQueueSize();
+            this.maxThreadAge = copy.getMaxThreadAge();
             this.keepAliveTime = copy.getKeepAliveTime();
             this.blockPolicy = copy.getBlockPolicy();
             this.shutdownGraceful = copy.isShutdownGraceful();
@@ -161,6 +168,22 @@
         this.queueSize = queueSize;
     }
 
+
+    /**
+     * @see org.apache.sling.commons.threads.ThreadPoolConfig#getMaxThreadAge()
+     */
+    public long getMaxThreadAge() {
+        return maxThreadAge;
+    }
+
+    /**
+     * Set the max thread age.
+     * @param maxThreadAge New max thread age in milliseconds.
+     */
+    public void setMaxThreadAge(final long maxThreadAge) {
+        this.maxThreadAge = maxThreadAge;
+    }
+
     /**
      * @see org.apache.sling.commons.threads.ThreadPoolConfig#getKeepAliveTime()
      */
@@ -282,6 +305,7 @@
             return this.minPoolSize == o.minPoolSize
                 && this.maxPoolSize == o.maxPoolSize
                 && this.queueSize == o.queueSize
+                && this.maxThreadAge == o.maxThreadAge
                 && this.keepAliveTime == o.keepAliveTime
                 && this.blockPolicy.equals(o.blockPolicy)
                 && this.shutdownGraceful == o.shutdownGraceful
diff --git a/src/main/java/org/apache/sling/commons/threads/ThreadPoolConfig.java b/src/main/java/org/apache/sling/commons/threads/ThreadPoolConfig.java
index 74cbf76..6bb641f 100644
--- a/src/main/java/org/apache/sling/commons/threads/ThreadPoolConfig.java
+++ b/src/main/java/org/apache/sling/commons/threads/ThreadPoolConfig.java
@@ -59,6 +59,12 @@
     int getQueueSize();
 
     /**
+     * Return the maximum age before a thread is retired.
+     * @return The maximum age of a thread in milliseconds.
+     */
+    long getMaxThreadAge();
+
+    /**
      * Return the keep alive time.
      * @return The keep alive time.
      */
diff --git a/src/main/java/org/apache/sling/commons/threads/impl/DefaultThreadPool.java b/src/main/java/org/apache/sling/commons/threads/impl/DefaultThreadPool.java
index 81b8740..4f5a1a2 100644
--- a/src/main/java/org/apache/sling/commons/threads/impl/DefaultThreadPool.java
+++ b/src/main/java/org/apache/sling/commons/threads/impl/DefaultThreadPool.java
@@ -126,8 +126,11 @@
                 handler = new ThreadPoolExecutor.CallerRunsPolicy();
                 break;
         }
-        this.executor = new ThreadPoolExecutor(this.configuration.getMinPoolSize(),
+
+        this.executor = new ThreadExpiringThreadPool(this.configuration.getMinPoolSize(),
                 this.configuration.getMaxPoolSize(),
+                this.configuration.getMaxThreadAge(),
+                TimeUnit.MILLISECONDS,
                 this.configuration.getKeepAliveTime(),
                 TimeUnit.MILLISECONDS,
                 queue,
@@ -204,7 +207,7 @@
                         logger.warn("Running commands have not terminated within "
                             + this.configuration.getShutdownWaitTimeMs()
                             + "ms. Will shut them down by interruption");
-                        this.executor.shutdownNow();
+                        this.executor.shutdownNow(); // TODO: shouldn't this be outside the if statement?!
                     }
                 }
             } catch (final InterruptedException ie) {
diff --git a/src/main/java/org/apache/sling/commons/threads/impl/ThreadExpiringThreadPool.java b/src/main/java/org/apache/sling/commons/threads/impl/ThreadExpiringThreadPool.java
new file mode 100644
index 0000000..ec5bab4
--- /dev/null
+++ b/src/main/java/org/apache/sling/commons/threads/impl/ThreadExpiringThreadPool.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sling.commons.threads.impl;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.RejectedExecutionHandler;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * An extension of ThreadPoolExecutor, which keeps track of the age
+ * of the worker threads and expires them when they get older than
+ * a specified max-age.
+ * <br/>
+ * To be precise, a thread is expired when it finishes processing
+ * a task and its max-age has been exceeded at that time. I.e. if a
+ * thread is idle past its expiry, it may still process a single
+ * task before it is expired.
+ */
+public class ThreadExpiringThreadPool extends ThreadPoolExecutor {
+
+    private static final Logger LOG = LoggerFactory.getLogger(ThreadExpiringThreadPool.class);
+
+    /**
+     * Map from thread-id to the time (in milliseconds) when a thread was first used to
+     * process a task. This is used to look determine when a thread is to be expired.
+     */
+    private final ConcurrentHashMap<Long, Long> threadStartTimes;
+
+    /**
+     * Thread max-age in milliseconds.
+     */
+    private final long maxThreadAge;
+
+    /**
+     * Convenience flag indicating whether threads expire or not.
+     * This is equivalent to {@code maxThreadAge >= 0}.
+     */
+    private final boolean enableThreadExpiry;
+
+    /**
+     * Marker exception object thrown to terminate threads that have
+     * reached or exceeded their max-age. This exception is intentionally
+     * used for (minimal) control flow, i.e. the {@code ThreadPoolExecutor}
+     * will dispose of any thread that threw an exception and create a new
+     * one in its stead. This exception should never show up in any logs,
+     * otherwise it is a bug.
+     */
+    private final RuntimeException expiredThreadException;
+
+    public ThreadExpiringThreadPool(
+            final int corePoolSize,
+            final int maximumPoolSize,
+            final long maxThreadAge,
+            final TimeUnit maxThreadAgeUnit,
+            final long keepAliveTime,
+            final TimeUnit keepAliveTimeUnit,
+            final BlockingQueue<Runnable> workQueue,
+            final ThreadFactory threadFactory,
+            final RejectedExecutionHandler handler
+    ) {
+        super(corePoolSize, maximumPoolSize, keepAliveTime, keepAliveTimeUnit, workQueue, threadFactory, handler);
+        this.threadStartTimes = new ConcurrentHashMap<Long, Long>(maximumPoolSize);
+        this.maxThreadAge = TimeUnit.MILLISECONDS.convert(maxThreadAge, maxThreadAgeUnit);
+        this.enableThreadExpiry = maxThreadAge >= 0;
+        this.expiredThreadException = new RuntimeException("Kill old thread");
+    }
+
+    @Override
+    protected void beforeExecute(final Thread thread, final Runnable runnable) {
+        if (enableThreadExpiry) {
+            recordStartTime(thread);
+        }
+        super.beforeExecute(thread, runnable);
+    }
+
+    private void recordStartTime(final Thread thread) {
+        final long threadId = thread.getId();
+        if (threadStartTimes.putIfAbsent(threadId, System.currentTimeMillis()) == null) {
+            LOG.debug("{} used for the first time.", thread);
+
+            // The uncaught exception handler makes sure that the exception
+            // signalling the death of a thread is swallowed. All other
+            // Throwables are handed to the originalHandler.
+            final Thread.UncaughtExceptionHandler originalHandler = thread.getUncaughtExceptionHandler();
+            thread.setUncaughtExceptionHandler(new Thread.UncaughtExceptionHandler() {
+                @Override
+                public void uncaughtException(final Thread thread, final Throwable throwable) {
+                    // first reset the original uncaught exception handler - just as a precaution
+                    thread.setUncaughtExceptionHandler(originalHandler);
+
+                    // ignore expected exception thrown to terminate the thread
+                    if (throwable == expiredThreadException) {
+                        return;
+                    }
+
+                    // delegate any other exceptions to the original uncaught exception handler
+                    if (originalHandler != null) {
+                        originalHandler.uncaughtException(thread, throwable);
+                    }
+                }
+            });
+        }
+    }
+
+    @Override
+    protected void afterExecute(final Runnable runnable, final Throwable throwable) {
+        super.afterExecute(runnable, throwable);
+        if (throwable == null && enableThreadExpiry) {
+            checkMaxThreadAge(Thread.currentThread());
+        }
+    }
+
+    private void checkMaxThreadAge(final Thread thread) {
+        final long now = System.currentTimeMillis();
+        final long threadId = thread.getId();
+        final Long started = threadStartTimes.get(threadId);
+        if (started != null && now >= started + maxThreadAge) {
+            final long delta = now - (started + maxThreadAge);
+            LOG.debug("{} exceeded its max age by {}ms and will be replaced.", thread, delta);
+            threadStartTimes.remove(threadId);
+
+            // throw marker exception to kill this thread and thus trigger creation of a new one
+            throw expiredThreadException;
+        }
+    }
+}
diff --git a/src/main/java/org/apache/sling/commons/threads/impl/ThreadPoolMBeanImpl.java b/src/main/java/org/apache/sling/commons/threads/impl/ThreadPoolMBeanImpl.java
index fe96b97..bed13a1 100644
--- a/src/main/java/org/apache/sling/commons/threads/impl/ThreadPoolMBeanImpl.java
+++ b/src/main/java/org/apache/sling/commons/threads/impl/ThreadPoolMBeanImpl.java
@@ -100,6 +100,10 @@
         }
     }
 
+    public long getMaxThreadAge() {
+        return this.entry.getConfig().getMaxThreadAge();
+    }
+
     public long getKeepAliveTime() {
         return this.entry.getConfig().getKeepAliveTime();
     }
diff --git a/src/main/java/org/apache/sling/commons/threads/jmx/ThreadPoolMBean.java b/src/main/java/org/apache/sling/commons/threads/jmx/ThreadPoolMBean.java
index 1674111..e933b91 100644
--- a/src/main/java/org/apache/sling/commons/threads/jmx/ThreadPoolMBean.java
+++ b/src/main/java/org/apache/sling/commons/threads/jmx/ThreadPoolMBean.java
@@ -83,6 +83,13 @@
     long getExecutorTaskCount();
 
     /**
+     * Return the configured max thread age.
+     *
+     * @return The configured max thread age.
+     */
+    long getMaxThreadAge();
+
+    /**
      * Return the configured keep alive time.
      * 
      * @return The configured keep alive time.
diff --git a/src/main/java/org/apache/sling/commons/threads/jmx/package-info.java b/src/main/java/org/apache/sling/commons/threads/jmx/package-info.java
index 2e9107f..53dcb2e 100644
--- a/src/main/java/org/apache/sling/commons/threads/jmx/package-info.java
+++ b/src/main/java/org/apache/sling/commons/threads/jmx/package-info.java
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-@Version("1.0.1")
+@Version("1.1.0")
 package org.apache.sling.commons.threads.jmx;
 
 import aQute.bnd.annotation.Version;
\ No newline at end of file
diff --git a/src/main/java/org/apache/sling/commons/threads/package-info.java b/src/main/java/org/apache/sling/commons/threads/package-info.java
index b7b9b97..3e2b736 100644
--- a/src/main/java/org/apache/sling/commons/threads/package-info.java
+++ b/src/main/java/org/apache/sling/commons/threads/package-info.java
@@ -14,7 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-@Version("3.2.1")
+@Version("3.3.0")
 package org.apache.sling.commons.threads;
 
 import aQute.bnd.annotation.Version;
\ No newline at end of file
diff --git a/src/main/resources/OSGI-INF/metatype/metatype.properties b/src/main/resources/OSGI-INF/metatype/metatype.properties
index 4108496..eccce0e 100644
--- a/src/main/resources/OSGI-INF/metatype/metatype.properties
+++ b/src/main/resources/OSGI-INF/metatype/metatype.properties
@@ -32,6 +32,10 @@
 queueSize.name=Queue Size
 queueSize.description=The queue size or -1 for an unlimited queue size.
 
+maxThreadAge.name=Max Thread Age
+maxThreadAge.description=Milliseconds before a pooled thread is replaced (-1 to disable expiry). \
+  Useful to avoid memory leaks by accumulation of ThreadLocals.
+
 keepAliveTime.name=Keep Alive Time
 keepAliveTime.description=The keep alive time.
 
diff --git a/src/main/resources/OSGI-INF/metatype/metatype.xml b/src/main/resources/OSGI-INF/metatype/metatype.xml
index a6c3f73..f1f7f4d 100644
--- a/src/main/resources/OSGI-INF/metatype/metatype.xml
+++ b/src/main/resources/OSGI-INF/metatype/metatype.xml
@@ -37,6 +37,9 @@
         <metatype:AD id="queueSize"
             type="Integer" default="-1" name="%queueSize.name"
             description="%queueSize.description" />
+        <metatype:AD id="maxThreadAge"
+            type="Long" default="300000" name="%maxThreadAge.name"
+            description="%maxThreadAge.description" />
         <metatype:AD id="keepAliveTime"
             type="Long" default="60000" name="%keepAliveTime.name"
             description="%keepAliveTime.description" />
diff --git a/src/test/java/org/apache/sling/commons/threads/impl/ThreadExpiringThreadPoolTest.java b/src/test/java/org/apache/sling/commons/threads/impl/ThreadExpiringThreadPoolTest.java
new file mode 100644
index 0000000..80c0e4f
--- /dev/null
+++ b/src/test/java/org/apache/sling/commons/threads/impl/ThreadExpiringThreadPoolTest.java
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.sling.commons.threads.impl;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExternalResource;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.RejectedExecutionHandler;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static java.util.Arrays.asList;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.is;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+public class ThreadExpiringThreadPoolTest {
+
+    private static final Logger LOG = LoggerFactory.getLogger(ThreadExpiringThreadPoolTest.class);
+
+    private static final int MAX_THREAD_AGE_MS = 15; // let threads expire after this many ms
+
+    @Rule
+    public ThreadPoolContext context = new ThreadPoolContext();
+
+    @Test
+    public void shouldCreateNewThreadAfterExpiry() throws InterruptedException, ExecutionException {
+        final TrackingThreadFactory threadFactory = context.getThreadFactory();
+        final ThreadExpiringThreadPool pool = context.getPool();
+
+        assertThat(threadFactory.getThreadCount(), is(0));
+
+        assertExecutionByThread(pool, "test-thread-0");
+        assertExecutionByThread(pool, "test-thread-0");
+        assertExecutionByThread(pool, "test-thread-0");
+        assertThat(threadFactory.getThreadCount(), is(1));
+
+        letThreadsDie();
+
+        // thread executes one more task after expiring
+        assertExecutionByThread(pool, "test-thread-0");
+        assertExecutionByThread(pool, "test-thread-1");
+        assertThat(threadFactory.getThreadCount(), is(2));
+
+        assertActiveThreads(threadFactory, "test-thread-1");
+        assertExpiredThreads(threadFactory, "test-thread-0");
+    }
+
+    @Test
+    public void shouldCreateNewThreadAfterExpiryForFailingTasks() throws InterruptedException, ExecutionException {
+        final TrackingThreadFactory threadFactory = context.getThreadFactory();
+        final ThreadExpiringThreadPool pool = context.getPool();
+
+        assertThat(threadFactory.getThreadCount(), is(0));
+
+        assertFailingSubmitThreadName(pool, "test-thread-0");
+        assertFailingSubmitThreadName(pool, "test-thread-0");
+        assertFailingSubmitThreadName(pool, "test-thread-0");
+        assertThat(threadFactory.getThreadCount(), is(1));
+
+        letThreadsDie();
+
+        // thread executes one more task after expiring
+        assertFailingSubmitThreadName(pool, "test-thread-0");
+        assertFailingSubmitThreadName(pool, "test-thread-1");
+        assertThat(threadFactory.getThreadCount(), is(2));
+
+        assertActiveThreads(threadFactory, "test-thread-1");
+        assertExpiredThreads(threadFactory, "test-thread-0");
+    }
+
+    @Test
+    public void shouldLetMultipleThreadsDieAfterExpiry()
+            throws ExecutionException, InterruptedException {
+
+        final TrackingThreadFactory threadFactory = context.getThreadFactory();
+        final ThreadExpiringThreadPool pool = context.getPool();
+        pool.setCorePoolSize(3);
+        pool.setMaximumPoolSize(3);
+
+        assertParallelExecutionsByThread(pool, "test-thread-0", "test-thread-1", "test-thread-2");
+        assertThat(threadFactory.getThreadCount(), is(3));
+
+        letThreadsDie();
+        // thread executes one more task after expiring
+        executeParallelTasks(pool, 3);
+
+        assertParallelExecutionsByThread(pool, "test-thread-3", "test-thread-4", "test-thread-5");
+        assertThat(threadFactory.getThreadCount(), is(6));
+
+        assertActiveThreads(threadFactory, "test-thread-3", "test-thread-4", "test-thread-5");
+        assertExpiredThreads(threadFactory, "test-thread-0", "test-thread-1", "test-thread-2");
+    }
+
+    private void assertActiveThreads(final TrackingThreadFactory factory, final String... names) {
+        assertThat("Active threads", factory.getActiveThreads(), equalTo(asSet(names)));
+    }
+
+    private void assertExpiredThreads(final TrackingThreadFactory factory, final String... names) {
+        assertThat("Expired threads", factory.getExpiredThreads(), equalTo(asSet(names)));
+    }
+
+    private Set<String> asSet(final String... items) {
+        return new HashSet<String>(asList(items));
+    }
+
+    private void assertParallelExecutionsByThread(final ExecutorService pool, final String... expectedThreads)
+            throws InterruptedException {
+
+        final Task[] tasks = executeParallelTasks(pool, 3);
+        final List<String> threadNames = new ArrayList<String>();
+        for (final Task task : tasks) {
+            threadNames.add(task.executedBy);
+        }
+        for (final String expectedThread : expectedThreads) {
+            assertTrue("No task was executed by " + expectedThread,
+                    threadNames.remove(expectedThread));
+            assertFalse("Multiple tasks were executed by " + expectedThread,
+                    threadNames.contains(expectedThread));
+        }
+    }
+
+    private Task[] executeParallelTasks(final ExecutorService pool, final int number)
+            throws InterruptedException {
+        final Task[] tasks = new Task[number];
+        final CountDownLatch latch = new CountDownLatch(number);
+        for (int i = 0; i < tasks.length; i++) {
+            tasks[i] = new Task(latch);
+            pool.execute(tasks[i]);
+        }
+        pool.awaitTermination(MAX_THREAD_AGE_MS, TimeUnit.MILLISECONDS);
+        return tasks;
+    }
+
+    private void assertExecutionByThread(final ExecutorService pool, final String expectedThread)
+            throws ExecutionException, InterruptedException {
+        final Task task = new Task();
+        pool.submit(task).get();
+        assertEquals("Thread name", expectedThread, task.executedBy);
+    }
+
+    private void assertFailingSubmitThreadName(final ExecutorService pool, final String expectedThread)
+            throws ExecutionException, InterruptedException {
+        final Task task = new ExceptionTask();
+        try {
+            pool.submit(task).get();
+        } catch (ExecutionException e) {
+            if (!e.getCause().getMessage().startsWith("ExceptionTask #")) {
+                LOG.error("Unexpected exception: ", e);
+                fail("Unexpected exception: " + e.getMessage());
+            }
+        }
+        assertEquals("Thread name", expectedThread, task.executedBy);
+    }
+
+    private void letThreadsDie() throws InterruptedException {
+        TimeUnit.MILLISECONDS.sleep(MAX_THREAD_AGE_MS * 2);
+    }
+
+    private static class Task implements Runnable {
+
+        private static int counter = 0;
+
+        protected final int count;
+
+        private final CountDownLatch mayFinish;
+
+        protected String executedBy;
+
+        Task() {
+            this(new CountDownLatch(0));
+        }
+
+        Task(final CountDownLatch latch) {
+            this.mayFinish = latch;
+            this.count = counter++;
+        }
+
+        @Override
+        public void run() {
+            mayFinish.countDown();
+            final Thread thread = Thread.currentThread();
+            try {
+                mayFinish.await();
+            } catch (InterruptedException e) {
+                thread.interrupt();
+            }
+            LOG.info("{} #{} running in thread {}",
+                    new Object[] {getClass().getSimpleName(), count, thread});
+            executedBy = thread.getName();
+        }
+    }
+
+    private static class ExceptionTask extends Task {
+        @Override
+        public void run() {
+            super.run();
+            throw new RuntimeException("ExceptionTask #" + count);
+        }
+    }
+
+    private static class TrackingThreadFactory implements ThreadFactory {
+
+        private final ThreadGroup group;
+
+        private final AtomicInteger threadCount = new AtomicInteger(0);
+
+        private final List<Thread> threadHistory = new CopyOnWriteArrayList<Thread>();
+
+        public TrackingThreadFactory() {
+            group = Thread.currentThread().getThreadGroup();
+        }
+
+        public int getThreadCount() {
+            return threadHistory.size();
+        }
+
+        public Set<String> getActiveThreads() {
+            final HashSet<String> active = new HashSet<String>();
+            for (final Thread thread : threadHistory) {
+                if (thread.isAlive()) {
+                    active.add(thread.getName());
+                }
+            }
+            return active;
+        }
+
+        public Set<String> getExpiredThreads() {
+            final HashSet<String> expired = new HashSet<String>();
+            for (final Thread thread : threadHistory) {
+                if (!thread.isAlive()) {
+                    expired.add(thread.getName());
+                }
+            }
+            return expired;
+        }
+
+        @Override
+        public Thread newThread(final Runnable r) {
+            final Thread thread = new Thread(group, r, "test-thread-" + threadCount.getAndIncrement());
+            thread.setDaemon(false);
+            thread.setPriority(Thread.NORM_PRIORITY);
+            threadHistory.add(thread);
+            LOG.info("Created thread {}", thread.getName());
+            return thread;
+        }
+    }
+
+    public static class ThreadPoolContext extends ExternalResource {
+
+        public TrackingThreadFactory getThreadFactory() {
+            return threadFactory;
+        }
+
+        public ThreadExpiringThreadPool getPool() {
+            return pool;
+        }
+
+        private TrackingThreadFactory threadFactory;
+
+        private ThreadExpiringThreadPool pool;
+
+        @Override
+        protected void before() throws Throwable {
+            Task.counter = 0; // reset counter
+            final BlockingQueue<Runnable> queue = new ArrayBlockingQueue<Runnable>(20);
+            final RejectedExecutionHandler rejectionHandler = new ThreadPoolExecutor.AbortPolicy();
+            threadFactory = new TrackingThreadFactory();
+            pool = new ThreadExpiringThreadPool(
+                    1, 1,
+                    MAX_THREAD_AGE_MS, TimeUnit.MILLISECONDS,
+                    1000, TimeUnit.MILLISECONDS,
+                    queue, threadFactory, rejectionHandler);
+        }
+
+        @Override
+        protected void after() {
+            threadFactory = null;
+            pool = null;
+        }
+    }
+}
+