blob: 0bacbfde9522720e9ced55b067788c0a00b9f10a [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.submarine.server.submitter.k8s.model.xgboostjob;
import com.google.gson.annotations.SerializedName;
import io.kubernetes.client.custom.V1Patch;
import io.kubernetes.client.openapi.ApiException;
import io.kubernetes.client.openapi.models.V1PodTemplateSpec;
import io.kubernetes.client.openapi.models.V1Status;
import io.kubernetes.client.util.generic.options.CreateOptions;
import io.kubernetes.client.util.generic.options.PatchOptions;
import org.apache.submarine.commons.utils.exception.SubmarineRuntimeException;
import org.apache.submarine.server.api.common.CustomResourceType;
import org.apache.submarine.server.api.exception.InvalidSpecException;
import org.apache.submarine.server.api.experiment.Experiment;
import org.apache.submarine.server.api.spec.ExperimentSpec;
import org.apache.submarine.server.api.spec.ExperimentTaskSpec;
import org.apache.submarine.server.submitter.k8s.client.K8sClient;
import org.apache.submarine.server.submitter.k8s.model.mljob.MLJob;
import org.apache.submarine.server.submitter.k8s.model.mljob.MLJobReplicaSpec;
import org.apache.submarine.server.submitter.k8s.parser.ExperimentSpecParser;
import org.apache.submarine.server.submitter.k8s.util.JsonUtils;
import org.apache.submarine.server.submitter.k8s.util.MLJobConverter;
import org.apache.submarine.server.submitter.k8s.util.YamlUtils;
import java.util.HashMap;
import java.util.Map;
public class XGBoostJob extends MLJob {
public static final String CRD_XGBOOST_KIND_V1 = "XGBoostJob";
public static final String CRD_XGBOOST_PLURAL_V1 = "xgboostjobs";
public static final String CRD_XGBOOST_GROUP_V1 = "kubeflow.org";
public static final String CRD_XGBOOST_VERSION_V1 = "v1";
public static final String CRD_XGBOOST_API_VERSION_V1 = CRD_XGBOOST_GROUP_V1 +
"/" + CRD_XGBOOST_VERSION_V1;
@SerializedName("spec")
private XGBoostJobSpec spec;
public XGBoostJob(ExperimentSpec experimentSpec) {
super(experimentSpec);
setApiVersion(CRD_XGBOOST_API_VERSION_V1);
setKind(CRD_XGBOOST_KIND_V1);
setPlural(CRD_XGBOOST_PLURAL_V1);
setVersion(CRD_XGBOOST_VERSION_V1);
setGroup(CRD_XGBOOST_GROUP_V1);
// set spec
setSpec(parseXGBoostJobSpec(experimentSpec));
}
private XGBoostJobSpec parseXGBoostJobSpec(ExperimentSpec experimentSpec)
throws InvalidSpecException {
XGBoostJobSpec xGBoostJobSpec = new XGBoostJobSpec();
Map<XGBoostJobReplicaType, MLJobReplicaSpec> replicaSpecMap = new HashMap<>();
for (Map.Entry<String, ExperimentTaskSpec> entry : experimentSpec.getSpec().entrySet()) {
String replicaType = entry.getKey();
ExperimentTaskSpec taskSpec = entry.getValue();
if (XGBoostJobReplicaType.isSupportedReplicaType(replicaType)) {
MLJobReplicaSpec replicaSpec = new MLJobReplicaSpec();
replicaSpec.setReplicas(taskSpec.getReplicas());
V1PodTemplateSpec podTemplateSpec = ExperimentSpecParser.parseTemplateSpec(taskSpec, experimentSpec);
replicaSpec.setTemplate(podTemplateSpec);
replicaSpecMap.put(XGBoostJobReplicaType.valueOf(replicaType), replicaSpec);
} else {
throw new InvalidSpecException("Unrecognized replica type name: " +
entry.getKey() +
", it should be " +
String.join(",", XGBoostJobReplicaType.names()) +
" for XGBoost experiment.");
}
}
xGBoostJobSpec.setReplicaSpecs(replicaSpecMap);
return xGBoostJobSpec;
}
/**
* Get the job spec which contains all the info for XGBoostJob.
* @return job spec
*/
public XGBoostJobSpec getSpec() {
return spec;
}
/**
* Set the spec, the entry of the XGBoostJob
* @param spec job spec
*/
public void setSpec(XGBoostJobSpec spec) {
this.spec = spec;
}
@Override
public CustomResourceType getResourceType() {
return CustomResourceType.XGBoost;
}
@Override
public Experiment read(K8sClient api) {
try {
XGBoostJob xgBoostJob = api.getXGBoostJobClient()
.get(getMetadata().getNamespace(), getMetadata().getName())
.throwsApiException().getObject();
if (LOG.isDebugEnabled()) {
LOG.debug("Get XGBoostJob resource: \n{}", YamlUtils.toPrettyYaml(xgBoostJob));
}
return parseExperimentResponseObject(xgBoostJob, XGBoostJob.class);
} catch (ApiException e) {
throw new SubmarineRuntimeException(e.getCode(), e.getMessage());
}
}
@Override
public Experiment create(K8sClient api) {
try {
if (LOG.isDebugEnabled()) {
LOG.debug("Create XGBoostJob resource: \n{}", YamlUtils.toPrettyYaml(this));
}
XGBoostJob xgBoostJob = api.getXGBoostJobClient()
.create(getMetadata().getNamespace(), this, new CreateOptions())
.throwsApiException().getObject();
return parseExperimentResponseObject(xgBoostJob, XGBoostJob.class);
} catch (ApiException e) {
LOG.error("K8s submitter: parse XGBoostJob object failed by " + e.getMessage(), e);
throw new SubmarineRuntimeException(e.getCode(), "K8s submitter: parse XGBoostJob object failed by " +
e.getMessage());
}
}
@Override
public Experiment replace(K8sClient api) {
try {
// Using apply yaml patch, field manager must be set, and it must be forced.
// https://kubernetes.io/docs/reference/using-api/server-side-apply/#field-management
PatchOptions patchOptions = new PatchOptions();
patchOptions.setFieldManager(getExperimentId());
patchOptions.setForce(true);
if (LOG.isDebugEnabled()) {
LOG.debug("Patch XGBoostJob resource: \n{}", YamlUtils.toPrettyYaml(this));
}
XGBoostJob xgBoostJob = api.getXGBoostJobClient()
.patch(getMetadata().getNamespace(), getMetadata().getName(),
V1Patch.PATCH_FORMAT_APPLY_YAML,
new V1Patch(JsonUtils.toJson(this)),
patchOptions)
.throwsApiException().getObject();
return parseExperimentResponseObject(xgBoostJob, XGBoostJob.class);
} catch (ApiException e) {
throw new SubmarineRuntimeException(e.getCode(), e.getMessage());
}
}
@Override
public Experiment delete(K8sClient api) {
try {
if (LOG.isDebugEnabled()) {
LOG.debug("Delete XGBoostJob resource in namespace: {} and name: {}",
this.getMetadata().getNamespace(), this.getMetadata().getName());
}
V1Status status = api.getXGBoostJobClient()
.delete(getMetadata().getNamespace(), getMetadata().getName(),
MLJobConverter.toDeleteOptionsFromMLJob(this))
.throwsApiException().getStatus();
return parseExperimentResponseStatus(status);
} catch (ApiException e) {
throw new SubmarineRuntimeException(e.getCode(), e.getMessage());
}
}
}