/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.openjpa.junit5.internal;

import org.apache.openjpa.conf.OpenJPAConfigurationImpl;
import org.apache.openjpa.enhance.AsmAdaptor;
import org.apache.openjpa.enhance.PCEnhancer;
import org.apache.openjpa.enhance.PersistenceCapable;
import org.apache.openjpa.junit5.OpenJPASupport;
import org.apache.openjpa.lib.log.LogFactory;
import org.apache.openjpa.lib.log.LogFactoryImpl;
import org.apache.openjpa.lib.log.SLF4JLogFactory;
import org.apache.openjpa.meta.MetaDataRepository;
import org.apache.openjpa.persistence.PersistenceMetaDataFactory;
import org.apache.xbean.asm7.AnnotationVisitor;
import org.apache.xbean.asm7.ClassReader;
import org.apache.xbean.asm7.Type;
import org.apache.xbean.asm7.shade.commons.EmptyVisitor;
import org.apache.xbean.finder.ClassLoaders;
import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.platform.commons.util.AnnotationUtils;
import serp.bytecode.BCClass;
import serp.bytecode.Project;

import javax.persistence.Embeddable;
import javax.persistence.Entity;
import javax.persistence.MappedSuperclass;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.net.URL;
import java.nio.file.FileVisitResult;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.SimpleFileVisitor;
import java.nio.file.StandardOpenOption;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.ArrayList;
import java.util.Collection;
import java.util.logging.Logger;
import java.util.stream.Stream;

import static java.util.Arrays.asList;
import static org.apache.xbean.asm7.ClassReader.SKIP_CODE;
import static org.apache.xbean.asm7.ClassReader.SKIP_DEBUG;
import static org.apache.xbean.asm7.ClassReader.SKIP_FRAMES;

public class OpenJPAExtension implements BeforeAllCallback {
    private static final Logger LOGGER = Logger.getLogger(OpenJPAExtension.class.getName());

    @Override
    public void beforeAll(final ExtensionContext context) {
        AnnotationUtils.findAnnotation(context.getElement(), OpenJPASupport.class).ifPresent(s -> {
            final ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
            final OpenJpaClassLoader enhancementClassLoader = new OpenJpaClassLoader(
                    classLoader, createLogFactory(classLoader, s.logFactory()));
            final Thread thread = Thread.currentThread();
            thread.setContextClassLoader(enhancementClassLoader);
            try {
                if (s.auto()) {
                    try {
                        ClassLoaders.findUrls(enhancementClassLoader.getParent()).stream()
                                .map(org.apache.xbean.finder.util.Files::toFile)
                                .filter(File::isDirectory)
                                .map(File::toPath)
                                .forEach(dir -> {
                                    LOGGER.fine(() -> "Enhancing folder '" + dir + "'");
                                    try {
                                        enhanceDirectory(enhancementClassLoader, dir);
                                    } catch (final IOException e) {
                                        throw new IllegalStateException(e);
                                    }
                                });
                    } catch (final IOException e) {
                        throw new IllegalStateException(e);
                    }
                } else {
                    Stream.of(s.entities()).forEach(e -> {
                        try {
                            enhancementClassLoader.loadClass(e);
                        } catch (final ClassNotFoundException e1) {
                            throw new IllegalArgumentException(e1);
                        }
                    });
                }
            } finally {
                thread.setContextClassLoader(enhancementClassLoader.getParent());
            }
        });
    }

    private LogFactory createLogFactory(final ClassLoader classLoader, final Class<?> logFactory) {
        try {
            if (logFactory == LogFactory.class) {
                try {
                    return new SLF4JLogFactory();
                } catch (final Error | Exception e) {
                    return new LogFactoryImpl();
                }
            }
            return logFactory.asSubclass(LogFactory.class).getConstructor().newInstance();
        } catch (final RuntimeException e) {
            throw e;
        } catch (final Exception e) {
            throw new IllegalStateException(e);
        }
    }

