| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.sdk.transforms.reflect; |
| |
| import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkState; |
| |
| import com.google.auto.value.AutoValue; |
| import java.lang.annotation.Annotation; |
| import java.lang.reflect.AnnotatedElement; |
| import java.lang.reflect.Field; |
| import java.lang.reflect.Method; |
| import java.lang.reflect.Modifier; |
| import java.lang.reflect.ParameterizedType; |
| import java.lang.reflect.Type; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Collection; |
| import java.util.Collections; |
| import java.util.HashMap; |
| import java.util.LinkedHashMap; |
| import java.util.LinkedHashSet; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Optional; |
| import javax.annotation.Nullable; |
| import org.apache.beam.sdk.coders.Coder; |
| import org.apache.beam.sdk.options.PipelineOptions; |
| import org.apache.beam.sdk.schemas.FieldAccessDescriptor; |
| import org.apache.beam.sdk.state.State; |
| import org.apache.beam.sdk.state.StateSpec; |
| import org.apache.beam.sdk.state.TimeDomain; |
| import org.apache.beam.sdk.state.Timer; |
| import org.apache.beam.sdk.state.TimerSpec; |
| import org.apache.beam.sdk.transforms.DoFn; |
| import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; |
| import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; |
| import org.apache.beam.sdk.transforms.DoFn.StateId; |
| import org.apache.beam.sdk.transforms.DoFn.TimerId; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignature.FieldAccessDeclaration; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration; |
| import org.apache.beam.sdk.transforms.reflect.DoFnSignature.TimerDeclaration; |
| import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker; |
| import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; |
| import org.apache.beam.sdk.transforms.windowing.BoundedWindow; |
| import org.apache.beam.sdk.transforms.windowing.PaneInfo; |
| import org.apache.beam.sdk.util.common.ReflectHelpers; |
| import org.apache.beam.sdk.values.PCollection; |
| import org.apache.beam.sdk.values.Row; |
| import org.apache.beam.sdk.values.TypeDescriptor; |
| import org.apache.beam.sdk.values.TypeParameter; |
| import org.apache.beam.vendor.guava.v20_0.com.google.common.annotations.VisibleForTesting; |
| import org.apache.beam.vendor.guava.v20_0.com.google.common.base.Predicates; |
| import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList; |
| import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap; |
| import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps; |
| import org.joda.time.Instant; |
| |
| /** Utilities for working with {@link DoFnSignature}. See {@link #getSignature}. */ |
| public class DoFnSignatures { |
| |
| private DoFnSignatures() {} |
| |
| private static final Map<Class<?>, DoFnSignature> signatureCache = new LinkedHashMap<>(); |
| |
| private static final ImmutableList<Class<? extends Parameter>> |
| ALLOWED_NON_SPLITTABLE_PROCESS_ELEMENT_PARAMETERS = |
| ImmutableList.of( |
| Parameter.ProcessContextParameter.class, |
| Parameter.ElementParameter.class, |
| Parameter.SchemaElementParameter.class, |
| Parameter.TimestampParameter.class, |
| Parameter.OutputReceiverParameter.class, |
| Parameter.TaggedOutputReceiverParameter.class, |
| Parameter.WindowParameter.class, |
| Parameter.PaneInfoParameter.class, |
| Parameter.PipelineOptionsParameter.class, |
| Parameter.TimerParameter.class, |
| Parameter.StateParameter.class); |
| |
| private static final ImmutableList<Class<? extends Parameter>> |
| ALLOWED_SPLITTABLE_PROCESS_ELEMENT_PARAMETERS = |
| ImmutableList.of( |
| Parameter.PipelineOptionsParameter.class, |
| Parameter.ElementParameter.class, |
| Parameter.TimestampParameter.class, |
| Parameter.OutputReceiverParameter.class, |
| Parameter.TaggedOutputReceiverParameter.class, |
| Parameter.ProcessContextParameter.class, |
| Parameter.RestrictionTrackerParameter.class); |
| |
| private static final ImmutableList<Class<? extends Parameter>> ALLOWED_ON_TIMER_PARAMETERS = |
| ImmutableList.of( |
| Parameter.OnTimerContextParameter.class, |
| Parameter.TimestampParameter.class, |
| Parameter.TimeDomainParameter.class, |
| Parameter.WindowParameter.class, |
| Parameter.PipelineOptionsParameter.class, |
| Parameter.OutputReceiverParameter.class, |
| Parameter.TaggedOutputReceiverParameter.class, |
| Parameter.TimerParameter.class, |
| Parameter.StateParameter.class); |
| |
| private static final Collection<Class<? extends Parameter>> |
| ALLOWED_ON_WINDOW_EXPIRATION_PARAMETERS = |
| ImmutableList.of( |
| Parameter.WindowParameter.class, |
| Parameter.PipelineOptionsParameter.class, |
| Parameter.OutputReceiverParameter.class, |
| Parameter.TaggedOutputReceiverParameter.class, |
| Parameter.StateParameter.class); |
| |
| /** @return the {@link DoFnSignature} for the given {@link DoFn} instance. */ |
| public static <FnT extends DoFn<?, ?>> DoFnSignature signatureForDoFn(FnT fn) { |
| return getSignature(fn.getClass()); |
| } |
| |
| /** @return the {@link DoFnSignature} for the given {@link DoFn} subclass. */ |
| public static synchronized <FnT extends DoFn<?, ?>> DoFnSignature getSignature(Class<FnT> fn) { |
| return signatureCache.computeIfAbsent(fn, k -> parseSignature(fn)); |
| } |
| |
| /** |
| * The context for a {@link DoFn} class, for use in analysis. |
| * |
| * <p>It contains much of the information that eventually becomes part of the {@link |
| * DoFnSignature}, but in an intermediate state. |
| */ |
| @VisibleForTesting |
| static class FnAnalysisContext { |
| |
| private final Map<String, StateDeclaration> stateDeclarations = new HashMap<>(); |
| private final Map<String, TimerDeclaration> timerDeclarations = new HashMap<>(); |
| private final Map<String, FieldAccessDeclaration> fieldAccessDeclarations = new HashMap<>(); |
| |
| private FnAnalysisContext() {} |
| |
| /** Create an empty context, with no declarations. */ |
| public static FnAnalysisContext create() { |
| return new FnAnalysisContext(); |
| } |
| |
| /** State parameters declared in this context, keyed by {@link StateId}. Unmodifiable. */ |
| public Map<String, StateDeclaration> getStateDeclarations() { |
| return Collections.unmodifiableMap(stateDeclarations); |
| } |
| |
| /** Timer parameters declared in this context, keyed by {@link TimerId}. Unmodifiable. */ |
| public Map<String, TimerDeclaration> getTimerDeclarations() { |
| return Collections.unmodifiableMap(timerDeclarations); |
| } |
| |
| /** Field access declaration declared in this context. */ |
| @Nullable |
| public Map<String, FieldAccessDeclaration> getFieldAccessDeclarations() { |
| return fieldAccessDeclarations; |
| } |
| |
| public void addStateDeclaration(StateDeclaration decl) { |
| stateDeclarations.put(decl.id(), decl); |
| } |
| |
| public void addStateDeclarations(Iterable<StateDeclaration> decls) { |
| for (StateDeclaration decl : decls) { |
| addStateDeclaration(decl); |
| } |
| } |
| |
| public void addTimerDeclaration(TimerDeclaration decl) { |
| timerDeclarations.put(decl.id(), decl); |
| } |
| |
| public void addTimerDeclarations(Iterable<TimerDeclaration> decls) { |
| for (TimerDeclaration decl : decls) { |
| addTimerDeclaration(decl); |
| } |
| } |
| |
| public void addFieldAccessDeclaration(FieldAccessDeclaration decl) { |
| fieldAccessDeclarations.put(decl.id(), decl); |
| } |
| |
| public void addFieldAccessDeclarations(Iterable<FieldAccessDeclaration> decls) { |
| for (FieldAccessDeclaration decl : decls) { |
| addFieldAccessDeclaration(decl); |
| } |
| } |
| } |
| |
| /** |
| * The context of analysis within a particular method. |
| * |
| * <p>It contains much of the information that eventually becomes part of the {@link |
| * DoFnSignature.MethodWithExtraParameters}, but in an intermediate state. |
| */ |
| private static class MethodAnalysisContext { |
| |
| private final Map<String, StateParameter> stateParameters = new HashMap<>(); |
| private final Map<String, TimerParameter> timerParameters = new HashMap<>(); |
| private final List<Parameter> extraParameters = new ArrayList<>(); |
| |
| @Nullable private TypeDescriptor<? extends BoundedWindow> windowT; |
| |
| private MethodAnalysisContext() {} |
| |
| /** Indicates whether a {@link RestrictionTrackerParameter} is known in this context. */ |
| public boolean hasRestrictionTrackerParameter() { |
| return extraParameters.stream() |
| .anyMatch(Predicates.instanceOf(RestrictionTrackerParameter.class)::apply); |
| } |
| |
| /** Indicates whether a {@link WindowParameter} is known in this context. */ |
| public boolean hasWindowParameter() { |
| return extraParameters.stream().anyMatch(Predicates.instanceOf(WindowParameter.class)::apply); |
| } |
| |
| /** Indicates whether a {@link Parameter.PipelineOptionsParameter} is known in this context. */ |
| public boolean hasPipelineOptionsParamter() { |
| return extraParameters.stream() |
| .anyMatch(Predicates.instanceOf(Parameter.PipelineOptionsParameter.class)::apply); |
| } |
| |
| /** The window type, if any, used by this method. */ |
| @Nullable |
| public TypeDescriptor<? extends BoundedWindow> getWindowType() { |
| return windowT; |
| } |
| |
| /** State parameters declared in this context, keyed by {@link StateId}. */ |
| public Map<String, StateParameter> getStateParameters() { |
| return Collections.unmodifiableMap(stateParameters); |
| } |
| |
| /** Timer parameters declared in this context, keyed by {@link TimerId}. */ |
| public Map<String, TimerParameter> getTimerParameters() { |
| return Collections.unmodifiableMap(timerParameters); |
| } |
| |
| /** Extra parameters in their entirety. Unmodifiable. */ |
| public List<Parameter> getExtraParameters() { |
| return Collections.unmodifiableList(extraParameters); |
| } |
| |
| public void setParameter(int index, Parameter parameter) { |
| extraParameters.set(index, parameter); |
| } |
| |
| /** |
| * Returns an {@link MethodAnalysisContext} like this one but including the provided {@link |
| * StateParameter}. |
| */ |
| public void addParameter(Parameter param) { |
| extraParameters.add(param); |
| |
| if (param instanceof StateParameter) { |
| StateParameter stateParameter = (StateParameter) param; |
| stateParameters.put(stateParameter.referent().id(), stateParameter); |
| } |
| if (param instanceof TimerParameter) { |
| TimerParameter timerParameter = (TimerParameter) param; |
| timerParameters.put(timerParameter.referent().id(), timerParameter); |
| } |
| } |
| |
| /** Create an empty context, with no declarations. */ |
| public static MethodAnalysisContext create() { |
| return new MethodAnalysisContext(); |
| } |
| } |
| |
| @AutoValue |
| abstract static class ParameterDescription { |
| public abstract Method getMethod(); |
| |
| public abstract int getIndex(); |
| |
| public abstract TypeDescriptor<?> getType(); |
| |
| public abstract List<Annotation> getAnnotations(); |
| |
| public static ParameterDescription of( |
| Method method, int index, TypeDescriptor<?> type, List<Annotation> annotations) { |
| return new AutoValue_DoFnSignatures_ParameterDescription(method, index, type, annotations); |
| } |
| |
| public static ParameterDescription of( |
| Method method, int index, TypeDescriptor<?> type, Annotation[] annotations) { |
| return new AutoValue_DoFnSignatures_ParameterDescription( |
| method, index, type, Arrays.asList(annotations)); |
| } |
| } |
| |
| /** Analyzes a given {@link DoFn} class and extracts its {@link DoFnSignature}. */ |
| private static DoFnSignature parseSignature(Class<? extends DoFn<?, ?>> fnClass) { |
| DoFnSignature.Builder signatureBuilder = DoFnSignature.builder(); |
| |
| ErrorReporter errors = new ErrorReporter(null, fnClass.getName()); |
| errors.checkArgument(DoFn.class.isAssignableFrom(fnClass), "Must be subtype of DoFn"); |
| signatureBuilder.setFnClass(fnClass); |
| |
| TypeDescriptor<? extends DoFn<?, ?>> fnT = TypeDescriptor.of(fnClass); |
| |
| // Extract the input and output type, and whether the fn is bounded. |
| TypeDescriptor<?> inputT = null; |
| TypeDescriptor<?> outputT = null; |
| for (TypeDescriptor<?> supertype : fnT.getTypes()) { |
| if (!supertype.getRawType().equals(DoFn.class)) { |
| continue; |
| } |
| Type[] args = ((ParameterizedType) supertype.getType()).getActualTypeArguments(); |
| inputT = TypeDescriptor.of(args[0]); |
| outputT = TypeDescriptor.of(args[1]); |
| } |
| errors.checkNotNull(inputT, "Unable to determine input type"); |
| |
| // Find the state and timer declarations in advance of validating |
| // method parameter lists |
| FnAnalysisContext fnContext = FnAnalysisContext.create(); |
| fnContext.addStateDeclarations(analyzeStateDeclarations(errors, fnClass).values()); |
| fnContext.addTimerDeclarations(analyzeTimerDeclarations(errors, fnClass).values()); |
| fnContext.addFieldAccessDeclarations(analyzeFieldAccessDeclaration(errors, fnClass).values()); |
| |
| Method processElementMethod = |
| findAnnotatedMethod(errors, DoFn.ProcessElement.class, fnClass, true); |
| Method startBundleMethod = findAnnotatedMethod(errors, DoFn.StartBundle.class, fnClass, false); |
| Method finishBundleMethod = |
| findAnnotatedMethod(errors, DoFn.FinishBundle.class, fnClass, false); |
| Method setupMethod = findAnnotatedMethod(errors, DoFn.Setup.class, fnClass, false); |
| Method teardownMethod = findAnnotatedMethod(errors, DoFn.Teardown.class, fnClass, false); |
| Method onWindowExpirationMethod = |
| findAnnotatedMethod(errors, DoFn.OnWindowExpiration.class, fnClass, false); |
| Method getInitialRestrictionMethod = |
| findAnnotatedMethod(errors, DoFn.GetInitialRestriction.class, fnClass, false); |
| Method splitRestrictionMethod = |
| findAnnotatedMethod(errors, DoFn.SplitRestriction.class, fnClass, false); |
| Method getRestrictionCoderMethod = |
| findAnnotatedMethod(errors, DoFn.GetRestrictionCoder.class, fnClass, false); |
| Method newTrackerMethod = findAnnotatedMethod(errors, DoFn.NewTracker.class, fnClass, false); |
| |
| Collection<Method> onTimerMethods = |
| declaredMethodsWithAnnotation(DoFn.OnTimer.class, fnClass, DoFn.class); |
| HashMap<String, DoFnSignature.OnTimerMethod> onTimerMethodMap = |
| Maps.newHashMapWithExpectedSize(onTimerMethods.size()); |
| for (Method onTimerMethod : onTimerMethods) { |
| String id = onTimerMethod.getAnnotation(DoFn.OnTimer.class).value(); |
| errors.checkArgument( |
| fnContext.getTimerDeclarations().containsKey(id), |
| "Callback %s is for undeclared timer %s", |
| onTimerMethod, |
| id); |
| |
| TimerDeclaration timerDecl = fnContext.getTimerDeclarations().get(id); |
| errors.checkArgument( |
| timerDecl.field().getDeclaringClass().equals(onTimerMethod.getDeclaringClass()), |
| "Callback %s is for timer %s declared in a different class %s." |
| + " Timer callbacks must be declared in the same lexical scope as their timer", |
| onTimerMethod, |
| id, |
| timerDecl.field().getDeclaringClass().getCanonicalName()); |
| |
| onTimerMethodMap.put( |
| id, analyzeOnTimerMethod(errors, fnT, onTimerMethod, id, inputT, outputT, fnContext)); |
| } |
| signatureBuilder.setOnTimerMethods(onTimerMethodMap); |
| |
| // Check the converse - that all timers have a callback. This could be relaxed to only |
| // those timers used in methods, once method parameter lists support timers. |
| for (TimerDeclaration decl : fnContext.getTimerDeclarations().values()) { |
| errors.checkArgument( |
| onTimerMethodMap.containsKey(decl.id()), |
| "No callback registered via %s for timer %s", |
| DoFn.OnTimer.class.getSimpleName(), |
| decl.id()); |
| } |
| |
| ErrorReporter processElementErrors = |
| errors.forMethod(DoFn.ProcessElement.class, processElementMethod); |
| DoFnSignature.ProcessElementMethod processElement = |
| analyzeProcessElementMethod( |
| processElementErrors, fnT, processElementMethod, inputT, outputT, fnContext); |
| signatureBuilder.setProcessElement(processElement); |
| |
| if (startBundleMethod != null) { |
| ErrorReporter startBundleErrors = errors.forMethod(DoFn.StartBundle.class, startBundleMethod); |
| signatureBuilder.setStartBundle( |
| analyzeStartBundleMethod(startBundleErrors, fnT, startBundleMethod, inputT, outputT)); |
| } |
| |
| if (finishBundleMethod != null) { |
| ErrorReporter finishBundleErrors = |
| errors.forMethod(DoFn.FinishBundle.class, finishBundleMethod); |
| signatureBuilder.setFinishBundle( |
| analyzeFinishBundleMethod(finishBundleErrors, fnT, finishBundleMethod, inputT, outputT)); |
| } |
| |
| if (setupMethod != null) { |
| signatureBuilder.setSetup( |
| analyzeLifecycleMethod(errors.forMethod(DoFn.Setup.class, setupMethod), setupMethod)); |
| } |
| |
| if (teardownMethod != null) { |
| signatureBuilder.setTeardown( |
| analyzeLifecycleMethod( |
| errors.forMethod(DoFn.Teardown.class, teardownMethod), teardownMethod)); |
| } |
| |
| if (onWindowExpirationMethod != null) { |
| signatureBuilder.setOnWindowExpiration( |
| analyzeOnWindowExpirationMethod( |
| errors, fnT, onWindowExpirationMethod, inputT, outputT, fnContext)); |
| } |
| |
| ErrorReporter getInitialRestrictionErrors; |
| if (getInitialRestrictionMethod != null) { |
| getInitialRestrictionErrors = |
| errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestrictionMethod); |
| signatureBuilder.setGetInitialRestriction( |
| analyzeGetInitialRestrictionMethod( |
| getInitialRestrictionErrors, fnT, getInitialRestrictionMethod, inputT)); |
| } |
| |
| if (splitRestrictionMethod != null) { |
| ErrorReporter splitRestrictionErrors = |
| errors.forMethod(DoFn.SplitRestriction.class, splitRestrictionMethod); |
| signatureBuilder.setSplitRestriction( |
| analyzeSplitRestrictionMethod( |
| splitRestrictionErrors, fnT, splitRestrictionMethod, inputT)); |
| } |
| |
| if (getRestrictionCoderMethod != null) { |
| ErrorReporter getRestrictionCoderErrors = |
| errors.forMethod(DoFn.GetRestrictionCoder.class, getRestrictionCoderMethod); |
| signatureBuilder.setGetRestrictionCoder( |
| analyzeGetRestrictionCoderMethod( |
| getRestrictionCoderErrors, fnT, getRestrictionCoderMethod)); |
| } |
| |
| if (newTrackerMethod != null) { |
| ErrorReporter newTrackerErrors = errors.forMethod(DoFn.NewTracker.class, newTrackerMethod); |
| signatureBuilder.setNewTracker( |
| analyzeNewTrackerMethod(newTrackerErrors, fnT, newTrackerMethod)); |
| } |
| |
| signatureBuilder.setIsBoundedPerElement(inferBoundedness(fnT, processElement, errors)); |
| |
| signatureBuilder.setStateDeclarations(fnContext.getStateDeclarations()); |
| signatureBuilder.setTimerDeclarations(fnContext.getTimerDeclarations()); |
| signatureBuilder.setFieldAccessDeclarations(fnContext.getFieldAccessDeclarations()); |
| |
| DoFnSignature signature = signatureBuilder.build(); |
| |
| // Additional validation for splittable DoFn's. |
| if (processElement.isSplittable()) { |
| verifySplittableMethods(signature, errors); |
| } else { |
| verifyUnsplittableMethods(errors, signature); |
| } |
| |
| return signature; |
| } |
| |
| /** |
| * Infers the boundedness of the {@link DoFn.ProcessElement} method (whether or not it performs a |
| * bounded amount of work per element) using the following criteria: |
| * |
| * <ol> |
| * <li>If the {@link DoFn} is not splittable, then it is bounded, it must not be annotated as |
| * {@link DoFn.BoundedPerElement} or {@link DoFn.UnboundedPerElement}, and {@link |
| * DoFn.ProcessElement} must return {@code void}. |
| * <li>If the {@link DoFn} (or any of its supertypes) is annotated as {@link |
| * DoFn.BoundedPerElement} or {@link DoFn.UnboundedPerElement}, use that. Only one of these |
| * must be specified. |
| * <li>If {@link DoFn.ProcessElement} returns {@link DoFn.ProcessContinuation}, assume it is |
| * unbounded. Otherwise (if it returns {@code void}), assume it is bounded. |
| * <li>If {@link DoFn.ProcessElement} returns {@code void}, but the {@link DoFn} is annotated |
| * {@link DoFn.UnboundedPerElement}, this is an error. |
| * </ol> |
| */ |
| private static PCollection.IsBounded inferBoundedness( |
| TypeDescriptor<? extends DoFn> fnT, |
| DoFnSignature.ProcessElementMethod processElement, |
| ErrorReporter errors) { |
| PCollection.IsBounded isBounded = null; |
| for (TypeDescriptor<?> supertype : fnT.getTypes()) { |
| if (supertype.getRawType().isAnnotationPresent(DoFn.BoundedPerElement.class) |
| || supertype.getRawType().isAnnotationPresent(DoFn.UnboundedPerElement.class)) { |
| errors.checkArgument( |
| isBounded == null, |
| "Both @%s and @%s specified", |
| DoFn.BoundedPerElement.class.getSimpleName(), |
| DoFn.UnboundedPerElement.class.getSimpleName()); |
| isBounded = |
| supertype.getRawType().isAnnotationPresent(DoFn.BoundedPerElement.class) |
| ? PCollection.IsBounded.BOUNDED |
| : PCollection.IsBounded.UNBOUNDED; |
| } |
| } |
| if (processElement.isSplittable()) { |
| if (isBounded == null) { |
| isBounded = |
| processElement.hasReturnValue() |
| ? PCollection.IsBounded.UNBOUNDED |
| : PCollection.IsBounded.BOUNDED; |
| } |
| } else { |
| errors.checkArgument( |
| isBounded == null, |
| "Non-splittable, but annotated as @" |
| + ((isBounded == PCollection.IsBounded.BOUNDED) |
| ? DoFn.BoundedPerElement.class.getSimpleName() |
| : DoFn.UnboundedPerElement.class.getSimpleName())); |
| checkState(!processElement.hasReturnValue(), "Should have been inferred splittable"); |
| isBounded = PCollection.IsBounded.BOUNDED; |
| } |
| return isBounded; |
| } |
| |
| /** |
| * Verifies properties related to methods of splittable {@link DoFn}: |
| * |
| * <ul> |
| * <li>Must declare the required {@link DoFn.GetInitialRestriction} and {@link DoFn.NewTracker} |
| * methods. |
| * <li>Types of restrictions and trackers must match exactly between {@link |
| * DoFn.ProcessElement}, {@link DoFn.GetInitialRestriction}, {@link DoFn.NewTracker}, {@link |
| * DoFn.GetRestrictionCoder}, {@link DoFn.SplitRestriction}. |
| * </ul> |
| */ |
| private static void verifySplittableMethods(DoFnSignature signature, ErrorReporter errors) { |
| DoFnSignature.ProcessElementMethod processElement = signature.processElement(); |
| DoFnSignature.GetInitialRestrictionMethod getInitialRestriction = |
| signature.getInitialRestriction(); |
| DoFnSignature.NewTrackerMethod newTracker = signature.newTracker(); |
| DoFnSignature.GetRestrictionCoderMethod getRestrictionCoder = signature.getRestrictionCoder(); |
| DoFnSignature.SplitRestrictionMethod splitRestriction = signature.splitRestriction(); |
| |
| ErrorReporter processElementErrors = |
| errors.forMethod(DoFn.ProcessElement.class, processElement.targetMethod()); |
| |
| List<String> missingRequiredMethods = new ArrayList<>(); |
| if (getInitialRestriction == null) { |
| missingRequiredMethods.add("@" + DoFn.GetInitialRestriction.class.getSimpleName()); |
| } |
| if (newTracker == null) { |
| if (getInitialRestriction != null |
| && getInitialRestriction |
| .restrictionT() |
| .isSubtypeOf(TypeDescriptor.of(HasDefaultTracker.class))) { |
| // no-op we are using the annotation @HasDefaultTracker |
| } else { |
| missingRequiredMethods.add("@" + DoFn.NewTracker.class.getSimpleName()); |
| } |
| } else { |
| ErrorReporter getInitialRestrictionErrors = |
| errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestriction.targetMethod()); |
| TypeDescriptor<?> restrictionT = getInitialRestriction.restrictionT(); |
| getInitialRestrictionErrors.checkArgument( |
| restrictionT.equals(newTracker.restrictionT()), |
| "Uses restriction type %s, but @%s method %s uses restriction type %s", |
| formatType(restrictionT), |
| DoFn.NewTracker.class.getSimpleName(), |
| format(newTracker.targetMethod()), |
| formatType(newTracker.restrictionT())); |
| } |
| |
| if (!missingRequiredMethods.isEmpty()) { |
| processElementErrors.throwIllegalArgument( |
| "Splittable, but does not define the following required methods: %s", |
| missingRequiredMethods); |
| } |
| |
| ErrorReporter getInitialRestrictionErrors = |
| errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestriction.targetMethod()); |
| TypeDescriptor<?> restrictionT = getInitialRestriction.restrictionT(); |
| processElementErrors.checkArgument( |
| processElement.trackerT().getRawType().equals(RestrictionTracker.class), |
| "Has tracker type %s, but the DoFn's tracker type must be of type RestrictionTracker.", |
| formatType(processElement.trackerT())); |
| |
| if (getRestrictionCoder != null) { |
| getInitialRestrictionErrors.checkArgument( |
| getRestrictionCoder.coderT().isSubtypeOf(coderTypeOf(restrictionT)), |
| "Uses restriction type %s, but @%s method %s returns %s " |
| + "which is not a subtype of %s", |
| formatType(restrictionT), |
| DoFn.GetRestrictionCoder.class.getSimpleName(), |
| format(getRestrictionCoder.targetMethod()), |
| formatType(getRestrictionCoder.coderT()), |
| formatType(coderTypeOf(restrictionT))); |
| } |
| |
| if (splitRestriction != null) { |
| getInitialRestrictionErrors.checkArgument( |
| splitRestriction.restrictionT().equals(restrictionT), |
| "Uses restriction type %s, but @%s method %s uses restriction type %s", |
| formatType(restrictionT), |
| DoFn.SplitRestriction.class.getSimpleName(), |
| format(splitRestriction.targetMethod()), |
| formatType(splitRestriction.restrictionT())); |
| } |
| } |
| |
| /** |
| * Verifies that a non-splittable {@link DoFn} does not declare any methods that only make sense |
| * for splittable {@link DoFn}: {@link DoFn.GetInitialRestriction}, {@link DoFn.SplitRestriction}, |
| * {@link DoFn.NewTracker}, {@link DoFn.GetRestrictionCoder}. |
| */ |
| private static void verifyUnsplittableMethods(ErrorReporter errors, DoFnSignature signature) { |
| List<String> forbiddenMethods = new ArrayList<>(); |
| if (signature.getInitialRestriction() != null) { |
| forbiddenMethods.add("@" + DoFn.GetInitialRestriction.class.getSimpleName()); |
| } |
| if (signature.splitRestriction() != null) { |
| forbiddenMethods.add("@" + DoFn.SplitRestriction.class.getSimpleName()); |
| } |
| if (signature.newTracker() != null) { |
| forbiddenMethods.add("@" + DoFn.NewTracker.class.getSimpleName()); |
| } |
| if (signature.getRestrictionCoder() != null) { |
| forbiddenMethods.add("@" + DoFn.GetRestrictionCoder.class.getSimpleName()); |
| } |
| errors.checkArgument( |
| forbiddenMethods.isEmpty(), "Non-splittable, but defines methods: %s", forbiddenMethods); |
| } |
| |
| /** |
| * Generates a {@link TypeDescriptor} for {@code DoFn<InputT, OutputT>.ProcessContext} given |
| * {@code InputT} and {@code OutputT}. |
| */ |
| private static <InputT, OutputT> |
| TypeDescriptor<DoFn<InputT, OutputT>.ProcessContext> doFnProcessContextTypeOf( |
| TypeDescriptor<InputT> inputT, TypeDescriptor<OutputT> outputT) { |
| return new TypeDescriptor<DoFn<InputT, OutputT>.ProcessContext>() {}.where( |
| new TypeParameter<InputT>() {}, inputT) |
| .where(new TypeParameter<OutputT>() {}, outputT); |
| } |
| |
| /** |
| * Generates a {@link TypeDescriptor} for {@code DoFn<InputT, OutputT>.StartBundleContext} given |
| * {@code InputT} and {@code OutputT}. |
| */ |
| private static <InputT, OutputT> |
| TypeDescriptor<DoFn<InputT, OutputT>.StartBundleContext> doFnStartBundleContextTypeOf( |
| TypeDescriptor<InputT> inputT, TypeDescriptor<OutputT> outputT) { |
| return new TypeDescriptor<DoFn<InputT, OutputT>.StartBundleContext>() {}.where( |
| new TypeParameter<InputT>() {}, inputT) |
| .where(new TypeParameter<OutputT>() {}, outputT); |
| } |
| |
| /** |
| * Generates a {@link TypeDescriptor} for {@code DoFn<InputT, OutputT>.FinishBundleContext} given |
| * {@code InputT} and {@code OutputT}. |
| */ |
| private static <InputT, OutputT> |
| TypeDescriptor<DoFn<InputT, OutputT>.FinishBundleContext> doFnFinishBundleContextTypeOf( |
| TypeDescriptor<InputT> inputT, TypeDescriptor<OutputT> outputT) { |
| return new TypeDescriptor<DoFn<InputT, OutputT>.FinishBundleContext>() {}.where( |
| new TypeParameter<InputT>() {}, inputT) |
| .where(new TypeParameter<OutputT>() {}, outputT); |
| } |
| |
| /** |
| * Generates a {@link TypeDescriptor} for {@code DoFn<InputT, OutputT>.Context} given {@code |
| * InputT} and {@code OutputT}. |
| */ |
| private static <InputT, OutputT> |
| TypeDescriptor<DoFn<InputT, OutputT>.OnTimerContext> doFnOnTimerContextTypeOf( |
| TypeDescriptor<InputT> inputT, TypeDescriptor<OutputT> outputT) { |
| return new TypeDescriptor<DoFn<InputT, OutputT>.OnTimerContext>() {}.where( |
| new TypeParameter<InputT>() {}, inputT) |
| .where(new TypeParameter<OutputT>() {}, outputT); |
| } |
| |
| @VisibleForTesting |
| static DoFnSignature.OnTimerMethod analyzeOnTimerMethod( |
| ErrorReporter errors, |
| TypeDescriptor<? extends DoFn<?, ?>> fnClass, |
| Method m, |
| String timerId, |
| TypeDescriptor<?> inputT, |
| TypeDescriptor<?> outputT, |
| FnAnalysisContext fnContext) { |
| errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); |
| |
| Type[] params = m.getGenericParameterTypes(); |
| |
| MethodAnalysisContext methodContext = MethodAnalysisContext.create(); |
| |
| boolean requiresStableInput = m.isAnnotationPresent(DoFn.RequiresStableInput.class); |
| |
| @Nullable TypeDescriptor<? extends BoundedWindow> windowT = getWindowType(fnClass, m); |
| |
| List<DoFnSignature.Parameter> extraParameters = new ArrayList<>(); |
| ErrorReporter onTimerErrors = errors.forMethod(DoFn.OnTimer.class, m); |
| for (int i = 0; i < params.length; ++i) { |
| Parameter parameter = |
| analyzeExtraParameter( |
| onTimerErrors, |
| fnContext, |
| methodContext, |
| fnClass, |
| ParameterDescription.of( |
| m, |
| i, |
| fnClass.resolveType(params[i]), |
| Arrays.asList(m.getParameterAnnotations()[i])), |
| inputT, |
| outputT); |
| |
| checkParameterOneOf(errors, parameter, ALLOWED_ON_TIMER_PARAMETERS); |
| |
| extraParameters.add(parameter); |
| } |
| |
| return DoFnSignature.OnTimerMethod.create( |
| m, timerId, requiresStableInput, windowT, extraParameters); |
| } |
| |
| @VisibleForTesting |
| static DoFnSignature.OnWindowExpirationMethod analyzeOnWindowExpirationMethod( |
| ErrorReporter errors, |
| TypeDescriptor<? extends DoFn<?, ?>> fnClass, |
| Method m, |
| TypeDescriptor<?> inputT, |
| TypeDescriptor<?> outputT, |
| FnAnalysisContext fnContext) { |
| errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); |
| |
| Type[] params = m.getGenericParameterTypes(); |
| |
| MethodAnalysisContext methodContext = MethodAnalysisContext.create(); |
| |
| boolean requiresStableInput = m.isAnnotationPresent(DoFn.RequiresStableInput.class); |
| |
| @Nullable TypeDescriptor<? extends BoundedWindow> windowT = getWindowType(fnClass, m); |
| |
| List<DoFnSignature.Parameter> extraParameters = new ArrayList<>(); |
| ErrorReporter onWindowExpirationErrors = errors.forMethod(DoFn.OnWindowExpiration.class, m); |
| for (int i = 0; i < params.length; ++i) { |
| Parameter parameter = |
| analyzeExtraParameter( |
| onWindowExpirationErrors, |
| fnContext, |
| methodContext, |
| fnClass, |
| ParameterDescription.of( |
| m, |
| i, |
| fnClass.resolveType(params[i]), |
| Arrays.asList(m.getParameterAnnotations()[i])), |
| inputT, |
| outputT); |
| |
| checkParameterOneOf(errors, parameter, ALLOWED_ON_WINDOW_EXPIRATION_PARAMETERS); |
| |
| extraParameters.add(parameter); |
| } |
| |
| return DoFnSignature.OnWindowExpirationMethod.create( |
| m, requiresStableInput, windowT, extraParameters); |
| } |
| |
| @VisibleForTesting |
| static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod( |
| ErrorReporter errors, |
| TypeDescriptor<? extends DoFn<?, ?>> fnClass, |
| Method m, |
| TypeDescriptor<?> inputT, |
| TypeDescriptor<?> outputT, |
| FnAnalysisContext fnContext) { |
| errors.checkArgument( |
| void.class.equals(m.getReturnType()) |
| || DoFn.ProcessContinuation.class.equals(m.getReturnType()), |
| "Must return void or %s", |
| DoFn.ProcessContinuation.class.getSimpleName()); |
| |
| MethodAnalysisContext methodContext = MethodAnalysisContext.create(); |
| |
| boolean requiresStableInput = m.isAnnotationPresent(DoFn.RequiresStableInput.class); |
| |
| Type[] params = m.getGenericParameterTypes(); |
| |
| TypeDescriptor<?> trackerT = getTrackerType(fnClass, m); |
| TypeDescriptor<? extends BoundedWindow> windowT = getWindowType(fnClass, m); |
| for (int i = 0; i < params.length; ++i) { |
| Parameter extraParam = |
| analyzeExtraParameter( |
| errors.forMethod(DoFn.ProcessElement.class, m), |
| fnContext, |
| methodContext, |
| fnClass, |
| ParameterDescription.of( |
| m, |
| i, |
| fnClass.resolveType(params[i]), |
| Arrays.asList(m.getParameterAnnotations()[i])), |
| inputT, |
| outputT); |
| |
| methodContext.addParameter(extraParam); |
| } |
| int schemaElementIndex = 0; |
| for (int i = 0; i < methodContext.getExtraParameters().size(); ++i) { |
| Parameter parameter = methodContext.getExtraParameters().get(i); |
| if (parameter instanceof SchemaElementParameter) { |
| SchemaElementParameter schemaParameter = (SchemaElementParameter) parameter; |
| schemaParameter = schemaParameter.toBuilder().setIndex(schemaElementIndex).build(); |
| methodContext.setParameter(i, schemaParameter); |
| ++schemaElementIndex; |
| } |
| } |
| |
| // The allowed parameters depend on whether this DoFn is splittable |
| if (methodContext.hasRestrictionTrackerParameter()) { |
| for (Parameter parameter : methodContext.getExtraParameters()) { |
| checkParameterOneOf(errors, parameter, ALLOWED_SPLITTABLE_PROCESS_ELEMENT_PARAMETERS); |
| } |
| } else { |
| for (Parameter parameter : methodContext.getExtraParameters()) { |
| checkParameterOneOf(errors, parameter, ALLOWED_NON_SPLITTABLE_PROCESS_ELEMENT_PARAMETERS); |
| } |
| } |
| |
| return DoFnSignature.ProcessElementMethod.create( |
| m, |
| methodContext.getExtraParameters(), |
| requiresStableInput, |
| trackerT, |
| windowT, |
| DoFn.ProcessContinuation.class.equals(m.getReturnType())); |
| } |
| |
| private static void checkParameterOneOf( |
| ErrorReporter errors, |
| Parameter parameter, |
| Collection<Class<? extends Parameter>> allowedParameterClasses) { |
| |
| for (Class<? extends Parameter> paramClass : allowedParameterClasses) { |
| if (paramClass.isAssignableFrom(parameter.getClass())) { |
| return; |
| } |
| } |
| |
| // If we get here, none matched |
| errors.throwIllegalArgument("Illegal parameter type: %s", parameter); |
| } |
| |
| private static Parameter analyzeExtraParameter( |
| ErrorReporter methodErrors, |
| FnAnalysisContext fnContext, |
| MethodAnalysisContext methodContext, |
| TypeDescriptor<? extends DoFn<?, ?>> fnClass, |
| ParameterDescription param, |
| TypeDescriptor<?> inputT, |
| TypeDescriptor<?> outputT) { |
| |
| TypeDescriptor<?> expectedProcessContextT = doFnProcessContextTypeOf(inputT, outputT); |
| TypeDescriptor<?> expectedOnTimerContextT = doFnOnTimerContextTypeOf(inputT, outputT); |
| |
| TypeDescriptor<?> paramT = param.getType(); |
| Class<?> rawType = paramT.getRawType(); |
| |
| ErrorReporter paramErrors = methodErrors.forParameter(param); |
| |
| String fieldAccessString = getFieldAccessId(param.getAnnotations()); |
| if (fieldAccessString != null) { |
| return Parameter.schemaElementParameter(paramT, fieldAccessString, param.getIndex()); |
| } else if (hasElementAnnotation(param.getAnnotations())) { |
| return (paramT.equals(inputT)) |
| ? Parameter.elementParameter(paramT) |
| : Parameter.schemaElementParameter(paramT, null, param.getIndex()); |
| } else if (hasTimestampAnnotation(param.getAnnotations())) { |
| methodErrors.checkArgument( |
| rawType.equals(Instant.class), |
| "@Timestamp argument must have type org.joda.time.Instant."); |
| return Parameter.timestampParameter(); |
| } else if (rawType.equals(TimeDomain.class)) { |
| return Parameter.timeDomainParameter(); |
| } else if (rawType.equals(PaneInfo.class)) { |
| return Parameter.paneInfoParameter(); |
| } else if (rawType.equals(DoFn.ProcessContext.class)) { |
| paramErrors.checkArgument( |
| paramT.equals(expectedProcessContextT), |
| "ProcessContext argument must have type %s", |
| formatType(expectedProcessContextT)); |
| return Parameter.processContext(); |
| } else if (rawType.equals(DoFn.OnTimerContext.class)) { |
| paramErrors.checkArgument( |
| paramT.equals(expectedOnTimerContextT), |
| "OnTimerContext argument must have type %s", |
| formatType(expectedOnTimerContextT)); |
| return Parameter.onTimerContext(); |
| } else if (BoundedWindow.class.isAssignableFrom(rawType)) { |
| methodErrors.checkArgument( |
| !methodContext.hasWindowParameter(), |
| "Multiple %s parameters", |
| BoundedWindow.class.getSimpleName()); |
| return Parameter.boundedWindow((TypeDescriptor<? extends BoundedWindow>) paramT); |
| } else if (rawType.equals(OutputReceiver.class)) { |
| // It's a schema row receiver if it's an OutputReceiver<Row> _and_ the output type is not |
| // already Row. |
| boolean schemaRowReceiver = |
| paramT.equals(outputReceiverTypeOf(TypeDescriptor.of(Row.class))) |
| && !outputT.equals(TypeDescriptor.of(Row.class)); |
| if (!schemaRowReceiver) { |
| TypeDescriptor<?> expectedReceiverT = outputReceiverTypeOf(outputT); |
| paramErrors.checkArgument( |
| paramT.equals(expectedReceiverT), |
| "OutputReceiver should be parameterized by %s", |
| outputT); |
| } |
| return Parameter.outputReceiverParameter(schemaRowReceiver); |
| } else if (rawType.equals(MultiOutputReceiver.class)) { |
| return Parameter.taggedOutputReceiverParameter(); |
| } else if (PipelineOptions.class.equals(rawType)) { |
| methodErrors.checkArgument( |
| !methodContext.hasPipelineOptionsParamter(), |
| "Multiple %s parameters", |
| PipelineOptions.class.getSimpleName()); |
| return Parameter.pipelineOptions(); |
| } else if (RestrictionTracker.class.isAssignableFrom(rawType)) { |
| methodErrors.checkArgument( |
| !methodContext.hasRestrictionTrackerParameter(), |
| "Multiple %s parameters", |
| RestrictionTracker.class.getSimpleName()); |
| return Parameter.restrictionTracker(paramT); |
| |
| } else if (rawType.equals(Timer.class)) { |
| // m.getParameters() is not available until Java 8 |
| String id = getTimerId(param.getAnnotations()); |
| |
| paramErrors.checkArgument( |
| id != null, |
| "%s missing %s annotation", |
| Timer.class.getSimpleName(), |
| TimerId.class.getSimpleName()); |
| |
| paramErrors.checkArgument( |
| !methodContext.getTimerParameters().containsKey(id), |
| "duplicate %s: \"%s\"", |
| TimerId.class.getSimpleName(), |
| id); |
| |
| TimerDeclaration timerDecl = fnContext.getTimerDeclarations().get(id); |
| paramErrors.checkArgument( |
| timerDecl != null, |
| "reference to undeclared %s: \"%s\"", |
| TimerId.class.getSimpleName(), |
| id); |
| |
| paramErrors.checkArgument( |
| timerDecl.field().getDeclaringClass().equals(param.getMethod().getDeclaringClass()), |
| "%s %s declared in a different class %s." |
| + " Timers may be referenced only in the lexical scope where they are declared.", |
| TimerId.class.getSimpleName(), |
| id, |
| timerDecl.field().getDeclaringClass().getName()); |
| |
| return Parameter.timerParameter(timerDecl); |
| |
| } else if (State.class.isAssignableFrom(rawType)) { |
| // m.getParameters() is not available until Java 8 |
| String id = getStateId(param.getAnnotations()); |
| paramErrors.checkArgument( |
| id != null, "missing %s annotation", DoFn.StateId.class.getSimpleName()); |
| |
| paramErrors.checkArgument( |
| !methodContext.getStateParameters().containsKey(id), |
| "duplicate %s: \"%s\"", |
| DoFn.StateId.class.getSimpleName(), |
| id); |
| |
| // By static typing this is already a well-formed State subclass |
| TypeDescriptor<? extends State> stateType = (TypeDescriptor<? extends State>) param.getType(); |
| |
| StateDeclaration stateDecl = fnContext.getStateDeclarations().get(id); |
| paramErrors.checkArgument( |
| stateDecl != null, |
| "reference to undeclared %s: \"%s\"", |
| DoFn.StateId.class.getSimpleName(), |
| id); |
| |
| paramErrors.checkArgument( |
| stateDecl.stateType().isSubtypeOf(stateType), |
| "data type of reference to %s %s must be a supertype of %s", |
| StateId.class.getSimpleName(), |
| id, |
| formatType(stateDecl.stateType())); |
| |
| paramErrors.checkArgument( |
| stateDecl.field().getDeclaringClass().equals(param.getMethod().getDeclaringClass()), |
| "%s %s declared in a different class %s." |
| + " State may be referenced only in the class where it is declared.", |
| StateId.class.getSimpleName(), |
| id, |
| stateDecl.field().getDeclaringClass().getName()); |
| |
| return Parameter.stateParameter(stateDecl); |
| } else { |
| List<String> allowedParamTypes = |
| Arrays.asList( |
| formatType(new TypeDescriptor<BoundedWindow>() {}), |
| formatType(new TypeDescriptor<RestrictionTracker<?, ?>>() {})); |
| paramErrors.throwIllegalArgument( |
| "%s is not a valid context parameter. Should be one of %s", |
| formatType(paramT), allowedParamTypes); |
| // Unreachable |
| return null; |
| } |
| } |
| |
| @Nullable |
| private static String getTimerId(List<Annotation> annotations) { |
| DoFn.TimerId stateId = findFirstOfType(annotations, DoFn.TimerId.class); |
| return stateId != null ? stateId.value() : null; |
| } |
| |
| @Nullable |
| private static String getStateId(List<Annotation> annotations) { |
| DoFn.StateId stateId = findFirstOfType(annotations, DoFn.StateId.class); |
| return stateId != null ? stateId.value() : null; |
| } |
| |
| @Nullable |
| private static String getFieldAccessId(List<Annotation> annotations) { |
| DoFn.FieldAccess access = findFirstOfType(annotations, DoFn.FieldAccess.class); |
| return access != null ? access.value() : null; |
| } |
| |
| @Nullable |
| static <T> T findFirstOfType(List<Annotation> annotations, Class<T> clazz) { |
| Optional<Annotation> annotation = |
| annotations.stream().filter(a -> a.annotationType().equals(clazz)).findFirst(); |
| return annotation.isPresent() ? (T) annotation.get() : null; |
| } |
| |
| private static boolean hasElementAnnotation(List<Annotation> annotations) { |
| return annotations.stream().anyMatch(a -> a.annotationType().equals(DoFn.Element.class)); |
| } |
| |
| private static boolean hasTimestampAnnotation(List<Annotation> annotations) { |
| return annotations.stream().anyMatch(a -> a.annotationType().equals(DoFn.Timestamp.class)); |
| } |
| |
| @Nullable |
| private static TypeDescriptor<?> getTrackerType(TypeDescriptor<?> fnClass, Method method) { |
| Type[] params = method.getGenericParameterTypes(); |
| for (Type param : params) { |
| TypeDescriptor<?> paramT = fnClass.resolveType(param); |
| if (RestrictionTracker.class.isAssignableFrom(paramT.getRawType())) { |
| return paramT; |
| } |
| } |
| return null; |
| } |
| |
| @Nullable |
| private static TypeDescriptor<? extends BoundedWindow> getWindowType( |
| TypeDescriptor<?> fnClass, Method method) { |
| Type[] params = method.getGenericParameterTypes(); |
| for (Type param : params) { |
| TypeDescriptor<?> paramT = fnClass.resolveType(param); |
| if (BoundedWindow.class.isAssignableFrom(paramT.getRawType())) { |
| return (TypeDescriptor<? extends BoundedWindow>) paramT; |
| } |
| } |
| return null; |
| } |
| |
| @VisibleForTesting |
| static DoFnSignature.BundleMethod analyzeStartBundleMethod( |
| ErrorReporter errors, |
| TypeDescriptor<? extends DoFn<?, ?>> fnT, |
| Method m, |
| TypeDescriptor<?> inputT, |
| TypeDescriptor<?> outputT) { |
| errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); |
| TypeDescriptor<?> expectedContextT = doFnStartBundleContextTypeOf(inputT, outputT); |
| Type[] params = m.getGenericParameterTypes(); |
| errors.checkArgument( |
| params.length == 0 |
| || (params.length == 1 && fnT.resolveType(params[0]).equals(expectedContextT)), |
| "Must take a single argument of type %s", |
| formatType(expectedContextT)); |
| return DoFnSignature.BundleMethod.create(m); |
| } |
| |
| @VisibleForTesting |
| static DoFnSignature.BundleMethod analyzeFinishBundleMethod( |
| ErrorReporter errors, |
| TypeDescriptor<? extends DoFn<?, ?>> fnT, |
| Method m, |
| TypeDescriptor<?> inputT, |
| TypeDescriptor<?> outputT) { |
| errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); |
| TypeDescriptor<?> expectedContextT = doFnFinishBundleContextTypeOf(inputT, outputT); |
| Type[] params = m.getGenericParameterTypes(); |
| errors.checkArgument( |
| params.length == 0 |
| || (params.length == 1 && fnT.resolveType(params[0]).equals(expectedContextT)), |
| "Must take a single argument of type %s", |
| formatType(expectedContextT)); |
| return DoFnSignature.BundleMethod.create(m); |
| } |
| |
| private static DoFnSignature.LifecycleMethod analyzeLifecycleMethod( |
| ErrorReporter errors, Method m) { |
| errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); |
| errors.checkArgument(m.getGenericParameterTypes().length == 0, "Must take zero arguments"); |
| return DoFnSignature.LifecycleMethod.create(m); |
| } |
| |
| @VisibleForTesting |
| static DoFnSignature.GetInitialRestrictionMethod analyzeGetInitialRestrictionMethod( |
| ErrorReporter errors, |
| TypeDescriptor<? extends DoFn> fnT, |
| Method m, |
| TypeDescriptor<?> inputT) { |
| // Method is of the form: |
| // @GetInitialRestriction |
| // RestrictionT getInitialRestriction(InputT element); |
| Type[] params = m.getGenericParameterTypes(); |
| errors.checkArgument( |
| params.length == 1 && fnT.resolveType(params[0]).equals(inputT), |
| "Must take a single argument of type %s", |
| formatType(inputT)); |
| return DoFnSignature.GetInitialRestrictionMethod.create( |
| m, fnT.resolveType(m.getGenericReturnType())); |
| } |
| |
| /** |
| * Generates a {@link TypeDescriptor} for {@code DoFn.OutputReceiver<OutputT>} given {@code |
| * OutputT}. |
| */ |
| private static <OutputT> TypeDescriptor<DoFn.OutputReceiver<OutputT>> outputReceiverTypeOf( |
| TypeDescriptor<OutputT> outputT) { |
| return new TypeDescriptor<DoFn.OutputReceiver<OutputT>>() {}.where( |
| new TypeParameter<OutputT>() {}, outputT); |
| } |
| |
| @VisibleForTesting |
| static DoFnSignature.SplitRestrictionMethod analyzeSplitRestrictionMethod( |
| ErrorReporter errors, |
| TypeDescriptor<? extends DoFn> fnT, |
| Method m, |
| TypeDescriptor<?> inputT) { |
| // Method is of the form: |
| // @SplitRestriction |
| // void splitRestriction(InputT element, RestrictionT restriction); |
| errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void"); |
| |
| Type[] params = m.getGenericParameterTypes(); |
| errors.checkArgument(params.length == 3, "Must have exactly 3 arguments"); |
| errors.checkArgument( |
| fnT.resolveType(params[0]).equals(inputT), |
| "First argument must be the element type %s", |
| formatType(inputT)); |
| |
| TypeDescriptor<?> restrictionT = fnT.resolveType(params[1]); |
| TypeDescriptor<?> receiverT = fnT.resolveType(params[2]); |
| TypeDescriptor<?> expectedReceiverT = outputReceiverTypeOf(restrictionT); |
| errors.checkArgument( |
| receiverT.equals(expectedReceiverT), |
| "Third argument must be %s, but is %s", |
| formatType(expectedReceiverT), |
| formatType(receiverT)); |
| |
| return DoFnSignature.SplitRestrictionMethod.create(m, restrictionT); |
| } |
| |
| private static ImmutableMap<String, TimerDeclaration> analyzeTimerDeclarations( |
| ErrorReporter errors, Class<?> fnClazz) { |
| Map<String, DoFnSignature.TimerDeclaration> declarations = new HashMap<>(); |
| for (Field field : declaredFieldsWithAnnotation(DoFn.TimerId.class, fnClazz, DoFn.class)) { |
| // TimerSpec fields may generally be private, but will be accessed via the signature |
| field.setAccessible(true); |
| String id = field.getAnnotation(DoFn.TimerId.class).value(); |
| validateTimerField(errors, declarations, id, field); |
| declarations.put(id, DoFnSignature.TimerDeclaration.create(id, field)); |
| } |
| |
| return ImmutableMap.copyOf(declarations); |
| } |
| |
| /** |
| * Returns successfully if the field is valid, otherwise throws an exception via its {@link |
| * ErrorReporter} parameter describing validation failures for the timer declaration. |
| */ |
| private static void validateTimerField( |
| ErrorReporter errors, Map<String, TimerDeclaration> declarations, String id, Field field) { |
| if (declarations.containsKey(id)) { |
| errors.throwIllegalArgument( |
| "Duplicate %s \"%s\", used on both of [%s] and [%s]", |
| DoFn.TimerId.class.getSimpleName(), |
| id, |
| field.toString(), |
| declarations.get(id).field().toString()); |
| } |
| |
| Class<?> timerSpecRawType = field.getType(); |
| if (!(timerSpecRawType.equals(TimerSpec.class))) { |
| errors.throwIllegalArgument( |
| "%s annotation on non-%s field [%s]", |
| DoFn.TimerId.class.getSimpleName(), TimerSpec.class.getSimpleName(), field.toString()); |
| } |
| |
| if (!Modifier.isFinal(field.getModifiers())) { |
| errors.throwIllegalArgument( |
| "Non-final field %s annotated with %s. Timer declarations must be final.", |
| field.toString(), DoFn.TimerId.class.getSimpleName()); |
| } |
| } |
| |
| /** Generates a {@link TypeDescriptor} for {@code Coder<T>} given {@code T}. */ |
| private static <T> TypeDescriptor<Coder<T>> coderTypeOf(TypeDescriptor<T> elementT) { |
| return new TypeDescriptor<Coder<T>>() {}.where(new TypeParameter<T>() {}, elementT); |
| } |
| |
| @VisibleForTesting |
| static DoFnSignature.GetRestrictionCoderMethod analyzeGetRestrictionCoderMethod( |
| ErrorReporter errors, TypeDescriptor<? extends DoFn> fnT, Method m) { |
| errors.checkArgument(m.getParameterTypes().length == 0, "Must have zero arguments"); |
| TypeDescriptor<?> resT = fnT.resolveType(m.getGenericReturnType()); |
| errors.checkArgument( |
| resT.isSubtypeOf(TypeDescriptor.of(Coder.class)), |
| "Must return a Coder, but returns %s", |
| formatType(resT)); |
| return DoFnSignature.GetRestrictionCoderMethod.create(m, resT); |
| } |
| |
| /** |
| * Generates a {@link TypeDescriptor} for {@code RestrictionTracker<RestrictionT>} given {@code |
| * RestrictionT}. |
| */ |
| private static <RestrictionT> |
| TypeDescriptor<RestrictionTracker<RestrictionT, ?>> restrictionTrackerTypeOf( |
| TypeDescriptor<RestrictionT> restrictionT) { |
| return new TypeDescriptor<RestrictionTracker<RestrictionT, ?>>() {}.where( |
| new TypeParameter<RestrictionT>() {}, restrictionT); |
| } |
| |
| @VisibleForTesting |
| static DoFnSignature.NewTrackerMethod analyzeNewTrackerMethod( |
| ErrorReporter errors, TypeDescriptor<? extends DoFn> fnT, Method m) { |
| // Method is of the form: |
| // @NewTracker |
| // TrackerT newTracker(RestrictionT restriction); |
| Type[] params = m.getGenericParameterTypes(); |
| errors.checkArgument(params.length == 1, "Must have a single argument"); |
| |
| TypeDescriptor<?> restrictionT = fnT.resolveType(params[0]); |
| TypeDescriptor<?> trackerT = fnT.resolveType(m.getGenericReturnType()); |
| TypeDescriptor<?> expectedTrackerT = restrictionTrackerTypeOf(restrictionT); |
| errors.checkArgument( |
| trackerT.isSubtypeOf(expectedTrackerT), |
| "Returns %s, but must return a subtype of %s", |
| formatType(trackerT), |
| formatType(expectedTrackerT)); |
| return DoFnSignature.NewTrackerMethod.create(m, restrictionT, trackerT); |
| } |
| |
| private static Collection<Method> declaredMethodsWithAnnotation( |
| Class<? extends Annotation> anno, Class<?> startClass, Class<?> stopClass) { |
| return declaredMembersWithAnnotation(anno, startClass, stopClass, GET_METHODS); |
| } |
| |
| private static Collection<Field> declaredFieldsWithAnnotation( |
| Class<? extends Annotation> anno, Class<?> startClass, Class<?> stopClass) { |
| return declaredMembersWithAnnotation(anno, startClass, stopClass, GET_FIELDS); |
| } |
| |
| private interface MemberGetter<MemberT> { |
| MemberT[] getMembers(Class<?> clazz); |
| } |
| |
| private static final MemberGetter<Method> GET_METHODS = Class::getDeclaredMethods; |
| |
| private static final MemberGetter<Field> GET_FIELDS = Class::getDeclaredFields; |
| |
| private static <MemberT extends AnnotatedElement> |
| Collection<MemberT> declaredMembersWithAnnotation( |
| Class<? extends Annotation> anno, |
| Class<?> startClass, |
| Class<?> stopClass, |
| MemberGetter<MemberT> getter) { |
| Collection<MemberT> matches = new ArrayList<>(); |
| |
| Class<?> clazz = startClass; |
| LinkedHashSet<Class<?>> interfaces = new LinkedHashSet<>(); |
| |
| // First, find all declared methods on the startClass and parents (up to stopClass) |
| while (clazz != null && !clazz.equals(stopClass)) { |
| for (MemberT member : getter.getMembers(clazz)) { |
| if (member.isAnnotationPresent(anno)) { |
| matches.add(member); |
| } |
| } |
| |
| // Add all interfaces, including transitive |
| for (TypeDescriptor<?> iface : TypeDescriptor.of(clazz).getInterfaces()) { |
| interfaces.add(iface.getRawType()); |
| } |
| |
| clazz = clazz.getSuperclass(); |
| } |
| |
| // Now, iterate over all the discovered interfaces |
| for (Class<?> iface : interfaces) { |
| for (MemberT member : getter.getMembers(iface)) { |
| if (member.isAnnotationPresent(anno)) { |
| matches.add(member); |
| } |
| } |
| } |
| return matches; |
| } |
| |
| private static Map<String, DoFnSignature.FieldAccessDeclaration> analyzeFieldAccessDeclaration( |
| ErrorReporter errors, Class<?> fnClazz) { |
| Map<String, FieldAccessDeclaration> fieldAccessDeclarations = new HashMap<>(); |
| for (Field field : declaredFieldsWithAnnotation(DoFn.FieldAccess.class, fnClazz, DoFn.class)) { |
| field.setAccessible(true); |
| DoFn.FieldAccess fieldAccessAnnotation = field.getAnnotation(DoFn.FieldAccess.class); |
| if (!Modifier.isFinal(field.getModifiers())) { |
| errors.throwIllegalArgument( |
| "Non-final field %s annotated with %s. Field access declarations must be final.", |
| field.toString(), DoFn.FieldAccess.class.getSimpleName()); |
| continue; |
| } |
| Class<?> fieldAccessRawType = field.getType(); |
| if (!fieldAccessRawType.equals(FieldAccessDescriptor.class)) { |
| errors.throwIllegalArgument( |
| "Field %s annotated with %s, but the value was not of type %s", |
| field.toString(), |
| DoFn.FieldAccess.class.getSimpleName(), |
| FieldAccessDescriptor.class.getSimpleName()); |
| } |
| fieldAccessDeclarations.put( |
| fieldAccessAnnotation.value(), |
| FieldAccessDeclaration.create(fieldAccessAnnotation.value(), field)); |
| } |
| return fieldAccessDeclarations; |
| } |
| |
| private static Map<String, DoFnSignature.StateDeclaration> analyzeStateDeclarations( |
| ErrorReporter errors, Class<?> fnClazz) { |
| |
| Map<String, DoFnSignature.StateDeclaration> declarations = new HashMap<>(); |
| |
| for (Field field : declaredFieldsWithAnnotation(DoFn.StateId.class, fnClazz, DoFn.class)) { |
| // StateSpec fields may generally be private, but will be accessed via the signature |
| field.setAccessible(true); |
| String id = field.getAnnotation(DoFn.StateId.class).value(); |
| |
| if (declarations.containsKey(id)) { |
| errors.throwIllegalArgument( |
| "Duplicate %s \"%s\", used on both of [%s] and [%s]", |
| DoFn.StateId.class.getSimpleName(), |
| id, |
| field.toString(), |
| declarations.get(id).field().toString()); |
| continue; |
| } |
| |
| Class<?> stateSpecRawType = field.getType(); |
| if (!(TypeDescriptor.of(stateSpecRawType).isSubtypeOf(TypeDescriptor.of(StateSpec.class)))) { |
| errors.throwIllegalArgument( |
| "%s annotation on non-%s field [%s] that has class %s", |
| DoFn.StateId.class.getSimpleName(), |
| StateSpec.class.getSimpleName(), |
| field.toString(), |
| stateSpecRawType.getName()); |
| continue; |
| } |
| |
| if (!Modifier.isFinal(field.getModifiers())) { |
| errors.throwIllegalArgument( |
| "Non-final field %s annotated with %s. State declarations must be final.", |
| field.toString(), DoFn.StateId.class.getSimpleName()); |
| continue; |
| } |
| |
| Type stateSpecType = field.getGenericType(); |
| |
| // A type descriptor for whatever type the @StateId-annotated class has, which |
| // must be some subtype of StateSpec |
| TypeDescriptor<? extends StateSpec<?>> stateSpecSubclassTypeDescriptor = |
| (TypeDescriptor) TypeDescriptor.of(stateSpecType); |
| |
| // A type descriptor for StateSpec, with the generic type parameters filled |
| // in according to the specialization of the subclass (or just straight params) |
| TypeDescriptor<StateSpec<?>> stateSpecTypeDescriptor = |
| (TypeDescriptor) stateSpecSubclassTypeDescriptor.getSupertype(StateSpec.class); |
| |
| // The type of the state, which may still have free type variables from the |
| // context |
| Type unresolvedStateType = |
| ((ParameterizedType) stateSpecTypeDescriptor.getType()).getActualTypeArguments()[0]; |
| |
| // By static typing this is already a well-formed State subclass |
| TypeDescriptor<? extends State> stateType = |
| (TypeDescriptor<? extends State>) |
| TypeDescriptor.of(fnClazz).resolveType(unresolvedStateType); |
| |
| declarations.put(id, DoFnSignature.StateDeclaration.create(id, field, stateType)); |
| } |
| |
| return ImmutableMap.copyOf(declarations); |
| } |
| |
| @Nullable |
| private static Method findAnnotatedMethod( |
| ErrorReporter errors, Class<? extends Annotation> anno, Class<?> fnClazz, boolean required) { |
| Collection<Method> matches = declaredMethodsWithAnnotation(anno, fnClazz, DoFn.class); |
| |
| if (matches.isEmpty()) { |
| errors.checkArgument(!required, "No method annotated with @%s found", anno.getSimpleName()); |
| return null; |
| } |
| |
| // If we have at least one match, then either it should be the only match |
| // or it should be an extension of the other matches (which came from parent |
| // classes). |
| Method first = matches.iterator().next(); |
| for (Method other : matches) { |
| errors.checkArgument( |
| first.getName().equals(other.getName()) |
| && Arrays.equals(first.getParameterTypes(), other.getParameterTypes()), |
| "Found multiple methods annotated with @%s. [%s] and [%s]", |
| anno.getSimpleName(), |
| format(first), |
| format(other)); |
| } |
| |
| ErrorReporter methodErrors = errors.forMethod(anno, first); |
| // We need to be able to call it. We require it is public. |
| methodErrors.checkArgument((first.getModifiers() & Modifier.PUBLIC) != 0, "Must be public"); |
| // And make sure its not static. |
| methodErrors.checkArgument((first.getModifiers() & Modifier.STATIC) == 0, "Must not be static"); |
| |
| return first; |
| } |
| |
| private static String format(Method method) { |
| return ReflectHelpers.METHOD_FORMATTER.apply(method); |
| } |
| |
| private static String formatType(TypeDescriptor<?> t) { |
| return ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply(t.getType()); |
| } |
| |
| static class ErrorReporter { |
| private final String label; |
| |
| ErrorReporter(@Nullable ErrorReporter root, String label) { |
| this.label = (root == null) ? label : String.format("%s, %s", root.label, label); |
| } |
| |
| ErrorReporter forMethod(Class<? extends Annotation> annotation, Method method) { |
| return new ErrorReporter( |
| this, |
| String.format( |
| "@%s %s", |
| annotation.getSimpleName(), (method == null) ? "(absent)" : format(method))); |
| } |
| |
| ErrorReporter forParameter(ParameterDescription param) { |
| return new ErrorReporter( |
| this, |
| String.format( |
| "parameter of type %s at index %s", formatType(param.getType()), param.getIndex())); |
| } |
| |
| void throwIllegalArgument(String message, Object... args) { |
| throw new IllegalArgumentException(label + ": " + String.format(message, args)); |
| } |
| |
| public void checkArgument(boolean condition, String message, Object... args) { |
| if (!condition) { |
| throwIllegalArgument(message, args); |
| } |
| } |
| |
| public void checkNotNull(Object value, String message, Object... args) { |
| if (value == null) { |
| throwIllegalArgument(message, args); |
| } |
| } |
| } |
| |
| public static StateSpec<?> getStateSpecOrThrow( |
| StateDeclaration stateDeclaration, DoFn<?, ?> target) { |
| try { |
| Object fieldValue = stateDeclaration.field().get(target); |
| checkState( |
| fieldValue instanceof StateSpec, |
| "Malformed %s class %s: state declaration field %s does not have type %s.", |
| DoFn.class.getSimpleName(), |
| target.getClass().getName(), |
| stateDeclaration.field().getName(), |
| StateSpec.class); |
| |
| return (StateSpec<?>) stateDeclaration.field().get(target); |
| } catch (IllegalAccessException exc) { |
| throw new RuntimeException( |
| String.format( |
| "Malformed %s class %s: state declaration field %s is not accessible.", |
| DoFn.class.getSimpleName(), |
| target.getClass().getName(), |
| stateDeclaration.field().getName())); |
| } |
| } |
| |
| public static TimerSpec getTimerSpecOrThrow( |
| TimerDeclaration timerDeclaration, DoFn<?, ?> target) { |
| try { |
| Object fieldValue = timerDeclaration.field().get(target); |
| checkState( |
| fieldValue instanceof TimerSpec, |
| "Malformed %s class %s: timer declaration field %s does not have type %s.", |
| DoFn.class.getSimpleName(), |
| target.getClass().getName(), |
| timerDeclaration.field().getName(), |
| TimerSpec.class); |
| |
| return (TimerSpec) timerDeclaration.field().get(target); |
| } catch (IllegalAccessException exc) { |
| throw new RuntimeException( |
| String.format( |
| "Malformed %s class %s: timer declaration field %s is not accessible.", |
| DoFn.class.getSimpleName(), |
| target.getClass().getName(), |
| timerDeclaration.field().getName())); |
| } |
| } |
| } |