extracting directory enhancement logic in a runnable for junit5 module to be able to reuse it easily
diff --git a/openjpa-junit5/src/main/java/org/apache/openjpa/junit5/internal/OpenJPADirectoriesEnhancer.java b/openjpa-junit5/src/main/java/org/apache/openjpa/junit5/internal/OpenJPADirectoriesEnhancer.java
new file mode 100644
index 0000000..7c62942
--- /dev/null
+++ b/openjpa-junit5/src/main/java/org/apache/openjpa/junit5/internal/OpenJPADirectoriesEnhancer.java
@@ -0,0 +1,360 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.openjpa.junit5.internal;
+
+import org.apache.openjpa.conf.OpenJPAConfigurationImpl;
+import org.apache.openjpa.enhance.AsmAdaptor;
+import org.apache.openjpa.enhance.PCEnhancer;
+import org.apache.openjpa.enhance.PersistenceCapable;
+import org.apache.openjpa.lib.log.LogFactory;
+import org.apache.openjpa.lib.log.LogFactoryImpl;
+import org.apache.openjpa.lib.log.SLF4JLogFactory;
+import org.apache.openjpa.meta.MetaDataRepository;
+import org.apache.openjpa.persistence.PersistenceMetaDataFactory;
+import org.apache.xbean.asm7.AnnotationVisitor;
+import org.apache.xbean.asm7.ClassReader;
+import org.apache.xbean.asm7.Type;
+import org.apache.xbean.asm7.shade.commons.EmptyVisitor;
+import org.apache.xbean.finder.ClassLoaders;
+import serp.bytecode.BCClass;
+import serp.bytecode.Project;
+
+import javax.persistence.Embeddable;
+import javax.persistence.Entity;
+import javax.persistence.MappedSuperclass;
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URL;
+import java.nio.file.FileVisitResult;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.nio.file.SimpleFileVisitor;
+import java.nio.file.StandardOpenOption;
+import java.nio.file.attribute.BasicFileAttributes;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.logging.Logger;
+import java.util.stream.Stream;
+
+import static java.util.Arrays.asList;
+import static org.apache.xbean.asm7.ClassReader.SKIP_CODE;
+import static org.apache.xbean.asm7.ClassReader.SKIP_DEBUG;
+import static org.apache.xbean.asm7.ClassReader.SKIP_FRAMES;
+
+public class OpenJPADirectoriesEnhancer implements Runnable {
+    private static final Logger LOGGER = Logger.getLogger(OpenJPADirectoriesEnhancer.class.getName());
+    public static final StackTraceElement[] NO_STACK_TRACE = new StackTraceElement[0];
+
+    private final boolean auto;
+    private final String[] entities;
+    private final Class<?> logFactory;
+
+    public OpenJPADirectoriesEnhancer(final boolean auto, final String[] entities, final Class<?> logFactory) {
+        this.auto = auto;
+        this.entities = entities;
+        this.logFactory = logFactory;
+    }
+
+    @Override
+    public void run() {
+        final ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
+        final OpenJpaClassLoader enhancementClassLoader = new OpenJpaClassLoader(
+                classLoader, createLogFactory(classLoader));
+        final Thread thread = Thread.currentThread();
+        thread.setContextClassLoader(enhancementClassLoader);
+        try {
+            if (auto) {
+                try {
+                    ClassLoaders.findUrls(enhancementClassLoader.getParent()).stream()
+                            .map(org.apache.xbean.finder.util.Files::toFile)
+                            .filter(File::isDirectory)
+                            .map(File::toPath)
+                            .forEach(dir -> {
+                                LOGGER.fine(() -> "Enhancing folder '" + dir + "'");
+                                try {
+                                    enhanceDirectory(enhancementClassLoader, dir);
+                                } catch (final IOException e) {
+                                    throw new IllegalStateException(e);
+                                }
+                            });
+                } catch (final IOException e) {
+                    throw new IllegalStateException(e);
+                }
+            } else {
+                Stream.of(entities).forEach(e -> {
+                    try {
+                        enhancementClassLoader.loadClass(e);
+                    } catch (final ClassNotFoundException e1) {
+                        throw new IllegalArgumentException(e1);
+                    }
+                });
+            }
+        } finally {
+            thread.setContextClassLoader(enhancementClassLoader.getParent());
+        }
+    }
+
+    private LogFactory createLogFactory(final ClassLoader classLoader) {
+        try {
+            if (logFactory == null || logFactory == LogFactory.class) {
+                try {
+                    return new SLF4JLogFactory();
+                } catch (final Error | Exception e) {
+                    return new LogFactoryImpl();
+                }
+            }
+            return logFactory.asSubclass(LogFactory.class).getConstructor().newInstance();
+        } catch (final RuntimeException e) {
+            throw e;
+        } catch (final Exception e) {
+            throw new IllegalStateException(e);
+        }
+    }
+
+    private void enhanceDirectory(final OpenJpaClassLoader enhancementClassLoader, final Path dir) throws IOException {
+        Files.walkFileTree(dir, new SimpleFileVisitor<Path>() {
+            @Override
+            public FileVisitResult visitFile(final Path file, final BasicFileAttributes attrs) throws IOException {
+                if (file.getFileName().toString().endsWith(".class")) {
+                    final String relativeName = dir.relativize(file).toString();
+                    try {
+                        enhancementClassLoader.handleEnhancement(
+                                relativeName.substring(0, relativeName.length() - ".class".length()));
+                    } catch (final ClassNotFoundException e) {
+                        throw new IllegalStateException(e);
+                    }
+                }
+                return super.visitFile(file, attrs);
+            }
+        });
+    }
+
+    private static abstract class BaseClassLoader extends ClassLoader {
+        private BaseClassLoader(final ClassLoader parent) {
+            super(parent);
+        }
+
+        protected abstract Class<?> doLoadClass(String name, boolean resolve) throws ClassNotFoundException;
+
+        @Override
+        protected Class<?> loadClass(final String name, final boolean resolve) throws ClassNotFoundException {
+            if (name != null && !name.startsWith("java") && !name.startsWith("sun") && !name.startsWith("jdk")) {
+                return doLoadClass(name, resolve);
+            }
+            return defaultLoadClass(name, resolve);
+        }
+
+        protected Class<?> defaultLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
+            return super.loadClass(name, resolve);
+        }
+
+        protected byte[] loadBytes(final String name) {
+            final URL url = findUrl(name);
+            if (url == null || "jar".equals(url.getProtocol()) /*assume done in build*/) {
+                return null;
+            }
+            byte[] buffer = new byte[4096];
+            final ByteArrayOutputStream inMem = new ByteArrayOutputStream(buffer.length);
+            try (final InputStream is = url.openStream()) {
+                int read;
+                while ((read = is.read(buffer)) >= 0) {
+                    if (read > 0) {
+                        inMem.write(buffer, 0, read);
+                    }
+                }
+            } catch (final IOException e) {
+                throw new IllegalStateException(e);
+            }
+            return inMem.toByteArray();
+        }
+
+        protected URL findUrl(final String name) {
+            return getResource(name.replace('.', '/') + ".class");
+        }
+    }
+
+    private static class OpenJpaClassLoader extends BaseClassLoader {
+        private static final String PERSITENCE_CAPABLE = Type.getDescriptor(PersistenceCapable.class);
+        private static final String ENTITY = Type.getDescriptor(Entity.class);
+        private static final String EMBEDDABLE = Type.getDescriptor(Embeddable.class);
+        private static final String MAPPED_SUPERCLASS = Type.getDescriptor(MappedSuperclass.class);
+
+        private final MetaDataRepository repos;
+        private final ClassLoader tmpLoader;
+        private final Collection<String> alreadyEnhanced = new ArrayList<>();
+
+        private OpenJpaClassLoader(final ClassLoader parent, final LogFactory logFactory) {
+            super(parent);
+
+            final OpenJPAConfigurationImpl conf = new OpenJPAConfigurationImpl();
+            conf.setLogFactory(logFactory);
+
+            tmpLoader = new CompanionLoader(parent);
+            repos = new MetaDataRepository();
+            repos.setConfiguration(conf);
+            repos.setMetaDataFactory(new PersistenceMetaDataFactory());
+        }
+
+        @Override
+        protected synchronized Class<?> doLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
+            final Class<?> clazz = findLoadedClass(name);
+            if (clazz != null) {
+                if (resolve) {
+                    resolveClass(clazz);
+                }
+                return clazz;
+            }
+            handleEnhancement(name);
+            return defaultLoadClass(name, resolve);
+        }
+
+        private void handleEnhancement(final String name) throws ClassNotFoundException {
+            final byte[] enhanced = ensureEnhancedIfNeeded(name);
+            if (enhanced != null && alreadyEnhanced.add(name)) {
+                // we could do that but test classes will be loaded with parent loader
+                // so just rewrite the class on the fly assuming it was not yet read
+                try {
+                    Files.write(findTarget(name), enhanced, StandardOpenOption.TRUNCATE_EXISTING);
+                    LOGGER.info(() -> "Enhanced '" + name + "'");
+                } catch (final IOException e) {
+                    throw new ClassNotFoundException(e.getMessage(), e);
+                }
+            }
+        }
+
+        private Path findTarget(final String name) {
+            final URL url = findUrl(name);
+            if (!"file".equals(url.getProtocol())) {
+                throw new IllegalStateException("Only file urls are supported today: " + url);
+            }
+            return Paths.get(url.getPath());
+        }
+
+        private byte[] enhance(final byte[] classBytes) {
+            final Thread thread = Thread.currentThread();
+            final ClassLoader old = thread.getContextClassLoader();
+            thread.setContextClassLoader(tmpLoader);
+            try (final InputStream stream = new ByteArrayInputStream(classBytes)) {
+                final PCEnhancer enhancer = new PCEnhancer(
+                        repos.getConfiguration(),
+                        new Project().loadClass(stream, tmpLoader),
+                        repos, tmpLoader);
+                if (enhancer.run() == PCEnhancer.ENHANCE_NONE) {
+                    return null;
+                }
+                final BCClass pcb = enhancer.getPCBytecode();
+                return AsmAdaptor.toByteArray(pcb, pcb.toByteArray());
+            } catch (final IOException e) {
+                throw new IllegalStateException(e);
+            } finally {
+                thread.setContextClassLoader(old);
+            }
+        }
+
+        private boolean isJpaButNotEnhanced(final byte[] classBytes) {
+            try (final InputStream stream = new ByteArrayInputStream(classBytes)) {
+                final ClassReader reader = new ClassReader(stream);
+                reader.accept(new EmptyVisitor() {
+                    @Override
+                    public void visit(final int version, final int access, final String name,
+                                      final String signature, final String superName, final String[] interfaces) {
+                        if (interfaces != null && asList(interfaces).contains(PERSITENCE_CAPABLE)) {
+                            throw new AlreadyEnhanced(); // exit
+                        }
+                        super.visit(version, access, name, signature, superName, interfaces);
+                    }
+
+                    @Override
+                    public AnnotationVisitor visitAnnotation(final String descriptor, final boolean visible) {
+                        if (ENTITY.equals(descriptor) ||
+                                EMBEDDABLE.equals(descriptor) ||
+                                MAPPED_SUPERCLASS.equals(descriptor)) {
+                            throw new MissingEnhancement(); // we already went into visit() so we miss the enhancement
+                        }
+                        return new EmptyVisitor().visitAnnotation(descriptor, visible);
+                    }
+                }, SKIP_DEBUG + SKIP_CODE + SKIP_FRAMES);
+                return false;
+            } catch (final IOException e) {
+                throw new IllegalStateException(e);
+            } catch (final AlreadyEnhanced alreadyEnhanced) {
+                return false;
+            } catch (final MissingEnhancement alreadyEnhanced) {
+                return true;
+            }
+        }
+
+        private byte[] ensureEnhancedIfNeeded(final String name) {
+            final byte[] classBytes = loadBytes(name);
+            if (classBytes == null) {
+                return null;
+            }
+            if (isJpaButNotEnhanced(classBytes)) {
+                final byte[] enhanced = enhance(classBytes);
+                if (enhanced != null) {
+                    return enhanced;
+                }
+                LOGGER.info("'" + name + "' already enhanced");
+            }
+            return null;
+        }
+    }
+
+    private static class CompanionLoader extends BaseClassLoader {
+        private CompanionLoader(final ClassLoader parent) {
+            super(parent);
+        }
+
+        @Override
+        protected Class<?> doLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
+            final Class<?> clazz = findLoadedClass(name);
+            if (clazz != null) {
+                if (resolve) {
+                    resolveClass(clazz);
+                }
+                return clazz;
+            }
+            final byte[] content = loadBytes(name);
+            if (content != null) {
+                final Class<?> value = super.defineClass(name, content, 0, content.length);
+                if (resolve) {
+                    resolveClass(value);
+                }
+                return value;
+            }
+            return defaultLoadClass(name, resolve);
+        }
+    }
+
+    private static class MissingEnhancement extends RuntimeException {
+        private MissingEnhancement() {
+            setStackTrace(NO_STACK_TRACE);
+        }
+    }
+
+    private static class AlreadyEnhanced extends RuntimeException {
+        private AlreadyEnhanced() {
+            setStackTrace(NO_STACK_TRACE);
+        }
+    }
+}
diff --git a/openjpa-junit5/src/main/java/org/apache/openjpa/junit5/internal/OpenJPAExtension.java b/openjpa-junit5/src/main/java/org/apache/openjpa/junit5/internal/OpenJPAExtension.java
index 2478cad..2cacd5e 100644
--- a/openjpa-junit5/src/main/java/org/apache/openjpa/junit5/internal/OpenJPAExtension.java
+++ b/openjpa-junit5/src/main/java/org/apache/openjpa/junit5/internal/OpenJPAExtension.java
@@ -18,333 +18,19 @@
  */
 package org.apache.openjpa.junit5.internal;
 
