| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.sdk.expansion.service; |
| |
| import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; |
| |
| import com.fasterxml.jackson.annotation.JsonCreator; |
| import com.fasterxml.jackson.annotation.JsonProperty; |
| import com.google.auto.value.AutoValue; |
| import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; |
| import java.io.IOException; |
| import java.lang.annotation.Annotation; |
| import java.lang.reflect.Array; |
| import java.lang.reflect.Constructor; |
| import java.lang.reflect.InvocationTargetException; |
| import java.lang.reflect.Method; |
| import java.lang.reflect.ParameterizedType; |
| import java.lang.reflect.Type; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Collection; |
| import java.util.List; |
| import java.util.stream.Collectors; |
| import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; |
| import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; |
| import org.apache.beam.model.pipeline.v1.ExternalTransforms.JavaClassLookupPayload; |
| import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; |
| import org.apache.beam.model.pipeline.v1.SchemaApi; |
| import org.apache.beam.repackaged.core.org.apache.commons.lang3.ClassUtils; |
| import org.apache.beam.sdk.coders.RowCoder; |
| import org.apache.beam.sdk.expansion.service.ExpansionService.TransformProvider; |
| import org.apache.beam.sdk.schemas.JavaFieldSchema; |
| import org.apache.beam.sdk.schemas.NoSuchSchemaException; |
| import org.apache.beam.sdk.schemas.Schema; |
| import org.apache.beam.sdk.schemas.Schema.Field; |
| import org.apache.beam.sdk.schemas.Schema.TypeName; |
| import org.apache.beam.sdk.schemas.SchemaRegistry; |
| import org.apache.beam.sdk.schemas.SchemaTranslation; |
| import org.apache.beam.sdk.transforms.PTransform; |
| import org.apache.beam.sdk.transforms.SerializableFunction; |
| import org.apache.beam.sdk.util.common.ReflectHelpers; |
| import org.apache.beam.sdk.values.PInput; |
| import org.apache.beam.sdk.values.POutput; |
| import org.apache.beam.sdk.values.Row; |
| import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString; |
| import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException; |
| import org.checkerframework.checker.nullness.qual.Nullable; |
| |
| /** |
| * A transform provider that can be used to directly instantiate a transform using Java class name |
| * and builder methods. |
| * |
| * @param <InputT> input {@link PInput} type of the transform |
| * @param <OutputT> output {@link POutput} type of the transform |
| */ |
| @SuppressWarnings({"argument.type.incompatible", "assignment.type.incompatible"}) |
| @SuppressFBWarnings("UWF_UNWRITTEN_PUBLIC_OR_PROTECTED_FIELD") |
| class JavaClassLookupTransformProvider<InputT extends PInput, OutputT extends POutput> |
| implements TransformProvider<PInput, POutput> { |
| |
| public static final String ALLOW_LIST_VERSION = "v1"; |
| private static final SchemaRegistry SCHEMA_REGISTRY = SchemaRegistry.createDefault(); |
| private final AllowList allowList; |
| |
| public JavaClassLookupTransformProvider(AllowList allowList) { |
| if (!allowList.getVersion().equals(ALLOW_LIST_VERSION)) { |
| throw new IllegalArgumentException("Unknown allow-list version"); |
| } |
| this.allowList = allowList; |
| } |
| |
| @Override |
| public PTransform<PInput, POutput> getTransform(FunctionSpec spec) { |
| JavaClassLookupPayload payload; |
| try { |
| payload = JavaClassLookupPayload.parseFrom(spec.getPayload()); |
| } catch (InvalidProtocolBufferException e) { |
| throw new IllegalArgumentException( |
| "Invalid payload type for URN " + getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP), e); |
| } |
| |
| String className = payload.getClassName(); |
| try { |
| AllowedClass allowlistClass = null; |
| if (this.allowList != null) { |
| for (AllowedClass cls : this.allowList.getAllowedClasses()) { |
| if (cls.getClassName().equals(className)) { |
| if (allowlistClass != null) { |
| throw new IllegalArgumentException( |
| "Found two matching allowlist classes " + allowlistClass + " and " + cls); |
| } |
| allowlistClass = cls; |
| } |
| } |
| } |
| if (allowlistClass == null) { |
| throw new UnsupportedOperationException( |
| "The provided allow list does not enable expanding a transform class by the name " |
| + className |
| + "."); |
| } |
| Class<PTransform<InputT, OutputT>> transformClass = |
| (Class<PTransform<InputT, OutputT>>) |
| ReflectHelpers.findClassLoader().loadClass(className); |
| PTransform<PInput, POutput> transform; |
| Row constructorRow = |
| decodeRow(payload.getConstructorSchema(), payload.getConstructorPayload()); |
| if (payload.getConstructorMethod().isEmpty()) { |
| Constructor<?>[] constructors = transformClass.getConstructors(); |
| Constructor<PTransform<InputT, OutputT>> constructor = |
| findMappingConstructor(constructors, payload); |
| Object[] parameterValues = |
| getParameterValues( |
| constructor.getParameters(), |
| constructorRow, |
| constructor.getGenericParameterTypes()); |
| transform = (PTransform<PInput, POutput>) constructor.newInstance(parameterValues); |
| } else { |
| Method[] methods = transformClass.getMethods(); |
| Method method = findMappingConstructorMethod(methods, payload, allowlistClass); |
| Object[] parameterValues = |
| getParameterValues( |
| method.getParameters(), constructorRow, method.getGenericParameterTypes()); |
| transform = (PTransform<PInput, POutput>) method.invoke(null /* static */, parameterValues); |
| } |
| return applyBuilderMethods(transform, payload, allowlistClass); |
| } catch (ClassNotFoundException e) { |
| throw new IllegalArgumentException("Could not find class " + className, e); |
| } catch (InstantiationException |
| | IllegalArgumentException |
| | IllegalAccessException |
| | InvocationTargetException e) { |
| throw new IllegalArgumentException("Could not instantiate class " + className, e); |
| } |
| } |
| |
| private PTransform<PInput, POutput> applyBuilderMethods( |
| PTransform<PInput, POutput> transform, |
| JavaClassLookupPayload payload, |
| AllowedClass allowListClass) { |
| for (BuilderMethod builderMethod : payload.getBuilderMethodsList()) { |
| Method method = getMethod(transform, builderMethod, allowListClass); |
| try { |
| Row builderMethodRow = decodeRow(builderMethod.getSchema(), builderMethod.getPayload()); |
| transform = |
| (PTransform<PInput, POutput>) |
| method.invoke( |
| transform, |
| getParameterValues( |
| method.getParameters(), |
| builderMethodRow, |
| method.getGenericParameterTypes())); |
| } catch (IllegalAccessException | InvocationTargetException e) { |
| throw new IllegalArgumentException( |
| "Could not invoke the builder method " |
| + builderMethod |
| + " on transform " |
| + transform |
| + " with parameter schema " |
| + builderMethod.getSchema(), |
| e); |
| } |
| } |
| |
| return transform; |
| } |
| |
| private boolean isBuilderMethodForName( |
| Method method, String nameFromPayload, AllowedClass allowListClass) { |
| // Lookup based on method annotations |
| for (Annotation annotation : method.getAnnotations()) { |
| if (annotation instanceof MultiLanguageBuilderMethod) { |
| if (nameFromPayload.equals(((MultiLanguageBuilderMethod) annotation).name())) { |
| if (allowListClass.getAllowedBuilderMethods().contains(nameFromPayload)) { |
| return true; |
| } else { |
| throw new RuntimeException( |
| "Builder method " + nameFromPayload + " has to be explicitly allowed"); |
| } |
| } |
| } |
| } |
| |
| // Lookup based on the method name. |
| boolean match = method.getName().equals(nameFromPayload); |
| String consideredMethodName = method.getName(); |
| |
| // We provide a simplification for common Java builder pattern naming convention where builder |
| // methods start with "with". In this case, for a builder method name in the form "withXyz", |
| // users may just use "xyz". If additional updates to the method name are needed the transform |
| // has to be updated by adding annotations. |
| if (!match && consideredMethodName.length() > 4 && consideredMethodName.startsWith("with")) { |
| consideredMethodName = |
| consideredMethodName.substring(4, 5).toLowerCase() + consideredMethodName.substring(5); |
| match = consideredMethodName.equals(nameFromPayload); |
| } |
| if (match && !allowListClass.getAllowedBuilderMethods().contains(consideredMethodName)) { |
| throw new RuntimeException( |
| "Builder method name " + consideredMethodName + " has to be explicitly allowed"); |
| } |
| return match; |
| } |
| |
| private Method getMethod( |
| PTransform<PInput, POutput> transform, |
| BuilderMethod builderMethod, |
| AllowedClass allowListClass) { |
| |
| Row builderMethodRow = decodeRow(builderMethod.getSchema(), builderMethod.getPayload()); |
| |
| List<Method> matchingMethods = |
| Arrays.stream(transform.getClass().getMethods()) |
| .filter(m -> isBuilderMethodForName(m, builderMethod.getName(), allowListClass)) |
| .filter(m -> parametersCompatible(m.getParameters(), builderMethodRow)) |
| .filter(m -> PTransform.class.isAssignableFrom(m.getReturnType())) |
| .collect(Collectors.toList()); |
| |
| if (matchingMethods.size() != 1) { |
| throw new RuntimeException( |
| "Expected to find exactly one matching method in transform " |
| + transform |
| + " for BuilderMethod" |
| + builderMethod |
| + " but found " |
| + matchingMethods.size()); |
| } |
| return matchingMethods.get(0); |
| } |
| |
| private static boolean isPrimitiveOrWrapperOrString(java.lang.Class<?> type) { |
| return ClassUtils.isPrimitiveOrWrapper(type) || type == String.class; |
| } |
| |
| private Schema getParameterSchema(Class<?> parameterClass) { |
| Schema parameterSchema; |
| try { |
| parameterSchema = SCHEMA_REGISTRY.getSchema(parameterClass); |
| } catch (NoSuchSchemaException e) { |
| |
| SCHEMA_REGISTRY.registerSchemaProvider(parameterClass, new JavaFieldSchema()); |
| try { |
| parameterSchema = SCHEMA_REGISTRY.getSchema(parameterClass); |
| } catch (NoSuchSchemaException e1) { |
| throw new RuntimeException(e1); |
| } |
| if (parameterSchema != null && parameterSchema.getFieldCount() == 0) { |
| throw new RuntimeException( |
| "Could not determine a valid schema for parameter class " + parameterClass); |
| } |
| } |
| return parameterSchema; |
| } |
| |
| private boolean parametersCompatible( |
| java.lang.reflect.Parameter[] methodParameters, Row constructorRow) { |
| Schema constructorSchema = constructorRow.getSchema(); |
| if (methodParameters.length != constructorSchema.getFieldCount()) { |
| return false; |
| } |
| |
| for (int i = 0; i < methodParameters.length; i++) { |
| java.lang.reflect.Parameter parameterFromReflection = methodParameters[i]; |
| Field parameterFromPayload = constructorSchema.getField(i); |
| |
| String paramNameFromReflection = parameterFromReflection.getName(); |
| if (!paramNameFromReflection.startsWith("arg") |
| && !paramNameFromReflection.equals(parameterFromPayload.getName())) { |
| // Parameter name through reflection is from the class file (not through synthesizing, |
| // hence we can validate names) |
| return false; |
| } |
| |
| Class<?> parameterClass = parameterFromReflection.getType(); |
| if (isPrimitiveOrWrapperOrString(parameterClass)) { |
| continue; |
| } |
| |
| // We perform additional validation for arrays and non-primitive types. |
| if (parameterClass.isArray()) { |
| Class<?> arrayFieldClass = parameterClass.getComponentType(); |
| if (parameterFromPayload.getType().getTypeName() != TypeName.ARRAY) { |
| throw new RuntimeException( |
| "Expected a schema with a single array field but received " |
| + parameterFromPayload.getType().getTypeName()); |
| } |
| |
| // Following is a best-effort validation that may not cover all cases. Idea is to resolve |
| // ambiguities as much as possible to determine an exact match for the given set of |
| // parameters. If there are ambiguities, the expansion will fail. |
| if (!isPrimitiveOrWrapperOrString(arrayFieldClass)) { |
| @Nullable Collection<Row> values = constructorRow.getArray(i); |
| Schema arrayFieldSchema = getParameterSchema(arrayFieldClass); |
| if (arrayFieldSchema == null) { |
| throw new RuntimeException("Could not determine a schema for type " + arrayFieldClass); |
| } |
| if (values != null) { |
| @Nullable Row firstItem = values.iterator().next(); |
| if (firstItem != null && !(firstItem.getSchema().assignableTo(arrayFieldSchema))) { |
| return false; |
| } |
| } |
| } |
| } else if (constructorRow.getValue(i) instanceof Row) { |
| @Nullable Row parameterRow = constructorRow.getRow(i); |
| Schema schema = getParameterSchema(parameterClass); |
| if (schema == null) { |
| throw new RuntimeException("Could not determine a schema for type " + parameterClass); |
| } |
| if (parameterRow != null && !parameterRow.getSchema().assignableTo(schema)) { |
| return false; |
| } |
| } |
| } |
| return true; |
| } |
| |
| private @Nullable Object getDecodedValueFromRow( |
| Class<?> type, Object valueFromRow, @Nullable Type genericType) { |
| if (isPrimitiveOrWrapperOrString(type)) { |
| if (!isPrimitiveOrWrapperOrString(valueFromRow.getClass())) { |
| throw new IllegalArgumentException( |
| "Expected a Java primitive value but received " + valueFromRow); |
| } |
| return valueFromRow; |
| } else if (type.isArray()) { |
| Class<?> arrayComponentClass = type.getComponentType(); |
| return getDecodedArrayValueFromRow(arrayComponentClass, valueFromRow); |
| } else if (Collection.class.isAssignableFrom(type)) { |
| List<Object> originalList = (List) valueFromRow; |
| List<Object> decodedList = new ArrayList<>(); |
| for (Object obj : originalList) { |
| if (genericType instanceof ParameterizedType) { |
| Class<?> elementType = |
| (Class<?>) ((ParameterizedType) genericType).getActualTypeArguments()[0]; |
| decodedList.add(getDecodedValueFromRow(elementType, obj, null)); |
| } else { |
| throw new RuntimeException("Could not determine the generic type of the list"); |
| } |
| } |
| return decodedList; |
| } else if (valueFromRow instanceof Row) { |
| Row row = (Row) valueFromRow; |
| SerializableFunction<Row, ?> fromRowFunc; |
| try { |
| fromRowFunc = SCHEMA_REGISTRY.getFromRowFunction(type); |
| } catch (NoSuchSchemaException e) { |
| throw new IllegalArgumentException( |
| "Could not determine the row function for class " + type, e); |
| } |
| return fromRowFunc.apply(row); |
| } |
| throw new RuntimeException("Could not decode the value from Row " + valueFromRow); |
| } |
| |
| private Object[] getParameterValues( |
| java.lang.reflect.Parameter[] parameters, Row constrtuctorRow, Type[] genericTypes) { |
| ArrayList<Object> parameterValues = new ArrayList<>(); |
| for (int i = 0; i < parameters.length; ++i) { |
| java.lang.reflect.Parameter parameter = parameters[i]; |
| Class<?> parameterClass = parameter.getType(); |
| Object parameterValue = |
| getDecodedValueFromRow(parameterClass, constrtuctorRow.getValue(i), genericTypes[i]); |
| parameterValues.add(parameterValue); |
| } |
| |
| return parameterValues.toArray(); |
| } |
| |
| private Object[] getDecodedArrayValueFromRow(Class<?> arrayComponentType, Object valueFromRow) { |
| List<Object> originalValues = (List<Object>) valueFromRow; |
| List<Object> decodedValues = new ArrayList<>(); |
| for (Object obj : originalValues) { |
| decodedValues.add(getDecodedValueFromRow(arrayComponentType, obj, null)); |
| } |
| |
| // We have to construct and return an array of the correct type. Otherwise Java reflection |
| // constructor/method invocations that use the returned value may consider the array as varargs |
| // (different parameters). |
| Object valueTypeArray = Array.newInstance(arrayComponentType, decodedValues.size()); |
| for (int i = 0; i < decodedValues.size(); i++) { |
| Array.set(valueTypeArray, i, arrayComponentType.cast(decodedValues.get(i))); |
| } |
| return (Object[]) valueTypeArray; |
| } |
| |
| private Constructor<PTransform<InputT, OutputT>> findMappingConstructor( |
| Constructor<?>[] constructors, JavaClassLookupPayload payload) { |
| Row constructorRow = decodeRow(payload.getConstructorSchema(), payload.getConstructorPayload()); |
| |
| List<Constructor<?>> mappingConstructors = |
| Arrays.stream(constructors) |
| .filter(c -> c.getParameterCount() == payload.getConstructorSchema().getFieldsCount()) |
| .filter(c -> parametersCompatible(c.getParameters(), constructorRow)) |
| .collect(Collectors.toList()); |
| if (mappingConstructors.size() != 1) { |
| throw new RuntimeException( |
| "Expected to find a single mapping constructor but found " + mappingConstructors.size()); |
| } |
| return (Constructor<PTransform<InputT, OutputT>>) mappingConstructors.get(0); |
| } |
| |
| private boolean isConstructorMethodForName( |
| Method method, String nameFromPayload, AllowedClass allowListClass) { |
| for (Annotation annotation : method.getAnnotations()) { |
| if (annotation instanceof MultiLanguageConstructorMethod) { |
| if (nameFromPayload.equals(((MultiLanguageConstructorMethod) annotation).name())) { |
| if (allowListClass.getAllowedConstructorMethods().contains(nameFromPayload)) { |
| return true; |
| } else { |
| throw new RuntimeException( |
| "Constructor method " + nameFromPayload + " needs to be explicitly allowed"); |
| } |
| } |
| } |
| } |
| if (method.getName().equals(nameFromPayload)) { |
| if (allowListClass.getAllowedConstructorMethods().contains(nameFromPayload)) { |
| return true; |
| } else { |
| throw new RuntimeException( |
| "Constructor method " + nameFromPayload + " needs to be explicitly allowed"); |
| } |
| } |
| return false; |
| } |
| |
| private Method findMappingConstructorMethod( |
| Method[] methods, JavaClassLookupPayload payload, AllowedClass allowListClass) { |
| |
| Row constructorRow = decodeRow(payload.getConstructorSchema(), payload.getConstructorPayload()); |
| |
| List<Method> mappingConstructorMethods = |
| Arrays.stream(methods) |
| .filter( |
| m -> isConstructorMethodForName(m, payload.getConstructorMethod(), allowListClass)) |
| .filter(m -> m.getParameterCount() == payload.getConstructorSchema().getFieldsCount()) |
| .filter(m -> parametersCompatible(m.getParameters(), constructorRow)) |
| .collect(Collectors.toList()); |
| |
| if (mappingConstructorMethods.size() != 1) { |
| throw new RuntimeException( |
| "Expected to find a single mapping constructor method but found " |
| + mappingConstructorMethods.size() |
| + " Payload was " |
| + payload); |
| } |
| return mappingConstructorMethods.get(0); |
| } |
| |
| @AutoValue |
| public abstract static class AllowList { |
| |
| public abstract String getVersion(); |
| |
| public abstract List<AllowedClass> getAllowedClasses(); |
| |
| @JsonCreator |
| static AllowList create( |
| @JsonProperty("version") String version, |
| @JsonProperty("allowedClasses") @javax.annotation.Nullable |
| List<AllowedClass> allowedClasses) { |
| if (allowedClasses == null) { |
| allowedClasses = new ArrayList<>(); |
| } |
| return new AutoValue_JavaClassLookupTransformProvider_AllowList(version, allowedClasses); |
| } |
| } |
| |
| @AutoValue |
| public abstract static class AllowedClass { |
| |
| public abstract String getClassName(); |
| |
| public abstract List<String> getAllowedBuilderMethods(); |
| |
| public abstract List<String> getAllowedConstructorMethods(); |
| |
| @JsonCreator |
| static AllowedClass create( |
| @JsonProperty("className") String className, |
| @JsonProperty("allowedBuilderMethods") @javax.annotation.Nullable |
| List<String> allowedBuilderMethods, |
| @JsonProperty("allowedConstructorMethods") @javax.annotation.Nullable |
| List<String> allowedConstructorMethods) { |
| if (allowedBuilderMethods == null) { |
| allowedBuilderMethods = new ArrayList<>(); |
| } |
| if (allowedConstructorMethods == null) { |
| allowedConstructorMethods = new ArrayList<>(); |
| } |
| return new AutoValue_JavaClassLookupTransformProvider_AllowedClass( |
| className, allowedBuilderMethods, allowedConstructorMethods); |
| } |
| } |
| |
| static Row decodeRow(SchemaApi.Schema schema, ByteString payload) { |
| Schema payloadSchema = SchemaTranslation.schemaFromProto(schema); |
| |
| if (payloadSchema.getFieldCount() == 0) { |
| return Row.withSchema(Schema.of()).build(); |
| } |
| |
| Row row; |
| try { |
| row = RowCoder.of(payloadSchema).decode(payload.newInput()); |
| } catch (IOException e) { |
| throw new RuntimeException("Error decoding payload", e); |
| } |
| return row; |
| } |
| } |