| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.sdk.transforms; |
| |
| import java.io.Serializable; |
| import org.apache.beam.sdk.coders.Coder; |
| import org.apache.beam.sdk.transforms.display.DisplayData; |
| import org.apache.beam.sdk.values.PCollection; |
| import org.apache.beam.sdk.values.PCollectionList; |
| import org.apache.beam.sdk.values.PCollectionTuple; |
| import org.apache.beam.sdk.values.TupleTag; |
| import org.apache.beam.sdk.values.TupleTagList; |
| |
| /** |
| * {@code Partition} takes a {@code PCollection<T>} and a |
| * {@code PartitionFn}, uses the {@code PartitionFn} to split the |
| * elements of the input {@code PCollection} into {@code N} partitions, and |
| * returns a {@code PCollectionList<T>} that bundles {@code N} |
| * {@code PCollection<T>}s containing the split elements. |
| * |
| * <p>Example of use: |
| * <pre> {@code |
| * PCollection<Student> students = ...; |
| * // Split students up into 10 partitions, by percentile: |
| * PCollectionList<Student> studentsByPercentile = |
| * students.apply(Partition.of(10, new PartitionFn<Student>() { |
| * public int partitionFor(Student student, int numPartitions) { |
| * return student.getPercentile() // 0..99 |
| * * numPartitions / 100; |
| * }})) |
| * for (int i = 0; i < 10; i++) { |
| * PCollection<Student> partition = studentsByPercentile.get(i); |
| * ... |
| * } |
| * } </pre> |
| * |
| * <p>By default, the {@code Coder} of each of the |
| * {@code PCollection}s in the output {@code PCollectionList} is the |
| * same as the {@code Coder} of the input {@code PCollection}. |
| * |
| * <p>Each output element has the same timestamp and is in the same windows |
| * as its corresponding input element, and each output {@code PCollection} |
| * has the same |
| * {@link org.apache.beam.sdk.transforms.windowing.WindowFn} |
| * associated with it as the input. |
| * |
| * @param <T> the type of the elements of the input and output |
| * {@code PCollection}s |
| */ |
| public class Partition<T> extends PTransform<PCollection<T>, PCollectionList<T>> { |
| |
| /** |
| * A function object that chooses an output partition for an element. |
| * |
| * @param <T> the type of the elements being partitioned |
| */ |
| public interface PartitionFn<T> extends Serializable { |
| /** |
| * Chooses the partition into which to put the given element. |
| * |
| * @param elem the element to be partitioned |
| * @param numPartitions the total number of partitions ({@code >= 1}) |
| * @return index of the selected partition (in the range |
| * {@code [0..numPartitions-1]}) |
| */ |
| int partitionFor(T elem, int numPartitions); |
| } |
| |
| /** |
| * Returns a new {@code Partition} {@code PTransform} that divides |
| * its input {@code PCollection} into the given number of partitions, |
| * using the given partitioning function. |
| * |
| * @param numPartitions the number of partitions to divide the input |
| * {@code PCollection} into |
| * @param partitionFn the function to invoke on each element to |
| * choose its output partition |
| * @throws IllegalArgumentException if {@code numPartitions <= 0} |
| */ |
| public static <T> Partition<T> of( |
| int numPartitions, PartitionFn<? super T> partitionFn) { |
| return new Partition<>(new PartitionDoFn<T>(numPartitions, partitionFn)); |
| } |
| |
| ///////////////////////////////////////////////////////////////////////////// |
| |
| @Override |
| public PCollectionList<T> expand(PCollection<T> in) { |
| final TupleTagList outputTags = partitionDoFn.getOutputTags(); |
| |
| PCollectionTuple outputs = in.apply( |
| ParDo |
| .of(partitionDoFn) |
| .withOutputTags(new TupleTag<Void>(){}, outputTags)); |
| |
| PCollectionList<T> pcs = PCollectionList.empty(in.getPipeline()); |
| Coder<T> coder = in.getCoder(); |
| |
| for (TupleTag<?> outputTag : outputTags.getAll()) { |
| // All the tuple tags are actually TupleTag<T> |
| // And all the collections are actually PCollection<T> |
| @SuppressWarnings("unchecked") |
| TupleTag<T> typedOutputTag = (TupleTag<T>) outputTag; |
| pcs = pcs.and(outputs.get(typedOutputTag).setCoder(coder)); |
| } |
| return pcs; |
| } |
| |
| @Override |
| public void populateDisplayData(DisplayData.Builder builder) { |
| super.populateDisplayData(builder); |
| builder.include("partitionFn", partitionDoFn); |
| } |
| |
| private final transient PartitionDoFn<T> partitionDoFn; |
| |
| private Partition(PartitionDoFn<T> partitionDoFn) { |
| this.partitionDoFn = partitionDoFn; |
| } |
| |
| private static class PartitionDoFn<X> extends DoFn<X, Void> { |
| private final int numPartitions; |
| private final PartitionFn<? super X> partitionFn; |
| private final TupleTagList outputTags; |
| |
| /** |
| * Constructs a PartitionDoFn. |
| * |
| * @throws IllegalArgumentException if {@code numPartitions <= 0} |
| */ |
| public PartitionDoFn(int numPartitions, PartitionFn<? super X> partitionFn) { |
| if (numPartitions <= 0) { |
| throw new IllegalArgumentException("numPartitions must be > 0"); |
| } |
| |
| this.numPartitions = numPartitions; |
| this.partitionFn = partitionFn; |
| |
| TupleTagList buildOutputTags = TupleTagList.empty(); |
| for (int partition = 0; partition < numPartitions; partition++) { |
| buildOutputTags = buildOutputTags.and(new TupleTag<X>()); |
| } |
| outputTags = buildOutputTags; |
| } |
| |
| public TupleTagList getOutputTags() { |
| return outputTags; |
| } |
| |
| @ProcessElement |
| public void processElement(ProcessContext c) { |
| X input = c.element(); |
| int partition = partitionFn.partitionFor(input, numPartitions); |
| if (0 <= partition && partition < numPartitions) { |
| @SuppressWarnings("unchecked") |
| TupleTag<X> typedTag = (TupleTag<X>) outputTags.get(partition); |
| c.output(typedTag, input); |
| } else { |
| throw new IndexOutOfBoundsException( |
| "Partition function returned out of bounds index: " |
| + partition + " not in [0.." + numPartitions + ")"); |
| } |
| } |
| |
| @Override |
| public void populateDisplayData(DisplayData.Builder builder) { |
| super.populateDisplayData(builder); |
| builder |
| .add(DisplayData.item("numPartitions", numPartitions) |
| .withLabel("Partition Count")) |
| .add(DisplayData.item("partitionFn", partitionFn.getClass()) |
| .withLabel("Partition Function")); |
| } |
| } |
| } |