blob: 0fe24795a3bfca9e3bcb0726a6b9ecf5c08f92af [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.karaf.service.interceptor.impl.runtime.proxy;
import static java.util.Collections.emptyList;
import static java.util.Optional.ofNullable;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toMap;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Proxy;
import java.net.URL;
import java.security.ProtectionDomain;
import java.util.Enumeration;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;
import org.apache.karaf.service.interceptor.impl.runtime.hook.InterceptorInstance;
import org.apache.karaf.service.interceptor.impl.runtime.invoker.InterceptorInvocationContext;
import org.osgi.framework.Bundle;
import org.osgi.framework.ServiceReference;
public class ProxyFactory {
private static final Class<?>[] EMPTY_CLASSES = new Class<?>[0];
public <T> T create(final ServiceReference<T> ref, final List<Class<?>> classes,
final List<InterceptorInstance<?>> interceptors,
final Map<Method, List<Class<?>>> interceptorsPerMethod) {
if (classes.isEmpty()) {
throw new IllegalArgumentException("Can't proxy an empty list of type: " + ref);
}
final Map<Method, List<InterceptorInstance<?>>> interceptorInstancePerMethod = interceptorsPerMethod.entrySet().stream()
.collect(toMap(Map.Entry::getKey, m -> m.getValue().stream()
.map(binding -> interceptors.stream().filter(i -> i.getBinding() == binding).findFirst().orElse(null))
.collect(toList())));
final ProxyClassLoader loader = new ProxyClassLoader(Thread.currentThread().getContextClassLoader(), ref.getBundle());
if (classes.stream().allMatch(Class::isInterface)) {
final Object proxyInstance = Proxy.newProxyInstance(
loader,
classes.toArray(EMPTY_CLASSES),
(proxy, method, args) -> doInvoke(ref, method, args, interceptorInstancePerMethod));
return (T) proxyInstance;
}
final AsmProxyFactory asm = new AsmProxyFactory();
final Class<?> proxyClass = asm.createProxyClass(
loader,
getProxyClassName(classes),
classes.stream().sorted(this::compareClasses).toArray(Class<?>[]::new),
findInterceptedMethods(classes));
return asm.create(proxyClass, (method, args) -> doInvoke(ref, method, args, interceptorInstancePerMethod));
}
private <T> Object doInvoke(final ServiceReference<T> ref,
final Method method, final Object[] args,
final Map<Method, List<InterceptorInstance<?>>> interceptorsPerMethod) throws Exception {
final List<InterceptorInstance<?>> methodInterceptors = interceptorsPerMethod.getOrDefault(method, emptyList());
return new InterceptorInvocationContext<>(ref, methodInterceptors, method, args).proceed();
}
private int compareClasses(final Class<?> c1, final Class<?> c2) {
if (c1 == c2) {
return 0;
}
if (c1.isAssignableFrom(c2)) {
return 1;
}
if (c2.isAssignableFrom(c1)) {
return -1;
}
if (c1.isInterface() && !c2.isInterface()) {
return 1;
}
if (c2.isInterface() && !c1.isInterface()) {
return -1;
}
if (!c1.isInterface() && !c2.isInterface()) {
throw new IllegalArgumentException("No common class between " + c1 + " and " + c2);
}
return c1.getName().compareTo(c2.getName()); // just to be deterministic
}
private Method[] findInterceptedMethods(final List<Class<?>> classes) {
return classes.stream()
.flatMap(c -> c.isInterface() ? Stream.of(c.getMethods()) : findMethods(c))
.distinct()
.filter(method -> Modifier.isPublic(method.getModifiers())) // todo: enable protected? not that scr friendly but doable
.toArray(Method[]::new);
}
private Stream<Method> findMethods(final Class<?> clazz) {
return clazz == null || Object.class == clazz ?
Stream.empty() :
Stream.concat(Stream.of(clazz.getDeclaredMethods()), findMethods(clazz.getSuperclass()));
}
private String getProxyClassName(final List<Class<?>> classes) {
return classes.iterator().next().getName() + "$$KarafInterceptorProxy" +
classes.stream().skip(1).map(c -> c.getName().replace(".", "_").replace("$", "")).collect(joining("__"));
}
static class ProxyClassLoader extends ClassLoader {
private final Bundle bundle;
private final Map<String, Class<?>> classes = new ConcurrentHashMap<>();
ProxyClassLoader(final ClassLoader parent, final Bundle bundle) {
super(parent);
this.bundle = bundle;
}
@Override
protected Class<?> loadClass(final String name, final boolean resolve) throws ClassNotFoundException {
final Class<?> clazz = classes.get(name);
if (clazz != null) {
return clazz;
}
if (bundle != null) {
try {
return bundle.loadClass(name);
} catch (final ClassNotFoundException cnfe) {
if (name != null && name.startsWith("org.apache.karaf.service.interceptor.")) {
return getClass().getClassLoader().loadClass(name);
}
throw cnfe;
}
}
return super.loadClass(name, resolve);
}
@Override
public URL getResource(final String name) {
return bundle.getResource(name);
}
@Override
public Enumeration<URL> getResources(final String name) throws IOException {
return bundle.getResources(name);
}
@Override
public InputStream getResourceAsStream(final String name) {
return ofNullable(getResource(name)).map(u -> {
try {
return u.openStream();
} catch (final IOException e) {
throw new IllegalStateException(e);
}
}).orElse(null);
}
public <T> Class<T> getOrRegister(final String proxyClassName, final byte[] proxyBytes,
final Package pck, final ProtectionDomain protectionDomain) {
final String key = proxyClassName.replace('/', '.');
Class<?> existing = classes.get(key);
if (existing == null) {
synchronized (this) {
existing = classes.get(key);
if (existing == null) {
definePackageFor(pck, protectionDomain);
existing = super.defineClass(proxyClassName, proxyBytes, 0, proxyBytes.length);
resolveClass(existing);
classes.put(key, existing);
}
}
}
return (Class<T>) existing;
}
private void definePackageFor(final Package model, final ProtectionDomain protectionDomain) {
if (model == null) {
return;
}
if (getPackage(model.getName()) == null) {
if (model.isSealed() && protectionDomain != null &&
protectionDomain.getCodeSource() != null &&
protectionDomain.getCodeSource().getLocation() != null) {
definePackage(
model.getName(),
model.getSpecificationTitle(),
model.getSpecificationVersion(),
model.getSpecificationVendor(),
model.getImplementationTitle(),
model.getImplementationVersion(),
model.getImplementationVendor(),
protectionDomain.getCodeSource().getLocation());
} else {
definePackage(
model.getName(),
model.getSpecificationTitle(),
model.getSpecificationVersion(),
model.getSpecificationVendor(),
model.getImplementationTitle(),
model.getImplementationVersion(),
model.getImplementationVendor(),
null);
}
}
}
}
}