blob: ab72c9f5ff51a0ac2893d9b5ff3b6c50d9a519c4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.sling.commons.threads.impl;
import java.lang.ref.Reference;
import java.lang.reflect.Field;
import java.util.Arrays;
import org.apache.sling.commons.threads.impl.ThreadLocalChangeListener.Mode;
import org.slf4j.Logger;
/** Notifies a {@link ThreadLocalChangeListener} about changes on a thread local storage. In addition it removes all references to variables
* being added to the thread local storage while the cleaner was running with its {@link cleanup} method.
*
* @see <a href="http://www.javaspecialists.eu/archive/Issue229.html">JavaSpecialist.eu - Cleaning ThreadLocals</a> */
public class ThreadLocalCleaner {
private final ThreadLocalChangeListener listener;
/* Reflection fields */
/** this field is in class {@link ThreadLocal} and is of type {@code ThreadLocal.ThreadLocalMap} */
private static volatile Field threadLocalsField;
/** this field is in class {@link ThreadLocal} and is of type {@code ThreadLocal.ThreadLocalMap} */
private static Field inheritableThreadLocalsField;
private static Class<?> threadLocalMapClass;
/** this field is in class {@code ThreadLocal.ThreadLocalMap} and contains an array of {@code ThreadLocal.ThreadLocalMap.Entry's} */
private static Field tableField;
private static Class<?> threadLocalMapEntryClass;
/** this field is in class {@code ThreadLocal.ThreadLocalMap.Entry} and contains an object referencing the actual thread local
* variable */
private static Field threadLocalEntryValueField;
/** this field is in the class {@code ThreadLocal.ThreadLocalMap} and contains the number of the entries */
private static Field threadLocalMapSizeField;
/** this field is in the class {@code ThreadLocal.ThreadLocalMap} and next resize threshold */
private static Field threadLocalMapThresholdField;
private static volatile IllegalStateException reflectionException;
public ThreadLocalCleaner(ThreadLocalChangeListener listener) {
if (threadLocalsField == null || reflectionException != null) {
initReflectionFields();
}
this.listener = listener;
saveOldThreadLocals();
}
private static synchronized void initReflectionFields() throws IllegalStateException {
// check if previous initialization lead to an exception
if (reflectionException != null) {
throw reflectionException;
}
// check if initialized
if (threadLocalsField == null) {
try {
threadLocalsField = field(Thread.class, "threadLocals");
inheritableThreadLocalsField = field(Thread.class, "inheritableThreadLocals");
threadLocalMapClass = inner(ThreadLocal.class, "ThreadLocalMap");
tableField = field(threadLocalMapClass, "table");
threadLocalMapEntryClass = inner(threadLocalMapClass, "Entry");
threadLocalEntryValueField = field(threadLocalMapEntryClass, "value");
threadLocalMapSizeField = field(threadLocalMapClass, "size");
threadLocalMapThresholdField = field(threadLocalMapClass, "threshold");
} catch (NoSuchFieldException e) {
reflectionException = new IllegalStateException(
"Could not locate threadLocals field in class Thread. " +
"Will not be able to clear thread locals: " + e,
e);
throw reflectionException;
}
}
}
/** This is only for debugging purposes. Gives out all thread locals bound to the current thread to the logger.
*
* @param log
* @throws IllegalArgumentException
* @throws IllegalAccessException */
void dump(Logger log) {
Thread thread = Thread.currentThread();
Object threadLocals;
try {
threadLocals = threadLocalsField.get(thread);
Reference<?>[] currentReferences = (Reference<?>[]) tableField.get(threadLocals);
int size = (int) threadLocalMapSizeField.get(threadLocals);
log.info("Found {} thread locals bound to thread {}", size, thread);
for (Reference<?> curRef : currentReferences) {
if (curRef != null) {
log.info("Found reference {} with value {}", (ThreadLocal<?>) curRef.get(), threadLocalEntryValueField.get(curRef));
}
}
} catch (IllegalArgumentException | IllegalAccessException e) {
log.error("Can not dump thread locals for thread {}: {}", thread, e, e);
}
}
public void cleanup() {
// the first two diff calls are only to notify the listener, the actual cleanup is done by restoreOldThreadLocals
diff(threadLocalsField, copyOfThreadLocals);
diff(inheritableThreadLocalsField, copyOfInheritableThreadLocals);
restoreOldThreadLocals();
}
/** Notifies the {@link ThreadLocalChangeListener} about changes on thread local variables for the current thread.
*
* @param field is a field containing a ThreadLocalMap
* @param backup */
private void diff(Field field, Reference<?>[] backup) {
try {
Thread thread = Thread.currentThread();
Object threadLocals = field.get(thread);
if (threadLocals == null) {
if (backup != null) {
for (Reference<?> reference : backup) {
changed(thread, reference, Mode.REMOVED);
}
}
return;
}
Reference<?>[] current = (Reference<?>[]) tableField.get(threadLocals);
if (backup == null) {
for (Reference<?> reference : current) {
changed(thread, reference, Mode.ADDED);
}
} else {
// nested loop - both arrays *should* be relatively small
next: for (Reference<?> curRef : current) {
if (curRef != null) {
if (curRef.get() == copyOfThreadLocals ||
curRef.get() == copyOfInheritableThreadLocals) {
continue next;
}
for (Reference<?> backupRef : backup) {
if (curRef == backupRef)
continue next;
}
// could not find it in backup - added
changed(thread, curRef, Mode.ADDED);
}
}
next: for (Reference<?> backupRef : backup) {
for (Reference<?> curRef : current) {
if (curRef == backupRef)
continue next;
}
// could not find it in current - removed
changed(thread, backupRef, Mode.REMOVED);
}
}
} catch (IllegalAccessException e) {
throw new IllegalStateException("Access denied", e);
}
}
private void changed(Thread thread, Reference<?> reference,
ThreadLocalChangeListener.Mode mode)
throws IllegalAccessException {
// just skip null reference entries (may happen if array has been resized)
if (reference != null) {
listener.changed(mode,
thread, (ThreadLocal<?>) reference.get(),
threadLocalEntryValueField.get(reference));
}
}
/** @param c the class containing the field
* @param name the name of the field
* @return the field from the given class with the given name (made accessible)
* @throws NoSuchFieldException */
private static Field field(Class<?> c, String name)
throws NoSuchFieldException {
Field field = c.getDeclaredField(name);
field.setAccessible(true);
return field;
}
/** @param clazz the class containing the inner class
* @param name the name of the inner class
* @return the class with the given name, declared as inner class of the given class */
private static Class<?> inner(Class<?> clazz, String name) {
for (Class<?> c : clazz.getDeclaredClasses()) {
if (c.getSimpleName().equals(name)) {
return c;
}
}
throw new IllegalStateException(
"Could not find inner class " + name + " in " + clazz);
}
private Reference<?>[] copyOfThreadLocals;
private Integer copyOfThreadLocalsSize;
private Integer copyOfThreadLocalsThreshold;
private Reference<?>[] copyOfInheritableThreadLocals;
private Integer copyOfInheritableThreadLocalsSize;
private Integer copyOfInheritableThreadLocalsThreshold;
private void saveOldThreadLocals() {
copyOfThreadLocals = copy(threadLocalsField);
copyOfThreadLocalsSize = size(threadLocalsField, threadLocalMapSizeField);
copyOfThreadLocalsThreshold = size(threadLocalsField, threadLocalMapThresholdField);
copyOfInheritableThreadLocals = copy(inheritableThreadLocalsField);
copyOfInheritableThreadLocalsSize = size(inheritableThreadLocalsField, threadLocalMapSizeField);
copyOfInheritableThreadLocalsThreshold = size(inheritableThreadLocalsField, threadLocalMapThresholdField);
}
private static Reference<?>[] copy(Field field) {
try {
Thread thread = Thread.currentThread();
Object threadLocals = field.get(thread);
if (threadLocals == null)
return null;
Reference<?>[] table = (Reference<?>[]) tableField.get(threadLocals);
return Arrays.copyOf(table, table.length);
} catch (IllegalAccessException e) {
throw new IllegalStateException("Access denied", e);
}
}
private static Integer size(Field field, Field sizeField) {
try {
Thread thread = Thread.currentThread();
Object threadLocals = field.get(thread);
if (threadLocals == null)
return null;
return (Integer) sizeField.get(threadLocals);
} catch (IllegalAccessException e) {
throw new IllegalStateException("Access denied", e);
}
}
private void restoreOldThreadLocals() {
try {
restore(inheritableThreadLocalsField, copyOfInheritableThreadLocals,
copyOfInheritableThreadLocalsSize, copyOfInheritableThreadLocalsThreshold);
restore(threadLocalsField, copyOfThreadLocals,
copyOfThreadLocalsSize, copyOfThreadLocalsThreshold);
} finally {
}
}
private static void restore(Field field, Object value, Integer size, Integer threshold) {
try {
Thread thread = Thread.currentThread();
if (value == null) {
field.set(thread, null);
} else {
final Object threadLocals = field.get(thread);
tableField.set(threadLocals, value);
threadLocalMapSizeField.set(threadLocals, size);
threadLocalMapThresholdField.set(threadLocals, threshold);
}
} catch (IllegalAccessException e) {
throw new IllegalStateException("Access denied", e);
}
}
}