blob: c76555bc4f7f0bf4d37feb0e8d03848f757c3f5d [file] [log] [blame]
package backtype.storm.task;
import backtype.storm.Constants;
import backtype.storm.generated.Grouping;
import backtype.storm.topology.IRichBolt;
import backtype.storm.topology.OutputFieldsDeclarer;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Tuple;
import backtype.storm.utils.Utils;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;
import static backtype.storm.utils.Utils.get;
import static backtype.storm.utils.Utils.tuple;
public class CoordinatedBolt implements IRichBolt {
public static Logger LOG = Logger.getLogger(CoordinatedBolt.class);
public static interface FinishedCallback {
void finishedId(Object id);
}
public static class SourceArgs implements Serializable {
public boolean singleCount;
protected SourceArgs(boolean singleCount) {
this.singleCount = singleCount;
}
public static SourceArgs single() {
return new SourceArgs(true);
}
public static SourceArgs all() {
return new SourceArgs(false);
}
}
public class CoordinatedOutputCollector extends OutputCollector {
IOutputCollector _delegate;
public CoordinatedOutputCollector(IOutputCollector delegate) {
_delegate = delegate;
}
public List<Integer> emit(int stream, List<Tuple> anchors, List<Object> tuple) {
List<Integer> tasks = _delegate.emit(stream, anchors, tuple);
updateTaskCounts(tuple.get(0), tasks);
return tasks;
}
public void emitDirect(int task, int stream, List<Tuple> anchors, List<Object> tuple) {
updateTaskCounts(tuple.get(0), Arrays.asList(task));
_delegate.emitDirect(task, stream, anchors, tuple);
}
public void ack(Tuple tuple) {
_delegate.ack(tuple);
Object id = tuple.getValue(0);
synchronized(_tracked) {
_tracked.get(id).receivedTuples++;
}
checkFinishId(id);
}
public void fail(Tuple tuple) {
_delegate.fail(tuple);
Object id = tuple.getValue(0);
synchronized(_tracked) {
_tracked.get(id).receivedTuples++;
}
checkFinishId(id);
}
public void reportError(Throwable error) {
_delegate.reportError(error);
}
private void updateTaskCounts(Object id, List<Integer> tasks) {
Map<Integer, Integer> taskEmittedTuples = _tracked.get(id).taskEmittedTuples;
for(Integer task: tasks) {
int newCount = get(taskEmittedTuples, task, 0) + 1;
taskEmittedTuples.put(task, newCount);
}
}
}
private SourceArgs _sourceArgs;
private IRichBolt _delegate;
private Integer _numSourceReports;
private List<Integer> _countOutTasks = new ArrayList<Integer>();;
private OutputCollector _collector;
private Map<Object, TrackingInfo> _tracked = new HashMap<Object, TrackingInfo>();
private boolean _allOut;
public static class TrackingInfo {
int reportCount = 0;
int expectedTupleCount = 0;
int receivedTuples = 0;
Map<Integer, Integer> taskEmittedTuples = new HashMap<Integer, Integer>();
@Override
public String toString() {
return "reportCount: " + reportCount + "\n" +
"expectedTupleCount: " + expectedTupleCount + "\n" +
"receivedTuples: " + receivedTuples + "\n" +
taskEmittedTuples.toString();
}
}
public CoordinatedBolt(IRichBolt delegate, SourceArgs sourceArgs) {
this(delegate, sourceArgs, false);
}
public CoordinatedBolt(IRichBolt delegate) {
this(delegate, null);
}
/**
* allOut indicates whether counts should be sent to all out tasks or just to those it sent tuples to
*/
public CoordinatedBolt(IRichBolt delegate, SourceArgs sourceArgs, boolean allOut) {
_sourceArgs = sourceArgs;
_delegate = delegate;
_allOut = allOut;
}
public CoordinatedBolt(IRichBolt delegate, boolean allOut) {
this(delegate, null, allOut);
}
public void prepare(Map config, TopologyContext context, OutputCollector collector) {
_collector = collector;
_delegate.prepare(config, context, new CoordinatedOutputCollector(collector));
for(Integer component: Utils.get(context.getThisTargets(),
Constants.COORDINATED_STREAM_ID,
new HashMap<Integer, Grouping>())
.keySet()) {
for(Integer task: context.getComponentTasks(component)) {
_countOutTasks.add(task);
}
}
if(_sourceArgs!=null) {
if(_sourceArgs.singleCount) {
_numSourceReports = 1;
} else {
int sourceComponent = context.getThisSources().keySet().iterator().next().get_componentId();
_numSourceReports = context.getComponentTasks(sourceComponent).size();
}
}
}
private void checkFinishId(Object id) {
synchronized(_tracked) {
TrackingInfo track = _tracked.get(id);
if(track!=null &&
(_sourceArgs==null
||
track.reportCount==_numSourceReports &&
track.expectedTupleCount == track.receivedTuples)) {
if(_delegate instanceof FinishedCallback) {
((FinishedCallback)_delegate).finishedId(id);
}
Iterator<Integer> outTasks;
if(_allOut) outTasks = _countOutTasks.iterator();
else outTasks = track.taskEmittedTuples.keySet().iterator();
while(outTasks.hasNext()) {
int task = outTasks.next();
int numTuples = get(track.taskEmittedTuples, task, 0);
_collector.emitDirect(task, Constants.COORDINATED_STREAM_ID, tuple(id, numTuples));
}
//TODO: need a thread that clears this out occassionally (or wait until have a map type that does this automatically)
_tracked.remove(id);
}
}
}
public void execute(Tuple tuple) {
Object id = tuple.getValue(0);
TrackingInfo track;
synchronized(_tracked) {
track = _tracked.get(id);
if(track==null) {
track = new TrackingInfo();
_tracked.put(id, track);
}
}
if(_sourceArgs!=null && tuple.getSourceStreamId()==Constants.COORDINATED_STREAM_ID) {
int count = (Integer) tuple.getValue(1);
synchronized(_tracked) {
track.reportCount++;
track.expectedTupleCount+=count;
}
checkFinishId(id);
} else {
_delegate.execute(tuple);
}
}
public void cleanup() {
_delegate.cleanup();
}
public void declareOutputFields(OutputFieldsDeclarer declarer) {
_delegate.declareOutputFields(declarer);
declarer.declareStream(Constants.COORDINATED_STREAM_ID, true, new Fields("id", "count"));
}
}