blob: 27f9eb8ed8804558e4281f34aa6a6be39a5ec835 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.cql3.functions;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.net.InetAddress;
import java.net.URL;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import com.google.common.base.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.datastax.driver.core.DataType;
import com.datastax.driver.core.TypeCodec;
import com.datastax.driver.core.UserType;
import org.apache.cassandra.config.Config;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.config.Schema;
import org.apache.cassandra.cql3.ColumnIdentifier;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.exceptions.FunctionExecutionException;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.schema.Functions;
import org.apache.cassandra.schema.KeyspaceMetadata;
import org.apache.cassandra.service.ClientWarn;
import org.apache.cassandra.service.MigrationManager;
import org.apache.cassandra.tracing.Tracing;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.JVMStabilityInspector;
/**
* Base class for User Defined Functions.
*/
public abstract class UDFunction extends AbstractFunction implements ScalarFunction
{
protected static final Logger logger = LoggerFactory.getLogger(UDFunction.class);
static final ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean();
protected final List<ColumnIdentifier> argNames;
protected final String language;
protected final String body;
protected final TypeCodec<Object>[] argCodecs;
protected final TypeCodec<Object> returnCodec;
protected final boolean calledOnNullInput;
//
// Access to classes is controlled via allow and disallow lists.
//
// When a class is requested (both during compilation and runtime),
// the allowedPatterns array is searched first, whether the
// requested name matches one of the patterns. If not, nothing is
// returned from the class-loader - meaning ClassNotFoundException
// during runtime and "type could not resolved" during compilation.
//
// If an allowed pattern has been found, the disallowedPatterns
// array is searched for a match. If a match is found, class-loader
// rejects access. Otherwise the class/resource can be loaded.
//
private static final String[] allowedPatterns =
{
"com/datastax/driver/core/",
"com/google/common/reflect/TypeToken",
"java/io/IOException.class",
"java/io/Serializable.class",
"java/lang/",
"java/math/",
"java/net/InetAddress.class",
"java/net/Inet4Address.class",
"java/net/Inet6Address.class",
"java/net/UnknownHostException.class", // req'd by InetAddress
"java/net/NetworkInterface.class", // req'd by InetAddress
"java/net/SocketException.class", // req'd by InetAddress
"java/nio/Buffer.class",
"java/nio/ByteBuffer.class",
"java/text/",
"java/time/",
"java/util/",
"org/apache/cassandra/cql3/functions/JavaUDF.class",
"org/apache/cassandra/exceptions/",
};
// Only need to disallow a pattern, if it would otherwise be allowed via allowedPatterns
private static final String[] disallowedPatterns =
{
"com/datastax/driver/core/Cluster.class",
"com/datastax/driver/core/Metrics.class",
"com/datastax/driver/core/NettyOptions.class",
"com/datastax/driver/core/Session.class",
"com/datastax/driver/core/Statement.class",
"com/datastax/driver/core/TimestampGenerator.class", // indirectly covers ServerSideTimestampGenerator + ThreadLocalMonotonicTimestampGenerator
"java/lang/Compiler.class",
"java/lang/InheritableThreadLocal.class",
"java/lang/Package.class",
"java/lang/Process.class",
"java/lang/ProcessBuilder.class",
"java/lang/ProcessEnvironment.class",
"java/lang/ProcessImpl.class",
"java/lang/Runnable.class",
"java/lang/Runtime.class",
"java/lang/Shutdown.class",
"java/lang/Thread.class",
"java/lang/ThreadGroup.class",
"java/lang/ThreadLocal.class",
"java/lang/instrument/",
"java/lang/invoke/",
"java/lang/management/",
"java/lang/ref/",
"java/lang/reflect/",
"java/util/ServiceLoader.class",
"java/util/Timer.class",
"java/util/concurrent/",
"java/util/function/",
"java/util/jar/",
"java/util/logging/",
"java/util/prefs/",
"java/util/spi/",
"java/util/stream/",
"java/util/zip/",
};
static boolean secureResource(String resource)
{
while (resource.startsWith("/"))
resource = resource.substring(1);
for (String allowed : allowedPatterns)
if (resource.startsWith(allowed))
{
// resource is in allowedPatterns, let's see if it is not explicitly disallowed
for (String disallowed : disallowedPatterns)
if (resource.startsWith(disallowed))
{
logger.trace("access denied: resource {}", resource);
return false;
}
return true;
}
logger.trace("access denied: resource {}", resource);
return false;
}
// setup the UDF class loader with no parent class loader so that we have full control about what class/resource UDF uses
static final ClassLoader udfClassLoader = new UDFClassLoader();
protected UDFunction(FunctionName name,
List<ColumnIdentifier> argNames,
List<AbstractType<?>> argTypes,
AbstractType<?> returnType,
boolean calledOnNullInput,
String language,
String body)
{
this(name, argNames, argTypes, UDHelper.driverTypes(argTypes), returnType,
UDHelper.driverType(returnType), calledOnNullInput, language, body);
}
protected UDFunction(FunctionName name,
List<ColumnIdentifier> argNames,
List<AbstractType<?>> argTypes,
DataType[] argDataTypes,
AbstractType<?> returnType,
DataType returnDataType,
boolean calledOnNullInput,
String language,
String body)
{
super(name, argTypes, returnType);
assert new HashSet<>(argNames).size() == argNames.size() : "duplicate argument names";
this.argNames = argNames;
this.language = language;
this.body = body;
this.argCodecs = UDHelper.codecsFor(argDataTypes);
this.returnCodec = UDHelper.codecFor(returnDataType);
this.calledOnNullInput = calledOnNullInput;
}
public static UDFunction create(FunctionName name,
List<ColumnIdentifier> argNames,
List<AbstractType<?>> argTypes,
AbstractType<?> returnType,
boolean calledOnNullInput,
String language,
String body)
{
UDFunction.assertUdfsEnabled(language);
switch (language)
{
case "java":
return new JavaBasedUDFunction(name, argNames, argTypes, returnType, calledOnNullInput, body);
default:
return new ScriptBasedUDFunction(name, argNames, argTypes, returnType, calledOnNullInput, language, body);
}
}
/**
* It can happen that a function has been declared (is listed in the scheam) but cannot
* be loaded (maybe only on some nodes). This is the case for instance if the class defining
* the class is not on the classpath for some of the node, or after a restart. In that case,
* we create a "fake" function so that:
* 1) the broken function can be dropped easily if that is what people want to do.
* 2) we return a meaningful error message if the function is executed (something more precise
* than saying that the function doesn't exist)
*/
public static UDFunction createBrokenFunction(FunctionName name,
List<ColumnIdentifier> argNames,
List<AbstractType<?>> argTypes,
AbstractType<?> returnType,
boolean calledOnNullInput,
String language,
String body,
InvalidRequestException reason)
{
return new UDFunction(name, argNames, argTypes, returnType, calledOnNullInput, language, body)
{
protected ExecutorService executor()
{
return Executors.newSingleThreadExecutor();
}
public ByteBuffer executeUserDefined(int protocolVersion, List<ByteBuffer> parameters)
{
throw new InvalidRequestException(String.format("Function '%s' exists but hasn't been loaded successfully "
+ "for the following reason: %s. Please see the server log for details",
this,
reason.getMessage()));
}
};
}
public final ByteBuffer execute(int protocolVersion, List<ByteBuffer> parameters)
{
assertUdfsEnabled(language);
if (!isCallableWrtNullable(parameters))
return null;
long tStart = System.nanoTime();
parameters = makeEmptyParametersNull(parameters);
try
{
// Using async UDF execution is expensive (adds about 100us overhead per invocation on a Core-i7 MBPr).
ByteBuffer result = DatabaseDescriptor.enableUserDefinedFunctionsThreads()
? executeAsync(protocolVersion, parameters)
: executeUserDefined(protocolVersion, parameters);
Tracing.trace("Executed UDF {} in {}\u03bcs", name(), (System.nanoTime() - tStart) / 1000);
return result;
}
catch (InvalidRequestException e)
{
throw e;
}
catch (Throwable t)
{
logger.trace("Invocation of user-defined function '{}' failed", this, t);
if (t instanceof VirtualMachineError)
throw (VirtualMachineError) t;
throw FunctionExecutionException.create(this, t);
}
}
public static void assertUdfsEnabled(String language)
{
if (!DatabaseDescriptor.enableUserDefinedFunctions())
throw new InvalidRequestException("User-defined functions are disabled in cassandra.yaml - set enable_user_defined_functions=true to enable");
if (!"java".equalsIgnoreCase(language) && !DatabaseDescriptor.enableScriptedUserDefinedFunctions())
throw new InvalidRequestException("Scripted user-defined functions are disabled in cassandra.yaml - set enable_scripted_user_defined_functions=true to enable if you are aware of the security risks");
}
static void initializeThread()
{
// Get the TypeCodec stuff in Java Driver initialized.
// This is to get the classes loaded outside of the restricted sandbox's security context of a UDF.
TypeCodec.inet().format(InetAddress.getLoopbackAddress());
TypeCodec.ascii().format("");
}
private static final class ThreadIdAndCpuTime extends CompletableFuture<Object>
{
long threadId;
long cpuTime;
ThreadIdAndCpuTime()
{
// Looks weird?
// This call "just" links this class to java.lang.management - otherwise UDFs (script UDFs) might fail due to
// java.security.AccessControlException: access denied: ("java.lang.RuntimePermission" "accessClassInPackage.java.lang.management")
// because class loading would be deferred until setup() is executed - but setup() is called with
// limited privileges.
threadMXBean.getCurrentThreadCpuTime();
}
void setup()
{
this.threadId = Thread.currentThread().getId();
this.cpuTime = threadMXBean.getCurrentThreadCpuTime();
complete(null);
}
}
private ByteBuffer executeAsync(int protocolVersion, List<ByteBuffer> parameters)
{
ThreadIdAndCpuTime threadIdAndCpuTime = new ThreadIdAndCpuTime();
Future<ByteBuffer> future = executor().submit(() -> {
threadIdAndCpuTime.setup();
return executeUserDefined(protocolVersion, parameters);
});
try
{
if (DatabaseDescriptor.getUserDefinedFunctionWarnTimeout() > 0)
try
{
return future.get(DatabaseDescriptor.getUserDefinedFunctionWarnTimeout(), TimeUnit.MILLISECONDS);
}
catch (TimeoutException e)
{
// log and emit a warning that UDF execution took long
String warn = String.format("User defined function %s ran longer than %dms", this, DatabaseDescriptor.getUserDefinedFunctionWarnTimeout());
logger.warn(warn);
ClientWarn.instance.warn(warn);
}
// retry with difference of warn-timeout to fail-timeout
return future.get(DatabaseDescriptor.getUserDefinedFunctionFailTimeout() - DatabaseDescriptor.getUserDefinedFunctionWarnTimeout(), TimeUnit.MILLISECONDS);
}
catch (InterruptedException e)
{
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
catch (ExecutionException e)
{
Throwable c = e.getCause();
if (c instanceof RuntimeException)
throw (RuntimeException) c;
throw new RuntimeException(c);
}
catch (TimeoutException e)
{
// retry a last time with the difference of UDF-fail-timeout to consumed CPU time (just in case execution hit a badly timed GC)
try
{
//The threadIdAndCpuTime shouldn't take a long time to be set so this should return immediately
threadIdAndCpuTime.get(1, TimeUnit.SECONDS);
long cpuTimeMillis = threadMXBean.getThreadCpuTime(threadIdAndCpuTime.threadId) - threadIdAndCpuTime.cpuTime;
cpuTimeMillis /= 1000000L;
return future.get(Math.max(DatabaseDescriptor.getUserDefinedFunctionFailTimeout() - cpuTimeMillis, 0L),
TimeUnit.MILLISECONDS);
}
catch (InterruptedException e1)
{
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
catch (ExecutionException e1)
{
Throwable c = e.getCause();
if (c instanceof RuntimeException)
throw (RuntimeException) c;
throw new RuntimeException(c);
}
catch (TimeoutException e1)
{
TimeoutException cause = new TimeoutException(String.format("User defined function %s ran longer than %dms%s",
this,
DatabaseDescriptor.getUserDefinedFunctionFailTimeout(),
DatabaseDescriptor.getUserFunctionTimeoutPolicy() == Config.UserFunctionTimeoutPolicy.ignore
? "" : " - will stop Cassandra VM"));
FunctionExecutionException fe = FunctionExecutionException.create(this, cause);
JVMStabilityInspector.userFunctionTimeout(cause);
throw fe;
}
}
}
private List<ByteBuffer> makeEmptyParametersNull(List<ByteBuffer> parameters)
{
List<ByteBuffer> r = new ArrayList<>(parameters.size());
for (int i = 0; i < parameters.size(); i++)
{
ByteBuffer param = parameters.get(i);
r.add(UDHelper.isNullOrEmpty(argTypes.get(i), param)
? null : param);
}
return r;
}
protected abstract ExecutorService executor();
public boolean isCallableWrtNullable(List<ByteBuffer> parameters)
{
if (!calledOnNullInput)
for (int i = 0; i < parameters.size(); i++)
if (UDHelper.isNullOrEmpty(argTypes.get(i), parameters.get(i)))
return false;
return true;
}
protected abstract ByteBuffer executeUserDefined(int protocolVersion, List<ByteBuffer> parameters);
public boolean isAggregate()
{
return false;
}
public boolean isNative()
{
return false;
}
public boolean isCalledOnNullInput()
{
return calledOnNullInput;
}
public List<ColumnIdentifier> argNames()
{
return argNames;
}
public String body()
{
return body;
}
public String language()
{
return language;
}
/**
* Used by UDF implementations (both Java code generated by {@link JavaBasedUDFunction}
* and script executor {@link ScriptBasedUDFunction}) to convert the C*
* serialized representation to the Java object representation.
*
* @param protocolVersion the native protocol version used for serialization
* @param argIndex index of the UDF input argument
*/
protected Object compose(int protocolVersion, int argIndex, ByteBuffer value)
{
return compose(argCodecs, protocolVersion, argIndex, value);
}
protected static Object compose(TypeCodec<Object>[] codecs, int protocolVersion, int argIndex, ByteBuffer value)
{
return value == null ? null : UDHelper.deserialize(codecs[argIndex], protocolVersion, value);
}
/**
* Used by UDF implementations (both Java code generated by {@link JavaBasedUDFunction}
* and script executor {@link ScriptBasedUDFunction}) to convert the Java
* object representation for the return value to the C* serialized representation.
*
* @param protocolVersion the native protocol version used for serialization
*/
protected ByteBuffer decompose(int protocolVersion, Object value)
{
return decompose(returnCodec, protocolVersion, value);
}
protected static ByteBuffer decompose(TypeCodec<Object> codec, int protocolVersion, Object value)
{
return value == null ? null : UDHelper.serialize(codec, protocolVersion, value);
}
@Override
public boolean equals(Object o)
{
if (!(o instanceof UDFunction))
return false;
UDFunction that = (UDFunction)o;
return Objects.equal(name, that.name)
&& Objects.equal(argNames, that.argNames)
&& Functions.typesMatch(argTypes, that.argTypes)
&& Functions.typesMatch(returnType, that.returnType)
&& Objects.equal(language, that.language)
&& Objects.equal(body, that.body);
}
@Override
public int hashCode()
{
return Objects.hashCode(name, Functions.typeHashCode(argTypes), Functions.typeHashCode(returnType), returnType, language, body);
}
public void userTypeUpdated(String ksName, String typeName)
{
boolean updated = false;
for (int i = 0; i < argCodecs.length; i++)
{
DataType dataType = argCodecs[i].getCqlType();
if (dataType instanceof UserType)
{
UserType userType = (UserType) dataType;
if (userType.getKeyspace().equals(ksName) && userType.getTypeName().equals(typeName))
{
KeyspaceMetadata ksm = Schema.instance.getKSMetaData(ksName);
assert ksm != null;
org.apache.cassandra.db.marshal.UserType ut = ksm.types.get(ByteBufferUtil.bytes(typeName)).get();
DataType newUserType = UDHelper.driverType(ut);
argCodecs[i] = UDHelper.codecFor(newUserType);
argTypes.set(i, ut);
updated = true;
}
}
}
if (updated)
MigrationManager.announceNewFunction(this, true);
}
private static class UDFClassLoader extends ClassLoader
{
// insecureClassLoader is the C* class loader
static final ClassLoader insecureClassLoader = Thread.currentThread().getContextClassLoader();
public URL getResource(String name)
{
if (!secureResource(name))
return null;
return insecureClassLoader.getResource(name);
}
protected URL findResource(String name)
{
return getResource(name);
}
public Enumeration<URL> getResources(String name)
{
return Collections.emptyEnumeration();
}
protected Class<?> findClass(String name) throws ClassNotFoundException
{
if (!secureResource(name.replace('.', '/') + ".class"))
throw new ClassNotFoundException(name);
return insecureClassLoader.loadClass(name);
}
public Class<?> loadClass(String name) throws ClassNotFoundException
{
if (!secureResource(name.replace('.', '/') + ".class"))
throw new ClassNotFoundException(name);
return super.loadClass(name);
}
}
}