| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.api.common.functions.util; |
| |
| import org.apache.flink.annotation.Internal; |
| import org.apache.flink.annotation.PublicEvolving; |
| import org.apache.flink.annotation.VisibleForTesting; |
| import org.apache.flink.api.common.ExecutionConfig; |
| import org.apache.flink.api.common.TaskInfo; |
| import org.apache.flink.api.common.accumulators.AbstractAccumulatorRegistry; |
| import org.apache.flink.api.common.accumulators.Accumulator; |
| import org.apache.flink.api.common.accumulators.AccumulatorHelper; |
| import org.apache.flink.api.common.accumulators.DoubleCounter; |
| import org.apache.flink.api.common.accumulators.Histogram; |
| import org.apache.flink.api.common.accumulators.IntCounter; |
| import org.apache.flink.api.common.accumulators.LongCounter; |
| import org.apache.flink.api.common.cache.DistributedCache; |
| import org.apache.flink.api.common.functions.RuntimeContext; |
| import org.apache.flink.api.common.state.AggregatingState; |
| import org.apache.flink.api.common.state.AggregatingStateDescriptor; |
| import org.apache.flink.api.common.state.FoldingState; |
| import org.apache.flink.api.common.state.FoldingStateDescriptor; |
| import org.apache.flink.api.common.state.ListState; |
| import org.apache.flink.api.common.state.ListStateDescriptor; |
| import org.apache.flink.api.common.state.MapState; |
| import org.apache.flink.api.common.state.MapStateDescriptor; |
| import org.apache.flink.api.common.state.ReducingState; |
| import org.apache.flink.api.common.state.ReducingStateDescriptor; |
| import org.apache.flink.api.common.state.SortedMapState; |
| import org.apache.flink.api.common.state.SortedMapStateDescriptor; |
| import org.apache.flink.api.common.state.ValueState; |
| import org.apache.flink.api.common.state.ValueStateDescriptor; |
| import org.apache.flink.core.fs.Path; |
| import org.apache.flink.metrics.MetricGroup; |
| |
| import java.io.Serializable; |
| import java.util.Map; |
| import java.util.concurrent.CompletableFuture; |
| import java.util.concurrent.Future; |
| |
| import static org.apache.flink.util.Preconditions.checkNotNull; |
| |
| /** |
| * A standalone implementation of the {@link RuntimeContext}, created by runtime UDF operators. |
| */ |
| @Internal |
| public abstract class AbstractRuntimeUDFContext implements RuntimeContext { |
| |
| private final TaskInfo taskInfo; |
| |
| private final ClassLoader userCodeClassLoader; |
| |
| private final ExecutionConfig executionConfig; |
| |
| private final AbstractAccumulatorRegistry accumulatorRegistry; |
| |
| private final DistributedCache distributedCache; |
| |
| private final MetricGroup metrics; |
| |
| public AbstractRuntimeUDFContext(TaskInfo taskInfo, |
| ClassLoader userCodeClassLoader, |
| ExecutionConfig executionConfig, |
| AbstractAccumulatorRegistry accumulatorRegistry, |
| Map<String, Future<Path>> cpTasks, |
| MetricGroup metrics) { |
| this.taskInfo = checkNotNull(taskInfo); |
| this.userCodeClassLoader = userCodeClassLoader; |
| this.executionConfig = executionConfig; |
| this.distributedCache = new DistributedCache(checkNotNull(cpTasks)); |
| this.accumulatorRegistry = checkNotNull(accumulatorRegistry); |
| this.metrics = metrics; |
| } |
| |
| @Override |
| public ExecutionConfig getExecutionConfig() { |
| return executionConfig; |
| } |
| |
| @Override |
| public String getTaskName() { |
| return taskInfo.getTaskName(); |
| } |
| |
| @Override |
| public int getNumberOfParallelSubtasks() { |
| return taskInfo.getNumberOfParallelSubtasks(); |
| } |
| |
| @Override |
| public int getMaxNumberOfParallelSubtasks() { |
| return taskInfo.getMaxNumberOfParallelSubtasks(); |
| } |
| |
| @Override |
| public int getIndexOfThisSubtask() { |
| return taskInfo.getIndexOfThisSubtask(); |
| } |
| |
| @Override |
| public MetricGroup getMetricGroup() { |
| return metrics; |
| } |
| |
| @Override |
| public int getAttemptNumber() { |
| return taskInfo.getAttemptNumber(); |
| } |
| |
| @Override |
| public String getTaskNameWithSubtasks() { |
| return taskInfo.getTaskNameWithSubtasks(); |
| } |
| |
| @Override |
| public IntCounter getIntCounter(String name) { |
| return (IntCounter) getAccumulator(name, IntCounter.class); |
| } |
| |
| @Override |
| public LongCounter getLongCounter(String name) { |
| return (LongCounter) getAccumulator(name, LongCounter.class); |
| } |
| |
| @Override |
| public Histogram getHistogram(String name) { |
| return (Histogram) getAccumulator(name, Histogram.class); |
| } |
| |
| @Override |
| public DoubleCounter getDoubleCounter(String name) { |
| return (DoubleCounter) getAccumulator(name, DoubleCounter.class); |
| } |
| |
| @Override |
| public <V, A extends Serializable> void addAccumulator(String name, Accumulator<V, A> accumulator) { |
| accumulatorRegistry.addAccumulator(name, accumulator); |
| } |
| |
| @SuppressWarnings("unchecked") |
| @Override |
| public <V, A extends Serializable> Accumulator<V, A> getAccumulator(String name) { |
| return (Accumulator<V, A>) accumulatorRegistry.getAccumulators().get(name); |
| } |
| |
| @Override |
| public Map<String, Accumulator<?, ?>> getAllAccumulators() { |
| return accumulatorRegistry.getAccumulators(); |
| } |
| |
| @Override |
| public <V, A extends Serializable> void addPreAggregatedAccumulator(String name, Accumulator<V, A> accumulator) { |
| accumulatorRegistry.addPreAggregatedAccumulator(name, accumulator); |
| } |
| |
| @SuppressWarnings("unchecked") |
| @Override |
| public <V, A extends Serializable> Accumulator<V, A> getPreAggregatedAccumulator(String name) { |
| return (Accumulator<V, A>) accumulatorRegistry.getPreAggregatedAccumulators().get(name); |
| } |
| |
| @Override |
| public void commitPreAggregatedAccumulator(String name) { |
| accumulatorRegistry.commitPreAggregatedAccumulator(name); |
| } |
| |
| @Override |
| public <V, A extends Serializable> CompletableFuture<Accumulator<V, A>> queryPreAggregatedAccumulator(String name) { |
| return accumulatorRegistry.queryPreAggregatedAccumulator(name); |
| } |
| |
| @Override |
| public ClassLoader getUserCodeClassLoader() { |
| return this.userCodeClassLoader; |
| } |
| |
| @Override |
| public DistributedCache getDistributedCache() { |
| return this.distributedCache; |
| } |
| |
| // -------------------------------------------------------------------------------------------- |
| |
| @SuppressWarnings("unchecked") |
| private <V, A extends Serializable> Accumulator<V, A> getAccumulator(String name, |
| Class<? extends Accumulator<V, A>> accumulatorClass) |
| { |
| Accumulator<?, ?> accumulator = accumulatorRegistry.getAccumulators().get(name); |
| |
| if (accumulator != null) { |
| AccumulatorHelper.compareAccumulatorTypes(name, accumulator.getClass(), accumulatorClass); |
| } else { |
| // Create new accumulator |
| try { |
| accumulator = accumulatorClass.newInstance(); |
| } |
| catch (Exception e) { |
| throw new RuntimeException("Cannot create accumulator " + accumulatorClass.getName()); |
| } |
| accumulatorRegistry.addAccumulator(name, accumulator); |
| } |
| return (Accumulator<V, A>) accumulator; |
| } |
| |
| @Override |
| @PublicEvolving |
| public <T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties) { |
| throw new UnsupportedOperationException( |
| "This state is only accessible by functions executed on a KeyedStream"); |
| } |
| |
| @Override |
| @PublicEvolving |
| public <T> ListState<T> getListState(ListStateDescriptor<T> stateProperties) { |
| throw new UnsupportedOperationException( |
| "This state is only accessible by functions executed on a KeyedStream"); |
| } |
| |
| @Override |
| @PublicEvolving |
| public <T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties) { |
| throw new UnsupportedOperationException( |
| "This state is only accessible by functions executed on a KeyedStream"); |
| } |
| |
| @Override |
| @PublicEvolving |
| public <IN, ACC, OUT> AggregatingState<IN, OUT> getAggregatingState(AggregatingStateDescriptor<IN, ACC, OUT> stateProperties) { |
| throw new UnsupportedOperationException( |
| "This state is only accessible by functions executed on a KeyedStream"); |
| } |
| |
| @Override |
| @PublicEvolving |
| @Deprecated |
| public <T, ACC> FoldingState<T, ACC> getFoldingState(FoldingStateDescriptor<T, ACC> stateProperties) { |
| throw new UnsupportedOperationException( |
| "This state is only accessible by functions executed on a KeyedStream"); |
| } |
| |
| @Override |
| @PublicEvolving |
| public <UK, UV> MapState<UK, UV> getMapState(MapStateDescriptor<UK, UV> stateProperties) { |
| throw new UnsupportedOperationException( |
| "This state is only accessible by functions executed on a KeyedStream"); |
| } |
| |
| @Override |
| public <UK, UV> SortedMapState<UK, UV> getSortedMapState(SortedMapStateDescriptor<UK, UV> stateProperties) { |
| throw new UnsupportedOperationException( |
| "This state is only accessible by functions executed on a KeyedStream"); |
| } |
| |
| @Internal |
| @VisibleForTesting |
| public String getAllocationIDAsString() { |
| return taskInfo.getAllocationIDAsString(); |
| } |
| } |