blob: 61314e1a3fa454844d411df51256d71b81bfe76b [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.s4.example.model;
import java.util.concurrent.TimeUnit;
import org.apache.s4.core.App;
import org.apache.s4.core.ProcessingElement;
import org.apache.s4.core.Stream;
import org.apache.s4.model.Model;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class MyApp extends App {
private static final Logger logger = LoggerFactory.getLogger(MyApp.class);
final private int numClasses;
final private long numVectors;
final private int outputInterval;
final private TimeUnit timeUnit;
final private Model model;
private Stream<ObsEvent> obsStream;
private ModelPE modelPE;
private MetricsPE metricsPE;
Stream<ObsEvent> assignmentStream;
private String report;
MyApp(int numClasses, long numVectors, Model model, int outputInterval, TimeUnit timeUnit) {
super();
this.numClasses = numClasses;
this.numVectors = numVectors;
this.model = model;
this.outputInterval = outputInterval;
this.timeUnit = timeUnit;
}
public void injectToAll(ObsEvent obsEvent) {
logger.trace("Inject: " + obsEvent.toString());
obsStream.put(obsEvent);
}
public void injectByKey(ObsEvent obsEvent) {
logger.trace("Inject: " + obsEvent.toString());
assignmentStream.put(obsEvent);
}
@Override
protected void onStart() {
}
/* init() visibility is protected. Normally the Server class only should call init. */
public void initApp() {
init();
}
@Override
protected void onInit() {
metricsPE = createPE(MetricsPE.class);
Stream<ResultEvent> resultStream = createStream("Result Stream", new ResultKeyFinder(), metricsPE);
modelPE = createPE(ModelPE.class);
modelPE.setModel(model);
modelPE.setNumVectors(numVectors);
assignmentStream = createStream("Assignment Stream", new ClassIDKeyFinder(), modelPE);
MaximizerPE maximizerPE = createPE(MaximizerPE.class);
maximizerPE.setNumClasses(numClasses);
maximizerPE.setAssignmentStream(assignmentStream);
Stream<ObsEvent> distanceStream = createStream("Distance Stream", new ObsIndexKeyFinder(), maximizerPE);
/*
* There is a loop in this graph so we need to set the stream at the end. Is there a cleaner way to do this?
*/
modelPE.setStream(distanceStream, resultStream);
// modelPE.setOutputIntervalInEvents(10); // output every 10 events
metricsPE.setTimerInterval(outputInterval, timeUnit); // output every 5
// seconds
// obsStream = new Stream<ObsEvent>(this, "Observation Stream", new
// ClassIDKeyFinder(), modelPE);
obsStream = createStream("Observation Stream", modelPE);
}
/** @return true if modelPE is initialized. */
public boolean isInited() {
if (modelPE.getRemoteInstances().size() == numClasses)
return true;
else
return false;
}
/**
* Check if a training iteration is completed.
*
* @param iteration
* @return true when a training iteration is completed.
*/
public boolean isTrained(int iteration) {
int sum = 0;
for (ProcessingElement pe : modelPE.getRemoteInstances().values()) {
sum += ((ModelPE) pe).getIteration();
}
if (sum < iteration * numClasses)
return false;
else
return true;
}
/**
* Check if a test is complete.
*
* @param number
* of test vectors.
* @return true when testing is complete.
*/
public boolean isTested(long numTestVectors) {
long count = 0;
MetricsPE mpe = null;
for (ProcessingElement pe : metricsPE.getRemoteInstances().values()) {
mpe = (MetricsPE) pe;
count = mpe.getCount();
}
if (count < numTestVectors) {
return false;
} else {
report = mpe.toString();
return true;
}
}
/**
* @return a metrics report.
*/
public String getReport() {
return report;
}
/* close() visibility is protected. Normally the Server class only should call init. */
public void closeApp() {
close();
}
public long getObsCount() {
return modelPE.getObsCount();
}
@Override
protected void onClose() {
}
}