blob: 9b04308347c9b2782e5829f9c1439faad0aa5983 [file] [log] [blame]
package org.apache.nifi.pql.evaluation.accumulation;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.nifi.pql.evaluation.Accumulator;
import org.apache.nifi.pql.evaluation.RecordEvaluator;
import org.apache.nifi.pql.groups.Group;
import org.apache.nifi.provenance.ProvenanceEventRecord;
public class AverageAccumulator implements Accumulator<Double> {
private final long id;
private final RecordEvaluator<Long> evaluator;
private final String label;
private final Map<Group, Values> values = new LinkedHashMap<>();
public AverageAccumulator(final long id, final RecordEvaluator<Long> extractor, final String label) {
this.id = id;
this.evaluator = extractor;
this.label = label;
}
@Override
public Map<Group, List<Double>> getValues() {
final Map<Group, List<Double>> avgs = new HashMap<>(values.size());
for ( final Map.Entry<Group, Values> entry : values.entrySet() ) {
final Values values = entry.getValue();
final double avg = (double) values.getSum() / (double) values.getCount();
avgs.put(entry.getKey(), Collections.singletonList(avg));
}
return avgs;
}
@Override
public List<Double> getValues(final Group group) {
final Values v = values.get(group);
if ( v == null ) {
return Collections.emptyList();
}
final double d = v.getSum() / v.getCount();
return Collections.singletonList(d);
}
public Double accumulate(final ProvenanceEventRecord record, final Group group) {
final Long val = evaluator.evaluate(record);
if ( val != null ) {
Values v = values.get(group);
if ( v == null ) {
v = new Values();
values.put(group, v);
}
v.increment(val.longValue());
return (double) v.getSum() / (double) v.getCount();
}
return null;
}
@Override
public String getLabel() {
return label;
}
@Override
public boolean isAggregateFunction() {
return true;
}
@Override
public Class<Double> getReturnType() {
return Double.class;
}
@Override
public AverageAccumulator clone() {
return new AverageAccumulator(id, evaluator, label);
}
@Override
public long getId() {
return id;
}
@Override
public void reset() {
values.clear();
}
private static class Values {
private long count;
private long sum;
public void increment(final long sum) {
count++;
this.sum += sum;
}
public long getCount() {
return count;
}
public long getSum() {
return sum;
}
}
}