blob: c4beaf9c2b64078bcdd5dd08d452e9b8541fd282 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.simulator.asm;
import java.io.IOException;
import java.io.InputStream;
import java.lang.instrument.ClassDefinition;
import java.lang.instrument.ClassFileTransformer;
import java.lang.instrument.IllegalClassFormatException;
import java.lang.instrument.Instrumentation;
import java.lang.instrument.UnmodifiableClassException;
import java.security.ProtectionDomain;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.function.BiFunction;
import java.util.regex.Pattern;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.FieldVisitor;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import static org.apache.cassandra.simulator.asm.Flag.DETERMINISTIC;
import static org.apache.cassandra.simulator.asm.Flag.LOCK_SUPPORT;
import static org.apache.cassandra.simulator.asm.Flag.NO_PROXY_METHODS;
import static org.apache.cassandra.simulator.asm.Flag.SYSTEM_CLOCK;
import static org.apache.cassandra.simulator.asm.InterceptClasses.BYTECODE_VERSION;
import static org.objectweb.asm.Opcodes.ALOAD;
import static org.objectweb.asm.Opcodes.GETFIELD;
import static org.objectweb.asm.Opcodes.GETSTATIC;
import static org.objectweb.asm.Opcodes.INVOKESPECIAL;
import static org.objectweb.asm.Opcodes.INVOKESTATIC;
import static org.objectweb.asm.Opcodes.INVOKEVIRTUAL;
import static org.objectweb.asm.Opcodes.IRETURN;
import static org.objectweb.asm.Opcodes.RETURN;
// checkstyle: suppress below 'blockSystemPropertyUsage'
/**
* A mechanism for weaving classes loaded by the bootstrap classloader that we cannot override.
* The design supports weaving of the internals of these classes, and in future we may want to
* weave LockSupport or the internals of other blocking concurrency primitives.
*
* Ultimately this wasn't necessary for the initial functionality, but we have maintained
* the layout so that it will be easier to enable such functionality in future should it be needed.
*
* To this end, the asm package and simulator-asm.jar is as self-contained set of classes for performing
* simulator byteweaving, and simulator-bootstrap.jar contains a self-contained class and interface for
* replacing important system methods.
*/
public class InterceptAgent
{
public static void premain(final String agentArgs, final Instrumentation instrumentation) throws UnmodifiableClassException, ClassNotFoundException, IOException
{
setup(agentArgs, instrumentation);
}
public void agentmain(final String agentArgs, final Instrumentation instrumentation) throws UnmodifiableClassException, ClassNotFoundException, IOException
{
setup(agentArgs, instrumentation);
}
private static void setup(final String agentArgs, final Instrumentation instrumentation) throws UnmodifiableClassException, ClassNotFoundException, IOException
{
instrumentation.addTransformer(new ClassFileTransformer()
{
@Override
public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined, ProtectionDomain protectionDomain, byte[] bytecode) throws IllegalClassFormatException
{
if (className == null)
return null;
if (className.equals("java/lang/Object"))
return transformObject(bytecode);
if (className.equals("java/lang/Enum"))
return transformEnum(bytecode);
if (className.equals("java/util/Random"))
return transformRandom(bytecode);
if (className.equals("java/util/concurrent/ThreadLocalRandom"))
return transformThreadLocalRandom(bytecode);
if (className.startsWith("java/util/concurrent/ConcurrentHashMap"))
return transformConcurrent(className, bytecode, DETERMINISTIC, NO_PROXY_METHODS);
if (className.startsWith("java/util/concurrent/locks"))
return transformConcurrent(className, bytecode, SYSTEM_CLOCK, LOCK_SUPPORT, NO_PROXY_METHODS);
return null;
}
});
Pattern reloadPattern = Pattern.compile("java\\.(lang\\.Enum|util\\.concurrent\\.(locks\\..*|ConcurrentHashMap)|util\\.(concurrent\\.ThreadLocal)?Random|lang\\.Object)");
List<ClassDefinition> redefine = new ArrayList<>();
for (Class<?> loadedClass : instrumentation.getAllLoadedClasses())
{
if (reloadPattern.matcher(loadedClass.getName()).matches())
redefine.add(new ClassDefinition(loadedClass, readDefinition(loadedClass)));
}
if (!redefine.isEmpty())
instrumentation.redefineClasses(redefine.toArray(new ClassDefinition[0]));
}
private static byte[] readDefinition(Class<?> clazz) throws IOException
{
return readDefinition(clazz.getName().replaceAll("\\.", "/"));
}
private static byte[] readDefinition(String className) throws IOException
{
byte[] bytes = new byte[1024];
try (InputStream in = ClassLoader.getSystemResourceAsStream(className + ".class"))
{
int count = 0;
while (true)
{
int add = in.read(bytes, count, bytes.length - count);
if (add < 0)
break;
if (add == 0)
bytes = Arrays.copyOf(bytes, bytes.length * 2);
count += add;
}
return Arrays.copyOf(bytes, count);
}
}
/**
* We don't want Object.toString() to invoke our overridden identityHashCode by virtue of invoking some overridden hashCode()
* So we overwrite Object.toString() to replace calls to Object.hashCode() with direct calls to System.identityHashCode()
*/
private static byte[] transformObject(byte[] bytes)
{
class ObjectVisitor extends ClassVisitor
{
public ObjectVisitor(int api, ClassVisitor classVisitor)
{
super(api, classVisitor);
}
@Override
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions)
{
if (descriptor.equals("()Ljava/lang/String;") && name.equals("toString"))
return Utils.deterministicToString(super.visitMethod(access, name, descriptor, signature, exceptions));
else
return super.visitMethod(access, name, descriptor, signature, exceptions);
}
}
return transform(bytes, ObjectVisitor::new);
}
/**
* We want Enum to have a deterministic hashCode() so we simply forward calls to ordinal()
*/
private static byte[] transformEnum(byte[] bytes)
{
class EnumVisitor extends ClassVisitor
{
public EnumVisitor(int api, ClassVisitor classVisitor)
{
super(api, classVisitor);
}
@Override
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions)
{
if (descriptor.equals("()I") && name.equals("hashCode"))
{
MethodVisitor visitor = super.visitMethod(access, name, descriptor, signature, exceptions);
visitor.visitLabel(new Label());
visitor.visitIntInsn(ALOAD, 0);
visitor.visitFieldInsn(GETFIELD, "java/lang/Enum", "ordinal", "I");
visitor.visitInsn(IRETURN);
visitor.visitLabel(new Label());
visitor.visitMaxs(1, 1);
visitor.visitEnd();
return new MethodVisitor(BYTECODE_VERSION) {};
}
else
{
return super.visitMethod(access, name, descriptor, signature, exceptions);
}
}
}
return transform(bytes, EnumVisitor::new);
}
/**
* We want Random to be initialised deterministically, so we modify the default constructor to fetch
* some deterministically generated seed to pass to its seed constructor
*/
private static byte[] transformRandom(byte[] bytes)
{
class RandomVisitor extends ClassVisitor
{
public RandomVisitor(int api, ClassVisitor classVisitor)
{
super(api, classVisitor);
}
@Override
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions)
{
if (descriptor.equals("()V") && name.equals("<init>"))
{
MethodVisitor visitor = super.visitMethod(access, name, descriptor, signature, exceptions);
visitor.visitLabel(new Label());
visitor.visitIntInsn(ALOAD, 0);
visitor.visitMethodInsn(INVOKESTATIC, "org/apache/cassandra/simulator/systems/InterceptorOfSystemMethods$Global", "randomSeed", "()J", false);
visitor.visitMethodInsn(INVOKESPECIAL, "java/util/Random", "<init>", "(J)V", false);
visitor.visitInsn(RETURN);
visitor.visitLabel(new Label());
visitor.visitMaxs(3, 1);
visitor.visitEnd();
return new MethodVisitor(BYTECODE_VERSION) {};
}
else
{
return super.visitMethod(access, name, descriptor, signature, exceptions);
}
}
}
return transform(bytes, RandomVisitor::new);
}
/**
* We require ThreadLocalRandom to be deterministic, so we modify its initialisation method to invoke a
* global deterministic random value generator
*/
private static byte[] transformThreadLocalRandom(byte[] bytes)
{
class ThreadLocalRandomVisitor extends ClassVisitor
{
// CassandraRelevantProperties is not available to us here
final boolean determinismCheck = System.getProperty("cassandra.test.simulator.determinismcheck", "none").matches("relaxed|strict");
public ThreadLocalRandomVisitor(int api, ClassVisitor classVisitor)
{
super(api, classVisitor);
}
String unsafeDescriptor;
String unsafeFieldName;
@Override
public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value)
{
if (descriptor.equals("Lsun/misc/Unsafe;") || descriptor.equals("Ljdk/internal/misc/Unsafe;"))
{
unsafeFieldName = name;
unsafeDescriptor = descriptor;
}
return super.visitField(access, name, descriptor, signature, value);
}
@Override
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions)
{
if (descriptor.equals("()V") && name.equals("localInit"))
{
if (unsafeFieldName == null)
{
String version = System.getProperty("java.version");
if (version.startsWith("11.")) { unsafeFieldName = "U"; unsafeDescriptor = "Ljdk/internal/misc/Unsafe;"; }
else if (version.startsWith("1.8")) { unsafeFieldName = "UNSAFE"; unsafeDescriptor = "Lsun/misc/Unsafe;"; }
else throw new AssertionError("Unsupported Java Version");
}
MethodVisitor visitor = super.visitMethod(access, name, descriptor, signature, exceptions);
visitor.visitLabel(new Label());
visitor.visitIntInsn(ALOAD, 0);
visitor.visitFieldInsn(GETSTATIC, "java/util/concurrent/ThreadLocalRandom", unsafeFieldName, unsafeDescriptor);
visitor.visitMethodInsn(INVOKESTATIC, "java/lang/Thread", "currentThread", "()Ljava/lang/Thread;", false);
visitor.visitFieldInsn(GETSTATIC, "java/util/concurrent/ThreadLocalRandom", "SEED", "J");
visitor.visitMethodInsn(INVOKESTATIC, "org/apache/cassandra/simulator/systems/InterceptorOfSystemMethods$Global", "randomSeed", "()J", false);
String unsafeClass = Utils.descriptorToClassName(unsafeDescriptor);
visitor.visitMethodInsn(INVOKEVIRTUAL, unsafeClass, "putLong", "(Ljava/lang/Object;JJ)V", false);
visitor.visitFieldInsn(GETSTATIC, "java/util/concurrent/ThreadLocalRandom", unsafeFieldName, unsafeDescriptor);
visitor.visitMethodInsn(INVOKESTATIC, "java/lang/Thread", "currentThread", "()Ljava/lang/Thread;", false);
visitor.visitFieldInsn(GETSTATIC, "java/util/concurrent/ThreadLocalRandom", "PROBE", "J");
visitor.visitLdcInsn(0);
visitor.visitMethodInsn(INVOKEVIRTUAL, unsafeClass, "putInt", "(Ljava/lang/Object;JI)V", false);
visitor.visitInsn(RETURN);
visitor.visitLabel(new Label());
visitor.visitMaxs(6, 1);
visitor.visitEnd();
return new MethodVisitor(BYTECODE_VERSION) {};
}
else
{
MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
if (determinismCheck && (name.equals("nextSeed") || name.equals("nextSecondarySeed")))
mv = new ThreadLocalRandomCheckTransformer(api, mv);
return mv;
}
}
}
return transform(bytes, ThreadLocalRandomVisitor::new);
}
private static byte[] transform(byte[] bytes, BiFunction<Integer, ClassWriter, ClassVisitor> constructor)
{
ClassWriter out = new ClassWriter(0);
ClassReader in = new ClassReader(bytes);
ClassVisitor transform = constructor.apply(BYTECODE_VERSION, out);
in.accept(transform, 0);
return out.toByteArray();
}
private static byte[] transformConcurrent(String className, byte[] bytes, Flag flag, Flag ... flags)
{
ClassTransformer transformer = new ClassTransformer(BYTECODE_VERSION, className, EnumSet.of(flag, flags), null);
transformer.readAndTransform(bytes);
if (!transformer.isTransformed())
return null;
return transformer.toBytes();
}
}