    private void enhanceDirectory(final OpenJpaClassLoader enhancementClassLoader, final Path dir) throws IOException {
        Files.walkFileTree(dir, new SimpleFileVisitor<Path>() {
            @Override
            public FileVisitResult visitFile(final Path file, final BasicFileAttributes attrs) throws IOException {
                if (file.getFileName().toString().endsWith(".class")) {
                    final String relativeName = dir.relativize(file).toString();
                    try {
                        enhancementClassLoader.handleEnhancement(
                                relativeName.substring(0, relativeName.length() - ".class".length()));
                    } catch (final ClassNotFoundException e) {
                        throw new IllegalStateException(e);
                    }
                }
                return super.visitFile(file, attrs);
            }
        });
    }

    private static abstract class BaseClassLoader extends ClassLoader {
        private BaseClassLoader(final ClassLoader parent) {
            super(parent);
        }

        protected abstract Class<?> doLoadClass(String name, boolean resolve) throws ClassNotFoundException;

        @Override
        protected Class<?> loadClass(final String name, final boolean resolve) throws ClassNotFoundException {
            if (name != null && !name.startsWith("java") && !name.startsWith("sun") && !name.startsWith("jdk")) {
                return doLoadClass(name, resolve);
            }
            return defaultLoadClass(name, resolve);
        }

        protected Class<?> defaultLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
            return super.loadClass(name, resolve);
        }

        protected byte[] loadBytes(final String name) {
            final URL url = findUrl(name);
            if (url == null || "jar".equals(url.getProtocol()) /*assume done in build*/) {
                return null;
            }
            byte[] buffer = new byte[4096];
            final ByteArrayOutputStream inMem = new ByteArrayOutputStream(buffer.length);
            try (final InputStream is = url.openStream()) {
                int read;
                while ((read = is.read(buffer)) >= 0) {
                    if (read > 0) {
                        inMem.write(buffer, 0, read);
                    }
                }
            } catch (final IOException e) {
                throw new IllegalStateException(e);
            }
            return inMem.toByteArray();
        }

