blob: 534d7a03650db73afc210dd9fd3445ee69ab09f2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.util.common.worker;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.beam.vendor.guava.v20_0.com.google.common.annotations.VisibleForTesting;
/** Static utility methods that provide {@link GroupingTable} implementations. */
public class GroupingTables {
/** Returns a {@link GroupingTable} that groups inputs into a {@link List}. */
public static <K, V> GroupingTable<K, V, List<V>> buffering(
GroupingKeyCreator<? super K> groupingKeyCreator,
PairInfo pairInfo,
SizeEstimator<? super K> keySizer,
SizeEstimator<? super V> valueSizer) {
return new BufferingGroupingTable<>(
DEFAULT_MAX_GROUPING_TABLE_BYTES, groupingKeyCreator, pairInfo, keySizer, valueSizer);
}
/**
* Returns a {@link GroupingTable} with a given max size that groups inputs into a {@link List}.
*/
public static <K, V> GroupingTable<K, V, List<V>> buffering(
Long maxTableSizeBytes,
GroupingKeyCreator<? super K> groupingKeyCreator,
PairInfo pairInfo,
SizeEstimator<? super K> keySizer,
SizeEstimator<? super V> valueSizer) {
return new BufferingGroupingTable<>(
maxTableSizeBytes, groupingKeyCreator, pairInfo, keySizer, valueSizer);
}
/**
* Returns a {@link GroupingTable} that groups inputs into a {@link List} with sampling {@link
* SizeEstimator SizeEstimators}.
*/
public static <K, V> GroupingTable<K, V, List<V>> bufferingAndSampling(
GroupingKeyCreator<? super K> groupingKeyCreator,
PairInfo pairInfo,
SizeEstimator<? super K> keySizer,
SizeEstimator<? super V> valueSizer,
double sizeEstimatorSampleRate) {
return new BufferingGroupingTable<>(
DEFAULT_MAX_GROUPING_TABLE_BYTES,
groupingKeyCreator,
pairInfo,
new SamplingSizeEstimator<>(keySizer, sizeEstimatorSampleRate, 1.0),
new SamplingSizeEstimator<>(valueSizer, sizeEstimatorSampleRate, 1.0));
}
/** Returns a {@link GroupingTable} that combines inputs into a accumulator. */
public static <K, InputT, AccumT> GroupingTable<K, InputT, AccumT> combining(
GroupingKeyCreator<? super K> groupingKeyCreator,
PairInfo pairInfo,
Combiner<? super K, InputT, AccumT, ?> combineFn,
SizeEstimator<? super K> keySizer,
SizeEstimator<? super AccumT> accumulatorSizer) {
return new CombiningGroupingTable<>(
DEFAULT_MAX_GROUPING_TABLE_BYTES,
groupingKeyCreator,
pairInfo,
combineFn,
keySizer,
accumulatorSizer);
}
/**
* Returns a {@link GroupingTable} that combines inputs into a accumulator with sampling {@link
* SizeEstimator SizeEstimators}.
*/
public static <K, InputT, AccumT> GroupingTable<K, InputT, AccumT> combiningAndSampling(
GroupingKeyCreator<? super K> groupingKeyCreator,
PairInfo pairInfo,
Combiner<? super K, InputT, AccumT, ?> combineFn,
SizeEstimator<? super K> keySizer,
SizeEstimator<? super AccumT> accumulatorSizer,
double sizeEstimatorSampleRate) {
return new CombiningGroupingTable<>(
DEFAULT_MAX_GROUPING_TABLE_BYTES,
groupingKeyCreator,
pairInfo,
combineFn,
new SamplingSizeEstimator<>(keySizer, sizeEstimatorSampleRate, 1.0),
new SamplingSizeEstimator<>(accumulatorSizer, sizeEstimatorSampleRate, 1.0));
}
/** Provides client-specific operations for grouping keys. */
public static interface GroupingKeyCreator<K> {
public Object createGroupingKey(K key) throws Exception;
}
/** Provides client-specific operations for size estimates. */
public static interface SizeEstimator<T> {
public long estimateSize(T element) throws Exception;
}
/**
* Provides client-specific operations for working with elements that are key/value or key/values
* pairs.
*/
public interface PairInfo {
public Object getKeyFromInputPair(Object pair);
public Object getValueFromInputPair(Object pair);
public Object makeOutputPair(Object key, Object value);
}
/** Provides client-specific operations for combining values. */
public interface Combiner<K, InputT, AccumT, OutputT> {
public AccumT createAccumulator(K key);
public AccumT add(K key, AccumT accumulator, InputT value);
public AccumT merge(K key, Iterable<AccumT> accumulators);
public AccumT compact(K key, AccumT accumulator);
public OutputT extract(K key, AccumT accumulator);
}
// By default, how many bytes we allow the grouping table to consume before
// it has to be flushed.
static final long DEFAULT_MAX_GROUPING_TABLE_BYTES = 100_000_000L;
// How many bytes a word in the JVM has.
static final int BYTES_PER_JVM_WORD = getBytesPerJvmWord();
/**
* The number of bytes of overhead to store an entry in the grouping table (a {@code
* HashMap<StructuralByteArray, KeyAndValues>}), ignoring the actual number of bytes in the keys
* and values:
*
* <ul>
* <li>an array element (1 word),
* <li>a HashMap.Entry (4 words),
* <li>a StructuralByteArray (1 words),
* <li>a backing array (guessed at 1 word for the length),
* <li>a KeyAndValues (2 words),
* <li>an ArrayList (2 words),
* <li>a backing array (1 word),
* <li>per-object overhead (JVM-specific, guessed at 2 words * 6 objects).
* </ul>
*/
static final int PER_KEY_OVERHEAD = 24 * BYTES_PER_JVM_WORD;
/**
* A base class of {@link GroupingTable} that provides the implementation of {@link #put} and
* {@link #flush}.
*
* <p>Subclasses override {@link #createTableEntry}.
*/
@VisibleForTesting
public abstract static class GroupingTableBase<K, InputT, AccumT>
implements GroupingTable<K, InputT, AccumT> {
// Keep the table relatively full to increase the chance of collisions.
private static final double TARGET_LOAD = 0.9;
private long maxSize;
private final GroupingKeyCreator<? super K> groupingKeyCreator;
private final PairInfo pairInfo;
private long size = 0;
private Map<Object, GroupingTableEntry<K, InputT, AccumT>> table;
private GroupingTableBase(
long maxSize, GroupingKeyCreator<? super K> groupingKeyCreator, PairInfo pairInfo) {
this.maxSize = maxSize;
this.groupingKeyCreator = groupingKeyCreator;
this.pairInfo = pairInfo;
this.table = new HashMap<>();
}
interface GroupingTableEntry<K, InputT, AccumT> {
public K getKey();
public AccumT getValue();
public void add(InputT value) throws Exception;
public long getSize();
public void compact() throws Exception;
}
public abstract GroupingTableEntry<K, InputT, AccumT> createTableEntry(K key) throws Exception;
/** Adds a pair to this table, possibly flushing some entries to output if the table is full. */
@SuppressWarnings("unchecked")
@Override
public void put(Object pair, Receiver receiver) throws Exception {
put(
(K) pairInfo.getKeyFromInputPair(pair),
(InputT) pairInfo.getValueFromInputPair(pair),
receiver);
}
/**
* Adds the key and value to this table, possibly flushing some entries to output if the table
* is full.
*/
public void put(K key, InputT value, Receiver receiver) throws Exception {
Object groupingKey = groupingKeyCreator.createGroupingKey(key);
GroupingTableEntry<K, InputT, AccumT> entry = table.get(groupingKey);
if (entry == null) {
entry = createTableEntry(key);
table.put(groupingKey, entry);
size += PER_KEY_OVERHEAD;
} else {
size -= entry.getSize();
}
entry.add(value);
size += entry.getSize();
if (size >= maxSize) {
long targetSize = (long) (TARGET_LOAD * maxSize);
Iterator<GroupingTableEntry<K, InputT, AccumT>> entries = table.values().iterator();
while (size >= targetSize) {
if (!entries.hasNext()) {
// Should never happen, but sizes may be estimates...
size = 0;
break;
}
GroupingTableEntry<K, InputT, AccumT> toFlush = entries.next();
entries.remove();
size -= toFlush.getSize() + PER_KEY_OVERHEAD;
output(toFlush, receiver);
}
}
}
/**
* Output the given entry. Does not actually remove it from the table or update this table's
* size.
*/
private void output(GroupingTableEntry<K, InputT, AccumT> entry, Receiver receiver)
throws Exception {
entry.compact();
receiver.process(pairInfo.makeOutputPair(entry.getKey(), entry.getValue()));
}
/** Flushes all entries in this table to output. */
@Override
public void flush(Receiver output) throws Exception {
for (GroupingTableEntry<K, InputT, AccumT> entry : table.values()) {
output(entry, output);
}
table.clear();
size = 0;
}
@VisibleForTesting
public void setMaxSize(long maxSize) {
this.maxSize = maxSize;
}
@VisibleForTesting
public long size() {
return size;
}
}
/** A grouping table that simply buffers all inserted values in a list. */
private static class BufferingGroupingTable<K, V> extends GroupingTableBase<K, V, List<V>> {
public final SizeEstimator<? super K> keySizer;
public final SizeEstimator<? super V> valueSizer;
private BufferingGroupingTable(
long maxSize,
GroupingKeyCreator<? super K> groupingKeyCreator,
PairInfo pairInfo,
SizeEstimator<? super K> keySizer,
SizeEstimator<? super V> valueSizer) {
super(maxSize, groupingKeyCreator, pairInfo);
this.keySizer = keySizer;
this.valueSizer = valueSizer;
}
@Override
public GroupingTableEntry<K, V, List<V>> createTableEntry(final K key) throws Exception {
return new GroupingTableEntry<K, V, List<V>>() {
long size = keySizer.estimateSize(key);
final List<V> values = new ArrayList<>();
@Override
public K getKey() {
return key;
}
@Override
public List<V> getValue() {
return values;
}
@Override
public long getSize() {
return size;
}
@Override
public void compact() {}
@Override
public void add(V value) throws Exception {
values.add(value);
size += BYTES_PER_JVM_WORD + valueSizer.estimateSize(value);
}
};
}
}
/** A grouping table that uses the given combiner to combine values in place. */
private static class CombiningGroupingTable<K, InputT, AccumT>
extends GroupingTableBase<K, InputT, AccumT> {
private final Combiner<? super K, InputT, AccumT, ?> combiner;
private final SizeEstimator<? super K> keySizer;
private final SizeEstimator<? super AccumT> accumulatorSizer;
private CombiningGroupingTable(
long maxSize,
GroupingKeyCreator<? super K> groupingKeyCreator,
PairInfo pairInfo,
Combiner<? super K, InputT, AccumT, ?> combineFn,
SizeEstimator<? super K> keySizer,
SizeEstimator<? super AccumT> accumulatorSizer) {
super(maxSize, groupingKeyCreator, pairInfo);
this.combiner = combineFn;
this.keySizer = keySizer;
this.accumulatorSizer = accumulatorSizer;
}
@Override
public GroupingTableEntry<K, InputT, AccumT> createTableEntry(final K key) throws Exception {
return new GroupingTableEntry<K, InputT, AccumT>() {
final long keySize = keySizer.estimateSize(key);
AccumT accumulator = combiner.createAccumulator(key);
long accumulatorSize = 0; // never used before a value is added...
@Override
public K getKey() {
return key;
}
@Override
public AccumT getValue() {
return accumulator;
}
@Override
public long getSize() {
return keySize + accumulatorSize;
}
@Override
public void compact() throws Exception {
AccumT newAccumulator = combiner.compact(key, accumulator);
if (newAccumulator != accumulator) {
accumulator = newAccumulator;
accumulatorSize = accumulatorSizer.estimateSize(newAccumulator);
}
}
@Override
public void add(InputT value) throws Exception {
accumulator = combiner.add(key, accumulator, value);
accumulatorSize = accumulatorSizer.estimateSize(accumulator);
}
};
}
}
/** Returns the number of bytes in a JVM word. In case we failed to find the answer, returns 8. */
private static int getBytesPerJvmWord() {
String wordSizeInBits = System.getProperty("sun.arch.data.model");
try {
return Integer.parseInt(wordSizeInBits) / 8;
} catch (NumberFormatException e) {
// The JVM word size is unknown. Assume 64-bit.
return 8;
}
}
////////////////////////////////////////////////////////////////////////////
// Size sampling.
/**
* Implements size estimation by adaptively delegating to an underlying (potentially more
* expensive) estimator for some elements and returning the average value for others.
*/
@VisibleForTesting
static class SamplingSizeEstimator<T> implements SizeEstimator<T> {
/**
* The degree of confidence required in our expected value predictions before we allow
* under-sampling.
*
* <p>The value of 3.0 is a confidence interval of about 99.7% for a a high-degree-of-freedom
* t-distribution.
*/
public static final double CONFIDENCE_INTERVAL_SIGMA = 3;
/**
* The desired size of our confidence interval (relative to the measured expected value).
*
* <p>The value of 0.25 is plus or minus 25%.
*/
public static final double CONFIDENCE_INTERVAL_SIZE = 0.25;
/** Default number of elements that must be measured before elements are skipped. */
public static final long DEFAULT_MIN_SAMPLED = 20;
private final SizeEstimator<T> underlying;
private final double minSampleRate;
private final double maxSampleRate;
private final long minSampled;
private final Random random;
private long totalElements = 0;
private long sampledElements = 0;
private long sampledSum = 0;
private double sampledSumSquares = 0;
private long estimate;
private long nextSample = 0;
private SamplingSizeEstimator(
SizeEstimator<T> underlying, double minSampleRate, double maxSampleRate) {
this(underlying, minSampleRate, maxSampleRate, DEFAULT_MIN_SAMPLED, new Random());
}
@VisibleForTesting
SamplingSizeEstimator(
SizeEstimator<T> underlying,
double minSampleRate,
double maxSampleRate,
long minSampled,
Random random) {
this.underlying = underlying;
this.minSampleRate = minSampleRate;
this.maxSampleRate = maxSampleRate;
this.minSampled = minSampled;
this.random = random;
}
@Override
public long estimateSize(T element) throws Exception {
if (sampleNow()) {
return recordSample(underlying.estimateSize(element));
} else {
return estimate;
}
}
private boolean sampleNow() {
totalElements++;
return --nextSample < 0;
}
private long recordSample(long value) {
sampledElements += 1;
sampledSum += value;
sampledSumSquares += value * value;
estimate = (long) Math.ceil(sampledSum / (double) sampledElements);
long target = desiredSampleSize();
if (sampledElements < minSampled || sampledElements < target) {
// Sample immediately.
nextSample = 0;
} else {
double rate =
cap(
minSampleRate,
maxSampleRate,
Math.max(
1.0 / (totalElements - minSampled + 1), // slowly ramp down
target / (double) totalElements)); // "future" target
// Uses the geometric distribution to return the likely distance between
// successive independent trials of a fixed probability p. This gives the
// same uniform distribution of branching on Math.random() < p, but with
// one random number generation per success rather than one
// per test, which can be a significant savings if p is small.
nextSample =
rate == 1.0 ? 0 : (long) Math.floor(Math.log(random.nextDouble()) / Math.log(1 - rate));
}
return value;
}
private static final double cap(double min, double max, double value) {
return Math.min(max, Math.max(min, value));
}
private long desiredSampleSize() {
// We have no a-priori information on the actual distribution of data
// sizes, so compute our desired sample as if it were normal.
// Yes this formula is unstable for small stddev, but we only care about large stddev.
double mean = sampledSum / (double) sampledElements;
double sumSquareDiff =
(sampledSumSquares - (2 * mean * sampledSum) + (sampledElements * mean * mean));
double stddev = Math.sqrt(sumSquareDiff / (sampledElements - 1));
double sqrtDesiredSamples =
(CONFIDENCE_INTERVAL_SIGMA * stddev) / (CONFIDENCE_INTERVAL_SIZE * mean);
return (long) Math.ceil(sqrtDesiredSamples * sqrtDesiredSamples);
}
}
}