blob: 48c0a977087f5503914c2e009254843d4e05df78 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package storm.trident.operation.impl;
import backtype.storm.tuple.Fields;
import java.util.List;
import java.util.Map;
import storm.trident.operation.Aggregator;
import storm.trident.operation.TridentCollector;
import storm.trident.operation.TridentOperationContext;
import storm.trident.tuple.ComboList;
import storm.trident.tuple.TridentTuple;
import storm.trident.tuple.TridentTupleView;
import storm.trident.tuple.TridentTupleView.ProjectionFactory;
public class ChainedAggregatorImpl implements Aggregator<ChainedResult> {
Aggregator[] _aggs;
ProjectionFactory[] _inputFactories;
ComboList.Factory _fact;
Fields[] _inputFields;
public ChainedAggregatorImpl(Aggregator[] aggs, Fields[] inputFields, ComboList.Factory fact) {
_aggs = aggs;
_inputFields = inputFields;
_fact = fact;
if (_aggs.length != _inputFields.length) {
throw new IllegalArgumentException("Require input fields for each aggregator");
}
}
public void prepare(Map conf, TridentOperationContext context) {
_inputFactories = new ProjectionFactory[_inputFields.length];
for (int i = 0; i < _inputFields.length; i++) {
_inputFactories[i] = context.makeProjectionFactory(_inputFields[i]);
_aggs[i].prepare(conf, new TridentOperationContext(context, _inputFactories[i]));
}
}
public ChainedResult init(Object batchId, TridentCollector collector) {
ChainedResult initted = new ChainedResult(collector, _aggs.length);
for (int i = 0; i < _aggs.length; i++) {
initted.objs[i] = _aggs[i].init(batchId, initted.collectors[i]);
}
return initted;
}
public void aggregate(ChainedResult val, TridentTuple tuple, TridentCollector collector) {
val.setFollowThroughCollector(collector);
for (int i = 0; i < _aggs.length; i++) {
TridentTuple projected = _inputFactories[i].create((TridentTupleView) tuple);
_aggs[i].aggregate(val.objs[i], projected, val.collectors[i]);
}
}
public void complete(ChainedResult val, TridentCollector collector) {
val.setFollowThroughCollector(collector);
for (int i = 0; i < _aggs.length; i++) {
_aggs[i].complete(val.objs[i], val.collectors[i]);
}
if (_aggs.length > 1) { // otherwise, tuples were emitted directly
int[] indices = new int[val.collectors.length];
for (int i = 0; i < indices.length; i++) {
indices[i] = 0;
}
boolean keepGoing = true;
// emit cross-join of all emitted tuples
while (keepGoing) {
List[] combined = new List[_aggs.length];
for (int i = 0; i < _aggs.length; i++) {
CaptureCollector capturer = (CaptureCollector) val.collectors[i];
combined[i] = capturer.captured.get(indices[i]);
}
collector.emit(_fact.create(combined));
keepGoing = increment(val.collectors, indices, indices.length - 1);
}
}
}
// return false if can't increment anymore
private boolean increment(TridentCollector[] lengths, int[] indices, int j) {
if (j == -1)
return false;
indices[j]++;
CaptureCollector capturer = (CaptureCollector) lengths[j];
if (indices[j] >= capturer.captured.size()) {
indices[j] = 0;
return increment(lengths, indices, j - 1);
}
return true;
}
public void cleanup() {
for (Aggregator a : _aggs) {
a.cleanup();
}
}
}