blob: eae9e3c2e207e6622904227ff0c58adc397702f4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.test.accumulators;
import org.apache.flink.api.common.JobExecutionResult;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.accumulators.AccumulatorHelper;
import org.apache.flink.api.common.accumulators.DoubleCounter;
import org.apache.flink.api.common.accumulators.Histogram;
import org.apache.flink.api.common.accumulators.IntCounter;
import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.test.util.JavaProgramTestBase;
import org.apache.flink.types.StringValue;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/**
* Test for the basic functionality of accumulators. We cannot test all different kinds of plans
* here (iterative, etc.).
*
* <p>TODO Test conflict when different UDFs write to accumulator with same name but with different
* type. The conflict will occur in JobManager while merging.
*/
@SuppressWarnings("serial")
public class AccumulatorITCase extends JavaProgramTestBase {
private static final String INPUT = "one\n" + "two two\n" + "three three three\n";
private static final String EXPECTED = "one 1\ntwo 2\nthree 3\n";
private String dataPath;
private String resultPath;
private JobExecutionResult result;
@Override
protected void preSubmit() throws Exception {
dataPath = createTempFile("datapoints.txt", INPUT);
resultPath = getTempFilePath("result");
}
@Override
protected void postSubmit() throws Exception {
compareResultsByLinesInMemory(EXPECTED, resultPath);
// Test accumulator results
System.out.println("Accumulator results:");
JobExecutionResult res = this.result;
System.out.println(AccumulatorHelper.getResultsFormatted(res.getAllAccumulatorResults()));
Assert.assertEquals(Integer.valueOf(3), res.getAccumulatorResult("num-lines"));
Assert.assertEquals(Integer.valueOf(3), res.getIntCounterResult("num-lines"));
Assert.assertEquals(
Double.valueOf(getParallelism()), res.getAccumulatorResult("open-close-counter"));
// Test histogram (words per line distribution)
Map<Integer, Integer> dist = new HashMap<>();
dist.put(1, 1);
dist.put(2, 1);
dist.put(3, 1);
Assert.assertEquals(dist, res.getAccumulatorResult("words-per-line"));
// Test distinct words (custom accumulator)
Set<StringValue> distinctWords = new HashSet<>();
distinctWords.add(new StringValue("one"));
distinctWords.add(new StringValue("two"));
distinctWords.add(new StringValue("three"));
Assert.assertEquals(distinctWords, res.getAccumulatorResult("distinct-words"));
}
@Override
protected void testProgram() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSet<String> input = env.readTextFile(dataPath);
input.flatMap(new TokenizeLine())
.groupBy(0)
.reduceGroup(new CountWords())
.writeAsCsv(resultPath, "\n", " ");
this.result = env.execute();
}
private static class TokenizeLine extends RichFlatMapFunction<String, Tuple2<String, Integer>> {
// Needs to be instantiated later since the runtime context is not yet
// initialized at this place
private IntCounter cntNumLines;
private Histogram wordsPerLineDistribution;
// This counter will be added without convenience functions
private DoubleCounter openCloseCounter = new DoubleCounter();
private SetAccumulator<StringValue> distinctWords;
@Override
public void open(Configuration parameters) {
// Add counters using convenience functions
this.cntNumLines = getRuntimeContext().getIntCounter("num-lines");
this.wordsPerLineDistribution = getRuntimeContext().getHistogram("words-per-line");
// Add built-in accumulator without convenience function
getRuntimeContext().addAccumulator("open-close-counter", this.openCloseCounter);
// Add custom counter
this.distinctWords = new SetAccumulator<>();
this.getRuntimeContext().addAccumulator("distinct-words", distinctWords);
// Create counter and test increment
IntCounter simpleCounter = getRuntimeContext().getIntCounter("simple-counter");
simpleCounter.add(1);
Assert.assertEquals(simpleCounter.getLocalValue().intValue(), 1);
// Test if we get the same counter
IntCounter simpleCounter2 = getRuntimeContext().getIntCounter("simple-counter");
Assert.assertEquals(simpleCounter.getLocalValue(), simpleCounter2.getLocalValue());
// Should fail if we request it with different type
try {
@SuppressWarnings("unused")
DoubleCounter simpleCounter3 =
getRuntimeContext().getDoubleCounter("simple-counter");
// DoubleSumAggregator longAggregator3 = (DoubleSumAggregator)
// getRuntimeContext().getAggregator("custom",
// DoubleSumAggregator.class);
Assert.fail(
"Should not be able to obtain previously created counter with different type");
} catch (UnsupportedOperationException ex) {
// expected!
}
// Test counter used in open() and closed()
this.openCloseCounter.add(0.5);
}
@Override
public void flatMap(String value, Collector<Tuple2<String, Integer>> out) {
this.cntNumLines.add(1);
int wordsPerLine = 0;
for (String token : value.toLowerCase().split("\\W+")) {
distinctWords.add(new StringValue(token));
out.collect(new Tuple2<>(token, 1));
++wordsPerLine;
}
wordsPerLineDistribution.add(wordsPerLine);
}
@Override
public void close() throws Exception {
// Test counter used in open and close only
this.openCloseCounter.add(0.5);
Assert.assertEquals(1, this.openCloseCounter.getLocalValue().intValue());
}
}
private static class CountWords
extends RichGroupReduceFunction<Tuple2<String, Integer>, Tuple2<String, Integer>>
implements GroupCombineFunction<Tuple2<String, Integer>, Tuple2<String, Integer>> {
private IntCounter reduceCalls;
private IntCounter combineCalls;
@Override
public void open(Configuration parameters) {
this.reduceCalls = getRuntimeContext().getIntCounter("reduce-calls");
this.combineCalls = getRuntimeContext().getIntCounter("combine-calls");
}
@Override
public void reduce(
Iterable<Tuple2<String, Integer>> values, Collector<Tuple2<String, Integer>> out) {
reduceCalls.add(1);
reduceInternal(values, out);
}
@Override
public void combine(
Iterable<Tuple2<String, Integer>> values, Collector<Tuple2<String, Integer>> out) {
combineCalls.add(1);
reduceInternal(values, out);
}
private void reduceInternal(
Iterable<Tuple2<String, Integer>> values, Collector<Tuple2<String, Integer>> out) {
int sum = 0;
String key = null;
for (Tuple2<String, Integer> e : values) {
key = e.f0;
sum += e.f1;
}
out.collect(new Tuple2<>(key, sum));
}
}
/** Custom accumulator. */
public static class SetAccumulator<T> implements Accumulator<T, HashSet<T>> {
private static final long serialVersionUID = 1L;
private HashSet<T> set = new HashSet<>();
@Override
public void add(T value) {
this.set.add(value);
}
@Override
public HashSet<T> getLocalValue() {
return this.set;
}
@Override
public void resetLocal() {
this.set.clear();
}
@Override
public void merge(Accumulator<T, HashSet<T>> other) {
// build union
this.set.addAll(other.getLocalValue());
}
@Override
public Accumulator<T, HashSet<T>> clone() {
SetAccumulator<T> result = new SetAccumulator<>();
result.set.addAll(set);
return result;
}
}
}