blob: ab4b86660a6ed8ca0818d9cd0c585062cf7b41c7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.impl;
import com.google.auto.value.AutoValue;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import java.nio.file.ProviderNotFoundException;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
import org.apache.beam.sdk.extensions.sql.udf.ScalarFn;
import org.apache.beam.sdk.extensions.sql.udf.UdfProvider;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.vendor.calcite.v1_28_0.org.apache.commons.codec.digest.DigestUtils;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Loads {@link UdfProvider} implementations from user-provided jars.
*
* <p>All UDFs are loaded and cached for each jar to mitigate IO costs.
*/
public class JavaUdfLoader {
private static final Logger LOG = LoggerFactory.getLogger(JavaUdfLoader.class);
/**
* Maps the external jar location to the functions the jar defines. Static so it can persist
* across multiple SQL transforms.
*/
private static final Map<String, FunctionDefinitions> functionCache = new HashMap<>();
/** Maps potentially remote jar paths to their local file copies. */
private static final Map<String, File> jarCache = new HashMap<>();
/** Load a user-defined scalar function from the specified jar. */
public ScalarFn loadScalarFunction(List<String> functionPath, String jarPath) {
String functionFullName = String.join(".", functionPath);
try {
FunctionDefinitions functionDefinitions = loadJar(jarPath);
if (!functionDefinitions.scalarFunctions().containsKey(functionPath)) {
throw new IllegalArgumentException(
String.format(
"No implementation of scalar function %s found in %s.%n"
+ " 1. Create a class implementing %s and annotate it with @AutoService(%s.class).%n"
+ " 2. Add function %s to the class's userDefinedScalarFunctions implementation.",
functionFullName,
jarPath,
UdfProvider.class.getSimpleName(),
UdfProvider.class.getSimpleName(),
functionFullName));
}
return functionDefinitions.scalarFunctions().get(functionPath);
} catch (IOException e) {
throw new RuntimeException(
String.format(
"Failed to load user-defined scalar function %s from %s", functionFullName, jarPath),
e);
}
}
/** Load a user-defined aggregate function from the specified jar. */
public AggregateFn loadAggregateFunction(List<String> functionPath, String jarPath) {
String functionFullName = String.join(".", functionPath);
try {
FunctionDefinitions functionDefinitions = loadJar(jarPath);
if (!functionDefinitions.aggregateFunctions().containsKey(functionPath)) {
throw new IllegalArgumentException(
String.format(
"No implementation of aggregate function %s found in %s.%n"
+ " 1. Create a class implementing %s and annotate it with @AutoService(%s.class).%n"
+ " 2. Add function %s to the class's userDefinedAggregateFunctions implementation.",
functionFullName,
jarPath,
UdfProvider.class.getSimpleName(),
UdfProvider.class.getSimpleName(),
functionFullName));
}
return functionDefinitions.aggregateFunctions().get(functionPath);
} catch (IOException e) {
throw new RuntimeException(
String.format(
"Failed to load user-defined aggregate function %s from %s",
functionFullName, jarPath),
e);
}
}
/**
* Creates a temporary local copy of the file at {@code inputPath}, and returns a handle to the
* local copy.
*/
private File downloadFile(String inputPath, String mimeType) throws IOException {
Preconditions.checkArgument(!inputPath.isEmpty(), "Path cannot be empty.");
ResourceId inputResource = FileSystems.matchNewResource(inputPath, false /* is directory */);
try (ReadableByteChannel inputChannel = FileSystems.open(inputResource)) {
File outputFile = File.createTempFile("sql-udf-", inputResource.getFilename());
ResourceId outputResource =
FileSystems.matchNewResource(outputFile.getAbsolutePath(), false /* is directory */);
try (WritableByteChannel outputChannel = FileSystems.create(outputResource, mimeType)) {
ByteStreams.copy(inputChannel, outputChannel);
}
// Compute and log checksum.
try (InputStream inputStream = new FileInputStream(outputFile)) {
LOG.info(
"Copied {} to {} with md5 hash {}.",
inputPath,
outputFile.getAbsolutePath(),
DigestUtils.md5Hex(inputStream));
}
return outputFile;
}
}
private File getLocalJar(String inputJarPath) throws IOException {
if (!jarCache.containsKey(inputJarPath)) {
jarCache.put(inputJarPath, downloadFile(inputJarPath, "application/java-archive"));
}
return jarCache.get(inputJarPath);
}
/**
* This code creates a classloader, which needs permission if a security manager is installed. If
* this code might be invoked by code that does not have security permissions, then the
* classloader creation needs to occur inside a doPrivileged block.
*/
private URLClassLoader createUrlClassLoader(URL[] urls) {
return (URLClassLoader)
AccessController.doPrivileged(
new PrivilegedAction<URLClassLoader>() {
@Override
public URLClassLoader run() {
return new URLClassLoader(urls);
}
});
}
private ClassLoader createClassLoader(String inputJarPath) throws IOException {
File tmpJar = getLocalJar(inputJarPath);
URL url = tmpJar.toURI().toURL();
return createUrlClassLoader(new URL[] {url});
}
public ClassLoader createClassLoader(List<String> inputJarPaths) throws IOException {
List<URL> urls = new ArrayList<>();
for (String inputJar : inputJarPaths) {
urls.add(getLocalJar(inputJar).toURI().toURL());
}
return createUrlClassLoader(urls.toArray(new URL[0]));
}
@VisibleForTesting
Iterator<UdfProvider> getUdfProviders(ClassLoader classLoader) throws IOException {
return ServiceLoader.load(UdfProvider.class, classLoader).iterator();
}
private FunctionDefinitions loadJar(String jarPath) throws IOException {
if (functionCache.containsKey(jarPath)) {
LOG.debug("Using cached function definitions from {}", jarPath);
return functionCache.get(jarPath);
}
ClassLoader classLoader = createClassLoader(jarPath);
Map<List<String>, ScalarFn> scalarFunctions = new HashMap<>();
Map<List<String>, AggregateFn> aggregateFunctions = new HashMap<>();
Iterator<UdfProvider> providers = getUdfProviders(classLoader);
int providersCount = 0;
while (providers.hasNext()) {
providersCount++;
UdfProvider provider = providers.next();
provider
.userDefinedScalarFunctions()
.forEach(
(functionName, implementation) -> {
List<String> functionPath = ImmutableList.copyOf(functionName.split("\\."));
if (scalarFunctions.containsKey(functionPath)) {
throw new IllegalArgumentException(
String.format(
"Found multiple definitions of scalar function %s in %s.",
functionName, jarPath));
}
scalarFunctions.put(functionPath, implementation);
});
provider
.userDefinedAggregateFunctions()
.forEach(
(functionName, implementation) -> {
List<String> functionPath = ImmutableList.copyOf(functionName.split("\\."));
if (aggregateFunctions.containsKey(functionPath)) {
throw new IllegalArgumentException(
String.format(
"Found multiple definitions of aggregate function %s in %s.",
functionName, jarPath));
}
aggregateFunctions.put(functionPath, implementation);
});
}
if (providersCount == 0) {
throw new ProviderNotFoundException(
String.format(
"No %s implementation found in %s. Create a class implementing %s and annotate it with @AutoService(%s.class).",
UdfProvider.class.getSimpleName(),
jarPath,
UdfProvider.class.getSimpleName(),
UdfProvider.class.getSimpleName()));
}
LOG.info(
"Loaded {} implementations of {} from {} with {} scalar function(s).",
providersCount,
UdfProvider.class.getSimpleName(),
jarPath,
scalarFunctions.size());
FunctionDefinitions userFunctionDefinitions =
FunctionDefinitions.newBuilder()
.setScalarFunctions(ImmutableMap.copyOf(scalarFunctions))
.setAggregateFunctions(ImmutableMap.copyOf(aggregateFunctions))
.build();
functionCache.put(jarPath, userFunctionDefinitions);
return userFunctionDefinitions;
}
/** Holds user defined function definitions. */
@AutoValue
abstract static class FunctionDefinitions {
abstract ImmutableMap<List<String>, ScalarFn> scalarFunctions();
abstract ImmutableMap<List<String>, AggregateFn> aggregateFunctions();
@AutoValue.Builder
abstract static class Builder {
abstract Builder setScalarFunctions(ImmutableMap<List<String>, ScalarFn> value);
abstract Builder setAggregateFunctions(ImmutableMap<List<String>, AggregateFn> value);
abstract FunctionDefinitions build();
}
static Builder newBuilder() {
return new AutoValue_JavaUdfLoader_FunctionDefinitions.Builder()
.setScalarFunctions(ImmutableMap.of())
.setAggregateFunctions(ImmutableMap.of());
}
}
}