| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.s4.example.model; |
| |
| import net.jcip.annotations.ThreadSafe; |
| |
| import org.apache.s4.base.Event; |
| import org.apache.s4.core.App; |
| import org.apache.s4.core.ProcessingElement; |
| import org.apache.s4.core.Stream; |
| import org.apache.s4.model.Model; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| |
| @ThreadSafe |
| final public class ModelPE extends ProcessingElement { |
| |
| private static final Logger logger = LoggerFactory.getLogger(ModelPE.class); |
| |
| private long numVectors; |
| private Model model; |
| private Stream<ObsEvent> distanceStream; |
| private Stream<ResultEvent> resultStream; |
| private int modelId; |
| private double logPriorProb; |
| private long obsCount = 0; |
| private long totalCount = 0; |
| private int iteration = 0; |
| |
| public ModelPE(App app) { |
| super(app); |
| } |
| |
| /** |
| * @param numVectors the numVectors to set |
| */ |
| public void setNumVectors(long numVectors) { |
| this.numVectors = numVectors; |
| } |
| |
| /** |
| * @return the number of training vectors. |
| */ |
| public long getNumVectors() { |
| return numVectors; |
| } |
| |
| /** |
| * @param model |
| * the model to set |
| */ |
| public void setModel(Model model) { |
| this.model = model; |
| } |
| |
| /** |
| * @return the model |
| */ |
| public Model getModel() { |
| return model; |
| } |
| |
| /** |
| * Set the output streams. |
| * |
| * @param distanceStream |
| * sends an {@link ObsEvent} to the {@link MaximizerPE}. |
| * @param resultStream |
| * sends a {@link ResultEvent} to the {@link MetricsPE}. |
| */ |
| public void setStream(Stream<ObsEvent> distanceStream, |
| Stream<ResultEvent> resultStream) { |
| |
| /* Init prototype. */ |
| this.distanceStream = distanceStream; |
| this.resultStream = resultStream; |
| } |
| |
| /** |
| * @return number of observation vectors used in training iteration. |
| */ |
| public long getObsCount() { |
| return obsCount; |
| } |
| |
| /** |
| * @return current iteration. |
| */ |
| public int getIteration() { |
| return iteration; |
| } |
| |
| private void updateStats(ObsEvent event) { |
| |
| logger.trace("TRAINING: ModelID: {}, {}", modelId, event.toString()); |
| model.update(event.getObsVector()); |
| |
| obsCount++; |
| |
| /* Log info. */ |
| if (obsCount % 10000 == 0) { |
| logger.info("Trained model using {} events with class id {}", |
| obsCount, modelId); |
| } |
| } |
| |
| private void estimateModel() { |
| |
| model.estimate(); |
| |
| double prob = (double) obsCount / numVectors; |
| logPriorProb = Math.log(prob); |
| logger.info("Prior prob: {}", prob); |
| logger.info("Update params for model {} is: {}", modelId, |
| model.toString()); |
| |
| obsCount = 0; |
| totalCount = 0; |
| model.clearStatistics(); |
| |
| /* Ready to start next iteration. */ |
| iteration++; |
| } |
| |
| public void onEvent(Event event) { |
| |
| ObsEvent inEvent = (ObsEvent) event; |
| float[] obs = inEvent.getObsVector(); |
| |
| /* Estimate model parameters using the training data. */ |
| if (inEvent.isTraining()) { |
| |
| /* |
| * Ignore events with negative index. They are just used to create |
| * the PE. |
| */ |
| if (inEvent.getIndex() < 0) { |
| return; |
| } |
| |
| if (++totalCount == numVectors) { |
| |
| /* End of training stream. */ |
| estimateModel(); |
| |
| /* Could send ack here. */ |
| |
| return; |
| } |
| |
| /* Check if the event belongs to this class. */ |
| if (inEvent.getClassId() == modelId) { |
| |
| updateStats(inEvent); |
| |
| } else { |
| |
| /* Not needed to compute the mean vector. */ |
| return; |
| } |
| |
| } else { // scoring |
| |
| if (inEvent.getHypId() < 0) { |
| /* Score observed vector and send it to the maximizer. */ |
| float dist = (float) (model.logProb(obs) + logPriorProb); |
| ObsEvent outEvent = new ObsEvent(inEvent.getIndex(), obs, dist, |
| inEvent.getClassId(), modelId, false); |
| |
| logger.trace(inEvent.getIndex() + " " + inEvent.getClassId() |
| + " " + modelId + " " + model.logProb(obs) + " " |
| + logPriorProb + " " + dist); |
| |
| distanceStream.put(outEvent); |
| |
| } else { |
| |
| /* Send out result. */ |
| if (resultStream != null) { |
| ResultEvent resultEvent = new ResultEvent( |
| inEvent.getIndex(), inEvent.getClassId(), |
| inEvent.getHypId()); |
| |
| resultStream.put(resultEvent); |
| } |
| } |
| } |
| } |
| |
| @Override |
| protected void onCreate() { |
| |
| this.modelId = Integer.parseInt(getId()); |
| |
| /* |
| * Initialize model. When a new PE instance is created we use the |
| * reference to the model in the PE prototype (initial value in variable |
| * model) to create a new model for this PE instance (final value in |
| * variable model). |
| */ |
| model = model.create(); |
| } |
| |
| @Override |
| protected void onRemove() { |
| |
| } |
| } |