| package storm.trident.operation.impl; |
| |
| import backtype.storm.tuple.Fields; |
| import java.util.ArrayList; |
| import java.util.List; |
| import java.util.Map; |
| import storm.trident.JoinType; |
| import storm.trident.operation.GroupedMultiReducer; |
| import storm.trident.operation.TridentCollector; |
| import storm.trident.operation.TridentMultiReducerContext; |
| import storm.trident.operation.impl.JoinerMultiReducer.JoinState; |
| import storm.trident.tuple.ComboList; |
| import storm.trident.tuple.TridentTuple; |
| |
| public class JoinerMultiReducer implements GroupedMultiReducer<JoinState> { |
| |
| List<JoinType> _types; |
| List<Fields> _sideFields; |
| int _numGroupFields; |
| ComboList.Factory _factory; |
| |
| |
| public JoinerMultiReducer(List<JoinType> types, int numGroupFields, List<Fields> sides) { |
| _types = types; |
| _sideFields = sides; |
| _numGroupFields = numGroupFields; |
| } |
| |
| @Override |
| public void prepare(Map conf, TridentMultiReducerContext context) { |
| int[] sizes = new int[_sideFields.size() + 1]; |
| sizes[0] = _numGroupFields; |
| for(int i=0; i<_sideFields.size(); i++) { |
| sizes[i+1] = _sideFields.get(i).size(); |
| } |
| _factory = new ComboList.Factory(sizes); |
| } |
| |
| @Override |
| public JoinState init(TridentCollector collector, TridentTuple group) { |
| return new JoinState(_types.size(), group); |
| } |
| |
| @Override |
| public void execute(JoinState state, int streamIndex, TridentTuple group, TridentTuple input, TridentCollector collector) { |
| //TODO: do the inner join incrementally, emitting the cross join with this tuple, against all other sides |
| //TODO: only do cross join if at least one tuple in each side |
| List<List> side = state.sides[streamIndex]; |
| if(side.isEmpty()) { |
| state.numSidesReceived++; |
| } |
| |
| side.add(input); |
| if(state.numSidesReceived == state.sides.length) { |
| emitCrossJoin(state, collector, streamIndex, input); |
| } |
| } |
| |
| @Override |
| public void complete(JoinState state, TridentTuple group, TridentCollector collector) { |
| List<List>[] sides = state.sides; |
| boolean wasEmpty = state.numSidesReceived < sides.length; |
| for(int i=0; i<sides.length; i++) { |
| if(sides[i].isEmpty() && _types.get(i) == JoinType.OUTER) { |
| state.numSidesReceived++; |
| sides[i].add(makeNullList(_sideFields.get(i).size())); |
| } |
| } |
| if(wasEmpty && state.numSidesReceived == sides.length) { |
| emitCrossJoin(state, collector, -1, null); |
| } |
| } |
| |
| @Override |
| public void cleanup() { |
| } |
| |
| private List<Object> makeNullList(int size) { |
| List<Object> ret = new ArrayList(size); |
| for(int i=0; i<size; i++) { |
| ret.add(null); |
| } |
| return ret; |
| } |
| |
| private void emitCrossJoin(JoinState state, TridentCollector collector, int overrideIndex, TridentTuple overrideTuple) { |
| List<List>[] sides = state.sides; |
| int[] indices = state.indices; |
| for(int i=0; i<indices.length; i++) { |
| indices[i] = 0; |
| } |
| |
| boolean keepGoing = true; |
| //emit cross-join of all emitted tuples |
| while(keepGoing) { |
| List[] combined = new List[sides.length+1]; |
| combined[0] = state.group; |
| for(int i=0; i<sides.length; i++) { |
| if(i==overrideIndex) { |
| combined[i+1] = overrideTuple; |
| } else { |
| combined[i+1] = sides[i].get(indices[i]); |
| } |
| } |
| collector.emit(_factory.create(combined)); |
| keepGoing = increment(sides, indices, indices.length - 1, overrideIndex); |
| } |
| } |
| |
| |
| //return false if can't increment anymore |
| //TODO: DRY this code up with what's in ChainedAggregatorImpl |
| private boolean increment(List[] lengths, int[] indices, int j, int overrideIndex) { |
| if(j==-1) return false; |
| if(j==overrideIndex) { |
| return increment(lengths, indices, j-1, overrideIndex); |
| } |
| indices[j]++; |
| if(indices[j] >= lengths[j].size()) { |
| indices[j] = 0; |
| return increment(lengths, indices, j-1, overrideIndex); |
| } |
| return true; |
| } |
| |
| public static class JoinState { |
| List<List>[] sides; |
| int numSidesReceived = 0; |
| int[] indices; |
| TridentTuple group; |
| |
| public JoinState(int numSides, TridentTuple group) { |
| sides = new List[numSides]; |
| indices = new int[numSides]; |
| this.group = group; |
| for(int i=0; i<numSides; i++) { |
| sides[i] = new ArrayList<List>(); |
| } |
| } |
| } |
| |
| } |