/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.cassandra.cql3.functions;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.*;
import java.nio.ByteBuffer;
import java.security.*;
import java.security.cert.Certificate;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.ByteStreams;
import com.google.common.reflect.TypeToken;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.datastax.driver.core.TypeCodec;
import org.apache.cassandra.concurrent.NamedThreadFactory;
import org.apache.cassandra.cql3.ColumnIdentifier;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.transport.ProtocolVersion;
import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.security.SecurityThreadGroup;
import org.apache.cassandra.security.ThreadAwareSecurityManager;
import org.eclipse.jdt.core.compiler.IProblem;
import org.eclipse.jdt.internal.compiler.*;
import org.eclipse.jdt.internal.compiler.Compiler;
import org.eclipse.jdt.internal.compiler.classfmt.ClassFileReader;
import org.eclipse.jdt.internal.compiler.classfmt.ClassFormatException;
import org.eclipse.jdt.internal.compiler.env.ICompilationUnit;
import org.eclipse.jdt.internal.compiler.env.INameEnvironment;
import org.eclipse.jdt.internal.compiler.env.NameEnvironmentAnswer;
import org.eclipse.jdt.internal.compiler.impl.CompilerOptions;
import org.eclipse.jdt.internal.compiler.problem.DefaultProblemFactory;

public final class JavaBasedUDFunction extends UDFunction
{
    private static final String BASE_PACKAGE = "org.apache.cassandra.cql3.udf.gen";

    private static final Pattern JAVA_LANG_PREFIX = Pattern.compile("\\bjava\\.lang\\.");

    static final Logger logger = LoggerFactory.getLogger(JavaBasedUDFunction.class);

    private static final AtomicInteger classSequence = new AtomicInteger();

    // use a JVM standard ExecutorService as DebuggableThreadPoolExecutor references internal
    // classes, which triggers AccessControlException from the UDF sandbox
    private static final UDFExecutorService executor =
        new UDFExecutorService(new NamedThreadFactory("UserDefinedFunctions",
                                                      Thread.MIN_PRIORITY,
                                                      udfClassLoader,
                                                      new SecurityThreadGroup("UserDefinedFunctions", null, UDFunction::initializeThread)),
                               "userfunction");

    private static final EcjTargetClassLoader targetClassLoader = new EcjTargetClassLoader();

    private static final UDFByteCodeVerifier udfByteCodeVerifier = new UDFByteCodeVerifier();

    private static final ProtectionDomain protectionDomain;

    private static final IErrorHandlingPolicy errorHandlingPolicy = DefaultErrorHandlingPolicies.proceedWithAllProblems();
    private static final IProblemFactory problemFactory = new DefaultProblemFactory(Locale.ENGLISH);
    private static final CompilerOptions compilerOptions;

    /**
     * Poor man's template - just a text file splitted at '#' chars.
     * Each string at an even index is a constant string (just copied),
     * each string at an odd index is an 'instruction'.
     */
    private static final String[] javaSourceTemplate;

