Merged master into LENS-581
diff --git a/lens-ml-lib/data/lr/lr_train.data b/lens-ml-lib/data/lr/lr_train.data
new file mode 100644
index 0000000..8ecf84c
--- /dev/null
+++ b/lens-ml-lib/data/lr/lr_train.data
@@ -0,0 +1,11 @@
+1 2 3 0 0
+0 3 5 1 1
+1 3 5 1 1
+0 6 7 0 0
+0 5 1 1 1
+0 3 1 1 1
+1 8 0 0 0
+0 1 1 0 0
+0 3 1 0 1
+1 3 0 0 0
+1 4 9 1 0
\ No newline at end of file
diff --git a/lens-ml-lib/pom.xml b/lens-ml-lib/pom.xml
index 67215e1..1f4f9d8 100644
--- a/lens-ml-lib/pom.xml
+++ b/lens-ml-lib/pom.xml
@@ -136,6 +136,15 @@
<artifactId>jersey-test-framework-core</artifactId>
</dependency>
+ <dependency>
+ <groupId>commons-dbutils</groupId>
+ <artifactId>commons-dbutils</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>mysql</groupId>
+ <artifactId>mysql-connector-java</artifactId>
+ <version>5.1.6</version>
+ </dependency>
</dependencies>
<build>
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
index 6dd0ecf..3f4bee9 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLClient.java
@@ -20,41 +20,57 @@
import java.io.Closeable;
import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.net.URI;
-import java.net.URISyntaxException;
-import java.util.Arrays;
-import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import javax.ws.rs.core.Form;
-
+import org.apache.lens.api.APIResult;
import org.apache.lens.api.LensSessionHandle;
-import org.apache.lens.ml.algo.api.MLAlgo;
-import org.apache.lens.ml.algo.api.MLModel;
-import org.apache.lens.ml.api.LensML;
-import org.apache.lens.ml.api.MLTestReport;
-import org.apache.lens.ml.api.ModelMetadata;
-import org.apache.lens.ml.api.TestReport;
+import org.apache.lens.ml.api.*;
+import org.apache.lens.ml.server.MLService;
+import org.apache.lens.server.api.LensConfConstants;
+import org.apache.lens.server.api.ServiceProvider;
+import org.apache.lens.server.api.ServiceProviderFactory;
import org.apache.lens.server.api.error.LensException;
-import org.apache.commons.lang.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.fs.FileSystem;
-import org.apache.hadoop.fs.Path;
-
-import lombok.extern.slf4j.Slf4j;
-
-/**
- * Client side implementation of LensML
- */
-@Slf4j
public class LensMLClient implements LensML, Closeable {
+ private static final Log LOG = LogFactory.getLog(LensMLClient.class);
+ private static final HiveConf HIVE_CONF;
- /** The client. */
+ static {
+ HIVE_CONF = new HiveConf();
+ // Add default config so that we know the service provider implementation
+ HIVE_CONF.addResource("lensserver-default.xml");
+ HIVE_CONF.addResource("lens-site.xml");
+ }
+
+ /**
+ * The ml service.
+ */
+ MLService mlService;
+ /**
+ * The service provider.
+ */
+ ServiceProvider serviceProvider;
+ /**
+ * The service provider factory.
+ */
+ ServiceProviderFactory serviceProviderFactory;
+ /**
+ * The client.
+ */
private LensMLJerseyClient client;
+ /**
+ * Instantiates a new ML service resource.
+ */
+ public LensMLClient() {
+
+ }
+
public LensMLClient(String password) {
this(new LensClientConfig(), password);
}
@@ -72,234 +88,183 @@
}
public LensMLClient(LensClient lensClient) {
- client = new LensMLJerseyClient(lensClient.getConnection(), lensClient
- .getConnection().getSessionHandle());
+ client = new LensMLJerseyClient(lensClient.getConnection(), lensClient.getConnection().getSessionHandle());
+ serviceProviderFactory = getServiceProviderFactory(HIVE_CONF);
}
- /**
- * Get list of available machine learning algorithms
- *
- * @return
- */
- @Override
- public List<String> getAlgorithms() {
- return client.getAlgoNames();
- }
-
- /**
- * Get user friendly information about parameters accepted by the algorithm.
- *
- * @param algorithm the algorithm
- * @return map of param key to its help message
- */
- @Override
- public Map<String, String> getAlgoParamDescription(String algorithm) {
- List<String> paramDesc = client.getParamDescriptionOfAlgo(algorithm);
- // convert paramDesc to map
- Map<String, String> paramDescMap = new LinkedHashMap<String, String>();
- for (String str : paramDesc) {
- String[] keyHelp = StringUtils.split(str, ":");
- paramDescMap.put(keyHelp[0].trim(), keyHelp[1].trim());
+ private ServiceProvider getServiceProvider() {
+ if (serviceProvider == null) {
+ serviceProvider = serviceProviderFactory.getServiceProvider();
}
- return paramDescMap;
+ return serviceProvider;
}
/**
- * Get a algo object instance which could be used to generate a model of the given algorithm.
+ * Gets the service provider factory.
*
- * @param algorithm the algorithm
- * @return the algo for name
- * @throws LensException the lens exception
+ * @param conf the conf
+ * @return the service provider factory
*/
- @Override
- public MLAlgo getAlgoForName(String algorithm) throws LensException {
- throw new UnsupportedOperationException("MLAlgo cannot be accessed from client");
- }
-
- /**
- * Create a model using the given HCatalog table as input. The arguments should contain information needeed to
- * generate the model.
- *
- * @param table the table
- * @param algorithm the algorithm
- * @param args the args
- * @return Unique ID of the model created after training is complete
- * @throws LensException the lens exception
- */
- @Override
- public String train(String table, String algorithm, String[] args) throws LensException {
- Form trainParams = new Form();
- trainParams.param("table", table);
- for (int i = 0; i < args.length; i += 2) {
- trainParams.param(args[i], args[i + 1]);
- }
- return client.trainModel(algorithm, trainParams);
- }
-
- /**
- * Get model IDs for the given algorithm.
- *
- * @param algorithm the algorithm
- * @return the models
- * @throws LensException the lens exception
- */
- @Override
- public List<String> getModels(String algorithm) throws LensException {
- return client.getModelsForAlgorithm(algorithm);
- }
-
- /**
- * Get a model instance given the algorithm name and model ID.
- *
- * @param algorithm the algorithm
- * @param modelId the model id
- * @return the model
- * @throws LensException the lens exception
- */
- @Override
- public MLModel getModel(String algorithm, String modelId) throws LensException {
- ModelMetadata metadata = client.getModelMetadata(algorithm, modelId);
- String modelPathURI = metadata.getModelPath();
-
- ObjectInputStream in = null;
+ private ServiceProviderFactory getServiceProviderFactory(HiveConf conf) {
+ Class<?> spfClass = conf.getClass(LensConfConstants.SERVICE_PROVIDER_FACTORY, ServiceProviderFactory.class);
try {
- URI modelURI = new URI(modelPathURI);
- Path modelPath = new Path(modelURI);
- FileSystem fs = FileSystem.get(modelURI, client.getConf());
- in = new ObjectInputStream(fs.open(modelPath));
- MLModel<?> model = (MLModel) in.readObject();
- return model;
- } catch (IOException e) {
- throw new LensException(e);
- } catch (URISyntaxException e) {
- throw new LensException(e);
- } catch (ClassNotFoundException e) {
- throw new LensException(e);
- } finally {
- if (in != null) {
- try {
- in.close();
- } catch (IOException e) {
- log.error("Error closing stream.", e);
- }
- }
+ return (ServiceProviderFactory) spfClass.newInstance();
+ } catch (InstantiationException e) {
+ throw new RuntimeException(e);
+ } catch (IllegalAccessException e) {
+ throw new RuntimeException(e);
}
-
}
- /**
- * Get the FS location where model instance is saved.
- *
- * @param algorithm the algorithm
- * @param modelID the model id
- * @return the model path
- */
- @Override
- public String getModelPath(String algorithm, String modelID) {
- ModelMetadata metadata = client.getModelMetadata(algorithm, modelID);
- return metadata.getModelPath();
+ private MLService getMlService() {
+ if (mlService == null) {
+ mlService = (MLService) getServiceProvider().getService(MLService.NAME);
+ }
+ return mlService;
}
- /**
- * Evaluate model by running it against test data contained in the given table.
- *
- * @param session the session
- * @param table the table
- * @param algorithm the algorithm
- * @param modelID the model id
- * @return Test report object containing test output table, and various evaluation metrics
- * @throws LensException the lens exception
- */
- @Override
- public MLTestReport testModel(LensSessionHandle session, String table, String algorithm, String modelID,
- String outputTable) throws LensException {
- String reportID = client.testModel(table, algorithm, modelID, outputTable);
- return getTestReport(algorithm, reportID);
- }
-
- /**
- * Get test reports for an algorithm.
- *
- * @param algorithm the algorithm
- * @return the test reports
- * @throws LensException the lens exception
- */
- @Override
- public List<String> getTestReports(String algorithm) throws LensException {
- return client.getTestReportsOfAlgorithm(algorithm);
- }
-
- /**
- * Get a test report by ID.
- *
- * @param algorithm the algorithm
- * @param reportID the report id
- * @return the test report
- * @throws LensException the lens exception
- */
- @Override
- public MLTestReport getTestReport(String algorithm, String reportID) throws LensException {
- TestReport report = client.getTestReport(algorithm, reportID);
- MLTestReport mlTestReport = new MLTestReport();
- mlTestReport.setAlgorithm(report.getAlgorithm());
- mlTestReport.setFeatureColumns(Arrays.asList(report.getFeatureColumns().split("\\,+")));
- mlTestReport.setLensQueryID(report.getQueryID());
- mlTestReport.setLabelColumn(report.getLabelColumn());
- mlTestReport.setModelID(report.getModelID());
- mlTestReport.setOutputColumn(report.getOutputColumn());
- mlTestReport.setPredictionResultColumn(report.getOutputColumn());
- mlTestReport.setQueryID(report.getQueryID());
- mlTestReport.setReportID(report.getReportID());
- mlTestReport.setTestTable(report.getTestTable());
- return mlTestReport;
- }
-
- /**
- * Online predict call given a model ID, algorithm name and sample feature values.
- *
- * @param algorithm the algorithm
- * @param modelID the model id
- * @param features the features
- * @return prediction result
- * @throws LensException the lens exception
- */
- @Override
- public Object predict(String algorithm, String modelID, Object[] features) throws LensException {
- return getModel(algorithm, modelID).predict(features);
- }
-
- /**
- * Permanently delete a model instance.
- *
- * @param algorithm the algorithm
- * @param modelID the model id
- * @throws LensException the lens exception
- */
- @Override
- public void deleteModel(String algorithm, String modelID) throws LensException {
- client.deleteModel(algorithm, modelID);
- }
-
- /**
- * Permanently delete a test report instance.
- *
- * @param algorithm the algorithm
- * @param reportID the report id
- * @throws LensException the lens exception
- */
- @Override
- public void deleteTestReport(String algorithm, String reportID) throws LensException {
- client.deleteTestReport(algorithm, reportID);
- }
-
- /**
- * Close connection
- */
@Override
public void close() throws IOException {
client.close();
}
+ @Override
+ public List<Algo> getAlgos() {
+ return getMlService().getAlgos();
+ }
+
+ public List<String> getAlgoNames() {
+ return null;
+ }
+
+ @Override
+ public Algo getAlgo(String name) throws LensException {
+ return null;
+ }
+
+ @Override
+ public void createDataSet(String name, String dataTable, String dataBase) throws LensException {
+
+ client.createDataSet(name, dataTable, dataBase);
+ }
+
+ public void createDataSet(DataSet dataSet) throws LensException {
+
+ }
+
+ public void test() {
+ client.test();
+ }
+
+ @Override
+ public String createDataSetFromQuery(String name, String query) {
+ return null;
+ }
+
+ @Override
+ public DataSet getDataSet(String name) throws LensException {
+ return client.getDataSet(name);
+ }
+
+ @Override
+ public void createModel(String name, String algo, Map<String, String> algoParams, List<Feature> features,
+ Feature label, LensSessionHandle lensSessionHandle) throws LensException {
+ APIResult result = client.createModel(name, algo, algoParams, features, label);
+ }
+
+ @Override
+ public void createModel(Model model) throws LensException {
+
+ }
+
+ @Override
+ public Model getModel(String modelId) throws LensException {
+ return client.getModel(modelId);
+ }
+
+ @Override
+ public String trainModel(String modelId, String dataSetName, LensSessionHandle lensSessionHandle)
+ throws LensException {
+ return client.tranModel(modelId, dataSetName, lensSessionHandle);
+ }
+
+ @Override
+ public ModelInstance getModelInstance(String modelInstanceId) throws LensException {
+ return client.getModelInstance(modelInstanceId);
+ }
+
+ @Override
+ public List<ModelInstance> getAllModelInstances(String modelId) {
+ return null;
+ }
+
+ @Override
+ public String evaluate(String modelInstanceId, String dataSetName, LensSessionHandle lensSessionHandle)
+ throws LensException {
+ return client.evaluate(modelInstanceId, dataSetName, lensSessionHandle);
+ }
+
+ @Override
+ public Evaluation getEvaluation(String evalId) throws LensException {
+ return client.getEvaluation(evalId);
+ }
+
+ @Override
+ public String predict(String modelInstanceId, String dataSetName, LensSessionHandle lensSessionHandle)
+ throws LensException {
+ return client.predict(modelInstanceId, dataSetName, lensSessionHandle);
+ }
+
+ @Override
+ public boolean cancelModelInstance(String modelInstanceId, LensSessionHandle lensSessionHandle) throws LensException {
+ return client.cancelModelInstance(modelInstanceId, lensSessionHandle);
+ }
+
+ @Override
+ public boolean cancelEvaluation(String evalId, LensSessionHandle lensSessionHandle) throws LensException {
+ return client.cancelEvaluation(evalId, lensSessionHandle);
+ }
+
+ @Override
+ public boolean cancelPrediction(String predicitonId, LensSessionHandle lensSessionHandle) throws LensException {
+ return client.cancelPrediction(predicitonId, lensSessionHandle);
+ }
+
+ @Override
+ public Prediction getPrediction(String predictionId) throws LensException {
+ return client.getPrediction(predictionId);
+ }
+
+ @Override
+ public String predict(String modelInstanceId, Map<String, String> featureVector) throws LensException {
+ return getMlService().predict(modelInstanceId, featureVector);
+ }
+
+ @Override
+ public void deleteModel(String modelId) throws LensException {
+ client.deleteModel(modelId);
+ }
+
+ @Override
+ public void deleteDataSet(String dataSetName) throws LensException {
+ client.deleteDataSet(dataSetName);
+ }
+
+ @Override
+ public void deleteModelInstance(String modelInstanceId) throws LensException {
+ client.deleteModelInstance(modelInstanceId);
+ }
+
+ @Override
+ public void deleteEvaluation(String evaluationId) throws LensException {
+ client.deleteEvaluation(evaluationId);
+ }
+
+ @Override
+ public void deletePrediction(String predictionId) throws LensException {
+ client.deletePrediction(predictionId);
+ }
+
public LensSessionHandle getSessionHandle() {
return client.getSessionHandle();
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
index 2ccdf2a..f693cb7 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/client/LensMLJerseyClient.java
@@ -21,23 +21,19 @@
import java.util.List;
import java.util.Map;
-import javax.ws.rs.NotFoundException;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.WebTarget;
-import javax.ws.rs.core.Form;
import javax.ws.rs.core.MediaType;
+import org.apache.lens.api.APIResult;
import org.apache.lens.api.LensSessionHandle;
import org.apache.lens.api.StringList;
-import org.apache.lens.ml.api.ModelMetadata;
-import org.apache.lens.ml.api.TestReport;
+import org.apache.lens.ml.api.*;
+import org.apache.lens.server.api.error.LensException;
-import org.apache.hadoop.conf.Configuration;
-
-import org.glassfish.jersey.media.multipart.FormDataBodyPart;
-import org.glassfish.jersey.media.multipart.FormDataContentDisposition;
-import org.glassfish.jersey.media.multipart.FormDataMultiPart;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import lombok.extern.slf4j.Slf4j;
@@ -50,13 +46,25 @@
*/
@Slf4j
public class LensMLJerseyClient {
- /** The Constant LENS_ML_RESOURCE_PATH. */
+
+ /**
+ * The Constant LENS_ML_RESOURCE_PATH.
+ */
public static final String LENS_ML_RESOURCE_PATH = "lens.ml.resource.path";
- /** The Constant DEFAULT_ML_RESOURCE_PATH. */
+ /**
+ * The Constant DEFAULT_ML_RESOURCE_PATH.
+ */
public static final String DEFAULT_ML_RESOURCE_PATH = "ml";
- /** The connection. */
+ /**
+ * The Constant LOG.
+ */
+ public static final Log LOG = LogFactory.getLog(LensMLJerseyClient.class);
+
+ /**
+ * The connection.
+ */
private final LensConnection connection;
private final LensSessionHandle sessionHandle;
@@ -82,6 +90,14 @@
this.sessionHandle = sessionHandle;
}
+ public void close() {
+ try {
+ connection.close();
+ } catch (Exception exc) {
+ LOG.error("Error closing connection", exc);
+ }
+ }
+
protected WebTarget getMLWebTarget() {
Client client = connection.buildClient();
LensConnectionParams connParams = connection.getLensConnectionParams();
@@ -90,171 +106,9 @@
return client.target(baseURI).path(mlURI);
}
- /**
- * Gets the model metadata.
- *
- * @param algorithm the algorithm
- * @param modelID the model id
- * @return the model metadata
- */
- public ModelMetadata getModelMetadata(String algorithm, String modelID) {
- try {
- return getMLWebTarget().path("models").path(algorithm).path(modelID).request().get(ModelMetadata.class);
- } catch (NotFoundException exc) {
- return null;
- }
- }
-
- /**
- * Delete model.
- *
- * @param algorithm the algorithm
- * @param modelID the model id
- */
- public void deleteModel(String algorithm, String modelID) {
- getMLWebTarget().path("models").path(algorithm).path(modelID).request().delete();
- }
-
- /**
- * Gets the models for algorithm.
- *
- * @param algorithm the algorithm
- * @return the models for algorithm
- */
- public List<String> getModelsForAlgorithm(String algorithm) {
- try {
- StringList models = getMLWebTarget().path("models").path(algorithm).request().get(StringList.class);
- return models == null ? null : models.getElements();
- } catch (NotFoundException exc) {
- return null;
- }
- }
-
public List<String> getAlgoNames() {
- StringList algoNames = getMLWebTarget().path("algos").request().get(StringList.class);
- return algoNames == null ? null : algoNames.getElements();
- }
-
- /**
- * Train model.
- *
- * @param algorithm the algorithm
- * @param params the params
- * @return the string
- */
- public String trainModel(String algorithm, Form params) {
- return getMLWebTarget().path(algorithm).path("train").request(MediaType.APPLICATION_JSON_TYPE)
- .post(Entity.entity(params, MediaType.APPLICATION_FORM_URLENCODED_TYPE), String.class);
- }
-
- /**
- * Test model.
- *
- * @param table the table
- * @param algorithm the algorithm
- * @param modelID the model id
- * @param outputTable the output table name
- * @return the string
- */
- public String testModel(String table, String algorithm, String modelID, String outputTable) {
- WebTarget modelTestTarget = getMLWebTarget().path("test").path(table).path(algorithm).path(modelID);
-
- FormDataMultiPart mp = new FormDataMultiPart();
-
- LensSessionHandle sessionHandle = this.sessionHandle == null ? connection.getSessionHandle() : this.sessionHandle;
-
- mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("sessionid").build(), sessionHandle,
- MediaType.APPLICATION_XML_TYPE));
-
- mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("outputTable").build(), outputTable));
- return modelTestTarget.request().post(Entity.entity(mp, MediaType.MULTIPART_FORM_DATA_TYPE), String.class);
- }
-
- /**
- * Gets the test reports of algorithm.
- *
- * @param algorithm the algorithm
- * @return the test reports of algorithm
- */
- public List<String> getTestReportsOfAlgorithm(String algorithm) {
- try {
- StringList list = getMLWebTarget().path("reports").path(algorithm).request().get(StringList.class);
- return list == null ? null : list.getElements();
- } catch (NotFoundException exc) {
- return null;
- }
- }
-
- /**
- * Gets the test report.
- *
- * @param algorithm the algorithm
- * @param reportID the report id
- * @return the test report
- */
- public TestReport getTestReport(String algorithm, String reportID) {
- try {
- return getMLWebTarget().path("reports").path(algorithm).path(reportID).request().get(TestReport.class);
- } catch (NotFoundException exc) {
- return null;
- }
- }
-
- /**
- * Delete test report.
- *
- * @param algorithm the algorithm
- * @param reportID the report id
- * @return the string
- */
- public String deleteTestReport(String algorithm, String reportID) {
- return getMLWebTarget().path("reports").path(algorithm).path(reportID).request().delete(String.class);
- }
-
- /**
- * Predict single.
- *
- * @param algorithm the algorithm
- * @param modelID the model id
- * @param features the features
- * @return the string
- */
- public String predictSingle(String algorithm, String modelID, Map<String, String> features) {
- WebTarget target = getMLWebTarget().path("predict").path(algorithm).path(modelID);
-
- for (Map.Entry<String, String> entry : features.entrySet()) {
- target.queryParam(entry.getKey(), entry.getValue());
- }
-
- return target.request().get(String.class);
- }
-
- /**
- * Gets the param description of algo.
- *
- * @param algorithm the algorithm
- * @return the param description of algo
- */
- public List<String> getParamDescriptionOfAlgo(String algorithm) {
- try {
- StringList paramHelp = getMLWebTarget().path("algos").path(algorithm).request(MediaType.APPLICATION_XML)
- .get(StringList.class);
- return paramHelp.getElements();
- } catch (NotFoundException exc) {
- return null;
- }
- }
-
- public Configuration getConf() {
- return connection.getLensConnectionParams().getConf();
- }
-
- public void close() {
- try {
- connection.close();
- } catch (Exception exc) {
- log.error("Error closing connection", exc);
- }
+ StringList algoNames = getMLWebTarget().path("algonames").request().get(StringList.class);
+ return algoNames.getElements() == null ? null : algoNames.getElements();
}
public LensSessionHandle getSessionHandle() {
@@ -263,4 +117,123 @@
}
return connection.getSessionHandle();
}
+
+
+ public void test() {
+
+ Evaluation result = getMLWebTarget().path("evaluation")
+ .request(MediaType.APPLICATION_XML)
+ .get(Evaluation.class);
+
+ }
+
+ public APIResult createDataSet(String dataSetName, String dataTableName, String dataBase) {
+ WebTarget target = getMLWebTarget();
+ DataSet dataSet = new DataSet(dataSetName, dataTableName, dataBase);
+ APIResult result = target.path("dataset")
+ .request(MediaType.APPLICATION_XML)
+ .post(Entity.xml(dataSet), APIResult.class);
+ return result;
+ }
+
+ public DataSet getDataSet(String dataSetName) {
+ WebTarget target = getMLWebTarget();
+ DataSet dataSet = target.path("dataset").queryParam("dataSetName", dataSetName).request(MediaType
+ .APPLICATION_XML).get(DataSet.class);
+ return dataSet;
+ }
+
+ public APIResult createModel(String name, String algo, Map<String, String> algoParams,
+ List<Feature> features, Feature label) {
+ Model model = new Model(name, new AlgoSpec(algo, algoParams), features, label);
+ WebTarget target = getMLWebTarget();
+ APIResult result =
+ target.path("models").request(MediaType.APPLICATION_XML).post(Entity.xml(model), APIResult.class);
+ return result;
+ }
+
+ public Model getModel(String modelName) {
+ WebTarget target = getMLWebTarget();
+ Model model = target.path("models").queryParam("modelName", modelName).request(MediaType.APPLICATION_XML)
+ .get(Model.class);
+ return model;
+ }
+
+ public String tranModel(String modelId, String dataSetName, LensSessionHandle lensSessionHandle) {
+ WebTarget target = getMLWebTarget();
+ return target.path("train").queryParam("modelId", modelId).queryParam("dataSetName",
+ dataSetName).queryParam("lensSessionHandle", lensSessionHandle.toString()).request(MediaType.APPLICATION_XML)
+ .get(String.class);
+ }
+
+ ModelInstance getModelInstance(String modelInstanceId) {
+ WebTarget target = getMLWebTarget();
+ return target.path("modelinstance/" + modelInstanceId).request(MediaType.APPLICATION_XML).get(ModelInstance.class);
+ }
+
+ boolean cancelModelInstance(String modelInstanceId, LensSessionHandle lensSessionHandle) {
+ WebTarget target = getMLWebTarget();
+ return target.path("modelinstance/" + modelInstanceId).queryParam("lensSessionHandle", lensSessionHandle.toString())
+ .request(MediaType.APPLICATION_XML).delete(Boolean.class);
+ }
+
+ String evaluate(String modelInstanceId, String dataSetName, LensSessionHandle lensSessionHandle) {
+ WebTarget target = getMLWebTarget();
+ return target.path("evaluate").queryParam("modelInstanceId", modelInstanceId).queryParam("dataSetName",
+ dataSetName).queryParam("lensSessionHandle", lensSessionHandle.toString()).request(MediaType.APPLICATION_XML)
+ .get(String.class);
+ }
+
+ Evaluation getEvaluation(String evalId) {
+ WebTarget target = getMLWebTarget();
+ return target.path("evaluation/" + evalId).request(MediaType.APPLICATION_XML).get(Evaluation.class);
+ }
+
+ boolean cancelEvaluation(String evalId, LensSessionHandle lensSessionHandle) {
+ WebTarget target = getMLWebTarget();
+ APIResult result = target.path("evaluation/" + evalId).queryParam("lensSessionHandle", lensSessionHandle.toString())
+ .request(MediaType.APPLICATION_XML).delete(APIResult.class);
+ return result.getStatus() == APIResult.Status.SUCCEEDED;
+ }
+
+ String predict(String modelInstanceId, String dataSetName, LensSessionHandle lensSessionHandle) {
+ WebTarget target = getMLWebTarget();
+ return target.path("predict").queryParam("modelInstanceId", modelInstanceId).queryParam("dataSetName",
+ dataSetName).queryParam("lensSessionHandle", lensSessionHandle.toString()).request(MediaType.APPLICATION_XML)
+ .get(String.class);
+ }
+
+ Prediction getPrediction(String predictionId) {
+ WebTarget target = getMLWebTarget();
+ return target.path("prediction/" + predictionId).request(MediaType.APPLICATION_XML).get(Prediction.class);
+ }
+
+ boolean cancelPrediction(String predictionId, LensSessionHandle lensSessionHandle) {
+ WebTarget target = getMLWebTarget();
+ return target.path("prediction/" + predictionId).queryParam("lensSessionHandle", lensSessionHandle.toString())
+ .request(MediaType.APPLICATION_XML).delete(Boolean.class);
+
+ }
+
+ public void deleteModel(String modelId) throws LensException {
+ WebTarget target = getMLWebTarget();
+ target.path("model/" + modelId).request(MediaType.APPLICATION_XML).delete();
+ }
+
+ public void deleteDataSet(String dataSetName) throws LensException {
+ WebTarget target = getMLWebTarget();
+ }
+
+ public void deleteModelInstance(String modelInstanceId) throws LensException {
+ WebTarget target = getMLWebTarget();
+ }
+
+ public void deleteEvaluation(String evaluationId) throws LensException {
+ WebTarget target = getMLWebTarget();
+ }
+
+ public void deletePrediction(String predictionId) throws LensException {
+ WebTarget target = getMLWebTarget();
+ }
+
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/Algorithm.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/Algorithm.java
index 29bde29..187c524 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/Algorithm.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/Algorithm.java
@@ -18,29 +18,31 @@
*/
package org.apache.lens.ml.algo.api;
-import java.lang.annotation.ElementType;
-import java.lang.annotation.Retention;
-import java.lang.annotation.RetentionPolicy;
-import java.lang.annotation.Target;
+import java.util.List;
-/**
- * The Interface Algorithm.
- */
-@Retention(RetentionPolicy.RUNTIME)
-@Target(ElementType.TYPE)
-public @interface Algorithm {
+import org.apache.lens.api.LensConf;
+import org.apache.lens.ml.api.AlgoParam;
+import org.apache.lens.ml.api.DataSet;
+import org.apache.lens.ml.api.Model;
+import org.apache.lens.server.api.error.LensException;
+
+public interface Algorithm {
+
+ String getName();
+
+ String getDescription();
+
+ List<AlgoParam> getParams();
/**
- * Name.
+ * Configure.
*
- * @return the string
+ * @param configuration the configuration
*/
- String name();
+ void configure(LensConf configuration);
- /**
- * Description.
- *
- * @return the string
- */
- String description();
+ LensConf getConf();
+
+ TrainedModel train(Model model, DataSet dataTable) throws LensException;
+
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLAlgo.java
deleted file mode 100644
index 65373c6..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLAlgo.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.algo.api;
-
-import org.apache.lens.api.LensConf;
-import org.apache.lens.server.api.error.LensException;
-
-/**
- * The Interface MLAlgo.
- */
-public interface MLAlgo {
- String getName();
-
- String getDescription();
-
- /**
- * Configure.
- *
- * @param configuration the configuration
- */
- void configure(LensConf configuration);
-
- LensConf getConf();
-
- /**
- * Train.
- *
- * @param conf the conf
- * @param db the db
- * @param table the table
- * @param modelId the model id
- * @param params the params
- * @return the ML model
- * @throws LensException the lens exception
- */
- MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException;
-}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLDriver.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLDriver.java
index d2a2748..9814008 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLDriver.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLDriver.java
@@ -43,7 +43,7 @@
* @return the algo instance
* @throws LensException the lens exception
*/
- MLAlgo getAlgoInstance(String algo) throws LensException;
+ Algorithm getAlgoInstance(String algo) throws LensException;
/**
* Inits the.
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLModel.java
deleted file mode 100644
index 73717ac..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/MLModel.java
+++ /dev/null
@@ -1,79 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.algo.api;
-
-import java.io.Serializable;
-import java.util.Date;
-import java.util.List;
-
-import lombok.Getter;
-import lombok.NoArgsConstructor;
-import lombok.Setter;
-import lombok.ToString;
-
-/**
- * Instantiates a new ML model.
- */
-@NoArgsConstructor
-@ToString
-public abstract class MLModel<PREDICTION> implements Serializable {
-
- /** The id. */
- @Getter
- @Setter
- private String id;
-
- /** The created at. */
- @Getter
- @Setter
- private Date createdAt;
-
- /** The algo name. */
- @Getter
- @Setter
- private String algoName;
-
- /** The table. */
- @Getter
- @Setter
- private String table;
-
- /** The params. */
- @Getter
- @Setter
- private List<String> params;
-
- /** The label column. */
- @Getter
- @Setter
- private String labelColumn;
-
- /** The feature columns. */
- @Getter
- @Setter
- private List<String> featureColumns;
-
- /**
- * Predict.
- *
- * @param args the args
- * @return the prediction
- */
- public abstract PREDICTION predict(Object... args);
-}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/AlgoParam.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/TrainedModel.java
similarity index 61%
copy from lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/AlgoParam.java
copy to lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/TrainedModel.java
index e0d13c0..2f23912 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/AlgoParam.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/TrainedModel.java
@@ -18,36 +18,13 @@
*/
package org.apache.lens.ml.algo.api;
-import java.lang.annotation.ElementType;
-import java.lang.annotation.Retention;
-import java.lang.annotation.RetentionPolicy;
-import java.lang.annotation.Target;
+import java.io.Serializable;
+import java.util.Map;
-/**
- * The Interface AlgoParam.
- */
-@Retention(RetentionPolicy.RUNTIME)
-@Target(ElementType.FIELD)
-public @interface AlgoParam {
+import org.apache.lens.server.api.error.LensException;
- /**
- * Name.
- *
- * @return the string
- */
- String name();
+public interface TrainedModel<PREDICTION> extends Serializable {
- /**
- * Help.
- *
- * @return the string
- */
- String help();
+ PREDICTION predict(Map<String, String> featureVector) throws LensException;
- /**
- * Default value.
- *
- * @return the string
- */
- String defaultValue() default "None";
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/AlgoArgParser.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/AlgoArgParser.java
index 00f20fc..2f808f2 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/AlgoArgParser.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/AlgoArgParser.java
@@ -19,51 +19,39 @@
package org.apache.lens.ml.algo.lib;
import java.lang.reflect.Field;
-import java.util.ArrayList;
import java.util.HashMap;
-import java.util.List;
import java.util.Map;
-import org.apache.lens.ml.algo.api.AlgoParam;
-import org.apache.lens.ml.algo.api.MLAlgo;
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.api.AlgoParam;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import lombok.extern.slf4j.Slf4j;
/**
- * The Class AlgoArgParser.
+ * AlgoArgParser class. Parses and sets algo params.
*/
@Slf4j
public final class AlgoArgParser {
+ /**
+ * The Constant LOG.
+ */
+ public static final Log LOG = LogFactory.getLog(AlgoArgParser.class);
+
private AlgoArgParser() {
}
/**
- * The Class CustomArgParser.
+ * Extracts all the variables annotated with @AlgoParam checks if any input key matches the name
+ * if so replaces the value.
*
- * @param <E> the element type
+ * @param algo
+ * @param algoParameters
*/
- public abstract static class CustomArgParser<E> {
-
- /**
- * Parses the.
- *
- * @param value the value
- * @return the e
- */
- public abstract E parse(String value);
- }
-
- /**
- * Extracts feature names. If the algo has any parameters associated with @AlgoParam annotation, those are set
- * as well.
- *
- * @param algo the algo
- * @param args the args
- * @return List of feature column names.
- */
- public static List<String> parseArgs(MLAlgo algo, String[] args) {
- List<String> featureColumns = new ArrayList<String>();
- Class<? extends MLAlgo> algoClass = algo.getClass();
+ public static void parseArgs(Algorithm algo, Map<String, String> algoParameters) {
+ Class<? extends Algorithm> algoClass = algo.getClass();
// Get param fields
Map<String, Field> fieldMap = new HashMap<String, Field>();
@@ -75,14 +63,11 @@
}
}
- for (int i = 0; i < args.length; i += 2) {
- String key = args[i].trim();
- String value = args[i + 1].trim();
-
+ for (Map.Entry<String, String> entry : algoParameters.entrySet()) {
+ String key = entry.getKey();
+ String value = entry.getValue();
try {
- if ("feature".equalsIgnoreCase(key)) {
- featureColumns.add(value);
- } else if (fieldMap.containsKey(key)) {
+ if (fieldMap.containsKey(key)) {
Field f = fieldMap.get(key);
if (String.class.equals(f.getType())) {
f.set(algo, value);
@@ -101,14 +86,29 @@
CustomArgParser<?> parser = clz.newInstance();
f.set(algo, parser.parse(value));
} else {
- log.warn("Ignored param " + key + "=" + value + " as no parser found");
+ LOG.warn("Ignored param " + key + "=" + value + " as no parser found");
}
}
}
} catch (Exception exc) {
- log.error("Error while setting param " + key + " to " + value + " for algo " + algo, exc);
+ LOG.error("Error while setting param " + key + " to " + value + " for algo " + algo, exc);
}
}
- return featureColumns;
+ }
+
+ /**
+ * The Class CustomArgParser.
+ *
+ * @param <E> the element type
+ */
+ public abstract static class CustomArgParser<E> {
+
+ /**
+ * Parses the.
+ *
+ * @param value the value
+ * @return the e
+ */
+ public abstract E parse(String value);
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/Algorithms.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/Algorithms.java
index ad37403..87ed8ee 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/Algorithms.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/Algorithms.java
@@ -25,7 +25,6 @@
import java.util.Map;
import org.apache.lens.ml.algo.api.Algorithm;
-import org.apache.lens.ml.algo.api.MLAlgo;
import org.apache.lens.server.api.error.LensException;
/**
@@ -33,50 +32,44 @@
*/
public class Algorithms {
- /** The algorithm classes. */
- private final Map<String, Class<? extends MLAlgo>> algorithmClasses
- = new HashMap<String, Class<? extends MLAlgo>>();
+ /**
+ * The algorithm classes.
+ */
+
+ private final Map<String, Class<? extends Algorithm>> algorithmClasses
+ = new HashMap<String, Class<? extends Algorithm>>();
/**
- * Register.
+ * Registers algorithm
*
- * @param algoClass the algo class
+ * @param name
+ * @param algoClass
*/
- public void register(Class<? extends MLAlgo> algoClass) {
- if (algoClass != null && algoClass.getAnnotation(Algorithm.class) != null) {
- algorithmClasses.put(algoClass.getAnnotation(Algorithm.class).name(), algoClass);
- } else {
- throw new IllegalArgumentException("Not a valid algorithm class: " + algoClass);
+ public void register(String name, Class<? extends Algorithm> algoClass) {
+ if (algoClass != null) {
+ algorithmClasses.put(name, algoClass);
}
}
- /**
- * Gets the algo for name.
- *
- * @param name the name
- * @return the algo for name
- * @throws LensException the lens exception
- */
- public MLAlgo getAlgoForName(String name) throws LensException {
- Class<? extends MLAlgo> algoClass = algorithmClasses.get(name);
+ public Algorithm getAlgoForName(String name) throws LensException {
+ Class<? extends Algorithm> algoClass = algorithmClasses.get(name);
+
if (algoClass == null) {
return null;
}
- Algorithm algoAnnotation = algoClass.getAnnotation(Algorithm.class);
- String description = algoAnnotation.description();
try {
- Constructor<? extends MLAlgo> algoConstructor = algoClass.getConstructor(String.class, String.class);
- return algoConstructor.newInstance(name, description);
- } catch (Exception exc) {
- throw new LensException("Unable to get algo: " + name, exc);
+ Constructor<? extends Algorithm> constructor = algoClass.getConstructor();
+ return constructor.newInstance();
+ } catch (Exception e) {
+ throw new LensException("Unable to get Algorithm " + name, e);
}
}
/**
- * Checks if is algo supported.
+ * Checks if algorithm is supported
*
- * @param name the name
- * @return true, if is algo supported
+ * @param name
+ * @return
*/
public boolean isAlgoSupported(String name) {
return algorithmClasses.containsKey(name);
@@ -85,5 +78,4 @@
public List<String> getAlgorithmNames() {
return new ArrayList<String>(algorithmClasses.keySet());
}
-
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ClassifierBaseModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ClassifierBaseModel.java
index a960a4a..6fb2101 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ClassifierBaseModel.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ClassifierBaseModel.java
@@ -18,13 +18,12 @@
*/
package org.apache.lens.ml.algo.lib;
-import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.algo.api.TrainedModel;
/**
- * Return a single double value as a prediction. This is useful in classifiers where the classifier returns a single
- * class label as a prediction.
+ * The Class ClassifierBaseModel
*/
-public abstract class ClassifierBaseModel extends MLModel<Double> {
+public abstract class ClassifierBaseModel implements TrainedModel<Double> {
/**
* Gets the feature vector.
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ForecastingModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ForecastingModel.java
index 16a6180..7176797 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ForecastingModel.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/lib/ForecastingModel.java
@@ -19,21 +19,24 @@
package org.apache.lens.ml.algo.lib;
import java.util.List;
+import java.util.Map;
-import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.algo.api.TrainedModel;
+import org.apache.lens.server.api.error.LensException;
/**
* The Class ForecastingModel.
*/
-public class ForecastingModel extends MLModel<MultiPrediction> {
+public class ForecastingModel implements TrainedModel<MultiPrediction> {
/*
* (non-Javadoc)
*
- * @see org.apache.lens.ml.MLModel#predict(java.lang.Object[])
+ * @see org.apache.lens.ml.TrainedModel#predict(java.lang.Object[])
*/
+
@Override
- public MultiPrediction predict(Object... args) {
+ public MultiPrediction predict(Map<String, String> featureVector) throws LensException {
return new ForecastingPredictions(null);
}
@@ -42,7 +45,9 @@
*/
public static class ForecastingPredictions implements MultiPrediction {
- /** The values. */
+ /**
+ * The values.
+ */
private final List<LabelledPrediction> values;
/**
@@ -65,10 +70,14 @@
*/
public static class ForecastingLabel implements LabelledPrediction<Long, Double> {
- /** The timestamp. */
+ /**
+ * The timestamp.
+ */
private final Long timestamp;
- /** The value. */
+ /**
+ * The value.
+ */
private final double value;
/**
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkAlgo.java
index 3936693..4316a40 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkAlgo.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkAlgo.java
@@ -22,58 +22,82 @@
import java.util.*;
import org.apache.lens.api.LensConf;
-import org.apache.lens.ml.algo.api.AlgoParam;
import org.apache.lens.ml.algo.api.Algorithm;
-import org.apache.lens.ml.algo.api.MLAlgo;
-import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.algo.api.TrainedModel;
+import org.apache.lens.ml.api.AlgoParam;
+import org.apache.lens.ml.api.DataSet;
+import org.apache.lens.ml.api.Feature;
+import org.apache.lens.ml.api.Model;
import org.apache.lens.server.api.error.LensException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.ql.session.SessionState;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
-import lombok.extern.slf4j.Slf4j;
-
/**
* The Class BaseSparkAlgo.
*/
-@Slf4j
-public abstract class BaseSparkAlgo implements MLAlgo {
+public abstract class BaseSparkAlgo implements Algorithm {
- /** The name. */
+ /**
+ * The Constant LOG.
+ */
+ public static final Log LOG = LogFactory.getLog(BaseSparkAlgo.class);
+
+ /**
+ * The name.
+ */
private final String name;
- /** The description. */
+ /**
+ * The description.
+ */
private final String description;
- /** The spark context. */
+ /**
+ * The spark context.
+ */
protected JavaSparkContext sparkContext;
- /** The params. */
+ /**
+ * The params.
+ */
protected Map<String, String> params;
- /** The conf. */
+ /**
+ * The conf.
+ */
protected transient LensConf conf;
- /** The training fraction. */
+ /**
+ * The training fraction.
+ */
@AlgoParam(name = "trainingFraction", help = "% of dataset to be used for training", defaultValue = "0")
protected double trainingFraction;
-
- /** The use training fraction. */
- private boolean useTrainingFraction;
-
- /** The label. */
- @AlgoParam(name = "label", help = "Name of column which is used as a training label for supervised learning")
- protected String label;
-
- /** The partition filter. */
+ /**
+ * The label.
+ */
+ @AlgoParam(name = "label", help = "column name, feature name which is used as a training label for supervised "
+ + "learning")
+ protected Feature label;
+ /**
+ * The partition filter.
+ */
@AlgoParam(name = "partition", help = "Partition filter used to create create HCatInputFormats")
protected String partitionFilter;
-
- /** The features. */
- @AlgoParam(name = "feature", help = "Column name(s) which are to be used as sample features")
- protected List<String> features;
+ /**
+ * The features.
+ */
+ @AlgoParam(name = "feature", help = "sample features containing feature name and column name")
+ protected List<Feature> features;
+ /**
+ * The use training fraction.
+ */
+ private boolean useTrainingFraction;
/**
* Instantiates a new base spark algo.
@@ -108,33 +132,38 @@
/*
* (non-Javadoc)
*
- * @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
- * java.lang.String, java.lang.String[])
+ * @see org.apache.lens.ml.TrainedModel#train(Model model, String dataTable)
*/
@Override
- public MLModel<?> train(LensConf conf, String db, String table, String modelId, String... params)
- throws LensException {
- parseParams(params);
+ public TrainedModel train(Model model, DataSet dataTable) throws LensException {
+ parseParams(model.getAlgoSpec().getAlgoParams());
+ features = model.getFeatureSpec();
+ String database = dataTable.getDbName();
+ if (database.isEmpty()) {
+ if (SessionState.get() != null) {
+ database = SessionState.get().getCurrentDatabase();
+ } else {
+ database = "default";
+ }
+ }
+
TableTrainingSpec.TableTrainingSpecBuilder builder = TableTrainingSpec.newBuilder().hiveConf(toHiveConf(conf))
- .database(db).table(table).partitionFilter(partitionFilter).featureColumns(features).labelColumn(label);
-
+ .database(database).table(dataTable.getTableName()).partitionFilter(partitionFilter)
+ .featureColumns(model.getFeatureSpec())
+ .labelColumn(model.getLabelSpec());
if (useTrainingFraction) {
builder.trainingFraction(trainingFraction);
}
TableTrainingSpec spec = builder.build();
- log.info("Training with {} features", features.size());
+ LOG.info("Training " + " with " + features.size() + " features");
spec.createRDDs(sparkContext);
RDD<LabeledPoint> trainingRDD = spec.getTrainingRDD();
- BaseSparkClassificationModel<?> model = trainInternal(modelId, trainingRDD);
- model.setTable(table);
- model.setParams(Arrays.asList(params));
- model.setLabelColumn(label);
- model.setFeatureColumns(features);
- return model;
+ BaseSparkClassificationModel<?> trainedModel = trainInternal(trainingRDD);
+ return trainedModel;
}
/**
@@ -156,29 +185,13 @@
*
* @param args the args
*/
- public void parseParams(String[] args) {
- if (args.length % 2 != 0) {
- throw new IllegalArgumentException("Invalid number of params " + args.length);
- }
- params = new LinkedHashMap<String, String>();
+ public void parseParams(Map<String, String> args) {
- for (int i = 0; i < args.length; i += 2) {
- if ("f".equalsIgnoreCase(args[i]) || "feature".equalsIgnoreCase(args[i])) {
- if (features == null) {
- features = new ArrayList<String>();
- }
- features.add(args[i + 1]);
- } else if ("l".equalsIgnoreCase(args[i]) || "label".equalsIgnoreCase(args[i])) {
- label = args[i + 1];
- } else {
- params.put(args[i].replaceAll("\\-+", ""), args[i + 1]);
- }
- }
+ params = new HashMap();
- if (params.containsKey("trainingFraction")) {
- // Get training Fraction
- String trainingFractionStr = params.get("trainingFraction");
+ if (args.containsKey("trainingFraction")) {
+ String trainingFractionStr = args.get("trainingFraction");
try {
trainingFraction = Double.parseDouble(trainingFractionStr);
useTrainingFraction = true;
@@ -186,12 +199,11 @@
throw new IllegalArgumentException("Invalid training fraction", nfe);
}
}
-
- if (params.containsKey("partition") || params.containsKey("p")) {
- partitionFilter = params.containsKey("partition") ? params.get("partition") : params.get("p");
+ if (args.containsKey("partition") || args.containsKey("p")) {
+ partitionFilter = args.containsKey("partition") ? args.get("partition") : args.get("p");
}
- parseAlgoParams(params);
+ parseAlgoParams(args);
}
/**
@@ -206,7 +218,7 @@
try {
return Double.parseDouble(params.get(param));
} catch (NumberFormatException nfe) {
- log.warn("Couldn't parse param value: {} as double.", param);
+ LOG.warn("Couldn't parse param value: " + param + " as double.");
}
}
return defaultVal;
@@ -224,7 +236,7 @@
try {
return Integer.parseInt(params.get(param));
} catch (NumberFormatException nfe) {
- log.warn("Couldn't parse param value: {} as integer.", param);
+ LOG.warn("Couldn't parse param value: " + param + " as integer.");
}
}
return defaultVal;
@@ -241,12 +253,8 @@
public Map<String, String> getArgUsage() {
Map<String, String> usage = new LinkedHashMap<String, String>();
Class<?> clz = this.getClass();
- // Put class name and description as well as part of the usage
- Algorithm algorithm = clz.getAnnotation(Algorithm.class);
- if (algorithm != null) {
- usage.put("Algorithm Name", algorithm.name());
- usage.put("Algorithm Description", algorithm.description());
- }
+ usage.put("Algorithm Name", name);
+ usage.put("Algorithm Description", description);
// Get all algo params including base algo params
while (clz != null) {
@@ -275,11 +283,22 @@
/**
* Train internal.
*
- * @param modelId the model id
* @param trainingRDD the training rdd
* @return the base spark classification model
* @throws LensException the lens exception
*/
- protected abstract BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ protected abstract BaseSparkClassificationModel trainInternal(RDD<LabeledPoint> trainingRDD)
throws LensException;
+
+ @Override
+ public List<AlgoParam> getParams() {
+ ArrayList<AlgoParam> paramList = new ArrayList();
+ for (Field field : this.getClass().getDeclaredFields()) {
+ AlgoParam param = field.getAnnotation(AlgoParam.class);
+ if (param != null) {
+ paramList.add(param);
+ }
+ }
+ return paramList;
+ }
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkClassificationModel.java
index 806dc1f..945d97d 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkClassificationModel.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/BaseSparkClassificationModel.java
@@ -18,48 +18,56 @@
*/
package org.apache.lens.ml.algo.spark;
+import java.util.List;
+import java.util.Map;
+
import org.apache.lens.ml.algo.lib.ClassifierBaseModel;
+import org.apache.lens.ml.api.Feature;
+import org.apache.lens.server.api.error.LensException;
import org.apache.spark.mllib.classification.ClassificationModel;
import org.apache.spark.mllib.linalg.Vectors;
/**
- * The Class BaseSparkClassificationModel.
+ * The class BaseSparkClassificationModel
*
- * @param <MODEL> the generic type
+ * @param <MODEL>
*/
public class BaseSparkClassificationModel<MODEL extends ClassificationModel> extends ClassifierBaseModel {
- /** The model id. */
- private final String modelId;
-
- /** The spark model. */
- private final MODEL sparkModel;
/**
- * Instantiates a new base spark classification model.
- *
- * @param modelId the model id
- * @param model the model
+ * The spark model.
*/
- public BaseSparkClassificationModel(String modelId, MODEL model) {
- this.modelId = modelId;
+ private final MODEL sparkModel;
+ private List<Feature> featureList;
+
+ /**
+ * initializes BaseSparkClassificationModel
+ *
+ * @param featureList
+ * @param model
+ */
+ public BaseSparkClassificationModel(List<Feature> featureList, MODEL model) {
this.sparkModel = model;
+ this.featureList = featureList;
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLModel#predict(java.lang.Object[])
- */
- @Override
- public Double predict(Object... args) {
- return sparkModel.predict(Vectors.dense(getFeatureVector(args)));
- }
@Override
- public String getId() {
- return modelId;
+ public Double predict(Map<String, String> featureVector) throws LensException {
+ String[] featureArray = new String[featureList.size()];
+ int i = 0;
+ for (Feature feature : featureList) {
+ String featureValue = featureVector.get(feature.getName());
+ if (featureValue == null || featureValue.isEmpty()) {
+ throw new LensException("Error while predicting: input featureVector doesn't contain all required features : "
+ + "Feature Name: " + feature.getName());
+ } else {
+ featureArray[i++] = featureVector.get(feature.getName());
+ }
+ }
+ return sparkModel.predict(Vectors.dense(getFeatureVector(featureArray)));
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/ColumnFeatureFunction.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/ColumnFeatureFunction.java
index 900792e..9837565 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/ColumnFeatureFunction.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/ColumnFeatureFunction.java
@@ -24,9 +24,7 @@
import org.apache.spark.mllib.regression.LabeledPoint;
import com.google.common.base.Preconditions;
-
import lombok.extern.slf4j.Slf4j;
-
import scala.Tuple2;
/**
@@ -36,19 +34,29 @@
@Slf4j
public class ColumnFeatureFunction extends FeatureFunction {
- /** The feature value mappers. */
+ /**
+ * The feature value mappers.
+ */
private final FeatureValueMapper[] featureValueMappers;
- /** The feature positions. */
+ /**
+ * The feature positions.
+ */
private final int[] featurePositions;
- /** The label column pos. */
+ /**
+ * The label column pos.
+ */
private final int labelColumnPos;
- /** The num features. */
+ /**
+ * The num features.
+ */
private final int numFeatures;
- /** The default labeled point. */
+ /**
+ * The default labeled point.
+ */
private final LabeledPoint defaultLabeledPoint;
/**
@@ -62,11 +70,12 @@
* @param defaultLabel default lable to be used for null records
*/
public ColumnFeatureFunction(int[] featurePositions, FeatureValueMapper[] valueMappers, int labelColumnPos,
- int numFeatures, double defaultLabel) {
+ int numFeatures, double defaultLabel) {
Preconditions.checkNotNull(valueMappers, "Value mappers argument is required");
Preconditions.checkNotNull(featurePositions, "Feature positions are required");
- Preconditions.checkArgument(valueMappers.length == featurePositions.length,
- "Mismatch between number of value mappers and feature positions");
+ Preconditions.checkArgument(valueMappers.length == featurePositions.length + 1,
+ "Mismatch between number of value mappers and feature positions. There should be value mappers for features and "
+ + "one additional mapper for label");
this.featurePositions = featurePositions;
this.featureValueMappers = valueMappers;
@@ -96,7 +105,8 @@
features[i] = featureValueMappers[i].call(record.get(featurePos));
}
- double label = featureValueMappers[labelColumnPos].call(record.get(labelColumnPos));
+ //Feature mapper for label is stored after label mappers at position numFeatures.
+ double label = featureValueMappers[numFeatures].call(record.get(labelColumnPos));
return new LabeledPoint(label, Vectors.dense(features));
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/HiveTableRDD.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/HiveTableRDD.java
index fd5651e..56af76e 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/HiveTableRDD.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/HiveTableRDD.java
@@ -27,10 +27,14 @@
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import lombok.extern.slf4j.Slf4j;
+
/**
* Create a JavaRDD based on a Hive table using HCatInputFormat.
*/
+@Slf4j
public final class HiveTableRDD {
+
private HiveTableRDD() {
}
@@ -46,7 +50,9 @@
* @throws IOException Signals that an I/O exception has occurred.
*/
public static JavaPairRDD<WritableComparable, HCatRecord> createHiveTableRDD(JavaSparkContext javaSparkContext,
- Configuration conf, String db, String table, String partitionFilter) throws IOException {
+ Configuration conf, String db,
+ String table, String partitionFilter)
+ throws IOException {
HCatInputFormat.setInput(conf, db, table, partitionFilter);
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/SparkMLDriver.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/SparkMLDriver.java
index 9ac62ce..d13d244 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/SparkMLDriver.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/SparkMLDriver.java
@@ -24,10 +24,11 @@
import java.util.List;
import org.apache.lens.api.LensConf;
-import org.apache.lens.ml.algo.api.MLAlgo;
+import org.apache.lens.ml.algo.api.Algorithm;
import org.apache.lens.ml.algo.api.MLDriver;
import org.apache.lens.ml.algo.lib.Algorithms;
import org.apache.lens.ml.algo.spark.dt.DecisionTreeAlgo;
+import org.apache.lens.ml.algo.spark.kmeans.KMeansAlgo;
import org.apache.lens.ml.algo.spark.lr.LogisticRegressionAlgo;
import org.apache.lens.ml.algo.spark.nb.NaiveBayesAlgo;
import org.apache.lens.ml.algo.spark.svm.SVMAlgo;
@@ -45,37 +46,29 @@
@Slf4j
public class SparkMLDriver implements MLDriver {
- /** The owns spark context. */
- private boolean ownsSparkContext = true;
-
/**
- * The Enum SparkMasterMode.
+ * The algorithms.
*/
- private enum SparkMasterMode {
- // Embedded mode used in tests
- /** The embedded. */
- EMBEDDED,
- // Yarn client and Yarn cluster modes are used when deploying the app to Yarn cluster
- /** The yarn client. */
- YARN_CLIENT,
-
- /** The yarn cluster. */
- YARN_CLUSTER
- }
-
- /** The algorithms. */
private final Algorithms algorithms = new Algorithms();
-
- /** The client mode. */
+ /**
+ * If the driver owns spark's context.
+ */
+ private boolean ownsSparkContext = true;
+ /**
+ * Spark's client mode.
+ */
private SparkMasterMode clientMode = SparkMasterMode.EMBEDDED;
-
- /** The is started. */
+ /**
+ * If the driver is started.
+ */
private boolean isStarted;
-
- /** The spark conf. */
+ /**
+ * The spark conf.
+ */
private SparkConf sparkConf;
-
- /** The spark context. */
+ /**
+ * The spark context.
+ */
private JavaSparkContext sparkContext;
/**
@@ -98,20 +91,22 @@
return algorithms.isAlgoSupported(name);
}
- /*
- * (non-Javadoc)
+ /**
+ * Returns Algorithm Instance for it's name
*
- * @see org.apache.lens.ml.MLDriver#getAlgoInstance(java.lang.String)
+ * @param name
+ * @return
+ * @throws LensException
*/
@Override
- public MLAlgo getAlgoInstance(String name) throws LensException {
+ public Algorithm getAlgoInstance(String name) throws LensException {
checkStarted();
if (!isAlgoSupported(name)) {
return null;
}
- MLAlgo algo = null;
+ Algorithm algo = null;
try {
algo = algorithms.getAlgoForName(name);
if (algo instanceof BaseSparkAlgo) {
@@ -127,16 +122,18 @@
* Register algos.
*/
private void registerAlgos() {
- algorithms.register(NaiveBayesAlgo.class);
- algorithms.register(SVMAlgo.class);
- algorithms.register(LogisticRegressionAlgo.class);
- algorithms.register(DecisionTreeAlgo.class);
+ algorithms.register("spark_logistic_regression", LogisticRegressionAlgo.class);
+ algorithms.register("spark_naive_bayes", NaiveBayesAlgo.class);
+ algorithms.register("spark_k_means", KMeansAlgo.class);
+ algorithms.register("spark_svm", SVMAlgo.class);
+ algorithms.register("spark_decision_tree", DecisionTreeAlgo.class);
}
- /*
- * (non-Javadoc)
+ /**
+ * Initializes the driver
*
- * @see org.apache.lens.ml.MLDriver#init(org.apache.lens.api.LensConf)
+ * @param conf the conf
+ * @throws LensException
*/
@Override
public void init(LensConf conf) throws LensException {
@@ -280,4 +277,25 @@
return sparkContext;
}
+ /**
+ * Supported algorithms.
+ */
+ private enum SparkMasterMode {
+ // Embedded mode used in tests
+ /**
+ * The embedded.
+ */
+ EMBEDDED,
+ // Yarn client and Yarn cluster modes are used when deploying the app to Yarn cluster
+ /**
+ * The yarn client.
+ */
+ YARN_CLIENT,
+
+ /**
+ * The yarn cluster.
+ */
+ YARN_CLUSTER
+ }
+
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/TableTrainingSpec.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/TableTrainingSpec.java
index 5b7c48b..4362945 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/TableTrainingSpec.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/TableTrainingSpec.java
@@ -23,8 +23,11 @@
import java.util.ArrayList;
import java.util.List;
+import org.apache.lens.ml.api.Feature;
import org.apache.lens.server.api.error.LensException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hive.hcatalog.data.HCatRecord;
@@ -39,69 +42,86 @@
import org.apache.spark.rdd.RDD;
import com.google.common.base.Preconditions;
-
import lombok.Getter;
import lombok.ToString;
-import lombok.extern.slf4j.Slf4j;
/**
* The Class TableTrainingSpec.
*/
@ToString
-@Slf4j
public class TableTrainingSpec implements Serializable {
- /** The training rdd. */
+ /**
+ * The Constant LOG.
+ */
+ public static final Log LOG = LogFactory.getLog(TableTrainingSpec.class);
+ /**
+ * The label pos.
+ */
+ int labelPos;
+ /**
+ * The feature positions.
+ */
+ int[] featurePositions;
+ /**
+ * The num features.
+ */
+ int numFeatures;
+ /**
+ * The labeled rdd.
+ */
+ transient JavaRDD<LabeledPoint> labeledRDD;
+ /**
+ * The training rdd.
+ */
@Getter
private transient RDD<LabeledPoint> trainingRDD;
-
- /** The testing rdd. */
+ /**
+ * The testing rdd.
+ */
@Getter
private transient RDD<LabeledPoint> testingRDD;
-
- /** The database. */
+ /**
+ * The database.
+ */
@Getter
private String database;
-
- /** The table. */
+ /**
+ * The table.
+ */
@Getter
private String table;
- /** The partition filter. */
+ // By default all samples are considered for training
+ /**
+ * The partition filter.
+ */
@Getter
private String partitionFilter;
-
- /** The feature columns. */
+ /**
+ * The feature columns.
+ */
@Getter
- private List<String> featureColumns;
-
- /** The label column. */
+ private List<Feature> featureColumns;
+ /**
+ * The label column.
+ */
@Getter
- private String labelColumn;
-
- /** The conf. */
+ private Feature labelColumn;
+ /**
+ * The conf.
+ */
@Getter
private transient HiveConf conf;
-
- // By default all samples are considered for training
- /** The split training. */
+ /**
+ * The split training.
+ */
private boolean splitTraining;
-
- /** The training fraction. */
+ /**
+ * The training fraction.
+ */
private double trainingFraction = 1.0;
- /** The label pos. */
- int labelPos;
-
- /** The feature positions. */
- int[] featurePositions;
-
- /** The num features. */
- int numFeatures;
-
- /** The labeled rdd. */
- transient JavaRDD<LabeledPoint> labeledRDD;
-
/**
* New builder.
*
@@ -112,11 +132,141 @@
}
/**
+ * Validate.
+ *
+ * @return true, if successful
+ */
+ boolean validate() {
+ List<HCatFieldSchema> columns;
+ try {
+ HCatInputFormat.setInput(conf, database == null ? "default" : database, table, partitionFilter);
+ HCatSchema tableSchema = HCatInputFormat.getTableSchema(conf);
+ columns = tableSchema.getFields();
+ } catch (IOException exc) {
+ LOG.error("Error getting table info " + toString(), exc);
+ return false;
+ }
+
+ LOG.info(table + " columns " + columns.toString());
+
+ boolean valid = false;
+ if (columns != null && !columns.isEmpty()) {
+ // Check labeled column
+ List<String> columnNamesInTable = new ArrayList();
+ List<String> columnNamesInFeatureList = new ArrayList();
+ for (HCatFieldSchema col : columns) {
+ columnNamesInTable.add(col.getName());
+ }
+
+ for (Feature feature : featureColumns) {
+ columnNamesInFeatureList.add(feature.getDataColumn());
+ }
+
+ String labelColumnName = labelColumn.getDataColumn();
+
+ // Need at least one feature column and one label column
+ valid = columnNamesInTable.contains(labelColumnName);
+
+ if (valid) {
+ labelPos = columnNamesInTable.indexOf(labelColumnName);
+
+ // Check feature columns
+ if (featureColumns == null || featureColumns.isEmpty()) {
+ // feature columns are not provided, so all columns except label column are feature columns
+ featurePositions = new int[columnNamesInTable.size() - 1];
+ int p = 0;
+ for (int i = 0; i < columnNamesInTable.size(); i++) {
+ if (i == labelPos) {
+ continue;
+ }
+ featurePositions[p++] = i;
+ }
+
+ columnNamesInTable.remove(labelPos);
+ featureColumns = new ArrayList<Feature>();
+ for (String featureName : columnNamesInTable) {
+ featureColumns.add(new Feature(featureName, null, null, featureName));
+ }
+ } else {
+ // Feature columns were provided, verify all feature columns are present in the table
+ valid = columnNamesInTable.containsAll(columnNamesInFeatureList);
+ if (valid) {
+ // Get feature positions
+ featurePositions = new int[columnNamesInFeatureList.size()];
+ for (int i = 0; i < columnNamesInFeatureList.size(); i++) {
+ featurePositions[i] = columnNamesInTable.indexOf(featureColumns.get(i).getDataColumn());
+ }
+ }
+ }
+ numFeatures = featureColumns.size();
+ }
+ }
+
+ return valid;
+ }
+
+ /**
+ * Creates the rd ds.
+ *
+ * @param sparkContext the spark context
+ * @throws LensException the lens exception
+ */
+ public void createRDDs(JavaSparkContext sparkContext) throws LensException {
+ // Validate the spec
+ if (!validate()) {
+ throw new LensException("Table spec not valid: " + toString());
+ }
+
+ LOG.info("Creating RDDs with spec " + toString());
+
+ // Get the RDD for table
+ JavaPairRDD<WritableComparable, HCatRecord> tableRDD;
+ try {
+ tableRDD = HiveTableRDD.createHiveTableRDD(sparkContext, conf, database, table, partitionFilter);
+ } catch (IOException e) {
+ throw new LensException(e);
+ }
+
+ // Map into trainable RDD
+ // TODO: Figure out a way to use custom value mappers
+ FeatureValueMapper[] valueMappers = new FeatureValueMapper[numFeatures + 1];
+ final DoubleValueMapper doubleMapper = new DoubleValueMapper();
+ for (int i = 0; i < numFeatures; i++) {
+ valueMappers[i] = doubleMapper;
+ }
+ valueMappers[numFeatures] = doubleMapper; //label mapper
+ ColumnFeatureFunction trainPrepFunction = new ColumnFeatureFunction(featurePositions, valueMappers, labelPos,
+ numFeatures, 0);
+ labeledRDD = tableRDD.map(trainPrepFunction);
+
+ if (splitTraining) {
+ // We have to split the RDD between a training RDD and a testing RDD
+ LOG.info("Splitting RDD for table " + database + "." + table + " with split fraction " + trainingFraction);
+ JavaRDD<DataSample> sampledRDD = labeledRDD.map(new Function<LabeledPoint, DataSample>() {
+ @Override
+ public DataSample call(LabeledPoint v1) throws Exception {
+ return new DataSample(v1);
+ }
+ });
+
+ trainingRDD = sampledRDD.filter(new TrainingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
+ testingRDD = sampledRDD.filter(new TestingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
+ } else {
+ LOG.info("Using same RDD for train and test");
+ trainingRDD = labeledRDD.rdd();
+ testingRDD = trainingRDD;
+ }
+ LOG.info("Generated RDDs");
+ }
+
+ /**
* The Class TableTrainingSpecBuilder.
*/
public static class TableTrainingSpecBuilder {
- /** The spec. */
+ /**
+ * The spec.
+ */
final TableTrainingSpec spec;
/**
@@ -176,7 +326,7 @@
* @param labelColumn the label column
* @return the table training spec builder
*/
- public TableTrainingSpecBuilder labelColumn(String labelColumn) {
+ public TableTrainingSpecBuilder labelColumn(Feature labelColumn) {
spec.labelColumn = labelColumn;
return this;
}
@@ -187,7 +337,7 @@
* @param featureColumns the feature columns
* @return the table training spec builder
*/
- public TableTrainingSpecBuilder featureColumns(List<String> featureColumns) {
+ public TableTrainingSpecBuilder featureColumns(List<Feature> featureColumns) {
spec.featureColumns = featureColumns;
return this;
}
@@ -221,10 +371,14 @@
*/
public static class DataSample implements Serializable {
- /** The labeled point. */
+ /**
+ * The labeled point.
+ */
private final LabeledPoint labeledPoint;
- /** The sample. */
+ /**
+ * The sample.
+ */
private final double sample;
/**
@@ -243,7 +397,9 @@
*/
public static class TrainingFilter implements Function<DataSample, Boolean> {
- /** The training fraction. */
+ /**
+ * The training fraction.
+ */
private double trainingFraction;
/**
@@ -271,7 +427,9 @@
*/
public static class TestingFilter implements Function<DataSample, Boolean> {
- /** The training fraction. */
+ /**
+ * The training fraction.
+ */
private double trainingFraction;
/**
@@ -310,122 +468,4 @@
}
}
- /**
- * Validate.
- *
- * @return true, if successful
- */
- boolean validate() {
- List<HCatFieldSchema> columns;
- try {
- HCatInputFormat.setInput(conf, database == null ? "default" : database, table, partitionFilter);
- HCatSchema tableSchema = HCatInputFormat.getTableSchema(conf);
- columns = tableSchema.getFields();
- } catch (IOException exc) {
- log.error("Error getting table info {}", toString(), exc);
- return false;
- }
-
- log.info("{} columns {}", table, columns.toString());
-
- boolean valid = false;
- if (columns != null && !columns.isEmpty()) {
- // Check labeled column
- List<String> columnNames = new ArrayList<String>();
- for (HCatFieldSchema col : columns) {
- columnNames.add(col.getName());
- }
-
- // Need at least one feature column and one label column
- valid = columnNames.contains(labelColumn) && columnNames.size() > 1;
-
- if (valid) {
- labelPos = columnNames.indexOf(labelColumn);
-
- // Check feature columns
- if (featureColumns == null || featureColumns.isEmpty()) {
- // feature columns are not provided, so all columns except label column are feature columns
- featurePositions = new int[columnNames.size() - 1];
- int p = 0;
- for (int i = 0; i < columnNames.size(); i++) {
- if (i == labelPos) {
- continue;
- }
- featurePositions[p++] = i;
- }
-
- columnNames.remove(labelPos);
- featureColumns = columnNames;
- } else {
- // Feature columns were provided, verify all feature columns are present in the table
- valid = columnNames.containsAll(featureColumns);
- if (valid) {
- // Get feature positions
- featurePositions = new int[featureColumns.size()];
- for (int i = 0; i < featureColumns.size(); i++) {
- featurePositions[i] = columnNames.indexOf(featureColumns.get(i));
- }
- }
- }
- numFeatures = featureColumns.size();
- }
- }
-
- return valid;
- }
-
- /**
- * Creates the rd ds.
- *
- * @param sparkContext the spark context
- * @throws LensException the lens exception
- */
- public void createRDDs(JavaSparkContext sparkContext) throws LensException {
- // Validate the spec
- if (!validate()) {
- throw new LensException("Table spec not valid: " + toString());
- }
-
- log.info("Creating RDDs with spec {}", toString());
-
- // Get the RDD for table
- JavaPairRDD<WritableComparable, HCatRecord> tableRDD;
- try {
- tableRDD = HiveTableRDD.createHiveTableRDD(sparkContext, conf, database, table, partitionFilter);
- } catch (IOException e) {
- throw new LensException(e);
- }
-
- // Map into trainable RDD
- // TODO: Figure out a way to use custom value mappers
- FeatureValueMapper[] valueMappers = new FeatureValueMapper[numFeatures];
- final DoubleValueMapper doubleMapper = new DoubleValueMapper();
- for (int i = 0; i < numFeatures; i++) {
- valueMappers[i] = doubleMapper;
- }
-
- ColumnFeatureFunction trainPrepFunction = new ColumnFeatureFunction(featurePositions, valueMappers, labelPos,
- numFeatures, 0);
- labeledRDD = tableRDD.map(trainPrepFunction);
-
- if (splitTraining) {
- // We have to split the RDD between a training RDD and a testing RDD
- log.info("Splitting RDD for table {}.{} with split fraction {}", database, table, trainingFraction);
- JavaRDD<DataSample> sampledRDD = labeledRDD.map(new Function<LabeledPoint, DataSample>() {
- @Override
- public DataSample call(LabeledPoint v1) throws Exception {
- return new DataSample(v1);
- }
- });
-
- trainingRDD = sampledRDD.filter(new TrainingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
- testingRDD = sampledRDD.filter(new TestingFilter(trainingFraction)).map(new GetLabeledPoint()).rdd();
- } else {
- log.info("Using same RDD for train and test");
- trainingRDD = labeledRDD.rdd();
- testingRDD = trainingRDD;
- }
- log.info("Generated RDDs");
- }
-
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeAlgo.java
index 7810a9a..4dd9f50 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeAlgo.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeAlgo.java
@@ -20,10 +20,9 @@
import java.util.Map;
-import org.apache.lens.ml.algo.api.AlgoParam;
-import org.apache.lens.ml.algo.api.Algorithm;
import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+import org.apache.lens.ml.api.AlgoParam;
import org.apache.lens.server.api.error.LensException;
import org.apache.spark.mllib.regression.LabeledPoint;
@@ -41,19 +40,27 @@
/**
* The Class DecisionTreeAlgo.
*/
-@Algorithm(name = "spark_decision_tree", description = "Spark Decision Tree classifier algo")
public class DecisionTreeAlgo extends BaseSparkAlgo {
- /** The algo. */
+ static final String DESCRIPTION = "Spark decision tree algo";
+ static final String NAME = "spark_decision_tree";
+
+ /**
+ * The algo.
+ */
@AlgoParam(name = "algo", help = "Decision tree algorithm. Allowed values are 'classification' and 'regression'")
private Enumeration.Value algo;
- /** The decision tree impurity. */
+ /**
+ * The decision tree impurity.
+ */
@AlgoParam(name = "impurity", help = "Impurity measure used by the decision tree. "
+ "Allowed values are 'gini', 'entropy' and 'variance'")
private Impurity decisionTreeImpurity;
- /** The max depth. */
+ /**
+ * The max depth.
+ */
@AlgoParam(name = "maxDepth", help = "Max depth of the decision tree. Integer values expected.",
defaultValue = "100")
private int maxDepth;
@@ -68,6 +75,10 @@
super(name, description);
}
+ public DecisionTreeAlgo() {
+ super(NAME, DESCRIPTION);
+ }
+
/*
* (non-Javadoc)
*
@@ -100,9 +111,10 @@
* @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
*/
@Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ protected BaseSparkClassificationModel trainInternal(RDD<LabeledPoint> trainingRDD)
throws LensException {
DecisionTreeModel model = DecisionTree$.MODULE$.train(trainingRDD, algo, decisionTreeImpurity, maxDepth);
- return new DecisionTreeClassificationModel(modelId, new SparkDecisionTreeModel(model));
+ return new DecisionTreeClassificationModel(features, new SparkDecisionTreeModel(model));
}
+
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeClassificationModel.java
index 27c32f4..2797f9a 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeClassificationModel.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/DecisionTreeClassificationModel.java
@@ -18,7 +18,10 @@
*/
package org.apache.lens.ml.algo.spark.dt;
+import java.util.List;
+
import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+import org.apache.lens.ml.api.Feature;
/**
* The Class DecisionTreeClassificationModel.
@@ -26,12 +29,10 @@
public class DecisionTreeClassificationModel extends BaseSparkClassificationModel<SparkDecisionTreeModel> {
/**
- * Instantiates a new decision tree classification model.
- *
- * @param modelId the model id
- * @param model the model
+ * @param featureList
+ * @param model
*/
- public DecisionTreeClassificationModel(String modelId, SparkDecisionTreeModel model) {
- super(modelId, model);
+ public DecisionTreeClassificationModel(List<Feature> featureList, SparkDecisionTreeModel model) {
+ super(featureList, model);
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/SparkDecisionTreeModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/SparkDecisionTreeModel.java
index e561a8d..c3082c9 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/SparkDecisionTreeModel.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/dt/SparkDecisionTreeModel.java
@@ -31,7 +31,9 @@
*/
public class SparkDecisionTreeModel implements ClassificationModel {
- /** The model. */
+ /**
+ * The model.
+ */
private final DecisionTreeModel model;
/**
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansAlgo.java
index be9af18..3c5e287 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansAlgo.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansAlgo.java
@@ -18,15 +18,18 @@
*/
package org.apache.lens.ml.algo.spark.kmeans;
+import java.util.ArrayList;
import java.util.List;
import org.apache.lens.api.LensConf;
-import org.apache.lens.ml.algo.api.AlgoParam;
import org.apache.lens.ml.algo.api.Algorithm;
-import org.apache.lens.ml.algo.api.MLAlgo;
-import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.algo.api.TrainedModel;
import org.apache.lens.ml.algo.lib.AlgoArgParser;
import org.apache.lens.ml.algo.spark.HiveTableRDD;
+import org.apache.lens.ml.api.AlgoParam;
+import org.apache.lens.ml.api.DataSet;
+import org.apache.lens.ml.api.Feature;
+import org.apache.lens.ml.api.Model;
import org.apache.lens.server.api.error.LensException;
import org.apache.hadoop.hive.conf.HiveConf;
@@ -46,54 +49,71 @@
import scala.Tuple2;
+
/**
* The Class KMeansAlgo.
*/
-@Algorithm(name = "spark_kmeans_algo", description = "Spark MLLib KMeans algo")
-public class KMeansAlgo implements MLAlgo {
+public class KMeansAlgo implements Algorithm {
- /** The conf. */
+ static String description = "Spark K means algo";
+ static String name = "spark_k_means";
+
+ /**
+ * The conf.
+ */
private transient LensConf conf;
- /** The spark context. */
+ /**
+ * The spark context.
+ */
private JavaSparkContext sparkContext;
- /** The part filter. */
+ /**
+ * The part filter.
+ */
@AlgoParam(name = "partition", help = "Partition filter to be used while constructing table RDD")
private String partFilter = null;
- /** The k. */
+ /**
+ * The k.
+ */
@AlgoParam(name = "k", help = "Number of cluster")
private int k;
- /** The max iterations. */
+ /**
+ * The max iterations.
+ */
@AlgoParam(name = "maxIterations", help = "Maximum number of iterations", defaultValue = "100")
private int maxIterations = 100;
- /** The runs. */
+ /**
+ * The runs.
+ */
@AlgoParam(name = "runs", help = "Number of parallel run", defaultValue = "1")
private int runs = 1;
- /** The initialization mode. */
+ /**
+ * The initialization mode.
+ */
@AlgoParam(name = "initializationMode",
help = "initialization model, either \"random\" or \"k-means||\" (default).", defaultValue = "k-means||")
private String initializationMode = "k-means||";
@Override
public String getName() {
- return getClass().getAnnotation(Algorithm.class).name();
+ return name;
}
@Override
public String getDescription() {
- return getClass().getAnnotation(Algorithm.class).description();
+ return description;
}
/*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
- */
+ * (non-Javadoc)
+ *
+ * @see org.apache.lens.ml.MLAlgo#configure(org.apache.lens.api.LensConf)
+ */
@Override
public void configure(LensConf configuration) {
this.conf = configuration;
@@ -104,51 +124,6 @@
return conf;
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLAlgo#train(org.apache.lens.api.LensConf, java.lang.String, java.lang.String,
- * java.lang.String, java.lang.String[])
- */
- @Override
- public MLModel train(LensConf conf, String db, String table, String modelId, String... params) throws LensException {
- List<String> features = AlgoArgParser.parseArgs(this, params);
- final int[] featurePositions = new int[features.size()];
- final int NUM_FEATURES = features.size();
-
- JavaPairRDD<WritableComparable, HCatRecord> rdd = null;
- try {
- // Map feature names to positions
- Table tbl = Hive.get(toHiveConf(conf)).getTable(db, table);
- List<FieldSchema> allCols = tbl.getAllCols();
- int f = 0;
- for (int i = 0; i < tbl.getAllCols().size(); i++) {
- String colName = allCols.get(i).getName();
- if (features.contains(colName)) {
- featurePositions[f++] = i;
- }
- }
-
- rdd = HiveTableRDD.createHiveTableRDD(sparkContext, toHiveConf(conf), db, table, partFilter);
- JavaRDD<Vector> trainableRDD = rdd.map(new Function<Tuple2<WritableComparable, HCatRecord>, Vector>() {
- @Override
- public Vector call(Tuple2<WritableComparable, HCatRecord> v1) throws Exception {
- HCatRecord hCatRecord = v1._2();
- double[] arr = new double[NUM_FEATURES];
- for (int i = 0; i < NUM_FEATURES; i++) {
- Object val = hCatRecord.get(featurePositions[i]);
- arr[i] = val == null ? 0d : (Double) val;
- }
- return Vectors.dense(arr);
- }
- });
-
- KMeansModel model = KMeans.train(trainableRDD.rdd(), k, maxIterations, runs, initializationMode);
- return new KMeansClusteringModel(modelId, model);
- } catch (Exception e) {
- throw new LensException("KMeans algo failed for " + db + "." + table, e);
- }
- }
/**
* To hive conf.
@@ -163,4 +138,57 @@
}
return hiveConf;
}
+
+ @Override
+ public List<AlgoParam> getParams() {
+ return null;
+ }
+
+ @Override
+ public TrainedModel train(Model model, DataSet trainingDataSet) throws LensException {
+ AlgoArgParser.parseArgs(this, model.getAlgoSpec().getAlgoParams());
+ List<String> features = new ArrayList<>();
+ for (Feature feature : model.getFeatureSpec()) {
+ features.add(feature.getDataColumn());
+ }
+
+ final int[] featurePositions = new int[features.size()];
+ final int NUM_FEATURES = features.size();
+
+ JavaPairRDD<WritableComparable, HCatRecord> rdd = null;
+
+ try {
+ // Map feature names to positions
+ Table tbl = Hive.get(toHiveConf(conf)).getTable(trainingDataSet.getDbName(), trainingDataSet.getTableName());
+ List<FieldSchema> allCols = tbl.getAllCols();
+ int f = 0;
+ for (int i = 0; i < tbl.getAllCols().size(); i++) {
+ String colName = allCols.get(i).getName();
+ if (features.contains(colName)) {
+ featurePositions[f++] = i;
+ }
+ }
+
+ rdd = HiveTableRDD.createHiveTableRDD(sparkContext, toHiveConf(conf), trainingDataSet.getDbName(),
+ trainingDataSet.getTableName(), partFilter);
+ JavaRDD<Vector> trainableRDD = rdd.map(new Function<Tuple2<WritableComparable, HCatRecord>, Vector>() {
+ @Override
+ public Vector call(Tuple2<WritableComparable, HCatRecord> v1) throws Exception {
+ HCatRecord hCatRecord = v1._2();
+ double[] arr = new double[NUM_FEATURES];
+ for (int i = 0; i < NUM_FEATURES; i++) {
+ Object val = hCatRecord.get(featurePositions[i]);
+ arr[i] = val == null ? 0d : (Double) val;
+ }
+ return Vectors.dense(arr);
+ }
+ });
+
+ KMeansModel kMeansModel = KMeans.train(trainableRDD.rdd(), k, maxIterations, runs, initializationMode);
+ return new KMeansClusteringModel(model.getFeatureSpec(), kMeansModel);
+ } catch (Exception e) {
+ throw new LensException(
+ "KMeans algo failed for " + trainingDataSet.getDbName() + "." + trainingDataSet.getTableName(), e);
+ }
+ }
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansClusteringModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansClusteringModel.java
index 62dc536..6341268 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansClusteringModel.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/kmeans/KMeansClusteringModel.java
@@ -18,7 +18,12 @@
*/
package org.apache.lens.ml.algo.spark.kmeans;
-import org.apache.lens.ml.algo.api.MLModel;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lens.ml.algo.api.TrainedModel;
+import org.apache.lens.ml.api.Feature;
+import org.apache.lens.server.api.error.LensException;
import org.apache.spark.mllib.clustering.KMeansModel;
import org.apache.spark.mllib.linalg.Vectors;
@@ -26,42 +31,39 @@
/**
* The Class KMeansClusteringModel.
*/
-public class KMeansClusteringModel extends MLModel<Integer> {
-
- /** The model. */
- private final KMeansModel model;
-
- /** The model id. */
- private final String modelId;
+public class KMeansClusteringModel implements TrainedModel<Integer> {
/**
- * Instantiates a new k means clustering model.
- *
- * @param modelId the model id
- * @param model the model
+ * The model.
*/
- public KMeansClusteringModel(String modelId, KMeansModel model) {
- this.model = model;
- this.modelId = modelId;
+ private final KMeansModel kMeansModel;
+
+ private List<Feature> featureList;
+
+ /**
+ * Instantiates a new k means clustering model.
+ *
+ * @param model the model
+ */
+ public KMeansClusteringModel(List<Feature> featureList, KMeansModel model) {
+ this.kMeansModel = model;
+ this.featureList = featureList;
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.MLModel#predict(java.lang.Object[])
- */
+
@Override
- public Integer predict(Object... args) {
- // Convert the params to array of double
- double[] arr = new double[args.length];
- for (int i = 0; i < args.length; i++) {
- if (args[i] != null) {
- arr[i] = (Double) args[i];
+ public Integer predict(Map<String, String> featureVector) throws LensException {
+ double[] featureArray = new double[featureList.size()];
+ int i = 0;
+ for (Feature feature : featureList) {
+ String featureValue = featureVector.get(feature.getName());
+ if (featureValue == null || featureValue.isEmpty()) {
+ throw new LensException("Error while predicting: input featureVector doesn't contain all required features : "
+ + "Feature Name: " + feature.getName());
} else {
- arr[i] = 0d;
+ featureArray[i++] = Double.parseDouble(featureValue);
}
}
-
- return model.predict(Vectors.dense(arr));
+ return kMeansModel.predict(Vectors.dense(featureArray));
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogisticRegressionAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogisticRegressionAlgo.java
index c2f97af..7e21871 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogisticRegressionAlgo.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogisticRegressionAlgo.java
@@ -20,10 +20,9 @@
import java.util.Map;
-import org.apache.lens.ml.algo.api.AlgoParam;
-import org.apache.lens.ml.algo.api.Algorithm;
import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+import org.apache.lens.ml.api.AlgoParam;
import org.apache.lens.server.api.error.LensException;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
@@ -34,18 +33,24 @@
/**
* The Class LogisticRegressionAlgo.
*/
-@Algorithm(name = "spark_logistic_regression", description = "Spark logistic regression algo")
public class LogisticRegressionAlgo extends BaseSparkAlgo {
- /** The iterations. */
+ static final String DESCRIPTION = "Spark logistic regression algo";
+ static final String NAME = "spark_logistic_regression";
+
+ /**
+ * The iterations.
+ */
@AlgoParam(name = "iterations", help = "Max number of iterations", defaultValue = "100")
private int iterations;
-
- /** The step size. */
+ /**
+ * The step size.
+ */
@AlgoParam(name = "stepSize", help = "Step size", defaultValue = "1.0d")
private double stepSize;
-
- /** The min batch fraction. */
+ /**
+ * The min batch fraction.
+ */
@AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
private double minBatchFraction;
@@ -59,6 +64,14 @@
super(name, description);
}
+ public LogisticRegressionAlgo(String name) {
+ super(name, DESCRIPTION);
+ }
+
+ public LogisticRegressionAlgo() {
+ super(NAME, DESCRIPTION);
+ }
+
/*
* (non-Javadoc)
*
@@ -77,10 +90,11 @@
* @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
*/
@Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ protected BaseSparkClassificationModel trainInternal(RDD<LabeledPoint> trainingRDD)
throws LensException {
+
LogisticRegressionModel lrModel = LogisticRegressionWithSGD.train(trainingRDD, iterations, stepSize,
minBatchFraction);
- return new LogitRegressionClassificationModel(modelId, lrModel);
+ return new LogitRegressionClassificationModel(features, lrModel);
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogitRegressionClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogitRegressionClassificationModel.java
index a4206e5..d6e07a5 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogitRegressionClassificationModel.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/lr/LogitRegressionClassificationModel.java
@@ -18,7 +18,10 @@
*/
package org.apache.lens.ml.algo.spark.lr;
+import java.util.List;
+
import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+import org.apache.lens.ml.api.Feature;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
@@ -28,12 +31,10 @@
public class LogitRegressionClassificationModel extends BaseSparkClassificationModel<LogisticRegressionModel> {
/**
- * Instantiates a new logit regression classification model.
- *
- * @param modelId the model id
- * @param model the model
+ * @param featureList
+ * @param model
*/
- public LogitRegressionClassificationModel(String modelId, LogisticRegressionModel model) {
- super(modelId, model);
+ public LogitRegressionClassificationModel(List<Feature> featureList, LogisticRegressionModel model) {
+ super(featureList, model);
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesAlgo.java
index c484dfe..e9709d3 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesAlgo.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesAlgo.java
@@ -20,23 +20,27 @@
import java.util.Map;
-import org.apache.lens.ml.algo.api.AlgoParam;
-import org.apache.lens.ml.algo.api.Algorithm;
import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+import org.apache.lens.ml.api.AlgoParam;
import org.apache.lens.server.api.error.LensException;
import org.apache.spark.mllib.classification.NaiveBayes;
+import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
/**
* The Class NaiveBayesAlgo.
*/
-@Algorithm(name = "spark_naive_bayes", description = "Spark Naive Bayes classifier algo")
public class NaiveBayesAlgo extends BaseSparkAlgo {
- /** The lambda. */
+ static final String DESCRIPTION = "Spark naive bayes algo";
+ static final String NAME = "spark_naive_bayes";
+
+ /**
+ * The lambda.
+ */
@AlgoParam(name = "lambda", help = "Lambda parameter for naive bayes learner", defaultValue = "1.0d")
private double lambda = 1.0;
@@ -50,6 +54,10 @@
super(name, description);
}
+ public NaiveBayesAlgo() {
+ super(NAME, DESCRIPTION);
+ }
+
/*
* (non-Javadoc)
*
@@ -66,8 +74,9 @@
* @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
*/
@Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ protected BaseSparkClassificationModel trainInternal(RDD<LabeledPoint> trainingRDD)
throws LensException {
- return new NaiveBayesClassificationModel(modelId, NaiveBayes.train(trainingRDD, lambda));
+ NaiveBayesModel naiveBayesModel = NaiveBayes.train(trainingRDD, lambda);
+ return new NaiveBayesClassificationModel(features, NaiveBayes.train(trainingRDD, lambda));
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesClassificationModel.java
index 26d39df..edbcba4 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesClassificationModel.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/nb/NaiveBayesClassificationModel.java
@@ -18,7 +18,10 @@
*/
package org.apache.lens.ml.algo.spark.nb;
+import java.util.List;
+
import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+import org.apache.lens.ml.api.Feature;
import org.apache.spark.mllib.classification.NaiveBayesModel;
@@ -28,12 +31,10 @@
public class NaiveBayesClassificationModel extends BaseSparkClassificationModel<NaiveBayesModel> {
/**
- * Instantiates a new naive bayes classification model.
- *
- * @param modelId the model id
- * @param model the model
+ * @param featureList
+ * @param model
*/
- public NaiveBayesClassificationModel(String modelId, NaiveBayesModel model) {
- super(modelId, model);
+ public NaiveBayesClassificationModel(List<Feature> featureList, NaiveBayesModel model) {
+ super(featureList, model);
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMAlgo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMAlgo.java
index 4b14d66..8e3083d 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMAlgo.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMAlgo.java
@@ -20,10 +20,9 @@
import java.util.Map;
-import org.apache.lens.ml.algo.api.AlgoParam;
-import org.apache.lens.ml.algo.api.Algorithm;
import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+import org.apache.lens.ml.api.AlgoParam;
import org.apache.lens.server.api.error.LensException;
import org.apache.spark.mllib.classification.SVMModel;
@@ -34,22 +33,31 @@
/**
* The Class SVMAlgo.
*/
-@Algorithm(name = "spark_svm", description = "Spark SVML classifier algo")
public class SVMAlgo extends BaseSparkAlgo {
- /** The min batch fraction. */
+ static final String DESCRIPTION = "Spark SVM algo";
+ static final String NAME = "spark_svm";
+ /**
+ * The min batch fraction.
+ */
@AlgoParam(name = "minBatchFraction", help = "Fraction for batched learning", defaultValue = "1.0d")
private double minBatchFraction;
- /** The reg param. */
+ /**
+ * The reg param.
+ */
@AlgoParam(name = "regParam", help = "regularization parameter for gradient descent", defaultValue = "1.0d")
private double regParam;
- /** The step size. */
+ /**
+ * The step size.
+ */
@AlgoParam(name = "stepSize", help = "Iteration step size", defaultValue = "1.0d")
private double stepSize;
- /** The iterations. */
+ /**
+ * The iterations.
+ */
@AlgoParam(name = "iterations", help = "Number of iterations", defaultValue = "100")
private int iterations;
@@ -63,6 +71,10 @@
super(name, description);
}
+ public SVMAlgo() {
+ super(NAME, DESCRIPTION);
+ }
+
/*
* (non-Javadoc)
*
@@ -82,9 +94,9 @@
* @see org.apache.lens.ml.spark.algos.BaseSparkAlgo#trainInternal(java.lang.String, org.apache.spark.rdd.RDD)
*/
@Override
- protected BaseSparkClassificationModel trainInternal(String modelId, RDD<LabeledPoint> trainingRDD)
+ protected BaseSparkClassificationModel trainInternal(RDD<LabeledPoint> trainingRDD)
throws LensException {
SVMModel svmModel = SVMWithSGD.train(trainingRDD, iterations, stepSize, regParam, minBatchFraction);
- return new SVMClassificationModel(modelId, svmModel);
+ return new SVMClassificationModel(features, svmModel);
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMClassificationModel.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMClassificationModel.java
index 433c0f9..d793d4a 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMClassificationModel.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/spark/svm/SVMClassificationModel.java
@@ -18,7 +18,10 @@
*/
package org.apache.lens.ml.algo.spark.svm;
+import java.util.List;
+
import org.apache.lens.ml.algo.spark.BaseSparkClassificationModel;
+import org.apache.lens.ml.api.Feature;
import org.apache.spark.mllib.classification.SVMModel;
@@ -28,12 +31,10 @@
public class SVMClassificationModel extends BaseSparkClassificationModel<SVMModel> {
/**
- * Instantiates a new SVM classification model.
- *
- * @param modelId the model id
- * @param model the model
+ * @param featureList
+ * @param model
*/
- public SVMClassificationModel(String modelId, SVMModel model) {
- super(modelId, model);
+ public SVMClassificationModel(List<Feature> featureList, SVMModel model) {
+ super(featureList, model);
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Algo.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Algo.java
new file mode 100644
index 0000000..3986cf7
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Algo.java
@@ -0,0 +1,42 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import java.util.List;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.Setter;
+
+@AllArgsConstructor
+public class Algo {
+
+ @Getter
+ @Setter
+ String name;
+
+ @Getter
+ @Setter
+ String description;
+
+ @Getter
+ @Setter
+ List<AlgoParam> params;
+
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/AlgoParam.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/AlgoParam.java
similarity index 86%
rename from lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/AlgoParam.java
rename to lens-ml-lib/src/main/java/org/apache/lens/ml/api/AlgoParam.java
index e0d13c0..f3f4d68 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/algo/api/AlgoParam.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/AlgoParam.java
@@ -16,13 +16,15 @@
* specific language governing permissions and limitations
* under the License.
*/
-package org.apache.lens.ml.algo.api;
+package org.apache.lens.ml.api;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
+import javax.xml.bind.annotation.XmlElement;
+
/**
* The Interface AlgoParam.
*/
@@ -35,19 +37,21 @@
*
* @return the string
*/
- String name();
+ @XmlElement String name();
/**
* Help.
*
* @return the string
*/
- String help();
+ @XmlElement String help();
/**
* Default value.
*
* @return the string
*/
- String defaultValue() default "None";
+ @XmlElement String defaultValue() default "None";
+
+
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/AlgoParameter.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/AlgoParameter.java
new file mode 100644
index 0000000..a913e43
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/AlgoParameter.java
@@ -0,0 +1,98 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import java.io.StringWriter;
+import java.lang.annotation.Annotation;
+
+import javax.xml.bind.JAXBContext;
+import javax.xml.bind.JAXBException;
+import javax.xml.bind.Marshaller;
+import javax.xml.bind.annotation.XmlElement;
+import javax.xml.bind.annotation.XmlRootElement;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.Setter;
+
+/**
+ * Created by vikassingh on 10/07/15.
+ */
+@XmlRootElement
+@AllArgsConstructor
+@NoArgsConstructor
+public class AlgoParameter implements AlgoParam {
+
+
+ private static final JAXBContext JAXB_CONTEXT;
+
+ static {
+ try {
+ JAXB_CONTEXT = JAXBContext.newInstance(AlgoParameter.class);
+ } catch (JAXBException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Getter
+ @Setter
+ @XmlElement
+ String name;
+ @Getter
+ @Setter
+ @XmlElement
+ String help;
+ @Getter
+ @Setter
+ @XmlElement
+ String defaultValue;
+
+ @Override
+ public String name() {
+ return name;
+ }
+
+ @Override
+ public String help() {
+ return help;
+ }
+
+ @Override
+ public String defaultValue() {
+ return defaultValue;
+ }
+
+ @Override
+ public Class<? extends Annotation> annotationType() {
+ return null;
+ }
+
+ @Override
+ public String toString() {
+ try {
+ StringWriter stringWriter = new StringWriter();
+ Marshaller marshaller = JAXB_CONTEXT.createMarshaller();
+ marshaller.marshal(this, stringWriter);
+ return stringWriter.toString();
+ } catch (JAXBException e) {
+ return e.getMessage();
+ }
+ }
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/AlgoSpec.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/AlgoSpec.java
new file mode 100644
index 0000000..76e384f
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/AlgoSpec.java
@@ -0,0 +1,69 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import java.util.Map;
+
+import javax.xml.bind.annotation.*;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.Setter;
+
+/**
+ * The Algo Spec class. This class works as a particular instance of an Algorithm run. since it contains the exact
+ * algoParams with which an Algorithm was run.
+ */
+@AllArgsConstructor
+@NoArgsConstructor
+@XmlRootElement
+@XmlAccessorType(XmlAccessType.FIELD)
+public class AlgoSpec {
+ @Getter
+ @Setter
+ @XmlElement
+ private String algo;
+ @Getter
+ @Setter
+ @XmlElementWrapper
+ private Map<String, String> algoParams;
+
+ /*private static final JAXBContext JAXB_CONTEXT;
+
+ /*static {
+ try {
+ JAXB_CONTEXT = JAXBContext.newInstance(AlgoSpec.class);
+ } catch (JAXBException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public String toString() {
+ try {
+ StringWriter stringWriter = new StringWriter();
+ Marshaller marshaller = JAXB_CONTEXT.createMarshaller();
+ marshaller.marshal(this, stringWriter);
+ return stringWriter.toString();
+ } catch (JAXBException e) {
+ return e.getMessage();
+ }
+ }*/
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/DataSet.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/DataSet.java
new file mode 100644
index 0000000..b5dd7b9
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/DataSet.java
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import javax.xml.bind.annotation.XmlAccessType;
+import javax.xml.bind.annotation.XmlAccessorType;
+import javax.xml.bind.annotation.XmlElement;
+import javax.xml.bind.annotation.XmlRootElement;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.Setter;
+
+/**
+ * Contains meta data for a data set. A data set is identified by name, tableName and database.
+ */
+@AllArgsConstructor
+@NoArgsConstructor
+@XmlRootElement
+@XmlAccessorType(XmlAccessType.FIELD)
+public class DataSet {
+ @Getter
+ @Setter
+ @XmlElement
+ private String dsName;
+
+ @Getter
+ @Setter
+ @XmlElement
+ private String tableName;
+
+ @Getter
+ @Setter
+ @XmlElement
+ private String dbName;
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Evaluation.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Evaluation.java
new file mode 100644
index 0000000..e273b5a
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Evaluation.java
@@ -0,0 +1,80 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import java.util.Date;
+
+import javax.xml.bind.annotation.XmlAccessType;
+import javax.xml.bind.annotation.XmlAccessorType;
+import javax.xml.bind.annotation.XmlElement;
+import javax.xml.bind.annotation.XmlRootElement;
+
+import org.apache.lens.api.LensSessionHandle;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.Setter;
+
+/**
+ * Contains meta data for an Evaluation. Evaluation captures metadata of the process of evaluating data contained in
+ * inputDataSetName against modelInstanceId.
+ */
+@NoArgsConstructor
+@AllArgsConstructor
+@XmlRootElement
+@XmlAccessorType(XmlAccessType.FIELD)
+public class Evaluation implements MLProcess {
+
+ @XmlElement
+ @Getter
+ @Setter
+ String id;
+
+ @XmlElement
+ @Getter
+ @Setter
+ Date startTime;
+
+ @XmlElement
+ @Getter
+ @Setter
+ Date finishTime;
+
+ @XmlElement
+ @Getter
+ @Setter
+ Status status;
+
+ @Getter
+ @Setter
+ @XmlElement
+ LensSessionHandle lensSessionHandle;
+
+ @Getter
+ @Setter
+ @XmlElement
+ String modelInstanceId;
+
+ @Getter
+ @Setter
+ @XmlElement
+ String inputDataSetName;
+
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Feature.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Feature.java
new file mode 100644
index 0000000..97e931d
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Feature.java
@@ -0,0 +1,68 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import java.io.Serializable;
+
+import javax.xml.bind.annotation.XmlAccessType;
+import javax.xml.bind.annotation.XmlAccessorType;
+import javax.xml.bind.annotation.XmlElement;
+import javax.xml.bind.annotation.XmlRootElement;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.Setter;
+
+/**
+ * Feature class. Equivalent of a feature of a Machine Learning model.
+ */
+
+@AllArgsConstructor
+@XmlRootElement
+@NoArgsConstructor
+@XmlAccessorType(XmlAccessType.FIELD)
+public class Feature implements Serializable {
+
+ @Getter
+ @Setter
+ @XmlElement
+ private String name;
+ @Getter
+ @Setter
+ @XmlElement
+ private String description;
+ @Getter
+ @Setter
+ @XmlElement
+ private Type type;
+ /**
+ * Name of the column of the table to which this feature is mapped.
+ */
+ @Getter
+ @Setter
+ @XmlElement
+ private String dataColumn;
+
+ @XmlRootElement
+ public enum Type {
+ Categorical,
+ Continuous
+ }
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/LensML.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/LensML.java
index 23b5437..697c2d2 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/LensML.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/LensML.java
@@ -22,140 +22,187 @@
import java.util.Map;
import org.apache.lens.api.LensSessionHandle;
-import org.apache.lens.ml.algo.api.MLAlgo;
-import org.apache.lens.ml.algo.api.MLModel;
import org.apache.lens.server.api.error.LensException;
/**
- * Lens's machine learning interface used by client code as well as Lens ML service.
+ * LensML interface to train and evaluate Machine learning models.
*/
public interface LensML {
- /** Name of ML service */
+ /**
+ * Name of ML service
+ */
String NAME = "ml";
/**
- * Get list of available machine learning algorithms
+ * @return the list of supported Algos.
+ */
+ List<Algo> getAlgos();
+
+ /**
+ * Get Algo given an algo name.
*
+ * @param name of the Algo
+ * @return Algo definition
+ */
+ Algo getAlgo(String name) throws LensException;
+
+ /**
+ * Creates a data set from an existing table
+ *
+ * @param name the name of the data set
+ * @param dataTable the data table
* @return
*/
- List<String> getAlgorithms();
+ void createDataSet(String name, String dataTable, String dataBase) throws LensException;
+
+ void createDataSet(DataSet dataSet) throws LensException;
/**
- * Get user friendly information about parameters accepted by the algorithm.
+ * Creates a data set from a query
*
- * @param algorithm the algorithm
- * @return map of param key to its help message
+ * @param name the name of the data set
+ * @param query query
+ * @return
*/
- Map<String, String> getAlgoParamDescription(String algorithm);
+ String createDataSetFromQuery(String name, String query);
/**
- * Get a algo object instance which could be used to generate a model of the given algorithm.
+ * Returns the data set given the name
*
- * @param algorithm the algorithm
- * @return the algo for name
- * @throws LensException the lens exception
+ * @param name the name of the data set
+ * @return
*/
- MLAlgo getAlgoForName(String algorithm) throws LensException;
+ DataSet getDataSet(String name) throws LensException;
+
+ void deleteDataSet(String dataSetName) throws LensException;
/**
- * Create a model using the given HCatalog table as input. The arguments should contain information needeed to
- * generate the model.
+ * Creates a Model with a chosen Algo and its parameters, feature list and the label.
*
- * @param table the table
- * @param algorithm the algorithm
- * @param args the args
- * @return Unique ID of the model created after training is complete
- * @throws LensException the lens exception
+ * @param name the name of Model
+ * @param algo the name of the Alog
+ * @param algoParams the algo parameters
+ * @param features list of features
+ * @param label the label to use
+ * @returns Model id of the created model
*/
- String train(String table, String algorithm, String[] args) throws LensException;
+ void createModel(String name, String algo, Map<String, String> algoParams,
+ List<Feature> features, Feature label, LensSessionHandle lensSessionHandle) throws LensException;
+
+ void createModel(Model model) throws LensException;
/**
- * Get model IDs for the given algorithm.
+ * Get Model given a modelId
*
- * @param algorithm the algorithm
- * @return the models
- * @throws LensException the lens exception
- */
- List<String> getModels(String algorithm) throws LensException;
-
- /**
- * Get a model instance given the algorithm name and model ID.
- *
- * @param algorithm the algorithm
- * @param modelId the model id
+ * @param modelId the id of the model
* @return the model
- * @throws LensException the lens exception
*/
- MLModel getModel(String algorithm, String modelId) throws LensException;
+ Model getModel(String modelId) throws LensException;
+
+ void deleteModel(String modelId) throws LensException;
/**
- * Get the FS location where model instance is saved.
+ * Train a model. This calls returns immediately after triggering the training
+ * asynchronously, the readiness of the ModelInstance should be checked based on its status.
*
- * @param algorithm the algorithm
- * @param modelID the model id
- * @return the model path
+ * @param modelId the model id
+ * @param dataSetName data set name to use
+ * @return ModelInstance id the handle to the ModelInstance instance
*/
- String getModelPath(String algorithm, String modelID);
+ String trainModel(String modelId, String dataSetName, LensSessionHandle lensSessionHandle) throws LensException;
/**
- * Evaluate model by running it against test data contained in the given table.
+ * Get Trained Model
*
- * @param session the session
- * @param table the table
- * @param algorithm the algorithm
- * @param modelID the model id
- * @return Test report object containing test output table, and various evaluation metrics
- * @throws LensException the lens exception
+ * @param modelInstanceId the id of the ModelInstance
+ * @return Trained model
*/
- MLTestReport testModel(LensSessionHandle session, String table, String algorithm, String modelID,
- String outputTable) throws LensException;
+ ModelInstance getModelInstance(String modelInstanceId) throws LensException;
/**
- * Get test reports for an algorithm.
+ * Cancels the creation of modelInstance.
*
- * @param algorithm the algorithm
- * @return the test reports
- * @throws LensException the lens exception
+ * @param modelInstanceId
+ * @return true on successful cancellation false otherwise.
*/
- List<String> getTestReports(String algorithm) throws LensException;
+ boolean cancelModelInstance(String modelInstanceId, LensSessionHandle lensSessionHandle) throws LensException;
+
+ void deleteModelInstance(String modelInstanceId) throws LensException;
+
/**
- * Get a test report by ID.
+ * Get the list of ModelInstance for a given model
*
- * @param algorithm the algorithm
- * @param reportID the report id
- * @return the test report
- * @throws LensException the lens exception
+ * @param modelId the model id
+ * @return List of trained models
*/
- MLTestReport getTestReport(String algorithm, String reportID) throws LensException;
+ List<ModelInstance> getAllModelInstances(String modelId);
/**
- * Online predict call given a model ID, algorithm name and sample feature values.
+ * Evaluate a ModelInstance. This calls returns immediately after triggering the training
+ * asynchronously, the readiness of the Evaluation should be checked based on its status.
*
- * @param algorithm the algorithm
- * @param modelID the model id
- * @param features the features
- * @return prediction result
- * @throws LensException the lens exception
+ * @param modelInstanceId the trained model id
+ * @param dataSetName the data to use to evaluate
+ * @return the evaluationId
*/
- Object predict(String algorithm, String modelID, Object[] features) throws LensException;
+ String evaluate(String modelInstanceId, String dataSetName, LensSessionHandle lensSessionHandle) throws LensException;
/**
- * Permanently delete a model instance.
+ * Get Evaluation
*
- * @param algorithm the algorithm
- * @param modelID the model id
- * @throws LensException the lens exception
+ * @param evalId
+ * @return the evaluation
*/
- void deleteModel(String algorithm, String modelID) throws LensException;
+ Evaluation getEvaluation(String evalId) throws LensException;
+
+ void deleteEvaluation(String evaluationId) throws LensException;
/**
- * Permanently delete a test report instance.
+ * Cancels the Evaluation
*
- * @param algorithm the algorithm
- * @param reportID the report id
- * @throws LensException the lens exception
+ * @param evalId
+ * @return true on successful cancellation false otherwise.
*/
- void deleteTestReport(String algorithm, String reportID) throws LensException;
+ boolean cancelEvaluation(String evalId, LensSessionHandle lensSessionHandle) throws LensException;
+
+ /**
+ * Batch predicts for a given data set. This calls returns immediately after triggering the prediction
+ * asynchronously, the readiness of the prediction should be checked based on its status.
+ *
+ * @param modelInstanceId
+ * @param dataSetName
+ * @return prediction id
+ */
+ String predict(String modelInstanceId, String dataSetName, LensSessionHandle lensSessionHandle) throws LensException;
+
+ /**
+ * Get BatchPrediction information
+ *
+ * @param predictionId
+ * @return
+ */
+ Prediction getPrediction(String predictionId) throws LensException;
+
+ void deletePrediction(String predictionId) throws LensException;
+
+
+ /**
+ * Cancels the Prediction
+ *
+ * @param predictionId
+ * @return true on successful cancellation false otherwise.
+ */
+ boolean cancelPrediction(String predictionId, LensSessionHandle lensSessionHandle) throws LensException;
+
+ /**
+ * Predict for a given feature vector
+ *
+ * @param modelInstanceId
+ * @param featureVector the key is feature name.
+ * @return predicted value
+ */
+ String predict(String modelInstanceId, Map<String, String> featureVector) throws LensException;
+
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLConfConstants.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLConfConstants.java
new file mode 100644
index 0000000..d69508c
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLConfConstants.java
@@ -0,0 +1,90 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import org.apache.lens.server.api.error.LensException;
+
+/**
+ * The class MLConfConstants.
+ */
+public class MLConfConstants {
+
+ public static final String SERVER_PFX = "lens.server.ml.";
+ /**
+ * Minimum thread pool size for worked threads who performing ML process i.e. running MLEvaluation,
+ * prediction or evaluation
+ */
+ public static final int DEFAULT_EXECUTOR_POOL_MIN_THREADS = 3;
+ public static final String EXECUTOR_POOL_MIN_THREADS = SERVER_PFX + "executor.pool.min.threads";
+ /**
+ * Maximum thread pool size for worked threads who performing ML process i.e. running MLEvaluation,
+ * prediction or evaluation
+ */
+ public static final int DEFAULT_EXECUTOR_POOL_MAX_THREADS = 100;
+ public static final String EXECUTOR_POOL_MAX_THREADS = SERVER_PFX + "executor.pool.max.threads";
+ /**
+ * keep alive time for threads in the MLProcess thread pool
+ */
+ public static final int DEFAULT_CREATOR_POOL_KEEP_ALIVE_MILLIS = 60000;
+
+ /**
+ * Minimum thread pool size for worked threads who performing ML process i.e. running MLEvaluation,
+ * prediction or evaluation
+ */
+ /**
+ * This is the time a MLprocess will be in cache after it is finished. After which request for that MLprocess will
+ * be server from the Meta store.
+ */
+ public static final long DEFAULT_ML_PROCESS_CACHE_LIFE = 1000 * 60 * 10; //10 min.
+ public static final String ML_PROCESS_CACHE_LIFE = SERVER_PFX + "mlprocess.cache.life";
+ /**
+ * This is the maximum time allowed for a MLProcess to run. After which it will be killed by the MLProcesPurger
+ * thread.
+ */
+ public static final long DEFAULT_ML_PROCESS_MAX_LIFE = 1000 * 60 * 60 * 10; // 10 hours.
+ public static final String ML_PROCESS_MAX_LIFE = SERVER_PFX + "mlprocess.max.life";
+ /**
+ * The Constant UDF_NAME.
+ */
+ public static final String UDF_NAME = "predict";
+ /**
+ * prefix for output table which gets created for any prediction.
+ */
+ public static final String PREDICTION_OUTPUT_TABLE_PREFIX = "prediction_";
+ /**
+ * prefix for output table which gets created for any evaluation.
+ */
+ public static final String EVALUATION_OUTPUT_TABLE_PREFIX = "evaluation_";
+ public static final String ML_META_STORE_DB_DRIVER_NAME = "";
+ public static final String DEFAULT_ML_META_STORE_DB_DRIVER_NAME = "com.mysql.jdbc.Driver";
+ public static final String ML_META_STORE_DB_JDBC_URL = "";
+ public static final String DEFAULT_ML_META_STORE_DB_JDBC_URL = "jdbc:mysql://localhost/lens_ml";
+ public static final String ML_META_STORE_DB_JDBC_USER = "";
+ public static final String DEFAULT_ML_META_STORE_DB_USER = "root";
+ public static final String ML_META_STORE_DB_JDBC_PASS = "";
+ public static final String DEFAULT_ML_META_STORE_DB_PASS = "";
+ public static final String ML_META_STORE_DB_VALIDATION_QUERY = "";
+ public static final String DEFAULT_ML_META_STORE_DB_VALIDATION_QUERY = "select 1 from datasets";
+ public static final String ML_META_STORE_DB_SIZE = "";
+ public static final int DEFAULT_ML_META_STORE_DB_SIZE = 10;
+
+ private MLConfConstants() throws LensException {
+ throw new LensException("Can't instantiate MLConfConstants");
+ }
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLProcess.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLProcess.java
new file mode 100644
index 0000000..8f76435
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLProcess.java
@@ -0,0 +1,50 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import java.util.Date;
+
+import org.apache.lens.api.LensSessionHandle;
+
+/**
+ * Interface MLProcess for Process which go through ML LifeCycle of Submitting, Polling and Completion.
+ */
+
+public interface MLProcess {
+
+ String getId();
+
+ void setId(String id);
+
+ Date getStartTime();
+
+ void setStartTime(Date time);
+
+ Date getFinishTime();
+
+ void setFinishTime(Date time);
+
+ Status getStatus();
+
+ void setStatus(Status status);
+
+ LensSessionHandle getLensSessionHandle();
+
+ void setLensSessionHandle(LensSessionHandle lensSessionHandle);
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLTestReport.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLTestReport.java
deleted file mode 100644
index 965161a..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/MLTestReport.java
+++ /dev/null
@@ -1,95 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.api;
-
-import java.io.Serializable;
-import java.util.List;
-
-import lombok.Getter;
-import lombok.NoArgsConstructor;
-import lombok.Setter;
-import lombok.ToString;
-
-/**
- * Instantiates a new ML test report.
- */
-@NoArgsConstructor
-@ToString
-public class MLTestReport implements Serializable {
-
- /** The test table. */
- @Getter
- @Setter
- private String testTable;
-
- /** The output table. */
- @Getter
- @Setter
- private String outputTable;
-
- /** The output column. */
- @Getter
- @Setter
- private String outputColumn;
-
- /** The label column. */
- @Getter
- @Setter
- private String labelColumn;
-
- /** The feature columns. */
- @Getter
- @Setter
- private List<String> featureColumns;
-
- /** The algorithm. */
- @Getter
- @Setter
- private String algorithm;
-
- /** The model id. */
- @Getter
- @Setter
- private String modelID;
-
- /** The report id. */
- @Getter
- @Setter
- private String reportID;
-
- /** The query id. */
- @Getter
- @Setter
- private String queryID;
-
- /** The test output path. */
- @Getter
- @Setter
- private String testOutputPath;
-
- /** The prediction result column. */
- @Getter
- @Setter
- private String predictionResultColumn;
-
- /** The lens query id. */
- @Getter
- @Setter
- private String lensQueryID;
-}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Model.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Model.java
new file mode 100644
index 0000000..2a873ef
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Model.java
@@ -0,0 +1,82 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import java.util.List;
+
+import javax.xml.bind.annotation.XmlAccessType;
+import javax.xml.bind.annotation.XmlAccessorType;
+import javax.xml.bind.annotation.XmlElement;
+import javax.xml.bind.annotation.XmlRootElement;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.Setter;
+
+/**
+ * The Model class. Contains meta data for a model creation. Algorithm to use, list of features and label. This doesn't
+ * contains the actual data for training the model (which is separated and stored in model instance class).
+ */
+@AllArgsConstructor
+@NoArgsConstructor
+@XmlRootElement
+@XmlAccessorType(XmlAccessType.FIELD)
+public class Model {
+
+ @Getter
+ @Setter
+ @XmlElement
+ private String name;
+
+ @Getter
+ @Setter
+ private AlgoSpec algoSpec;
+
+ @Getter
+ @Setter
+ @XmlElement
+ private List<Feature> featureSpec;
+
+ @Getter
+ @Setter
+ private Feature labelSpec;
+
+ /*private static final JAXBContext JAXB_CONTEXT;
+
+ static {
+ try {
+ JAXB_CONTEXT = JAXBContext.newInstance(Model.class);
+ } catch (JAXBException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public String toString() {
+ try {
+ StringWriter stringWriter = new StringWriter();
+ Marshaller marshaller = JAXB_CONTEXT.createMarshaller();
+ marshaller.marshal(this, stringWriter);
+ return stringWriter.toString();
+ } catch (JAXBException e) {
+ return e.getMessage();
+ }
+ }*/
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelInstance.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelInstance.java
new file mode 100644
index 0000000..a3f25ca
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelInstance.java
@@ -0,0 +1,91 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import java.util.Date;
+
+import javax.xml.bind.annotation.XmlAccessType;
+import javax.xml.bind.annotation.XmlAccessorType;
+import javax.xml.bind.annotation.XmlElement;
+import javax.xml.bind.annotation.XmlRootElement;
+
+import org.apache.lens.api.LensSessionHandle;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.Setter;
+
+/**
+ * The Model Instance Class. Contains meta data for TrainedModel i.e. information about the algorithm used for
+ * training and specification about features and label. Model instance captures meta data for the process of training
+ * Model modelId using test data from dataSet.
+ */
+@AllArgsConstructor
+@XmlRootElement
+@NoArgsConstructor
+@XmlAccessorType(XmlAccessType.FIELD)
+public class ModelInstance implements MLProcess {
+
+ @XmlElement
+ @Getter
+ @Setter
+ String id;
+
+ @XmlElement
+ @Getter
+ @Setter
+ Date startTime;
+
+ @XmlElement
+ @Getter
+ @Setter
+ Date finishTime;
+
+ @XmlElement
+ @Getter
+ @Setter
+ Status status;
+
+ @Getter
+ @Setter
+ @XmlElement
+ LensSessionHandle lensSessionHandle;
+
+ @Getter
+ @Setter
+ @XmlElement
+ String modelId;
+
+ @Getter
+ @Setter
+ @XmlElement
+ String dataSetName;
+
+ @Getter
+ @Setter
+ @XmlElement
+ String path;
+
+ @Getter
+ @Setter
+ @XmlElement
+ String defaultEvaluationId;
+
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelMetadata.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelMetadata.java
deleted file mode 100644
index 3f7dff1..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/ModelMetadata.java
+++ /dev/null
@@ -1,118 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.api;
-
-import javax.xml.bind.annotation.XmlElement;
-import javax.xml.bind.annotation.XmlRootElement;
-
-import lombok.AllArgsConstructor;
-import lombok.Getter;
-import lombok.NoArgsConstructor;
-
-/**
- * The Class ModelMetadata.
- */
-@XmlRootElement
-/**
- * Instantiates a new model metadata.
- *
- * @param modelID
- * the model id
- * @param table
- * the table
- * @param algorithm
- * the algorithm
- * @param params
- * the params
- * @param createdAt
- * the created at
- * @param modelPath
- * the model path
- * @param labelColumn
- * the label column
- * @param features
- * the features
- */
-@AllArgsConstructor
-/**
- * Instantiates a new model metadata.
- */
-@NoArgsConstructor
-public class ModelMetadata {
-
- /** The model id. */
- @XmlElement
- @Getter
- private String modelID;
-
- /** The table. */
- @XmlElement
- @Getter
- private String table;
-
- /** The algorithm. */
- @XmlElement
- @Getter
- private String algorithm;
-
- /** The params. */
- @XmlElement
- @Getter
- private String params;
-
- /** The created at. */
- @XmlElement
- @Getter
- private String createdAt;
-
- /** The model path. */
- @XmlElement
- @Getter
- private String modelPath;
-
- /** The label column. */
- @XmlElement
- @Getter
- private String labelColumn;
-
- /** The features. */
- @XmlElement
- @Getter
- private String features;
-
- /*
- * (non-Javadoc)
- *
- * @see java.lang.Object#toString()
- */
- @Override
- public String toString() {
- StringBuilder builder = new StringBuilder();
-
- builder.append("Algorithm: ").append(algorithm).append('\n');
- builder.append("Model ID: ").append(modelID).append('\n');
- builder.append("Training table: ").append(table).append('\n');
- builder.append("Features: ").append(features).append('\n');
- builder.append("Labelled Column: ").append(labelColumn).append('\n');
- builder.append("Training params: ").append(params).append('\n');
- builder.append("Created on: ").append(createdAt).append('\n');
- builder.append("Model saved at: ").append(modelPath).append('\n');
- return builder.toString();
- }
-}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Prediction.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Prediction.java
new file mode 100644
index 0000000..16c433d
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Prediction.java
@@ -0,0 +1,88 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import java.util.Date;
+
+import javax.xml.bind.annotation.XmlAccessType;
+import javax.xml.bind.annotation.XmlAccessorType;
+import javax.xml.bind.annotation.XmlElement;
+import javax.xml.bind.annotation.XmlRootElement;
+
+import org.apache.lens.api.LensSessionHandle;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.NoArgsConstructor;
+import lombok.Setter;
+
+/*
+ * Batch prediction Instance
+ */
+
+/**
+ * Contains meta data for an Prediction. Prediction captures the meta data of process of batch predicting data of
+ * inputDataSet against the ModelInstance modelInstanceId and populating results in outputDataSet.
+ */
+@AllArgsConstructor
+@XmlRootElement
+@XmlAccessorType(XmlAccessType.FIELD)
+@NoArgsConstructor
+public class Prediction implements MLProcess {
+ @XmlElement
+ @Getter
+ @Setter
+ String id;
+
+ @XmlElement
+ @Getter
+ @Setter
+ Date startTime;
+
+ @XmlElement
+ @Getter
+ @Setter
+ Date finishTime;
+
+ @XmlElement
+ @Getter
+ @Setter
+ Status status;
+
+ @Getter
+ @Setter
+ @XmlElement
+ LensSessionHandle lensSessionHandle;
+
+ @Getter
+ @Setter
+ @XmlElement
+ String modelInstanceId;
+
+ @Getter
+ @Setter
+ @XmlElement
+ String inputDataSet;
+
+ @Getter
+ @Setter
+ @XmlElement
+ String outputDataSet;
+
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Status.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Status.java
new file mode 100644
index 0000000..8436e28
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/Status.java
@@ -0,0 +1,34 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.api;
+
+import javax.xml.bind.annotation.XmlRootElement;
+
+/**
+ * Status of the batch jobs for training, evaluation and prediction
+ */
+
+@XmlRootElement
+public enum Status {
+ SUBMITTED,
+ RUNNING,
+ FAILED,
+ CANCELLED,
+ COMPLETED
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/TestReport.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/api/TestReport.java
deleted file mode 100644
index 294fef3..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/api/TestReport.java
+++ /dev/null
@@ -1,125 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.api;
-
-import javax.xml.bind.annotation.XmlElement;
-import javax.xml.bind.annotation.XmlRootElement;
-
-import lombok.AllArgsConstructor;
-import lombok.Getter;
-import lombok.NoArgsConstructor;
-
-/**
- * The Class TestReport.
- */
-@XmlRootElement
-/**
- * Instantiates a new test report.
- *
- * @param testTable
- * the test table
- * @param outputTable
- * the output table
- * @param outputColumn
- * the output column
- * @param labelColumn
- * the label column
- * @param featureColumns
- * the feature columns
- * @param algorithm
- * the algorithm
- * @param modelID
- * the model id
- * @param reportID
- * the report id
- * @param queryID
- * the query id
- */
-@AllArgsConstructor
-/**
- * Instantiates a new test report.
- */
-@NoArgsConstructor
-public class TestReport {
-
- /** The test table. */
- @XmlElement
- @Getter
- private String testTable;
-
- /** The output table. */
- @XmlElement
- @Getter
- private String outputTable;
-
- /** The output column. */
- @XmlElement
- @Getter
- private String outputColumn;
-
- /** The label column. */
- @XmlElement
- @Getter
- private String labelColumn;
-
- /** The feature columns. */
- @XmlElement
- @Getter
- private String featureColumns;
-
- /** The algorithm. */
- @XmlElement
- @Getter
- private String algorithm;
-
- /** The model id. */
- @XmlElement
- @Getter
- private String modelID;
-
- /** The report id. */
- @XmlElement
- @Getter
- private String reportID;
-
- /** The query id. */
- @XmlElement
- @Getter
- private String queryID;
-
- /*
- * (non-Javadoc)
- *
- * @see java.lang.Object#toString()
- */
- @Override
- public String toString() {
- StringBuilder builder = new StringBuilder();
- builder.append("Input test table: ").append(testTable).append('\n');
- builder.append("Algorithm: ").append(algorithm).append('\n');
- builder.append("Report id: ").append(reportID).append('\n');
- builder.append("Model id: ").append(modelID).append('\n');
- builder.append("Lens Query id: ").append(queryID).append('\n');
- builder.append("Feature columns: ").append(featureColumns).append('\n');
- builder.append("Labelled column: ").append(labelColumn).append('\n');
- builder.append("Predicted column: ").append(outputColumn).append('\n');
- builder.append("Test output table: ").append(outputTable).append('\n');
- return builder.toString();
- }
-}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MLDBUtils.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MLDBUtils.java
deleted file mode 100644
index d444a32..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MLDBUtils.java
+++ /dev/null
@@ -1,105 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.dao;
-
-import org.apache.lens.ml.algo.api.MLModel;
-import org.apache.lens.ml.api.MLTestReport;
-import org.apache.lens.ml.impl.MLTask;
-
-public class MLDBUtils {
-
- /**
- * Create table to store test report data
- */
- public void createTestReportTable() {
-
- }
-
- /**
- * Create table to store ML task workflow data
- */
- public void createMLTaskTable() {
-
- }
-
- /**
- * Create table to save ML Models
- */
- public void createMLModelTable() {
-
- }
-
- /**
- * Insert an ML Task into ml task table
- *
- * @param task
- */
- public void saveMLTask(MLTask task) {
-
- }
-
- /**
- * Get ML Task given its id
- *
- * @param taskID
- * @return
- */
- public MLTask getMLTask(String taskID) {
- return null;
- }
-
- /**
- * Insert test report into test report table
- *
- * @param testReport
- */
- public void saveTestReport(MLTestReport testReport) {
-
- }
-
- /**
- * Get test report given its ID
- *
- * @param testReportID
- * @return
- */
- public MLTestReport getTestReport(String testReportID) {
- return null;
- }
-
- /**
- * Insert model metadata into model table
- *
- * @param mlModel
- */
- public void saveMLModel(MLModel<?> mlModel) {
-
- }
-
- /**
- * Get model metadata given ID
- *
- * @param modelID
- * @return
- */
- public MLModel<?> getMLModel(String modelID) {
- return null;
- }
-
-}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MetaStoreClient.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MetaStoreClient.java
new file mode 100644
index 0000000..fb5cd84
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MetaStoreClient.java
@@ -0,0 +1,232 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.dao;
+
+import java.sql.SQLException;
+import java.util.Date;
+import java.util.List;
+
+import org.apache.lens.api.LensSessionHandle;
+import org.apache.lens.ml.api.*;
+import org.apache.lens.server.api.error.LensException;
+
+public interface MetaStoreClient {
+
+ void init();
+
+ /**
+ * @param dataSet
+ * @throws LensException
+ */
+
+ void createDataSet(DataSet dataSet) throws SQLException;
+
+ /**
+ * Creates a data set
+ *
+ * @param name
+ * @return
+ */
+ DataSet getDataSet(String name) throws SQLException;
+
+ void deleteDataSet(String dataSetName) throws SQLException;
+
+ /**
+ * @param name
+ * @param algo
+ * @param algoSpec
+ * @param features
+ * @param label
+ * @return same Id if
+ * @throws LensException
+ */
+ void createModel(String name, String algo, AlgoSpec algoSpec,
+ List<Feature> features, Feature label) throws LensException;
+
+
+ void createModel(Model model) throws LensException;
+
+ /**
+ * Retrieves Model
+ *
+ * @param modelId
+ * @return
+ * @throws LensException If model not present or meta store error
+ */
+ Model getModel(String modelId) throws LensException;
+
+ void deleteModel(String modelId) throws SQLException;
+
+ /**
+ * creates model instance
+ *
+ * @param startTime
+ * @param finishTime
+ * @param status
+ * @param lensSessionHandle
+ * @param modelId
+ * @param dataSet
+ * @param path
+ * @param evaluationId
+ * @return
+ * @throws LensException
+ */
+ String createModelInstance(Date startTime, Date finishTime, Status status, LensSessionHandle lensSessionHandle,
+ String modelId, String dataSet, String path, String evaluationId)
+ throws LensException;
+
+ /**
+ * Updates the model instance
+ *
+ * @param modelInstance
+ * @throws LensException If modelInstance not already present in DB or meta store error.
+ */
+ void updateModelInstance(ModelInstance modelInstance) throws LensException;
+
+ void deleteModelInstance(String modelInstanceId) throws SQLException;
+
+ /**
+ * Return list of all ModelInstances in meta store having Status other than COMPLETED, FAILED, CANCELLED
+ *
+ * @return
+ * @throws LensException On meta store error.
+ */
+ List<ModelInstance> getIncompleteModelInstances() throws LensException;
+
+ /**
+ * Return list of all Evaluation in meta store having Status other than COMPLETED, FAILED, CANCELLED
+ *
+ * @return
+ * @throws LensException On meta store error.
+ */
+ List<Evaluation> getIncompleteEvaluations() throws LensException;
+
+ /**
+ * Return list of all Prediction in meta store having Status other than COMPLETED, FAILED, CANCELLED
+ *
+ * @return
+ * @throws LensException On meta store error.
+ */
+ List<Prediction> getIncompletePredictions() throws LensException;
+
+ /**
+ * Returns all ModelInstances for modelId
+ *
+ * @param modelId
+ * @return
+ * @throws LensException If modelId is not present or meta store error.
+ */
+ List<ModelInstance> getModelInstances(String modelId) throws SQLException;
+
+ /**
+ * Returns all Evaluations for modelInstanceId
+ *
+ * @param modelInstanceId
+ * @return
+ * @throws LensException
+ */
+ List<Evaluation> getEvaluations(String modelInstanceId) throws SQLException;
+
+ void deleteEvaluation(String evaluationId) throws SQLException;
+
+ /**
+ * Returns all Prediction for modelInstanceId
+ *
+ * @param modelInstanceId
+ * @return
+ * @throws LensException
+ */
+ List<Prediction> getPredictions(String modelInstanceId) throws SQLException;
+
+ /**
+ * @param modelInstanceId
+ * @return
+ * @throws LensException If modelInstanceId is not present
+ */
+ ModelInstance getModelInstance(String modelInstanceId) throws LensException;
+
+ /**
+ * Creates Prediction
+ *
+ * @param startTime
+ * @param finishTime
+ * @param status
+ * @param lensSessionHandle
+ * @param modelInstanceId
+ * @param inputDataSet
+ * @param outputDataSet
+ * @return predictionId
+ * @throws LensException
+ */
+ String createPrediction(Date startTime, Date finishTime, Status status, LensSessionHandle lensSessionHandle,
+ String modelInstanceId, String inputDataSet, String outputDataSet) throws LensException;
+
+ void deletePrediction(String predictionId) throws SQLException;
+
+ /**
+ * gets Prediction
+ *
+ * @param predictionId
+ * @return
+ * @throws LensException If prediction Id is not present or meta store error
+ */
+ Prediction getPrediction(String predictionId) throws LensException;
+
+ /**
+ * Updates Evaluation
+ *
+ * @param evaluation
+ * @throws LensException If evaluation not already present in DB or meta store error.
+ */
+ void updateEvaluation(Evaluation evaluation) throws LensException;
+
+ /**
+ * Creates Evaluation
+ *
+ * @param startTime
+ * @param finishTime
+ * @param status
+ * @param lensSessionHandle
+ * @param modelInstanceId
+ * @param inputDataSetName
+ * @return evaluationId
+ * @throws LensException
+ */
+ String createEvaluation(Date startTime, Date finishTime, Status status, LensSessionHandle lensSessionHandle,
+ String
+ modelInstanceId, String inputDataSetName) throws LensException;
+
+ /**
+ * gets evaluation
+ *
+ * @param evaluationId
+ * @return
+ * @throws LensException If evaluationId is not present in meta store.
+ */
+ Evaluation getEvaluation(String evaluationId) throws LensException;
+
+ /**
+ * Updates prediction
+ *
+ * @param prediction
+ * @throws LensException If evaluation not already present in DB or meta store error.
+ */
+ void updatePrediction(Prediction prediction) throws LensException;
+
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MetaStoreClientImpl.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MetaStoreClientImpl.java
new file mode 100644
index 0000000..9905161
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/dao/MetaStoreClientImpl.java
@@ -0,0 +1,678 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.dao;
+
+import java.sql.Connection;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.*;
+
+import javax.sql.DataSource;
+
+import org.apache.lens.api.LensSessionHandle;
+import org.apache.lens.ml.api.*;
+import org.apache.lens.server.api.error.LensException;
+
+import org.apache.commons.dbutils.DbUtils;
+import org.apache.commons.dbutils.QueryRunner;
+import org.apache.commons.dbutils.ResultSetHandler;
+import org.apache.commons.dbutils.handlers.BeanHandler;
+import org.apache.commons.lang.StringUtils;
+
+import org.codehaus.jettison.json.JSONObject;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import lombok.extern.slf4j.Slf4j;
+
+/**
+ * MetaStore implementation for JDBC.
+ */
+@Slf4j
+public class MetaStoreClientImpl implements MetaStoreClient {
+
+ ResultSetHandler<List<Feature>> featureListHandler = new ResultSetHandler<List<Feature>>() {
+ public List<Feature> handle(ResultSet rs) throws SQLException {
+ List<Feature> featureList = new ArrayList();
+ while (rs.next()) {
+ String name = rs.getString("featureName");
+ String description = rs.getString("description");
+ Feature.Type type = Feature.Type.valueOf(rs.getString("valueType"));
+ String dataColumn = rs.getString("dataColumn");
+ featureList.add(new Feature(name, description, type, dataColumn));
+ }
+ return featureList;
+ }
+ };
+ ResultSetHandler<Model> modelHandler = new ResultSetHandler<Model>() {
+ @Override
+ public Model handle(ResultSet resultSet) throws SQLException {
+ if (!resultSet.next()) {
+ return null;
+ }
+ String modelName = resultSet.getString("modelName");
+ String algoName = resultSet.getString("algoName");
+ String lableName = resultSet.getString("labelName");
+ String labelDescription = resultSet.getString("labelDescription");
+ String lableDataColumn = resultSet.getString("lableDataColumn");
+ Feature.Type lableType = Feature.Type.valueOf(resultSet.getString("lableType"));
+
+ Feature lable = new Feature(lableName, labelDescription, lableType, lableDataColumn);
+ String algoParamJson = resultSet.getString("algoParams");
+ ObjectMapper mapper = new ObjectMapper();
+ Map<String, String> alsoParams = null;
+ try {
+ alsoParams = mapper.readValue(algoParamJson, Map.class);
+ } catch (Exception e) {
+ throw new SQLException("Error Parsing algoParamJson", e);
+ }
+ AlgoSpec algoSpec = new AlgoSpec(algoName, alsoParams);
+ return new Model(modelName, algoSpec, null, lable);
+ }
+ };
+ ResultSetHandler<List<AlgoParameter>> algoParamListHandler = new ResultSetHandler<List<AlgoParameter>>() {
+ public List<AlgoParameter> handle(ResultSet rs) throws SQLException {
+ List<AlgoParameter> algoParameterList = new ArrayList<>();
+ while (rs.next()) {
+ String name = rs.getString("paramName");
+ String helpText = rs.getString("helpText");
+ String defaultValue = rs.getString("defaultValue");
+
+ algoParameterList.add(new AlgoParameter(name, helpText, defaultValue));
+ }
+
+ return algoParameterList;
+ }
+ };
+ ResultSetHandler<List<Evaluation>> evaluationListHandler = new ResultSetHandler<List<Evaluation>>() {
+ @Override
+ public List<Evaluation> handle(ResultSet rs) throws SQLException {
+ List<Evaluation> evaluations = new ArrayList<>();
+ if (rs.next()) {
+ String id = rs.getString("id");
+ Date startTime = rs.getTimestamp("startTime");
+ Date finishTime = rs.getTimestamp("finishTime");
+ Status status = Status.valueOf(rs.getString("status"));
+ LensSessionHandle lensSessionHandle = LensSessionHandle.valueOf(rs.getString("lensSessionHandle"));
+ String modelInstanceId = rs.getString("modelInstanceId");
+ String inputDataSet = rs.getString("inputDataSetName");
+
+ Evaluation evaluation = new Evaluation(id, startTime, finishTime, status, lensSessionHandle, modelInstanceId,
+ inputDataSet);
+ evaluations.add(evaluation);
+ }
+ return evaluations;
+ }
+ };
+ ResultSetHandler<ModelInstance> modelInstanceHandler = new ResultSetHandler<ModelInstance>() {
+ public ModelInstance handle(ResultSet rs) throws SQLException {
+ if (rs.next()) {
+ String id = rs.getString("id");
+ Date startTime = rs.getTimestamp("startTime");
+ Date finishTime = rs.getTimestamp("finishTime");
+ Status status = Status.valueOf(rs.getString("status"));
+ LensSessionHandle lensSessionHandle = LensSessionHandle.valueOf(rs.getString("lensSessionHandle"));
+ String modelName = rs.getString("modelName");
+ String dataSetName = rs.getString("dataSetName");
+ String path = rs.getString("path");
+ String defaultEvaluationId = rs.getString("defaultEvaluationId");
+
+ return new ModelInstance(id, startTime, finishTime, status, lensSessionHandle, modelName, dataSetName, path,
+ defaultEvaluationId);
+ }
+ return null;
+ }
+ };
+ ResultSetHandler<List<ModelInstance>> modelInstanceListHandler = new ResultSetHandler<List<ModelInstance>>() {
+ @Override
+ public List<ModelInstance> handle(ResultSet rs) throws SQLException {
+ List<ModelInstance> modelInstances = new ArrayList<>();
+ if (rs.next()) {
+ String id = rs.getString("id");
+ Date startTime = rs.getTimestamp("startTime");
+ Date finishTime = rs.getTimestamp("finishTime");
+ Status status = Status.valueOf(rs.getString("status"));
+ LensSessionHandle lensSessionHandle = LensSessionHandle.valueOf(rs.getString("lensSessionHandle"));
+ String modelName = rs.getString("modelName");
+ String dataSetName = rs.getString("dataSetName");
+ String path = rs.getString("path");
+ String defaultEvaluationId = rs.getString("defaultEvaluationId");
+
+
+ ModelInstance modelInstance = new ModelInstance(id, startTime, finishTime, status, lensSessionHandle, modelName,
+ dataSetName, path,
+ defaultEvaluationId);
+ modelInstances.add(modelInstance);
+ }
+ return modelInstances;
+ }
+ };
+ ResultSetHandler<Prediction> predictionResultSetHandler = new ResultSetHandler<Prediction>() {
+ public Prediction handle(ResultSet rs) throws SQLException {
+ if (rs.next()) {
+ String id = rs.getString("id");
+ Date startTime = rs.getTimestamp("startTime");
+ Date finishTime = rs.getTimestamp("finishTime");
+ Status status = Status.valueOf(rs.getString("status"));
+ LensSessionHandle lensSessionHandle = LensSessionHandle.valueOf(rs.getString("lensSessionHandle"));
+ String modelInstanceId = rs.getString("modelInstanceId");
+ String inputDataSet = rs.getString("inputDataSet");
+ String outputDataSet = rs.getString("outputDataSet");
+
+ return new Prediction(id, startTime, finishTime, status, lensSessionHandle, modelInstanceId, inputDataSet,
+ outputDataSet);
+ }
+ return null;
+ }
+ };
+ ResultSetHandler<List<Prediction>> predictionListHandler = new ResultSetHandler<List<Prediction>>() {
+ @Override
+ public List<Prediction> handle(ResultSet rs) throws SQLException {
+ List<Prediction> predictions = new ArrayList<>();
+ if (rs.next()) {
+ String id = rs.getString("id");
+ Date startTime = rs.getTimestamp("startTime");
+ Date finishTime = rs.getTimestamp("finishTime");
+ Status status = Status.valueOf(rs.getString("status"));
+ LensSessionHandle lensSessionHandle = LensSessionHandle.valueOf(rs.getString("lensSessionHandle"));
+ String modelInstanceId = rs.getString("modelInstanceId");
+ String inputDataSet = rs.getString("inputDataSet");
+ String outputDataSet = rs.getString("outputDataSet");
+
+ Prediction prediction = new Prediction(id, startTime, finishTime, status, lensSessionHandle, modelInstanceId,
+ inputDataSet,
+ outputDataSet);
+ predictions.add(prediction);
+ }
+ return predictions;
+ }
+ };
+ ResultSetHandler<Evaluation> evaluationResultSetHandler = new ResultSetHandler<Evaluation>() {
+ public Evaluation handle(ResultSet rs) throws SQLException {
+ if (rs.next()) {
+ String id = rs.getString("id");
+ Date startTime = rs.getTimestamp("startTime");
+ Date finishTime = rs.getTimestamp("finishTime");
+ Status status = Status.valueOf(rs.getString("status"));
+ LensSessionHandle lensSessionHandle = LensSessionHandle.valueOf(rs.getString("lensSessionHandle"));
+ String modelInstanceId = rs.getString("modelInstanceId");
+ String inputDataSet = rs.getString("inputDataSetName");
+
+ return new Evaluation(id, startTime, finishTime, status, lensSessionHandle, modelInstanceId, inputDataSet);
+ }
+ return null;
+ }
+ };
+ /**
+ * The ds.
+ */
+ private DataSource ds;
+
+ public MetaStoreClientImpl(DataSource dataSource) {
+ this.ds = dataSource;
+ }
+
+ public void init() {
+ String dataSourceCreateSql = "CREATE TABLE IF NOT EXISTS `datasets` (\n"
+ + " `dsName` varchar(255) NOT NULL,\n"
+ + " `tableName` varchar(255) NOT NULL,\n"
+ + " `dbName` varchar(255) NOT NULL,\n"
+ + " PRIMARY KEY (`dsName`)\n"
+ + ");";
+
+ String modelCreateSql = "CREATE TABLE IF NOT EXISTS `models` (\n"
+ + " `modelName` varchar(255) NOT NULL,\n"
+ + " `algoName` varchar(255) NOT NULL,\n"
+ + " `labelName` varchar(255) NOT NULL,\n"
+ + " `labelDescription` varchar(1000) DEFAULT NULL,\n"
+ + " `lableType` varchar(255) NOT NULL,\n"
+ + " `lableDataColumn` varchar(255) DEFAULT NULL,\n"
+ + " `algoParams` text,\n"
+ + " PRIMARY KEY (`modelName`)\n"
+ + ");";
+
+ String modelInstanceCreateSql = "CREATE TABLE IF NOT EXISTS `modelInstances` (\n"
+ + " `id` varchar(255) NOT NULL DEFAULT '',\n"
+ + " `startTime` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n"
+ + " `finishTime` timestamp NOT NULL DEFAULT '0000-00-00 00:00:00',\n"
+ + " `status` enum('SUBMITTED','RUNNING','FAILED','CANCELLED','COMPLETED') DEFAULT NULL,\n"
+ + " `lensSessionHandle` varchar(255) NOT NULL,\n"
+ + " `modelName` varchar(255) DEFAULT NULL,\n"
+ + " `dataSetName` varchar(255) NOT NULL,\n"
+ + " `path` varchar(255) DEFAULT NULL,\n"
+ + " `defaultEvaluationId` varchar(255) DEFAULT NULL,\n"
+ + " PRIMARY KEY (`id`),\n"
+ + " KEY `modelName` (`modelName`),\n"
+ + " KEY `dataSetName` (`dataSetName`),\n"
+ + " CONSTRAINT `modelinstances_ibfk_1` FOREIGN KEY (`modelName`) REFERENCES `models` (`modelName`),\n"
+ + " CONSTRAINT `modelinstances_ibfk_2` FOREIGN KEY (`dataSetName`) REFERENCES `datasets` (`dsName`)\n"
+ + ");";
+
+ String evaluationsCreateSql = "CREATE TABLE IF NOT EXISTS `evaluations` (\n"
+ + " `id` varchar(255) NOT NULL DEFAULT '',\n"
+ + " `startTime` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n"
+ + " `finishTime` timestamp NOT NULL DEFAULT '0000-00-00 00:00:00',\n"
+ + " `status` enum('SUBMITTED','RUNNING','FAILED','CANCELLED','COMPLETED') DEFAULT NULL,\n"
+ + " `lensSessionHandle` varchar(255) NOT NULL,\n"
+ + " `modelInstanceId` varchar(255) DEFAULT NULL,\n"
+ + " `inputDataSetName` varchar(255) DEFAULT NULL,\n"
+ + " PRIMARY KEY (`id`),\n"
+ + " KEY `modelInstanceId` (`modelInstanceId`),\n"
+ + " KEY `inputDataSetName` (`inputDataSetName`),\n"
+ + " CONSTRAINT `evaluations_ibfk_1` FOREIGN KEY (`modelInstanceId`) REFERENCES `modelinstances` (`id`),\n"
+ + " CONSTRAINT `evaluations_ibfk_2` FOREIGN KEY (`inputDataSetName`) REFERENCES `datasets` (`dsName`)\n"
+ + ");\n";
+
+ String predictionCreateSql = "CREATE TABLE IF NOT EXISTS `predictions` (\n"
+ + " `id` varchar(255) NOT NULL DEFAULT '',\n"
+ + " `startTime` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,\n"
+ + " `finishTime` timestamp NOT NULL DEFAULT '0000-00-00 00:00:00',\n"
+ + " `status` enum('SUBMITTED','RUNNING','FAILED','CANCELLED','COMPLETED') DEFAULT NULL,\n"
+ + " `lensSessionHandle` varchar(255) NOT NULL,\n"
+ + " `modelInstanceId` varchar(255) DEFAULT NULL,\n"
+ + " `inputDataSet` varchar(255) NOT NULL,\n"
+ + " `outputDataSet` varchar(255) NOT NULL,\n"
+ + " PRIMARY KEY (`id`),\n"
+ + " KEY `modelInstanceId` (`modelInstanceId`),\n"
+ + " KEY `inputDataSet` (`inputDataSet`),\n"
+ + " KEY `outputDataSet` (`outputDataSet`),\n"
+ + " CONSTRAINT `predictions_ibfk_1` FOREIGN KEY (`modelInstanceId`) REFERENCES `modelinstances` (`id`),\n"
+ + " CONSTRAINT `predictions_ibfk_2` FOREIGN KEY (`inputDataSet`) REFERENCES `datasets` (`dsName`),\n"
+ + " CONSTRAINT `predictions_ibfk_3` FOREIGN KEY (`outputDataSet`) REFERENCES `datasets` (`dsName`)\n"
+ + ") ;";
+
+ String featureCreateSql = "CREATE TABLE IF NOT EXISTS `features` (\n"
+ + " `featureName` varchar(255) NOT NULL,\n"
+ + " `description` varchar(1000) DEFAULT NULL,\n"
+ + " `valueType` enum('Categorical','Continuous') NOT NULL,\n"
+ + " `dataColumn` varchar(255) NOT NULL,\n"
+ + " `modelName` varchar(255) NOT NULL DEFAULT '',\n"
+ + " PRIMARY KEY (`featureName`,`modelName`),\n"
+ + " KEY `modelName` (`modelName`),\n"
+ + " CONSTRAINT `features_ibfk_1` FOREIGN KEY (`modelName`) REFERENCES `models` (`modelName`)\n"
+ + ") ;";
+
+ QueryRunner runner = new QueryRunner(ds);
+ try {
+ runner.update(dataSourceCreateSql);
+ runner.update(modelCreateSql);
+ runner.update(modelInstanceCreateSql);
+ runner.update(predictionCreateSql);
+ runner.update(evaluationsCreateSql);
+ runner.update(featureCreateSql);
+ } catch (SQLException e) {
+ log.error("Error creating ML meta store");
+ }
+
+ }
+
+ /*
+ public List<AlgoParam> getAlgoParams(String modelId) throws LensException{
+
+
+ QueryRunner runner = new QueryRunner(ds);
+
+ String sql = "select * from algoparams where id = ?";
+
+ try{
+ Object[] runner.query(sql, modelId);
+ } catch(SQLException e){
+ log.error("SQL exception while executing query. ", e);
+ throw new LensException(e);
+ }
+ }*/
+
+ @Override
+ public void createDataSet(DataSet dataSet) throws SQLException {
+ String sql = "insert into datasets (dsName, tableName, dbName) values (?,?,?)";
+ DataSet alreadyExisting;
+ alreadyExisting = getDataSet(dataSet.getDsName());
+ if (alreadyExisting != null) {
+ throw new SQLException("Dataset with same name already exists.");
+ }
+
+ QueryRunner runner = new QueryRunner(ds);
+ runner.update(sql, dataSet.getDsName(), dataSet.getTableName(), dataSet.getDbName());
+ }
+
+ @Override
+ public DataSet getDataSet(String name) throws SQLException {
+ ResultSetHandler<DataSet> dsh = new BeanHandler<DataSet>(DataSet.class);
+ String sql = "select * from datasets where dsName = ?";
+ QueryRunner runner = new QueryRunner(ds);
+ return runner.query(sql, dsh, name);
+ }
+
+ public void deleteDataSet(String dataSetName) throws SQLException {
+ String sql = "delete from datasets where dsName = ?";
+ QueryRunner runner = new QueryRunner(ds);
+ runner.update(sql, dataSetName);
+ }
+
+ @Override
+ public void createModel(String name, String algo, AlgoSpec algoSpec, List<Feature> features, Feature label)
+ throws LensException {
+ String modelSql = "INSERT INTO models VALUES (?, ?, ?, ?, ?, ?, ?)";
+ String featuresSql = "INSERT INTO features VALUES (?, ?, ?, ?, ?)";
+
+ try {
+ Connection con = ds.getConnection();
+ con.setAutoCommit(false);
+ QueryRunner runner = new QueryRunner(ds);
+
+ JSONObject algoParamJson = new JSONObject(algoSpec.getAlgoParams());
+ runner.update(con, modelSql, name, algoSpec.getAlgo(), label.getName(), label.getDescription(), label.getType()
+ .toString(), label.getDataColumn(), algoParamJson.toString());
+
+ for (Feature feature : features) {
+ runner
+ .update(con, featuresSql, feature.getName(), feature.getDescription(), feature.getType().toString(), feature
+ .getDataColumn(), name);
+ }
+
+ DbUtils.commitAndClose(con);
+ } catch (SQLException e) {
+ throw new LensException("Error while creating Model, Id: " + name, e);
+ }
+
+ }
+
+ public void createModel(Model model) throws LensException {
+ String modelSql = "INSERT INTO models VALUES (?, ?, ?, ?, ?, ?, ?)";
+ String featuresSql = "INSERT INTO features VALUES (?, ?, ?, ?, ?)";
+
+ try {
+ Connection con = ds.getConnection();
+ con.setAutoCommit(false);
+ QueryRunner runner = new QueryRunner(ds);
+
+ JSONObject algoParamJson = new JSONObject(model.getAlgoSpec().getAlgoParams());
+ runner.update(con, modelSql, model.getName(), model.getAlgoSpec().getAlgo(), model.getLabelSpec().getName(),
+ model.getLabelSpec().getDescription(), model.getLabelSpec().getType().toString(),
+ model.getLabelSpec().getDataColumn(),
+ algoParamJson.toString());
+
+ for (Feature feature : model.getFeatureSpec()) {
+ runner
+ .update(con, featuresSql, feature.getName(), feature.getDescription(), feature.getType().toString(), feature
+ .getDataColumn(), model.getName());
+ }
+
+ DbUtils.commitAndClose(con);
+ } catch (SQLException e) {
+ throw new LensException("Error while creating Model, Id: " + model.getName(), e);
+ }
+ }
+
+ @Override
+ public Model getModel(String modelName) throws LensException {
+
+ String featureSql = "select * from features where modelName = ?";
+ String modelSql = "select * from models where modelName = ? ";
+ QueryRunner runner = new QueryRunner(ds);
+ try {
+ Model model = runner.query(modelSql, modelHandler, modelName);
+ List<Feature> featureList = runner.query(featureSql, featureListHandler, modelName);
+ model.setFeatureSpec(featureList);
+ return model;
+ } catch (SQLException e) {
+ log.error("SQL exception while executing query. ", e);
+ throw new LensException(e);
+ }
+ }
+
+ public void deleteModel(String modelId) throws SQLException {
+ String featureSql = "select * from features where modelName = ?";
+ String modelDeleteSql = "delete from models where modelName = ?";
+ QueryRunner runner = new QueryRunner(ds);
+ List<Feature> featureList = runner.query(featureSql, featureListHandler, modelId);
+ if (!featureList.isEmpty()) {
+ List<String> features = new ArrayList<String>();
+ for (Feature feature : featureList) {
+ features.add(feature.getName());
+ }
+ String featureDeleteSql = "delete from features where modelName = ? and featureName in (";
+ String commaDelimitedFeatures = StringUtils.join(features, ",");
+ featureDeleteSql = featureDeleteSql.concat(commaDelimitedFeatures).concat(")");
+ runner.update(featureDeleteSql, modelId);
+ runner.update(modelDeleteSql, modelId);
+ }
+ }
+
+ @Override
+ public String createModelInstance(Date startTime, Date finishTime, Status status,
+ LensSessionHandle lensSessionHandle, String modelName, String dataSetName,
+ String path, String defaultEvaluationId) throws LensException {
+
+ String modelInstanceSql = "INSERT INTO modelInstances (id, startTime, finishTime, status, "
+ + "lensSessionHandle, modelName, dataSetName, path, defaultEvaluationId) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
+
+ String modelInstanceId;
+ Connection con;
+ try {
+ QueryRunner runner = new QueryRunner(ds);
+ ModelInstance alreadyExisting = null;
+ do {
+ modelInstanceId = UUID.randomUUID().toString();
+ alreadyExisting = getModelInstance(modelInstanceId);
+ } while (alreadyExisting != null);
+
+ runner.update(modelInstanceSql, modelInstanceId, startTime, finishTime, status
+ .toString(), lensSessionHandle.toString(), modelName, dataSetName, path, defaultEvaluationId);
+ } catch (SQLException e) {
+ log.error("Error while creating ModelInstance for Model, " + modelName);
+ throw new LensException("Error while creating ModelInstance for Model, " + modelName);
+ }
+ return modelInstanceId;
+ }
+
+ public void deleteModelInstance(String modelInstanceId) throws SQLException {
+ String sql = "delete from modelInstances where id = ?";
+ QueryRunner runner = new QueryRunner(ds);
+ runner.update(sql, modelInstanceId);
+ }
+
+ @Override
+ public void updateModelInstance(ModelInstance modelInstance) throws LensException {
+ String modelInstanceSql = "UPDATE modelInstances SET startTime = ?, finishTime = ?, status =?, "
+ + "lensSessionHandle = ?, modelName = ?, dataSetName = ?, path = ?, defaultEvaluationId = ? where id = ?";
+ try {
+ QueryRunner runner = new QueryRunner(ds);
+ runner.update(modelInstanceSql, modelInstance.getStartTime(), modelInstance.getFinishTime(), modelInstance
+ .getStatus().toString(), modelInstance.getLensSessionHandle().toString(), modelInstance.getModelId(),
+ modelInstance.getDataSetName(), modelInstance.getPath(), modelInstance.getDefaultEvaluationId(),
+ modelInstance.getId());
+ } catch (SQLException sq) {
+ throw new LensException("Error while updating Model Instance. Id: " + modelInstance);
+ }
+ }
+
+ @Override
+ public List<ModelInstance> getIncompleteModelInstances() throws LensException {
+ return null;
+ }
+
+ @Override
+ public List<Evaluation> getIncompleteEvaluations() throws LensException {
+ return null;
+ }
+
+ @Override
+ public List<Prediction> getIncompletePredictions() throws LensException {
+ return null;
+ }
+
+ @Override
+ public List<ModelInstance> getModelInstances(String modelId) throws SQLException {
+ String sql = "SELECT * FROM modelInstances where modelName = ?";
+ QueryRunner runner = new QueryRunner(ds);
+ return runner.query(sql, modelInstanceListHandler, modelId);
+ }
+
+ @Override
+ public List<Evaluation> getEvaluations(String modelInstanceId) throws SQLException {
+ String sql = "SELECT * FROM evaluations where modelInstanceId = ?";
+ QueryRunner runner = new QueryRunner(ds);
+ return runner.query(sql, evaluationListHandler, modelInstanceId);
+ }
+
+ @Override
+ public List<Prediction> getPredictions(String modelInstanceId) throws SQLException {
+ String sql = "SELECT * FROM predictions where modelInstanceId = ?";
+ QueryRunner runner = new QueryRunner(ds);
+ return runner.query(sql, predictionListHandler, modelInstanceId);
+ }
+
+ @Override
+ public ModelInstance getModelInstance(String modelInstanceId) throws LensException {
+ String modelInstanceGetSql = "SELECT * FROM modelInstances WHERE id = ?";
+ QueryRunner runner = new QueryRunner(ds);
+ try {
+ ModelInstance modelInstance = runner.query(modelInstanceGetSql, modelInstanceHandler, modelInstanceId);
+ return modelInstance;
+ } catch (Exception e) {
+ log.error("Error while reading modelInstance, Id: " + modelInstanceId);
+ throw new LensException("Error while reading modelInstance, Id: " + modelInstanceId, e);
+ }
+ }
+
+ @Override
+ public String createPrediction(Date startTime, Date finishTime, Status status,
+ LensSessionHandle lensSessionHandle, String modelInstanceId, String inputDataSet,
+ String outputDataSet) throws LensException {
+ String predictionInsertSql = "INSERT INTO predictions (id, startTime, finishTime, status, lensSessionHandle,"
+ + " modelInstanceId, inputDataSet, outputDataSet) VALUES (?, ?, ?, ?, ?, ?, ?, ?)";
+ String predictionId;
+ try {
+ QueryRunner runner = new QueryRunner(ds);
+ Prediction alreadyExisting = null;
+ do {
+ predictionId = UUID.randomUUID().toString();
+ alreadyExisting = getPrediction(predictionId);
+ } while (alreadyExisting != null);
+
+ runner.update(predictionInsertSql, predictionId, startTime, finishTime, status
+ .toString(), lensSessionHandle.toString(), modelInstanceId, inputDataSet, outputDataSet);
+ } catch (SQLException e) {
+ log.error("Error while creating Prediction for ModelInstance, " + modelInstanceId);
+ throw new LensException("Error while creating Prediction for ModelInstance, " + modelInstanceId, e);
+ }
+ return predictionId;
+ }
+
+ @Override
+ public Prediction getPrediction(String predictionId) throws LensException {
+ String modelInstanceGetSql = "SELECT * FROM predictions WHERE id = ?";
+ QueryRunner runner = new QueryRunner(ds);
+ try {
+ Prediction prediction = runner.query(modelInstanceGetSql, predictionResultSetHandler, predictionId);
+ return prediction;
+ } catch (Exception e) {
+ log.error("Error while reading prediction, Id: " + predictionId);
+ throw new LensException("Error while reading prediction, Id: " + predictionId, e);
+ }
+ }
+
+ @Override
+ public void updateEvaluation(Evaluation evaluation) throws LensException {
+ String evaluationUpdateSql = "UPDATE evaluations SET startTime = ?, finishTime = ?, status =?, "
+ + "lensSessionHandle = ?, modelInstanceId = ?, inputDataSetName = ? where id = ?";
+ try {
+ QueryRunner runner = new QueryRunner(ds);
+ runner.update(evaluationUpdateSql, evaluation.getStartTime(), evaluation.getFinishTime(), evaluation
+ .getStatus().toString(), evaluation.getLensSessionHandle().toString(), evaluation.getModelInstanceId(),
+ evaluation.getInputDataSetName(), evaluation.getId());
+ } catch (SQLException sq) {
+ throw new LensException("Error while updating Evaluation. Id: " + evaluation.getId());
+ }
+ }
+
+ @Override
+ public String createEvaluation(Date startTime, Date finishTime, Status status,
+ LensSessionHandle lensSessionHandle, String modelInstanceId,
+ String inputDataSetName) throws LensException {
+ String evaluationInsertSql = "INSERT INTO evaluations (id, startTime, finishTime, status, lensSessionHandle,"
+ + " modelInstanceId, inputDataSetName) VALUES (?, ?, ?, ?, ?, ?, ?)";
+ String evaluationId;
+ try {
+ QueryRunner runner = new QueryRunner(ds);
+ Evaluation alreadyExisting = null;
+ do {
+ evaluationId = UUID.randomUUID().toString();
+ alreadyExisting = getEvaluation(evaluationId);
+ } while (alreadyExisting != null);
+
+ runner.update(evaluationInsertSql, evaluationId, startTime, finishTime, status
+ .toString(), lensSessionHandle.toString(), modelInstanceId, inputDataSetName);
+ } catch (SQLException e) {
+ log.error("Error while creating Evaluation for ModelInstance, " + modelInstanceId);
+ throw new LensException("Error while creating Evaluation for ModelInstance, " + modelInstanceId, e);
+ }
+ return evaluationId;
+ }
+
+ public void deleteEvaluation(String evaluationId) throws SQLException {
+ String sql = "delete from evaluations where id = ?";
+ QueryRunner runner = new QueryRunner(ds);
+ runner.update(sql, evaluationId);
+ }
+
+ @Override
+ public Evaluation getEvaluation(String evaluationId) throws LensException {
+ String evaluationGetSql = "SELECT * FROM evaluations WHERE id = ?";
+ QueryRunner runner = new QueryRunner(ds);
+ try {
+ Evaluation evaluation = runner.query(evaluationGetSql, evaluationResultSetHandler, evaluationId);
+ return evaluation;
+ } catch (Exception e) {
+ log.error("Error while reading evaluation, Id: " + evaluationId);
+ throw new LensException("Error while reading evaluation, Id: " + evaluationId, e);
+ }
+ }
+
+ @Override
+ public void updatePrediction(Prediction prediction) throws LensException {
+ String predictionUpdateSql = "UPDATE predictions SET startTime = ?, finishTime = ?, status =?, "
+ + "lensSessionHandle = ?, modelInstanceId = ?, inputDataSet = ?, outputDataSet = ? where id = ?";
+ try {
+ QueryRunner runner = new QueryRunner(ds);
+ runner.update(predictionUpdateSql, prediction.getStartTime(), prediction.getFinishTime(), prediction
+ .getStatus().toString(), prediction.getLensSessionHandle().toString(), prediction.getModelInstanceId(),
+ prediction.getInputDataSet(), prediction.getOutputDataSet(), prediction.getId());
+ } catch (SQLException sq) {
+ throw new LensException("Error while updating Prediction. Id: " + prediction.getId());
+ }
+ }
+
+ public void deletePrediction(String predictionId) throws SQLException {
+ String sql = "delete from predictions where id = ?";
+ QueryRunner runner = new QueryRunner(ds);
+ runner.update(sql, predictionId);
+ }
+
+ /**
+ * Creates the table.
+ *
+ * @param sql the sql
+ * @throws java.sql.SQLException the SQL exception
+ */
+ private void createTable(String sql) throws SQLException {
+ QueryRunner runner = new QueryRunner(ds);
+ runner.update(sql);
+ }
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/BatchPredictSpec.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/BatchPredictSpec.java
new file mode 100644
index 0000000..1c9a4ce
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/BatchPredictSpec.java
@@ -0,0 +1,366 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.lens.ml.api.Feature;
+
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.metastore.api.FieldSchema;
+import org.apache.hadoop.hive.ql.metadata.Hive;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.metadata.Table;
+
+import lombok.Getter;
+
+/**
+ * BatchPredictSpec class. Contains table specification for input table for prediction in batch mode. Returns the
+ * HIVE query which can be used to run the prediction job.
+ */
+public class BatchPredictSpec {
+
+ /**
+ * The Constant LOG.
+ */
+ public static final Log LOG = LogFactory.getLog(BatchPredictSpec.class);
+
+ /**
+ * The db.
+ */
+ private String db;
+
+ /**
+ * The table containing input data.
+ */
+ private String inputTable;
+
+ // TODO use partition condition
+ /**
+ * The partition filter.
+ */
+ private String partitionFilter;
+
+ /**
+ * The feature columns.
+ */
+ private List<Feature> featureColumns;
+
+ /**
+ * The output column.
+ */
+ private String outputColumn;
+
+ /**
+ * The output table.
+ */
+ private String outputTable;
+
+ /**
+ * The conf.
+ */
+ private transient HiveConf conf;
+
+ /**
+ * The algorithm.
+ */
+ private String algorithm;
+
+ /**
+ * The model id.
+ */
+ private String modelID;
+
+ /**
+ * The modelInstanceIds
+ */
+ private String modelInstanceId;
+
+ @Getter
+ private boolean outputTableExists;
+
+ /**
+ * A unique testId which is predictionId
+ */
+ @Getter
+ private String testID;
+
+ private HashMap<String, FieldSchema> columnNameToFieldSchema;
+
+ /**
+ * New builder.
+ *
+ * @return the table testing spec builder
+ */
+ public static TableTestingSpecBuilder newBuilder() {
+ return new TableTestingSpecBuilder();
+ }
+
+ /**
+ * Validate.
+ *
+ * @return true, if successful
+ */
+ public boolean validate() {
+ List<FieldSchema> columns;
+ try {
+ Hive metastoreClient = Hive.get(conf);
+ Table tbl = (db == null) ? metastoreClient.getTable(inputTable) : metastoreClient.getTable(db, inputTable);
+ columns = tbl.getAllCols();
+ columnNameToFieldSchema = new HashMap<String, FieldSchema>();
+
+ for (FieldSchema fieldSchema : columns) {
+ columnNameToFieldSchema.put(fieldSchema.getName(), fieldSchema);
+ }
+
+ // Check if output table exists
+ Table outTbl = metastoreClient.getTable(db == null ? "default" : db, outputTable, false);
+ outputTableExists = (outTbl != null);
+ } catch (HiveException exc) {
+ LOG.error("Error getting table info " + toString(), exc);
+ return false;
+ }
+
+ // Check if labeled column and feature columns are contained in the table
+ List<String> testTableColumns = new ArrayList<String>(columns.size());
+ for (FieldSchema column : columns) {
+ testTableColumns.add(column.getName());
+ }
+
+ List<String> inputColumnNames = new ArrayList();
+ for (Feature feature : featureColumns) {
+ inputColumnNames.add(feature.getDataColumn());
+ }
+
+ if (!testTableColumns.containsAll(inputColumnNames)) {
+ LOG.info("Invalid feature columns: " + inputColumnNames + ". Actual columns in table:" + testTableColumns);
+ return false;
+ }
+
+
+ if (StringUtils.isBlank(outputColumn)) {
+ LOG.info("Output column is required");
+ return false;
+ }
+
+ if (StringUtils.isBlank(outputTable)) {
+ LOG.info("Output table is required");
+ return false;
+ }
+ return true;
+ }
+
+ public String getTestQuery() {
+ if (!validate()) {
+ return null;
+ }
+
+ // We always insert a dynamic partition
+ StringBuilder q = new StringBuilder("INSERT OVERWRITE TABLE " + outputTable + " PARTITION (part_testid='" + testID
+ + "') SELECT ");
+ List<String> featureNameList = new ArrayList();
+ List<String> featureMapBuilder = new ArrayList();
+ for (Feature feature : featureColumns) {
+ featureNameList.add(feature.getDataColumn());
+ featureMapBuilder.add("'" + feature.getDataColumn() + "'");
+ featureMapBuilder.add(feature.getDataColumn());
+ }
+ String featureCols = StringUtils.join(featureNameList, ",");
+ String featureMapString = StringUtils.join(featureMapBuilder, ",");
+ q.append(featureCols).append(",").append("predict(").append("'").append(algorithm)
+ .append("', ").append("'").append(modelID).append("', ").append("'").append(modelInstanceId).append("', ")
+ .append(featureMapString).append(") ").append(outputColumn)
+ .append(" FROM ").append(inputTable);
+
+ return q.toString();
+ }
+
+ public String getCreateOutputTableQuery() {
+ StringBuilder createTableQuery = new StringBuilder("CREATE TABLE IF NOT EXISTS ").append(outputTable).append("(");
+ // Output table contains feature columns, label column, output column
+ List<String> outputTableColumns = new ArrayList<String>();
+ for (Feature featureCol : featureColumns) {
+ outputTableColumns.add(featureCol.getDataColumn() + " "
+ + columnNameToFieldSchema.get(featureCol.getDataColumn()).getType());
+ }
+
+
+ outputTableColumns.add(outputColumn + " string");
+
+ createTableQuery.append(StringUtils.join(outputTableColumns, ", "));
+
+ // Append partition column
+ createTableQuery.append(") PARTITIONED BY (part_testid string)");
+
+ return createTableQuery.toString();
+ }
+
+ /**
+ * The Class TableTestingSpecBuilder.
+ */
+ public static class TableTestingSpecBuilder {
+
+ /**
+ * The spec.
+ */
+ private final BatchPredictSpec spec;
+
+ /**
+ * Instantiates a new table testing spec builder.
+ */
+ public TableTestingSpecBuilder() {
+ spec = new BatchPredictSpec();
+ }
+
+ /**
+ * Database.
+ *
+ * @param database the database
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder database(String database) {
+ spec.db = database;
+ return this;
+ }
+
+ /**
+ * Set the input table
+ *
+ * @param table the table
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder inputTable(String table) {
+ spec.inputTable = table;
+ return this;
+ }
+
+ /**
+ * Partition filter for input table
+ *
+ * @param partFilter the part filter
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder partitionFilter(String partFilter) {
+ spec.partitionFilter = partFilter;
+ return this;
+ }
+
+ /**
+ * Feature columns.
+ *
+ * @param featureColumns the feature columns
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder featureColumns(List<Feature> featureColumns) {
+ spec.featureColumns = featureColumns;
+ return this;
+ }
+
+ /**
+ * Output column.
+ *
+ * @param outputColumn the output column
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder outputColumn(String outputColumn) {
+ spec.outputColumn = outputColumn;
+ return this;
+ }
+
+ /**
+ * Output table.
+ *
+ * @param table the table
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder outputTable(String table) {
+ spec.outputTable = table;
+ return this;
+ }
+
+ /**
+ * Hive conf.
+ *
+ * @param conf the conf
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder hiveConf(HiveConf conf) {
+ spec.conf = conf;
+ return this;
+ }
+
+ /**
+ * Algorithm.
+ *
+ * @param algorithm the algorithm
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder algorithm(String algorithm) {
+ spec.algorithm = algorithm;
+ return this;
+ }
+
+ /**
+ * Model id.
+ *
+ * @param modelID the model id
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder modelID(String modelID) {
+ spec.modelID = modelID;
+ return this;
+ }
+
+ /**
+ * modelInstanceID
+ *
+ * @param modelInstanceId
+ * @return the table testing spec builder
+ */
+ public TableTestingSpecBuilder modelInstanceID(String modelInstanceId) {
+ spec.modelInstanceId = modelInstanceId;
+ return this;
+ }
+
+ /**
+ * Builds the.
+ *
+ * @return the table testing spec
+ */
+ public BatchPredictSpec build() {
+ return spec;
+ }
+
+ /**
+ * Set the unique test id
+ *
+ * @param testID
+ * @return
+ */
+ public TableTestingSpecBuilder testID(String testID) {
+ spec.testID = testID;
+ return this;
+ }
+ }
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/HiveMLUDF.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/HiveMLUDF.java
index 2addb20..5e41fc3 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/HiveMLUDF.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/HiveMLUDF.java
@@ -19,9 +19,14 @@
package org.apache.lens.ml.impl;
import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
-import org.apache.lens.ml.algo.api.MLModel;
+import org.apache.lens.ml.algo.api.TrainedModel;
+import org.apache.lens.ml.api.MLConfConstants;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.MapredContext;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
@@ -36,34 +41,39 @@
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.mapred.JobConf;
-import lombok.extern.slf4j.Slf4j;
-
/**
* Generic UDF to laod ML Models saved in HDFS and apply the model on list of columns passed as argument.
+ * The feature list is expected to be key value pair. i.e. feature_name, feature_value
*/
@Description(name = "predict",
value = "_FUNC_(algorithm, modelID, features...) - Run prediction algorithm with given "
+ "algorithm name, model ID and input feature columns")
-@Slf4j
public final class HiveMLUDF extends GenericUDF {
+
+ /**
+ * The Constant LOG.
+ */
+ public static final Log LOG = LogFactory.getLog(HiveMLUDF.class);
+ /**
+ * The conf.
+ */
+ private JobConf conf;
+ /**
+ * The soi.
+ */
+ private StringObjectInspector soi;
+ /**
+ * The doi.
+ */
+ private LazyDoubleObjectInspector doi;
+ /**
+ * The model.
+ */
+ private TrainedModel model;
+
private HiveMLUDF() {
}
- /** The Constant UDF_NAME. */
- public static final String UDF_NAME = "predict";
-
- /** The conf. */
- private JobConf conf;
-
- /** The soi. */
- private StringObjectInspector soi;
-
- /** The doi. */
- private LazyDoubleObjectInspector doi;
-
- /** The model. */
- private MLModel model;
-
/**
* Currently we only support double as the return value.
*
@@ -73,12 +83,22 @@
*/
@Override
public ObjectInspector initialize(ObjectInspector[] objectInspectors) throws UDFArgumentException {
- // We require algo name, model id and at least one feature
- if (objectInspectors.length < 3) {
- throw new UDFArgumentLengthException("Algo name, model ID and at least one feature should be passed to "
- + UDF_NAME);
+ // We require algo name, model id, modelInstance id and at least one feature name value pair
+ String usage = "algo_name model_id, modelInstance_id [feature_name, feature_value]+ .";
+ if (objectInspectors.length < 5) {
+ throw new UDFArgumentLengthException(
+ "Algo name, model ID, modelInstance ID and at least one feature name value pair should be passed to "
+ + MLConfConstants.UDF_NAME + ". " + usage);
}
- log.info("{} initialized", UDF_NAME);
+
+ int numberOfFeatures = objectInspectors.length;
+ if (numberOfFeatures % 2 == 0) {
+ throw new UDFArgumentException(
+ "The feature list should be even in length since it's key value pair. i.e. feature_name, feature_value" + ". "
+ + usage);
+ }
+
+ LOG.info(MLConfConstants.UDF_NAME + " initialized");
return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
}
@@ -92,22 +112,36 @@
public Object evaluate(DeferredObject[] deferredObjects) throws HiveException {
String algorithm = soi.getPrimitiveJavaObject(deferredObjects[0].get());
String modelId = soi.getPrimitiveJavaObject(deferredObjects[1].get());
+ String modelInstanceId = soi.getPrimitiveJavaObject(deferredObjects[2].get());
+ Map<String, String> features = new HashMap();
- Double[] features = new Double[deferredObjects.length - 2];
- for (int i = 2; i < deferredObjects.length; i++) {
- LazyDouble lazyDouble = (LazyDouble) deferredObjects[i].get();
- features[i - 2] = (lazyDouble == null) ? 0d : doi.get(lazyDouble);
+ for (int i = 3; i < deferredObjects.length; i += 2) {
+ try {
+ String key = soi.getPrimitiveJavaObject(deferredObjects[i].get());
+ LazyDouble lazyDouble = (LazyDouble) deferredObjects[i + 1].get();
+ Double value = (lazyDouble == null) ? 0d : doi.get(lazyDouble);
+ LOG.debug("key: " + key + ", value " + value);
+ features.put(key, String.valueOf(value));
+ } catch (Exception e) {
+ LOG.error("Error Parsing feature pair");
+ throw new HiveException(e.getMessage());
+ }
}
try {
if (model == null) {
- model = ModelLoader.loadModel(conf, algorithm, modelId);
+ model = ModelLoader.loadModel(conf, algorithm, modelId, modelInstanceId);
}
} catch (IOException e) {
throw new HiveException(e);
}
- return model.predict(features);
+ try {
+ Object object = model.predict(features);
+ return object;
+ } catch (Exception e) {
+ throw new HiveException(e);
+ }
}
/*
@@ -117,7 +151,7 @@
*/
@Override
public String getDisplayString(String[] strings) {
- return UDF_NAME;
+ return MLConfConstants.UDF_NAME;
}
/*
@@ -131,6 +165,7 @@
conf = context.getJobConf();
soi = PrimitiveObjectInspectorFactory.javaStringObjectInspector;
doi = LazyPrimitiveObjectInspectorFactory.LAZY_DOUBLE_OBJECT_INSPECTOR;
- log.info("{} configured. Model base dir path: {}", UDF_NAME, conf.get(ModelLoader.MODEL_PATH_BASE_DIR));
+ LOG.info(
+ MLConfConstants.UDF_NAME + " configured. Model base dir path: " + conf.get(ModelLoader.MODEL_PATH_BASE_DIR));
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensMLImpl.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensMLImpl.java
index e090e68..228adad 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensMLImpl.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensMLImpl.java
@@ -19,73 +19,77 @@
package org.apache.lens.ml.impl;
import java.io.IOException;
-import java.io.ObjectOutputStream;
+import java.sql.SQLException;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
-import javax.ws.rs.client.Client;
-import javax.ws.rs.client.ClientBuilder;
-import javax.ws.rs.client.Entity;
-import javax.ws.rs.client.WebTarget;
-import javax.ws.rs.core.GenericType;
-import javax.ws.rs.core.MediaType;
-
import org.apache.lens.api.LensConf;
import org.apache.lens.api.LensSessionHandle;
-import org.apache.lens.api.query.LensQuery;
import org.apache.lens.api.query.QueryHandle;
-import org.apache.lens.api.query.QueryStatus;
-import org.apache.lens.api.result.LensAPIResult;
-import org.apache.lens.ml.algo.api.MLAlgo;
+import org.apache.lens.ml.algo.api.Algorithm;
import org.apache.lens.ml.algo.api.MLDriver;
-import org.apache.lens.ml.algo.api.MLModel;
-import org.apache.lens.ml.algo.spark.BaseSparkAlgo;
+import org.apache.lens.ml.algo.api.TrainedModel;
import org.apache.lens.ml.algo.spark.SparkMLDriver;
-import org.apache.lens.ml.api.LensML;
-import org.apache.lens.ml.api.MLTestReport;
-import org.apache.lens.server.api.LensConfConstants;
+import org.apache.lens.ml.api.*;
+import org.apache.lens.ml.dao.MetaStoreClient;
+import org.apache.lens.ml.dao.MetaStoreClientImpl;
import org.apache.lens.server.api.error.LensException;
import org.apache.lens.server.api.session.SessionService;
-import org.apache.commons.io.IOUtils;
-import org.apache.hadoop.fs.FileStatus;
-import org.apache.hadoop.fs.FileSystem;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.hive.ql.session.SessionState;
import org.apache.spark.api.java.JavaSparkContext;
-import org.glassfish.jersey.media.multipart.FormDataBodyPart;
-import org.glassfish.jersey.media.multipart.FormDataContentDisposition;
-import org.glassfish.jersey.media.multipart.FormDataMultiPart;
-import org.glassfish.jersey.media.multipart.MultiPartFeature;
-
-import lombok.extern.slf4j.Slf4j;
-
-/**
- * The Class LensMLImpl.
- */
-@Slf4j
public class LensMLImpl implements LensML {
- /** The drivers. */
+ /**
+ * The Constant LOG.
+ */
+ public static final Log LOG = LogFactory.getLog(LensMLImpl.class);
+
+ /**
+ * Check if the predict UDF has been registered for a user
+ */
+ private final Map<LensSessionHandle, Boolean> predictUdfStatus;
+
+ /**
+ * The drivers.
+ */
protected List<MLDriver> drivers;
- /** The conf. */
+ /**
+ * The metaStoreClient
+ */
+ MetaStoreClient metaStoreClient;
+
+ Map<String, Algo> algorithms = new HashMap<String, Algo>();
+
+ /**
+ * The conf.
+ */
private HiveConf conf;
- /** The spark context. */
+ /**
+ * The spark context.
+ */
private JavaSparkContext sparkContext;
- /** Check if the predict UDF has been registered for a user */
- private final Map<LensSessionHandle, Boolean> predictUdfStatus;
- /** Background thread to periodically check if we need to clear expire status for a session */
+ /**
+ * Background thread to periodically check if we need to clear expire status for a session
+ */
private ScheduledExecutorService udfStatusExpirySvc;
/**
+ * Life Cycle manager of Model Instance creation.
+ */
+ private MLProcessLifeCycleManager mlProcessLifeCycleManager;
+
+ /**
* Instantiates a new lens ml impl.
*
* @param conf the conf
@@ -108,155 +112,10 @@
this.sparkContext = jsc;
}
- public List<String> getAlgorithms() {
- List<String> algos = new ArrayList<String>();
- for (MLDriver driver : drivers) {
- algos.addAll(driver.getAlgoNames());
- }
- return algos;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String)
- */
- public MLAlgo getAlgoForName(String algorithm) throws LensException {
- for (MLDriver driver : drivers) {
- if (driver.isAlgoSupported(algorithm)) {
- return driver.getAlgoInstance(algorithm);
- }
- }
- throw new LensException("Algo not supported " + algorithm);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#train(java.lang.String, java.lang.String, java.lang.String[])
- */
- public String train(String table, String algorithm, String[] args) throws LensException {
- MLAlgo algo = getAlgoForName(algorithm);
-
- String modelId = UUID.randomUUID().toString();
-
- log.info("Begin training model " + modelId + ", algo=" + algorithm + ", table=" + table + ", params="
- + Arrays.toString(args));
-
- String database = null;
- if (SessionState.get() != null) {
- database = SessionState.get().getCurrentDatabase();
- } else {
- database = "default";
- }
-
- MLModel model = algo.train(toLensConf(conf), database, table, modelId, args);
-
- log.info("Done training model: " + modelId);
-
- model.setCreatedAt(new Date());
- model.setAlgoName(algorithm);
-
- Path modelLocation = null;
- try {
- modelLocation = persistModel(model);
- log.info("Model saved: " + modelId + ", algo: " + algorithm + ", path: " + modelLocation);
- return model.getId();
- } catch (IOException e) {
- throw new LensException("Error saving model " + modelId + " for algo " + algorithm, e);
- }
- }
-
/**
- * Gets the algo dir.
+ * Initialises LensMLImpl. Registers drives, checks if there are previously interrupted MLprocess.
*
- * @param algoName the algo name
- * @return the algo dir
- * @throws IOException Signals that an I/O exception has occurred.
- */
- private Path getAlgoDir(String algoName) throws IOException {
- String modelSaveBaseDir = conf.get(ModelLoader.MODEL_PATH_BASE_DIR, ModelLoader.MODEL_PATH_BASE_DIR_DEFAULT);
- return new Path(new Path(modelSaveBaseDir), algoName);
- }
-
- /**
- * Persist model.
- *
- * @param model the model
- * @return the path
- * @throws IOException Signals that an I/O exception has occurred.
- */
- private Path persistModel(MLModel model) throws IOException {
- // Get model save path
- Path algoDir = getAlgoDir(model.getAlgoName());
- FileSystem fs = algoDir.getFileSystem(conf);
-
- if (!fs.exists(algoDir)) {
- fs.mkdirs(algoDir);
- }
-
- Path modelSavePath = new Path(algoDir, model.getId());
- ObjectOutputStream outputStream = null;
-
- try {
- outputStream = new ObjectOutputStream(fs.create(modelSavePath, false));
- outputStream.writeObject(model);
- outputStream.flush();
- } catch (IOException io) {
- log.error("Error saving model " + model.getId() + " reason: " + io.getMessage(), io);
- throw io;
- } finally {
- IOUtils.closeQuietly(outputStream);
- }
- return modelSavePath;
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getModels(java.lang.String)
- */
- public List<String> getModels(String algorithm) throws LensException {
- try {
- Path algoDir = getAlgoDir(algorithm);
- FileSystem fs = algoDir.getFileSystem(conf);
- if (!fs.exists(algoDir)) {
- return null;
- }
-
- List<String> models = new ArrayList<String>();
-
- for (FileStatus stat : fs.listStatus(algoDir)) {
- models.add(stat.getPath().getName());
- }
-
- if (models.isEmpty()) {
- return null;
- }
-
- return models;
- } catch (IOException ioex) {
- throw new LensException(ioex);
- }
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getModel(java.lang.String, java.lang.String)
- */
- public MLModel getModel(String algorithm, String modelId) throws LensException {
- try {
- return ModelLoader.loadModel(conf, algorithm, modelId);
- } catch (IOException e) {
- throw new LensException(e);
- }
- }
-
- /**
- * Inits the.
- *
- * @param hiveConf the hive conf
+ * @param hiveConf
*/
public synchronized void init(HiveConf hiveConf) {
this.conf = hiveConf;
@@ -268,7 +127,7 @@
throw new RuntimeException("No ML Drivers specified in conf");
}
- log.info("Loading drivers " + Arrays.toString(driverClasses));
+ LOG.info("Loading drivers " + Arrays.toString(driverClasses));
drivers = new ArrayList<MLDriver>(driverClasses.length);
for (String driverClass : driverClasses) {
@@ -276,12 +135,12 @@
try {
cls = Class.forName(driverClass);
} catch (ClassNotFoundException e) {
- log.error("Driver class not found " + driverClass, e);
+ LOG.error("Driver class not found " + driverClass);
continue;
}
if (!MLDriver.class.isAssignableFrom(cls)) {
- log.warn("Not a driver class " + driverClass);
+ LOG.warn("Not a driver class " + driverClass);
continue;
}
@@ -290,21 +149,24 @@
MLDriver driver = mlDriverClass.newInstance();
driver.init(toLensConf(conf));
drivers.add(driver);
- log.info("Added driver " + driverClass);
+ LOG.info("Added driver " + driverClass);
} catch (Exception e) {
- log.error("Failed to create driver " + driverClass + " reason: " + e.getMessage(), e);
+ LOG.error("Failed to create driver " + driverClass + " reason: " + e.getMessage(), e);
}
}
if (drivers.isEmpty()) {
throw new RuntimeException("No ML drivers loaded");
}
- log.info("Inited ML service");
+ metaStoreClient = new MetaStoreClientImpl(MLUtils.createMLMetastoreConnectionPool(hiveConf));
+ metaStoreClient.init();
+ mlProcessLifeCycleManager = new MLProcessLifeCycleManager(conf, metaStoreClient, drivers);
+
+ mlProcessLifeCycleManager.init();
+
+ LOG.info("Inited ML service");
}
- /**
- * Start.
- */
public synchronized void start() {
for (MLDriver driver : drivers) {
try {
@@ -312,362 +174,247 @@
((SparkMLDriver) driver).useSparkContext(sparkContext);
}
driver.start();
+ registerAlgorithms(driver);
} catch (LensException e) {
- log.error("Failed to start driver " + driver, e);
+ LOG.error("Failed to start driver " + driver, e);
}
}
+ mlProcessLifeCycleManager.start();
+
udfStatusExpirySvc = Executors.newSingleThreadScheduledExecutor();
udfStatusExpirySvc.scheduleAtFixedRate(new UDFStatusExpiryRunnable(), 60, 60, TimeUnit.SECONDS);
- log.info("Started ML service");
+ LOG.info("Started ML service");
}
- /**
- * Stop.
- */
public synchronized void stop() {
for (MLDriver driver : drivers) {
try {
driver.stop();
} catch (LensException e) {
- log.error("Failed to stop driver " + driver, e);
+ LOG.error("Failed to stop driver " + driver, e);
}
}
drivers.clear();
udfStatusExpirySvc.shutdownNow();
- log.info("Stopped ML service");
+
+ mlProcessLifeCycleManager.stop();
+
+ LOG.info("Stopped ML service");
}
- public synchronized HiveConf getHiveConf() {
- return conf;
- }
-
- /**
- * Clear models.
- */
- public void clearModels() {
- ModelLoader.clearCache();
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getModelPath(java.lang.String, java.lang.String)
- */
- public String getModelPath(String algorithm, String modelID) {
- return ModelLoader.getModelLocation(conf, algorithm, modelID).toString();
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#testModel(org.apache.lens.api.LensSessionHandle, java.lang.String, java.lang.String,
- * java.lang.String)
- */
@Override
- public MLTestReport testModel(LensSessionHandle session, String table, String algorithm, String modelID,
- String outputTable) throws LensException {
+ public List<Algo> getAlgos() {
+ List<Algo> allAlgos = new ArrayList<Algo>();
+ allAlgos.addAll(algorithms.values());
+ return allAlgos;
+ }
+
+ @Override
+ public Algo getAlgo(String name) throws LensException {
+ return algorithms.get(name);
+ }
+
+ @Override
+ public void createDataSet(String name, String dataTable, String dataBase) throws LensException {
+ createDataSet(new DataSet(name, dataTable, dataBase));
+ }
+
+ @Override
+ public void createDataSet(DataSet dataSet) throws LensException {
+ try {
+ metaStoreClient.createDataSet(dataSet);
+ } catch (SQLException e) {
+ throw new LensException("Error while creating DataSet name " + dataSet.getDsName());
+ }
+ }
+
+ @Override
+ public String createDataSetFromQuery(String name, String query) {
return null;
}
- /**
- * Test a model in embedded mode.
- *
- * @param sessionHandle the session handle
- * @param table the table
- * @param algorithm the algorithm
- * @param modelID the model id
- * @param queryApiUrl the query api url
- * @return the ML test report
- * @throws LensException the lens exception
- */
- public MLTestReport testModelRemote(LensSessionHandle sessionHandle, String table, String algorithm, String modelID,
- String queryApiUrl, String outputTable) throws LensException {
- return testModel(sessionHandle, table, algorithm, modelID, new RemoteQueryRunner(sessionHandle, queryApiUrl),
- outputTable);
- }
-
- /**
- * Evaluate a model. Evaluation is done on data selected table from an input table. The model is run as a UDF and its
- * output is inserted into a table with a partition. Each evaluation is given a unique ID. The partition label is
- * associated with this unique ID.
- * <p></p>
- * <p>
- * This call also required a query runner. Query runner is responsible for executing the evaluation query against Lens
- * server.
- * </p>
- *
- * @param sessionHandle the session handle
- * @param table the table
- * @param algorithm the algorithm
- * @param modelID the model id
- * @param queryRunner the query runner
- * @param outputTable table where test output will be written
- * @return the ML test report
- * @throws LensException the lens exception
- */
- public MLTestReport testModel(final LensSessionHandle sessionHandle, String table, String algorithm, String modelID,
- QueryRunner queryRunner, String outputTable) throws LensException {
- if (sessionHandle == null) {
- throw new NullPointerException("Null session not allowed");
- }
- // check if algorithm exists
- if (!getAlgorithms().contains(algorithm)) {
- throw new LensException("No such algorithm " + algorithm);
- }
-
- MLModel<?> model;
+ @Override
+ public DataSet getDataSet(String name) throws LensException {
try {
- model = ModelLoader.loadModel(conf, algorithm, modelID);
- } catch (IOException e) {
- throw new LensException(e);
- }
-
- if (model == null) {
- throw new LensException("Model not found: " + modelID + " algorithm=" + algorithm);
- }
-
- String database = null;
-
- if (SessionState.get() != null) {
- database = SessionState.get().getCurrentDatabase();
- }
-
- String testID = UUID.randomUUID().toString().replace("-", "_");
- final String testTable = outputTable;
- final String testResultColumn = "prediction_result";
-
- // TODO support error metric UDAFs
- TableTestingSpec spec = TableTestingSpec.newBuilder().hiveConf(conf)
- .database(database == null ? "default" : database).inputTable(table).featureColumns(model.getFeatureColumns())
- .outputColumn(testResultColumn).lableColumn(model.getLabelColumn()).algorithm(algorithm).modelID(modelID)
- .outputTable(testTable).testID(testID).build();
-
- String testQuery = spec.getTestQuery();
- if (testQuery == null) {
- throw new LensException("Invalid test spec. " + "table=" + table + " algorithm=" + algorithm + " modelID="
- + modelID);
- }
-
- if (!spec.isOutputTableExists()) {
- log.info("Output table '" + testTable + "' does not exist for test algorithm = " + algorithm + " modelid="
- + modelID + ", Creating table using query: " + spec.getCreateOutputTableQuery());
- // create the output table
- String createOutputTableQuery = spec.getCreateOutputTableQuery();
- queryRunner.runQuery(createOutputTableQuery);
- log.info("Table created " + testTable);
- }
-
- // Check if ML UDF is registered in this session
- registerPredictUdf(sessionHandle, queryRunner);
-
- log.info("Running evaluation query " + testQuery);
- queryRunner.setQueryName("model_test_" + modelID);
- QueryHandle testQueryHandle = queryRunner.runQuery(testQuery);
-
- MLTestReport testReport = new MLTestReport();
- testReport.setReportID(testID);
- testReport.setAlgorithm(algorithm);
- testReport.setFeatureColumns(model.getFeatureColumns());
- testReport.setLabelColumn(model.getLabelColumn());
- testReport.setModelID(model.getId());
- testReport.setOutputColumn(testResultColumn);
- testReport.setOutputTable(testTable);
- testReport.setTestTable(table);
- testReport.setQueryID(testQueryHandle.toString());
-
- // Save test report
- persistTestReport(testReport);
- log.info("Saved test report " + testReport.getReportID());
- return testReport;
- }
-
- /**
- * Persist test report.
- *
- * @param testReport the test report
- * @throws LensException the lens exception
- */
- private void persistTestReport(MLTestReport testReport) throws LensException {
- log.info("saving test report " + testReport.getReportID());
- try {
- ModelLoader.saveTestReport(conf, testReport);
- log.info("Saved report " + testReport.getReportID());
- } catch (IOException e) {
- log.error("Error saving report " + testReport.getReportID() + " reason: " + e.getMessage(), e);
+ return metaStoreClient.getDataSet(name);
+ } catch (SQLException ex) {
+ throw new LensException("Error while reading DataSet. Name: " + name);
}
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getTestReports(java.lang.String)
- */
- public List<String> getTestReports(String algorithm) throws LensException {
- Path reportBaseDir = new Path(conf.get(ModelLoader.TEST_REPORT_BASE_DIR, ModelLoader.TEST_REPORT_BASE_DIR_DEFAULT));
- FileSystem fs = null;
-
- try {
- fs = reportBaseDir.getFileSystem(conf);
- if (!fs.exists(reportBaseDir)) {
- return null;
- }
-
- Path algoDir = new Path(reportBaseDir, algorithm);
- if (!fs.exists(algoDir)) {
- return null;
- }
-
- List<String> reports = new ArrayList<String>();
- for (FileStatus stat : fs.listStatus(algoDir)) {
- reports.add(stat.getPath().getName());
- }
- return reports;
- } catch (IOException e) {
- log.error("Error reading report list for " + algorithm, e);
- return null;
- }
+ @Override
+ public void createModel(String name, String algoName, Map<String, String> algoParams, List<Feature> features,
+ Feature label, LensSessionHandle lensSessionHandle) throws LensException {
+ AlgoSpec algoSpec = new AlgoSpec(algoName, algoParams);
+ metaStoreClient.createModel(name, algoName, algoSpec, features, label);
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getTestReport(java.lang.String, java.lang.String)
- */
- public MLTestReport getTestReport(String algorithm, String reportID) throws LensException {
- try {
- return ModelLoader.loadReport(conf, algorithm, reportID);
- } catch (IOException e) {
- throw new LensException(e);
- }
+ @Override
+ public void createModel(Model model) throws LensException {
+ metaStoreClient.createModel(model);
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#predict(java.lang.String, java.lang.String, java.lang.Object[])
- */
- public Object predict(String algorithm, String modelID, Object[] features) throws LensException {
- // Load the model instance
- MLModel<?> model = getModel(algorithm, modelID);
- return model.predict(features);
+ @Override
+ public Model getModel(String modelId) throws LensException {
+ return metaStoreClient.getModel(modelId);
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#deleteModel(java.lang.String, java.lang.String)
- */
- public void deleteModel(String algorithm, String modelID) throws LensException {
- try {
- ModelLoader.deleteModel(conf, algorithm, modelID);
- log.info("DELETED model " + modelID + " algorithm=" + algorithm);
- } catch (IOException e) {
- log.error(
- "Error deleting model file. algorithm=" + algorithm + " model=" + modelID + " reason: " + e.getMessage(), e);
- throw new LensException("Unable to delete model " + modelID + " for algorithm " + algorithm, e);
- }
+ @Override
+ public String trainModel(String modelId, String dataSetName, LensSessionHandle lensSessionHandle)
+ throws LensException {
+ String modelInstanceId = metaStoreClient.createModelInstance(new Date(), null, Status.SUBMITTED,
+ lensSessionHandle,
+ modelId, dataSetName, "", null);
+ ModelInstance modelInstance = metaStoreClient.getModelInstance(modelInstanceId);
+ mlProcessLifeCycleManager.addProcess(modelInstance);
+ return modelInstance.getId();
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#deleteTestReport(java.lang.String, java.lang.String)
- */
- public void deleteTestReport(String algorithm, String reportID) throws LensException {
- try {
- ModelLoader.deleteTestReport(conf, algorithm, reportID);
- log.info("DELETED report=" + reportID + " algorithm=" + algorithm);
- } catch (IOException e) {
- log.error("Error deleting report " + reportID + " algorithm=" + algorithm + " reason: " + e.getMessage(), e);
- throw new LensException("Unable to delete report " + reportID + " for algorithm " + algorithm, e);
+ @Override
+ public ModelInstance getModelInstance(String modelInstanceId) throws LensException {
+ ModelInstance modelInstance = (ModelInstance) mlProcessLifeCycleManager.getMLProcess(modelInstanceId);
+ if (modelInstance != null) {
+ return modelInstance;
}
+ return metaStoreClient.getModelInstance(modelInstanceId);
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getAlgoParamDescription(java.lang.String)
- */
- public Map<String, String> getAlgoParamDescription(String algorithm) {
- MLAlgo algo = null;
- try {
- algo = getAlgoForName(algorithm);
- } catch (LensException e) {
- log.error("Error getting algo description : " + algorithm, e);
- return null;
- }
- if (algo instanceof BaseSparkAlgo) {
- return ((BaseSparkAlgo) algo).getArgUsage();
- }
+ @Override
+ public boolean cancelModelInstance(String modelInstanceId, LensSessionHandle lensSessionHandle) throws LensException {
+ return mlProcessLifeCycleManager.cancelProcess(modelInstanceId, lensSessionHandle);
+ }
+
+ @Override
+ public List<ModelInstance> getAllModelInstances(String modelId) {
return null;
}
- /**
- * Submit model test query to a remote Lens server.
- */
- class RemoteQueryRunner extends QueryRunner {
+ @Override
+ public String evaluate(String modelInstanceId, String inputDataSetName, LensSessionHandle lensSessionHandle)
+ throws LensException {
- /** The query api url. */
- final String queryApiUrl;
-
- /**
- * Instantiates a new remote query runner.
- *
- * @param sessionHandle the session handle
- * @param queryApiUrl the query api url
- */
- public RemoteQueryRunner(LensSessionHandle sessionHandle, String queryApiUrl) {
- super(sessionHandle);
- this.queryApiUrl = queryApiUrl;
+ DataSet inputDataSet = getDataSet(inputDataSetName);
+ if (inputDataSet == null) {
+ throw new LensException("Input DataSet does not exist. Name: " + inputDataSetName);
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.TestQueryRunner#runQuery(java.lang.String)
- */
- @Override
- public QueryHandle runQuery(String query) throws LensException {
- // Create jersey client for query endpoint
- Client client = ClientBuilder.newBuilder().register(MultiPartFeature.class).build();
- WebTarget target = client.target(queryApiUrl);
- final FormDataMultiPart mp = new FormDataMultiPart();
- mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("sessionid").build(), sessionHandle,
- MediaType.APPLICATION_XML_TYPE));
- mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("query").build(), query));
- mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("operation").build(), "execute"));
+ String evaluationId = metaStoreClient.createEvaluation(new Date(), null, Status.SUBMITTED, lensSessionHandle,
+ modelInstanceId, inputDataSetName);
+ Evaluation evaluation = metaStoreClient.getEvaluation(evaluationId);
+ mlProcessLifeCycleManager.addProcess(evaluation);
+ return evaluation.getId();
+ }
- LensConf lensConf = new LensConf();
- lensConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_SET, false + "");
- lensConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_INDRIVER, false + "");
- mp.bodyPart(new FormDataBodyPart(FormDataContentDisposition.name("conf").fileName("conf").build(), lensConf,
- MediaType.APPLICATION_XML_TYPE));
+ @Override
+ public Evaluation getEvaluation(String evalId) throws LensException {
+ return metaStoreClient.getEvaluation(evalId);
+ }
- final QueryHandle handle = target.request().post(Entity.entity(mp, MediaType.MULTIPART_FORM_DATA_TYPE),
- new GenericType<LensAPIResult<QueryHandle>>() {}).getData();
+ @Override
+ public boolean cancelEvaluation(String evalId, LensSessionHandle lensSessionHandle) throws LensException {
+ return mlProcessLifeCycleManager.cancelProcess(evalId, lensSessionHandle);
+ }
- LensQuery ctx = target.path(handle.toString()).queryParam("sessionid", sessionHandle).request()
- .get(LensQuery.class);
+ @Override
+ public String predict(String modelInstanceId, String dataSetName, LensSessionHandle lensSessionHandle)
+ throws LensException {
+ String id = UUID.randomUUID().toString();
+ DataSet dataSet = getDataSet(dataSetName);
+ if (dataSet == null) {
+ throw new LensException("DataSet not available: " + dataSetName);
+ }
+ String outputDataSetName = MLConfConstants.PREDICTION_OUTPUT_TABLE_PREFIX + id.replace("-", "_");
+ createDataSet(outputDataSetName, outputDataSetName, dataSet.getDbName());
+ String predictionId = metaStoreClient.createPrediction(new Date(), null, Status.SUBMITTED, lensSessionHandle,
+ modelInstanceId,
+ dataSetName, outputDataSetName);
+ Prediction prediction = metaStoreClient.getPrediction(predictionId);
+ mlProcessLifeCycleManager.addProcess(prediction);
+ return prediction.getId();
+ }
- QueryStatus stat = ctx.getStatus();
- while (!stat.finished()) {
- ctx = target.path(handle.toString()).queryParam("sessionid", sessionHandle).request().get(LensQuery.class);
- stat = ctx.getStatus();
- try {
- Thread.sleep(500);
- } catch (InterruptedException e) {
- throw new LensException(e);
- }
+ @Override
+ public Prediction getPrediction(String predictionId) throws LensException {
+ Prediction prediction = (Prediction) mlProcessLifeCycleManager.getMLProcess(predictionId);
+ if (prediction != null) {
+ return prediction;
+ }
+ return metaStoreClient.getPrediction(predictionId);
+ }
+
+ @Override
+ public boolean cancelPrediction(String predictionId, LensSessionHandle lensSessionHandle) throws LensException {
+ return mlProcessLifeCycleManager.cancelProcess(predictionId, lensSessionHandle);
+ }
+
+ @Override
+ public String predict(String modelInstanceId, Map<String, String> featureVector) throws LensException {
+ ModelInstance modelInstance;
+ Model model;
+ try {
+ modelInstance = metaStoreClient.getModelInstance(modelInstanceId);
+ if (modelInstance == null) {
+ throw new LensException("Invalid modelInstance Id.");
}
-
- if (stat.getStatus() != QueryStatus.Status.SUCCESSFUL) {
- throw new LensException("Query failed " + ctx.getQueryHandle().getHandleId() + " reason:"
- + stat.getErrorMessage());
+ if (modelInstance.getStatus() != Status.COMPLETED) {
+ throw new LensException("Prediction is allowed only on modelInstances which has completed training "
+ + "successfully. Current modelInstance status : " + modelInstance.getStatus());
}
+ model = metaStoreClient.getModel(modelInstance.getModelId());
+ } catch (Exception e) {
+ throw new LensException("Error Reading modelInstanceId :" + modelInstanceId, e);
+ }
- return ctx.getQueryHandle();
+ try {
+ TrainedModel trainedModel = ModelLoader
+ .loadModel(conf, model.getAlgoSpec().getAlgo(), model.getName(), modelInstance.getId());
+ Object trainingResult = trainedModel.predict(featureVector);
+ return trainingResult.toString();
+ } catch (Exception e) {
+ throw new LensException("Error while training model for modelInstanceId: " + modelInstanceId, e);
+ }
+ }
+
+ @Override
+ public void deleteDataSet(String dataSetName) throws LensException {
+
+ }
+
+ @Override
+ public void deleteModel(String modelId) throws LensException {
+
+ }
+
+ @Override
+ public void deleteModelInstance(String modelInstanceId) throws LensException {
+
+ }
+
+ @Override
+ public void deleteEvaluation(String evaluationId) throws LensException {
+
+ }
+
+ @Override
+ public void deletePrediction(String predictionId) throws LensException {
+
+ }
+
+ /**
+ * Register all available algorithms to cache.
+ *
+ * @param driver
+ */
+ void registerAlgorithms(MLDriver driver) {
+ for (String algoName : driver.getAlgoNames()) {
+ try {
+ final Algorithm algorithm = driver.getAlgoInstance(algoName);
+ algorithms.put(algoName, new Algo(algorithm.getName(), algorithm.getDescription(), algorithm.getParams()));
+ } catch (Exception e) {
+ LOG.error("Couldn't register algorithm " + algoName);
+ }
}
}
@@ -683,29 +430,55 @@
return lensConf;
}
- protected void registerPredictUdf(LensSessionHandle sessionHandle, QueryRunner queryRunner) throws LensException {
- if (isUdfRegisterd(sessionHandle)) {
+ /**
+ * Gets the algo dir.
+ *
+ * @param algoName the algo name
+ * @return the algo dir
+ * @throws java.io.IOException Signals that an I/O exception has occurred.
+ */
+ private Path getAlgoDir(String algoName) throws IOException {
+ String modelSaveBaseDir = conf.get(ModelLoader.MODEL_PATH_BASE_DIR, ModelLoader.MODEL_PATH_BASE_DIR_DEFAULT);
+ return new Path(new Path(modelSaveBaseDir), algoName);
+ }
+
+ /**
+ * Registers predict UDF for a given LensSession
+ *
+ * @param sessionHandle
+ * @param lensQueryRunner
+ * @throws LensException
+ */
+ protected void registerPredictUdf(LensSessionHandle sessionHandle, LensQueryRunner lensQueryRunner)
+ throws LensException {
+ if (isUdfRegistered(sessionHandle)) {
// Already registered, nothing to do
return;
}
- log.info("Registering UDF for session " + sessionHandle.getPublicId().toString());
+ LOG.info("Registering UDF for session " + sessionHandle.getPublicId().toString());
- String regUdfQuery = "CREATE TEMPORARY FUNCTION " + HiveMLUDF.UDF_NAME + " AS '" + HiveMLUDF.class
+ String regUdfQuery = "CREATE TEMPORARY FUNCTION " + MLConfConstants.UDF_NAME + " AS '" + HiveMLUDF.class
.getCanonicalName() + "'";
- queryRunner.setQueryName("register_predict_udf_" + sessionHandle.getPublicId().toString());
- QueryHandle udfQuery = queryRunner.runQuery(regUdfQuery);
- log.info("udf query handle is " + udfQuery);
+ lensQueryRunner.setQueryName("register_predict_udf_" + sessionHandle.getPublicId().toString());
+ QueryHandle udfQuery = lensQueryRunner.runQuery(regUdfQuery);
+ LOG.info("udf query handle is " + udfQuery);
predictUdfStatus.put(sessionHandle, true);
- log.info("Predict UDF registered for session " + sessionHandle.getPublicId().toString());
+ LOG.info("Predict UDF registered for session " + sessionHandle.getPublicId().toString());
}
- protected boolean isUdfRegisterd(LensSessionHandle sessionHandle) {
+ /**
+ * Checks if predict UDF is registered for a given LensSession
+ *
+ * @param sessionHandle
+ * @return
+ */
+ protected boolean isUdfRegistered(LensSessionHandle sessionHandle) {
return predictUdfStatus.containsKey(sessionHandle);
}
/**
- * Periodically check if sessions have been closed, and clear UDF registered status.
+ * Clears predictUDFStatus for all closed LensSessions.
*/
private class UDFStatusExpiryRunnable implements Runnable {
public void run() {
@@ -715,12 +488,12 @@
List<LensSessionHandle> sessions = new ArrayList<LensSessionHandle>(predictUdfStatus.keySet());
for (LensSessionHandle sessionHandle : sessions) {
if (!sessionService.isOpen(sessionHandle)) {
- log.info("Session closed, removing UDF status: " + sessionHandle);
+ LOG.info("Session closed, removing UDF status: " + sessionHandle);
predictUdfStatus.remove(sessionHandle);
}
}
} catch (Exception exc) {
- log.warn("Error clearing UDF statuses", exc);
+ LOG.warn("Error clearing UDF statuses", exc);
}
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/QueryRunner.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensQueryRunner.java
similarity index 79%
rename from lens-ml-lib/src/main/java/org/apache/lens/ml/impl/QueryRunner.java
rename to lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensQueryRunner.java
index 2eeee29..db91c29 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/QueryRunner.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/LensQueryRunner.java
@@ -28,12 +28,15 @@
/**
* Run a query against a Lens server.
*/
-public abstract class QueryRunner {
+public abstract class LensQueryRunner {
- /** The session handle. */
+ /**
+ * The session handle.
+ */
protected final LensSessionHandle sessionHandle;
- @Getter @Setter
+ @Getter
+ @Setter
protected String queryName;
/**
@@ -41,7 +44,7 @@
*
* @param sessionHandle the session handle
*/
- public QueryRunner(LensSessionHandle sessionHandle) {
+ public LensQueryRunner(LensSessionHandle sessionHandle) {
this.sessionHandle = sessionHandle;
}
@@ -52,5 +55,9 @@
* @return the query handle
* @throws LensException the lens exception
*/
- public abstract QueryHandle runQuery(String query) throws LensException;
+ public abstract QueryHandle runQuery(String query, MLProcessLifeCycleManager.MLProcessContext mlProcessContext) throws
+ LensException;
+
+ public abstract QueryHandle runQuery(String query) throws
+ LensException;
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLProcessLifeCycleManager.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLProcessLifeCycleManager.java
new file mode 100644
index 0000000..74bdb69
--- /dev/null
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLProcessLifeCycleManager.java
@@ -0,0 +1,786 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml.impl;
+
+import java.io.IOException;
+import java.util.Date;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.*;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.lens.api.LensConf;
+import org.apache.lens.api.LensSessionHandle;
+import org.apache.lens.api.query.LensQuery;
+import org.apache.lens.api.query.QueryHandle;
+import org.apache.lens.api.query.QueryStatus;
+import org.apache.lens.ml.algo.api.Algorithm;
+import org.apache.lens.ml.algo.api.MLDriver;
+import org.apache.lens.ml.algo.api.TrainedModel;
+import org.apache.lens.ml.api.*;
+import org.apache.lens.ml.dao.MetaStoreClient;
+import org.apache.lens.server.api.LensConfConstants;
+import org.apache.lens.server.api.error.LensException;
+import org.apache.lens.server.api.query.QueryExecutionService;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.ql.session.SessionState;
+
+import lombok.Getter;
+import lombok.Setter;
+
+/**
+ * MLProcessLifeCycleManager class. Responsible for Life Cycle management of a MLProcess i.e. ModelInstance,
+ * Prediction, Evaluation.
+ */
+public class MLProcessLifeCycleManager {
+
+ /**
+ * The Constant LOG.
+ */
+ public static final Log LOG = LogFactory.getLog(MLProcessLifeCycleManager.class);
+ /**
+ * Prefix for all MLProcess worker threads.
+ */
+ private static final String ML_PROCESS_LIFECYCLE_THREAD_PREFIX = "MLProcess-";
+ /**
+ * Check if the predict UDF has been registered for a user
+ */
+ private final Map<LensSessionHandle, Boolean> predictUdfStatus;
+ /**
+ * Runnable for thread responsible for purging completed MLProcesses. and killing threads which have exceeded
+ * maximum life time.
+ */
+ ProcessPurger processPurgerRunnable = new ProcessPurger();
+ Thread processPurger = new Thread(processPurgerRunnable, "MLProcessPurger");
+ /**
+ * Runnable for thread responsible for submitting incoming MLProcesses to the Queue.
+ */
+ MLProcessSubmitter mlProcessSubmitterRunnable = new MLProcessSubmitter();
+ Thread mlProcessSubmitter = new Thread(mlProcessSubmitterRunnable, "MLProcessSubmitter");
+ /**
+ * boolean for checking whether this LifeCycle is running or not.
+ */
+ boolean stopped;
+ /**
+ * The meta store client.
+ */
+ MetaStoreClient metaStoreClient;
+
+ /**
+ * All registered drivers.
+ */
+ List<MLDriver> drivers;
+
+ /**
+ * Map for storing MLProcesses which are submitted, or executing. Once a process is finished purger thread will
+ * remove it from this Map after configured time or if a process exceeds it's maximum life time.
+ */
+ ConcurrentMap<String, MLProcessContext> allProcesses = new ConcurrentHashMap<String, MLProcessContext>();
+ private HiveConf conf;
+ /**
+ * All accepted MLProcesses are put into this queue. MLProcessSubmitter thread waits on this queue. It fetches
+ * MLProcesses from here starts its execution.
+ */
+ private BlockingQueue<MLProcessContext> submittedQueue = new LinkedBlockingQueue<MLProcessContext>();
+ /**
+ * Executor pool for MLProcess worker threads. i.e. EvaluationCreator, ModelCreator, PredictionCreator.
+ */
+ private ExecutorService executorPool;
+
+ public MLProcessLifeCycleManager(HiveConf conf, MetaStoreClient metaStoreClient, List<MLDriver> drivers) {
+ this.conf = conf;
+ this.metaStoreClient = metaStoreClient;
+ this.predictUdfStatus = new ConcurrentHashMap<LensSessionHandle, Boolean>();
+ this.drivers = drivers;
+ }
+
+ /**
+ * Initializes MLProcess. Restores previous incomplete MLProcesses
+ */
+ public void init() {
+
+ try {
+ for (MLProcess process : metaStoreClient.getIncompleteEvaluations()) {
+ MLProcessContext mlProcessContext = new MLProcessContext(process);
+ submittedQueue.add(mlProcessContext);
+ }
+ LOG.info("Restored old incomplete Evaluations.");
+ } catch (Exception e) {
+ LOG.error("Error while restoring previous incomplete Evaluations.");
+ }
+
+ try {
+ for (MLProcess process : metaStoreClient.getIncompleteModelInstances()) {
+ MLProcessContext mlProcessContext = new MLProcessContext(process);
+ submittedQueue.add(mlProcessContext);
+ }
+ LOG.info("Restored old incomplete ModelInstances.");
+ } catch (Exception e) {
+ LOG.error("Error while restoring previous incomplete ModelInstance.");
+ }
+
+ try {
+ for (MLProcess process : metaStoreClient.getIncompletePredictions()) {
+ MLProcessContext mlProcessContext = new MLProcessContext(process);
+ submittedQueue.add(mlProcessContext);
+ }
+ LOG.info("Restored old incomplete Predictions.");
+ } catch (Exception e) {
+ LOG.error("Error while restoring previous incomplete Predictions.");
+ }
+ LOG.info("Initialized MLProcessLifeCycle");
+ }
+
+ /**
+ * Starts the MLProcessLifeCycle Manager
+ */
+ public void start() {
+ stopped = false;
+ startExecutorPool();
+ mlProcessSubmitter.start();
+ processPurger.start();
+ LOG.info("Started MLProcessLifeCycle");
+ }
+
+ /**
+ * Stop teh ML Process Life Cycle Manager
+ */
+ public void stop() {
+ executorPool.shutdown();
+ stopped = true;
+ LOG.info("Stopped MLProcessLifeCycle");
+ }
+
+ public MLProcess getMLProcess(String id) {
+ if (allProcesses.containsKey(id)) {
+ return allProcesses.get(id).getMlProcess();
+ }
+ return null;
+ }
+
+ private void startExecutorPool() {
+ int minPoolSize =
+ conf.getInt(MLConfConstants.EXECUTOR_POOL_MIN_THREADS, MLConfConstants.DEFAULT_EXECUTOR_POOL_MIN_THREADS);
+ int maxPoolSize = conf.getInt(MLConfConstants.EXECUTOR_POOL_MAX_THREADS, MLConfConstants
+ .DEFAULT_EXECUTOR_POOL_MAX_THREADS);
+
+ final ThreadFactory defaultFactory = Executors.defaultThreadFactory();
+ final AtomicInteger thId = new AtomicInteger();
+
+ ThreadFactory threadFactory = new ThreadFactory() {
+ @Override
+ public Thread newThread(Runnable r) {
+ Thread th = defaultFactory.newThread(r);
+ th.setName(ML_PROCESS_LIFECYCLE_THREAD_PREFIX + thId.incrementAndGet());
+ return th;
+ }
+ };
+
+ LOG.debug("starting executor pool");
+ ThreadPoolExecutor executorPool =
+ new ThreadPoolExecutor(minPoolSize, maxPoolSize, MLConfConstants.DEFAULT_CREATOR_POOL_KEEP_ALIVE_MILLIS,
+ TimeUnit.MILLISECONDS,
+ new LinkedBlockingQueue<Runnable>(), threadFactory);
+ this.executorPool = executorPool;
+ }
+
+ /**
+ * Accepts MLProcesses and adds them to submitted queue.
+ *
+ * @param mlProcess
+ * @throws LensException
+ */
+ public void addProcess(MLProcess mlProcess) throws LensException {
+ MLProcessContext mlProcessContext = new MLProcessContext(mlProcess);
+ allProcesses.put(mlProcess.getId(), mlProcessContext);
+ submittedQueue.add(mlProcessContext);
+ LOG.debug("MLProcess submitted, Id: " + mlProcess.getId());
+ }
+
+ /**
+ * Cancels the execution of MLProcess by killing the Executor Thread.
+ *
+ * @param processId
+ * @return
+ */
+ public boolean cancelProcess(String processId, LensSessionHandle lensSessionHandle) throws LensException {
+
+ MLProcessContext mlProcessContext = allProcesses.get(processId);
+ if (mlProcessContext == null) {
+ return false;
+ }
+
+ if (mlProcessContext.isFinished()) {
+ return false;
+ }
+
+ updateProcess(mlProcessContext.getMlProcess(), Status.CANCELLED);
+
+ QueryExecutionService queryService;
+ try {
+ queryService = (QueryExecutionService) MLUtils.getServiceProvider().getService("query");
+
+ } catch (Exception e) {
+ throw new LensException("Error while getting Service Provider");
+ }
+
+ if (mlProcessContext.getCurrentQueryHandle() != null) {
+ queryService.cancelQuery(lensSessionHandle, mlProcessContext.getCurrentQueryHandle());
+ }
+ return true;
+ }
+
+ public void updateProcess(MLProcess mlProcess, Status newStatus) throws LensException {
+ synchronized (mlProcess) {
+ mlProcess.setFinishTime(new Date());
+ mlProcess.setStatus(newStatus);
+ if (mlProcess instanceof ModelInstance) {
+ metaStoreClient.updateModelInstance((ModelInstance) mlProcess);
+ } else if (mlProcess instanceof Prediction) {
+ metaStoreClient.updatePrediction((Prediction) mlProcess);
+ } else if (mlProcess instanceof Evaluation) {
+ metaStoreClient.updateEvaluation((Evaluation) mlProcess);
+ }
+ }
+ }
+
+ /**
+ * Sets MLProcess status. Also updates finish time.
+ *
+ * @param status
+ * @param ctx
+ */
+ public void setProcessStatusAndFinishTime(Status status, MLProcessContext ctx) {
+ synchronized (ctx) {
+ ctx.getMlProcess().setStatus(status);
+ ctx.getMlProcess().setFinishTime(new Date());
+ }
+ }
+
+ /**
+ * Registers the Predict UDF for a given Lens Session.
+ *
+ * @param sessionHandle
+ * @param lensQueryRunner
+ * @throws LensException
+ */
+ protected void registerPredictUdf(LensSessionHandle sessionHandle, LensQueryRunner lensQueryRunner,
+ MLProcessContext ctx)
+ throws
+ LensException {
+ if (isUdfRegistered(sessionHandle)) {
+ // Already registered, nothing to do
+ return;
+ }
+
+ LOG.info("Registering UDF for session " + sessionHandle.getPublicId().toString());
+
+ String regUdfQuery = "CREATE TEMPORARY FUNCTION " + MLConfConstants.UDF_NAME + " AS '" + HiveMLUDF.class
+ .getCanonicalName() + "'";
+ lensQueryRunner.setQueryName("register_predict_udf_" + sessionHandle.getPublicId().toString());
+ QueryHandle udfQuery = lensQueryRunner.runQuery(regUdfQuery, ctx);
+ LOG.info("udf query handle is " + udfQuery);
+ predictUdfStatus.put(sessionHandle, true);
+ LOG.info("Predict UDF registered for session " + sessionHandle.getPublicId().toString());
+ }
+
+ protected boolean isUdfRegistered(LensSessionHandle sessionHandle) {
+ return predictUdfStatus.containsKey(sessionHandle);
+ }
+
+ /**
+ * Returns the Algorithm registered in driver for the name.
+ *
+ * @param name
+ * @return
+ * @throws LensException
+ */
+ public Algorithm getAlgoForName(String name) throws LensException {
+ for (MLDriver driver : drivers) {
+ if (driver.isAlgoSupported(name)) {
+ Algorithm algorithm = driver.getAlgoInstance(name);
+ algorithm.configure(toLensConf(conf));
+ return algorithm;
+ }
+ }
+ throw new LensException("Algorithm not supported " + name);
+ }
+
+ private LensConf toLensConf(HiveConf conf) {
+ LensConf lensConf = new LensConf();
+ lensConf.getProperties().putAll(conf.getValByRegex(".*"));
+ return lensConf;
+ }
+
+ private void runPredictUDF(MLProcessContext ctx, String testQuery,
+ DirectQueryRunner queryRunner) throws LensException {
+ try {
+ LOG.info("Running Prediction UDF" + ctx.getMlProcess().getId());
+ queryRunner.runQuery(testQuery, ctx);
+ } catch (LensException e) {
+ LOG.error(
+ "Error while running MLProcess. Id: " + ctx.getMlProcess().getId() + ". Unable to run predict UDF"
+ + e.getMessage());
+ throw new LensException(
+ "Error while running MLProcess. Id: " + ctx.getMlProcess().getId() + ". Unable to run predict UDF", e);
+ }
+ }
+
+ /**
+ * MLProcessContext class
+ */
+ public class MLProcessContext {
+ @Getter
+ @Setter
+ MLProcess mlProcess;
+ @Getter
+ @Setter
+ Future thread;
+ @Getter
+ @Setter
+ QueryHandle currentQueryHandle;
+
+ public MLProcessContext(MLProcess mlProcess) {
+ this.mlProcess = mlProcess;
+ }
+
+ /**
+ * An ML process is finished if it has status among FAILED, COMPLETED or CANCELLED.
+ *
+ * @return
+ */
+ boolean isFinished() {
+ Status status = mlProcess.getStatus();
+ if (status == Status.FAILED || status == Status.COMPLETED || status == Status.CANCELLED) {
+ return true;
+ }
+ return false;
+ }
+ }
+
+ /**
+ * MLProcessSubmitter thread. Responsible for starting one of the executor threads based on the type of MLProcess.
+ */
+ class MLProcessSubmitter implements Runnable {
+ @Override
+ public void run() {
+
+ LOG.info("Started Submitter Thread.");
+ try {
+ while (!stopped) {
+ MLProcessContext ctx = submittedQueue.take();
+ synchronized (ctx) {
+ // Only accept the process with SUBMITTED status. they might be cancelled.
+ MLProcess mlProcess = ctx.getMlProcess();
+ Runnable creatorThread = null;
+ if (mlProcess instanceof ModelInstance) {
+ creatorThread = new ModelExecutor(ctx.getMlProcess().getId());
+ } else if (mlProcess instanceof Evaluation) {
+ creatorThread = new EvaluationExecutor(ctx.getMlProcess().getId());
+ } else if (mlProcess instanceof Prediction) {
+ creatorThread = new PredictionExecutor(ctx.getMlProcess().getId());
+ }
+ if (creatorThread != null) {
+ Future future = executorPool.submit(creatorThread);
+ ctx.setThread(future);
+ }
+
+ }
+ }
+ } catch (InterruptedException ex) {
+ LOG.error("Submitter has been interrupted, exiting" + ex.getMessage());
+ return;
+ } catch (Exception e) {
+ LOG.error("Error in submitter", e);
+ }
+ LOG.info("Submitter exited");
+ }
+ }
+
+ /**
+ * Worker Thread for running an Evaluation process. It generates the target Hive Query which uses the predict udf
+ * for predicting. Makes sure the outputTable is present, UDF is registered for current session. Finally runs the
+ * generated hive query through Lens Server. On successful completion it sets the status of MLProcess to COMPLETED
+ * otherwise FAILED.
+ */
+ private class EvaluationExecutor implements Runnable {
+ String evaluationId;
+
+ public EvaluationExecutor(String evaluationId) {
+ this.evaluationId = evaluationId;
+ }
+
+ @Override
+ public void run() {
+ MLProcessContext ctx = null;
+ Status finalProcessStatus = Status.FAILED;
+ Evaluation evaluation = null;
+ try {
+ ctx = allProcesses.get(evaluationId);
+ if (ctx != null) {
+ if (ctx.getMlProcess().getStatus() != Status.SUBMITTED) {
+ LOG.info("Process with status other than SUBMITTED submitted");
+ return;
+ }
+ }
+ evaluation = (Evaluation) ctx.getMlProcess();
+ updateProcess(evaluation, Status.RUNNING);
+
+ DataSet inputDataSet = metaStoreClient.getDataSet(evaluation.getInputDataSetName());
+ LensSessionHandle sessionHandle = evaluation.getLensSessionHandle();
+
+ ModelInstance modelInstance = metaStoreClient.getModelInstance(evaluation.getModelInstanceId());
+ Model model = metaStoreClient.getModel(modelInstance.getModelId());
+
+ final String testResultColumn = "prediction_result";
+ String outputTableName =
+ (MLConfConstants.EVALUATION_OUTPUT_TABLE_PREFIX + evaluation.getId()).replace("-", "_");
+ TableTestingSpec spec = TableTestingSpec.newBuilder().hiveConf(conf)
+ .database(inputDataSet.getDbName() == null ? "default" : inputDataSet.getDbName())
+ .inputTable(evaluation.getInputDataSetName())
+ .featureColumns(model.getFeatureSpec())
+ .outputColumn(testResultColumn).lableColumn(model.getLabelSpec())
+ .algorithm(model.getAlgoSpec().getAlgo()).modelID(model.getName())
+ .modelInstanceID(modelInstance.getId())
+ .outputTable(outputTableName).testID(evaluationId).build();
+
+ String evaluationQuery = spec.getTestQuery();
+ if (evaluationQuery == null) {
+ throw new LensException("Error while creating query.");
+ }
+
+ DirectQueryRunner queryRunner = new DirectQueryRunner(sessionHandle);
+
+ if (ctx.getMlProcess().getStatus() != Status.CANCELLED) {
+ createOutputTable(spec, ctx, evaluation, queryRunner);
+ }
+
+ if (ctx.getMlProcess().getStatus() != Status.CANCELLED) {
+ registerPredictUdf(sessionHandle, queryRunner, ctx);
+ }
+
+ if (ctx.getMlProcess().getStatus() != Status.CANCELLED) {
+ runPredictUDF(ctx, evaluationQuery, queryRunner);
+ }
+
+ finalProcessStatus = Status.COMPLETED;
+
+ } catch (Exception e) {
+ LOG.error("Error while Running Evaluation, Id:" + evaluationId);
+ finalProcessStatus = Status.FAILED;
+ } finally {
+ try {
+ if (ctx.getMlProcess().getStatus() != Status.CANCELLED) {
+ updateProcess(evaluation, finalProcessStatus);
+ }
+ } catch (Exception e) {
+ LOG.error("Error While updating Evaluation state, Id: " + evaluationId);
+ }
+ }
+
+ LOG.info("exiting evaluation creator!");
+
+ }
+
+ private void createOutputTable(TableTestingSpec spec, MLProcessContext ctx, Evaluation evaluation,
+ DirectQueryRunner queryRunner)
+ throws LensException {
+ ModelInstance modelInstance = metaStoreClient.getModelInstance(evaluation.getModelInstanceId());
+ Model model = metaStoreClient.getModel(modelInstance.getModelId());
+ String createOutputTableQuery = spec.getCreateOutputTableQuery();
+ LOG.error("Error while creating output table: for evaluation id: " + ctx.getMlProcess().getId() + " Create "
+ + "table query: " + spec.getCreateOutputTableQuery());
+ queryRunner.runQuery(createOutputTableQuery, ctx);
+ }
+ }
+
+ /**
+ * Worker Thread for Creation of Model Instances. It launches the job for training a Model against the inputTable. On
+ * successful completion it sets the status of MLProcess to COMPLETED otherwise FAILED.
+ */
+ private class ModelExecutor implements Runnable {
+ String id;
+
+ public ModelExecutor(String id) {
+ this.id = id;
+ }
+
+ @Override
+ public void run() {
+ MLProcessContext ctx;
+ try {
+ ctx = allProcesses.get(id);
+ } catch (NullPointerException ex) {
+ LOG.error("");
+ return;
+ }
+ ModelInstance modelInstance = null;
+ Status finalStatus = Status.FAILED;
+ try {
+ modelInstance = (ModelInstance) ctx.getMlProcess();
+ Model model = metaStoreClient.getModel(modelInstance.getModelId());
+ DataSet dataSet = metaStoreClient.getDataSet(modelInstance.getDataSetName());
+
+ Algorithm algorithm = getAlgoForName(model.getAlgoSpec().getAlgo());
+ TrainedModel trainedModel;
+ trainedModel = algorithm.train(model, dataSet);
+
+
+ Path modelLocation = MLUtils.persistModel(trainedModel, model, modelInstance.getId());
+ LOG.info("ModelInstance saved: " + modelInstance.getId() + ", algo: " + algorithm + ", path: "
+ + modelLocation);
+
+ //setProcessStatusAndFinishTime(Status.COMPLETED, ctx);
+ finalStatus = Status.COMPLETED;
+
+ } catch (IOException ex) {
+ LOG.error("Error saving modelInstance ID: " + ctx.getMlProcess().getId());
+ } catch (LensException ex) {
+ LOG.error("Error training modelInstance ID: " + ctx.getMlProcess().getId());
+ } catch (Exception e) {
+ LOG.error(e.getMessage());
+ } finally {
+ try {
+ updateProcess(ctx.getMlProcess(), finalStatus);
+ } catch (Exception e) {
+ LOG.error("Error occurred while updating final status for modelInstance: " + modelInstance.getId());
+ }
+ }
+ }
+ }
+
+ /**
+ * Worker Thread for Batch Prediction process. It generates the target Hive Query which uses the predict udf
+ * for prediction. Makes sure the outputTable is present, UDF is registered for current session. Finally runs the
+ * generated hive query through Lens Server. On successful completion it sets the status of MLProcess to COMPLETED
+ * otherwise FAILED.
+ */
+ private class PredictionExecutor implements Runnable {
+ String predictionId;
+
+ public PredictionExecutor(String predictionId) {
+ this.predictionId = predictionId;
+ }
+
+ @Override
+ public void run() {
+ MLProcessContext ctx = null;
+
+ try {
+ ctx = allProcesses.get(predictionId);
+
+ String database = null;
+ if (SessionState.get() != null) {
+ database = SessionState.get().getCurrentDatabase();
+ }
+ Prediction prediction = (Prediction) ctx.getMlProcess();
+
+ LensSessionHandle sessionHandle = prediction.getLensSessionHandle();
+
+ ModelInstance modelInstance = metaStoreClient.getModelInstance(prediction.getModelInstanceId());
+ Model model = metaStoreClient.getModel(modelInstance.getModelId());
+
+ final String testResultColumn = "prediction_result";
+
+ BatchPredictSpec spec = BatchPredictSpec.newBuilder().hiveConf(conf)
+ .database(database == null ? "default" : database).inputTable(prediction.getInputDataSet())
+ .featureColumns(model.getFeatureSpec())
+ .outputColumn(testResultColumn).algorithm(model.getAlgoSpec().getAlgo())
+ .modelID(model.getName()).modelInstanceID(modelInstance.getId())
+ .outputTable(prediction.getOutputDataSet()).testID(predictionId).build();
+
+ String testQuery = spec.getTestQuery();
+ if (testQuery == null) {
+ setProcessStatusAndFinishTime(Status.FAILED, ctx);
+ LOG.error("Error while running prediction. Id: " + ctx.getMlProcess().getId());
+ return;
+ } else {
+ DirectQueryRunner queryRunner = new DirectQueryRunner(sessionHandle);
+ if (!spec.isOutputTableExists()) {
+ try {
+ String createOutputTableQuery = spec.getCreateOutputTableQuery();
+ LOG.info("Output table '" + prediction.getOutputDataSet()
+ + "' does not exist for predicting algorithm = " + model.getAlgoSpec().getAlgo()
+ + " modelId="
+ + model.getName() + " modelInstanceId= " + modelInstance.getId()
+ + ", Creating table using query: " + spec.getCreateOutputTableQuery());
+ queryRunner.runQuery(createOutputTableQuery, ctx);
+ } catch (LensException e) {
+ LOG.error(
+ "Error while running prediction. Id: " + ctx.getMlProcess().getId() + "Unable to create output table"
+ + e.getMessage());
+ throw new LensException(
+ "Error while running prediction. Id: " + ctx.getMlProcess().getId() + "Unable to create output table"
+ , e);
+ }
+
+ registerPredictUdf(sessionHandle, queryRunner, ctx);
+
+ runPredictUDF(ctx, testQuery, queryRunner);
+
+ setProcessStatusAndFinishTime(Status.COMPLETED, ctx);
+ }
+ }
+ } catch (Exception e) {
+ if (ctx == null) {
+ LOG.error("Error while running prediction. Id: " + predictionId + ", " + e.getMessage());
+ return;
+ }
+ setProcessStatusAndFinishTime(Status.FAILED, ctx);
+ LOG.error("Error while running prediction. Id: " + predictionId + ", " + e.getMessage());
+ } finally {
+ updateMetastore((Prediction) ctx.getMlProcess());
+ }
+
+ LOG.info("exiting prediction creator!");
+ }
+
+ void updateMetastore(Prediction prediction) {
+ if (prediction != null) {
+ try {
+ metaStoreClient.updatePrediction(prediction);
+ } catch (Exception e) {
+ LOG.error("Error updating prediction status in metastore: Id: " + prediction.getId());
+ }
+ }
+ }
+ }
+
+ /**
+ * DirectQueryRunner class which runs query against the same lens server where ML Service is running.
+ */
+ private class DirectQueryRunner extends LensQueryRunner {
+
+ /**
+ * Instantiates a new direct query runner.
+ *
+ * @param sessionHandle the session handle
+ */
+ public DirectQueryRunner(LensSessionHandle sessionHandle) {
+ super(sessionHandle);
+ }
+
+ /**
+ * @param testQuery
+ * @return
+ * @throws LensException
+ */
+ @Override
+ public QueryHandle runQuery(String testQuery, MLProcessContext mlProcessContext) throws LensException {
+ // Run the query in query executions service
+ QueryExecutionService queryService;
+ try {
+ queryService = (QueryExecutionService) MLUtils.getServiceProvider().getService("query");
+
+ } catch (Exception e) {
+ throw new LensException("Error while getting Service Provider");
+ }
+
+ LensConf queryConf = new LensConf();
+ queryConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_SET, false + "");
+ queryConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_INDRIVER, false + "");
+ QueryHandle testQueryHandle = queryService.executeAsync(sessionHandle, testQuery, queryConf, queryName);
+ mlProcessContext.setCurrentQueryHandle(testQueryHandle);
+ // Wait for test query to complete
+ LensQuery query = queryService.getQuery(sessionHandle, testQueryHandle);
+
+ LOG.info("Submitted query " + testQueryHandle.getHandleId());
+ while (!query.getStatus().finished()) {
+ try {
+ Thread.sleep(2000);
+ } catch (InterruptedException e) {
+ throw new LensException(e);
+ }
+
+ query = queryService.getQuery(sessionHandle, testQueryHandle);
+ }
+
+ if (query.getStatus().getStatus() != QueryStatus.Status.SUCCESSFUL) {
+ throw new LensException("Failed to run test query: " + testQueryHandle.getHandleId() + " reason= "
+ + query.getStatus().getErrorMessage());
+ }
+
+ return testQueryHandle;
+ }
+
+ @Override
+ public QueryHandle runQuery(String query) throws LensException {
+ return runQuery(query, new MLProcessContext(null));
+ }
+ }
+
+ /**
+ * Process Purger Thread. Removes processes from in memory cache after MLConfConstants.ML_PROCESS_CACHE_LIFE time.
+ * Also kills a process if it exceeds MLConfConstants.ML_PROCESS_MAX_LIFE.
+ */
+ private class ProcessPurger implements Runnable {
+ @Override
+ public void run() {
+ Set<String> keys = allProcesses.keySet();
+ for (String key : keys) {
+ MLProcessContext ctx = allProcesses.get(key);
+ MLProcess mlProcess = ctx.getMlProcess();
+ long maxQueryLife = conf.getLong(MLConfConstants.ML_PROCESS_MAX_LIFE, MLConfConstants
+ .DEFAULT_ML_PROCESS_MAX_LIFE);
+
+ if (ctx.isFinished()) {
+ long cacheLife = conf.getLong(MLConfConstants.ML_PROCESS_CACHE_LIFE, MLConfConstants
+ .DEFAULT_ML_PROCESS_CACHE_LIFE);
+ if ((new Date().getTime() - mlProcess.getFinishTime().getTime()) > cacheLife) {
+ try {
+ updateMLProcess(mlProcess);
+ } catch (Exception e) {
+ LOG.error("Error while persisting MLProcess to meta store, Id: " + mlProcess.getId());
+ }
+ }
+ } else if ((new Date().getTime() - mlProcess.getFinishTime().getTime()) > maxQueryLife) {
+ // Kill the thread
+ try {
+ Future thread = ctx.getThread();
+ if (!thread.isDone()) {
+ thread.cancel(true);
+ }
+ mlProcess.setFinishTime(new Date());
+ mlProcess.setStatus(Status.FAILED);
+ updateMLProcess(mlProcess);
+ } catch (LensException e) {
+ LOG.error("Error while persisting MLProcess to meta store, Id: " + mlProcess.getId());
+ } catch (Exception e) {
+ LOG.error("Error while cancelling MLProcess, Id: " + mlProcess.getId());
+ }
+ }
+ }
+ }
+
+ void updateMLProcess(MLProcess mlProcess) throws LensException {
+ if (mlProcess instanceof Prediction) {
+ metaStoreClient.updatePrediction((Prediction) mlProcess);
+ } else if (mlProcess instanceof Evaluation) {
+ metaStoreClient.updateEvaluation((Evaluation) mlProcess);
+ } else if (mlProcess instanceof ModelInstance) {
+ metaStoreClient.updateModelInstance((ModelInstance) mlProcess);
+ }
+ }
+ }
+}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLRunner.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLRunner.java
deleted file mode 100644
index 91840e9..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLRunner.java
+++ /dev/null
@@ -1,167 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.impl;
-
-import java.io.File;
-import java.io.FileInputStream;
-import java.util.*;
-
-import org.apache.lens.client.LensClient;
-import org.apache.lens.client.LensClientConfig;
-import org.apache.lens.client.LensMLClient;
-
-import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.hive.metastore.TableType;
-import org.apache.hadoop.hive.metastore.api.FieldSchema;
-import org.apache.hadoop.hive.ql.metadata.Hive;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.metadata.Table;
-import org.apache.hadoop.hive.ql.plan.AddPartitionDesc;
-import org.apache.hadoop.hive.serde.serdeConstants;
-import org.apache.hadoop.mapred.TextInputFormat;
-
-import lombok.extern.slf4j.Slf4j;
-
-@Slf4j
-public class MLRunner {
-
- private LensMLClient mlClient;
- private String algoName;
- private String database;
- private String trainTable;
- private String trainFile;
- private String testTable;
- private String testFile;
- private String outputTable;
- private String[] features;
- private String labelColumn;
- private HiveConf conf;
-
- public void init(LensMLClient mlClient, String confDir) throws Exception {
- File dir = new File(confDir);
- File propFile = new File(dir, "ml.properties");
- Properties props = new Properties();
- props.load(new FileInputStream(propFile));
- String feat = props.getProperty("features");
- String trainFile = confDir + File.separator + "train.data";
- String testFile = confDir + File.separator + "test.data";
- init(mlClient, props.getProperty("algo"), props.getProperty("database"),
- props.getProperty("traintable"), trainFile,
- props.getProperty("testtable"), testFile,
- props.getProperty("outputtable"), feat.split(","),
- props.getProperty("labelcolumn"));
- }
-
- public void init(LensMLClient mlClient, String algoName,
- String database, String trainTable, String trainFile,
- String testTable, String testFile, String outputTable, String[] features,
- String labelColumn) {
- this.mlClient = mlClient;
- this.algoName = algoName;
- this.database = database;
- this.trainTable = trainTable;
- this.trainFile = trainFile;
- this.testTable = testTable;
- this.testFile = testFile;
- this.outputTable = outputTable;
- this.features = features;
- this.labelColumn = labelColumn;
- //hive metastore settings are loaded via lens-site.xml, so loading LensClientConfig
- //is required
- this.conf = new HiveConf(new LensClientConfig(), MLRunner.class);
- }
-
- public MLTask train() throws Exception {
- log.info("Starting train & eval");
-
- createTable(trainTable, trainFile);
- createTable(testTable, testFile);
- MLTask.Builder taskBuilder = new MLTask.Builder();
- taskBuilder.algorithm(algoName).hiveConf(conf).labelColumn(labelColumn).outputTable(outputTable)
- .client(mlClient).trainingTable(trainTable).testTable(testTable);
-
- // Add features
- for (String feature : features) {
- taskBuilder.addFeatureColumn(feature);
- }
- MLTask task = taskBuilder.build();
- log.info("Created task {}", task.toString());
- task.run();
- return task;
- }
-
- public void createTable(String tableName, String dataFile) throws HiveException {
-
- File filedataFile = new File(dataFile);
- Path dataFilePath = new Path(filedataFile.toURI());
- Path partDir = dataFilePath.getParent();
-
- // Create table
- List<FieldSchema> columns = new ArrayList<FieldSchema>();
-
- // Label is optional. Not used for unsupervised models.
- // If present, label will be the first column, followed by features
- if (labelColumn != null) {
- columns.add(new FieldSchema(labelColumn, "double", "Labelled Column"));
- }
-
- for (String feature : features) {
- columns.add(new FieldSchema(feature, "double", "Feature " + feature));
- }
-
- Table tbl = Hive.get(conf).newTable(database + "." + tableName);
- tbl.setTableType(TableType.MANAGED_TABLE);
- tbl.getTTable().getSd().setCols(columns);
- // tbl.getTTable().getParameters().putAll(new HashMap<String, String>());
- tbl.setInputFormatClass(TextInputFormat.class);
- tbl.setSerdeParam(serdeConstants.LINE_DELIM, "\n");
- tbl.setSerdeParam(serdeConstants.FIELD_DELIM, " ");
-
- List<FieldSchema> partCols = new ArrayList<FieldSchema>(1);
- partCols.add(new FieldSchema("dummy_partition_col", "string", ""));
- tbl.setPartCols(partCols);
-
- Hive.get(conf).dropTable(database, tableName, false, true);
- Hive.get(conf).createTable(tbl, true);
- log.info("Created table {}", tableName);
-
- // Add partition for the data file
- AddPartitionDesc partitionDesc = new AddPartitionDesc(database, tableName,
- false);
- Map<String, String> partSpec = new HashMap<String, String>();
- partSpec.put("dummy_partition_col", "dummy_val");
- partitionDesc.addPartition(partSpec, partDir.toUri().toString());
- Hive.get(conf).createPartitions(partitionDesc);
- log.info("{}: Added partition {}", tableName, partDir.toUri().toString());
- }
-
- public static void main(String[] args) throws Exception {
- if (args.length < 1) {
- System.out.println("Usage: " + MLRunner.class.getName() + " <ml-conf-dir>");
- System.exit(-1);
- }
- String confDir = args[0];
- LensMLClient client = new LensMLClient(new LensClient());
- MLRunner runner = new MLRunner();
- runner.init(client, confDir);
- runner.train();
- System.out.println("Created the Model successfully. Output Table: " + runner.outputTable);
- }
-}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLTask.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLTask.java
deleted file mode 100644
index a3695ba..0000000
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLTask.java
+++ /dev/null
@@ -1,284 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.lens.ml.impl;
-
-import java.util.*;
-
-import org.apache.lens.client.LensMLClient;
-import org.apache.lens.ml.api.LensML;
-import org.apache.lens.ml.api.MLTestReport;
-
-import org.apache.hadoop.hive.conf.HiveConf;
-
-import lombok.Getter;
-import lombok.ToString;
-import lombok.extern.slf4j.Slf4j;
-
-/**
- * Run a complete cycle of train and test (evaluation) for an ML algorithm
- */
-@ToString
-@Slf4j
-public class MLTask implements Runnable {
-
- public enum State {
- RUNNING, SUCCESSFUL, FAILED
- }
-
- @Getter
- private State taskState;
-
- /**
- * Name of the algo/algorithm.
- */
- @Getter
- private String algorithm;
-
- /**
- * Name of the table containing training data.
- */
- @Getter
- private String trainingTable;
-
- /**
- * Name of the table containing test data. Optional, if not provided trainingTable itself is
- * used for testing
- */
- @Getter
- private String testTable;
-
- /**
- * Training table partition spec
- */
- @Getter
- private String partitionSpec;
-
- /**
- * Name of the column which is a label for supervised algorithms.
- */
- @Getter
- private String labelColumn;
-
- /**
- * Names of columns which are features in the training data.
- */
- @Getter
- private List<String> featureColumns;
-
- /**
- * Configuration for the example.
- */
- @Getter
- private HiveConf configuration;
-
- private LensML ml;
- private String taskID;
-
- /**
- * ml client
- */
- @Getter
- private LensMLClient mlClient;
-
- /**
- * Output table name
- */
- @Getter
- private String outputTable;
-
- /**
- * Extra params passed to the training algorithm
- */
- @Getter
- private Map<String, String> extraParams;
-
- @Getter
- private String modelID;
-
- @Getter
- private String reportID;
-
- /**
- * Use ExampleTask.Builder to create an instance
- */
- private MLTask() {
- // Use builder to construct the example
- extraParams = new HashMap<String, String>();
- taskID = UUID.randomUUID().toString();
- }
-
- /**
- * Builder to create an example task
- */
- public static class Builder {
- private MLTask task;
-
- public Builder() {
- task = new MLTask();
- }
-
- public Builder trainingTable(String trainingTable) {
- task.trainingTable = trainingTable;
- return this;
- }
-
- public Builder testTable(String testTable) {
- task.testTable = testTable;
- return this;
- }
-
- public Builder algorithm(String algorithm) {
- task.algorithm = algorithm;
- return this;
- }
-
- public Builder labelColumn(String labelColumn) {
- task.labelColumn = labelColumn;
- return this;
- }
-
- public Builder client(LensMLClient client) {
- task.mlClient = client;
- return this;
- }
-
- public Builder addFeatureColumn(String featureColumn) {
- if (task.featureColumns == null) {
- task.featureColumns = new ArrayList<String>();
- }
- task.featureColumns.add(featureColumn);
- return this;
- }
-
- public Builder hiveConf(HiveConf hiveConf) {
- task.configuration = hiveConf;
- return this;
- }
-
-
-
- public Builder extraParam(String param, String value) {
- task.extraParams.put(param, value);
- return this;
- }
-
- public Builder partitionSpec(String partitionSpec) {
- task.partitionSpec = partitionSpec;
- return this;
- }
-
- public Builder outputTable(String outputTable) {
- task.outputTable = outputTable;
- return this;
- }
-
- public MLTask build() {
- MLTask builtTask = task;
- task = null;
- return builtTask;
- }
-
- }
-
- @Override
- public void run() {
- taskState = State.RUNNING;
- log.info("Starting {}", taskID);
- try {
- runTask();
- taskState = State.SUCCESSFUL;
- log.info("Complete {}", taskID);
- } catch (Exception e) {
- taskState = State.FAILED;
- log.info("Error running task {}", taskID, e);
- }
- }
-
- /**
- * Train an ML model, with specified algorithm and input data. Do model evaluation using the evaluation data and print
- * evaluation result
- *
- * @throws Exception
- */
- private void runTask() throws Exception {
- if (mlClient != null) {
- // Connect to a remote Lens server
- ml = mlClient;
- log.info("Working in client mode. Lens session handle {}", mlClient.getSessionHandle().getPublicId());
- } else {
- // In server mode session handle has to be passed by the user as a request parameter
- ml = MLUtils.getMLService();
- log.info("Working in Lens server");
- }
-
- String[] algoArgs = buildTrainingArgs();
- log.info("Starting task {} algo args: {} ", taskID, Arrays.toString(algoArgs));
-
- modelID = ml.train(trainingTable, algorithm, algoArgs);
- printModelMetadata(taskID, modelID);
-
- log.info("Starting test {}", taskID);
- testTable = (testTable != null) ? testTable : trainingTable;
- MLTestReport testReport = ml.testModel(mlClient.getSessionHandle(), testTable, algorithm, modelID, outputTable);
- reportID = testReport.getReportID();
- printTestReport(taskID, testReport);
- saveTask();
- }
-
- // Save task metadata to DB
- private void saveTask() {
- log.info("Saving task details to DB");
- }
-
- private void printTestReport(String exampleID, MLTestReport testReport) {
- StringBuilder builder = new StringBuilder("Example: ").append(exampleID);
- builder.append("\n\t");
- builder.append("EvaluationReport: ").append(testReport.toString());
- System.out.println(builder.toString());
- }
-
- private String[] buildTrainingArgs() {
- List<String> argList = new ArrayList<String>();
- argList.add("label");
- argList.add(labelColumn);
-
- // Add all the features
- for (String featureCol : featureColumns) {
- argList.add("feature");
- argList.add(featureCol);
- }
-
- // Add extra params
- for (String param : extraParams.keySet()) {
- argList.add(param);
- argList.add(extraParams.get(param));
- }
-
- return argList.toArray(new String[argList.size()]);
- }
-
- // Get the model instance and print its metadat to stdout
- private void printModelMetadata(String exampleID, String modelID) throws Exception {
- StringBuilder builder = new StringBuilder("Example: ").append(exampleID);
- builder.append("\n\t");
- builder.append("Model: ");
- builder.append(ml.getModel(algorithm, modelID).toString());
- System.out.println(builder.toString());
- }
-}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLUtils.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLUtils.java
index 9c96d9b..69a5756 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLUtils.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/MLUtils.java
@@ -18,22 +18,33 @@
*/
package org.apache.lens.ml.impl;
-import org.apache.lens.ml.algo.api.Algorithm;
-import org.apache.lens.ml.algo.api.MLAlgo;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+
+import org.apache.lens.ml.algo.api.TrainedModel;
+import org.apache.lens.ml.api.MLConfConstants;
+import org.apache.lens.ml.api.Model;
import org.apache.lens.ml.server.MLService;
import org.apache.lens.ml.server.MLServiceImpl;
import org.apache.lens.server.api.LensConfConstants;
import org.apache.lens.server.api.ServiceProvider;
import org.apache.lens.server.api.ServiceProviderFactory;
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hive.conf.HiveConf;
-public final class MLUtils {
- private MLUtils() {
- }
+import org.datanucleus.store.rdbms.datasource.dbcp.BasicDataSource;
+public final class MLUtils {
private static final HiveConf HIVE_CONF;
+ private static final Log LOG = LogFactory.getLog(MLUtils.class);
+
static {
HIVE_CONF = new HiveConf();
// Add default config so that we know the service provider implementation
@@ -41,12 +52,7 @@
HIVE_CONF.addResource("lens-site.xml");
}
- public static String getAlgoName(Class<? extends MLAlgo> algoClass) {
- Algorithm annotation = algoClass.getAnnotation(Algorithm.class);
- if (annotation != null) {
- return annotation.name();
- }
- throw new IllegalArgumentException("Algo should be decorated with annotation - " + Algorithm.class.getName());
+ private MLUtils() {
}
public static MLServiceImpl getMLService() throws Exception {
@@ -59,4 +65,50 @@
ServiceProviderFactory spf = spfClass.newInstance();
return spf.getServiceProvider();
}
+
+ public static Path persistModel(TrainedModel trainedModel, Model model, String modelInstanceId) throws IOException {
+ Path algoDir = getAlgoDir(model.getAlgoSpec().getAlgo());
+ FileSystem fs = algoDir.getFileSystem(HIVE_CONF);
+
+ if (!fs.exists(algoDir)) {
+ fs.mkdirs(algoDir);
+ }
+
+ Path modelSavePath = new Path(algoDir, new Path(model.getName(), modelInstanceId));
+ ObjectOutputStream outputStream = null;
+
+ try {
+ outputStream = new ObjectOutputStream(fs.create(modelSavePath, false));
+ outputStream.writeObject(trainedModel);
+ outputStream.flush();
+ } catch (IOException io) {
+ LOG.error("Error saving model " + modelInstanceId + " reason: " + io.getMessage());
+ throw io;
+ } finally {
+ IOUtils.closeQuietly(outputStream);
+ }
+ return modelSavePath;
+ }
+
+ public static Path getAlgoDir(String algoName) throws IOException {
+ String modelSaveBaseDir = HIVE_CONF.get(ModelLoader.MODEL_PATH_BASE_DIR, ModelLoader.MODEL_PATH_BASE_DIR_DEFAULT);
+ return new Path(new Path(modelSaveBaseDir), algoName);
+ }
+
+ public static BasicDataSource createMLMetastoreConnectionPool(Configuration conf) {
+ BasicDataSource tmp = new BasicDataSource();
+ tmp.setDriverClassName(conf.get(MLConfConstants.ML_META_STORE_DB_DRIVER_NAME,
+ MLConfConstants.DEFAULT_ML_META_STORE_DB_DRIVER_NAME));
+ tmp.setUrl(conf.get(MLConfConstants.ML_META_STORE_DB_JDBC_URL, MLConfConstants.DEFAULT_ML_META_STORE_DB_JDBC_URL));
+ tmp
+ .setUsername(conf.get(MLConfConstants.ML_META_STORE_DB_JDBC_USER, MLConfConstants.DEFAULT_ML_META_STORE_DB_USER));
+ tmp
+ .setPassword(conf.get(MLConfConstants.ML_META_STORE_DB_JDBC_PASS, MLConfConstants.DEFAULT_ML_META_STORE_DB_PASS));
+ //tmp.setValidationQuery(conf.get(MLConfConstants.ML_META_STORE_DB_VALIDATION_QUERY,
+ // MLConfConstants.DEFAULT_ML_META_STORE_DB_VALIDATION_QUERY));
+ tmp.setInitialSize(conf.getInt(MLConfConstants.ML_META_STORE_DB_SIZE, MLConfConstants
+ .DEFAULT_ML_META_STORE_DB_SIZE));
+ tmp.setDefaultAutoCommit(true);
+ return tmp;
+ }
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/ModelLoader.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/ModelLoader.java
index 8a69545..fa80a56 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/ModelLoader.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/ModelLoader.java
@@ -18,15 +18,17 @@
*/
package org.apache.lens.ml.impl;
-import java.io.*;
+import java.io.IOException;
+import java.io.ObjectInputStream;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
-import org.apache.lens.ml.algo.api.MLModel;
-import org.apache.lens.ml.api.MLTestReport;
+import org.apache.lens.ml.algo.api.TrainedModel;
import org.apache.commons.io.IOUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
@@ -34,23 +36,24 @@
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
-import lombok.extern.slf4j.Slf4j;
-/**
- * Load ML models from a FS location.
- */
-@Slf4j
-public final class ModelLoader {
- private ModelLoader() {
- }
-
- /** The Constant MODEL_PATH_BASE_DIR. */
+public class ModelLoader {
+ /**
+ * The Constant MODEL_PATH_BASE_DIR.
+ */
public static final String MODEL_PATH_BASE_DIR = "lens.ml.model.basedir";
- /** The Constant MODEL_PATH_BASE_DIR_DEFAULT. */
+ /**
+ * The Constant MODEL_PATH_BASE_DIR_DEFAULT.
+ */
public static final String MODEL_PATH_BASE_DIR_DEFAULT = "file:///tmp";
-
- /** The Constant TEST_REPORT_BASE_DIR. */
+ /**
+ * The Constant LOG.
+ */
+ public static final Log LOG = LogFactory.getLog(ModelLoader.class);
+ /**
+ * The Constant TEST_REPORT_BASE_DIR.
+ */
public static final String TEST_REPORT_BASE_DIR = "lens.ml.test.basedir";
/** The Constant TEST_REPORT_BASE_DIR_DEFAULT. */
@@ -60,13 +63,20 @@
/** The Constant MODEL_CACHE_SIZE. */
public static final long MODEL_CACHE_SIZE = 10;
- /** The Constant MODEL_CACHE_TIMEOUT. */
+ // Model cache settings
+ /**
+ * The Constant MODEL_CACHE_TIMEOUT.
+ */
public static final long MODEL_CACHE_TIMEOUT = 3600000L; // one hour
-
- /** The model cache. */
- private static Cache<Path, MLModel> modelCache = CacheBuilder.newBuilder().maximumSize(MODEL_CACHE_SIZE)
+ /**
+ * The model cache.
+ */
+ private static Cache<Path, TrainedModel> modelCache = CacheBuilder.newBuilder().maximumSize(MODEL_CACHE_SIZE)
.expireAfterAccess(MODEL_CACHE_TIMEOUT, TimeUnit.MILLISECONDS).build();
+ private ModelLoader() {
+ }
+
/**
* Gets the model location.
*
@@ -75,29 +85,28 @@
* @param modelID the model id
* @return the model location
*/
- public static Path getModelLocation(Configuration conf, String algorithm, String modelID) {
+ public static Path getModelLocation(Configuration conf, String algorithm, String modelID, String modelInstanceId) {
String modelDataBaseDir = conf.get(MODEL_PATH_BASE_DIR, MODEL_PATH_BASE_DIR_DEFAULT);
// Model location format - <modelDataBaseDir>/<algorithm>/modelID
- return new Path(new Path(new Path(modelDataBaseDir), algorithm), modelID);
+ return new Path(new Path(new Path(new Path(modelDataBaseDir), algorithm), modelID), modelInstanceId);
}
/**
- * Load model.
- *
- * @param conf the conf
- * @param algorithm the algorithm
- * @param modelID the model id
- * @return the ML model
- * @throws IOException Signals that an I/O exception has occurred.
+ * @param conf
+ * @param algorithm
+ * @param modelID
+ * @return
+ * @throws IOException
*/
- public static MLModel loadModel(Configuration conf, String algorithm, String modelID) throws IOException {
- final Path modelPath = getModelLocation(conf, algorithm, modelID);
- log.info("Loading model for algorithm: {} modelID: {} At path: {}", algorithm, modelID,
- modelPath.toUri().toString());
+ public static TrainedModel loadModel(Configuration conf, String algorithm, final String modelID,
+ String modelInstanceId) throws IOException {
+ final Path modelPath = getModelLocation(conf, algorithm, modelID, modelInstanceId);
+ LOG.info("Loading model for algorithm: " + algorithm + " modelID: " + modelID + " At path: "
+ + modelPath.toUri().toString());
try {
- return modelCache.get(modelPath, new Callable<MLModel>() {
+ return modelCache.get(modelPath, new Callable<TrainedModel>() {
@Override
- public MLModel call() throws Exception {
+ public TrainedModel call() throws Exception {
FileSystem fs = modelPath.getFileSystem(new HiveConf());
if (!fs.exists(modelPath)) {
throw new IOException("Model path not found " + modelPath.toString());
@@ -106,8 +115,8 @@
ObjectInputStream ois = null;
try {
ois = new ObjectInputStream(fs.open(modelPath));
- MLModel model = (MLModel) ois.readObject();
- log.info("Loaded model {} from location {}", model.getId(), modelPath);
+ TrainedModel model = (TrainedModel) ois.readObject();
+ LOG.info("Loaded model " + modelID + " from location " + modelPath);
return model;
} catch (ClassNotFoundException e) {
throw new IOException(e);
@@ -129,86 +138,6 @@
}
/**
- * Gets the test report path.
- *
- * @param conf the conf
- * @param algorithm the algorithm
- * @param report the report
- * @return the test report path
- */
- public static Path getTestReportPath(Configuration conf, String algorithm, String report) {
- String testReportDir = conf.get(TEST_REPORT_BASE_DIR, TEST_REPORT_BASE_DIR_DEFAULT);
- return new Path(new Path(testReportDir, algorithm), report);
- }
-
- /**
- * Save test report.
- *
- * @param conf the conf
- * @param report the report
- * @throws IOException Signals that an I/O exception has occurred.
- */
- public static void saveTestReport(Configuration conf, MLTestReport report) throws IOException {
- Path reportDir = new Path(conf.get(TEST_REPORT_BASE_DIR, TEST_REPORT_BASE_DIR_DEFAULT));
- FileSystem fs = reportDir.getFileSystem(conf);
-
- if (!fs.exists(reportDir)) {
- log.info("Creating test report dir {}", reportDir.toUri().toString());
- fs.mkdirs(reportDir);
- }
-
- Path algoDir = new Path(reportDir, report.getAlgorithm());
-
- if (!fs.exists(algoDir)) {
- log.info("Creating algorithm report dir {}", algoDir.toUri().toString());
- fs.mkdirs(algoDir);
- }
-
- ObjectOutputStream reportOutputStream = null;
- Path reportSaveLocation;
- try {
- reportSaveLocation = new Path(algoDir, report.getReportID());
- reportOutputStream = new ObjectOutputStream(fs.create(reportSaveLocation));
- reportOutputStream.writeObject(report);
- reportOutputStream.flush();
- } catch (IOException ioexc) {
- log.error("Error saving test report {}", report.getReportID(), ioexc);
- throw ioexc;
- } finally {
- IOUtils.closeQuietly(reportOutputStream);
- }
- log.info("Saved report {} at location {}", report.getReportID(), reportSaveLocation.toUri());
- }
-
- /**
- * Load report.
- *
- * @param conf the conf
- * @param algorithm the algorithm
- * @param reportID the report id
- * @return the ML test report
- * @throws IOException Signals that an I/O exception has occurred.
- */
- public static MLTestReport loadReport(Configuration conf, String algorithm, String reportID) throws IOException {
- Path reportLocation = getTestReportPath(conf, algorithm, reportID);
- FileSystem fs = reportLocation.getFileSystem(conf);
- ObjectInputStream reportStream = null;
- MLTestReport report = null;
-
- try {
- reportStream = new ObjectInputStream(fs.open(reportLocation));
- report = (MLTestReport) reportStream.readObject();
- } catch (IOException ioex) {
- log.error("Error reading report {}", reportLocation, ioex);
- } catch (ClassNotFoundException e) {
- throw new IOException(e);
- } finally {
- IOUtils.closeQuietly(reportStream);
- }
- return report;
- }
-
- /**
* Delete model.
*
* @param conf the conf
@@ -216,22 +145,10 @@
* @param modelID the model id
* @throws IOException Signals that an I/O exception has occurred.
*/
- public static void deleteModel(HiveConf conf, String algorithm, String modelID) throws IOException {
- Path modelLocation = getModelLocation(conf, algorithm, modelID);
+ public static void deleteModel(HiveConf conf, String algorithm, String modelID, String modelInstanceId)
+ throws IOException {
+ Path modelLocation = getModelLocation(conf, algorithm, modelID, modelInstanceId);
FileSystem fs = modelLocation.getFileSystem(conf);
fs.delete(modelLocation, false);
}
-
- /**
- * Delete test report.
- *
- * @param conf the conf
- * @param algorithm the algorithm
- * @param reportID the report id
- * @throws IOException Signals that an I/O exception has occurred.
- */
- public static void deleteTestReport(HiveConf conf, String algorithm, String reportID) throws IOException {
- Path reportPath = getTestReportPath(conf, algorithm, reportID);
- reportPath.getFileSystem(conf).delete(reportPath, false);
- }
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/TableTestingSpec.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/TableTestingSpec.java
index 470c977..36ccb68 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/TableTestingSpec.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/impl/TableTestingSpec.java
@@ -22,7 +22,11 @@
import java.util.HashMap;
import java.util.List;
+import org.apache.lens.ml.api.Feature;
+
import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
import org.apache.hadoop.hive.ql.metadata.Hive;
@@ -30,45 +34,71 @@
import org.apache.hadoop.hive.ql.metadata.Table;
import lombok.Getter;
-import lombok.extern.slf4j.Slf4j;
/**
* Table specification for running test on a table.
*/
-@Slf4j
public class TableTestingSpec {
- /** The db. */
+ /**
+ * The Constant LOG.
+ */
+ public static final Log LOG = LogFactory.getLog(TableTestingSpec.class);
+
+ /**
+ * The db.
+ */
private String db;
- /** The table containing input data. */
+ /**
+ * The table containing input data.
+ */
private String inputTable;
// TODO use partition condition
- /** The partition filter. */
+ /**
+ * The partition filter.
+ */
private String partitionFilter;
- /** The feature columns. */
- private List<String> featureColumns;
+ /**
+ * The feature columns.
+ */
+ private List<Feature> featureColumns;
- /** The label column. */
- private String labelColumn;
+ /**
+ * The label column.
+ */
+ private Feature labelColumn;
- /** The output column. */
+ /**
+ * The output column.
+ */
private String outputColumn;
- /** The output table. */
+ /**
+ * The output table.
+ */
private String outputTable;
- /** The conf. */
+ /**
+ * The conf.
+ */
private transient HiveConf conf;
- /** The algorithm. */
+ /**
+ * The algorithm.
+ */
private String algorithm;
- /** The model id. */
+ /**
+ * The model id.
+ */
private String modelID;
+ /*The modelInstanceId*/
+ private String modelInstanceId;
+
@Getter
private boolean outputTableExists;
@@ -78,11 +108,128 @@
private HashMap<String, FieldSchema> columnNameToFieldSchema;
/**
+ * New builder.
+ *
+ * @return the table testing spec builder
+ */
+ public static TableTestingSpecBuilder newBuilder() {
+ return new TableTestingSpecBuilder();
+ }
+
+ /**
+ * Validate.
+ *
+ * @return true, if successful
+ */
+ public boolean validate() {
+ List<FieldSchema> columns;
+ try {
+ Hive metastoreClient = Hive.get(conf);
+ Table tbl = (db == null) ? metastoreClient.getTable(inputTable) : metastoreClient.getTable(db, inputTable);
+ columns = tbl.getAllCols();
+ columnNameToFieldSchema = new HashMap<String, FieldSchema>();
+
+ for (FieldSchema fieldSchema : columns) {
+ columnNameToFieldSchema.put(fieldSchema.getName(), fieldSchema);
+ }
+
+ // Check if output table exists
+ Table outTbl = metastoreClient.getTable(db == null ? "default" : db, outputTable, false);
+ outputTableExists = (outTbl != null);
+ } catch (HiveException exc) {
+ LOG.error("Error getting table info " + toString(), exc);
+ return false;
+ }
+
+ // Check if labeled column and feature columns are contained in the table
+ List<String> testTableColumns = new ArrayList<String>(columns.size());
+ for (FieldSchema column : columns) {
+ testTableColumns.add(column.getName());
+ }
+
+ List<String> inputColumnNames = new ArrayList();
+ for (Feature feature : featureColumns) {
+ inputColumnNames.add(feature.getDataColumn());
+ }
+
+ if (!testTableColumns.containsAll(inputColumnNames)) {
+ LOG.info("Invalid feature columns: " + inputColumnNames + ". Actual columns in table:" + testTableColumns);
+ return false;
+ }
+
+ if (!testTableColumns.contains(labelColumn.getDataColumn())) {
+ LOG.info(
+ "Invalid label column: " + labelColumn.getDataColumn() + ". Actual columns in table:" + testTableColumns);
+ return false;
+ }
+
+ if (StringUtils.isBlank(outputColumn)) {
+ LOG.info("Output column is required");
+ return false;
+ }
+
+ if (StringUtils.isBlank(outputTable)) {
+ LOG.info("Output table is required");
+ return false;
+ }
+ return true;
+ }
+
+ public String getTestQuery() {
+ if (!validate()) {
+ return null;
+ }
+
+ // We always insert a dynamic partition
+ StringBuilder q = new StringBuilder("INSERT OVERWRITE TABLE " + outputTable + " PARTITION (part_testid='" + testID
+ + "') SELECT ");
+ List<String> featureNameList = new ArrayList();
+ List<String> featureMapBuilder = new ArrayList();
+ for (Feature feature : featureColumns) {
+ featureNameList.add(feature.getDataColumn());
+ featureMapBuilder.add("'" + feature.getDataColumn() + "'");
+ featureMapBuilder.add(feature.getDataColumn());
+ }
+ String featureCols = StringUtils.join(featureNameList, ",");
+ String featureMapString = StringUtils.join(featureMapBuilder, ",");
+ q.append(featureCols).append(",").append(labelColumn.getDataColumn()).append(", ").append("predict(").append("'")
+ .append(algorithm)
+ .append("', ").append("'").append(modelID).append("', ").append("'").append(modelInstanceId).append("', ")
+ .append(featureMapString).append(") ").append(outputColumn)
+ .append(" FROM ").append(inputTable);
+
+ return q.toString();
+ }
+
+ public String getCreateOutputTableQuery() {
+ StringBuilder createTableQuery = new StringBuilder("CREATE TABLE IF NOT EXISTS ").append(outputTable).append("(");
+ // Output table contains feature columns, label column, output column
+ List<String> outputTableColumns = new ArrayList<String>();
+ for (Feature featureCol : featureColumns) {
+ outputTableColumns.add(featureCol.getDataColumn() + " "
+ + columnNameToFieldSchema.get(featureCol.getDataColumn()).getType());
+ }
+
+ outputTableColumns.add(labelColumn.getDataColumn() + " "
+ + columnNameToFieldSchema.get(labelColumn.getDataColumn()).getType());
+ outputTableColumns.add(outputColumn + " string");
+
+ createTableQuery.append(StringUtils.join(outputTableColumns, ", "));
+
+ // Append partition column
+ createTableQuery.append(") PARTITIONED BY (part_testid string)");
+
+ return createTableQuery.toString();
+ }
+
+ /**
* The Class TableTestingSpecBuilder.
*/
public static class TableTestingSpecBuilder {
- /** The spec. */
+ /**
+ * The spec.
+ */
private final TableTestingSpec spec;
/**
@@ -131,7 +278,7 @@
* @param featureColumns the feature columns
* @return the table testing spec builder
*/
- public TableTestingSpecBuilder featureColumns(List<String> featureColumns) {
+ public TableTestingSpecBuilder featureColumns(List<Feature> featureColumns) {
spec.featureColumns = featureColumns;
return this;
}
@@ -142,7 +289,7 @@
* @param labelColumn the label column
* @return the table testing spec builder
*/
- public TableTestingSpecBuilder lableColumn(String labelColumn) {
+ public TableTestingSpecBuilder lableColumn(Feature labelColumn) {
spec.labelColumn = labelColumn;
return this;
}
@@ -202,6 +349,11 @@
return this;
}
+ public TableTestingSpecBuilder modelInstanceID(String modelInstanceId) {
+ spec.modelInstanceId = modelInstanceId;
+ return this;
+ }
+
/**
* Builds the.
*
@@ -222,101 +374,4 @@
return this;
}
}
-
- /**
- * New builder.
- *
- * @return the table testing spec builder
- */
- public static TableTestingSpecBuilder newBuilder() {
- return new TableTestingSpecBuilder();
- }
-
- /**
- * Validate.
- *
- * @return true, if successful
- */
- public boolean validate() {
- List<FieldSchema> columns;
- try {
- Hive metastoreClient = Hive.get(conf);
- Table tbl = (db == null) ? metastoreClient.getTable(inputTable) : metastoreClient.getTable(db, inputTable);
- columns = tbl.getAllCols();
- columnNameToFieldSchema = new HashMap<String, FieldSchema>();
-
- for (FieldSchema fieldSchema : columns) {
- columnNameToFieldSchema.put(fieldSchema.getName(), fieldSchema);
- }
-
- // Check if output table exists
- Table outTbl = metastoreClient.getTable(db == null ? "default" : db, outputTable, false);
- outputTableExists = (outTbl != null);
- } catch (HiveException exc) {
- log.error("Error getting table info {}", toString(), exc);
- return false;
- }
-
- // Check if labeled column and feature columns are contained in the table
- List<String> testTableColumns = new ArrayList<String>(columns.size());
- for (FieldSchema column : columns) {
- testTableColumns.add(column.getName());
- }
-
- if (!testTableColumns.containsAll(featureColumns)) {
- log.info("Invalid feature columns: {}. Actual columns in table:{}", featureColumns, testTableColumns);
- return false;
- }
-
- if (!testTableColumns.contains(labelColumn)) {
- log.info("Invalid label column: {}. Actual columns in table:{}", labelColumn, testTableColumns);
- return false;
- }
-
- if (StringUtils.isBlank(outputColumn)) {
- log.info("Output column is required");
- return false;
- }
-
- if (StringUtils.isBlank(outputTable)) {
- log.info("Output table is required");
- return false;
- }
- return true;
- }
-
- public String getTestQuery() {
- if (!validate()) {
- return null;
- }
-
- // We always insert a dynamic partition
- StringBuilder q = new StringBuilder("INSERT OVERWRITE TABLE " + outputTable + " PARTITION (part_testid='" + testID
- + "') SELECT ");
- String featureCols = StringUtils.join(featureColumns, ",");
- q.append(featureCols).append(",").append(labelColumn).append(", ").append("predict(").append("'").append(algorithm)
- .append("', ").append("'").append(modelID).append("', ").append(featureCols).append(") ").append(outputColumn)
- .append(" FROM ").append(inputTable);
-
- return q.toString();
- }
-
- public String getCreateOutputTableQuery() {
- StringBuilder createTableQuery = new StringBuilder("CREATE TABLE IF NOT EXISTS ").append(outputTable).append("(");
- // Output table contains feature columns, label column, output column
- List<String> outputTableColumns = new ArrayList<String>();
- for (String featureCol : featureColumns) {
- outputTableColumns.add(featureCol + " " + columnNameToFieldSchema.get(featureCol).getType());
- }
-
- outputTableColumns.add(labelColumn + " " + columnNameToFieldSchema.get(labelColumn).getType());
- outputTableColumns.add(outputColumn + " string");
-
- createTableQuery.append(StringUtils.join(outputTableColumns, ", "));
-
- // Append partition column
- createTableQuery.append(") PARTITIONED BY (part_testid string)");
-
- return createTableQuery.toString();
- }
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLApp.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLApp.java
index e6e3c02..641ec5c 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLApp.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLApp.java
@@ -29,14 +29,8 @@
@ApplicationPath("/ml")
public class MLApp extends Application {
-
private final Set<Class<?>> classes;
- /**
- * Pass additional classes when running in test mode
- *
- * @param additionalClasses
- */
public MLApp(Class<?>... additionalClasses) {
classes = new HashSet<Class<?>>();
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceImpl.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceImpl.java
index fcbc9ea..de4277d 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceImpl.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceImpl.java
@@ -21,41 +21,42 @@
import java.util.List;
import java.util.Map;
-import org.apache.lens.api.LensConf;
import org.apache.lens.api.LensSessionHandle;
-import org.apache.lens.api.query.LensQuery;
-import org.apache.lens.api.query.QueryHandle;
-import org.apache.lens.api.query.QueryStatus;
-import org.apache.lens.ml.algo.api.MLAlgo;
-import org.apache.lens.ml.algo.api.MLModel;
-import org.apache.lens.ml.api.MLTestReport;
+import org.apache.lens.ml.api.*;
import org.apache.lens.ml.impl.LensMLImpl;
-import org.apache.lens.ml.impl.ModelLoader;
-import org.apache.lens.ml.impl.QueryRunner;
import org.apache.lens.server.api.LensConfConstants;
import org.apache.lens.server.api.ServiceProvider;
import org.apache.lens.server.api.ServiceProviderFactory;
import org.apache.lens.server.api.error.LensException;
-import org.apache.lens.server.api.query.QueryExecutionService;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hive.service.CompositeService;
-import lombok.extern.slf4j.Slf4j;
-
/**
* The Class MLServiceImpl.
*/
-@Slf4j
public class MLServiceImpl extends CompositeService implements MLService {
- /** The ml. */
+ /**
+ * The Constant LOG.
+ */
+ public static final Log LOG = LogFactory.getLog(LensMLImpl.class);
+
+ /**
+ * The ml.
+ */
private LensMLImpl ml;
- /** The service provider. */
+ /**
+ * The service provider.
+ */
private ServiceProvider serviceProvider;
- /** The service provider factory. */
+ /**
+ * The service provider factory.
+ */
private ServiceProviderFactory serviceProviderFactory;
/**
@@ -75,48 +76,131 @@
}
@Override
- public List<String> getAlgorithms() {
- return ml.getAlgorithms();
+ public List<Algo> getAlgos() {
+ return ml.getAlgos();
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getAlgoForName(java.lang.String)
- */
@Override
- public MLAlgo getAlgoForName(String algorithm) throws LensException {
- return ml.getAlgoForName(algorithm);
+ public Algo getAlgo(String name) throws LensException {
+ return ml.getAlgo(name);
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#train(java.lang.String, java.lang.String, java.lang.String[])
- */
@Override
- public String train(String table, String algorithm, String[] args) throws LensException {
- return ml.train(table, algorithm, args);
+ public void createDataSet(String name, String dataTable, String dataBase) throws LensException {
+ ml.createDataSet(name, dataTable, dataBase);
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getModels(java.lang.String)
- */
- @Override
- public List<String> getModels(String algorithm) throws LensException {
- return ml.getModels(algorithm);
+ public void createDataSet(DataSet dataSet) throws LensException {
+ ml.createDataSet(dataSet);
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getModel(java.lang.String, java.lang.String)
- */
@Override
- public MLModel getModel(String algorithm, String modelId) throws LensException {
- return ml.getModel(algorithm, modelId);
+ public String createDataSetFromQuery(String name, String query) {
+ return ml.createDataSetFromQuery(name, query);
+ }
+
+ @Override
+ public DataSet getDataSet(String name) throws LensException {
+ return ml.getDataSet(name);
+ }
+
+ @Override
+ public void createModel(String name, String algo, Map<String, String> algoParams, List<Feature> features,
+ Feature label, LensSessionHandle lensSessionHandle) throws LensException {
+ ml.createModel(name, algo, algoParams, features, label, lensSessionHandle);
+ }
+
+ @Override
+ public void createModel(Model model) throws LensException {
+ ml.createModel(model);
+ }
+
+ @Override
+ public boolean cancelModelInstance(String modelInstanceId, LensSessionHandle lensSessionHandle) throws LensException {
+ return ml.cancelModelInstance(modelInstanceId, lensSessionHandle);
+ }
+
+ @Override
+ public boolean cancelEvaluation(String evalId, LensSessionHandle lensSessionHandle) throws LensException {
+ return ml.cancelEvaluation(evalId, lensSessionHandle);
+ }
+
+ @Override
+ public boolean cancelPrediction(String predicitonId, LensSessionHandle lensSessionHandle) throws LensException {
+ return ml.cancelPrediction(predicitonId, lensSessionHandle);
+ }
+
+ @Override
+ public Model getModel(String modelId) throws LensException {
+ return ml.getModel(modelId);
+ }
+
+ @Override
+ public String trainModel(String modelId, String dataSetName, LensSessionHandle lensSessionHandle)
+ throws LensException {
+ return ml.trainModel(modelId, dataSetName, lensSessionHandle);
+ }
+
+ @Override
+ public ModelInstance getModelInstance(String modelInstanceId) throws LensException {
+ return ml.getModelInstance(modelInstanceId);
+ }
+
+ @Override
+ public List<ModelInstance> getAllModelInstances(String modelId) {
+ return ml.getAllModelInstances(modelId);
+ }
+
+ @Override
+ public String evaluate(String modelInstanceId, String dataSetName, LensSessionHandle lensSessionHandle)
+ throws LensException {
+ return ml.evaluate(modelInstanceId, dataSetName, lensSessionHandle);
+ }
+
+ @Override
+ public Evaluation getEvaluation(String evalId) throws LensException {
+ return ml.getEvaluation(evalId);
+ }
+
+ @Override
+ public String predict(String modelInstanceId, String dataSetName, LensSessionHandle lensSessionHandle)
+ throws LensException {
+ return ml.predict(modelInstanceId, dataSetName, lensSessionHandle);
+ }
+
+ @Override
+ public Prediction getPrediction(String predictionId) throws LensException {
+ return ml.getPrediction(predictionId);
+ }
+
+ @Override
+ public String predict(String modelInstanceId, Map<String, String> featureVector) throws LensException {
+ return ml.predict(modelInstanceId, featureVector);
+ }
+
+ @Override
+ public void deleteDataSet(String dataSetName) throws LensException {
+ ml.deleteDataSet(dataSetName);
+ }
+
+ @Override
+ public void deleteModel(String modelId) throws LensException {
+ ml.deleteModel(modelId);
+ }
+
+ @Override
+ public void deleteModelInstance(String modelInstanceId) throws LensException {
+ ml.deleteModelInstance(modelInstanceId);
+ }
+
+ @Override
+ public void deleteEvaluation(String evaluationId) throws LensException {
+ ml.deleteEvaluation(evaluationId);
+ }
+
+ @Override
+ public void deletePrediction(String predictionId) throws LensException {
+ ml.deletePrediction(predictionId);
}
private ServiceProvider getServiceProvider() {
@@ -143,30 +227,25 @@
}
}
- /*
- * (non-Javadoc)
- *
- * @see org.apache.hive.service.CompositeService#init(org.apache.hadoop.hive.conf.HiveConf)
- */
@Override
public synchronized void init(HiveConf hiveConf) {
ml = new LensMLImpl(hiveConf);
ml.init(hiveConf);
super.init(hiveConf);
serviceProviderFactory = getServiceProviderFactory(hiveConf);
- log.info("Inited ML service");
+ LOG.info("Inited ML service");
}
/*
- * (non-Javadoc)
- *
- * @see org.apache.hive.service.CompositeService#start()
- */
+ * (non-Javadoc)
+ *
+ * @see org.apache.hive.service.CompositeService#start()
+ */
@Override
public synchronized void start() {
ml.start();
super.start();
- log.info("Started ML service");
+ LOG.info("Started ML service");
}
/*
@@ -178,147 +257,6 @@
public synchronized void stop() {
ml.stop();
super.stop();
- log.info("Stopped ML service");
- }
-
- /**
- * Clear models.
- */
- public void clearModels() {
- ModelLoader.clearCache();
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getModelPath(java.lang.String, java.lang.String)
- */
- @Override
- public String getModelPath(String algorithm, String modelID) {
- return ml.getModelPath(algorithm, modelID);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#testModel(org.apache.lens.api.LensSessionHandle, java.lang.String, java.lang.String,
- * java.lang.String)
- */
- @Override
- public MLTestReport testModel(LensSessionHandle sessionHandle, String table, String algorithm, String modelID,
- String outputTable) throws LensException {
- return ml.testModel(sessionHandle, table, algorithm, modelID, new DirectQueryRunner(sessionHandle), outputTable);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getTestReports(java.lang.String)
- */
- @Override
- public List<String> getTestReports(String algorithm) throws LensException {
- return ml.getTestReports(algorithm);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getTestReport(java.lang.String, java.lang.String)
- */
- @Override
- public MLTestReport getTestReport(String algorithm, String reportID) throws LensException {
- return ml.getTestReport(algorithm, reportID);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#predict(java.lang.String, java.lang.String, java.lang.Object[])
- */
- @Override
- public Object predict(String algorithm, String modelID, Object[] features) throws LensException {
- return ml.predict(algorithm, modelID, features);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#deleteModel(java.lang.String, java.lang.String)
- */
- @Override
- public void deleteModel(String algorithm, String modelID) throws LensException {
- ml.deleteModel(algorithm, modelID);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#deleteTestReport(java.lang.String, java.lang.String)
- */
- @Override
- public void deleteTestReport(String algorithm, String reportID) throws LensException {
- ml.deleteTestReport(algorithm, reportID);
- }
-
- /**
- * Run the test model query directly in the current lens server process.
- */
- private class DirectQueryRunner extends QueryRunner {
-
- /**
- * Instantiates a new direct query runner.
- *
- * @param sessionHandle the session handle
- */
- public DirectQueryRunner(LensSessionHandle sessionHandle) {
- super(sessionHandle);
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.TestQueryRunner#runQuery(java.lang.String)
- */
- @Override
- public QueryHandle runQuery(String testQuery) throws LensException {
- // Run the query in query executions service
- QueryExecutionService queryService = getServiceProvider().getService(QueryExecutionService.NAME);
-
- LensConf queryConf = new LensConf();
- queryConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_SET, false + "");
- queryConf.addProperty(LensConfConstants.QUERY_PERSISTENT_RESULT_INDRIVER, false + "");
-
- QueryHandle testQueryHandle = queryService.executeAsync(sessionHandle, testQuery, queryConf, queryName);
-
- // Wait for test query to complete
- LensQuery query = queryService.getQuery(sessionHandle, testQueryHandle);
- log.info("Submitted query {}", testQueryHandle.getHandleId());
- while (!query.getStatus().finished()) {
- try {
- Thread.sleep(500);
- } catch (InterruptedException e) {
- throw new LensException(e);
- }
-
- query = queryService.getQuery(sessionHandle, testQueryHandle);
- }
-
- if (query.getStatus().getStatus() != QueryStatus.Status.SUCCESSFUL) {
- throw new LensException("Failed to run test query: " + testQueryHandle.getHandleId() + " reason= "
- + query.getStatus().getErrorMessage());
- }
-
- return testQueryHandle;
- }
- }
-
- /*
- * (non-Javadoc)
- *
- * @see org.apache.lens.ml.LensML#getAlgoParamDescription(java.lang.String)
- */
- @Override
- public Map<String, String> getAlgoParamDescription(String algorithm) {
- return ml.getAlgoParamDescription(algorithm);
+ LOG.info("Stopped ML service");
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceResource.java b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceResource.java
index 53bac7d..ed0d398 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceResource.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/ml/server/MLServiceResource.java
@@ -18,58 +18,41 @@
*/
package org.apache.lens.ml.server;
-import static org.apache.commons.lang.StringUtils.isBlank;
-
import java.util.ArrayList;
import java.util.List;
-import java.util.Map;
-import java.util.Set;
import javax.ws.rs.*;
-import javax.ws.rs.core.*;
+import javax.ws.rs.core.MediaType;
+import org.apache.lens.api.APIResult;
import org.apache.lens.api.LensSessionHandle;
import org.apache.lens.api.StringList;
-import org.apache.lens.ml.algo.api.MLModel;
-import org.apache.lens.ml.api.MLTestReport;
-import org.apache.lens.ml.api.ModelMetadata;
-import org.apache.lens.ml.api.TestReport;
-import org.apache.lens.ml.impl.ModelLoader;
+import org.apache.lens.ml.api.*;
import org.apache.lens.server.api.LensConfConstants;
import org.apache.lens.server.api.ServiceProvider;
import org.apache.lens.server.api.ServiceProviderFactory;
import org.apache.lens.server.api.error.LensException;
-import org.apache.commons.lang.StringUtils;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
-import org.glassfish.jersey.media.multipart.FormDataParam;
-
-import lombok.extern.slf4j.Slf4j;
-
/**
* Machine Learning service.
*/
@Path("/ml")
@Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML})
-@Slf4j
public class MLServiceResource {
- /** The ml service. */
- MLService mlService;
-
- /** The service provider. */
- ServiceProvider serviceProvider;
-
- /** The service provider factory. */
- ServiceProviderFactory serviceProviderFactory;
-
- private static final HiveConf HIVE_CONF;
-
+ /**
+ * The Constant LOG.
+ */
+ public static final Log LOG = LogFactory.getLog(MLServiceResource.class);
/**
* Message indicating if ML service is up
*/
public static final String ML_UP_MESSAGE = "ML service is up";
+ private static final HiveConf HIVE_CONF;
static {
HIVE_CONF = new HiveConf();
@@ -79,6 +62,19 @@
}
/**
+ * The ml service.
+ */
+ MLService mlService;
+ /**
+ * The service provider.
+ */
+ ServiceProvider serviceProvider;
+ /**
+ * The service provider factory.
+ */
+ ServiceProviderFactory serviceProviderFactory;
+
+ /**
* Instantiates a new ML service resource.
*/
public MLServiceResource() {
@@ -127,286 +123,124 @@
}
/**
- * Get a list of algos available
+ * Get the list of algos available.
*
* @return
*/
@GET
@Path("algos")
+ public List<Algo> getAlgos() {
+ List<Algo> algos = getMlService().getAlgos();
+ return algos;
+ }
+
+ @GET
+ @Path("algonames")
public StringList getAlgoNames() {
- List<String> algos = getMlService().getAlgorithms();
- StringList result = new StringList(algos);
+ List<Algo> algos = getMlService().getAlgos();
+ ArrayList<String> stringArrayList = new ArrayList();
+ for (Algo algo : algos) {
+ stringArrayList.add(algo.getName());
+ }
+
+ StringList result = new StringList(stringArrayList);
return result;
}
- /**
- * Gets the human readable param description of an algorithm
- *
- * @param algorithm the algorithm
- * @return the param description
- */
- @GET
- @Path("algos/{algorithm}")
- public StringList getParamDescription(@PathParam("algorithm") String algorithm) {
- Map<String, String> paramDesc = getMlService().getAlgoParamDescription(algorithm);
- if (paramDesc == null) {
- throw new NotFoundException("Param description not found for " + algorithm);
- }
-
- List<String> descriptions = new ArrayList<String>();
- for (String key : paramDesc.keySet()) {
- descriptions.add(key + " : " + paramDesc.get(key));
- }
- return new StringList(descriptions);
- }
-
- /**
- * Get model ID list for a given algorithm.
- *
- * @param algorithm algorithm name
- * @return the models for algo
- * @throws LensException the lens exception
- */
- @GET
- @Path("models/{algorithm}")
- public StringList getModelsForAlgo(@PathParam("algorithm") String algorithm) throws LensException {
- List<String> models = getMlService().getModels(algorithm);
- if (models == null || models.isEmpty()) {
- throw new NotFoundException("No models found for algorithm " + algorithm);
- }
- return new StringList(models);
- }
-
- /**
- * Get metadata of the model given algorithm and model ID.
- *
- * @param algorithm algorithm name
- * @param modelID model ID
- * @return model metadata
- * @throws LensException the lens exception
- */
- @GET
- @Path("models/{algorithm}/{modelID}")
- public ModelMetadata getModelMetadata(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID)
- throws LensException {
- MLModel model = getMlService().getModel(algorithm, modelID);
- if (model == null) {
- throw new NotFoundException("Model not found " + modelID + ", algo=" + algorithm);
- }
-
- ModelMetadata meta = new ModelMetadata(model.getId(), model.getTable(), model.getAlgoName(), StringUtils.join(
- model.getParams(), ' '), model.getCreatedAt().toString(), getMlService().getModelPath(algorithm, modelID),
- model.getLabelColumn(), StringUtils.join(model.getFeatureColumns(), ","));
- return meta;
- }
-
- /**
- * Delete a model given model ID and algorithm name.
- *
- * @param algorithm the algorithm
- * @param modelID the model id
- * @return confirmation text
- * @throws LensException the lens exception
- */
- @DELETE
- @Consumes({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML, MediaType.TEXT_PLAIN})
- @Path("models/{algorithm}/{modelID}")
- public String deleteModel(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID)
- throws LensException {
- getMlService().deleteModel(algorithm, modelID);
- return "DELETED model=" + modelID + " algorithm=" + algorithm;
- }
-
- /**
- * Train a model given an algorithm name and algorithm parameters
- * <p>
- * Following parameters are mandatory and must be passed as part of the form
- * </p>
- * <ol>
- * <li>table - input Hive table to load training data from</li>
- * <li>label - name of the labelled column</li>
- * <li>feature - one entry per feature column. At least one feature column is required</li>
- * </ol>
- * <p></p>
- *
- * @param algorithm algorithm name
- * @param form form data
- * @return if model is successfully trained, the model ID will be returned
- * @throws LensException the lens exception
- */
@POST
- @Consumes(MediaType.APPLICATION_FORM_URLENCODED)
- @Path("{algorithm}/train")
- public String train(@PathParam("algorithm") String algorithm, MultivaluedMap<String, String> form)
- throws LensException {
-
- // Check if algo is valid
- if (getMlService().getAlgoForName(algorithm) == null) {
- throw new NotFoundException("Algo for algo: " + algorithm + " not found");
- }
-
- if (isBlank(form.getFirst("table"))) {
- throw new BadRequestException("table parameter is rquired");
- }
-
- String table = form.getFirst("table");
-
- if (isBlank(form.getFirst("label"))) {
- throw new BadRequestException("label parameter is required");
- }
-
- // Check features
- List<String> featureNames = form.get("feature");
- if (featureNames.size() < 1) {
- throw new BadRequestException("At least one feature is required");
- }
-
- List<String> algoArgs = new ArrayList<String>();
- Set<Map.Entry<String, List<String>>> paramSet = form.entrySet();
-
- for (Map.Entry<String, List<String>> e : paramSet) {
- String p = e.getKey();
- List<String> values = e.getValue();
- if ("algorithm".equals(p) || "table".equals(p)) {
- continue;
- } else if ("feature".equals(p)) {
- for (String feature : values) {
- algoArgs.add("feature");
- algoArgs.add(feature);
- }
- } else if ("label".equals(p)) {
- algoArgs.add("label");
- algoArgs.add(values.get(0));
- } else {
- algoArgs.add(p);
- algoArgs.add(values.get(0));
- }
- }
- log.info("Training table {} with algo {} params={}", table, algorithm, algoArgs.toString());
- String modelId = getMlService().train(table, algorithm, algoArgs.toArray(new String[]{}));
- log.info("Done training {} modelid = {}", table, modelId);
- return modelId;
+ @Path("dataset")
+ public APIResult createDataSet(DataSet dataSet) throws LensException {
+ getMlService().createDataSet(dataSet);
+ return new APIResult(APIResult.Status.SUCCEEDED, "");
}
- /**
- * Clear model cache (for admin use).
- *
- * @return OK if the cache was cleared
- */
- @DELETE
- @Path("clearModelCache")
- @Produces(MediaType.TEXT_PLAIN)
- public Response clearModelCache() {
- ModelLoader.clearCache();
- log.info("Cleared model cache");
- return Response.ok("Cleared cache", MediaType.TEXT_PLAIN_TYPE).build();
+ @GET
+ @Path("dataset")
+ public DataSet getDataSet(@QueryParam("dataSetName") String dataSetName) throws LensException {
+ return getMlService().getDataSet(dataSetName);
}
- /**
- * Run a test on a model for an algorithm.
- *
- * @param algorithm algorithm name
- * @param modelID model ID
- * @param table Hive table to run test on
- * @param session Lens session ID. This session ID will be used to run the test query
- * @return Test report ID
- * @throws LensException the lens exception
- */
@POST
- @Path("test/{table}/{algorithm}/{modelID}")
- @Consumes(MediaType.MULTIPART_FORM_DATA)
- public String test(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID,
- @PathParam("table") String table, @FormDataParam("sessionid") LensSessionHandle session,
- @FormDataParam("outputTable") String outputTable) throws LensException {
- MLTestReport testReport = getMlService().testModel(session, table, algorithm, modelID, outputTable);
- return testReport.getReportID();
+ @Path("models")
+ public APIResult createModel(Model model) throws LensException {
+ getMlService().createModel(model);
+ return new APIResult(APIResult.Status.SUCCEEDED, "");
}
- /**
- * Get list of reports for a given algorithm.
- *
- * @param algoritm the algoritm
- * @return the reports for algorithm
- * @throws LensException the lens exception
- */
@GET
- @Path("reports/{algorithm}")
- public StringList getReportsForAlgorithm(@PathParam("algorithm") String algoritm) throws LensException {
- List<String> reports = getMlService().getTestReports(algoritm);
- if (reports == null || reports.isEmpty()) {
- throw new NotFoundException("No test reports found for " + algoritm);
- }
- return new StringList(reports);
+ @Path("models")
+ public Model getModel(@QueryParam("modelName") String modelName) throws LensException {
+ return getMlService().getModel(modelName);
}
- /**
- * Get a single test report given the algorithm name and report id.
- *
- * @param algorithm the algorithm
- * @param reportID the report id
- * @return the test report
- * @throws LensException the lens exception
- */
@GET
- @Path("reports/{algorithm}/{reportID}")
- public TestReport getTestReport(@PathParam("algorithm") String algorithm, @PathParam("reportID") String reportID)
- throws LensException {
- MLTestReport report = getMlService().getTestReport(algorithm, reportID);
-
- if (report == null) {
- throw new NotFoundException("Test report: " + reportID + " not found for algorithm " + algorithm);
- }
-
- TestReport result = new TestReport(report.getTestTable(), report.getOutputTable(), report.getOutputColumn(),
- report.getLabelColumn(), StringUtils.join(report.getFeatureColumns(), ","), report.getAlgorithm(),
- report.getModelID(), report.getReportID(), report.getLensQueryID());
- return result;
+ @Path("train")
+ public String trainModel(@QueryParam("modelId") String modelId, @QueryParam("dataSetName") String dataSetName,
+ @QueryParam("lensSessionHandle") LensSessionHandle
+ lensSessionHandle)
+ throws
+ LensException {
+ return getMlService().trainModel(modelId, dataSetName, lensSessionHandle);
}
- /**
- * DELETE a report given the algorithm name and report ID.
- *
- * @param algorithm the algorithm
- * @param reportID the report id
- * @return the string
- * @throws LensException the lens exception
- */
+ @GET
+ @Path("modelinstance/{modelInstanceId}")
+ public ModelInstance getModelInstance(@PathParam("modelInstanceId") String modelInstanceId) throws LensException {
+ return getMlService().getModelInstance(modelInstanceId);
+ }
+
@DELETE
- @Path("reports/{algorithm}/{reportID}")
- @Consumes({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML, MediaType.TEXT_PLAIN})
- public String deleteTestReport(@PathParam("algorithm") String algorithm, @PathParam("reportID") String reportID)
+ @Path("modelinstance/{modelInstanceId}")
+ public boolean cancelModelInstance(@PathParam("modelInstanceId") String modelInstanceId,
+ @QueryParam("lensSessionHandle") LensSessionHandle lensSessionHandle)
throws LensException {
- getMlService().deleteTestReport(algorithm, reportID);
- return "DELETED report=" + reportID + " algorithm=" + algorithm;
+ return getMlService().cancelModelInstance(modelInstanceId, lensSessionHandle);
}
- /**
- * Predict.
- *
- * @param algorithm the algorithm
- * @param modelID the model id
- * @param uriInfo the uri info
- * @return the string
- * @throws LensException the lens exception
- */
@GET
- @Path("/predict/{algorithm}/{modelID}")
- @Produces({MediaType.APPLICATION_ATOM_XML, MediaType.APPLICATION_JSON})
- public String predict(@PathParam("algorithm") String algorithm, @PathParam("modelID") String modelID,
- @Context UriInfo uriInfo) throws LensException {
- // Load the model instance
- MLModel<?> model = getMlService().getModel(algorithm, modelID);
+ @Path("predict")
+ public String predict(@QueryParam("modelInstanceId") String modelInstanceId, @QueryParam("dataSetName") String
+ dataSetName, @QueryParam("lensSessionHandle") LensSessionHandle lensSessionHandle) throws LensException {
+ return getMlService().predict(modelInstanceId, dataSetName, lensSessionHandle);
+ }
- // Get input feature names
- MultivaluedMap<String, String> params = uriInfo.getQueryParameters();
- String[] features = new String[model.getFeatureColumns().size()];
- // Assuming that feature name parameters are same
- int i = 0;
- for (String feature : model.getFeatureColumns()) {
- features[i++] = params.getFirst(feature);
+ @GET
+ @Path("prediction/{predictionId}")
+ public Prediction getPrediction(@PathParam("predictionId") String predictionId) throws LensException {
+ return getMlService().getPrediction(predictionId);
+ }
+
+ @DELETE
+ @Path("prediction/{predictionId}")
+ public boolean cancelPrediction(@PathParam("predictionId") String predictionId,
+ @QueryParam("lensSessionHandle") LensSessionHandle lensSessionHandle)
+ throws LensException {
+ return getMlService().cancelPrediction(predictionId, lensSessionHandle);
+ }
+
+ @GET
+ @Path("evaluate")
+ public String evaluate(@QueryParam("modelInstanceId") String modelInstanceId, @QueryParam("dataSetName") String
+ dataSetName, @QueryParam("lensSessionHandle") LensSessionHandle lensSessionHandle)
+ throws LensException {
+ return getMlService().evaluate(modelInstanceId, dataSetName, lensSessionHandle);
+ }
+
+ @GET
+ @Path("evaluation/{evalId}")
+ public Evaluation getEvaluation(@PathParam("evalId") String evalId) throws LensException {
+ return getMlService().getEvaluation(evalId);
+ }
+
+ @DELETE
+ @Path("evaluation/{evalId}")
+ public APIResult cancelEvaluation(@PathParam("evalId") String evalId,
+ @QueryParam("lensSessionHandle") LensSessionHandle lensSessionHandle)
+ throws LensException {
+ boolean result = getMlService().cancelEvaluation(evalId, lensSessionHandle);
+ if (result) {
+ return new APIResult(APIResult.Status.SUCCEEDED, "");
}
-
- // TODO needs a 'prediction formatter'
- return getMlService().predict(algorithm, modelID, features).toString();
+ return new APIResult(APIResult.Status.FAILED, "");
}
}
diff --git a/lens-ml-lib/src/main/java/org/apache/lens/rdd/LensRDDClient.java b/lens-ml-lib/src/main/java/org/apache/lens/rdd/LensRDDClient.java
index b4f43ec..9d7ee78 100644
--- a/lens-ml-lib/src/main/java/org/apache/lens/rdd/LensRDDClient.java
+++ b/lens-ml-lib/src/main/java/org/apache/lens/rdd/LensRDDClient.java
@@ -32,6 +32,8 @@
import org.apache.lens.ml.algo.spark.HiveTableRDD;
import org.apache.lens.server.api.error.LensException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.TableType;
import org.apache.hadoop.hive.metastore.api.FieldSchema;
@@ -48,8 +50,6 @@
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.rdd.RDD;
-import lombok.extern.slf4j.Slf4j;
-
/**
* <p>
* Create RDD from a Lens query. User can poll returned query handle with isReadyForRDD() until the RDD is ready to be
@@ -76,26 +76,19 @@
* JavaRDD<ResultRow> rdd = client.createLensRDD("SELECT msr1 from TEST_CUBE WHERE ...", conf);
* </pre>
*/
-@Slf4j
public class LensRDDClient {
+ /**
+ * The Constant LOG.
+ */
+ public static final Log LOG = LogFactory.getLog(LensRDDClient.class);
// Default input format for table created from Lens result set
- /** The Constant INPUT_FORMAT. */
- private static final String INPUT_FORMAT = TextInputFormat.class.getName();
- // Default output format
- /** The Constant OUTPUT_FORMAT. */
- private static final String OUTPUT_FORMAT = TextOutputFormat.class.getName();
- // Name of partition column and its value. There is always exactly one partition in the table created from
- // Result set.
- /** The Constant TEMP_TABLE_PART_COL. */
- private static final String TEMP_TABLE_PART_COL = "dummy_partition_column";
-
- /** The Constant TEMP_TABLE_PART_VAL. */
- private static final String TEMP_TABLE_PART_VAL = "placeholder_value";
-
- /** The Constant HIVE_CONF. */
+ /**
+ * The Constant HIVE_CONF.
+ */
protected static final HiveConf HIVE_CONF = new HiveConf();
+ // Default output format
static {
HIVE_CONF.setVar(HiveConf.ConfVars.METASTOREURIS, "");
HIVE_CONF.set("javax.jdo.option.ConnectionURL", "jdbc:derby:;databaseName=./metastore_db;create=true");
@@ -103,11 +96,32 @@
HIVE_CONF.setBoolean("hive.metastore.local", true);
HIVE_CONF.set("hive.metastore.warehouse.dir", "file://${user.dir}/warehouse");
}
-
- /** The spark context. */
+ // Name of partition column and its value. There is always exactly one partition in the table created from
+ // Result set.
+ /**
+ * The Constant INPUT_FORMAT.
+ */
+ private static final String INPUT_FORMAT = TextInputFormat.class.getName();
+ /**
+ * The Constant OUTPUT_FORMAT.
+ */
+ private static final String OUTPUT_FORMAT = TextOutputFormat.class.getName();
+ /**
+ * The Constant TEMP_TABLE_PART_COL.
+ */
+ private static final String TEMP_TABLE_PART_COL = "dummy_partition_column";
+ /**
+ * The Constant TEMP_TABLE_PART_VAL.
+ */
+ private static final String TEMP_TABLE_PART_VAL = "placeholder_value";
+ /**
+ * The spark context.
+ */
private final JavaSparkContext sparkContext; // Spark context
- /** The lens client. */
+ /**
+ * The lens client.
+ */
private LensClient lensClient; // Lens client instance. Initialized lazily.
/**
@@ -230,7 +244,7 @@
try {
rdd = HiveTableRDD.createHiveTableRDD(sparkContext, HIVE_CONF, "default", tempTableName, TEMP_TABLE_PART_COL
+ "='" + TEMP_TABLE_PART_VAL + "'");
- log.info("Created RDD {} for table {}", rdd.name(), tempTableName);
+ LOG.info("Created RDD " + rdd.name() + " for table " + tempTableName);
} catch (IOException e) {
throw new LensException("Error creating RDD for table " + tempTableName, e);
}
@@ -267,7 +281,7 @@
tbl.getPartCols().add(new FieldSchema(TEMP_TABLE_PART_COL, "string", "default"));
hiveClient.createTable(tbl);
- log.info("Table {} created", tableName);
+ LOG.info("Table " + tableName + " created");
// Add partition to the table
AddPartitionDesc partitionDesc = new AddPartitionDesc("default", tableName, false);
@@ -275,7 +289,7 @@
partSpec.put(TEMP_TABLE_PART_COL, TEMP_TABLE_PART_VAL);
partitionDesc.addPartition(partSpec, dataLocation);
hiveClient.createPartitions(partitionDesc);
- log.info("Created partition in {} for data in {}", tableName, dataLocation);
+ LOG.info("Created partition in " + tableName + " for data in " + dataLocation);
return tableName;
}
@@ -305,7 +319,7 @@
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
- log.warn("Interrupted while waiting for query", e);
+ LOG.warn("Interrupted while waiting for query", e);
break;
}
}
@@ -317,13 +331,19 @@
*/
public static class LensRDDResult implements Serializable {
- /** The result rdd. */
+ /**
+ * The result rdd.
+ */
private transient RDD<List<Object>> resultRDD;
- /** The lens query. */
+ /**
+ * The lens query.
+ */
private QueryHandle lensQuery;
- /** The temp table name. */
+ /**
+ * The temp table name.
+ */
private String tempTableName;
/**
@@ -368,7 +388,7 @@
JavaPairRDD<WritableComparable, HCatRecord> javaPairRDD = HiveTableRDD.createHiveTableRDD(sparkContext,
HIVE_CONF, "default", tempTableName, TEMP_TABLE_PART_COL + "='" + TEMP_TABLE_PART_VAL + "'");
resultRDD = javaPairRDD.map(new HCatRecordToObjectListMapper()).rdd();
- log.info("Created RDD {} for table {}", resultRDD.name(), tempTableName);
+ LOG.info("Created RDD " + resultRDD.name() + " for table " + tempTableName);
} catch (IOException e) {
throw new LensException("Error creating RDD for table " + tempTableName, e);
}
@@ -390,7 +410,7 @@
try {
hiveClient = Hive.get(HIVE_CONF);
hiveClient.dropTable("default." + tempTableName);
- log.info("Dropped temp table {}", tempTableName);
+ LOG.info("Dropped temp table " + tempTableName);
} catch (HiveException e) {
throw new LensException(e);
}
diff --git a/lens-ml-lib/src/test/java/org/apache/lens/ml/ExampleUtils.java b/lens-ml-lib/src/test/java/org/apache/lens/ml/ExampleUtils.java
index ac3a55e..0eb8708 100644
--- a/lens-ml-lib/src/test/java/org/apache/lens/ml/ExampleUtils.java
+++ b/lens-ml-lib/src/test/java/org/apache/lens/ml/ExampleUtils.java
@@ -41,6 +41,7 @@
*/
@Slf4j
public final class ExampleUtils {
+
private ExampleUtils() {
}
@@ -56,7 +57,8 @@
* @throws HiveException the hive exception
*/
public static void createTable(HiveConf conf, String database, String tableName, String sampleDataFile,
- String labelColumn, Map<String, String> tableParams, String... features) throws HiveException {
+ String labelColumn, Map<String, String> tableParams, String... features)
+ throws HiveException {
Path dataFilePath = new Path(sampleDataFile);
Path partDir = dataFilePath.getParent();
@@ -64,16 +66,16 @@
// Create table
List<FieldSchema> columns = new ArrayList<FieldSchema>();
+ for (String feature : features) {
+ columns.add(new FieldSchema(feature, "double", "Feature " + feature));
+ }
+
// Label is optional. Not used for unsupervised models.
// If present, label will be the first column, followed by features
if (labelColumn != null) {
columns.add(new FieldSchema(labelColumn, "double", "Labelled Column"));
}
- for (String feature : features) {
- columns.add(new FieldSchema(feature, "double", "Feature " + feature));
- }
-
Table tbl = Hive.get(conf).newTable(database + "." + tableName);
tbl.setTableType(TableType.MANAGED_TABLE);
tbl.getTTable().getSd().setCols(columns);
diff --git a/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java
index 51344ce..2fe846b 100644
--- a/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java
+++ b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLResource.java
@@ -18,51 +18,9 @@
*/
package org.apache.lens.ml;
-import java.io.File;
-import java.net.URI;
-import java.util.*;
-
-import javax.ws.rs.client.WebTarget;
-import javax.ws.rs.core.Application;
-import javax.ws.rs.core.UriBuilder;
-
-import org.apache.lens.client.LensClient;
-import org.apache.lens.client.LensClientConfig;
-import org.apache.lens.client.LensMLClient;
-import org.apache.lens.ml.algo.spark.dt.DecisionTreeAlgo;
-import org.apache.lens.ml.algo.spark.lr.LogisticRegressionAlgo;
-import org.apache.lens.ml.algo.spark.nb.NaiveBayesAlgo;
-import org.apache.lens.ml.algo.spark.svm.SVMAlgo;
-import org.apache.lens.ml.impl.MLTask;
-import org.apache.lens.ml.impl.MLUtils;
-import org.apache.lens.ml.server.MLApp;
-import org.apache.lens.ml.server.MLServiceResource;
-import org.apache.lens.server.LensJerseyTest;
-import org.apache.lens.server.api.LensConfConstants;
-import org.apache.lens.server.query.QueryServiceResource;
-import org.apache.lens.server.session.SessionResource;
-
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.hive.metastore.api.Database;
-import org.apache.hadoop.hive.ql.metadata.Hive;
-import org.apache.hadoop.hive.ql.metadata.HiveException;
-import org.apache.hadoop.hive.ql.metadata.Partition;
-import org.apache.hadoop.hive.ql.metadata.Table;
-
-import org.glassfish.jersey.client.ClientConfig;
-import org.glassfish.jersey.media.multipart.MultiPartFeature;
-import org.testng.Assert;
-import org.testng.annotations.AfterTest;
-import org.testng.annotations.BeforeMethod;
-import org.testng.annotations.BeforeTest;
-import org.testng.annotations.Test;
-
-import lombok.extern.slf4j.Slf4j;
-
-@Slf4j
-@Test
-public class TestMLResource extends LensJerseyTest {
-
+public class TestMLResource {
+ /*
+ private static final Log LOG = LogFactory.getLog(TestMLResource.class);
private static final String TEST_DB = "default";
private WebTarget mlTarget;
@@ -70,7 +28,7 @@
@Override
protected int getTestPort() {
- return 10002;
+ return 10003;
}
@Override
@@ -98,7 +56,7 @@
LensClientConfig lensClientConfig = new LensClientConfig();
lensClientConfig.setLensDatabase(TEST_DB);
lensClientConfig.set(LensConfConstants.SERVER_BASE_URL,
- "http://localhost:" + getTestPort() + "/lensapi");
+ "http://localhost:" + getTestPort() + "/lensapi");
LensClient client = new LensClient(lensClientConfig);
mlClient = new LensMLClient(client);
}
@@ -112,7 +70,7 @@
hive.dropDatabase(TEST_DB);
} catch (Exception exc) {
// Ignore drop db exception
- log.error("Exception while dropping database.", exc);
+ ////LOG.error(exc.getMessage());
}
mlClient.close();
}
@@ -130,120 +88,7 @@
@Test
public void testGetAlgos() throws Exception {
- List<String> algoNames = mlClient.getAlgorithms();
- Assert.assertNotNull(algoNames);
-
- Assert.assertTrue(
- algoNames.contains(MLUtils.getAlgoName(NaiveBayesAlgo.class)),
- MLUtils.getAlgoName(NaiveBayesAlgo.class));
-
- Assert.assertTrue(algoNames.contains(MLUtils.getAlgoName(SVMAlgo.class)),
- MLUtils.getAlgoName(SVMAlgo.class));
-
- Assert.assertTrue(
- algoNames.contains(MLUtils.getAlgoName(LogisticRegressionAlgo.class)),
- MLUtils.getAlgoName(LogisticRegressionAlgo.class));
-
- Assert.assertTrue(
- algoNames.contains(MLUtils.getAlgoName(DecisionTreeAlgo.class)),
- MLUtils.getAlgoName(DecisionTreeAlgo.class));
+ mlClient.test();
}
-
- @Test
- public void testGetAlgoParams() throws Exception {
- Map<String, String> params = mlClient.getAlgoParamDescription(MLUtils
- .getAlgoName(DecisionTreeAlgo.class));
- Assert.assertNotNull(params);
- Assert.assertFalse(params.isEmpty());
-
- for (String key : params.keySet()) {
- log.info("## Param " + key + " help = " + params.get(key));
- }
- }
-
- @Test
- public void trainAndEval() throws Exception {
- log.info("Starting train & eval");
- final String algoName = MLUtils.getAlgoName(NaiveBayesAlgo.class);
- HiveConf conf = new HiveConf();
- String tableName = "naivebayes_training_table";
- String sampleDataFilePath = "data/naive_bayes/naive_bayes_train.data";
-
- File sampleDataFile = new File(sampleDataFilePath);
- URI sampleDataFileURI = sampleDataFile.toURI();
-
- String labelColumn = "label";
- String[] features = { "feature_1", "feature_2", "feature_3" };
- String outputTable = "naivebayes_eval_table";
-
- log.info("Creating training table from file "
- + sampleDataFileURI.toString());
-
- Map<String, String> tableParams = new HashMap<String, String>();
- try {
- ExampleUtils.createTable(conf, TEST_DB, tableName,
- sampleDataFileURI.toString(), labelColumn, tableParams, features);
- } catch (HiveException exc) {
- log.error("Hive exception encountered.", exc);
- }
- MLTask.Builder taskBuilder = new MLTask.Builder();
-
- taskBuilder.algorithm(algoName).hiveConf(conf).labelColumn(labelColumn)
- .outputTable(outputTable).client(mlClient).trainingTable(tableName);
-
- // Add features
- taskBuilder.addFeatureColumn("feature_1").addFeatureColumn("feature_2")
- .addFeatureColumn("feature_3");
-
- MLTask task = taskBuilder.build();
-
- log.info("Created task " + task.toString());
- task.run();
- Assert.assertEquals(task.getTaskState(), MLTask.State.SUCCESSFUL);
-
- String firstModelID = task.getModelID();
- String firstReportID = task.getReportID();
- Assert.assertNotNull(firstReportID);
- Assert.assertNotNull(firstModelID);
-
- taskBuilder = new MLTask.Builder();
- taskBuilder.algorithm(algoName).hiveConf(conf).labelColumn(labelColumn)
- .outputTable(outputTable).client(mlClient).trainingTable(tableName);
-
- taskBuilder.addFeatureColumn("feature_1").addFeatureColumn("feature_2")
- .addFeatureColumn("feature_3");
-
- MLTask anotherTask = taskBuilder.build();
-
- log.info("Created second task " + anotherTask.toString());
- anotherTask.run();
-
- String secondModelID = anotherTask.getModelID();
- String secondReportID = anotherTask.getReportID();
- Assert.assertNotNull(secondModelID);
- Assert.assertNotNull(secondReportID);
-
- Hive metastoreClient = Hive.get(conf);
- Table outputHiveTable = metastoreClient.getTable(outputTable);
- List<Partition> partitions = metastoreClient.getPartitions(outputHiveTable);
-
- Assert.assertNotNull(partitions);
-
- int i = 0;
- Set<String> partReports = new HashSet<String>();
- for (Partition part : partitions) {
- log.info("@@PART#" + i + " " + part.getSpec().toString());
- partReports.add(part.getSpec().get("part_testid"));
- }
-
- // Verify partitions created for each run
- Assert.assertTrue(partReports.contains(firstReportID), firstReportID
- + " first partition not there");
- Assert.assertTrue(partReports.contains(secondReportID), secondReportID
- + " second partition not there");
-
- log.info("Completed task run");
-
- }
-
+*/
}
diff --git a/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLRunner.java b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLRunner.java
index ef3d53e..e52c31b 100644
--- a/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLRunner.java
+++ b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLRunner.java
@@ -18,40 +18,10 @@
*/
package org.apache.lens.ml;
-import java.net.URI;
-
-import javax.ws.rs.core.Application;
-import javax.ws.rs.core.UriBuilder;
-
-import org.apache.lens.client.LensClient;
-import org.apache.lens.client.LensClientConfig;
-import org.apache.lens.client.LensMLClient;
-import org.apache.lens.ml.impl.MLRunner;
-import org.apache.lens.ml.impl.MLTask;
-import org.apache.lens.ml.server.MLApp;
-import org.apache.lens.server.LensJerseyTest;
-import org.apache.lens.server.api.LensConfConstants;
-import org.apache.lens.server.metastore.MetastoreResource;
-import org.apache.lens.server.query.QueryServiceResource;
-import org.apache.lens.server.session.SessionResource;
-
-import org.apache.hadoop.hive.conf.HiveConf;
-import org.apache.hadoop.hive.metastore.api.Database;
-import org.apache.hadoop.hive.ql.metadata.Hive;
-
-import org.glassfish.jersey.client.ClientConfig;
-import org.glassfish.jersey.media.multipart.MultiPartFeature;
-import org.testng.Assert;
-import org.testng.annotations.AfterTest;
-import org.testng.annotations.BeforeTest;
-import org.testng.annotations.Test;
-
-import lombok.extern.slf4j.Slf4j;
-
-
-@Test
-@Slf4j
-public class TestMLRunner extends LensJerseyTest {
+//@Test
+public class TestMLRunner {
+ /*
+ private static final Log LOG = LogFactory.getLog(TestMLRunner.class);
private static final String TEST_DB = TestMLRunner.class.getSimpleName();
private LensMLClient mlClient;
@@ -101,7 +71,7 @@
@Test
public void trainAndEval() throws Exception {
- log.info("Starting train & eval");
+ LOG.info("Starting train & eval");
String algoName = "spark_naive_bayes";
String database = "default";
String trainTable = "naivebayes_training_table";
@@ -125,7 +95,7 @@
@Test
public void trainAndEvalFromDir() throws Exception {
- log.info("Starting train & eval from Dir");
+ LOG.info("Starting train & eval from Dir");
MLRunner runner = new MLRunner();
runner.init(mlClient, "data/naive_bayes");
MLTask task = runner.train();
@@ -135,4 +105,5 @@
Assert.assertNotNull(modelID);
Assert.assertNotNull(reportID);
}
+ */
}
diff --git a/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLServices.java b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLServices.java
new file mode 100644
index 0000000..18a4f2f
--- /dev/null
+++ b/lens-ml-lib/src/test/java/org/apache/lens/ml/TestMLServices.java
@@ -0,0 +1,323 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.lens.ml;
+
+import java.io.File;
+import java.net.URI;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import javax.ws.rs.client.WebTarget;
+import javax.ws.rs.core.Application;
+import javax.ws.rs.core.UriBuilder;
+
+import org.apache.lens.client.LensClient;
+import org.apache.lens.client.LensClientConfig;
+import org.apache.lens.client.LensMLClient;
+import org.apache.lens.ml.algo.spark.dt.DecisionTreeAlgo;
+import org.apache.lens.ml.algo.spark.kmeans.KMeansAlgo;
+import org.apache.lens.ml.algo.spark.lr.LogisticRegressionAlgo;
+import org.apache.lens.ml.algo.spark.nb.NaiveBayesAlgo;
+import org.apache.lens.ml.algo.spark.svm.SVMAlgo;
+import org.apache.lens.ml.api.*;
+import org.apache.lens.ml.server.MLApp;
+import org.apache.lens.ml.server.MLServiceResource;
+import org.apache.lens.server.LensJerseyTest;
+import org.apache.lens.server.api.LensConfConstants;
+import org.apache.lens.server.query.QueryServiceResource;
+import org.apache.lens.server.session.SessionResource;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.hadoop.hive.conf.HiveConf;
+import org.apache.hadoop.hive.metastore.api.Database;
+import org.apache.hadoop.hive.ql.metadata.Hive;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.metadata.Partition;
+import org.apache.hadoop.hive.ql.metadata.Table;
+
+import org.glassfish.jersey.client.ClientConfig;
+import org.glassfish.jersey.media.multipart.MultiPartFeature;
+import org.testng.Assert;
+import org.testng.annotations.AfterTest;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.BeforeTest;
+import org.testng.annotations.Test;
+
+
+@Test
+public class TestMLServices extends LensJerseyTest {
+ private static final Log LOG = LogFactory.getLog(TestMLServices.class);
+ private static final String TEST_DB = "default";
+
+ private WebTarget mlTarget;
+ private LensMLClient mlClient;
+
+ @Override
+ protected int getTestPort() {
+ return 10002;
+ }
+
+ @Override
+ protected Application configure() {
+ return new MLApp(SessionResource.class, QueryServiceResource.class);
+ }
+
+ @Override
+ protected void configureClient(ClientConfig config) {
+ config.register(MultiPartFeature.class);
+ }
+
+ @Override
+ protected URI getBaseUri() {
+ return UriBuilder.fromUri("http://localhost/").port(getTestPort()).path("/lensapi").build();
+ }
+
+ @BeforeTest
+ public void setUp() throws Exception {
+ LOG.info("Testing started!");
+ super.setUp();
+ Hive hive = Hive.get(new HiveConf());
+ Database db = new Database();
+ db.setName(TEST_DB);
+ //hive.createDatabase(db, true);
+ LensClientConfig lensClientConfig = new LensClientConfig();
+ lensClientConfig.setLensDatabase(TEST_DB);
+ lensClientConfig.set(LensConfConstants.SERVER_BASE_URL,
+ "http://localhost:" + 10002 + "/lensapi");
+ LensClient client = new LensClient(lensClientConfig);
+ mlClient = new LensMLClient(client);
+ }
+
+ @AfterTest
+ public void tearDown() throws Exception {
+ super.tearDown();
+ Hive hive = Hive.get(new HiveConf());
+
+ try {
+ hive.dropDatabase(TEST_DB);
+ } catch (Exception exc) {
+ // Ignore drop db exception
+ LOG.error(exc.getMessage());
+ }
+ mlClient.close();
+ }
+
+ @BeforeMethod
+ public void setMLTarget() {
+ mlTarget = target().path("ml");
+ }
+
+ @Test
+ public void testMLResourceUp() throws Exception {
+ String mlUpMsg = mlTarget.request().get(String.class);
+ Assert.assertEquals(mlUpMsg, MLServiceResource.ML_UP_MESSAGE);
+ }
+
+
+ public void testGetAlgos() throws Exception {
+
+ LOG.info("Testing registered algorithm.");
+
+ List<Algo> algos = mlClient.getAlgos();
+
+ Assert.assertNotNull(algos);
+
+ ArrayList<String> algoNames = new ArrayList();
+ for (Algo algo : algos) {
+ algoNames.add(algo.getName());
+ }
+
+ Assert.assertTrue(
+ algoNames.contains(new NaiveBayesAlgo().getName()),
+ new NaiveBayesAlgo().getName());
+
+ Assert.assertTrue(algoNames.contains(new SVMAlgo().getName()),
+ new SVMAlgo().getName());
+
+ Assert.assertTrue(
+ algoNames.contains(new LogisticRegressionAlgo().getName()),
+ new LogisticRegressionAlgo().getName());
+
+ Assert.assertTrue(
+ algoNames.contains(new DecisionTreeAlgo().getName()),
+ new DecisionTreeAlgo().getName());
+
+ Assert.assertTrue(
+ algoNames.contains(new KMeansAlgo().getName()),
+ new KMeansAlgo().getName());
+ }
+
+
+ @Test
+ public void trainAndEval() throws Exception {
+ LOG.info("Starting train and eval");
+ HiveConf conf = new HiveConf();
+
+ String sampleDataFilePath = "data/lr/lr_train.data";
+
+ File sampleDataFile = new File(sampleDataFilePath);
+ URI sampleDataFileURI = sampleDataFile.toURI();
+
+ ArrayList<Feature> features = new ArrayList();
+ Feature feature1 = new Feature("feature1", "description", Feature.Type.Categorical, "feature1");
+ Feature feature2 = new Feature("feature2", "description", Feature.Type.Categorical, "feature2");
+ Feature feature3 = new Feature("feature3", "description", Feature.Type.Categorical, "feature3");
+ Feature label = new Feature("label", "description", Feature.Type.Categorical, "label");
+
+ features.add(feature1);
+ features.add(feature2);
+ features.add(feature3);
+
+ String tableName = "lr_table_6";
+
+ String labelColumn = "label";
+ String[] featureNames = {"feature1", "feature2", "feature3"};
+ Map<String, String> tableParams = new HashMap<String, String>();
+ try {
+ ExampleUtils.createTable(conf, TEST_DB, tableName,
+ sampleDataFileURI.toString(), labelColumn, tableParams, featureNames);
+ } catch (HiveException exc) {
+ LOG.error(exc.getLocalizedMessage());
+ }
+
+ mlClient.createDataSet("lr_table_6", "lr_table_6", TEST_DB);
+
+
+ mlClient.createModel("lr_model", "spark_logistic_regression", new HashMap<String, String>(), features,
+ label, mlClient.getSessionHandle());
+ LOG.info("model created with Id: " + "lr_model");
+
+ Model model = mlClient.getModel("lr_model");
+ Assert.assertTrue(model != null, "Null model returned after creation");
+
+ String modelInstanceId = mlClient.trainModel("lr_model", "lr_table_6", mlClient.getSessionHandle());
+ LOG.info("Model Instance created with Id:" + modelInstanceId);
+
+ ModelInstance modelInstance;
+
+ do {
+ modelInstance = mlClient.getModelInstance(modelInstanceId);
+ Thread.sleep(2000);
+ } while (!(modelInstance.getStatus() == Status.COMPLETED || modelInstance.getStatus() == Status.FAILED));
+
+ Assert.assertTrue(modelInstance.getStatus() == Status.COMPLETED, "Training model failed");
+
+ Map<String, String> featureMap = new HashMap();
+ featureMap.put("feature1", "1");
+ featureMap.put("feature2", "0");
+ featureMap.put("feature3", "4");
+
+ String predictedValue = mlClient.predict(modelInstanceId, featureMap);
+ LOG.info("Predicting :" + predictedValue);
+
+ Assert.assertTrue(predictedValue.equals("1.0"), "Predicted value incorrect");
+
+
+ String predictionId = mlClient.predict(modelInstanceId, "lr_table_6", mlClient.getSessionHandle());
+ Prediction prediction;
+ do {
+ prediction = mlClient.getPrediction(predictionId);
+ Thread.sleep(2000);
+ } while (!(prediction.getStatus() == Status.COMPLETED || prediction.getStatus() == Status.FAILED));
+
+ Assert.assertTrue(prediction != null, "Prediction failed");
+
+ Hive metastoreClient = Hive.get(conf);
+ Table outputHiveTableForPrediction = metastoreClient.getTable(prediction.getOutputDataSet());
+ List<Partition> predictionPartitions = metastoreClient.getPartitions(outputHiveTableForPrediction);
+
+ Assert.assertNotNull(predictionPartitions);
+
+ String evaluationId = mlClient.evaluate(modelInstanceId, "lr_table_6", mlClient.getSessionHandle());
+ LOG.info("Evaluation started with id:" + evaluationId);
+
+ Evaluation evaluation;
+ do {
+ evaluation = mlClient.getEvaluation(evaluationId);
+ Thread.sleep(2000);
+ } while (!(evaluation.getStatus() == Status.COMPLETED || evaluation.getStatus() == Status.FAILED));
+
+ Assert.assertTrue(evaluation.getStatus() == Status.COMPLETED, "Evaluating model failed");
+
+ Table outputHiveTableForEvaluation = metastoreClient.getTable(prediction.getOutputDataSet());
+ List<Partition> evaluationPartitions = metastoreClient.getPartitions(outputHiveTableForEvaluation);
+
+ Assert.assertNotNull(evaluationPartitions);
+
+ //Testing cancellation
+
+ evaluationId = mlClient.evaluate(modelInstanceId, "lr_table_6", mlClient.getSessionHandle());
+ LOG.info("Evaluation started with id:" + evaluationId);
+ boolean result = mlClient.cancelEvaluation(evaluationId, mlClient.getSessionHandle());
+
+ Assert.assertTrue(result);
+
+ evaluation = mlClient.getEvaluation(evaluationId);
+ Assert
+ .assertTrue(evaluation.getStatus() == Status.CANCELLED, "Cancellation failed" + evaluation.getStatus().name());
+ }
+
+ public void testDb() throws Exception {
+ //mlClient.createDataSet("lr_table_6", "lr_table_6", TEST_DB);
+
+ //DataSet ds = mlClient.getDataSet("lr_table_6");
+
+ }
+
+ @Test
+ public void testMetaStore() throws Exception {
+ //mlClient.createDataSet("lr_table_6", "lr_table_6", TEST_DB);
+ /*Map<String, String> algoParams = new HashMap();
+ algoParams.put("apn1", "apv1");
+ algoParams.put("apn2", "apv2");
+ List<Feature> features = new ArrayList<>();
+ features.add(new Feature("f1", "fd1", Feature.Type.Categorical, "dc1"));
+ features.add(new Feature("f2", "fd2", Feature.Type.Categorical, "dc2"));
+ //mlClient.createModel("model1", "algo1", algoParams, features, new Feature("lf","lfd", Feature.Type.Categorical,
+ "lddc"),mlClient.getSessionHandle());
+
+ Model model = mlClient.getModel("model1");
+
+ String id = mlClient*/
+
+ /*ArrayList<Feature> features = new ArrayList();
+ Feature feature1 = new Feature("feature1", "description", Feature.Type.Categorical, "feature1");
+ Feature feature2 = new Feature("feature2", "description", Feature.Type.Categorical, "feature2");
+ Feature feature3 = new Feature("feature3", "description", Feature.Type.Categorical, "feature3");
+ Feature label = new Feature("label", "description", Feature.Type.Categorical, "label");
+
+ features.add(feature1);
+ features.add(feature2);
+ features.add(feature3);
+ Map param = new HashMap<String, String>();
+ param.put("one","uno");
+ param.put("two","rindu");
+ mlClient.createModel("lr_model", "spark_logistic_regression", param, features,
+ label, mlClient.getSessionHandle());
+
+
+ Model ml = mlClient.getModel("lr_model");
+*/
+ }
+
+}
+
diff --git a/lens-ml-lib/src/test/resources/lens-site.xml b/lens-ml-lib/src/test/resources/lens-site.xml
index 9be7850..854f861 100644
--- a/lens-ml-lib/src/test/resources/lens-site.xml
+++ b/lens-ml-lib/src/test/resources/lens-site.xml
@@ -23,6 +23,7 @@
<?xml-stylesheet type="text/xsl" href="configuration.xsl"?>
<configuration>
+
<property>
<name>lens.server.drivers</name>
<value>hive:org.apache.lens.driver.hive.HiveDriver</value>
diff --git a/lens-ml-lib/src/test/resources/log4j.properties b/lens-ml-lib/src/test/resources/log4j.properties
new file mode 100644
index 0000000..afadc2f
--- /dev/null
+++ b/lens-ml-lib/src/test/resources/log4j.properties
@@ -0,0 +1,85 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+lensserver.root.logger=INFO, ROOT
+lensserver.request.logger=INFO, REQUEST
+lensserver.audit.logger=INFO, CONSOLE
+lensserver.querystatus.logger=INFO, CONSOLE
+
+log4j.rootLogger=${lensserver.root.logger}
+log4j.logger.org.apache.lens.server.LensServer.request=${lensserver.request.logger}
+log4j.additivity.org.apache.lens.server.LensServer.request=false
+log4j.logger.org.apache.lens.server.LensServer.audit=${lensserver.audit.logger}
+log4j.additivity.org.apache.lens.server.LensServer.audit=false
+log4j.logger.org.apache.lens.server.query.QueryExecutionServiceImpl$QueryStatusLogger=${lensserver.querystatus.logger}
+log4j.additivity.org.apache.lens.server.query.QueryExecutionServiceImpl$QueryStatusLogger=false
+log4j.logger.org.apache.lens.server.stats.event.query.QueryExecutionStatistics=DEBUG, QueryExecutionStatistics
+log4j.additivity.org.apache.lens.server.stats.event.query.QueryExecutionStatistics=false
+
+
+# CONSOLE is set to be a ConsoleAppender.
+log4j.appender.CONSOLE=org.apache.log4j.ConsoleAppender
+
+# CONSOLE uses PatternLayout.
+log4j.appender.CONSOLE.layout=org.apache.log4j.PatternLayout
+log4j.appender.CONSOLE.layout.ConversionPattern=%d [%t] %-5p %c - %m%n
+
+log4j.appender.ROOT=org.apache.log4j.RollingFileAppender
+log4j.appender.ROOT.File=${lens.log.dir}/lensserver.log
+log4j.appender.ROOT.layout=org.apache.log4j.PatternLayout
+log4j.appender.ROOT.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} [%t] %-5p %c %x - %m%n
+
+log4j.appender.ROOT.MaxFileSize=100000KB
+# Keep 20 backup files
+log4j.appender.ROOT.MaxBackupIndex=20
+
+
+log4j.appender.AUDIT=org.apache.log4j.RollingFileAppender
+log4j.appender.AUDIT.File=${lens.log.dir}/lensserver-audit.log
+log4j.appender.AUDIT.layout=org.apache.log4j.PatternLayout
+log4j.appender.AUDIT.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} [%t] %-5p %c %x - %m%n
+
+log4j.appender.AUDIT.MaxFileSize=100000KB
+# Keep 20 backup files
+log4j.appender.AUDIT.MaxBackupIndex=20
+
+log4j.appender.REQUEST=org.apache.log4j.RollingFileAppender
+log4j.appender.REQUEST.File=${lens.log.dir}/lensserver-requests.log
+log4j.appender.REQUEST.layout=org.apache.log4j.PatternLayout
+log4j.appender.REQUEST.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} [%t] %-5p %c %x - %m%n
+
+log4j.appender.REQUEST.MaxFileSize=100000KB
+# Keep 20 backup files
+log4j.appender.REQUEST.MaxBackupIndex=20
+
+log4j.appender.STATUS=org.apache.log4j.RollingFileAppender
+log4j.appender.STATUS.File=${lens.log.dir}/lensserver-query-status.log
+log4j.appender.STATUS.layout=org.apache.log4j.PatternLayout
+log4j.appender.STATUS.layout.ConversionPattern=%d{dd MMM yyyy HH:mm:ss,SSS} [%t] %-5p %c %x - %m%n
+
+log4j.appender.STATUS.MaxFileSize=100000KB
+# Keep 20 backup files
+log4j.appender.STATUS.MaxBackupIndex=20
+
+
+#Add query statistics logger with hourly rollup
+log4j.appender.QueryExecutionStatistics=org.apache.log4j.DailyRollingFileAppender
+log4j.appender.QueryExecutionStatistics.DatePattern='.'yyyy-MM-dd-HH
+log4j.appender.QueryExecutionStatistics.File=${lens.log.dir}/query-stats.log
+log4j.appender.QueryExecutionStatistics.layout=org.apache.lens.server.stats.store.log.StatisticsLogLayout