blob: cc57dca53c2e11e486bde15a2f385e2d3c96827c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.submarine.server.experiment;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.atomic.AtomicInteger;
import javax.ws.rs.core.Response.Status;
import com.google.common.annotations.VisibleForTesting;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import org.apache.submarine.commons.utils.SubmarineConfiguration;
import org.apache.submarine.commons.utils.exception.SubmarineRuntimeException;
import org.apache.submarine.server.SubmarineServer;
import org.apache.submarine.server.SubmitterManager;
import org.apache.submarine.server.api.experiment.Experiment;
import org.apache.submarine.server.api.experiment.ExperimentId;
import org.apache.submarine.server.api.Submitter;
import org.apache.submarine.server.api.experiment.ExperimentLog;
import org.apache.submarine.server.api.experiment.TensorboardInfo;
import org.apache.submarine.server.api.experiment.MlflowInfo;
import org.apache.submarine.server.api.experiment.ServeRequest;
import org.apache.submarine.server.api.experiment.ServeResponse;
import org.apache.submarine.server.api.spec.ExperimentSpec;
import org.apache.submarine.server.experiment.database.ExperimentEntity;
import org.apache.submarine.server.experiment.database.ExperimentService;
import org.apache.submarine.server.rest.RestConstants;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.mlflow.tracking.MlflowClient;
/**
* It's responsible for managing the experiment CRUD and cache them
*/
public class ExperimentManager {
private static final Logger LOG = LoggerFactory.getLogger(ExperimentManager.class);
private static volatile ExperimentManager manager;
private final AtomicInteger experimentCounter = new AtomicInteger(0);
private Optional<org.mlflow.api.proto.Service.Experiment> MlflowExperimentOptional;
private org.mlflow.api.proto.Service.Experiment MlflowExperiment;
/**
* Used to cache the specs by the experiment id.
* key: the string of experiment id
* value: Experiment object
*/
private final ConcurrentMap<String, Experiment> cachedExperimentMap = new ConcurrentHashMap<>();
private final Submitter submitter;
private final ExperimentService experimentService;
/**
* Get the singleton instance
*
* @return object
*/
public static ExperimentManager getInstance() {
if (manager == null) {
synchronized (ExperimentManager.class) {
if (manager == null) {
manager = new ExperimentManager(SubmitterManager.loadSubmitter(), new ExperimentService());
}
}
}
return manager;
}
@VisibleForTesting
protected ExperimentManager(Submitter submitter, ExperimentService experimentService) {
this.submitter = submitter;
this.experimentService = experimentService;
}
/**
* Create experiment
*
* @param spec spec
* @return object
* @throws SubmarineRuntimeException the service error
*/
public Experiment createExperiment(ExperimentSpec spec) throws SubmarineRuntimeException {
checkSpec(spec);
// Submarine sdk will get experimentID and JDBC URL from environment variables in each worker,
// and then log experiment metrics and parameters to submarine server
ExperimentId id = generateExperimentId();
String url = getSQLAlchemyURL();
spec.getMeta().getEnvVars().put(RestConstants.JOB_ID, id.toString());
spec.getMeta().getEnvVars().put(RestConstants.SUBMARINE_TRACKING_URI, url);
spec.getMeta().getEnvVars().put(RestConstants.LOG_DIR_KEY, RestConstants.LOG_DIR_VALUE);
String lowerName = spec.getMeta().getName().toLowerCase();
spec.getMeta().setName(lowerName);
spec.getMeta().setExperimentId(id.toString().replaceAll("_", "-"));
LOG.info(spec.getMeta().getExperimentId());
Experiment experiment = submitter.createExperiment(spec);
experiment.setExperimentId(id);
spec.getMeta().getEnvVars().remove(RestConstants.JOB_ID);
spec.getMeta().getEnvVars().remove(RestConstants.SUBMARINE_TRACKING_URI);
spec.getMeta().getEnvVars().remove(RestConstants.LOG_DIR_KEY);
experiment.setSpec(spec);
ExperimentEntity entity = buildEntityFromExperiment(experiment);
experimentService.insert(entity);
return experiment;
}
/**
* Get experiment
*
* @param id experiment id
* @return object
* @throws SubmarineRuntimeException the service error
*/
public Experiment getExperiment(String id) throws SubmarineRuntimeException {
checkExperimentId(id);
ExperimentEntity entity = experimentService.select(id);
Experiment experiment = buildExperimentFromEntity(entity);
Experiment foundExperiment = submitter.findExperiment(experiment.getSpec());
experiment.rebuild(foundExperiment);
return experiment;
}
/**
* List experiments
*
* @param status status, if null will return all experiments
* @return list
* @throws SubmarineRuntimeException the service error
*/
public List<Experiment> listExperimentsByStatus(String status) throws SubmarineRuntimeException {
List<Experiment> experimentList = new ArrayList<>();
List<ExperimentEntity> entities = experimentService.selectAll();
for (ExperimentEntity entity : entities) {
Experiment experiment = buildExperimentFromEntity(entity);
Experiment foundExperiment;
try {
foundExperiment = submitter.findExperiment(experiment.getSpec());
} catch (SubmarineRuntimeException e) {
LOG.warn("Submitter can not find experiment: {}, will delete it", entity.getId());
experimentService.delete(entity.getId());
continue;
}
LOG.info("Found experiment: {}", foundExperiment.getStatus());
if (status == null || status.toLowerCase().equals(foundExperiment.getStatus().toLowerCase())) {
experiment.rebuild(foundExperiment);
experimentList.add(experiment);
}
}
LOG.info("List experiment: {}", experimentList.size());
return experimentList;
}
/**
* Patch the experiment
*
* @param id experiment id
* @param newSpec spec
* @return object
* @throws SubmarineRuntimeException the service error
*/
public Experiment patchExperiment(String id, ExperimentSpec newSpec) throws SubmarineRuntimeException {
checkExperimentId(id);
checkSpec(newSpec);
ExperimentEntity entity = experimentService.select(id);
Experiment experiment = buildExperimentFromEntity(entity);
Experiment patchExperiment = submitter.patchExperiment(newSpec);
// update spec in returned experiment
experiment.setSpec(newSpec);
// update entity and commit
entity.setExperimentSpec(new GsonBuilder().disableHtmlEscaping().create().toJson(newSpec));
experimentService.update(entity);
// patch new information in experiment
experiment.rebuild(patchExperiment);
return experiment;
}
/**
* Delete experiment
*
* @param id experiment id
* @return object
* @throws SubmarineRuntimeException the service error
*/
public Experiment deleteExperiment(String id) throws SubmarineRuntimeException {
checkExperimentId(id);
ExperimentEntity entity = experimentService.select(id);
Experiment experiment = buildExperimentFromEntity(entity);
Experiment deletedExperiment = submitter.deleteExperiment(experiment.getSpec());
experimentService.delete(id);
experiment.rebuild(deletedExperiment);
MlflowClient mlflowClient = new MlflowClient("http://submarine-mlflow-service:5000");
try {
MlflowExperimentOptional = mlflowClient.getExperimentByName(id);
MlflowExperiment = MlflowExperimentOptional.get();
String mlflowId = MlflowExperiment.getExperimentId();
mlflowClient.deleteExperiment(mlflowId);
} finally {
return experiment;
}
}
/**
* List experiment logs
*
* @param status status, if null will return all experiment logs
* @return log list
* @throws SubmarineRuntimeException the service error
*/
public List<ExperimentLog> listExperimentLogsByStatus(String status) throws SubmarineRuntimeException {
List<ExperimentLog> experimentLogList = new ArrayList<>();
List<ExperimentEntity> entities = experimentService.selectAll();
for (ExperimentEntity entity : entities) {
Experiment experiment = buildExperimentFromEntity(entity);
Experiment foundExperiment = submitter.findExperiment(experiment.getSpec());
LOG.info("Found experiment: {}", foundExperiment.getStatus());
if (status == null || status.toLowerCase().equals(foundExperiment.getStatus().toLowerCase())) {
experiment.rebuild(foundExperiment);
experimentLogList.add(submitter.getExperimentLogName(
experiment.getSpec(),
experiment.getExperimentId().toString()
));
}
}
return experimentLogList;
}
/**
* Get experiment log
*
* @param id experiment id
* @return object
* @throws SubmarineRuntimeException the service error
*/
public ExperimentLog getExperimentLog(String id) throws SubmarineRuntimeException {
checkExperimentId(id);
ExperimentEntity entity = experimentService.select(id);
Experiment experiment = buildExperimentFromEntity(entity);
Experiment foundExperiment = submitter.findExperiment(experiment.getSpec());
experiment.rebuild(foundExperiment);
return submitter.getExperimentLog(
experiment.getSpec(),
experiment.getExperimentId().toString()
);
}
/**
* Get tensorboard meta data
*
* @return tensorboardinfo
* @throws SubmarineRuntimeException the service error
*/
public TensorboardInfo getTensorboardInfo() throws SubmarineRuntimeException {
return submitter.getTensorboardInfo();
}
/**
* Get mlflow meta data
*
* @return mlflowinfo
* @throws SubmarineRuntimeException the service error
*/
public MlflowInfo getMLflowInfo() throws SubmarineRuntimeException {
return submitter.getMlflowInfo();
}
/**
* Create serve
*
* @param spec spec
* @return object
* @throws SubmarineRuntimeException the service error
*/
public ServeResponse createServe(ServeRequest spec) throws SubmarineRuntimeException {
// TODO(byronhsu): use mlflow api to make sure the model exists. Otherwise, raise exception.
ServeResponse serve = submitter.createServe(spec);
return serve;
}
/**
* Delete serve
*
* @param spec spec
* @return object
* @throws SubmarineRuntimeException the service error
*/
public ServeResponse deleteServe(ServeRequest spec) throws SubmarineRuntimeException {
ServeResponse serve = submitter.deleteServe(spec);
return serve;
}
private void checkSpec(ExperimentSpec spec) throws SubmarineRuntimeException {
if (spec == null) {
throw new SubmarineRuntimeException(Status.OK.getStatusCode(), "Invalid experiment spec.");
}
}
private void checkExperimentId(String id) throws SubmarineRuntimeException {
ExperimentEntity entity = experimentService.select(id);
if (entity == null) {
throw new SubmarineRuntimeException(Status.NOT_FOUND.getStatusCode(), "Not found experiment.");
}
}
private String getSQLAlchemyURL() {
SubmarineConfiguration conf = SubmarineConfiguration.getInstance();
String jdbcUrl = conf.getJdbcUrl();
jdbcUrl = jdbcUrl.substring(jdbcUrl.indexOf("//") + 2, jdbcUrl.indexOf("?"));
String jdbcUserName = conf.getJdbcUserName();
String jdbcPassword = conf.getJdbcPassword();
return "mysql+pymysql://" + jdbcUserName + ":" + jdbcPassword + "@" + jdbcUrl;
}
public ExperimentId generateExperimentId() {
return ExperimentId.newInstance(SubmarineServer.getServerTimeStamp(),
experimentCounter.incrementAndGet());
}
/**
* Create a new experiment instance from entity, and filled
* 1. experimentId
* 2. spec
*
* @param entity
* @return Experiment
*/
private Experiment buildExperimentFromEntity(ExperimentEntity entity) {
Experiment experiment = new Experiment();
experiment.setExperimentId(ExperimentId.fromString(entity.getId()));
experiment.setSpec(new Gson().fromJson(entity.getExperimentSpec(), ExperimentSpec.class));
return experiment;
}
/**
* Create a ExperimentEntity instance from experiment
*
* @param experiment
* @return ExperimentEntity
*/
private ExperimentEntity buildEntityFromExperiment(Experiment experiment) {
ExperimentEntity entity = new ExperimentEntity();
entity.setId(experiment.getExperimentId().toString());
entity.setExperimentSpec(new GsonBuilder().disableHtmlEscaping().create().toJson(experiment.getSpec()));
return entity;
}
}