blob: a252ff482c1d395d633f764358862d941c6f60dd [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.s4.example.model;
import net.jcip.annotations.ThreadSafe;
import org.apache.s4.base.Event;
import org.apache.s4.core.App;
import org.apache.s4.core.ProcessingElement;
import org.apache.s4.core.Stream;
import org.apache.s4.model.Model;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ThreadSafe
final public class ModelPE extends ProcessingElement {
private static final Logger logger = LoggerFactory.getLogger(ModelPE.class);
private long numVectors;
private Model model;
private Stream<ObsEvent> distanceStream;
private Stream<ResultEvent> resultStream;
private int modelId;
private double logPriorProb;
private long obsCount = 0;
private long totalCount = 0;
private int iteration = 0;
public ModelPE(App app) {
super(app);
}
/**
* @param numVectors the numVectors to set
*/
public void setNumVectors(long numVectors) {
this.numVectors = numVectors;
}
/**
* @return the number of training vectors.
*/
public long getNumVectors() {
return numVectors;
}
/**
* @param model
* the model to set
*/
public void setModel(Model model) {
this.model = model;
}
/**
* @return the model
*/
public Model getModel() {
return model;
}
/**
* Set the output streams.
*
* @param distanceStream
* sends an {@link ObsEvent} to the {@link MaximizerPE}.
* @param resultStream
* sends a {@link ResultEvent} to the {@link MetricsPE}.
*/
public void setStream(Stream<ObsEvent> distanceStream,
Stream<ResultEvent> resultStream) {
/* Init prototype. */
this.distanceStream = distanceStream;
this.resultStream = resultStream;
}
/**
* @return number of observation vectors used in training iteration.
*/
public long getObsCount() {
return obsCount;
}
/**
* @return current iteration.
*/
public int getIteration() {
return iteration;
}
private void updateStats(ObsEvent event) {
logger.trace("TRAINING: ModelID: {}, {}", modelId, event.toString());
model.update(event.getObsVector());
obsCount++;
/* Log info. */
if (obsCount % 10000 == 0) {
logger.info("Trained model using {} events with class id {}",
obsCount, modelId);
}
}
private void estimateModel() {
model.estimate();
double prob = (double) obsCount / numVectors;
logPriorProb = Math.log(prob);
logger.info("Prior prob: {}", prob);
logger.info("Update params for model {} is: {}", modelId,
model.toString());
obsCount = 0;
totalCount = 0;
model.clearStatistics();
/* Ready to start next iteration. */
iteration++;
}
public void onEvent(Event event) {
ObsEvent inEvent = (ObsEvent) event;
float[] obs = inEvent.getObsVector();
/* Estimate model parameters using the training data. */
if (inEvent.isTraining()) {
/*
* Ignore events with negative index. They are just used to create
* the PE.
*/
if (inEvent.getIndex() < 0) {
return;
}
if (++totalCount == numVectors) {
/* End of training stream. */
estimateModel();
/* Could send ack here. */
return;
}
/* Check if the event belongs to this class. */
if (inEvent.getClassId() == modelId) {
updateStats(inEvent);
} else {
/* Not needed to compute the mean vector. */
return;
}
} else { // scoring
if (inEvent.getHypId() < 0) {
/* Score observed vector and send it to the maximizer. */
float dist = (float) (model.logProb(obs) + logPriorProb);
ObsEvent outEvent = new ObsEvent(inEvent.getIndex(), obs, dist,
inEvent.getClassId(), modelId, false);
logger.trace(inEvent.getIndex() + " " + inEvent.getClassId()
+ " " + modelId + " " + model.logProb(obs) + " "
+ logPriorProb + " " + dist);
distanceStream.put(outEvent);
} else {
/* Send out result. */
if (resultStream != null) {
ResultEvent resultEvent = new ResultEvent(
inEvent.getIndex(), inEvent.getClassId(),
inEvent.getHypId());
resultStream.put(resultEvent);
}
}
}
}
@Override
protected void onCreate() {
this.modelId = Integer.parseInt(getId());
/*
* Initialize model. When a new PE instance is created we use the
* reference to the model in the PE prototype (initial value in variable
* model) to create a new model for this PE instance (final value in
* variable model).
*/
model = model.create();
}
@Override
protected void onRemove() {
}
}