blob: e6c6226f06d1df365cc898250ef83e201248993b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.commons.weaver.privilizer;
import java.io.InputStream;
import java.lang.invoke.LambdaMetafactory;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.Validate;
import org.apache.commons.lang3.tuple.Pair;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.Handle;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.commons.AdviceAdapter;
import org.objectweb.asm.commons.GeneratorAdapter;
import org.objectweb.asm.commons.Method;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.FieldNode;
import org.objectweb.asm.tree.MethodNode;
/**
* {@link ClassVisitor} to import so-called "blueprint methods".
*/
class BlueprintingVisitor extends Privilizer.PrivilizerClassVisitor {
static class TypeInfo {
final int access;
final String superName;
final Map<String, FieldNode> fields;
final Map<Method, MethodNode> methods;
TypeInfo(int access, String superName, Map<String, FieldNode> fields, Map<Method, MethodNode> methods) {
super();
this.access = access;
this.superName = superName;
this.fields = fields;
this.methods = methods;
}
}
private static final Type LAMBDA_METAFACTORY = Type.getType(LambdaMetafactory.class);
private static Pair<Type, Method> methodKey(String owner, String name, String desc) {
return Pair.of(Type.getObjectType(owner), new Method(name, desc));
}
private final Set<Type> blueprintTypes = new HashSet<>();
private final Map<Pair<Type, Method>, MethodNode> blueprintRegistry = new HashMap<>();
private final Map<Pair<Type, Method>, String> importedMethods = new HashMap<>();
private final Map<Type, TypeInfo> typeInfoCache = new HashMap<>();
private final Map<Pair<Type, String>, FieldAccess> fieldAccessMap = new HashMap<>();
private final ClassVisitor nextVisitor;
/**
* Create a new {@link BlueprintingVisitor}.
* @param privilizer owner
* @param nextVisitor wrapped
* @param config annotation
*/
BlueprintingVisitor(@SuppressWarnings("PMD.UnusedFormalParameter") final Privilizer privilizer, //false positive
final ClassVisitor nextVisitor,
final Privilizing config) {
privilizer.super(new ClassNode(Privilizer.ASM_VERSION));
this.nextVisitor = nextVisitor;
// load up blueprint methods:
for (final Privilizing.CallTo callTo : config.value()) {
final Type blueprintType = Type.getType(callTo.value());
blueprintTypes.add(blueprintType);
final Set<String> methodNames = new HashSet<>(Arrays.asList(callTo.methods()));
typeInfo(blueprintType).methods.entrySet().stream()
.filter(e -> methodNames.isEmpty() || methodNames.contains(e.getKey().getName()))
.forEach(e -> blueprintRegistry.put(Pair.of(blueprintType, e.getKey()), e.getValue()));
}
}
private TypeInfo typeInfo(Type type) {
return typeInfoCache.computeIfAbsent(type, k -> {
final ClassNode cn = read(k.getClassName());
return new TypeInfo(cn.access, cn.superName,
cn.fields.stream().collect(Collectors.toMap(f -> f.name, Function.identity())),
cn.methods.stream().collect(Collectors.toMap(m -> new Method(m.name, m.desc), Function.identity())));
});
}
private ClassNode read(final String className) {
final ClassNode result = new ClassNode(Privilizer.ASM_VERSION);
try (InputStream bytecode = privilizer().env.getClassfile(className).getInputStream()) {
new ClassReader(bytecode).accept(result, ClassReader.SKIP_DEBUG | ClassReader.EXPAND_FRAMES);
} catch (final Exception e) {
throw new RuntimeException(e);
}
return result;
}
@Override
@SuppressWarnings("PMD.UseVarargs") //overridden method
public void visit(final int version, final int access, final String name, final String signature,
final String superName, final String[] interfaces) {
Validate.isTrue(!blueprintTypes.contains(Type.getObjectType(name)),
"Class %s cannot declare itself as a blueprint!", name);
super.visit(version, access, name, signature, superName, interfaces);
}
@Override
@SuppressWarnings("PMD.UseVarargs") //overridden method
public MethodVisitor visitMethod(final int access, final String name, final String desc, final String signature,
final String[] exceptions) {
final MethodVisitor toWrap = super.visitMethod(access, name, desc, signature, exceptions);
return new MethodInvocationHandler(toWrap) {
@Override
boolean shouldImport(final Pair<Type, Method> methodKey) {
return blueprintRegistry.containsKey(methodKey);
}
};
}
private String importMethod(final Pair<Type, Method> key) {
if (importedMethods.containsKey(key)) {
return importedMethods.get(key);
}
final String result =
new StringBuilder(key.getLeft().getInternalName().replace('/', '_')).append("$$")
.append(key.getRight().getName()).toString();
importedMethods.put(key, result);
privilizer().env.debug("importing %s#%s as %s", key.getLeft().getClassName(), key.getRight(), result);
final int access = Opcodes.ACC_PRIVATE + Opcodes.ACC_STATIC + Opcodes.ACC_SYNTHETIC;
final MethodNode source = typeInfo(key.getLeft()).methods.get(key.getRight());
final String[] exceptions = source.exceptions.toArray(ArrayUtils.EMPTY_STRING_ARRAY);
// non-public fields accessed
final Set<FieldAccess> fieldAccesses = new LinkedHashSet<>();
source.accept(new MethodVisitor(Privilizer.ASM_VERSION) {
@Override
public void visitFieldInsn(final int opcode, final String owner, final String name, final String desc) {
final FieldAccess fieldAccess = fieldAccess(Type.getObjectType(owner), name);
super.visitFieldInsn(opcode, owner, name, desc);
if (!Modifier.isPublic(fieldAccess.access)) {
fieldAccesses.add(fieldAccess);
}
}
});
final MethodNode withAccessibleAdvice =
new MethodNode(access, result, source.desc, source.signature, exceptions);
// spider own methods:
MethodVisitor mv = new NestedMethodInvocationHandler(withAccessibleAdvice, key); //NOPMD
if (!fieldAccesses.isEmpty()) {
mv = new AccessibleAdvisor(mv, access, result, source.desc, new ArrayList<>(fieldAccesses));
}
source.accept(mv);
// private can only be called by other privileged methods, so no need to mark as privileged
if (!Modifier.isPrivate(source.access)) {
withAccessibleAdvice.visitAnnotation(Type.getType(Privileged.class).getDescriptor(), false).visitEnd();
}
withAccessibleAdvice.accept(this.cv);
return result;
}
private FieldAccess fieldAccess(final Type owner, final String name) {
return fieldAccessMap.computeIfAbsent(Pair.of(owner, name), k -> {
final FieldNode fieldNode = typeInfo(k.getLeft()).fields.get(k.getRight());
Validate.validState(fieldNode != null, "Could not locate %s.%s", k.getLeft().getClassName(), k.getRight());
return new FieldAccess(fieldNode.access, k.getLeft(), fieldNode.name, Type.getType(fieldNode.desc));
});
}
@Override
public void visitEnd() {
super.visitEnd();
((ClassNode) cv).accept(nextVisitor);
}
private abstract class MethodInvocationHandler extends MethodVisitor {
MethodInvocationHandler(final MethodVisitor mvr) {
super(Privilizer.ASM_VERSION, mvr);
}
@Override
public void visitMethodInsn(final int opcode, final String owner, final String name, final String desc,
final boolean itf) {
if (opcode == Opcodes.INVOKESTATIC) {
final Pair<Type, Method> methodKey = methodKey(owner, name, desc);
if (shouldImport(methodKey)) {
final String importedName = importMethod(methodKey);
super.visitMethodInsn(opcode, className, importedName, desc, itf);
return;
}
}
visitNonImportedMethodInsn(opcode, owner, name, desc, itf);
}
protected void visitNonImportedMethodInsn(final int opcode, final String owner, final String name,
final String desc, final boolean itf) {
super.visitMethodInsn(opcode, owner, name, desc, itf);
}
@Override
public void visitInvokeDynamicInsn(String name, String descriptor, Handle bootstrapMethodHandle,
Object... bootstrapMethodArguments) {
if (isLambda(bootstrapMethodHandle)) {
Object[] args = bootstrapMethodArguments;
Handle handle = null;
for (int i = 0; i < args.length; i++) {
if (bootstrapMethodArguments[i] instanceof Handle) {
if (handle != null) {
// we don't know what to do with multiple handles; skip the whole thing:
args = bootstrapMethodArguments;
break;
}
handle = (Handle) args[i];
if (handle.getTag() == Opcodes.H_INVOKESTATIC) {
final Pair<Type, Method> methodKey =
methodKey(handle.getOwner(), handle.getName(), handle.getDesc());
if (shouldImport(methodKey)) {
final String importedName = importMethod(methodKey);
args = bootstrapMethodArguments.clone();
args[i] = new Handle(handle.getTag(), className, importedName, handle.getDesc(), false);
}
}
}
}
if (handle != null) {
if (args == bootstrapMethodArguments) {
validateLambda(handle);
} else {
super.visitInvokeDynamicInsn(name, descriptor, bootstrapMethodHandle, args);
return;
}
}
}
super.visitInvokeDynamicInsn(name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments);
}
protected void validateLambda(Handle handle) {
}
abstract boolean shouldImport(Pair<Type, Method> methodKey);
private boolean isLambda(Handle handle) {
return handle.getTag() == Opcodes.H_INVOKESTATIC
&& LAMBDA_METAFACTORY.getInternalName().equals(handle.getOwner())
&& "metafactory".equals(handle.getName());
}
}
class NestedMethodInvocationHandler extends MethodInvocationHandler {
final Pair<Type, Method> methodKey;
final Type owner;
NestedMethodInvocationHandler(final MethodVisitor mvr, final Pair<Type,Method> methodKey) {
super(mvr);
this.methodKey = methodKey;
this.owner = methodKey.getLeft();
}
@Override
protected void visitNonImportedMethodInsn(int opcode, String owner, String name, String desc, boolean itf) {
final Type ownerType = Type.getObjectType(owner);
final Method m = new Method(name, desc);
if (isAccessible(ownerType) && isAccessible(ownerType, m)) {
super.visitNonImportedMethodInsn(opcode, owner, name, desc, itf);
} else {
throw new IllegalStateException(String.format("Blueprint method %s.%s calls inaccessible method %s.%s",
this.owner, methodKey.getRight(), owner, m));
}
}
@Override
protected void validateLambda(Handle handle) {
super.validateLambda(handle);
final Type ownerType = Type.getObjectType(handle.getOwner());
final Method m = new Method(handle.getName(), handle.getDesc());
if (!(isAccessible(ownerType) && isAccessible(ownerType, m))) {
throw new IllegalStateException(
String.format("Blueprint method %s.%s utilizes inaccessible method reference %s::%s", owner,
methodKey.getRight(), handle.getOwner(), m));
}
}
@Override
boolean shouldImport(final Pair<Type, Method> methodKey) {
// call anything called within a class hierarchy:
final Type called = methodKey.getLeft();
// "I prefer the short cut":
if (called.equals(owner)) {
return true;
}
try {
final Class<?> inner = load(called);
final Class<?> outer = load(owner);
return inner.isAssignableFrom(outer);
} catch (final ClassNotFoundException e) {
return false;
}
}
private Class<?> load(final Type type) throws ClassNotFoundException {
return privilizer().env.classLoader.loadClass(type.getClassName());
}
private boolean isAccessible(Type type) {
final TypeInfo typeInfo = typeInfo(type);
return isAccessible(type, typeInfo.access);
}
private boolean isAccessible(Type type, Method m) {
Type t = type;
while (t != null) {
final TypeInfo typeInfo = typeInfo(t);
final MethodNode methodNode = typeInfo.methods.get(m);
if (methodNode == null) {
t = Optional.ofNullable(typeInfo.superName).map(Type::getObjectType).orElse(null);
continue;
}
return isAccessible(type, methodNode.access);
}
throw new IllegalStateException(String.format("Cannot find method %s.%s", type, m));
}
private boolean isAccessible(Type type, int access) {
if (Modifier.isPublic(access)) {
return true;
}
if (Modifier.isProtected(access) || Modifier.isPrivate(access)) {
return false;
}
return Stream.of(target, type).map(Type::getInternalName).map(n -> StringUtils.substringBeforeLast(n, "/"))
.distinct().count() == 1;
}
}
/**
* For every non-public referenced field of an imported method, replaces with reflective calls. Additionally, for
* every such field that is not accessible, sets the field's accessibility and clears it as the method exits.
*/
private class AccessibleAdvisor extends AdviceAdapter {
final Type bitSetType = Type.getType(BitSet.class);
final Type classType = Type.getType(Class.class);
final Type fieldType = Type.getType(java.lang.reflect.Field.class);
final Type fieldArrayType = Type.getType(java.lang.reflect.Field[].class);
final Type stringType = Type.getType(String.class);
final List<FieldAccess> fieldAccesses;
final Label begin = new Label();
int localFieldArray;
int bitSet;
int fieldCounter;
AccessibleAdvisor(final MethodVisitor mvr, final int access, final String name, final String desc,
final List<FieldAccess> fieldAccesses) {
super(Privilizer.ASM_VERSION, mvr, access, name, desc);
this.fieldAccesses = fieldAccesses;
}
@Override
protected void onMethodEnter() {
localFieldArray = newLocal(fieldArrayType);
bitSet = newLocal(bitSetType);
fieldCounter = newLocal(Type.INT_TYPE);
// create localFieldArray
push(fieldAccesses.size());
newArray(fieldArrayType.getElementType());
storeLocal(localFieldArray);
// create bitSet
newInstance(bitSetType);
dup();
push(fieldAccesses.size());
invokeConstructor(bitSetType, Method.getMethod("void <init>(int)"));
storeLocal(bitSet);
// populate localFieldArray
push(0);
storeLocal(fieldCounter);
for (final FieldAccess access : fieldAccesses) {
prehandle(access);
iinc(fieldCounter, 1);
}
mark(begin);
}
private void prehandle(final FieldAccess access) {
// push owner.class literal
visitLdcInsn(access.owner);
push(access.name);
final Label next = new Label();
invokeVirtual(classType, new Method("getDeclaredField", fieldType, new Type[] { stringType }));
dup();
// store the field at localFieldArray[fieldCounter]:
loadLocal(localFieldArray);
swap();
loadLocal(fieldCounter);
swap();
arrayStore(fieldArrayType.getElementType());
dup();
invokeVirtual(fieldArrayType.getElementType(), Method.getMethod("boolean isAccessible()"));
final Label setAccessible = new Label();
// if false, setAccessible:
ifZCmp(EQ, setAccessible);
// else pop field instance
pop();
// and record that he was already accessible:
loadLocal(bitSet);
loadLocal(fieldCounter);
invokeVirtual(bitSetType, Method.getMethod("void set(int)"));
goTo(next);
mark(setAccessible);
push(true);
invokeVirtual(fieldArrayType.getElementType(), Method.getMethod("void setAccessible(boolean)"));
mark(next);
}
@Override
public void visitFieldInsn(final int opcode, final String owner, final String name, final String desc) {
final Pair<Type, String> key = Pair.of(Type.getObjectType(owner), name);
final FieldAccess fieldAccess = fieldAccessMap.get(key);
Validate.isTrue(fieldAccesses.contains(fieldAccess), "Cannot find field %s", key);
final int fieldIndex = fieldAccesses.indexOf(fieldAccess);
visitInsn(NOP);
loadLocal(localFieldArray);
push(fieldIndex);
arrayLoad(fieldArrayType.getElementType());
checkCast(fieldType);
final Method access;
if (opcode == PUTSTATIC) {
// value should have been at top of stack on entry; position the field under the value:
swap();
// add null object for static field deref and swap under value:
push((String) null);
swap();
if (fieldAccess.type.getSort() < Type.ARRAY) {
// box value:
valueOf(fieldAccess.type);
}
access = Method.getMethod("void set(Object, Object)");
} else {
access = Method.getMethod("Object get(Object)");
// add null object for static field deref:
push((String) null);
}
invokeVirtual(fieldType, access);
if (opcode == GETSTATIC) {
checkCast(Privilizer.wrap(fieldAccess.type));
if (fieldAccess.type.getSort() < Type.ARRAY) {
unbox(fieldAccess.type);
}
}
}
@Override
public void visitMaxs(final int maxStack, final int maxLocals) {
// put try-finally around the whole method
final Label fny = mark();
// null exception type signifies finally block:
final Type exceptionType = null;
catchException(begin, fny, exceptionType);
onFinally();
throwException();
super.visitMaxs(maxStack, maxLocals);
}
@Override
protected void onMethodExit(final int opcode) {
if (opcode != ATHROW) {
onFinally();
}
}
private void onFinally() {
// loop over fields and return any non-null element to being inaccessible:
push(0);
storeLocal(fieldCounter);
final Label test = mark();
final Label increment = new Label();
final Label endFinally = new Label();
loadLocal(fieldCounter);
push(fieldAccesses.size());
ifCmp(Type.INT_TYPE, GeneratorAdapter.GE, endFinally);
loadLocal(bitSet);
loadLocal(fieldCounter);
invokeVirtual(bitSetType, Method.getMethod("boolean get(int)"));
// if true, increment:
ifZCmp(NE, increment);
loadLocal(localFieldArray);
loadLocal(fieldCounter);
arrayLoad(fieldArrayType.getElementType());
push(false);
invokeVirtual(fieldArrayType.getElementType(), Method.getMethod("void setAccessible(boolean)"));
mark(increment);
iinc(fieldCounter, 1);
goTo(test);
mark(endFinally);
}
}
}