-import org.apache.openjpa.conf.OpenJPAConfigurationImpl;
-import org.apache.openjpa.enhance.AsmAdaptor;
-import org.apache.openjpa.enhance.PCEnhancer;
-import org.apache.openjpa.enhance.PersistenceCapable;
 import org.apache.openjpa.junit5.OpenJPASupport;
-import org.apache.openjpa.lib.log.LogFactory;
-import org.apache.openjpa.lib.log.LogFactoryImpl;
-import org.apache.openjpa.lib.log.SLF4JLogFactory;
-import org.apache.openjpa.meta.MetaDataRepository;
-import org.apache.openjpa.persistence.PersistenceMetaDataFactory;
-import org.apache.xbean.asm7.AnnotationVisitor;
-import org.apache.xbean.asm7.ClassReader;
-import org.apache.xbean.asm7.Type;
-import org.apache.xbean.asm7.shade.commons.EmptyVisitor;
-import org.apache.xbean.finder.ClassLoaders;
 import org.junit.jupiter.api.extension.BeforeAllCallback;
 import org.junit.jupiter.api.extension.ExtensionContext;
 import org.junit.platform.commons.util.AnnotationUtils;
-import serp.bytecode.BCClass;
-import serp.bytecode.Project;
 
-import javax.persistence.Embeddable;
-import javax.persistence.Entity;
-import javax.persistence.MappedSuperclass;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
-import java.io.File;
-import java.io.IOException;
-import java.io.InputStream;
-import java.lang.reflect.InvocationTargetException;
-import java.net.URL;
-import java.nio.file.FileVisitResult;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.Paths;
-import java.nio.file.SimpleFileVisitor;
-import java.nio.file.StandardOpenOption;
-import java.nio.file.attribute.BasicFileAttributes;
-import java.util.ArrayList;
-import java.util.Collection;
 import java.util.logging.Logger;
