blob: 0184e58dcddda024bf1b473e67fcf4a2a5f75a55 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.accumulators.SerializedListAccumulator;
import org.apache.flink.api.common.accumulators.SimpleAccumulator;
import org.apache.flink.api.common.io.RichOutputFormat;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.java.typeutils.GenericTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.commons.lang3.StringUtils;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.Random;
import static org.apache.flink.api.java.functions.FunctionAnnotation.SkipCodeAnalysis;
/**
* Utility class that contains helper methods to work with Java APIs.
*/
@Internal
public final class Utils {
public static final Random RNG = new Random();
public static String getCallLocationName() {
return getCallLocationName(4);
}
public static String getCallLocationName(int depth) {
StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
if (stackTrace.length <= depth) {
return "<unknown>";
}
StackTraceElement elem = stackTrace[depth];
return String.format("%s(%s:%d)", elem.getMethodName(), elem.getFileName(), elem.getLineNumber());
}
// --------------------------------------------------------------------------------------------
/**
* Utility sink function that counts elements and writes the count into an accumulator,
* from which it can be retrieved by the client. This sink is used by the
* {@link DataSet#count()} function.
*
* @param <T> Type of elements to count.
*/
@SkipCodeAnalysis
public static class CountHelper<T> extends RichOutputFormat<T> {
private static final long serialVersionUID = 1L;
private final String id;
private long counter;
public CountHelper(String id) {
this.id = id;
this.counter = 0L;
}
@Override
public void configure(Configuration parameters) {}
@Override
public void open(int taskNumber, int numTasks) {}
@Override
public void writeRecord(T record) {
counter++;
}
@Override
public void close() {
getRuntimeContext().getLongCounter(id).add(counter);
}
}
/**
* Utility sink function that collects elements into an accumulator,
* from which it they can be retrieved by the client. This sink is used by the
* {@link DataSet#collect()} function.
*
* @param <T> Type of elements to count.
*/
@SkipCodeAnalysis
public static class CollectHelper<T> extends RichOutputFormat<T> {
private static final long serialVersionUID = 1L;
private final String id;
private final TypeSerializer<T> serializer;
private SerializedListAccumulator<T> accumulator;
public CollectHelper(String id, TypeSerializer<T> serializer) {
this.id = id;
this.serializer = serializer;
}
@Override
public void configure(Configuration parameters) {}
@Override
public void open(int taskNumber, int numTasks) {
this.accumulator = new SerializedListAccumulator<>();
}
@Override
public void writeRecord(T record) throws IOException {
accumulator.add(record, serializer);
}
@Override
public void close() {
// Important: should only be added in close method to minimize traffic of accumulators
getRuntimeContext().addAccumulator(id, accumulator);
}
}
/**
* Accumulator of {@link ChecksumHashCode}.
*/
public static class ChecksumHashCode implements SimpleAccumulator<ChecksumHashCode> {
private static final long serialVersionUID = 1L;
private long count;
private long checksum;
public ChecksumHashCode() {}
public ChecksumHashCode(long count, long checksum) {
this.count = count;
this.checksum = checksum;
}
public long getCount() {
return count;
}
public long getChecksum() {
return checksum;
}
@Override
public void add(ChecksumHashCode value) {
this.count += value.count;
this.checksum += value.checksum;
}
@Override
public ChecksumHashCode getLocalValue() {
return this;
}
@Override
public void resetLocal() {
this.count = 0;
this.checksum = 0;
}
@Override
public void merge(Accumulator<ChecksumHashCode, ChecksumHashCode> other) {
this.add(other.getLocalValue());
}
@Override
public ChecksumHashCode clone() {
return new ChecksumHashCode(count, checksum);
}
@Override
public boolean equals(Object obj) {
if (obj instanceof ChecksumHashCode) {
ChecksumHashCode other = (ChecksumHashCode) obj;
return this.count == other.count && this.checksum == other.checksum;
} else {
return false;
}
}
@Override
public int hashCode() {
return (int) (this.count + this.checksum);
}
@Override
public String toString() {
return String.format("ChecksumHashCode 0x%016x, count %d", this.checksum, this.count);
}
}
/**
* {@link RichOutputFormat} for {@link ChecksumHashCode}.
* @param <T>
*/
@SkipCodeAnalysis
public static class ChecksumHashCodeHelper<T> extends RichOutputFormat<T> {
private static final long serialVersionUID = 1L;
private final String id;
private long counter;
private long checksum;
public ChecksumHashCodeHelper(String id) {
this.id = id;
this.counter = 0L;
this.checksum = 0L;
}
@Override
public void configure(Configuration parameters) {}
@Override
public void open(int taskNumber, int numTasks) {}
@Override
public void writeRecord(T record) throws IOException {
counter++;
// convert 32-bit integer to non-negative long
checksum += record.hashCode() & 0xffffffffL;
}
@Override
public void close() throws IOException {
ChecksumHashCode update = new ChecksumHashCode(counter, checksum);
getRuntimeContext().addAccumulator(id, update);
}
}
// --------------------------------------------------------------------------------------------
/**
* Debugging utility to understand the hierarchy of serializers created by the Java API.
* Tested in GroupReduceITCase.testGroupByGenericType()
*/
public static <T> String getSerializerTree(TypeInformation<T> ti) {
return getSerializerTree(ti, 0);
}
private static <T> String getSerializerTree(TypeInformation<T> ti, int indent) {
String ret = "";
if (ti instanceof CompositeType) {
ret += StringUtils.repeat(' ', indent) + ti.getClass().getSimpleName() + "\n";
CompositeType<T> cti = (CompositeType<T>) ti;
String[] fieldNames = cti.getFieldNames();
for (int i = 0; i < cti.getArity(); i++) {
TypeInformation<?> fieldType = cti.getTypeAt(i);
ret += StringUtils.repeat(' ', indent + 2) + fieldNames[i] + ":" + getSerializerTree(fieldType, indent);
}
} else {
if (ti instanceof GenericTypeInfo) {
ret += StringUtils.repeat(' ', indent) + "GenericTypeInfo (" + ti.getTypeClass().getSimpleName() + ")\n";
ret += getGenericTypeTree(ti.getTypeClass(), indent + 4);
} else {
ret += StringUtils.repeat(' ', indent) + ti.toString() + "\n";
}
}
return ret;
}
private static String getGenericTypeTree(Class<?> type, int indent) {
String ret = "";
for (Field field : type.getDeclaredFields()) {
if (Modifier.isStatic(field.getModifiers()) || Modifier.isTransient(field.getModifiers())) {
continue;
}
ret += StringUtils.repeat(' ', indent) + field.getName() + ":" + field.getType().getName() +
(field.getType().isEnum() ? " (is enum)" : "") + "\n";
if (!field.getType().isPrimitive()) {
ret += getGenericTypeTree(field.getType(), indent + 4);
}
}
return ret;
}
/**
* Private constructor to prevent instantiation.
*/
private Utils() {
throw new RuntimeException();
}
}