        protected URL findUrl(final String name) {
            return getResource(name.replace('.', '/') + ".class");
        }
    }

    private static class OpenJpaClassLoader extends BaseClassLoader {
        private static final String PERSITENCE_CAPABLE = Type.getDescriptor(PersistenceCapable.class);
        private static final String ENTITY = Type.getDescriptor(Entity.class);
        private static final String EMBEDDABLE = Type.getDescriptor(Embeddable.class);
        private static final String MAPPED_SUPERCLASS = Type.getDescriptor(MappedSuperclass.class);

        private final MetaDataRepository repos;
        private final ClassLoader tmpLoader;
        private final Collection<String> alreadyEnhanced = new ArrayList<>();

        private OpenJpaClassLoader(final ClassLoader parent, final LogFactory logFactory) {
            super(parent);

            final OpenJPAConfigurationImpl conf = new OpenJPAConfigurationImpl();
            conf.setLogFactory(logFactory);

            tmpLoader = new CompanionLoader(parent);
            repos = new MetaDataRepository();
            repos.setConfiguration(conf);
            repos.setMetaDataFactory(new PersistenceMetaDataFactory());
        }

        @Override
        protected synchronized Class<?> doLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
            final Class<?> clazz = findLoadedClass(name);
            if (clazz != null) {
                if (resolve) {
                    resolveClass(clazz);
                }
                return clazz;
            }
            handleEnhancement(name);
            return defaultLoadClass(name, resolve);
        }

        private void handleEnhancement(final String name) throws ClassNotFoundException {
            final byte[] enhanced = ensureEnhancedIfNeeded(name);
            if (enhanced != null && alreadyEnhanced.add(name)) {
                // we could do that but test classes will be loaded with parent loader
                // so just rewrite the class on the fly assuming it was not yet read
                try {
                    Files.write(findTarget(name), enhanced, StandardOpenOption.TRUNCATE_EXISTING);
                    LOGGER.info(() -> "Enhanced '" + name + "'");
                } catch (final IOException e) {
                    throw new ClassNotFoundException(e.getMessage(), e);
                }
            }
        }

        private Path findTarget(final String name) {
            final URL url = findUrl(name);
            if (!"file".equals(url.getProtocol())) {
                throw new IllegalStateException("Only file urls are supported today: " + url);
            }
            return Paths.get(url.getPath());
        }

        private byte[] enhance(final byte[] classBytes) {
            final Thread thread = Thread.currentThread();
            final ClassLoader old = thread.getContextClassLoader();
            thread.setContextClassLoader(tmpLoader);
            try (final InputStream stream = new ByteArrayInputStream(classBytes)) {
                final PCEnhancer enhancer = new PCEnhancer(
                        repos.getConfiguration(),
                        new Project().loadClass(stream, tmpLoader),
                        repos, tmpLoader);
                if (enhancer.run() == PCEnhancer.ENHANCE_NONE) {
                    return null;
                }
                final BCClass pcb = enhancer.getPCBytecode();
                return AsmAdaptor.toByteArray(pcb, pcb.toByteArray());
            } catch (final IOException e) {
                throw new IllegalStateException(e);
            } finally {
                thread.setContextClassLoader(old);
            }
        }

        private boolean isJpaButNotEnhanced(final byte[] classBytes) {
            try (final InputStream stream = new ByteArrayInputStream(classBytes)) {
                final ClassReader reader = new ClassReader(stream);
                reader.accept(new EmptyVisitor() {
                    @Override
                    public void visit(final int version, final int access, final String name,
                                      final String signature, final String superName, final String[] interfaces) {
                        if (interfaces != null && asList(interfaces).contains(PERSITENCE_CAPABLE)) {
                            throw new AlreadyEnhanced(); // exit
                        }
                        super.visit(version, access, name, signature, superName, interfaces);
                    }

                    @Override
                    public AnnotationVisitor visitAnnotation(final String descriptor, final boolean visible) {
                        if (ENTITY.equals(descriptor) ||
                                EMBEDDABLE.equals(descriptor) ||
                                MAPPED_SUPERCLASS.equals(descriptor)) {
                            throw new MissingEnhancement(); // we already went into visit() so we miss the enhancement
                        }
                        return new EmptyVisitor().visitAnnotation(descriptor, visible);
                    }
                }, SKIP_DEBUG + SKIP_CODE + SKIP_FRAMES);
                return false;
            } catch (final IOException e) {
                throw new IllegalStateException(e);
            } catch (final AlreadyEnhanced alreadyEnhanced) {
                return false;
            } catch (final MissingEnhancement alreadyEnhanced) {
                return true;
            }
        }

        private byte[] ensureEnhancedIfNeeded(final String name) {
            final byte[] classBytes = loadBytes(name);
            if (classBytes == null) {
                return null;
            }
            if (isJpaButNotEnhanced(classBytes)) {
                final byte[] enhanced = enhance(classBytes);
                if (enhanced != null) {
                    return enhanced;
                }
                LOGGER.info("'" + name + "' already enhanced");
            }
            return null;
        }
    }

    private static class CompanionLoader extends BaseClassLoader {
        private CompanionLoader(final ClassLoader parent) {
            super(parent);
        }

        @Override
        protected Class<?> doLoadClass(final String name, final boolean resolve) throws ClassNotFoundException {
            final Class<?> clazz = findLoadedClass(name);
            if (clazz != null) {
                if (resolve) {
                    resolveClass(clazz);
                }
                return clazz;
            }
            final byte[] content = loadBytes(name);
            if (content != null) {
                final Class<?> value = super.defineClass(name, content, 0, content.length);
                if (resolve) {
                    resolveClass(value);
                }
                return value;
            }
            return defaultLoadClass(name, resolve);
        }
    }

    private static class MissingEnhancement extends RuntimeException {
    }

    private static class AlreadyEnhanced extends RuntimeException {
    }
}
