blob: 69b03a43d0d9fc6f3a2c17d667194227bd988059 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.transforms.reflect;
import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkState;
import com.google.auto.value.AutoValue;
import java.lang.annotation.Annotation;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.state.Timer;
import org.apache.beam.sdk.state.TimerSpec;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
import org.apache.beam.sdk.transforms.DoFn.StateId;
import org.apache.beam.sdk.transforms.DoFn.TimerId;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.FieldAccessDeclaration;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.StateDeclaration;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.TimerDeclaration;
import org.apache.beam.sdk.transforms.splittabledofn.HasDefaultTracker;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.common.ReflectHelpers;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.TypeParameter;
import org.apache.beam.vendor.guava.v20_0.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v20_0.com.google.common.base.Predicates;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Maps;
import org.joda.time.Instant;
/** Utilities for working with {@link DoFnSignature}. See {@link #getSignature}. */
public class DoFnSignatures {
private DoFnSignatures() {}
private static final Map<Class<?>, DoFnSignature> signatureCache = new LinkedHashMap<>();
private static final ImmutableList<Class<? extends Parameter>>
ALLOWED_NON_SPLITTABLE_PROCESS_ELEMENT_PARAMETERS =
ImmutableList.of(
Parameter.ProcessContextParameter.class,
Parameter.ElementParameter.class,
Parameter.SchemaElementParameter.class,
Parameter.TimestampParameter.class,
Parameter.OutputReceiverParameter.class,
Parameter.TaggedOutputReceiverParameter.class,
Parameter.WindowParameter.class,
Parameter.PaneInfoParameter.class,
Parameter.PipelineOptionsParameter.class,
Parameter.TimerParameter.class,
Parameter.StateParameter.class);
private static final ImmutableList<Class<? extends Parameter>>
ALLOWED_SPLITTABLE_PROCESS_ELEMENT_PARAMETERS =
ImmutableList.of(
Parameter.PipelineOptionsParameter.class,
Parameter.ElementParameter.class,
Parameter.TimestampParameter.class,
Parameter.OutputReceiverParameter.class,
Parameter.TaggedOutputReceiverParameter.class,
Parameter.ProcessContextParameter.class,
Parameter.RestrictionTrackerParameter.class);
private static final ImmutableList<Class<? extends Parameter>> ALLOWED_ON_TIMER_PARAMETERS =
ImmutableList.of(
Parameter.OnTimerContextParameter.class,
Parameter.TimestampParameter.class,
Parameter.TimeDomainParameter.class,
Parameter.WindowParameter.class,
Parameter.PipelineOptionsParameter.class,
Parameter.OutputReceiverParameter.class,
Parameter.TaggedOutputReceiverParameter.class,
Parameter.TimerParameter.class,
Parameter.StateParameter.class);
private static final Collection<Class<? extends Parameter>>
ALLOWED_ON_WINDOW_EXPIRATION_PARAMETERS =
ImmutableList.of(
Parameter.WindowParameter.class,
Parameter.PipelineOptionsParameter.class,
Parameter.OutputReceiverParameter.class,
Parameter.TaggedOutputReceiverParameter.class,
Parameter.StateParameter.class);
/** @return the {@link DoFnSignature} for the given {@link DoFn} instance. */
public static <FnT extends DoFn<?, ?>> DoFnSignature signatureForDoFn(FnT fn) {
return getSignature(fn.getClass());
}
/** @return the {@link DoFnSignature} for the given {@link DoFn} subclass. */
public static synchronized <FnT extends DoFn<?, ?>> DoFnSignature getSignature(Class<FnT> fn) {
return signatureCache.computeIfAbsent(fn, k -> parseSignature(fn));
}
/**
* The context for a {@link DoFn} class, for use in analysis.
*
* <p>It contains much of the information that eventually becomes part of the {@link
* DoFnSignature}, but in an intermediate state.
*/
@VisibleForTesting
static class FnAnalysisContext {
private final Map<String, StateDeclaration> stateDeclarations = new HashMap<>();
private final Map<String, TimerDeclaration> timerDeclarations = new HashMap<>();
private final Map<String, FieldAccessDeclaration> fieldAccessDeclarations = new HashMap<>();
private FnAnalysisContext() {}
/** Create an empty context, with no declarations. */
public static FnAnalysisContext create() {
return new FnAnalysisContext();
}
/** State parameters declared in this context, keyed by {@link StateId}. Unmodifiable. */
public Map<String, StateDeclaration> getStateDeclarations() {
return Collections.unmodifiableMap(stateDeclarations);
}
/** Timer parameters declared in this context, keyed by {@link TimerId}. Unmodifiable. */
public Map<String, TimerDeclaration> getTimerDeclarations() {
return Collections.unmodifiableMap(timerDeclarations);
}
/** Field access declaration declared in this context. */
@Nullable
public Map<String, FieldAccessDeclaration> getFieldAccessDeclarations() {
return fieldAccessDeclarations;
}
public void addStateDeclaration(StateDeclaration decl) {
stateDeclarations.put(decl.id(), decl);
}
public void addStateDeclarations(Iterable<StateDeclaration> decls) {
for (StateDeclaration decl : decls) {
addStateDeclaration(decl);
}
}
public void addTimerDeclaration(TimerDeclaration decl) {
timerDeclarations.put(decl.id(), decl);
}
public void addTimerDeclarations(Iterable<TimerDeclaration> decls) {
for (TimerDeclaration decl : decls) {
addTimerDeclaration(decl);
}
}
public void addFieldAccessDeclaration(FieldAccessDeclaration decl) {
fieldAccessDeclarations.put(decl.id(), decl);
}
public void addFieldAccessDeclarations(Iterable<FieldAccessDeclaration> decls) {
for (FieldAccessDeclaration decl : decls) {
addFieldAccessDeclaration(decl);
}
}
}
/**
* The context of analysis within a particular method.
*
* <p>It contains much of the information that eventually becomes part of the {@link
* DoFnSignature.MethodWithExtraParameters}, but in an intermediate state.
*/
private static class MethodAnalysisContext {
private final Map<String, StateParameter> stateParameters = new HashMap<>();
private final Map<String, TimerParameter> timerParameters = new HashMap<>();
private final List<Parameter> extraParameters = new ArrayList<>();
@Nullable private TypeDescriptor<? extends BoundedWindow> windowT;
private MethodAnalysisContext() {}
/** Indicates whether a {@link RestrictionTrackerParameter} is known in this context. */
public boolean hasRestrictionTrackerParameter() {
return extraParameters.stream()
.anyMatch(Predicates.instanceOf(RestrictionTrackerParameter.class)::apply);
}
/** Indicates whether a {@link WindowParameter} is known in this context. */
public boolean hasWindowParameter() {
return extraParameters.stream().anyMatch(Predicates.instanceOf(WindowParameter.class)::apply);
}
/** Indicates whether a {@link Parameter.PipelineOptionsParameter} is known in this context. */
public boolean hasPipelineOptionsParamter() {
return extraParameters.stream()
.anyMatch(Predicates.instanceOf(Parameter.PipelineOptionsParameter.class)::apply);
}
/** The window type, if any, used by this method. */
@Nullable
public TypeDescriptor<? extends BoundedWindow> getWindowType() {
return windowT;
}
/** State parameters declared in this context, keyed by {@link StateId}. */
public Map<String, StateParameter> getStateParameters() {
return Collections.unmodifiableMap(stateParameters);
}
/** Timer parameters declared in this context, keyed by {@link TimerId}. */
public Map<String, TimerParameter> getTimerParameters() {
return Collections.unmodifiableMap(timerParameters);
}
/** Extra parameters in their entirety. Unmodifiable. */
public List<Parameter> getExtraParameters() {
return Collections.unmodifiableList(extraParameters);
}
public void setParameter(int index, Parameter parameter) {
extraParameters.set(index, parameter);
}
/**
* Returns an {@link MethodAnalysisContext} like this one but including the provided {@link
* StateParameter}.
*/
public void addParameter(Parameter param) {
extraParameters.add(param);
if (param instanceof StateParameter) {
StateParameter stateParameter = (StateParameter) param;
stateParameters.put(stateParameter.referent().id(), stateParameter);
}
if (param instanceof TimerParameter) {
TimerParameter timerParameter = (TimerParameter) param;
timerParameters.put(timerParameter.referent().id(), timerParameter);
}
}
/** Create an empty context, with no declarations. */
public static MethodAnalysisContext create() {
return new MethodAnalysisContext();
}
}
@AutoValue
abstract static class ParameterDescription {
public abstract Method getMethod();
public abstract int getIndex();
public abstract TypeDescriptor<?> getType();
public abstract List<Annotation> getAnnotations();
public static ParameterDescription of(
Method method, int index, TypeDescriptor<?> type, List<Annotation> annotations) {
return new AutoValue_DoFnSignatures_ParameterDescription(method, index, type, annotations);
}
public static ParameterDescription of(
Method method, int index, TypeDescriptor<?> type, Annotation[] annotations) {
return new AutoValue_DoFnSignatures_ParameterDescription(
method, index, type, Arrays.asList(annotations));
}
}
/** Analyzes a given {@link DoFn} class and extracts its {@link DoFnSignature}. */
private static DoFnSignature parseSignature(Class<? extends DoFn<?, ?>> fnClass) {
DoFnSignature.Builder signatureBuilder = DoFnSignature.builder();
ErrorReporter errors = new ErrorReporter(null, fnClass.getName());
errors.checkArgument(DoFn.class.isAssignableFrom(fnClass), "Must be subtype of DoFn");
signatureBuilder.setFnClass(fnClass);
TypeDescriptor<? extends DoFn<?, ?>> fnT = TypeDescriptor.of(fnClass);
// Extract the input and output type, and whether the fn is bounded.
TypeDescriptor<?> inputT = null;
TypeDescriptor<?> outputT = null;
for (TypeDescriptor<?> supertype : fnT.getTypes()) {
if (!supertype.getRawType().equals(DoFn.class)) {
continue;
}
Type[] args = ((ParameterizedType) supertype.getType()).getActualTypeArguments();
inputT = TypeDescriptor.of(args[0]);
outputT = TypeDescriptor.of(args[1]);
}
errors.checkNotNull(inputT, "Unable to determine input type");
// Find the state and timer declarations in advance of validating
// method parameter lists
FnAnalysisContext fnContext = FnAnalysisContext.create();
fnContext.addStateDeclarations(analyzeStateDeclarations(errors, fnClass).values());
fnContext.addTimerDeclarations(analyzeTimerDeclarations(errors, fnClass).values());
fnContext.addFieldAccessDeclarations(analyzeFieldAccessDeclaration(errors, fnClass).values());
Method processElementMethod =
findAnnotatedMethod(errors, DoFn.ProcessElement.class, fnClass, true);
Method startBundleMethod = findAnnotatedMethod(errors, DoFn.StartBundle.class, fnClass, false);
Method finishBundleMethod =
findAnnotatedMethod(errors, DoFn.FinishBundle.class, fnClass, false);
Method setupMethod = findAnnotatedMethod(errors, DoFn.Setup.class, fnClass, false);
Method teardownMethod = findAnnotatedMethod(errors, DoFn.Teardown.class, fnClass, false);
Method onWindowExpirationMethod =
findAnnotatedMethod(errors, DoFn.OnWindowExpiration.class, fnClass, false);
Method getInitialRestrictionMethod =
findAnnotatedMethod(errors, DoFn.GetInitialRestriction.class, fnClass, false);
Method splitRestrictionMethod =
findAnnotatedMethod(errors, DoFn.SplitRestriction.class, fnClass, false);
Method getRestrictionCoderMethod =
findAnnotatedMethod(errors, DoFn.GetRestrictionCoder.class, fnClass, false);
Method newTrackerMethod = findAnnotatedMethod(errors, DoFn.NewTracker.class, fnClass, false);
Collection<Method> onTimerMethods =
declaredMethodsWithAnnotation(DoFn.OnTimer.class, fnClass, DoFn.class);
HashMap<String, DoFnSignature.OnTimerMethod> onTimerMethodMap =
Maps.newHashMapWithExpectedSize(onTimerMethods.size());
for (Method onTimerMethod : onTimerMethods) {
String id = onTimerMethod.getAnnotation(DoFn.OnTimer.class).value();
errors.checkArgument(
fnContext.getTimerDeclarations().containsKey(id),
"Callback %s is for undeclared timer %s",
onTimerMethod,
id);
TimerDeclaration timerDecl = fnContext.getTimerDeclarations().get(id);
errors.checkArgument(
timerDecl.field().getDeclaringClass().equals(onTimerMethod.getDeclaringClass()),
"Callback %s is for timer %s declared in a different class %s."
+ " Timer callbacks must be declared in the same lexical scope as their timer",
onTimerMethod,
id,
timerDecl.field().getDeclaringClass().getCanonicalName());
onTimerMethodMap.put(
id, analyzeOnTimerMethod(errors, fnT, onTimerMethod, id, inputT, outputT, fnContext));
}
signatureBuilder.setOnTimerMethods(onTimerMethodMap);
// Check the converse - that all timers have a callback. This could be relaxed to only
// those timers used in methods, once method parameter lists support timers.
for (TimerDeclaration decl : fnContext.getTimerDeclarations().values()) {
errors.checkArgument(
onTimerMethodMap.containsKey(decl.id()),
"No callback registered via %s for timer %s",
DoFn.OnTimer.class.getSimpleName(),
decl.id());
}
ErrorReporter processElementErrors =
errors.forMethod(DoFn.ProcessElement.class, processElementMethod);
DoFnSignature.ProcessElementMethod processElement =
analyzeProcessElementMethod(
processElementErrors, fnT, processElementMethod, inputT, outputT, fnContext);
signatureBuilder.setProcessElement(processElement);
if (startBundleMethod != null) {
ErrorReporter startBundleErrors = errors.forMethod(DoFn.StartBundle.class, startBundleMethod);
signatureBuilder.setStartBundle(
analyzeStartBundleMethod(startBundleErrors, fnT, startBundleMethod, inputT, outputT));
}
if (finishBundleMethod != null) {
ErrorReporter finishBundleErrors =
errors.forMethod(DoFn.FinishBundle.class, finishBundleMethod);
signatureBuilder.setFinishBundle(
analyzeFinishBundleMethod(finishBundleErrors, fnT, finishBundleMethod, inputT, outputT));
}
if (setupMethod != null) {
signatureBuilder.setSetup(
analyzeLifecycleMethod(errors.forMethod(DoFn.Setup.class, setupMethod), setupMethod));
}
if (teardownMethod != null) {
signatureBuilder.setTeardown(
analyzeLifecycleMethod(
errors.forMethod(DoFn.Teardown.class, teardownMethod), teardownMethod));
}
if (onWindowExpirationMethod != null) {
signatureBuilder.setOnWindowExpiration(
analyzeOnWindowExpirationMethod(
errors, fnT, onWindowExpirationMethod, inputT, outputT, fnContext));
}
ErrorReporter getInitialRestrictionErrors;
if (getInitialRestrictionMethod != null) {
getInitialRestrictionErrors =
errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestrictionMethod);
signatureBuilder.setGetInitialRestriction(
analyzeGetInitialRestrictionMethod(
getInitialRestrictionErrors, fnT, getInitialRestrictionMethod, inputT));
}
if (splitRestrictionMethod != null) {
ErrorReporter splitRestrictionErrors =
errors.forMethod(DoFn.SplitRestriction.class, splitRestrictionMethod);
signatureBuilder.setSplitRestriction(
analyzeSplitRestrictionMethod(
splitRestrictionErrors, fnT, splitRestrictionMethod, inputT));
}
if (getRestrictionCoderMethod != null) {
ErrorReporter getRestrictionCoderErrors =
errors.forMethod(DoFn.GetRestrictionCoder.class, getRestrictionCoderMethod);
signatureBuilder.setGetRestrictionCoder(
analyzeGetRestrictionCoderMethod(
getRestrictionCoderErrors, fnT, getRestrictionCoderMethod));
}
if (newTrackerMethod != null) {
ErrorReporter newTrackerErrors = errors.forMethod(DoFn.NewTracker.class, newTrackerMethod);
signatureBuilder.setNewTracker(
analyzeNewTrackerMethod(newTrackerErrors, fnT, newTrackerMethod));
}
signatureBuilder.setIsBoundedPerElement(inferBoundedness(fnT, processElement, errors));
signatureBuilder.setStateDeclarations(fnContext.getStateDeclarations());
signatureBuilder.setTimerDeclarations(fnContext.getTimerDeclarations());
signatureBuilder.setFieldAccessDeclarations(fnContext.getFieldAccessDeclarations());
DoFnSignature signature = signatureBuilder.build();
// Additional validation for splittable DoFn's.
if (processElement.isSplittable()) {
verifySplittableMethods(signature, errors);
} else {
verifyUnsplittableMethods(errors, signature);
}
return signature;
}
/**
* Infers the boundedness of the {@link DoFn.ProcessElement} method (whether or not it performs a
* bounded amount of work per element) using the following criteria:
*
* <ol>
* <li>If the {@link DoFn} is not splittable, then it is bounded, it must not be annotated as
* {@link DoFn.BoundedPerElement} or {@link DoFn.UnboundedPerElement}, and {@link
* DoFn.ProcessElement} must return {@code void}.
* <li>If the {@link DoFn} (or any of its supertypes) is annotated as {@link
* DoFn.BoundedPerElement} or {@link DoFn.UnboundedPerElement}, use that. Only one of these
* must be specified.
* <li>If {@link DoFn.ProcessElement} returns {@link DoFn.ProcessContinuation}, assume it is
* unbounded. Otherwise (if it returns {@code void}), assume it is bounded.
* <li>If {@link DoFn.ProcessElement} returns {@code void}, but the {@link DoFn} is annotated
* {@link DoFn.UnboundedPerElement}, this is an error.
* </ol>
*/
private static PCollection.IsBounded inferBoundedness(
TypeDescriptor<? extends DoFn> fnT,
DoFnSignature.ProcessElementMethod processElement,
ErrorReporter errors) {
PCollection.IsBounded isBounded = null;
for (TypeDescriptor<?> supertype : fnT.getTypes()) {
if (supertype.getRawType().isAnnotationPresent(DoFn.BoundedPerElement.class)
|| supertype.getRawType().isAnnotationPresent(DoFn.UnboundedPerElement.class)) {
errors.checkArgument(
isBounded == null,
"Both @%s and @%s specified",
DoFn.BoundedPerElement.class.getSimpleName(),
DoFn.UnboundedPerElement.class.getSimpleName());
isBounded =
supertype.getRawType().isAnnotationPresent(DoFn.BoundedPerElement.class)
? PCollection.IsBounded.BOUNDED
: PCollection.IsBounded.UNBOUNDED;
}
}
if (processElement.isSplittable()) {
if (isBounded == null) {
isBounded =
processElement.hasReturnValue()
? PCollection.IsBounded.UNBOUNDED
: PCollection.IsBounded.BOUNDED;
}
} else {
errors.checkArgument(
isBounded == null,
"Non-splittable, but annotated as @"
+ ((isBounded == PCollection.IsBounded.BOUNDED)
? DoFn.BoundedPerElement.class.getSimpleName()
: DoFn.UnboundedPerElement.class.getSimpleName()));
checkState(!processElement.hasReturnValue(), "Should have been inferred splittable");
isBounded = PCollection.IsBounded.BOUNDED;
}
return isBounded;
}
/**
* Verifies properties related to methods of splittable {@link DoFn}:
*
* <ul>
* <li>Must declare the required {@link DoFn.GetInitialRestriction} and {@link DoFn.NewTracker}
* methods.
* <li>Types of restrictions and trackers must match exactly between {@link
* DoFn.ProcessElement}, {@link DoFn.GetInitialRestriction}, {@link DoFn.NewTracker}, {@link
* DoFn.GetRestrictionCoder}, {@link DoFn.SplitRestriction}.
* </ul>
*/
private static void verifySplittableMethods(DoFnSignature signature, ErrorReporter errors) {
DoFnSignature.ProcessElementMethod processElement = signature.processElement();
DoFnSignature.GetInitialRestrictionMethod getInitialRestriction =
signature.getInitialRestriction();
DoFnSignature.NewTrackerMethod newTracker = signature.newTracker();
DoFnSignature.GetRestrictionCoderMethod getRestrictionCoder = signature.getRestrictionCoder();
DoFnSignature.SplitRestrictionMethod splitRestriction = signature.splitRestriction();
ErrorReporter processElementErrors =
errors.forMethod(DoFn.ProcessElement.class, processElement.targetMethod());
List<String> missingRequiredMethods = new ArrayList<>();
if (getInitialRestriction == null) {
missingRequiredMethods.add("@" + DoFn.GetInitialRestriction.class.getSimpleName());
}
if (newTracker == null) {
if (getInitialRestriction != null
&& getInitialRestriction
.restrictionT()
.isSubtypeOf(TypeDescriptor.of(HasDefaultTracker.class))) {
// no-op we are using the annotation @HasDefaultTracker
} else {
missingRequiredMethods.add("@" + DoFn.NewTracker.class.getSimpleName());
}
} else {
ErrorReporter getInitialRestrictionErrors =
errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestriction.targetMethod());
TypeDescriptor<?> restrictionT = getInitialRestriction.restrictionT();
getInitialRestrictionErrors.checkArgument(
restrictionT.equals(newTracker.restrictionT()),
"Uses restriction type %s, but @%s method %s uses restriction type %s",
formatType(restrictionT),
DoFn.NewTracker.class.getSimpleName(),
format(newTracker.targetMethod()),
formatType(newTracker.restrictionT()));
}
if (!missingRequiredMethods.isEmpty()) {
processElementErrors.throwIllegalArgument(
"Splittable, but does not define the following required methods: %s",
missingRequiredMethods);
}
ErrorReporter getInitialRestrictionErrors =
errors.forMethod(DoFn.GetInitialRestriction.class, getInitialRestriction.targetMethod());
TypeDescriptor<?> restrictionT = getInitialRestriction.restrictionT();
processElementErrors.checkArgument(
processElement.trackerT().getRawType().equals(RestrictionTracker.class),
"Has tracker type %s, but the DoFn's tracker type must be of type RestrictionTracker.",
formatType(processElement.trackerT()));
if (getRestrictionCoder != null) {
getInitialRestrictionErrors.checkArgument(
getRestrictionCoder.coderT().isSubtypeOf(coderTypeOf(restrictionT)),
"Uses restriction type %s, but @%s method %s returns %s "
+ "which is not a subtype of %s",
formatType(restrictionT),
DoFn.GetRestrictionCoder.class.getSimpleName(),
format(getRestrictionCoder.targetMethod()),
formatType(getRestrictionCoder.coderT()),
formatType(coderTypeOf(restrictionT)));
}
if (splitRestriction != null) {
getInitialRestrictionErrors.checkArgument(
splitRestriction.restrictionT().equals(restrictionT),
"Uses restriction type %s, but @%s method %s uses restriction type %s",
formatType(restrictionT),
DoFn.SplitRestriction.class.getSimpleName(),
format(splitRestriction.targetMethod()),
formatType(splitRestriction.restrictionT()));
}
}
/**
* Verifies that a non-splittable {@link DoFn} does not declare any methods that only make sense
* for splittable {@link DoFn}: {@link DoFn.GetInitialRestriction}, {@link DoFn.SplitRestriction},
* {@link DoFn.NewTracker}, {@link DoFn.GetRestrictionCoder}.
*/
private static void verifyUnsplittableMethods(ErrorReporter errors, DoFnSignature signature) {
List<String> forbiddenMethods = new ArrayList<>();
if (signature.getInitialRestriction() != null) {
forbiddenMethods.add("@" + DoFn.GetInitialRestriction.class.getSimpleName());
}
if (signature.splitRestriction() != null) {
forbiddenMethods.add("@" + DoFn.SplitRestriction.class.getSimpleName());
}
if (signature.newTracker() != null) {
forbiddenMethods.add("@" + DoFn.NewTracker.class.getSimpleName());
}
if (signature.getRestrictionCoder() != null) {
forbiddenMethods.add("@" + DoFn.GetRestrictionCoder.class.getSimpleName());
}
errors.checkArgument(
forbiddenMethods.isEmpty(), "Non-splittable, but defines methods: %s", forbiddenMethods);
}
/**
* Generates a {@link TypeDescriptor} for {@code DoFn<InputT, OutputT>.ProcessContext} given
* {@code InputT} and {@code OutputT}.
*/
private static <InputT, OutputT>
TypeDescriptor<DoFn<InputT, OutputT>.ProcessContext> doFnProcessContextTypeOf(
TypeDescriptor<InputT> inputT, TypeDescriptor<OutputT> outputT) {
return new TypeDescriptor<DoFn<InputT, OutputT>.ProcessContext>() {}.where(
new TypeParameter<InputT>() {}, inputT)
.where(new TypeParameter<OutputT>() {}, outputT);
}
/**
* Generates a {@link TypeDescriptor} for {@code DoFn<InputT, OutputT>.StartBundleContext} given
* {@code InputT} and {@code OutputT}.
*/
private static <InputT, OutputT>
TypeDescriptor<DoFn<InputT, OutputT>.StartBundleContext> doFnStartBundleContextTypeOf(
TypeDescriptor<InputT> inputT, TypeDescriptor<OutputT> outputT) {
return new TypeDescriptor<DoFn<InputT, OutputT>.StartBundleContext>() {}.where(
new TypeParameter<InputT>() {}, inputT)
.where(new TypeParameter<OutputT>() {}, outputT);
}
/**
* Generates a {@link TypeDescriptor} for {@code DoFn<InputT, OutputT>.FinishBundleContext} given
* {@code InputT} and {@code OutputT}.
*/
private static <InputT, OutputT>
TypeDescriptor<DoFn<InputT, OutputT>.FinishBundleContext> doFnFinishBundleContextTypeOf(
TypeDescriptor<InputT> inputT, TypeDescriptor<OutputT> outputT) {
return new TypeDescriptor<DoFn<InputT, OutputT>.FinishBundleContext>() {}.where(
new TypeParameter<InputT>() {}, inputT)
.where(new TypeParameter<OutputT>() {}, outputT);
}
/**
* Generates a {@link TypeDescriptor} for {@code DoFn<InputT, OutputT>.Context} given {@code
* InputT} and {@code OutputT}.
*/
private static <InputT, OutputT>
TypeDescriptor<DoFn<InputT, OutputT>.OnTimerContext> doFnOnTimerContextTypeOf(
TypeDescriptor<InputT> inputT, TypeDescriptor<OutputT> outputT) {
return new TypeDescriptor<DoFn<InputT, OutputT>.OnTimerContext>() {}.where(
new TypeParameter<InputT>() {}, inputT)
.where(new TypeParameter<OutputT>() {}, outputT);
}
@VisibleForTesting
static DoFnSignature.OnTimerMethod analyzeOnTimerMethod(
ErrorReporter errors,
TypeDescriptor<? extends DoFn<?, ?>> fnClass,
Method m,
String timerId,
TypeDescriptor<?> inputT,
TypeDescriptor<?> outputT,
FnAnalysisContext fnContext) {
errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void");
Type[] params = m.getGenericParameterTypes();
MethodAnalysisContext methodContext = MethodAnalysisContext.create();
boolean requiresStableInput = m.isAnnotationPresent(DoFn.RequiresStableInput.class);
@Nullable TypeDescriptor<? extends BoundedWindow> windowT = getWindowType(fnClass, m);
List<DoFnSignature.Parameter> extraParameters = new ArrayList<>();
ErrorReporter onTimerErrors = errors.forMethod(DoFn.OnTimer.class, m);
for (int i = 0; i < params.length; ++i) {
Parameter parameter =
analyzeExtraParameter(
onTimerErrors,
fnContext,
methodContext,
fnClass,
ParameterDescription.of(
m,
i,
fnClass.resolveType(params[i]),
Arrays.asList(m.getParameterAnnotations()[i])),
inputT,
outputT);
checkParameterOneOf(errors, parameter, ALLOWED_ON_TIMER_PARAMETERS);
extraParameters.add(parameter);
}
return DoFnSignature.OnTimerMethod.create(
m, timerId, requiresStableInput, windowT, extraParameters);
}
@VisibleForTesting
static DoFnSignature.OnWindowExpirationMethod analyzeOnWindowExpirationMethod(
ErrorReporter errors,
TypeDescriptor<? extends DoFn<?, ?>> fnClass,
Method m,
TypeDescriptor<?> inputT,
TypeDescriptor<?> outputT,
FnAnalysisContext fnContext) {
errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void");
Type[] params = m.getGenericParameterTypes();
MethodAnalysisContext methodContext = MethodAnalysisContext.create();
boolean requiresStableInput = m.isAnnotationPresent(DoFn.RequiresStableInput.class);
@Nullable TypeDescriptor<? extends BoundedWindow> windowT = getWindowType(fnClass, m);
List<DoFnSignature.Parameter> extraParameters = new ArrayList<>();
ErrorReporter onWindowExpirationErrors = errors.forMethod(DoFn.OnWindowExpiration.class, m);
for (int i = 0; i < params.length; ++i) {
Parameter parameter =
analyzeExtraParameter(
onWindowExpirationErrors,
fnContext,
methodContext,
fnClass,
ParameterDescription.of(
m,
i,
fnClass.resolveType(params[i]),
Arrays.asList(m.getParameterAnnotations()[i])),
inputT,
outputT);
checkParameterOneOf(errors, parameter, ALLOWED_ON_WINDOW_EXPIRATION_PARAMETERS);
extraParameters.add(parameter);
}
return DoFnSignature.OnWindowExpirationMethod.create(
m, requiresStableInput, windowT, extraParameters);
}
@VisibleForTesting
static DoFnSignature.ProcessElementMethod analyzeProcessElementMethod(
ErrorReporter errors,
TypeDescriptor<? extends DoFn<?, ?>> fnClass,
Method m,
TypeDescriptor<?> inputT,
TypeDescriptor<?> outputT,
FnAnalysisContext fnContext) {
errors.checkArgument(
void.class.equals(m.getReturnType())
|| DoFn.ProcessContinuation.class.equals(m.getReturnType()),
"Must return void or %s",
DoFn.ProcessContinuation.class.getSimpleName());
MethodAnalysisContext methodContext = MethodAnalysisContext.create();
boolean requiresStableInput = m.isAnnotationPresent(DoFn.RequiresStableInput.class);
Type[] params = m.getGenericParameterTypes();
TypeDescriptor<?> trackerT = getTrackerType(fnClass, m);
TypeDescriptor<? extends BoundedWindow> windowT = getWindowType(fnClass, m);
for (int i = 0; i < params.length; ++i) {
Parameter extraParam =
analyzeExtraParameter(
errors.forMethod(DoFn.ProcessElement.class, m),
fnContext,
methodContext,
fnClass,
ParameterDescription.of(
m,
i,
fnClass.resolveType(params[i]),
Arrays.asList(m.getParameterAnnotations()[i])),
inputT,
outputT);
methodContext.addParameter(extraParam);
}
int schemaElementIndex = 0;
for (int i = 0; i < methodContext.getExtraParameters().size(); ++i) {
Parameter parameter = methodContext.getExtraParameters().get(i);
if (parameter instanceof SchemaElementParameter) {
SchemaElementParameter schemaParameter = (SchemaElementParameter) parameter;
schemaParameter = schemaParameter.toBuilder().setIndex(schemaElementIndex).build();
methodContext.setParameter(i, schemaParameter);
++schemaElementIndex;
}
}
// The allowed parameters depend on whether this DoFn is splittable
if (methodContext.hasRestrictionTrackerParameter()) {
for (Parameter parameter : methodContext.getExtraParameters()) {
checkParameterOneOf(errors, parameter, ALLOWED_SPLITTABLE_PROCESS_ELEMENT_PARAMETERS);
}
} else {
for (Parameter parameter : methodContext.getExtraParameters()) {
checkParameterOneOf(errors, parameter, ALLOWED_NON_SPLITTABLE_PROCESS_ELEMENT_PARAMETERS);
}
}
return DoFnSignature.ProcessElementMethod.create(
m,
methodContext.getExtraParameters(),
requiresStableInput,
trackerT,
windowT,
DoFn.ProcessContinuation.class.equals(m.getReturnType()));
}
private static void checkParameterOneOf(
ErrorReporter errors,
Parameter parameter,
Collection<Class<? extends Parameter>> allowedParameterClasses) {
for (Class<? extends Parameter> paramClass : allowedParameterClasses) {
if (paramClass.isAssignableFrom(parameter.getClass())) {
return;
}
}
// If we get here, none matched
errors.throwIllegalArgument("Illegal parameter type: %s", parameter);
}
private static Parameter analyzeExtraParameter(
ErrorReporter methodErrors,
FnAnalysisContext fnContext,
MethodAnalysisContext methodContext,
TypeDescriptor<? extends DoFn<?, ?>> fnClass,
ParameterDescription param,
TypeDescriptor<?> inputT,
TypeDescriptor<?> outputT) {
TypeDescriptor<?> expectedProcessContextT = doFnProcessContextTypeOf(inputT, outputT);
TypeDescriptor<?> expectedOnTimerContextT = doFnOnTimerContextTypeOf(inputT, outputT);
TypeDescriptor<?> paramT = param.getType();
Class<?> rawType = paramT.getRawType();
ErrorReporter paramErrors = methodErrors.forParameter(param);
String fieldAccessString = getFieldAccessId(param.getAnnotations());
if (fieldAccessString != null) {
return Parameter.schemaElementParameter(paramT, fieldAccessString, param.getIndex());
} else if (hasElementAnnotation(param.getAnnotations())) {
return (paramT.equals(inputT))
? Parameter.elementParameter(paramT)
: Parameter.schemaElementParameter(paramT, null, param.getIndex());
} else if (hasTimestampAnnotation(param.getAnnotations())) {
methodErrors.checkArgument(
rawType.equals(Instant.class),
"@Timestamp argument must have type org.joda.time.Instant.");
return Parameter.timestampParameter();
} else if (rawType.equals(TimeDomain.class)) {
return Parameter.timeDomainParameter();
} else if (rawType.equals(PaneInfo.class)) {
return Parameter.paneInfoParameter();
} else if (rawType.equals(DoFn.ProcessContext.class)) {
paramErrors.checkArgument(
paramT.equals(expectedProcessContextT),
"ProcessContext argument must have type %s",
formatType(expectedProcessContextT));
return Parameter.processContext();
} else if (rawType.equals(DoFn.OnTimerContext.class)) {
paramErrors.checkArgument(
paramT.equals(expectedOnTimerContextT),
"OnTimerContext argument must have type %s",
formatType(expectedOnTimerContextT));
return Parameter.onTimerContext();
} else if (BoundedWindow.class.isAssignableFrom(rawType)) {
methodErrors.checkArgument(
!methodContext.hasWindowParameter(),
"Multiple %s parameters",
BoundedWindow.class.getSimpleName());
return Parameter.boundedWindow((TypeDescriptor<? extends BoundedWindow>) paramT);
} else if (rawType.equals(OutputReceiver.class)) {
// It's a schema row receiver if it's an OutputReceiver<Row> _and_ the output type is not
// already Row.
boolean schemaRowReceiver =
paramT.equals(outputReceiverTypeOf(TypeDescriptor.of(Row.class)))
&& !outputT.equals(TypeDescriptor.of(Row.class));
if (!schemaRowReceiver) {
TypeDescriptor<?> expectedReceiverT = outputReceiverTypeOf(outputT);
paramErrors.checkArgument(
paramT.equals(expectedReceiverT),
"OutputReceiver should be parameterized by %s",
outputT);
}
return Parameter.outputReceiverParameter(schemaRowReceiver);
} else if (rawType.equals(MultiOutputReceiver.class)) {
return Parameter.taggedOutputReceiverParameter();
} else if (PipelineOptions.class.equals(rawType)) {
methodErrors.checkArgument(
!methodContext.hasPipelineOptionsParamter(),
"Multiple %s parameters",
PipelineOptions.class.getSimpleName());
return Parameter.pipelineOptions();
} else if (RestrictionTracker.class.isAssignableFrom(rawType)) {
methodErrors.checkArgument(
!methodContext.hasRestrictionTrackerParameter(),
"Multiple %s parameters",
RestrictionTracker.class.getSimpleName());
return Parameter.restrictionTracker(paramT);
} else if (rawType.equals(Timer.class)) {
// m.getParameters() is not available until Java 8
String id = getTimerId(param.getAnnotations());
paramErrors.checkArgument(
id != null,
"%s missing %s annotation",
Timer.class.getSimpleName(),
TimerId.class.getSimpleName());
paramErrors.checkArgument(
!methodContext.getTimerParameters().containsKey(id),
"duplicate %s: \"%s\"",
TimerId.class.getSimpleName(),
id);
TimerDeclaration timerDecl = fnContext.getTimerDeclarations().get(id);
paramErrors.checkArgument(
timerDecl != null,
"reference to undeclared %s: \"%s\"",
TimerId.class.getSimpleName(),
id);
paramErrors.checkArgument(
timerDecl.field().getDeclaringClass().equals(param.getMethod().getDeclaringClass()),
"%s %s declared in a different class %s."
+ " Timers may be referenced only in the lexical scope where they are declared.",
TimerId.class.getSimpleName(),
id,
timerDecl.field().getDeclaringClass().getName());
return Parameter.timerParameter(timerDecl);
} else if (State.class.isAssignableFrom(rawType)) {
// m.getParameters() is not available until Java 8
String id = getStateId(param.getAnnotations());
paramErrors.checkArgument(
id != null, "missing %s annotation", DoFn.StateId.class.getSimpleName());
paramErrors.checkArgument(
!methodContext.getStateParameters().containsKey(id),
"duplicate %s: \"%s\"",
DoFn.StateId.class.getSimpleName(),
id);
// By static typing this is already a well-formed State subclass
TypeDescriptor<? extends State> stateType = (TypeDescriptor<? extends State>) param.getType();
StateDeclaration stateDecl = fnContext.getStateDeclarations().get(id);
paramErrors.checkArgument(
stateDecl != null,
"reference to undeclared %s: \"%s\"",
DoFn.StateId.class.getSimpleName(),
id);
paramErrors.checkArgument(
stateDecl.stateType().isSubtypeOf(stateType),
"data type of reference to %s %s must be a supertype of %s",
StateId.class.getSimpleName(),
id,
formatType(stateDecl.stateType()));
paramErrors.checkArgument(
stateDecl.field().getDeclaringClass().equals(param.getMethod().getDeclaringClass()),
"%s %s declared in a different class %s."
+ " State may be referenced only in the class where it is declared.",
StateId.class.getSimpleName(),
id,
stateDecl.field().getDeclaringClass().getName());
return Parameter.stateParameter(stateDecl);
} else {
List<String> allowedParamTypes =
Arrays.asList(
formatType(new TypeDescriptor<BoundedWindow>() {}),
formatType(new TypeDescriptor<RestrictionTracker<?, ?>>() {}));
paramErrors.throwIllegalArgument(
"%s is not a valid context parameter. Should be one of %s",
formatType(paramT), allowedParamTypes);
// Unreachable
return null;
}
}
@Nullable
private static String getTimerId(List<Annotation> annotations) {
DoFn.TimerId stateId = findFirstOfType(annotations, DoFn.TimerId.class);
return stateId != null ? stateId.value() : null;
}
@Nullable
private static String getStateId(List<Annotation> annotations) {
DoFn.StateId stateId = findFirstOfType(annotations, DoFn.StateId.class);
return stateId != null ? stateId.value() : null;
}
@Nullable
private static String getFieldAccessId(List<Annotation> annotations) {
DoFn.FieldAccess access = findFirstOfType(annotations, DoFn.FieldAccess.class);
return access != null ? access.value() : null;
}
@Nullable
static <T> T findFirstOfType(List<Annotation> annotations, Class<T> clazz) {
Optional<Annotation> annotation =
annotations.stream().filter(a -> a.annotationType().equals(clazz)).findFirst();
return annotation.isPresent() ? (T) annotation.get() : null;
}
private static boolean hasElementAnnotation(List<Annotation> annotations) {
return annotations.stream().anyMatch(a -> a.annotationType().equals(DoFn.Element.class));
}
private static boolean hasTimestampAnnotation(List<Annotation> annotations) {
return annotations.stream().anyMatch(a -> a.annotationType().equals(DoFn.Timestamp.class));
}
@Nullable
private static TypeDescriptor<?> getTrackerType(TypeDescriptor<?> fnClass, Method method) {
Type[] params = method.getGenericParameterTypes();
for (Type param : params) {
TypeDescriptor<?> paramT = fnClass.resolveType(param);
if (RestrictionTracker.class.isAssignableFrom(paramT.getRawType())) {
return paramT;
}
}
return null;
}
@Nullable
private static TypeDescriptor<? extends BoundedWindow> getWindowType(
TypeDescriptor<?> fnClass, Method method) {
Type[] params = method.getGenericParameterTypes();
for (Type param : params) {
TypeDescriptor<?> paramT = fnClass.resolveType(param);
if (BoundedWindow.class.isAssignableFrom(paramT.getRawType())) {
return (TypeDescriptor<? extends BoundedWindow>) paramT;
}
}
return null;
}
@VisibleForTesting
static DoFnSignature.BundleMethod analyzeStartBundleMethod(
ErrorReporter errors,
TypeDescriptor<? extends DoFn<?, ?>> fnT,
Method m,
TypeDescriptor<?> inputT,
TypeDescriptor<?> outputT) {
errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void");
TypeDescriptor<?> expectedContextT = doFnStartBundleContextTypeOf(inputT, outputT);
Type[] params = m.getGenericParameterTypes();
errors.checkArgument(
params.length == 0
|| (params.length == 1 && fnT.resolveType(params[0]).equals(expectedContextT)),
"Must take a single argument of type %s",
formatType(expectedContextT));
return DoFnSignature.BundleMethod.create(m);
}
@VisibleForTesting
static DoFnSignature.BundleMethod analyzeFinishBundleMethod(
ErrorReporter errors,
TypeDescriptor<? extends DoFn<?, ?>> fnT,
Method m,
TypeDescriptor<?> inputT,
TypeDescriptor<?> outputT) {
errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void");
TypeDescriptor<?> expectedContextT = doFnFinishBundleContextTypeOf(inputT, outputT);
Type[] params = m.getGenericParameterTypes();
errors.checkArgument(
params.length == 0
|| (params.length == 1 && fnT.resolveType(params[0]).equals(expectedContextT)),
"Must take a single argument of type %s",
formatType(expectedContextT));
return DoFnSignature.BundleMethod.create(m);
}
private static DoFnSignature.LifecycleMethod analyzeLifecycleMethod(
ErrorReporter errors, Method m) {
errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void");
errors.checkArgument(m.getGenericParameterTypes().length == 0, "Must take zero arguments");
return DoFnSignature.LifecycleMethod.create(m);
}
@VisibleForTesting
static DoFnSignature.GetInitialRestrictionMethod analyzeGetInitialRestrictionMethod(
ErrorReporter errors,
TypeDescriptor<? extends DoFn> fnT,
Method m,
TypeDescriptor<?> inputT) {
// Method is of the form:
// @GetInitialRestriction
// RestrictionT getInitialRestriction(InputT element);
Type[] params = m.getGenericParameterTypes();
errors.checkArgument(
params.length == 1 && fnT.resolveType(params[0]).equals(inputT),
"Must take a single argument of type %s",
formatType(inputT));
return DoFnSignature.GetInitialRestrictionMethod.create(
m, fnT.resolveType(m.getGenericReturnType()));
}
/**
* Generates a {@link TypeDescriptor} for {@code DoFn.OutputReceiver<OutputT>} given {@code
* OutputT}.
*/
private static <OutputT> TypeDescriptor<DoFn.OutputReceiver<OutputT>> outputReceiverTypeOf(
TypeDescriptor<OutputT> outputT) {
return new TypeDescriptor<DoFn.OutputReceiver<OutputT>>() {}.where(
new TypeParameter<OutputT>() {}, outputT);
}
@VisibleForTesting
static DoFnSignature.SplitRestrictionMethod analyzeSplitRestrictionMethod(
ErrorReporter errors,
TypeDescriptor<? extends DoFn> fnT,
Method m,
TypeDescriptor<?> inputT) {
// Method is of the form:
// @SplitRestriction
// void splitRestriction(InputT element, RestrictionT restriction);
errors.checkArgument(void.class.equals(m.getReturnType()), "Must return void");
Type[] params = m.getGenericParameterTypes();
errors.checkArgument(params.length == 3, "Must have exactly 3 arguments");
errors.checkArgument(
fnT.resolveType(params[0]).equals(inputT),
"First argument must be the element type %s",
formatType(inputT));
TypeDescriptor<?> restrictionT = fnT.resolveType(params[1]);
TypeDescriptor<?> receiverT = fnT.resolveType(params[2]);
TypeDescriptor<?> expectedReceiverT = outputReceiverTypeOf(restrictionT);
errors.checkArgument(
receiverT.equals(expectedReceiverT),
"Third argument must be %s, but is %s",
formatType(expectedReceiverT),
formatType(receiverT));
return DoFnSignature.SplitRestrictionMethod.create(m, restrictionT);
}
private static ImmutableMap<String, TimerDeclaration> analyzeTimerDeclarations(
ErrorReporter errors, Class<?> fnClazz) {
Map<String, DoFnSignature.TimerDeclaration> declarations = new HashMap<>();
for (Field field : declaredFieldsWithAnnotation(DoFn.TimerId.class, fnClazz, DoFn.class)) {
// TimerSpec fields may generally be private, but will be accessed via the signature
field.setAccessible(true);
String id = field.getAnnotation(DoFn.TimerId.class).value();
validateTimerField(errors, declarations, id, field);
declarations.put(id, DoFnSignature.TimerDeclaration.create(id, field));
}
return ImmutableMap.copyOf(declarations);
}
/**
* Returns successfully if the field is valid, otherwise throws an exception via its {@link
* ErrorReporter} parameter describing validation failures for the timer declaration.
*/
private static void validateTimerField(
ErrorReporter errors, Map<String, TimerDeclaration> declarations, String id, Field field) {
if (declarations.containsKey(id)) {
errors.throwIllegalArgument(
"Duplicate %s \"%s\", used on both of [%s] and [%s]",
DoFn.TimerId.class.getSimpleName(),
id,
field.toString(),
declarations.get(id).field().toString());
}
Class<?> timerSpecRawType = field.getType();
if (!(timerSpecRawType.equals(TimerSpec.class))) {
errors.throwIllegalArgument(
"%s annotation on non-%s field [%s]",
DoFn.TimerId.class.getSimpleName(), TimerSpec.class.getSimpleName(), field.toString());
}
if (!Modifier.isFinal(field.getModifiers())) {
errors.throwIllegalArgument(
"Non-final field %s annotated with %s. Timer declarations must be final.",
field.toString(), DoFn.TimerId.class.getSimpleName());
}
}
/** Generates a {@link TypeDescriptor} for {@code Coder<T>} given {@code T}. */
private static <T> TypeDescriptor<Coder<T>> coderTypeOf(TypeDescriptor<T> elementT) {
return new TypeDescriptor<Coder<T>>() {}.where(new TypeParameter<T>() {}, elementT);
}
@VisibleForTesting
static DoFnSignature.GetRestrictionCoderMethod analyzeGetRestrictionCoderMethod(
ErrorReporter errors, TypeDescriptor<? extends DoFn> fnT, Method m) {
errors.checkArgument(m.getParameterTypes().length == 0, "Must have zero arguments");
TypeDescriptor<?> resT = fnT.resolveType(m.getGenericReturnType());
errors.checkArgument(
resT.isSubtypeOf(TypeDescriptor.of(Coder.class)),
"Must return a Coder, but returns %s",
formatType(resT));
return DoFnSignature.GetRestrictionCoderMethod.create(m, resT);
}
/**
* Generates a {@link TypeDescriptor} for {@code RestrictionTracker<RestrictionT>} given {@code
* RestrictionT}.
*/
private static <RestrictionT>
TypeDescriptor<RestrictionTracker<RestrictionT, ?>> restrictionTrackerTypeOf(
TypeDescriptor<RestrictionT> restrictionT) {
return new TypeDescriptor<RestrictionTracker<RestrictionT, ?>>() {}.where(
new TypeParameter<RestrictionT>() {}, restrictionT);
}
@VisibleForTesting
static DoFnSignature.NewTrackerMethod analyzeNewTrackerMethod(
ErrorReporter errors, TypeDescriptor<? extends DoFn> fnT, Method m) {
// Method is of the form:
// @NewTracker
// TrackerT newTracker(RestrictionT restriction);
Type[] params = m.getGenericParameterTypes();
errors.checkArgument(params.length == 1, "Must have a single argument");
TypeDescriptor<?> restrictionT = fnT.resolveType(params[0]);
TypeDescriptor<?> trackerT = fnT.resolveType(m.getGenericReturnType());
TypeDescriptor<?> expectedTrackerT = restrictionTrackerTypeOf(restrictionT);
errors.checkArgument(
trackerT.isSubtypeOf(expectedTrackerT),
"Returns %s, but must return a subtype of %s",
formatType(trackerT),
formatType(expectedTrackerT));
return DoFnSignature.NewTrackerMethod.create(m, restrictionT, trackerT);
}
private static Collection<Method> declaredMethodsWithAnnotation(
Class<? extends Annotation> anno, Class<?> startClass, Class<?> stopClass) {
return declaredMembersWithAnnotation(anno, startClass, stopClass, GET_METHODS);
}
private static Collection<Field> declaredFieldsWithAnnotation(
Class<? extends Annotation> anno, Class<?> startClass, Class<?> stopClass) {
return declaredMembersWithAnnotation(anno, startClass, stopClass, GET_FIELDS);
}
private interface MemberGetter<MemberT> {
MemberT[] getMembers(Class<?> clazz);
}
private static final MemberGetter<Method> GET_METHODS = Class::getDeclaredMethods;
private static final MemberGetter<Field> GET_FIELDS = Class::getDeclaredFields;
private static <MemberT extends AnnotatedElement>
Collection<MemberT> declaredMembersWithAnnotation(
Class<? extends Annotation> anno,
Class<?> startClass,
Class<?> stopClass,
MemberGetter<MemberT> getter) {
Collection<MemberT> matches = new ArrayList<>();
Class<?> clazz = startClass;
LinkedHashSet<Class<?>> interfaces = new LinkedHashSet<>();
// First, find all declared methods on the startClass and parents (up to stopClass)
while (clazz != null && !clazz.equals(stopClass)) {
for (MemberT member : getter.getMembers(clazz)) {
if (member.isAnnotationPresent(anno)) {
matches.add(member);
}
}
// Add all interfaces, including transitive
for (TypeDescriptor<?> iface : TypeDescriptor.of(clazz).getInterfaces()) {
interfaces.add(iface.getRawType());
}
clazz = clazz.getSuperclass();
}
// Now, iterate over all the discovered interfaces
for (Class<?> iface : interfaces) {
for (MemberT member : getter.getMembers(iface)) {
if (member.isAnnotationPresent(anno)) {
matches.add(member);
}
}
}
return matches;
}
private static Map<String, DoFnSignature.FieldAccessDeclaration> analyzeFieldAccessDeclaration(
ErrorReporter errors, Class<?> fnClazz) {
Map<String, FieldAccessDeclaration> fieldAccessDeclarations = new HashMap<>();
for (Field field : declaredFieldsWithAnnotation(DoFn.FieldAccess.class, fnClazz, DoFn.class)) {
field.setAccessible(true);
DoFn.FieldAccess fieldAccessAnnotation = field.getAnnotation(DoFn.FieldAccess.class);
if (!Modifier.isFinal(field.getModifiers())) {
errors.throwIllegalArgument(
"Non-final field %s annotated with %s. Field access declarations must be final.",
field.toString(), DoFn.FieldAccess.class.getSimpleName());
continue;
}
Class<?> fieldAccessRawType = field.getType();
if (!fieldAccessRawType.equals(FieldAccessDescriptor.class)) {
errors.throwIllegalArgument(
"Field %s annotated with %s, but the value was not of type %s",
field.toString(),
DoFn.FieldAccess.class.getSimpleName(),
FieldAccessDescriptor.class.getSimpleName());
}
fieldAccessDeclarations.put(
fieldAccessAnnotation.value(),
FieldAccessDeclaration.create(fieldAccessAnnotation.value(), field));
}
return fieldAccessDeclarations;
}
private static Map<String, DoFnSignature.StateDeclaration> analyzeStateDeclarations(
ErrorReporter errors, Class<?> fnClazz) {
Map<String, DoFnSignature.StateDeclaration> declarations = new HashMap<>();
for (Field field : declaredFieldsWithAnnotation(DoFn.StateId.class, fnClazz, DoFn.class)) {
// StateSpec fields may generally be private, but will be accessed via the signature
field.setAccessible(true);
String id = field.getAnnotation(DoFn.StateId.class).value();
if (declarations.containsKey(id)) {
errors.throwIllegalArgument(
"Duplicate %s \"%s\", used on both of [%s] and [%s]",
DoFn.StateId.class.getSimpleName(),
id,
field.toString(),
declarations.get(id).field().toString());
continue;
}
Class<?> stateSpecRawType = field.getType();
if (!(TypeDescriptor.of(stateSpecRawType).isSubtypeOf(TypeDescriptor.of(StateSpec.class)))) {
errors.throwIllegalArgument(
"%s annotation on non-%s field [%s] that has class %s",
DoFn.StateId.class.getSimpleName(),
StateSpec.class.getSimpleName(),
field.toString(),
stateSpecRawType.getName());
continue;
}
if (!Modifier.isFinal(field.getModifiers())) {
errors.throwIllegalArgument(
"Non-final field %s annotated with %s. State declarations must be final.",
field.toString(), DoFn.StateId.class.getSimpleName());
continue;
}
Type stateSpecType = field.getGenericType();
// A type descriptor for whatever type the @StateId-annotated class has, which
// must be some subtype of StateSpec
TypeDescriptor<? extends StateSpec<?>> stateSpecSubclassTypeDescriptor =
(TypeDescriptor) TypeDescriptor.of(stateSpecType);
// A type descriptor for StateSpec, with the generic type parameters filled
// in according to the specialization of the subclass (or just straight params)
TypeDescriptor<StateSpec<?>> stateSpecTypeDescriptor =
(TypeDescriptor) stateSpecSubclassTypeDescriptor.getSupertype(StateSpec.class);
// The type of the state, which may still have free type variables from the
// context
Type unresolvedStateType =
((ParameterizedType) stateSpecTypeDescriptor.getType()).getActualTypeArguments()[0];
// By static typing this is already a well-formed State subclass
TypeDescriptor<? extends State> stateType =
(TypeDescriptor<? extends State>)
TypeDescriptor.of(fnClazz).resolveType(unresolvedStateType);
declarations.put(id, DoFnSignature.StateDeclaration.create(id, field, stateType));
}
return ImmutableMap.copyOf(declarations);
}
@Nullable
private static Method findAnnotatedMethod(
ErrorReporter errors, Class<? extends Annotation> anno, Class<?> fnClazz, boolean required) {
Collection<Method> matches = declaredMethodsWithAnnotation(anno, fnClazz, DoFn.class);
if (matches.isEmpty()) {
errors.checkArgument(!required, "No method annotated with @%s found", anno.getSimpleName());
return null;
}
// If we have at least one match, then either it should be the only match
// or it should be an extension of the other matches (which came from parent
// classes).
Method first = matches.iterator().next();
for (Method other : matches) {
errors.checkArgument(
first.getName().equals(other.getName())
&& Arrays.equals(first.getParameterTypes(), other.getParameterTypes()),
"Found multiple methods annotated with @%s. [%s] and [%s]",
anno.getSimpleName(),
format(first),
format(other));
}
ErrorReporter methodErrors = errors.forMethod(anno, first);
// We need to be able to call it. We require it is public.
methodErrors.checkArgument((first.getModifiers() & Modifier.PUBLIC) != 0, "Must be public");
// And make sure its not static.
methodErrors.checkArgument((first.getModifiers() & Modifier.STATIC) == 0, "Must not be static");
return first;
}
private static String format(Method method) {
return ReflectHelpers.METHOD_FORMATTER.apply(method);
}
private static String formatType(TypeDescriptor<?> t) {
return ReflectHelpers.TYPE_SIMPLE_DESCRIPTION.apply(t.getType());
}
static class ErrorReporter {
private final String label;
ErrorReporter(@Nullable ErrorReporter root, String label) {
this.label = (root == null) ? label : String.format("%s, %s", root.label, label);
}
ErrorReporter forMethod(Class<? extends Annotation> annotation, Method method) {
return new ErrorReporter(
this,
String.format(
"@%s %s",
annotation.getSimpleName(), (method == null) ? "(absent)" : format(method)));
}
ErrorReporter forParameter(ParameterDescription param) {
return new ErrorReporter(
this,
String.format(
"parameter of type %s at index %s", formatType(param.getType()), param.getIndex()));
}
void throwIllegalArgument(String message, Object... args) {
throw new IllegalArgumentException(label + ": " + String.format(message, args));
}
public void checkArgument(boolean condition, String message, Object... args) {
if (!condition) {
throwIllegalArgument(message, args);
}
}
public void checkNotNull(Object value, String message, Object... args) {
if (value == null) {
throwIllegalArgument(message, args);
}
}
}
public static StateSpec<?> getStateSpecOrThrow(
StateDeclaration stateDeclaration, DoFn<?, ?> target) {
try {
Object fieldValue = stateDeclaration.field().get(target);
checkState(
fieldValue instanceof StateSpec,
"Malformed %s class %s: state declaration field %s does not have type %s.",
DoFn.class.getSimpleName(),
target.getClass().getName(),
stateDeclaration.field().getName(),
StateSpec.class);
return (StateSpec<?>) stateDeclaration.field().get(target);
} catch (IllegalAccessException exc) {
throw new RuntimeException(
String.format(
"Malformed %s class %s: state declaration field %s is not accessible.",
DoFn.class.getSimpleName(),
target.getClass().getName(),
stateDeclaration.field().getName()));
}
}
public static TimerSpec getTimerSpecOrThrow(
TimerDeclaration timerDeclaration, DoFn<?, ?> target) {
try {
Object fieldValue = timerDeclaration.field().get(target);
checkState(
fieldValue instanceof TimerSpec,
"Malformed %s class %s: timer declaration field %s does not have type %s.",
DoFn.class.getSimpleName(),
target.getClass().getName(),
timerDeclaration.field().getName(),
TimerSpec.class);
return (TimerSpec) timerDeclaration.field().get(target);
} catch (IllegalAccessException exc) {
throw new RuntimeException(
String.format(
"Malformed %s class %s: timer declaration field %s is not accessible.",
DoFn.class.getSimpleName(),
target.getClass().getName(),
timerDeclaration.field().getName()));
}
}
}