blob: 0f7c0d9c2c7aa200bd6ab7a417c5b90d8150a1ce [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.transforms;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Objects;
import org.apache.beam.sdk.coders.AtomicCoder;
import org.apache.beam.sdk.coders.BigEndianLongCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.DoubleCoder;
import org.apache.beam.sdk.transforms.Combine.AccumulatingCombineFn.Accumulator;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
/**
* {@code PTransform}s for computing the arithmetic mean (a.k.a. average) of the elements in a
* {@code PCollection}, or the mean of the values associated with each key in a {@code PCollection}
* of {@code KV}s.
*
* <p>Example 1: get the mean of a {@code PCollection} of {@code Long}s.
*
* <pre>{@code
* PCollection<Long> input = ...;
* PCollection<Double> mean = input.apply(Mean.<Long>globally());
* }</pre>
*
* <p>Example 2: calculate the mean of the {@code Integer}s associated with each unique key (which
* is of type {@code String}).
*
* <pre>{@code
* PCollection<KV<String, Integer>> input = ...;
* PCollection<KV<String, Double>> meanPerKey =
* input.apply(Mean.<String, Integer>perKey());
* }</pre>
*/
public class Mean {
private Mean() {} // Namespace only
/**
* Returns a {@code PTransform} that takes an input {@code PCollection<NumT>} and returns a {@code
* PCollection<Double>} whose contents is the mean of the input {@code PCollection}'s elements, or
* {@code 0} if there are no elements.
*
* @param <NumT> the type of the {@code Number}s being combined
*/
public static <NumT extends Number> Combine.Globally<NumT, Double> globally() {
return Combine.globally(Mean.of());
}
/**
* Returns a {@code PTransform} that takes an input {@code PCollection<KV<K, N>>} and returns a
* {@code PCollection<KV<K, Double>>} that contains an output element mapping each distinct key in
* the input {@code PCollection} to the mean of the values associated with that key in the input
* {@code PCollection}.
*
* <p>See {@link Combine.PerKey} for how this affects timestamps and bucketing.
*
* @param <K> the type of the keys
* @param <NumT> the type of the {@code Number}s being combined
*/
public static <K, NumT extends Number> Combine.PerKey<K, NumT, Double> perKey() {
return Combine.perKey(Mean.of());
}
/**
* A {@code Combine.CombineFn} that computes the arithmetic mean (a.k.a. average) of an {@code
* Iterable} of numbers of type {@code N}, useful as an argument to {@link Combine#globally} or
* {@link Combine#perKey}.
*
* <p>Returns {@code Double.NaN} if combining zero elements.
*
* @param <NumT> the type of the {@code Number}s being combined
*/
public static <NumT extends Number>
Combine.AccumulatingCombineFn<NumT, CountSum<NumT>, Double> of() {
return new MeanFn<>();
}
/////////////////////////////////////////////////////////////////////////////
private static class MeanFn<NumT extends Number>
extends Combine.AccumulatingCombineFn<NumT, CountSum<NumT>, Double> {
/**
* Constructs a combining function that computes the mean over a collection of values of type
* {@code N}.
*/
@Override
public CountSum<NumT> createAccumulator() {
return new CountSum<>();
}
@Override
public Coder<CountSum<NumT>> getAccumulatorCoder(
CoderRegistry registry, Coder<NumT> inputCoder) {
return new CountSumCoder<>();
}
}
/** Accumulator class for {@link MeanFn}. */
static class CountSum<NumT extends Number> implements Accumulator<NumT, CountSum<NumT>, Double> {
long count = 0;
double sum = 0.0;
public CountSum() {
this(0, 0);
}
public CountSum(long count, double sum) {
this.count = count;
this.sum = sum;
}
@Override
public void addInput(NumT element) {
count++;
sum += element.doubleValue();
}
@Override
public void mergeAccumulator(CountSum<NumT> accumulator) {
count += accumulator.count;
sum += accumulator.sum;
}
@Override
public Double extractOutput() {
return count == 0 ? Double.NaN : sum / count;
}
@Override
public boolean equals(Object other) {
if (!(other instanceof CountSum)) {
return false;
}
@SuppressWarnings("unchecked")
CountSum<?> otherCountSum = (CountSum<?>) other;
return (count == otherCountSum.count) && (sum == otherCountSum.sum);
}
@Override
public int hashCode() {
return Objects.hash(count, sum);
}
@Override
public String toString() {
return MoreObjects.toStringHelper(this).add("count", count).add("sum", sum).toString();
}
}
static class CountSumCoder<NumT extends Number> extends AtomicCoder<CountSum<NumT>> {
private static final Coder<Long> LONG_CODER = BigEndianLongCoder.of();
private static final Coder<Double> DOUBLE_CODER = DoubleCoder.of();
@Override
public void encode(CountSum<NumT> value, OutputStream outStream)
throws CoderException, IOException {
LONG_CODER.encode(value.count, outStream);
DOUBLE_CODER.encode(value.sum, outStream);
}
@Override
public CountSum<NumT> decode(InputStream inStream) throws CoderException, IOException {
return new CountSum<>(LONG_CODER.decode(inStream), DOUBLE_CODER.decode(inStream));
}
@Override
public void verifyDeterministic() throws NonDeterministicException {
LONG_CODER.verifyDeterministic();
DOUBLE_CODER.verifyDeterministic();
}
}
}