-import java.util.stream.Stream;
-
-import static java.util.Arrays.asList;
-import static org.apache.xbean.asm7.ClassReader.SKIP_CODE;
-import static org.apache.xbean.asm7.ClassReader.SKIP_DEBUG;
-import static org.apache.xbean.asm7.ClassReader.SKIP_FRAMES;
 
 public class OpenJPAExtension implements BeforeAllCallback {
     private static final Logger LOGGER = Logger.getLogger(OpenJPAExtension.class.getName());
 
     @Override
     public void beforeAll(final ExtensionContext context) {
-        AnnotationUtils.findAnnotation(context.getElement(), OpenJPASupport.class).ifPresent(s -> {
-            final ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
-            final OpenJpaClassLoader enhancementClassLoader = new OpenJpaClassLoader(
-                    classLoader, createLogFactory(classLoader, s.logFactory()));
-            final Thread thread = Thread.currentThread();
-            thread.setContextClassLoader(enhancementClassLoader);
-            try {
-                if (s.auto()) {
-                    try {
-                        ClassLoaders.findUrls(enhancementClassLoader.getParent()).stream()
-                                .map(org.apache.xbean.finder.util.Files::toFile)
-                                .filter(File::isDirectory)
-                                .map(File::toPath)
-                                .forEach(dir -> {
-                                    LOGGER.fine(() -> "Enhancing folder '" + dir + "'");
-                                    try {
-                                        enhanceDirectory(enhancementClassLoader, dir);
-                                    } catch (final IOException e) {
-                                        throw new IllegalStateException(e);
-                                    }
-                                });
-                    } catch (final IOException e) {
-                        throw new IllegalStateException(e);
-                    }
-                } else {
-                    Stream.of(s.entities()).forEach(e -> {
-                        try {
-                            enhancementClassLoader.loadClass(e);
-                        } catch (final ClassNotFoundException e1) {
-                            throw new IllegalArgumentException(e1);
-                        }
-                    });
-                }
-            } finally {
-                thread.setContextClassLoader(enhancementClassLoader.getParent());
-            }
-        });
-    }
-
-    private LogFactory createLogFactory(final ClassLoader classLoader, final Class<?> logFactory) {
-        try {
-            if (logFactory == LogFactory.class) {
-                try {
-                    return new SLF4JLogFactory();
-                } catch (final Error | Exception e) {
-                    return new LogFactoryImpl();
-                }
-            }
-            return logFactory.asSubclass(LogFactory.class).getConstructor().newInstance();
-        } catch (final RuntimeException e) {
-            throw e;
-        } catch (final Exception e) {
-            throw new IllegalStateException(e);
-        }
-    }
-
-    private void enhanceDirectory(final OpenJpaClassLoader enhancementClassLoader, final Path dir) throws IOException {
-        Files.walkFileTree(dir, new SimpleFileVisitor<Path>() {
-            @Override
-            public FileVisitResult visitFile(final Path file, final BasicFileAttributes attrs) throws IOException {
-                if (file.getFileName().toString().endsWith(".class")) {
-                    final String relativeName = dir.relativize(file).toString();
-                    try {
-                        enhancementClassLoader.handleEnhancement(
-                                relativeName.substring(0, relativeName.length() - ".class".length()));
-                    } catch (final ClassNotFoundException e) {
-                        throw new IllegalStateException(e);
-                    }
-                }
-                return super.visitFile(file, attrs);
-            }
-        });
-    }
-
-    private static abstract class BaseClassLoader extends ClassLoader {
-        private BaseClassLoader(final ClassLoader parent) {
-            super(parent);
-        }
-
-        protected abstract Class<?> doLoadClass(String name, boolean resolve) throws ClassNotFoundException;
-
-        @Override
-        protected Class<?> loadClass(final String name, final boolean resolve) throws ClassNotFoundException {
-            if (name != null && !name.startsWith("java") && !name.startsWith("sun") && !name.startsWith("jdk")) {
-                return doLoadClass(name, resolve);
-            }
-            return defaultLoadClass(name, resolve);
-        }
-
-        protected Class<?> defaultLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
-            return super.loadClass(name, resolve);
-        }
-
-        protected byte[] loadBytes(final String name) {
-            final URL url = findUrl(name);
-            if (url == null || "jar".equals(url.getProtocol()) /*assume done in build*/) {
-                return null;
-            }
-            byte[] buffer = new byte[4096];
-            final ByteArrayOutputStream inMem = new ByteArrayOutputStream(buffer.length);
-            try (final InputStream is = url.openStream()) {
-                int read;
-                while ((read = is.read(buffer)) >= 0) {
-                    if (read > 0) {
-                        inMem.write(buffer, 0, read);
-                    }
-                }
-            } catch (final IOException e) {
-                throw new IllegalStateException(e);
-            }
-            return inMem.toByteArray();
-        }
-
-        protected URL findUrl(final String name) {
-            return getResource(name.replace('.', '/') + ".class");
-        }
-    }
-
-    private static class OpenJpaClassLoader extends BaseClassLoader {
-        private static final String PERSITENCE_CAPABLE = Type.getDescriptor(PersistenceCapable.class);
-        private static final String ENTITY = Type.getDescriptor(Entity.class);
-        private static final String EMBEDDABLE = Type.getDescriptor(Embeddable.class);
-        private static final String MAPPED_SUPERCLASS = Type.getDescriptor(MappedSuperclass.class);
-
-        private final MetaDataRepository repos;
-        private final ClassLoader tmpLoader;
-        private final Collection<String> alreadyEnhanced = new ArrayList<>();
-
-        private OpenJpaClassLoader(final ClassLoader parent, final LogFactory logFactory) {
-            super(parent);
-
-            final OpenJPAConfigurationImpl conf = new OpenJPAConfigurationImpl();
-            conf.setLogFactory(logFactory);
-
-            tmpLoader = new CompanionLoader(parent);
-            repos = new MetaDataRepository();
-            repos.setConfiguration(conf);
-            repos.setMetaDataFactory(new PersistenceMetaDataFactory());
-        }
-
-        @Override
-        protected synchronized Class<?> doLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
-            final Class<?> clazz = findLoadedClass(name);
-            if (clazz != null) {
-                if (resolve) {
-                    resolveClass(clazz);
-                }
-                return clazz;
-            }
-            handleEnhancement(name);
-            return defaultLoadClass(name, resolve);
-        }
-
-        private void handleEnhancement(final String name) throws ClassNotFoundException {
-            final byte[] enhanced = ensureEnhancedIfNeeded(name);
-            if (enhanced != null && alreadyEnhanced.add(name)) {
-                // we could do that but test classes will be loaded with parent loader
-                // so just rewrite the class on the fly assuming it was not yet read
-                try {
-                    Files.write(findTarget(name), enhanced, StandardOpenOption.TRUNCATE_EXISTING);
-                    LOGGER.info(() -> "Enhanced '" + name + "'");
-                } catch (final IOException e) {
-                    throw new ClassNotFoundException(e.getMessage(), e);
-                }
-            }
-        }
-
-        private Path findTarget(final String name) {
-            final URL url = findUrl(name);
-            if (!"file".equals(url.getProtocol())) {
-                throw new IllegalStateException("Only file urls are supported today: " + url);
-            }
-            return Paths.get(url.getPath());
-        }
-
-        private byte[] enhance(final byte[] classBytes) {
-            final Thread thread = Thread.currentThread();
-            final ClassLoader old = thread.getContextClassLoader();
-            thread.setContextClassLoader(tmpLoader);
-            try (final InputStream stream = new ByteArrayInputStream(classBytes)) {
-                final PCEnhancer enhancer = new PCEnhancer(
-                        repos.getConfiguration(),
-                        new Project().loadClass(stream, tmpLoader),
-                        repos, tmpLoader);
-                if (enhancer.run() == PCEnhancer.ENHANCE_NONE) {
-                    return null;
-                }
-                final BCClass pcb = enhancer.getPCBytecode();
-                return AsmAdaptor.toByteArray(pcb, pcb.toByteArray());
-            } catch (final IOException e) {
-                throw new IllegalStateException(e);
-            } finally {
-                thread.setContextClassLoader(old);
-            }
-        }
-
-        private boolean isJpaButNotEnhanced(final byte[] classBytes) {
-            try (final InputStream stream = new ByteArrayInputStream(classBytes)) {
-                final ClassReader reader = new ClassReader(stream);
-                reader.accept(new EmptyVisitor() {
-                    @Override
-                    public void visit(final int version, final int access, final String name,
-                                      final String signature, final String superName, final String[] interfaces) {
-                        if (interfaces != null && asList(interfaces).contains(PERSITENCE_CAPABLE)) {
-                            throw new AlreadyEnhanced(); // exit
-                        }
-                        super.visit(version, access, name, signature, superName, interfaces);
-                    }
-
-                    @Override
-                    public AnnotationVisitor visitAnnotation(final String descriptor, final boolean visible) {
-                        if (ENTITY.equals(descriptor) ||
-                                EMBEDDABLE.equals(descriptor) ||
-                                MAPPED_SUPERCLASS.equals(descriptor)) {
-                            throw new MissingEnhancement(); // we already went into visit() so we miss the enhancement
-                        }
-                        return new EmptyVisitor().visitAnnotation(descriptor, visible);
-                    }
-                }, SKIP_DEBUG + SKIP_CODE + SKIP_FRAMES);
-                return false;
-            } catch (final IOException e) {
-                throw new IllegalStateException(e);
-            } catch (final AlreadyEnhanced alreadyEnhanced) {
-                return false;
-            } catch (final MissingEnhancement alreadyEnhanced) {
-                return true;
-            }
-        }
-
-        private byte[] ensureEnhancedIfNeeded(final String name) {
-            final byte[] classBytes = loadBytes(name);
-            if (classBytes == null) {
-                return null;
-            }
-            if (isJpaButNotEnhanced(classBytes)) {
-                final byte[] enhanced = enhance(classBytes);
-                if (enhanced != null) {
-                    return enhanced;
-                }
-                LOGGER.info("'" + name + "' already enhanced");
-            }
-            return null;
-        }
-    }
-
-    private static class CompanionLoader extends BaseClassLoader {
-        private CompanionLoader(final ClassLoader parent) {
-            super(parent);
-        }
-
-        @Override
-        protected Class<?> doLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
-            final Class<?> clazz = findLoadedClass(name);
-            if (clazz != null) {
-                if (resolve) {
-                    resolveClass(clazz);
-                }
-                return clazz;
-            }
-            final byte[] content = loadBytes(name);
-            if (content != null) {
-                final Class<?> value = super.defineClass(name, content, 0, content.length);
-                if (resolve) {
-                    resolveClass(value);
-                }
-                return value;
-            }
-            return defaultLoadClass(name, resolve);
-        }
-    }
-
-    private static class MissingEnhancement extends RuntimeException {
-    }
-
-    private static class AlreadyEnhanced extends RuntimeException {
+        AnnotationUtils.findAnnotation(context.getElement(), OpenJPASupport.class)
+                .ifPresent(s -> new OpenJPADirectoriesEnhancer(s.auto(), s.entities(), s.logFactory()).run());
     }
 }