blob: 1ed2f3961ea0297e54f246d5f4a92a2052d305e1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.reef.examples.suspend;
import org.apache.reef.driver.client.JobMessageObserver;
import org.apache.reef.driver.context.ActiveContext;
import org.apache.reef.driver.context.ContextConfiguration;
import org.apache.reef.driver.evaluator.AllocatedEvaluator;
import org.apache.reef.driver.evaluator.EvaluatorDescriptor;
import org.apache.reef.driver.evaluator.EvaluatorRequest;
import org.apache.reef.driver.evaluator.EvaluatorRequestor;
import org.apache.reef.driver.task.*;
import org.apache.reef.io.checkpoint.fs.FSCheckPointServiceConfiguration;
import org.apache.reef.tang.Configuration;
import org.apache.reef.tang.JavaConfigurationBuilder;
import org.apache.reef.tang.Tang;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.tang.annotations.Unit;
import org.apache.reef.tang.exceptions.BindException;
import org.apache.reef.wake.EventHandler;
import org.apache.reef.wake.remote.impl.ObjectSerializableCodec;
import org.apache.reef.wake.time.event.StartTime;
import org.apache.reef.wake.time.event.StopTime;
import javax.inject.Inject;
import javax.xml.bind.DatatypeConverter;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* Suspend/resume example job driver. Execute a simple task in all evaluators,
* and send EvaluatorControlMessage suspend/resume events properly.
*/
@Unit
public class SuspendDriver {
/**
* Standard Java logger.
*/
private static final Logger LOG = Logger.getLogger(SuspendDriver.class.getName());
/**
* Number of evaluators to request.
*/
private static final int NUM_EVALUATORS = 2;
/**
* String codec is used to encode the results driver sends to the client.
*/
private static final ObjectSerializableCodec<String> CODEC_STR = new ObjectSerializableCodec<>();
/**
* Integer codec is used to decode the results driver gets from the tasks.
*/
private static final ObjectSerializableCodec<Integer> CODEC_INT = new ObjectSerializableCodec<>();
/**
* Job observer on the client.
* We use it to send results from the driver back to the client.
*/
private final JobMessageObserver jobMessageObserver;
/**
* Job driver uses EvaluatorRequestor to request Evaluators that will run the Tasks.
*/
private final EvaluatorRequestor evaluatorRequestor;
/**
* TANG Configuration of the Task.
*/
private final Configuration contextConfig;
/**
* Map from task ID (a string) to the TaskRuntime instance (that can be suspended).
*/
private final Map<String, RunningTask> runningTasks =
Collections.synchronizedMap(new HashMap<String, RunningTask>());
/**
* Map from task ID (a string) to the SuspendedTask instance (that can be resumed).
*/
private final Map<String, SuspendedTask> suspendedTasks = new HashMap<>();
/**
* Job driver constructor.
* All parameters are injected from TANG automatically.
*
* @param evaluatorRequestor is used to request Evaluators.
* @param numCycles number of cycles to run in the task.
* @param delay delay in seconds between cycles in the task.
*/
@Inject
SuspendDriver(
final JobMessageObserver jobMessageObserver,
final EvaluatorRequestor evaluatorRequestor,
@Parameter(Launch.Local.class) final boolean isLocal,
@Parameter(Launch.NumCycles.class) final int numCycles,
@Parameter(Launch.Delay.class) final int delay) {
this.jobMessageObserver = jobMessageObserver;
this.evaluatorRequestor = evaluatorRequestor;
try {
final Configuration checkpointServiceConfig = FSCheckPointServiceConfiguration.CONF
.set(FSCheckPointServiceConfiguration.IS_LOCAL, Boolean.toString(isLocal))
.set(FSCheckPointServiceConfiguration.PATH, "/tmp")
.set(FSCheckPointServiceConfiguration.PREFIX, "reef-checkpoint-")
.set(FSCheckPointServiceConfiguration.REPLICATION_FACTOR, "3")
.build();
final JavaConfigurationBuilder cb = Tang.Factory.getTang().newConfigurationBuilder()
.bindNamedParameter(Launch.NumCycles.class, Integer.toString(numCycles))
.bindNamedParameter(Launch.Delay.class, Integer.toString(delay));
cb.addConfiguration(checkpointServiceConfig);
this.contextConfig = cb.build();
} catch (final BindException ex) {
throw new RuntimeException(ex);
}
}
/**
* Receive notification that the Task is ready to run.
*/
final class RunningTaskHandler implements EventHandler<RunningTask> {
@Override
public void onNext(final RunningTask task) {
LOG.log(Level.INFO, "Running task: {0}", task.getId());
runningTasks.put(task.getId(), task);
jobMessageObserver.sendMessageToClient(CODEC_STR.encode("start task: " + task.getId()));
}
}
/**
* Receive notification that the Task has completed successfully.
*/
final class CompletedTaskHandler implements EventHandler<CompletedTask> {
@Override
public void onNext(final CompletedTask task) {
final EvaluatorDescriptor e = task.getActiveContext().getEvaluatorDescriptor();
final String msg = "Task completed " + task.getId() + " on node " + e;
LOG.info(msg);
jobMessageObserver.sendMessageToClient(CODEC_STR.encode(msg));
runningTasks.remove(task.getId());
task.getActiveContext().close();
final boolean noTasks;
synchronized (suspendedTasks) {
LOG.log(Level.INFO, "Tasks running: {0} suspended: {1}", new Object[]{
runningTasks.size(), suspendedTasks.size()});
noTasks = runningTasks.isEmpty() && suspendedTasks.isEmpty();
}
if (noTasks) {
LOG.info("All tasks completed; shutting down.");
}
}
}
/**
* Receive notification that the Task has been suspended.
*/
final class SuspendedTaskHandler implements EventHandler<SuspendedTask> {
@Override
public void onNext(final SuspendedTask task) {
final String msg = "Task suspended: " + task.getId();
LOG.info(msg);
synchronized (suspendedTasks) {
suspendedTasks.put(task.getId(), task);
runningTasks.remove(task.getId());
}
jobMessageObserver.sendMessageToClient(CODEC_STR.encode(msg));
}
}
/**
* Receive message from the Task.
*/
final class TaskMessageHandler implements EventHandler<TaskMessage> {
@Override
public void onNext(final TaskMessage message) {
final int result = CODEC_INT.decode(message.get());
final String msg = "Task message " + message.getId() + ": " + result;
LOG.info(msg);
jobMessageObserver.sendMessageToClient(CODEC_STR.encode(msg));
}
}
/**
* Receive notification that an Evaluator had been allocated,
* and submitTask a new Task in that Evaluator.
*/
final class AllocatedEvaluatorHandler implements EventHandler<AllocatedEvaluator> {
@Override
public void onNext(final AllocatedEvaluator eval) {
try {
LOG.log(Level.INFO, "Allocated Evaluator: {0}", eval.getId());
final Configuration thisContextConfiguration = ContextConfiguration.CONF.set(
ContextConfiguration.IDENTIFIER, eval.getId() + "_context").build();
eval.submitContext(Tang.Factory.getTang()
.newConfigurationBuilder(thisContextConfiguration, contextConfig).build());
} catch (final BindException ex) {
throw new RuntimeException(ex);
}
}
}
/**
* Receive notification that a new Context is available.
* Submit a new Task to that Context.
*/
final class ActiveContextHandler implements EventHandler<ActiveContext> {
@Override
public synchronized void onNext(final ActiveContext context) {
LOG.log(Level.INFO, "Active Context: {0}", context.getId());
try {
context.submitTask(TaskConfiguration.CONF
.set(TaskConfiguration.IDENTIFIER, context.getId() + "_task")
.set(TaskConfiguration.TASK, SuspendTestTask.class)
.set(TaskConfiguration.ON_SUSPEND, SuspendTestTask.SuspendHandler.class)
.build());
} catch (final BindException ex) {
LOG.log(Level.SEVERE, "Bad Task configuration for context: " + context.getId(), ex);
throw new RuntimeException(ex);
}
}
}
/**
* Handle notifications from the client.
*/
final class ClientMessageHandler implements EventHandler<byte[]> {
@Override
public void onNext(final byte[] message) {
final String commandStr = CODEC_STR.decode(message);
LOG.log(Level.INFO, "Client message: {0}", commandStr);
final String[] split = commandStr.split("\\s+", 2);
if (split.length != 2) {
throw new IllegalArgumentException("Bad command: " + commandStr);
} else {
final String command = split[0].toLowerCase().intern();
final String taskId = split[1];
switch (command) {
case "suspend": {
final RunningTask task = runningTasks.get(taskId);
if (task != null) {
task.suspend();
} else {
throw new IllegalArgumentException("Suspend: Task not found: " + taskId);
}
break;
}
case "resume": {
final SuspendedTask suspendedTask;
synchronized (suspendedTasks) {
suspendedTask = suspendedTasks.remove(taskId);
}
if (suspendedTask != null) {
try {
suspendedTask.getActiveContext().submitTask(TaskConfiguration.CONF
.set(TaskConfiguration.IDENTIFIER, taskId)
.set(TaskConfiguration.MEMENTO,
DatatypeConverter.printBase64Binary(suspendedTask.get()))
.build());
} catch (final BindException e) {
throw new RuntimeException(e);
}
} else {
throw new IllegalArgumentException("Resume: Task not found: " + taskId);
}
break;
}
default:
throw new IllegalArgumentException("Bad command: " + command);
}
}
}
}
/**
* Job Driver is ready and the clock is set up: request the evaluators.
*/
final class StartHandler implements EventHandler<StartTime> {
@Override
public void onNext(final StartTime time) {
LOG.log(Level.INFO, "StartTime: {0}", time);
evaluatorRequestor.submit(EvaluatorRequest.newBuilder()
.setMemory(128).setNumberOfCores(1).setNumber(NUM_EVALUATORS).build());
}
}
/**
* Shutting down the job driver: close the evaluators.
*/
final class StopHandler implements EventHandler<StopTime> {
@Override
public void onNext(final StopTime time) {
LOG.log(Level.INFO, "StopTime: {0}", time);
jobMessageObserver.sendMessageToClient(CODEC_STR.encode("got StopTime"));
}
}
}