blob: 3bb4c812da5fffb112921e15def092fb630d302c [file] [log] [blame]
/*
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.airavata.helix.impl.workflow;
import org.apache.airavata.common.exception.AiravataException;
import org.apache.airavata.common.exception.ApplicationSettingsException;
import org.apache.airavata.common.utils.ServerSettings;
import org.apache.airavata.common.utils.ThriftUtils;
import org.apache.airavata.helix.core.AbstractTask;
import org.apache.airavata.helix.core.OutPort;
import org.apache.airavata.helix.impl.task.AiravataTask;
import org.apache.airavata.helix.impl.task.cancel.CancelCompletingTask;
import org.apache.airavata.helix.impl.task.cancel.RemoteJobCancellationTask;
import org.apache.airavata.helix.impl.task.cancel.WorkflowCancellationTask;
import org.apache.airavata.helix.impl.task.env.EnvSetupTask;
import org.apache.airavata.helix.impl.task.staging.InputDataStagingTask;
import org.apache.airavata.helix.impl.task.submission.DefaultJobSubmissionTask;
import org.apache.airavata.messaging.core.*;
import org.apache.airavata.model.experiment.ExperimentModel;
import org.apache.airavata.model.messaging.event.*;
import org.apache.airavata.model.process.ProcessModel;
import org.apache.airavata.model.process.ProcessWorkflow;
import org.apache.airavata.model.status.ProcessState;
import org.apache.airavata.model.status.ProcessStatus;
import org.apache.airavata.model.task.TaskModel;
import org.apache.airavata.model.task.TaskTypes;
import org.apache.airavata.patform.monitoring.CountMonitor;
import org.apache.airavata.patform.monitoring.MonitoringServer;
import org.apache.airavata.registry.api.RegistryService;
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import java.util.stream.Collectors;
public class PreWorkflowManager extends WorkflowManager {
private final static Logger logger = LoggerFactory.getLogger(PreWorkflowManager.class);
private final static CountMonitor prewfCounter = new CountMonitor("pre_wf_counter");
private Subscriber subscriber;
public PreWorkflowManager() throws ApplicationSettingsException {
super(ServerSettings.getSetting("pre.workflow.manager.name"),
Boolean.parseBoolean(ServerSettings.getSetting("pre.workflow.manager.loadbalance.clusters")));
}
public void startServer() throws Exception {
super.initComponents();
initLaunchSubscriber();
}
public void stopServer() {
}
private void initLaunchSubscriber() throws AiravataException {
List<String> routingKeys = new ArrayList<>();
routingKeys.add(ServerSettings.getRabbitmqProcessExchangeName());
this.subscriber = MessagingFactory.getSubscriber(new ProcessLaunchMessageHandler(), routingKeys, Type.PROCESS_LAUNCH);
}
private String createAndLaunchPreWorkflow(String processId, boolean forceRun) throws Exception {
prewfCounter.inc();
RegistryService.Client registryClient = getRegistryClientPool().getResource();
ProcessModel processModel;
ExperimentModel experimentModel;
try {
processModel = registryClient.getProcess(processId);
experimentModel = registryClient.getExperiment(processModel.getExperimentId());
getRegistryClientPool().returnResource(registryClient);
} catch (Exception e) {
logger.error("Failed to fetch experiment or process from registry associated with process id " + processId, e);
getRegistryClientPool().returnBrokenResource(registryClient);
throw new Exception("Failed to fetch experiment or process from registry associated with process id " + processId, e);
}
String taskDag = processModel.getTaskDag();
List<TaskModel> taskList = processModel.getTasks();
String[] taskIds = taskDag.split(",");
final List<AiravataTask> allTasks = new ArrayList<>();
boolean jobSubmissionFound = false;
for (String taskId : taskIds) {
Optional<TaskModel> model = taskList.stream().filter(taskModel -> taskModel.getTaskId().equals(taskId)).findFirst();
if (model.isPresent()) {
TaskModel taskModel = model.get();
AiravataTask airavataTask = null;
if (taskModel.getTaskType() == TaskTypes.ENV_SETUP) {
airavataTask = new EnvSetupTask();
airavataTask.setForceRunTask(true);
} else if (taskModel.getTaskType() == TaskTypes.JOB_SUBMISSION) {
airavataTask = new DefaultJobSubmissionTask();
airavataTask.setForceRunTask(forceRun);
jobSubmissionFound = true;
} else if (taskModel.getTaskType() == TaskTypes.DATA_STAGING) {
if (!jobSubmissionFound) {
airavataTask = new InputDataStagingTask();
airavataTask.setForceRunTask(true);
}
}
if (airavataTask != null) {
airavataTask.setGatewayId(experimentModel.getGatewayId());
airavataTask.setExperimentId(experimentModel.getExperimentId());
airavataTask.setProcessId(processModel.getProcessId());
airavataTask.setTaskId(taskModel.getTaskId());
airavataTask.setRetryCount(taskModel.getMaxRetry());
if (allTasks.size() > 0) {
allTasks.get(allTasks.size() -1).setNextTask(new OutPort(airavataTask.getTaskId(), airavataTask));
}
allTasks.add(airavataTask);
}
}
}
String workflowName = getWorkflowOperator().launchWorkflow(processId + "-PRE-" + UUID.randomUUID().toString(),
new ArrayList<>(allTasks), true, false);
registerWorkflowForProcess(processId, workflowName, "PRE");
return workflowName;
}
private String createAndLaunchCancelWorkflow(String processId, String gateway) throws Exception {
RegistryService.Client registryClient = getRegistryClientPool().getResource();
ProcessModel processModel;
try {
processModel = registryClient.getProcess(processId);
getRegistryClientPool().returnResource(registryClient);
} catch (Exception e) {
logger.error("Failed to fetch process from registry associated with process id " + processId, e);
getRegistryClientPool().returnBrokenResource(registryClient);
throw new Exception("Failed to fetch process from registry associated with process id " + processId, e);
}
String experimentId = processModel.getExperimentId();
final List<AbstractTask> allTasks = new ArrayList<>();
Optional<List<String>> workflowsOpt = Optional.ofNullable(processModel.getProcessWorkflows()).map(wfs -> wfs.stream().map(ProcessWorkflow::getWorkflowId).collect(Collectors.toList()));
if (workflowsOpt.isPresent()) {
List<String> workflows = workflowsOpt.get();
if (workflows.size() > 0) {
for (String wf : workflows) {
logger.info("Creating cancellation task for workflow " + wf + " of process " + processId);
WorkflowCancellationTask wfct = new WorkflowCancellationTask();
wfct.setTaskId(UUID.randomUUID().toString());
wfct.setCancellingWorkflowName(wf);
if (allTasks.size() > 0) {
allTasks.get(allTasks.size() - 1).setNextTask(new OutPort(wfct.getTaskId(), wfct));
}
allTasks.add(wfct);
}
} else {
logger.warn("No workflow registered with process " + processId + " to cancel");
}
} else {
logger.warn("No workflow registered with process " + processId + " to cancel");
}
RemoteJobCancellationTask rjct = new RemoteJobCancellationTask();
rjct.setTaskId(UUID.randomUUID().toString());
rjct.setExperimentId(experimentId);
rjct.setProcessId(processId);
rjct.setGatewayId(gateway);
rjct.setSkipTaskStatusPublish(true);
if (allTasks.size() > 0) {
allTasks.get(allTasks.size() -1).setNextTask(new OutPort(rjct.getTaskId(), rjct));
}
allTasks.add(rjct);
CancelCompletingTask cct = new CancelCompletingTask();
cct.setTaskId(UUID.randomUUID().toString());
cct.setExperimentId(experimentId);
cct.setProcessId(processId);
cct.setGatewayId(gateway);
cct.setSkipTaskStatusPublish(true);
if (allTasks.size() > 0) {
allTasks.get(allTasks.size() -1).setNextTask(new OutPort(cct.getTaskId(), cct));
}
allTasks.add(cct);
String workflow = getWorkflowOperator().launchWorkflow(processId + "-CANCEL-" + UUID.randomUUID().toString(), allTasks, true, false);
logger.info("Started launching workflow " + workflow + " to cancel process " + processId);
return workflow;
}
public static void main(String[] args) throws Exception {
if (ServerSettings.getBooleanSetting("pre.workflow.manager.monitoring.enabled")) {
MonitoringServer monitoringServer = new MonitoringServer(
ServerSettings.getSetting("pre.workflow.manager.monitoring.host"),
ServerSettings.getIntSetting("pre.workflow.manager.monitoring.port"));
monitoringServer.start();
Runtime.getRuntime().addShutdownHook(new Thread(monitoringServer::stop));
}
PreWorkflowManager preWorkflowManager = new PreWorkflowManager();
preWorkflowManager.startServer();
}
private class ProcessLaunchMessageHandler implements MessageHandler {
@Override
public void onMessage(MessageContext messageContext) {
logger.info(" Message Received with message id " + messageContext.getMessageId() + " and with message type: " + messageContext.getType());
if (messageContext.getType().equals(MessageType.LAUNCHPROCESS)) {
ProcessSubmitEvent event = new ProcessSubmitEvent();
TBase messageEvent = messageContext.getEvent();
try {
byte[] bytes = ThriftUtils.serializeThriftObject(messageEvent);
ThriftUtils.createThriftFromBytes(bytes, event);
} catch (TException e) {
logger.error("Failed to fetch process submit event", e);
subscriber.sendAck(messageContext.getDeliveryTag());
}
String processId = event.getProcessId();
String experimentId = event.getExperimentId();
String gateway = event.getGatewayId();
logger.info("Received process launch message for process " + processId + " of experiment " + experimentId + " in gateway " + gateway);
try {
logger.info("Launching the pre workflow for process " + processId + " of experiment " + experimentId + " in gateway " + gateway);
String workflowName = createAndLaunchPreWorkflow(processId, false);
logger.info("Completed launching the pre workflow " + workflowName + " for process" + processId + " of experiment " + experimentId + " in gateway " + gateway);
// updating the process status
ProcessStatus status = new ProcessStatus();
status.setState(ProcessState.STARTED);
status.setTimeOfStateChange(Calendar.getInstance().getTimeInMillis());
publishProcessStatus(processId, experimentId, gateway, ProcessState.STARTED);
subscriber.sendAck(messageContext.getDeliveryTag());
} catch (Exception e) {
logger.error("Failed to launch the pre workflow for process " + processId + " in gateway " + gateway, e);
//subscriber.sendAck(messageContext.getDeliveryTag());
}
} else if (messageContext.getType().equals(MessageType.TERMINATEPROCESS)) {
ProcessTerminateEvent event = new ProcessTerminateEvent();
TBase messageEvent = messageContext.getEvent();
try {
byte[] bytes = ThriftUtils.serializeThriftObject(messageEvent);
ThriftUtils.createThriftFromBytes(bytes, event);
} catch (TException e) {
logger.error("Failed to fetch process cancellation event", e);
subscriber.sendAck(messageContext.getDeliveryTag());
}
String processId = event.getProcessId();
String gateway = event.getGatewayId();
logger.info("Received process cancel message for process " + processId + " in gateway " + gateway);
try {
logger.info("Launching the process cancel workflow for process " + processId + " in gateway " + gateway);
String workflowName = createAndLaunchCancelWorkflow(processId, gateway);
logger.info("Completed process cancel workflow " + workflowName + " for process " + processId + " in gateway " + gateway);
subscriber.sendAck(messageContext.getDeliveryTag());
} catch (Exception e) {
logger.error("Failed to launch process cancel workflow for process " + processId + " in gateway " + gateway, e);
//subscriber.sendAck(messageContext.getDeliveryTag());
}
} else {
logger.warn("Unknown message type");
subscriber.sendAck(messageContext.getDeliveryTag());
}
}
}
}