blob: 963751e5ab0ed71e7d30f519a3a3dd77865a011b [file] [log] [blame]
package storm.trident.operation.impl;
import backtype.storm.tuple.Fields;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import storm.trident.JoinType;
import storm.trident.operation.GroupedMultiReducer;
import storm.trident.operation.TridentCollector;
import storm.trident.operation.TridentMultiReducerContext;
import storm.trident.operation.impl.JoinerMultiReducer.JoinState;
import storm.trident.tuple.ComboList;
import storm.trident.tuple.TridentTuple;
public class JoinerMultiReducer implements GroupedMultiReducer<JoinState> {
List<JoinType> _types;
List<Fields> _sideFields;
int _numGroupFields;
ComboList.Factory _factory;
public JoinerMultiReducer(List<JoinType> types, int numGroupFields, List<Fields> sides) {
_types = types;
_sideFields = sides;
_numGroupFields = numGroupFields;
}
@Override
public void prepare(Map conf, TridentMultiReducerContext context) {
int[] sizes = new int[_sideFields.size() + 1];
sizes[0] = _numGroupFields;
for(int i=0; i<_sideFields.size(); i++) {
sizes[i+1] = _sideFields.get(i).size();
}
_factory = new ComboList.Factory(sizes);
}
@Override
public JoinState init(TridentCollector collector, TridentTuple group) {
return new JoinState(_types.size(), group);
}
@Override
public void execute(JoinState state, int streamIndex, TridentTuple group, TridentTuple input, TridentCollector collector) {
//TODO: do the inner join incrementally, emitting the cross join with this tuple, against all other sides
//TODO: only do cross join if at least one tuple in each side
List<List> side = state.sides[streamIndex];
if(side.isEmpty()) {
state.numSidesReceived++;
}
side.add(input);
if(state.numSidesReceived == state.sides.length) {
emitCrossJoin(state, collector, streamIndex, input);
}
}
@Override
public void complete(JoinState state, TridentTuple group, TridentCollector collector) {
List<List>[] sides = state.sides;
boolean wasEmpty = state.numSidesReceived < sides.length;
for(int i=0; i<sides.length; i++) {
if(sides[i].isEmpty() && _types.get(i) == JoinType.OUTER) {
state.numSidesReceived++;
sides[i].add(makeNullList(_sideFields.get(i).size()));
}
}
if(wasEmpty && state.numSidesReceived == sides.length) {
emitCrossJoin(state, collector, -1, null);
}
}
@Override
public void cleanup() {
}
private List<Object> makeNullList(int size) {
List<Object> ret = new ArrayList(size);
for(int i=0; i<size; i++) {
ret.add(null);
}
return ret;
}
private void emitCrossJoin(JoinState state, TridentCollector collector, int overrideIndex, TridentTuple overrideTuple) {
List<List>[] sides = state.sides;
int[] indices = state.indices;
for(int i=0; i<indices.length; i++) {
indices[i] = 0;
}
boolean keepGoing = true;
//emit cross-join of all emitted tuples
while(keepGoing) {
List[] combined = new List[sides.length+1];
combined[0] = state.group;
for(int i=0; i<sides.length; i++) {
if(i==overrideIndex) {
combined[i+1] = overrideTuple;
} else {
combined[i+1] = sides[i].get(indices[i]);
}
}
collector.emit(_factory.create(combined));
keepGoing = increment(sides, indices, indices.length - 1, overrideIndex);
}
}
//return false if can't increment anymore
//TODO: DRY this code up with what's in ChainedAggregatorImpl
private boolean increment(List[] lengths, int[] indices, int j, int overrideIndex) {
if(j==-1) return false;
if(j==overrideIndex) {
return increment(lengths, indices, j-1, overrideIndex);
}
indices[j]++;
if(indices[j] >= lengths[j].size()) {
indices[j] = 0;
return increment(lengths, indices, j-1, overrideIndex);
}
return true;
}
public static class JoinState {
List<List>[] sides;
int numSidesReceived = 0;
int[] indices;
TridentTuple group;
public JoinState(int numSides, TridentTuple group) {
sides = new List[numSides];
indices = new int[numSides];
this.group = group;
for(int i=0; i<numSides; i++) {
sides[i] = new ArrayList<List>();
}
}
}
}