blob: 2a384573d2b8c5455f3efb4aa2dfec7419e48323 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more contributor license
* agreements. See the NOTICE file distributed with this work for additional information regarding
* copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance with the License. You may obtain a
* copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable
* law or agreed to in writing, software distributed under the License is distributed on an "AS IS"
* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License
* for the specific language governing permissions and limitations under the License.
*/
package org.apache.crunch.lib.join;
import java.io.Serializable;
import java.util.Random;
import org.apache.crunch.DoFn;
import org.apache.crunch.Emitter;
import org.apache.crunch.MapFn;
import org.apache.crunch.PTable;
import org.apache.crunch.Pair;
import org.apache.crunch.types.PTableType;
import org.apache.crunch.types.PTypeFamily;
/**
* JoinStrategy that splits the key space up into shards.
* <p>
* This strategy is useful when there are multiple values per key on at least one side of the join,
* and a large proportion of the values are mapped to a small number of keys.
* <p>
* Using this strategy will increase the number of keys being joined, but can increase performance
* by spreading processing of a single key over multiple reduce groups.
* <p>
* A custom {@link ShardingStrategy} can be provided so that only certain keys are sharded, or
* keys can be sharded in accordance with how many values are mapped to them.
*/
public class ShardedJoinStrategy<K, U, V> implements JoinStrategy<K, U, V> {
/**
* Determines over how many shards a key will be split in a sharded join.
* <p>
* It is essential that implementations of this class are deterministic.
*/
public static interface ShardingStrategy<K> extends Serializable {
/**
* Retrieve the number of shards over which the given key should be split.
* @param key key for which shards are to be determined
* @return number of shards for the given key, must be greater than 0
*/
int getNumShards(K key);
}
private JoinStrategy<Pair<K, Integer>, U, V> wrappedJoinStrategy;
private ShardingStrategy<K> shardingStrategy;
/**
* Instantiate with a constant number of shards to use for all keys.
*
* @param numShards number of shards to use
*/
public ShardedJoinStrategy(int numShards) {
this(new ConstantShardingStrategy<K>(numShards));
}
/**
* Instantiate with a constant number of shards to use for all keys.
*
* @param numShards number of shards to use
* @param numReducers the amount of reducers to run the join with
*/
public ShardedJoinStrategy(int numShards, int numReducers) {
this(new ConstantShardingStrategy<K>(numShards), numReducers);
}
/**
* Instantiate with a custom sharding strategy.
*
* @param shardingStrategy strategy to be used for sharding
*/
public ShardedJoinStrategy(ShardingStrategy<K> shardingStrategy) {
this.wrappedJoinStrategy = new DefaultJoinStrategy<Pair<K, Integer>, U, V>();
this.shardingStrategy = shardingStrategy;
}
/**
* Instantiate with a custom sharding strategy and a specified number of reducers.
*
* @param shardingStrategy strategy to be used for sharding
* @param numReducers the amount of reducers to run the join with
*/
public ShardedJoinStrategy(ShardingStrategy<K> shardingStrategy, int numReducers) {
if (numReducers < 1) {
throw new IllegalArgumentException("Num reducers must be > 0, got " + numReducers);
}
this.wrappedJoinStrategy = new DefaultJoinStrategy<Pair<K, Integer>, U, V>(numReducers);
this.shardingStrategy = shardingStrategy;
}
@Override
public PTable<K, Pair<U, V>> join(PTable<K, U> left, PTable<K, V> right, JoinType joinType) {
if (joinType == JoinType.FULL_OUTER_JOIN || joinType == JoinType.LEFT_OUTER_JOIN) {
throw new UnsupportedOperationException("Join type " + joinType + " not supported by ShardedJoinStrategy");
}
PTypeFamily ptf = left.getTypeFamily();
PTableType<Pair<K, Integer>, U> shardedLeftType = ptf.tableOf(ptf.pairs(left.getKeyType(), ptf.ints()), left.getValueType());
PTableType<Pair<K, Integer>, V> shardedRightType = ptf.tableOf(ptf.pairs(right.getKeyType(), ptf.ints()), right.getValueType());
PTableType<K, Pair<U,V>> outputType = ptf.tableOf(left.getKeyType(), ptf.pairs(left.getValueType(), right.getValueType()));
PTable<Pair<K,Integer>,U> shardedLeft = left.parallelDo("Pre-shard left", new PreShardLeftSideFn<K, U>(shardingStrategy), shardedLeftType);
PTable<Pair<K,Integer>,V> shardedRight = right.parallelDo("Pre-shard right", new PreShardRightSideFn<K, V>(shardingStrategy), shardedRightType);
PTable<Pair<K, Integer>, Pair<U, V>> shardedJoined = wrappedJoinStrategy.join(shardedLeft, shardedRight, joinType);
return shardedJoined.parallelDo("Unshard", new UnshardFn<K, U, V>(), outputType);
}
private static class PreShardLeftSideFn<K, U> extends DoFn<Pair<K, U>, Pair<Pair<K, Integer>, U>> {
private ShardingStrategy<K> shardingStrategy;
public PreShardLeftSideFn(ShardingStrategy<K> shardingStrategy) {
this.shardingStrategy = shardingStrategy;
}
@Override
public void process(Pair<K, U> input, Emitter<Pair<Pair<K, Integer>, U>> emitter) {
K key = input.first();
int numShards = shardingStrategy.getNumShards(key);
if (numShards < 1) {
throw new IllegalArgumentException("Num shards must be > 0, got " + numShards + " for " + key);
}
for (int i = 0; i < numShards; i++) {
emitter.emit(Pair.of(Pair.of(key, i), input.second()));
}
}
}
private static class PreShardRightSideFn<K, V> extends MapFn<Pair<K, V>, Pair<Pair<K, Integer>, V>> {
private ShardingStrategy<K> shardingStrategy;
private transient Random random;
public PreShardRightSideFn(ShardingStrategy<K> shardingStrategy) {
this.shardingStrategy = shardingStrategy;
}
@Override
public void initialize() {
random = new Random(getTaskAttemptID().getTaskID().getId());
}
@Override
public Pair<Pair<K, Integer>, V> map(Pair<K, V> input) {
K key = input.first();
V value = input.second();
int numShards = shardingStrategy.getNumShards(key);
if (numShards < 1) {
throw new IllegalArgumentException("Num shards must be > 0, got " + numShards + " for " + key);
}
return Pair.of(Pair.of(key, random.nextInt(numShards)), value);
}
}
private static class UnshardFn<K, U, V> extends MapFn<Pair<Pair<K, Integer>, Pair<U, V>>, Pair<K, Pair<U, V>>> {
@Override
public Pair<K, Pair<U, V>> map(Pair<Pair<K, Integer>, Pair<U, V>> input) {
return Pair.of(input.first().first(), input.second());
}
}
/**
* Sharding strategy that returns the same number of shards for all keys.
*/
private static class ConstantShardingStrategy<K> implements ShardingStrategy<K> {
private int numShards;
public ConstantShardingStrategy(int numShards) {
this.numShards = numShards;
}
@Override
public int getNumShards(K key) {
return numShards;
}
}
}