| package backtype.storm.task; |
| |
| import backtype.storm.Constants; |
| import backtype.storm.generated.Grouping; |
| import backtype.storm.topology.IRichBolt; |
| import backtype.storm.topology.OutputFieldsDeclarer; |
| import backtype.storm.tuple.Fields; |
| import backtype.storm.tuple.Tuple; |
| import backtype.storm.utils.Utils; |
| import java.io.Serializable; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.HashMap; |
| import java.util.Iterator; |
| import java.util.List; |
| import java.util.Map; |
| import org.apache.log4j.Logger; |
| import static backtype.storm.utils.Utils.get; |
| import static backtype.storm.utils.Utils.tuple; |
| |
| |
| public class CoordinatedBolt implements IRichBolt { |
| public static Logger LOG = Logger.getLogger(CoordinatedBolt.class); |
| |
| public static interface FinishedCallback { |
| void finishedId(Object id); |
| } |
| |
| public static class SourceArgs implements Serializable { |
| public boolean singleCount; |
| |
| protected SourceArgs(boolean singleCount) { |
| this.singleCount = singleCount; |
| } |
| |
| public static SourceArgs single() { |
| return new SourceArgs(true); |
| } |
| |
| public static SourceArgs all() { |
| return new SourceArgs(false); |
| } |
| } |
| |
| public class CoordinatedOutputCollector extends OutputCollector { |
| IOutputCollector _delegate; |
| |
| public CoordinatedOutputCollector(IOutputCollector delegate) { |
| _delegate = delegate; |
| } |
| |
| public List<Integer> emit(int stream, List<Tuple> anchors, List<Object> tuple) { |
| List<Integer> tasks = _delegate.emit(stream, anchors, tuple); |
| updateTaskCounts(tuple.get(0), tasks); |
| return tasks; |
| } |
| |
| public void emitDirect(int task, int stream, List<Tuple> anchors, List<Object> tuple) { |
| updateTaskCounts(tuple.get(0), Arrays.asList(task)); |
| _delegate.emitDirect(task, stream, anchors, tuple); |
| } |
| |
| public void ack(Tuple tuple) { |
| _delegate.ack(tuple); |
| Object id = tuple.getValue(0); |
| synchronized(_tracked) { |
| _tracked.get(id).receivedTuples++; |
| } |
| checkFinishId(id); |
| } |
| |
| public void fail(Tuple tuple) { |
| _delegate.fail(tuple); |
| Object id = tuple.getValue(0); |
| synchronized(_tracked) { |
| _tracked.get(id).receivedTuples++; |
| } |
| checkFinishId(id); |
| } |
| |
| public void reportError(Throwable error) { |
| _delegate.reportError(error); |
| } |
| |
| |
| private void updateTaskCounts(Object id, List<Integer> tasks) { |
| Map<Integer, Integer> taskEmittedTuples = _tracked.get(id).taskEmittedTuples; |
| for(Integer task: tasks) { |
| int newCount = get(taskEmittedTuples, task, 0) + 1; |
| taskEmittedTuples.put(task, newCount); |
| } |
| } |
| } |
| |
| private SourceArgs _sourceArgs; |
| private IRichBolt _delegate; |
| private Integer _numSourceReports; |
| private List<Integer> _countOutTasks = new ArrayList<Integer>();; |
| private OutputCollector _collector; |
| private Map<Object, TrackingInfo> _tracked = new HashMap<Object, TrackingInfo>(); |
| private boolean _allOut; |
| |
| public static class TrackingInfo { |
| int reportCount = 0; |
| int expectedTupleCount = 0; |
| int receivedTuples = 0; |
| Map<Integer, Integer> taskEmittedTuples = new HashMap<Integer, Integer>(); |
| |
| @Override |
| public String toString() { |
| return "reportCount: " + reportCount + "\n" + |
| "expectedTupleCount: " + expectedTupleCount + "\n" + |
| "receivedTuples: " + receivedTuples + "\n" + |
| taskEmittedTuples.toString(); |
| } |
| } |
| |
| |
| public CoordinatedBolt(IRichBolt delegate, SourceArgs sourceArgs) { |
| this(delegate, sourceArgs, false); |
| } |
| |
| public CoordinatedBolt(IRichBolt delegate) { |
| this(delegate, null); |
| } |
| |
| /** |
| * allOut indicates whether counts should be sent to all out tasks or just to those it sent tuples to |
| */ |
| public CoordinatedBolt(IRichBolt delegate, SourceArgs sourceArgs, boolean allOut) { |
| _sourceArgs = sourceArgs; |
| _delegate = delegate; |
| _allOut = allOut; |
| } |
| |
| public CoordinatedBolt(IRichBolt delegate, boolean allOut) { |
| this(delegate, null, allOut); |
| } |
| |
| |
| public void prepare(Map config, TopologyContext context, OutputCollector collector) { |
| _collector = collector; |
| _delegate.prepare(config, context, new CoordinatedOutputCollector(collector)); |
| for(Integer component: Utils.get(context.getThisTargets(), |
| Constants.COORDINATED_STREAM_ID, |
| new HashMap<Integer, Grouping>()) |
| .keySet()) { |
| for(Integer task: context.getComponentTasks(component)) { |
| _countOutTasks.add(task); |
| } |
| } |
| if(_sourceArgs!=null) { |
| if(_sourceArgs.singleCount) { |
| _numSourceReports = 1; |
| } else { |
| int sourceComponent = context.getThisSources().keySet().iterator().next().get_componentId(); |
| _numSourceReports = context.getComponentTasks(sourceComponent).size(); |
| } |
| } |
| } |
| |
| private void checkFinishId(Object id) { |
| synchronized(_tracked) { |
| TrackingInfo track = _tracked.get(id); |
| if(track!=null && |
| (_sourceArgs==null |
| || |
| track.reportCount==_numSourceReports && |
| track.expectedTupleCount == track.receivedTuples)) { |
| if(_delegate instanceof FinishedCallback) { |
| ((FinishedCallback)_delegate).finishedId(id); |
| } |
| Iterator<Integer> outTasks; |
| if(_allOut) outTasks = _countOutTasks.iterator(); |
| else outTasks = track.taskEmittedTuples.keySet().iterator(); |
| while(outTasks.hasNext()) { |
| int task = outTasks.next(); |
| int numTuples = get(track.taskEmittedTuples, task, 0); |
| _collector.emitDirect(task, Constants.COORDINATED_STREAM_ID, tuple(id, numTuples)); |
| } |
| //TODO: need a thread that clears this out occassionally (or wait until have a map type that does this automatically) |
| _tracked.remove(id); |
| } |
| } |
| } |
| |
| public void execute(Tuple tuple) { |
| Object id = tuple.getValue(0); |
| TrackingInfo track; |
| synchronized(_tracked) { |
| track = _tracked.get(id); |
| if(track==null) { |
| track = new TrackingInfo(); |
| _tracked.put(id, track); |
| } |
| } |
| |
| if(_sourceArgs!=null && tuple.getSourceStreamId()==Constants.COORDINATED_STREAM_ID) { |
| int count = (Integer) tuple.getValue(1); |
| synchronized(_tracked) { |
| track.reportCount++; |
| track.expectedTupleCount+=count; |
| } |
| checkFinishId(id); |
| } else { |
| _delegate.execute(tuple); |
| } |
| } |
| |
| public void cleanup() { |
| _delegate.cleanup(); |
| } |
| |
| public void declareOutputFields(OutputFieldsDeclarer declarer) { |
| _delegate.declareOutputFields(declarer); |
| declarer.declareStream(Constants.COORDINATED_STREAM_ID, true, new Fields("id", "count")); |
| } |
| |
| } |