| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package org.apache.commons.weaver.privilizer; |
| |
| import java.io.InputStream; |
| import java.lang.invoke.LambdaMetafactory; |
| import java.lang.reflect.Modifier; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.BitSet; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.LinkedHashSet; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Optional; |
| import java.util.Set; |
| import java.util.function.Function; |
| import java.util.stream.Collectors; |
| import java.util.stream.Stream; |
| |
| import org.apache.commons.lang3.ArrayUtils; |
| import org.apache.commons.lang3.StringUtils; |
| import org.apache.commons.lang3.Validate; |
| import org.apache.commons.lang3.tuple.Pair; |
| import org.objectweb.asm.ClassReader; |
| import org.objectweb.asm.ClassVisitor; |
| import org.objectweb.asm.Handle; |
| import org.objectweb.asm.Label; |
| import org.objectweb.asm.MethodVisitor; |
| import org.objectweb.asm.Opcodes; |
| import org.objectweb.asm.Type; |
| import org.objectweb.asm.commons.AdviceAdapter; |
| import org.objectweb.asm.commons.GeneratorAdapter; |
| import org.objectweb.asm.commons.Method; |
| import org.objectweb.asm.tree.ClassNode; |
| import org.objectweb.asm.tree.FieldNode; |
| import org.objectweb.asm.tree.MethodNode; |
| |
| /** |
| * {@link ClassVisitor} to import so-called "blueprint methods". |
| */ |
| class BlueprintingVisitor extends Privilizer.PrivilizerClassVisitor { |
| static class TypeInfo { |
| final int access; |
| final String superName; |
| final Map<String, FieldNode> fields; |
| final Map<Method, MethodNode> methods; |
| |
| TypeInfo(int access, String superName, Map<String, FieldNode> fields, Map<Method, MethodNode> methods) { |
| super(); |
| this.access = access; |
| this.superName = superName; |
| this.fields = fields; |
| this.methods = methods; |
| } |
| } |
| |
| private static final Type LAMBDA_METAFACTORY = Type.getType(LambdaMetafactory.class); |
| |
| private static Pair<Type, Method> methodKey(String owner, String name, String desc) { |
| return Pair.of(Type.getObjectType(owner), new Method(name, desc)); |
| } |
| |
| private final Set<Type> blueprintTypes = new HashSet<>(); |
| private final Map<Pair<Type, Method>, MethodNode> blueprintRegistry = new HashMap<>(); |
| |
| private final Map<Pair<Type, Method>, String> importedMethods = new HashMap<>(); |
| |
| private final Map<Type, TypeInfo> typeInfoCache = new HashMap<>(); |
| private final Map<Pair<Type, String>, FieldAccess> fieldAccessMap = new HashMap<>(); |
| |
| private final ClassVisitor nextVisitor; |
| |
| /** |
| * Create a new {@link BlueprintingVisitor}. |
| * @param privilizer owner |
| * @param nextVisitor wrapped |
| * @param config annotation |
| */ |
| BlueprintingVisitor(@SuppressWarnings("PMD.UnusedFormalParameter") final Privilizer privilizer, //false positive |
| final ClassVisitor nextVisitor, |
| final Privilizing config) { |
| privilizer.super(new ClassNode(Privilizer.ASM_VERSION)); |
| this.nextVisitor = nextVisitor; |
| |
| // load up blueprint methods: |
| for (final Privilizing.CallTo callTo : config.value()) { |
| final Type blueprintType = Type.getType(callTo.value()); |
| blueprintTypes.add(blueprintType); |
| |
| final Set<String> methodNames = new HashSet<>(Arrays.asList(callTo.methods())); |
| |
| typeInfo(blueprintType).methods.entrySet().stream() |
| .filter(e -> methodNames.isEmpty() || methodNames.contains(e.getKey().getName())) |
| .forEach(e -> blueprintRegistry.put(Pair.of(blueprintType, e.getKey()), e.getValue())); |
| } |
| } |
| |
| private TypeInfo typeInfo(Type type) { |
| return typeInfoCache.computeIfAbsent(type, k -> { |
| final ClassNode cn = read(k.getClassName()); |
| |
| return new TypeInfo(cn.access, cn.superName, |
| cn.fields.stream().collect(Collectors.toMap(f -> f.name, Function.identity())), |
| cn.methods.stream().collect(Collectors.toMap(m -> new Method(m.name, m.desc), Function.identity()))); |
| }); |
| } |
| |
| private ClassNode read(final String className) { |
| final ClassNode result = new ClassNode(Privilizer.ASM_VERSION); |
| try (InputStream bytecode = privilizer().env.getClassfile(className).getInputStream()) { |
| new ClassReader(bytecode).accept(result, ClassReader.SKIP_DEBUG | ClassReader.EXPAND_FRAMES); |
| } catch (final Exception e) { |
| throw new RuntimeException(e); |
| } |
| return result; |
| } |
| |
| @Override |
| @SuppressWarnings("PMD.UseVarargs") //overridden method |
| public void visit(final int version, final int access, final String name, final String signature, |
| final String superName, final String[] interfaces) { |
| Validate.isTrue(!blueprintTypes.contains(Type.getObjectType(name)), |
| "Class %s cannot declare itself as a blueprint!", name); |
| super.visit(version, access, name, signature, superName, interfaces); |
| } |
| |
| @Override |
| @SuppressWarnings("PMD.UseVarargs") //overridden method |
| public MethodVisitor visitMethod(final int access, final String name, final String desc, final String signature, |
| final String[] exceptions) { |
| final MethodVisitor toWrap = super.visitMethod(access, name, desc, signature, exceptions); |
| return new MethodInvocationHandler(toWrap) { |
| @Override |
| boolean shouldImport(final Pair<Type, Method> methodKey) { |
| return blueprintRegistry.containsKey(methodKey); |
| } |
| }; |
| } |
| |
| private String importMethod(final Pair<Type, Method> key) { |
| if (importedMethods.containsKey(key)) { |
| return importedMethods.get(key); |
| } |
| final String result = |
| new StringBuilder(key.getLeft().getInternalName().replace('/', '_')).append("$$") |
| .append(key.getRight().getName()).toString(); |
| importedMethods.put(key, result); |
| privilizer().env.debug("importing %s#%s as %s", key.getLeft().getClassName(), key.getRight(), result); |
| final int access = Opcodes.ACC_PRIVATE + Opcodes.ACC_STATIC + Opcodes.ACC_SYNTHETIC; |
| |
| final MethodNode source = typeInfo(key.getLeft()).methods.get(key.getRight()); |
| |
| final String[] exceptions = source.exceptions.toArray(ArrayUtils.EMPTY_STRING_ARRAY); |
| |
| // non-public fields accessed |
| final Set<FieldAccess> fieldAccesses = new LinkedHashSet<>(); |
| |
| source.accept(new MethodVisitor(Privilizer.ASM_VERSION) { |
| @Override |
| public void visitFieldInsn(final int opcode, final String owner, final String name, final String desc) { |
| final FieldAccess fieldAccess = fieldAccess(Type.getObjectType(owner), name); |
| |
| super.visitFieldInsn(opcode, owner, name, desc); |
| if (!Modifier.isPublic(fieldAccess.access)) { |
| fieldAccesses.add(fieldAccess); |
| } |
| } |
| }); |
| |
| final MethodNode withAccessibleAdvice = |
| new MethodNode(access, result, source.desc, source.signature, exceptions); |
| |
| // spider own methods: |
| MethodVisitor mv = new NestedMethodInvocationHandler(withAccessibleAdvice, key); //NOPMD |
| |
| if (!fieldAccesses.isEmpty()) { |
| mv = new AccessibleAdvisor(mv, access, result, source.desc, new ArrayList<>(fieldAccesses)); |
| } |
| source.accept(mv); |
| |
| // private can only be called by other privileged methods, so no need to mark as privileged |
| if (!Modifier.isPrivate(source.access)) { |
| withAccessibleAdvice.visitAnnotation(Type.getType(Privileged.class).getDescriptor(), false).visitEnd(); |
| } |
| withAccessibleAdvice.accept(this.cv); |
| |
| return result; |
| } |
| |
| private FieldAccess fieldAccess(final Type owner, final String name) { |
| return fieldAccessMap.computeIfAbsent(Pair.of(owner, name), k -> { |
| final FieldNode fieldNode = typeInfo(k.getLeft()).fields.get(k.getRight()); |
| Validate.validState(fieldNode != null, "Could not locate %s.%s", k.getLeft().getClassName(), k.getRight()); |
| return new FieldAccess(fieldNode.access, k.getLeft(), fieldNode.name, Type.getType(fieldNode.desc)); |
| }); |
| } |
| |
| @Override |
| public void visitEnd() { |
| super.visitEnd(); |
| ((ClassNode) cv).accept(nextVisitor); |
| } |
| |
| private abstract class MethodInvocationHandler extends MethodVisitor { |
| MethodInvocationHandler(final MethodVisitor mvr) { |
| super(Privilizer.ASM_VERSION, mvr); |
| } |
| |
| @Override |
| public void visitMethodInsn(final int opcode, final String owner, final String name, final String desc, |
| final boolean itf) { |
| if (opcode == Opcodes.INVOKESTATIC) { |
| final Pair<Type, Method> methodKey = methodKey(owner, name, desc); |
| if (shouldImport(methodKey)) { |
| final String importedName = importMethod(methodKey); |
| super.visitMethodInsn(opcode, className, importedName, desc, itf); |
| return; |
| } |
| } |
| visitNonImportedMethodInsn(opcode, owner, name, desc, itf); |
| } |
| |
| protected void visitNonImportedMethodInsn(final int opcode, final String owner, final String name, |
| final String desc, final boolean itf) { |
| super.visitMethodInsn(opcode, owner, name, desc, itf); |
| } |
| |
| @Override |
| public void visitInvokeDynamicInsn(String name, String descriptor, Handle bootstrapMethodHandle, |
| Object... bootstrapMethodArguments) { |
| |
| if (isLambda(bootstrapMethodHandle)) { |
| Object[] args = bootstrapMethodArguments; |
| |
| Handle handle = null; |
| |
| for (int i = 0; i < args.length; i++) { |
| if (bootstrapMethodArguments[i] instanceof Handle) { |
| if (handle != null) { |
| // we don't know what to do with multiple handles; skip the whole thing: |
| args = bootstrapMethodArguments; |
| break; |
| } |
| handle = (Handle) args[i]; |
| |
| if (handle.getTag() == Opcodes.H_INVOKESTATIC) { |
| final Pair<Type, Method> methodKey = |
| methodKey(handle.getOwner(), handle.getName(), handle.getDesc()); |
| |
| if (shouldImport(methodKey)) { |
| final String importedName = importMethod(methodKey); |
| args = bootstrapMethodArguments.clone(); |
| args[i] = new Handle(handle.getTag(), className, importedName, handle.getDesc(), false); |
| } |
| } |
| } |
| } |
| if (handle != null) { |
| if (args == bootstrapMethodArguments) { |
| validateLambda(handle); |
| } else { |
| super.visitInvokeDynamicInsn(name, descriptor, bootstrapMethodHandle, args); |
| return; |
| } |
| } |
| } |
| super.visitInvokeDynamicInsn(name, descriptor, bootstrapMethodHandle, bootstrapMethodArguments); |
| } |
| |
| protected void validateLambda(Handle handle) { |
| } |
| |
| abstract boolean shouldImport(Pair<Type, Method> methodKey); |
| |
| private boolean isLambda(Handle handle) { |
| return handle.getTag() == Opcodes.H_INVOKESTATIC |
| && LAMBDA_METAFACTORY.getInternalName().equals(handle.getOwner()) |
| && "metafactory".equals(handle.getName()); |
| } |
| } |
| |
| class NestedMethodInvocationHandler extends MethodInvocationHandler { |
| final Pair<Type, Method> methodKey; |
| final Type owner; |
| |
| NestedMethodInvocationHandler(final MethodVisitor mvr, final Pair<Type,Method> methodKey) { |
| super(mvr); |
| this.methodKey = methodKey; |
| this.owner = methodKey.getLeft(); |
| } |
| |
| @Override |
| protected void visitNonImportedMethodInsn(int opcode, String owner, String name, String desc, boolean itf) { |
| final Type ownerType = Type.getObjectType(owner); |
| final Method m = new Method(name, desc); |
| |
| if (isAccessible(ownerType) && isAccessible(ownerType, m)) { |
| super.visitNonImportedMethodInsn(opcode, owner, name, desc, itf); |
| } else { |
| throw new IllegalStateException(String.format("Blueprint method %s.%s calls inaccessible method %s.%s", |
| this.owner, methodKey.getRight(), owner, m)); |
| } |
| } |
| |
| @Override |
| protected void validateLambda(Handle handle) { |
| super.validateLambda(handle); |
| final Type ownerType = Type.getObjectType(handle.getOwner()); |
| final Method m = new Method(handle.getName(), handle.getDesc()); |
| |
| if (!(isAccessible(ownerType) && isAccessible(ownerType, m))) { |
| throw new IllegalStateException( |
| String.format("Blueprint method %s.%s utilizes inaccessible method reference %s::%s", owner, |
| methodKey.getRight(), handle.getOwner(), m)); |
| } |
| } |
| |
| @Override |
| boolean shouldImport(final Pair<Type, Method> methodKey) { |
| // call anything called within a class hierarchy: |
| final Type called = methodKey.getLeft(); |
| // "I prefer the short cut": |
| if (called.equals(owner)) { |
| return true; |
| } |
| try { |
| final Class<?> inner = load(called); |
| final Class<?> outer = load(owner); |
| return inner.isAssignableFrom(outer); |
| } catch (final ClassNotFoundException e) { |
| return false; |
| } |
| } |
| |
| private Class<?> load(final Type type) throws ClassNotFoundException { |
| return privilizer().env.classLoader.loadClass(type.getClassName()); |
| } |
| |
| private boolean isAccessible(Type type) { |
| final TypeInfo typeInfo = typeInfo(type); |
| return isAccessible(type, typeInfo.access); |
| } |
| |
| private boolean isAccessible(Type type, Method m) { |
| Type t = type; |
| while (t != null) { |
| final TypeInfo typeInfo = typeInfo(t); |
| final MethodNode methodNode = typeInfo.methods.get(m); |
| if (methodNode == null) { |
| t = Optional.ofNullable(typeInfo.superName).map(Type::getObjectType).orElse(null); |
| continue; |
| } |
| return isAccessible(type, methodNode.access); |
| } |
| throw new IllegalStateException(String.format("Cannot find method %s.%s", type, m)); |
| } |
| |
| private boolean isAccessible(Type type, int access) { |
| if (Modifier.isPublic(access)) { |
| return true; |
| } |
| if (Modifier.isProtected(access) || Modifier.isPrivate(access)) { |
| return false; |
| } |
| return Stream.of(target, type).map(Type::getInternalName).map(n -> StringUtils.substringBeforeLast(n, "/")) |
| .distinct().count() == 1; |
| } |
| } |
| |
| /** |
| * For every non-public referenced field of an imported method, replaces with reflective calls. Additionally, for |
| * every such field that is not accessible, sets the field's accessibility and clears it as the method exits. |
| */ |
| private class AccessibleAdvisor extends AdviceAdapter { |
| final Type bitSetType = Type.getType(BitSet.class); |
| final Type classType = Type.getType(Class.class); |
| final Type fieldType = Type.getType(java.lang.reflect.Field.class); |
| final Type fieldArrayType = Type.getType(java.lang.reflect.Field[].class); |
| final Type stringType = Type.getType(String.class); |
| |
| final List<FieldAccess> fieldAccesses; |
| final Label begin = new Label(); |
| int localFieldArray; |
| int bitSet; |
| int fieldCounter; |
| |
| AccessibleAdvisor(final MethodVisitor mvr, final int access, final String name, final String desc, |
| final List<FieldAccess> fieldAccesses) { |
| super(Privilizer.ASM_VERSION, mvr, access, name, desc); |
| this.fieldAccesses = fieldAccesses; |
| } |
| |
| @Override |
| protected void onMethodEnter() { |
| localFieldArray = newLocal(fieldArrayType); |
| bitSet = newLocal(bitSetType); |
| fieldCounter = newLocal(Type.INT_TYPE); |
| |
| // create localFieldArray |
| push(fieldAccesses.size()); |
| newArray(fieldArrayType.getElementType()); |
| storeLocal(localFieldArray); |
| |
| // create bitSet |
| newInstance(bitSetType); |
| dup(); |
| push(fieldAccesses.size()); |
| invokeConstructor(bitSetType, Method.getMethod("void <init>(int)")); |
| storeLocal(bitSet); |
| |
| // populate localFieldArray |
| push(0); |
| storeLocal(fieldCounter); |
| for (final FieldAccess access : fieldAccesses) { |
| prehandle(access); |
| iinc(fieldCounter, 1); |
| } |
| mark(begin); |
| } |
| |
| private void prehandle(final FieldAccess access) { |
| // push owner.class literal |
| visitLdcInsn(access.owner); |
| push(access.name); |
| final Label next = new Label(); |
| invokeVirtual(classType, new Method("getDeclaredField", fieldType, new Type[] { stringType })); |
| |
| dup(); |
| // store the field at localFieldArray[fieldCounter]: |
| loadLocal(localFieldArray); |
| swap(); |
| loadLocal(fieldCounter); |
| swap(); |
| arrayStore(fieldArrayType.getElementType()); |
| |
| dup(); |
| invokeVirtual(fieldArrayType.getElementType(), Method.getMethod("boolean isAccessible()")); |
| |
| final Label setAccessible = new Label(); |
| // if false, setAccessible: |
| ifZCmp(EQ, setAccessible); |
| |
| // else pop field instance |
| pop(); |
| // and record that he was already accessible: |
| loadLocal(bitSet); |
| loadLocal(fieldCounter); |
| invokeVirtual(bitSetType, Method.getMethod("void set(int)")); |
| goTo(next); |
| |
| mark(setAccessible); |
| push(true); |
| invokeVirtual(fieldArrayType.getElementType(), Method.getMethod("void setAccessible(boolean)")); |
| |
| mark(next); |
| } |
| |
| @Override |
| public void visitFieldInsn(final int opcode, final String owner, final String name, final String desc) { |
| final Pair<Type, String> key = Pair.of(Type.getObjectType(owner), name); |
| final FieldAccess fieldAccess = fieldAccessMap.get(key); |
| Validate.isTrue(fieldAccesses.contains(fieldAccess), "Cannot find field %s", key); |
| final int fieldIndex = fieldAccesses.indexOf(fieldAccess); |
| visitInsn(NOP); |
| loadLocal(localFieldArray); |
| push(fieldIndex); |
| arrayLoad(fieldArrayType.getElementType()); |
| checkCast(fieldType); |
| |
| final Method access; |
| if (opcode == PUTSTATIC) { |
| // value should have been at top of stack on entry; position the field under the value: |
| swap(); |
| // add null object for static field deref and swap under value: |
| push((String) null); |
| swap(); |
| if (fieldAccess.type.getSort() < Type.ARRAY) { |
| // box value: |
| valueOf(fieldAccess.type); |
| } |
| access = Method.getMethod("void set(Object, Object)"); |
| } else { |
| access = Method.getMethod("Object get(Object)"); |
| // add null object for static field deref: |
| push((String) null); |
| } |
| |
| invokeVirtual(fieldType, access); |
| |
| if (opcode == GETSTATIC) { |
| checkCast(Privilizer.wrap(fieldAccess.type)); |
| if (fieldAccess.type.getSort() < Type.ARRAY) { |
| unbox(fieldAccess.type); |
| } |
| } |
| } |
| |
| @Override |
| public void visitMaxs(final int maxStack, final int maxLocals) { |
| // put try-finally around the whole method |
| final Label fny = mark(); |
| // null exception type signifies finally block: |
| final Type exceptionType = null; |
| catchException(begin, fny, exceptionType); |
| onFinally(); |
| throwException(); |
| super.visitMaxs(maxStack, maxLocals); |
| } |
| |
| @Override |
| protected void onMethodExit(final int opcode) { |
| if (opcode != ATHROW) { |
| onFinally(); |
| } |
| } |
| |
| private void onFinally() { |
| // loop over fields and return any non-null element to being inaccessible: |
| push(0); |
| storeLocal(fieldCounter); |
| |
| final Label test = mark(); |
| final Label increment = new Label(); |
| final Label endFinally = new Label(); |
| |
| loadLocal(fieldCounter); |
| push(fieldAccesses.size()); |
| ifCmp(Type.INT_TYPE, GeneratorAdapter.GE, endFinally); |
| |
| loadLocal(bitSet); |
| loadLocal(fieldCounter); |
| invokeVirtual(bitSetType, Method.getMethod("boolean get(int)")); |
| |
| // if true, increment: |
| ifZCmp(NE, increment); |
| |
| loadLocal(localFieldArray); |
| loadLocal(fieldCounter); |
| arrayLoad(fieldArrayType.getElementType()); |
| push(false); |
| invokeVirtual(fieldArrayType.getElementType(), Method.getMethod("void setAccessible(boolean)")); |
| |
| mark(increment); |
| iinc(fieldCounter, 1); |
| goTo(test); |
| mark(endFinally); |
| } |
| } |
| } |