blob: d32c7e4207a8dec47cd26ba25eee7d3ca7bfe500 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.expansion.service;
import static org.apache.beam.runners.core.construction.BeamUrns.getUrn;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.auto.value.AutoValue;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Array;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod;
import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods;
import org.apache.beam.model.pipeline.v1.ExternalTransforms.JavaClassLookupPayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
import org.apache.beam.model.pipeline.v1.SchemaApi;
import org.apache.beam.repackaged.core.org.apache.commons.lang3.ClassUtils;
import org.apache.beam.sdk.coders.RowCoder;
import org.apache.beam.sdk.expansion.service.ExpansionService.TransformProvider;
import org.apache.beam.sdk.schemas.JavaFieldSchema;
import org.apache.beam.sdk.schemas.NoSuchSchemaException;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.TypeName;
import org.apache.beam.sdk.schemas.SchemaRegistry;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.util.common.ReflectHelpers;
import org.apache.beam.sdk.values.PInput;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException;
import org.checkerframework.checker.nullness.qual.Nullable;
/**
* A transform provider that can be used to directly instantiate a transform using Java class name
* and builder methods.
*
* @param <InputT> input {@link PInput} type of the transform
* @param <OutputT> output {@link POutput} type of the transform
*/
@SuppressWarnings({"argument.type.incompatible", "assignment.type.incompatible"})
@SuppressFBWarnings("UWF_UNWRITTEN_PUBLIC_OR_PROTECTED_FIELD")
class JavaClassLookupTransformProvider<InputT extends PInput, OutputT extends POutput>
implements TransformProvider<PInput, POutput> {
public static final String ALLOW_LIST_VERSION = "v1";
private static final SchemaRegistry SCHEMA_REGISTRY = SchemaRegistry.createDefault();
private final AllowList allowList;
public JavaClassLookupTransformProvider(AllowList allowList) {
if (!allowList.getVersion().equals(ALLOW_LIST_VERSION)) {
throw new IllegalArgumentException("Unknown allow-list version");
}
this.allowList = allowList;
}
@Override
public PTransform<PInput, POutput> getTransform(FunctionSpec spec) {
JavaClassLookupPayload payload;
try {
payload = JavaClassLookupPayload.parseFrom(spec.getPayload());
} catch (InvalidProtocolBufferException e) {
throw new IllegalArgumentException(
"Invalid payload type for URN " + getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP), e);
}
String className = payload.getClassName();
try {
AllowedClass allowlistClass = null;
if (this.allowList != null) {
for (AllowedClass cls : this.allowList.getAllowedClasses()) {
if (cls.getClassName().equals(className)) {
if (allowlistClass != null) {
throw new IllegalArgumentException(
"Found two matching allowlist classes " + allowlistClass + " and " + cls);
}
allowlistClass = cls;
}
}
}
if (allowlistClass == null) {
throw new UnsupportedOperationException(
"The provided allow list does not enable expanding a transform class by the name "
+ className
+ ".");
}
Class<PTransform<InputT, OutputT>> transformClass =
(Class<PTransform<InputT, OutputT>>)
ReflectHelpers.findClassLoader().loadClass(className);
PTransform<PInput, POutput> transform;
Row constructorRow =
decodeRow(payload.getConstructorSchema(), payload.getConstructorPayload());
if (payload.getConstructorMethod().isEmpty()) {
Constructor<?>[] constructors = transformClass.getConstructors();
Constructor<PTransform<InputT, OutputT>> constructor =
findMappingConstructor(constructors, payload);
Object[] parameterValues =
getParameterValues(
constructor.getParameters(),
constructorRow,
constructor.getGenericParameterTypes());
transform = (PTransform<PInput, POutput>) constructor.newInstance(parameterValues);
} else {
Method[] methods = transformClass.getMethods();
Method method = findMappingConstructorMethod(methods, payload, allowlistClass);
Object[] parameterValues =
getParameterValues(
method.getParameters(), constructorRow, method.getGenericParameterTypes());
transform = (PTransform<PInput, POutput>) method.invoke(null /* static */, parameterValues);
}
return applyBuilderMethods(transform, payload, allowlistClass);
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Could not find class " + className, e);
} catch (InstantiationException
| IllegalArgumentException
| IllegalAccessException
| InvocationTargetException e) {
throw new IllegalArgumentException("Could not instantiate class " + className, e);
}
}
private PTransform<PInput, POutput> applyBuilderMethods(
PTransform<PInput, POutput> transform,
JavaClassLookupPayload payload,
AllowedClass allowListClass) {
for (BuilderMethod builderMethod : payload.getBuilderMethodsList()) {
Method method = getMethod(transform, builderMethod, allowListClass);
try {
Row builderMethodRow = decodeRow(builderMethod.getSchema(), builderMethod.getPayload());
transform =
(PTransform<PInput, POutput>)
method.invoke(
transform,
getParameterValues(
method.getParameters(),
builderMethodRow,
method.getGenericParameterTypes()));
} catch (IllegalAccessException | InvocationTargetException e) {
throw new IllegalArgumentException(
"Could not invoke the builder method "
+ builderMethod
+ " on transform "
+ transform
+ " with parameter schema "
+ builderMethod.getSchema(),
e);
}
}
return transform;
}
private boolean isBuilderMethodForName(
Method method, String nameFromPayload, AllowedClass allowListClass) {
// Lookup based on method annotations
for (Annotation annotation : method.getAnnotations()) {
if (annotation instanceof MultiLanguageBuilderMethod) {
if (nameFromPayload.equals(((MultiLanguageBuilderMethod) annotation).name())) {
if (allowListClass.getAllowedBuilderMethods().contains(nameFromPayload)) {
return true;
} else {
throw new RuntimeException(
"Builder method " + nameFromPayload + " has to be explicitly allowed");
}
}
}
}
// Lookup based on the method name.
boolean match = method.getName().equals(nameFromPayload);
String consideredMethodName = method.getName();
// We provide a simplification for common Java builder pattern naming convention where builder
// methods start with "with". In this case, for a builder method name in the form "withXyz",
// users may just use "xyz". If additional updates to the method name are needed the transform
// has to be updated by adding annotations.
if (!match && consideredMethodName.length() > 4 && consideredMethodName.startsWith("with")) {
consideredMethodName =
consideredMethodName.substring(4, 5).toLowerCase() + consideredMethodName.substring(5);
match = consideredMethodName.equals(nameFromPayload);
}
if (match && !allowListClass.getAllowedBuilderMethods().contains(consideredMethodName)) {
throw new RuntimeException(
"Builder method name " + consideredMethodName + " has to be explicitly allowed");
}
return match;
}
private Method getMethod(
PTransform<PInput, POutput> transform,
BuilderMethod builderMethod,
AllowedClass allowListClass) {
Row builderMethodRow = decodeRow(builderMethod.getSchema(), builderMethod.getPayload());
List<Method> matchingMethods =
Arrays.stream(transform.getClass().getMethods())
.filter(m -> isBuilderMethodForName(m, builderMethod.getName(), allowListClass))
.filter(m -> parametersCompatible(m.getParameters(), builderMethodRow))
.filter(m -> PTransform.class.isAssignableFrom(m.getReturnType()))
.collect(Collectors.toList());
if (matchingMethods.size() != 1) {
throw new RuntimeException(
"Expected to find exactly one matching method in transform "
+ transform
+ " for BuilderMethod"
+ builderMethod
+ " but found "
+ matchingMethods.size());
}
return matchingMethods.get(0);
}
private static boolean isPrimitiveOrWrapperOrString(java.lang.Class<?> type) {
return ClassUtils.isPrimitiveOrWrapper(type) || type == String.class;
}
private Schema getParameterSchema(Class<?> parameterClass) {
Schema parameterSchema;
try {
parameterSchema = SCHEMA_REGISTRY.getSchema(parameterClass);
} catch (NoSuchSchemaException e) {
SCHEMA_REGISTRY.registerSchemaProvider(parameterClass, new JavaFieldSchema());
try {
parameterSchema = SCHEMA_REGISTRY.getSchema(parameterClass);
} catch (NoSuchSchemaException e1) {
throw new RuntimeException(e1);
}
if (parameterSchema != null && parameterSchema.getFieldCount() == 0) {
throw new RuntimeException(
"Could not determine a valid schema for parameter class " + parameterClass);
}
}
return parameterSchema;
}
private boolean parametersCompatible(
java.lang.reflect.Parameter[] methodParameters, Row constructorRow) {
Schema constructorSchema = constructorRow.getSchema();
if (methodParameters.length != constructorSchema.getFieldCount()) {
return false;
}
for (int i = 0; i < methodParameters.length; i++) {
java.lang.reflect.Parameter parameterFromReflection = methodParameters[i];
Field parameterFromPayload = constructorSchema.getField(i);
String paramNameFromReflection = parameterFromReflection.getName();
if (!paramNameFromReflection.startsWith("arg")
&& !paramNameFromReflection.equals(parameterFromPayload.getName())) {
// Parameter name through reflection is from the class file (not through synthesizing,
// hence we can validate names)
return false;
}
Class<?> parameterClass = parameterFromReflection.getType();
if (isPrimitiveOrWrapperOrString(parameterClass)) {
continue;
}
// We perform additional validation for arrays and non-primitive types.
if (parameterClass.isArray()) {
Class<?> arrayFieldClass = parameterClass.getComponentType();
if (parameterFromPayload.getType().getTypeName() != TypeName.ARRAY) {
throw new RuntimeException(
"Expected a schema with a single array field but received "
+ parameterFromPayload.getType().getTypeName());
}
// Following is a best-effort validation that may not cover all cases. Idea is to resolve
// ambiguities as much as possible to determine an exact match for the given set of
// parameters. If there are ambiguities, the expansion will fail.
if (!isPrimitiveOrWrapperOrString(arrayFieldClass)) {
@Nullable Collection<Row> values = constructorRow.getArray(i);
Schema arrayFieldSchema = getParameterSchema(arrayFieldClass);
if (arrayFieldSchema == null) {
throw new RuntimeException("Could not determine a schema for type " + arrayFieldClass);
}
if (values != null) {
@Nullable Row firstItem = values.iterator().next();
if (firstItem != null && !(firstItem.getSchema().assignableTo(arrayFieldSchema))) {
return false;
}
}
}
} else if (constructorRow.getValue(i) instanceof Row) {
@Nullable Row parameterRow = constructorRow.getRow(i);
Schema schema = getParameterSchema(parameterClass);
if (schema == null) {
throw new RuntimeException("Could not determine a schema for type " + parameterClass);
}
if (parameterRow != null && !parameterRow.getSchema().assignableTo(schema)) {
return false;
}
}
}
return true;
}
private @Nullable Object getDecodedValueFromRow(
Class<?> type, Object valueFromRow, @Nullable Type genericType) {
if (isPrimitiveOrWrapperOrString(type)) {
if (!isPrimitiveOrWrapperOrString(valueFromRow.getClass())) {
throw new IllegalArgumentException(
"Expected a Java primitive value but received " + valueFromRow);
}
return valueFromRow;
} else if (type.isArray()) {
Class<?> arrayComponentClass = type.getComponentType();
return getDecodedArrayValueFromRow(arrayComponentClass, valueFromRow);
} else if (Collection.class.isAssignableFrom(type)) {
List<Object> originalList = (List) valueFromRow;
List<Object> decodedList = new ArrayList<>();
for (Object obj : originalList) {
if (genericType instanceof ParameterizedType) {
Class<?> elementType =
(Class<?>) ((ParameterizedType) genericType).getActualTypeArguments()[0];
decodedList.add(getDecodedValueFromRow(elementType, obj, null));
} else {
throw new RuntimeException("Could not determine the generic type of the list");
}
}
return decodedList;
} else if (valueFromRow instanceof Row) {
Row row = (Row) valueFromRow;
SerializableFunction<Row, ?> fromRowFunc;
try {
fromRowFunc = SCHEMA_REGISTRY.getFromRowFunction(type);
} catch (NoSuchSchemaException e) {
throw new IllegalArgumentException(
"Could not determine the row function for class " + type, e);
}
return fromRowFunc.apply(row);
}
throw new RuntimeException("Could not decode the value from Row " + valueFromRow);
}
private Object[] getParameterValues(
java.lang.reflect.Parameter[] parameters, Row constrtuctorRow, Type[] genericTypes) {
ArrayList<Object> parameterValues = new ArrayList<>();
for (int i = 0; i < parameters.length; ++i) {
java.lang.reflect.Parameter parameter = parameters[i];
Class<?> parameterClass = parameter.getType();
Object parameterValue =
getDecodedValueFromRow(parameterClass, constrtuctorRow.getValue(i), genericTypes[i]);
parameterValues.add(parameterValue);
}
return parameterValues.toArray();
}
private Object[] getDecodedArrayValueFromRow(Class<?> arrayComponentType, Object valueFromRow) {
List<Object> originalValues = (List<Object>) valueFromRow;
List<Object> decodedValues = new ArrayList<>();
for (Object obj : originalValues) {
decodedValues.add(getDecodedValueFromRow(arrayComponentType, obj, null));
}
// We have to construct and return an array of the correct type. Otherwise Java reflection
// constructor/method invocations that use the returned value may consider the array as varargs
// (different parameters).
Object valueTypeArray = Array.newInstance(arrayComponentType, decodedValues.size());
for (int i = 0; i < decodedValues.size(); i++) {
Array.set(valueTypeArray, i, arrayComponentType.cast(decodedValues.get(i)));
}
return (Object[]) valueTypeArray;
}
private Constructor<PTransform<InputT, OutputT>> findMappingConstructor(
Constructor<?>[] constructors, JavaClassLookupPayload payload) {
Row constructorRow = decodeRow(payload.getConstructorSchema(), payload.getConstructorPayload());
List<Constructor<?>> mappingConstructors =
Arrays.stream(constructors)
.filter(c -> c.getParameterCount() == payload.getConstructorSchema().getFieldsCount())
.filter(c -> parametersCompatible(c.getParameters(), constructorRow))
.collect(Collectors.toList());
if (mappingConstructors.size() != 1) {
throw new RuntimeException(
"Expected to find a single mapping constructor but found " + mappingConstructors.size());
}
return (Constructor<PTransform<InputT, OutputT>>) mappingConstructors.get(0);
}
private boolean isConstructorMethodForName(
Method method, String nameFromPayload, AllowedClass allowListClass) {
for (Annotation annotation : method.getAnnotations()) {
if (annotation instanceof MultiLanguageConstructorMethod) {
if (nameFromPayload.equals(((MultiLanguageConstructorMethod) annotation).name())) {
if (allowListClass.getAllowedConstructorMethods().contains(nameFromPayload)) {
return true;
} else {
throw new RuntimeException(
"Constructor method " + nameFromPayload + " needs to be explicitly allowed");
}
}
}
}
if (method.getName().equals(nameFromPayload)) {
if (allowListClass.getAllowedConstructorMethods().contains(nameFromPayload)) {
return true;
} else {
throw new RuntimeException(
"Constructor method " + nameFromPayload + " needs to be explicitly allowed");
}
}
return false;
}
private Method findMappingConstructorMethod(
Method[] methods, JavaClassLookupPayload payload, AllowedClass allowListClass) {
Row constructorRow = decodeRow(payload.getConstructorSchema(), payload.getConstructorPayload());
List<Method> mappingConstructorMethods =
Arrays.stream(methods)
.filter(
m -> isConstructorMethodForName(m, payload.getConstructorMethod(), allowListClass))
.filter(m -> m.getParameterCount() == payload.getConstructorSchema().getFieldsCount())
.filter(m -> parametersCompatible(m.getParameters(), constructorRow))
.collect(Collectors.toList());
if (mappingConstructorMethods.size() != 1) {
throw new RuntimeException(
"Expected to find a single mapping constructor method but found "
+ mappingConstructorMethods.size()
+ " Payload was "
+ payload);
}
return mappingConstructorMethods.get(0);
}
@AutoValue
public abstract static class AllowList {
public abstract String getVersion();
public abstract List<AllowedClass> getAllowedClasses();
@JsonCreator
static AllowList create(
@JsonProperty("version") String version,
@JsonProperty("allowedClasses") @javax.annotation.Nullable
List<AllowedClass> allowedClasses) {
if (allowedClasses == null) {
allowedClasses = new ArrayList<>();
}
return new AutoValue_JavaClassLookupTransformProvider_AllowList(version, allowedClasses);
}
}
@AutoValue
public abstract static class AllowedClass {
public abstract String getClassName();
public abstract List<String> getAllowedBuilderMethods();
public abstract List<String> getAllowedConstructorMethods();
@JsonCreator
static AllowedClass create(
@JsonProperty("className") String className,
@JsonProperty("allowedBuilderMethods") @javax.annotation.Nullable
List<String> allowedBuilderMethods,
@JsonProperty("allowedConstructorMethods") @javax.annotation.Nullable
List<String> allowedConstructorMethods) {
if (allowedBuilderMethods == null) {
allowedBuilderMethods = new ArrayList<>();
}
if (allowedConstructorMethods == null) {
allowedConstructorMethods = new ArrayList<>();
}
return new AutoValue_JavaClassLookupTransformProvider_AllowedClass(
className, allowedBuilderMethods, allowedConstructorMethods);
}
}
static Row decodeRow(SchemaApi.Schema schema, ByteString payload) {
Schema payloadSchema = SchemaTranslation.schemaFromProto(schema);
if (payloadSchema.getFieldCount() == 0) {
return Row.withSchema(Schema.of()).build();
}
Row row;
try {
row = RowCoder.of(payloadSchema).decode(payload.newInput());
} catch (IOException e) {
throw new RuntimeException("Error decoding payload", e);
}
return row;
}
}