    static
    {
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/Class", "forName");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/Class", "getClassLoader");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/Class", "getResource");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/Class", "getResourceAsStream");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "clearAssertionStatus");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getResource");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getResourceAsStream");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getResources");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getSystemClassLoader");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getSystemResource");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getSystemResourceAsStream");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getSystemResources");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "loadClass");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "setClassAssertionStatus");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "setDefaultAssertionStatus");
        udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "setPackageAssertionStatus");
        udfByteCodeVerifier.addDisallowedMethodCall("java/nio/ByteBuffer", "allocateDirect");
        for (String ia : new String[]{"java/net/InetAddress", "java/net/Inet4Address", "java/net/Inet6Address"})
        {
            // static method, probably performing DNS lookups (despite SecurityManager)
            udfByteCodeVerifier.addDisallowedMethodCall(ia, "getByAddress");
            udfByteCodeVerifier.addDisallowedMethodCall(ia, "getAllByName");
            udfByteCodeVerifier.addDisallowedMethodCall(ia, "getByName");
            udfByteCodeVerifier.addDisallowedMethodCall(ia, "getLocalHost");
            // instance methods, probably performing DNS lookups (despite SecurityManager)
            udfByteCodeVerifier.addDisallowedMethodCall(ia, "getHostName");
            udfByteCodeVerifier.addDisallowedMethodCall(ia, "getCanonicalHostName");
            // ICMP PING
            udfByteCodeVerifier.addDisallowedMethodCall(ia, "isReachable");
        }
        udfByteCodeVerifier.addDisallowedClass("java/net/NetworkInterface");
        udfByteCodeVerifier.addDisallowedClass("java/net/SocketException");

        Map<String, String> settings = new HashMap<>();
        settings.put(CompilerOptions.OPTION_LineNumberAttribute,
                     CompilerOptions.GENERATE);
        settings.put(CompilerOptions.OPTION_SourceFileAttribute,
                     CompilerOptions.DISABLED);
        settings.put(CompilerOptions.OPTION_ReportDeprecation,
                     CompilerOptions.IGNORE);
        settings.put(CompilerOptions.OPTION_Source,
                     CompilerOptions.VERSION_1_8);
        settings.put(CompilerOptions.OPTION_TargetPlatform,
                     CompilerOptions.VERSION_1_8);

        compilerOptions = new CompilerOptions(settings);
        compilerOptions.parseLiteralExpressionsAsConstants = true;

        try (InputStream input = JavaBasedUDFunction.class.getResource("JavaSourceUDF.txt").openConnection().getInputStream())
        {
            ByteArrayOutputStream output = new ByteArrayOutputStream();
            FBUtilities.copy(input, output, Long.MAX_VALUE);
            String template = output.toString();

            StringTokenizer st = new StringTokenizer(template, "#");
            javaSourceTemplate = new String[st.countTokens()];
            for (int i = 0; st.hasMoreElements(); i++)
                javaSourceTemplate[i] = st.nextToken();
        }
        catch (IOException e)
        {
            throw new RuntimeException(e);
        }

        CodeSource codeSource;
        try
        {
            codeSource = new CodeSource(new URL("udf", "localhost", 0, "/java", new URLStreamHandler()
            {
                protected URLConnection openConnection(URL u)
                {
                    return null;
                }
            }), (Certificate[])null);
        }
        catch (MalformedURLException e)
        {
            throw new RuntimeException(e);
        }

        protectionDomain = new ProtectionDomain(codeSource, ThreadAwareSecurityManager.noPermissions, targetClassLoader, null);
    }

    private final JavaUDF javaUDF;

    JavaBasedUDFunction(FunctionName name, List<ColumnIdentifier> argNames, List<AbstractType<?>> argTypes,
                        AbstractType<?> returnType, boolean calledOnNullInput, String body)
    {
        super(name, argNames, argTypes, UDHelper.driverTypes(argTypes),
              returnType, UDHelper.driverType(returnType), calledOnNullInput, "java", body);

        // javaParamTypes is just the Java representation for argTypes resp. argDataTypes
        TypeToken<?>[] javaParamTypes = UDHelper.typeTokens(argCodecs, calledOnNullInput);
        // javaReturnType is just the Java representation for returnType resp. returnTypeCodec
        TypeToken<?> javaReturnType = returnCodec.getJavaType();

        // put each UDF in a separate package to prevent cross-UDF code access
        String pkgName = BASE_PACKAGE + '.' + generateClassName(name, 'p');
        String clsName = generateClassName(name, 'C');

        String executeInternalName = generateClassName(name, 'x');

        StringBuilder javaSourceBuilder = new StringBuilder();
        int lineOffset = 1;
        for (int i = 0; i < javaSourceTemplate.length; i++)
        {
            String s = javaSourceTemplate[i];

            // strings at odd indexes are 'instructions'
            if ((i & 1) == 1)
            {
                switch (s)
                {
                    case "package_name":
                        s = pkgName;
                        break;
                    case "class_name":
                        s = clsName;
                        break;
                    case "body":
                        lineOffset = countNewlines(javaSourceBuilder);
                        s = body;
                        break;
                    case "arguments":
                        s = generateArguments(javaParamTypes, argNames, false);
                        break;
                    case "arguments_aggregate":
                        s = generateArguments(javaParamTypes, argNames, true);
                        break;
                    case "argument_list":
                        s = generateArgumentList(javaParamTypes, argNames);
                        break;
                    case "return_type":
                        s = javaSourceName(javaReturnType);
                        break;
                    case "execute_internal_name":
                        s = executeInternalName;
                        break;
                }
            }

            javaSourceBuilder.append(s);
        }

        String targetClassName = pkgName + '.' + clsName;

        String javaSource = javaSourceBuilder.toString();

        logger.trace("Compiling Java source UDF '{}' as class '{}' using source:\n{}", name, targetClassName, javaSource);

        try
        {
            EcjCompilationUnit compilationUnit = new EcjCompilationUnit(javaSource, targetClassName);

            Compiler compiler = new Compiler(compilationUnit,
                                                                               errorHandlingPolicy,
                                                                               compilerOptions,
                                                                               compilationUnit,
                                                                               problemFactory);
            compiler.compile(new ICompilationUnit[]{ compilationUnit });

            if (compilationUnit.problemList != null && !compilationUnit.problemList.isEmpty())
            {
                boolean fullSource = false;
                StringBuilder problems = new StringBuilder();
                for (IProblem problem : compilationUnit.problemList)
                {
                    long ln = problem.getSourceLineNumber() - lineOffset;
                    if (ln < 1L)
                    {
                        if (problem.isError())
                        {
                            // if generated source around UDF source provided by the user is buggy,
                            // this code is appended.
                            problems.append("GENERATED SOURCE ERROR: line ")
                                    .append(problem.getSourceLineNumber())
                                    .append(" (in generated source): ")
                                    .append(problem.getMessage())
                                    .append('\n');
                            fullSource = true;
                        }
                    }
                    else
                    {
                        problems.append("Line ")
                                .append(Long.toString(ln))
                                .append(": ")
                                .append(problem.getMessage())
                                .append('\n');
                    }
                }

                if (fullSource)
                    throw new InvalidRequestException("Java source compilation failed:\n" + problems + "\n generated source:\n" + javaSource);
                else
                    throw new InvalidRequestException("Java source compilation failed:\n" + problems);
            }

            // Verify the UDF bytecode against use of probably dangerous code
            Set<String> errors = udfByteCodeVerifier.verify(targetClassName, targetClassLoader.classData(targetClassName));
            String validDeclare = "not allowed method declared: " + executeInternalName + '(';
            for (Iterator<String> i = errors.iterator(); i.hasNext();)
            {
                String error = i.next();
                // we generate a random name of the private, internal execute method, which is detected by the byte-code verifier
                if (error.startsWith(validDeclare))
                    i.remove();
            }
            if (!errors.isEmpty())
                throw new InvalidRequestException("Java UDF validation failed: " + errors);

            // Load the class and create a new instance of it
            Thread thread = Thread.currentThread();
            ClassLoader orig = thread.getContextClassLoader();
            try
            {
                thread.setContextClassLoader(UDFunction.udfClassLoader);
                // Execute UDF intiialization from UDF class loader

                Class cls = Class.forName(targetClassName, false, targetClassLoader);

                // Count only non-synthetic methods, so code coverage instrumentation doesn't cause a miscount
                int nonSyntheticMethodCount = 0;
                for (Method m : cls.getDeclaredMethods())
                {
                    if (!m.isSynthetic())
                    {
                        nonSyntheticMethodCount += 1;
                    }
                }

                if (nonSyntheticMethodCount != 3 || cls.getDeclaredConstructors().length != 1)
                    throw new InvalidRequestException("Check your source to not define additional Java methods or constructors");
                MethodType methodType = MethodType.methodType(void.class)
                                                  .appendParameterTypes(TypeCodec.class, TypeCodec[].class, UDFContext.class);
                MethodHandle ctor = MethodHandles.lookup().findConstructor(cls, methodType);
                this.javaUDF = (JavaUDF) ctor.invokeWithArguments(returnCodec, argCodecs, udfContext);
            }
            finally
            {
                thread.setContextClassLoader(orig);
            }
        }
        catch (InvocationTargetException e)
        {
            // in case of an ITE, use the cause
            throw new InvalidRequestException(String.format("Could not compile function '%s' from Java source: %s", name, e.getCause()));
        }
        catch (InvalidRequestException | VirtualMachineError e)
        {
            throw e;
        }
        catch (Throwable e)
        {
            logger.error(String.format("Could not compile function '%s' from Java source:%n%s", name, javaSource), e);
            throw new InvalidRequestException(String.format("Could not compile function '%s' from Java source: %s", name, e));
        }
    }

    protected ExecutorService executor()
    {
        return executor;
    }

    protected ByteBuffer executeUserDefined(ProtocolVersion protocolVersion, List<ByteBuffer> params)
    {
        return javaUDF.executeImpl(protocolVersion, params);
    }

    protected Object executeAggregateUserDefined(ProtocolVersion protocolVersion, Object firstParam, List<ByteBuffer> params)
    {
        return javaUDF.executeAggregateImpl(protocolVersion, firstParam, params);
    }

    private static int countNewlines(StringBuilder javaSource)
    {
        int ln = 0;
        for (int i = 0; i < javaSource.length(); i++)
            if (javaSource.charAt(i) == '\n')
                ln++;
        return ln;
    }

    private static String generateClassName(FunctionName name, char prefix)
    {
        String qualifiedName = name.toString();

        StringBuilder sb = new StringBuilder(qualifiedName.length() + 10);
        sb.append(prefix);
        for (int i = 0; i < qualifiedName.length(); i++)
        {
            char c = qualifiedName.charAt(i);
            if (Character.isJavaIdentifierPart(c))
                sb.append(c);
            else
                sb.append(Integer.toHexString(((short)c)&0xffff));
        }
        sb.append('_')
          .append(ThreadLocalRandom.current().nextInt() & 0xffffff)
          .append('_')
          .append(classSequence.incrementAndGet());
        return sb.toString();
    }

    @VisibleForTesting
    public static String javaSourceName(TypeToken<?> type)
    {
        String n = type.toString();
        return JAVA_LANG_PREFIX.matcher(n).replaceAll("");
    }

    private static String generateArgumentList(TypeToken<?>[] paramTypes, List<ColumnIdentifier> argNames)
    {
        // initial builder size can just be a guess (prevent temp object allocations)
        StringBuilder code = new StringBuilder(32 * paramTypes.length);
        for (int i = 0; i < paramTypes.length; i++)
        {
            if (i > 0)
                code.append(", ");
            code.append(javaSourceName(paramTypes[i]))
                .append(' ')
                .append(argNames.get(i));
        }
        return code.toString();
    }

    /**
     * Generate Java source code snippet for the arguments part to call the UDF implementation function -
     * i.e. the {@code private #return_type# #execute_internal_name#(#argument_list#)} function
     * (see {@code JavaSourceUDF.txt} template file for details).
     * <p>
     * This method generates the arguments code snippet for both {@code executeImpl} and
     * {@code executeAggregateImpl}. General signature for both is the {@code protocolVersion} and
     * then all UDF arguments. For aggregation UDF calls the first argument is always unserialized as
     * that is the state variable.
     * </p>
     * <p>
     * An example output for {@code executeImpl}:
     * {@code (double) super.compose_double(protocolVersion, 0, params.get(0)), (double) super.compose_double(protocolVersion, 1, params.get(1))}
     * </p>
     * <p>
     * Similar output for {@code executeAggregateImpl}:
     * {@code firstParam, (double) super.compose_double(protocolVersion, 1, params.get(1))}
     * </p>
     */
    private static String generateArguments(TypeToken<?>[] paramTypes, List<ColumnIdentifier> argNames, boolean forAggregate)
    {
        StringBuilder code = new StringBuilder(64 * paramTypes.length);
        for (int i = 0; i < paramTypes.length; i++)
        {
            if (i > 0)
                // add separator, if not the first argument
                code.append(",\n");

            // add comment only if trace is enabled
            if (logger.isTraceEnabled())
                code.append("            /* parameter '").append(argNames.get(i)).append("' */\n");

            // cast to Java type
            code.append("            (").append(javaSourceName(paramTypes[i])).append(") ");

            if (forAggregate && i == 0)
                // special case for aggregations where the state variable (1st arg to state + final function and
                // return value from state function) is not re-serialized
                code.append("firstParam");
            else
                // generate object representation of input parameter (call UDFunction.compose)
                code.append(composeMethod(paramTypes[i])).append("(protocolVersion, ").append(i).append(", params.get(").append(forAggregate ? i - 1 : i).append("))");
        }
        return code.toString();
    }

    private static String composeMethod(TypeToken<?> type)
    {
        return (type.isPrimitive()) ? ("super.compose_" + type.getRawType().getName()) : "super.compose";
    }

    // Java source UDFs are a very simple compilation task, which allows us to let one class implement
    // all interfaces required by ECJ.
    static final class EcjCompilationUnit implements ICompilationUnit, ICompilerRequestor, INameEnvironment
    {
        List<IProblem> problemList;
        private final String className;
        private final char[] sourceCode;

        EcjCompilationUnit(String sourceCode, String className)
        {
            this.className = className;
            this.sourceCode = sourceCode.toCharArray();
        }

        // ICompilationUnit

        @Override
        public char[] getFileName()
        {
            return sourceCode;
        }

        @Override
        public char[] getContents()
        {
            return sourceCode;
        }

        @Override
        public char[] getMainTypeName()
        {
            int dot = className.lastIndexOf('.');
            return ((dot > 0) ? className.substring(dot + 1) : className).toCharArray();
        }

        @Override
        public char[][] getPackageName()
        {
            StringTokenizer izer = new StringTokenizer(className, ".");
            char[][] result = new char[izer.countTokens() - 1][];
            for (int i = 0; i < result.length; i++)
                result[i] = izer.nextToken().toCharArray();
            return result;
        }

        @Override
        public boolean ignoreOptionalProblems()
        {
            return false;
        }

        // ICompilerRequestor

        @Override
        public void acceptResult(CompilationResult result)
        {
            if (result.hasErrors())
            {
                IProblem[] problems = result.getProblems();
                if (problemList == null)
                    problemList = new ArrayList<>(problems.length);
                Collections.addAll(problemList, problems);
            }
            else
            {
                ClassFile[] classFiles = result.getClassFiles();
                for (ClassFile classFile : classFiles)
                    targetClassLoader.addClass(className, classFile.getBytes());
            }
        }

        // INameEnvironment

        @Override
        public NameEnvironmentAnswer findType(char[][] compoundTypeName)
        {
            StringBuilder result = new StringBuilder();
            for (int i = 0; i < compoundTypeName.length; i++)
            {
                if (i > 0)
                    result.append('.');
                result.append(compoundTypeName[i]);
            }
            return findType(result.toString());
        }

        @Override
        public NameEnvironmentAnswer findType(char[] typeName, char[][] packageName)
        {
            StringBuilder result = new StringBuilder();
            int i = 0;
            for (; i < packageName.length; i++)
            {
                if (i > 0)
                    result.append('.');
                result.append(packageName[i]);
            }
            if (i > 0)
                result.append('.');
            result.append(typeName);
            return findType(result.toString());
        }

        private NameEnvironmentAnswer findType(String className)
        {
            if (className.equals(this.className))
            {
                return new NameEnvironmentAnswer(this, null);
            }

            String resourceName = className.replace('.', '/') + ".class";

            try (InputStream is = UDFunction.udfClassLoader.getResourceAsStream(resourceName))
            {
                if (is != null)
                {
                    byte[] classBytes = ByteStreams.toByteArray(is);
                    char[] fileName = className.toCharArray();
                    ClassFileReader classFileReader = new ClassFileReader(classBytes, fileName, true);
                    return new NameEnvironmentAnswer(classFileReader, null);
                }
            }
            catch (IOException | ClassFormatException exc)
            {
                throw new RuntimeException(exc);
            }
            return null;
        }

        private boolean isPackage(String result)
        {
            if (result.equals(this.className))
                return false;
            String resourceName = result.replace('.', '/') + ".class";
            try (InputStream is = UDFunction.udfClassLoader.getResourceAsStream(resourceName))
            {
                return is == null;
            }
            catch (IOException e)
            {
                // we are here, since close on is failed. That means it was not null
                return false;
            }
        }

        @Override
        public boolean isPackage(char[][] parentPackageName, char[] packageName)
        {
            StringBuilder result = new StringBuilder();
            int i = 0;
            if (parentPackageName != null)
                for (; i < parentPackageName.length; i++)
                {
                    if (i > 0)
                        result.append('.');
                    result.append(parentPackageName[i]);
                }

            if (Character.isUpperCase(packageName[0]) && !isPackage(result.toString()))
                return false;
            if (i > 0)
                result.append('.');
            result.append(packageName);

            return isPackage(result.toString());
        }

        @Override
        public void cleanup()
        {
        }
    }

    static final class EcjTargetClassLoader extends SecureClassLoader
    {
        EcjTargetClassLoader()
        {
            super(UDFunction.udfClassLoader);
        }

        // This map is usually empty.
        // It only contains data *during* UDF compilation but not during runtime.
        //
        // addClass() is invoked by ECJ after successful compilation of the generated Java source.
        // loadClass(targetClassName) is invoked by buildUDF() after ECJ returned from successful compilation.
        //
        private final Map<String, byte[]> classes = new ConcurrentHashMap<>();

        void addClass(String className, byte[] classData)
        {
            classes.put(className, classData);
        }

        byte[] classData(String className)
        {
            return classes.get(className);
        }

        protected Class<?> findClass(String name) throws ClassNotFoundException
        {
            // remove the class binary - it's only used once - so it's wasting heap
            byte[] classData = classes.remove(name);

            if (classData != null)
                return defineClass(name, classData, 0, classData.length, protectionDomain);

            return getParent().loadClass(name);
        }

        protected PermissionCollection getPermissions(CodeSource codesource)
        {
            return ThreadAwareSecurityManager.noPermissions;
        }
    }}
