blob: 1ed2f3961ea0297e54f246d5f4a92a2052d305e1 [file] [log] [blame]
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
package org.apache.reef.examples.suspend;
import org.apache.reef.driver.client.JobMessageObserver;
import org.apache.reef.driver.context.ActiveContext;
import org.apache.reef.driver.context.ContextConfiguration;
import org.apache.reef.driver.evaluator.AllocatedEvaluator;
import org.apache.reef.driver.evaluator.EvaluatorDescriptor;
import org.apache.reef.driver.evaluator.EvaluatorRequest;
import org.apache.reef.driver.evaluator.EvaluatorRequestor;
import org.apache.reef.driver.task.*;
import org.apache.reef.tang.Configuration;
import org.apache.reef.tang.JavaConfigurationBuilder;
import org.apache.reef.tang.Tang;
import org.apache.reef.tang.annotations.Parameter;
import org.apache.reef.tang.annotations.Unit;
import org.apache.reef.tang.exceptions.BindException;
import org.apache.reef.wake.EventHandler;
import org.apache.reef.wake.remote.impl.ObjectSerializableCodec;
import org.apache.reef.wake.time.event.StartTime;
import org.apache.reef.wake.time.event.StopTime;
import javax.inject.Inject;
import javax.xml.bind.DatatypeConverter;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
* Suspend/resume example job driver. Execute a simple task in all evaluators,
* and send EvaluatorControlMessage suspend/resume events properly.
public class SuspendDriver {
* Standard Java logger.
private static final Logger LOG = Logger.getLogger(SuspendDriver.class.getName());
* Number of evaluators to request.
private static final int NUM_EVALUATORS = 2;
* String codec is used to encode the results driver sends to the client.
private static final ObjectSerializableCodec<String> CODEC_STR = new ObjectSerializableCodec<>();
* Integer codec is used to decode the results driver gets from the tasks.
private static final ObjectSerializableCodec<Integer> CODEC_INT = new ObjectSerializableCodec<>();
* Job observer on the client.
* We use it to send results from the driver back to the client.
private final JobMessageObserver jobMessageObserver;
* Job driver uses EvaluatorRequestor to request Evaluators that will run the Tasks.
private final EvaluatorRequestor evaluatorRequestor;
* TANG Configuration of the Task.
private final Configuration contextConfig;
* Map from task ID (a string) to the TaskRuntime instance (that can be suspended).
private final Map<String, RunningTask> runningTasks =
Collections.synchronizedMap(new HashMap<String, RunningTask>());
* Map from task ID (a string) to the SuspendedTask instance (that can be resumed).
private final Map<String, SuspendedTask> suspendedTasks = new HashMap<>();
* Job driver constructor.
* All parameters are injected from TANG automatically.
* @param evaluatorRequestor is used to request Evaluators.
* @param numCycles number of cycles to run in the task.
* @param delay delay in seconds between cycles in the task.
final JobMessageObserver jobMessageObserver,
final EvaluatorRequestor evaluatorRequestor,
@Parameter(Launch.Local.class) final boolean isLocal,
@Parameter(Launch.NumCycles.class) final int numCycles,
@Parameter(Launch.Delay.class) final int delay) {
this.jobMessageObserver = jobMessageObserver;
this.evaluatorRequestor = evaluatorRequestor;
try {
final Configuration checkpointServiceConfig = FSCheckPointServiceConfiguration.CONF
.set(FSCheckPointServiceConfiguration.IS_LOCAL, Boolean.toString(isLocal))
.set(FSCheckPointServiceConfiguration.PATH, "/tmp")
.set(FSCheckPointServiceConfiguration.PREFIX, "reef-checkpoint-")
.set(FSCheckPointServiceConfiguration.REPLICATION_FACTOR, "3")
final JavaConfigurationBuilder cb = Tang.Factory.getTang().newConfigurationBuilder()
.bindNamedParameter(Launch.NumCycles.class, Integer.toString(numCycles))
.bindNamedParameter(Launch.Delay.class, Integer.toString(delay));
this.contextConfig =;
} catch (final BindException ex) {
throw new RuntimeException(ex);
* Receive notification that the Task is ready to run.
final class RunningTaskHandler implements EventHandler<RunningTask> {
public void onNext(final RunningTask task) {
LOG.log(Level.INFO, "Running task: {0}", task.getId());
runningTasks.put(task.getId(), task);
jobMessageObserver.sendMessageToClient(CODEC_STR.encode("start task: " + task.getId()));
* Receive notification that the Task has completed successfully.
final class CompletedTaskHandler implements EventHandler<CompletedTask> {
public void onNext(final CompletedTask task) {
final EvaluatorDescriptor e = task.getActiveContext().getEvaluatorDescriptor();
final String msg = "Task completed " + task.getId() + " on node " + e;;
final boolean noTasks;
synchronized (suspendedTasks) {
LOG.log(Level.INFO, "Tasks running: {0} suspended: {1}", new Object[]{
runningTasks.size(), suspendedTasks.size()});
noTasks = runningTasks.isEmpty() && suspendedTasks.isEmpty();
if (noTasks) {"All tasks completed; shutting down.");
* Receive notification that the Task has been suspended.
final class SuspendedTaskHandler implements EventHandler<SuspendedTask> {
public void onNext(final SuspendedTask task) {
final String msg = "Task suspended: " + task.getId();;
synchronized (suspendedTasks) {
suspendedTasks.put(task.getId(), task);
* Receive message from the Task.
final class TaskMessageHandler implements EventHandler<TaskMessage> {
public void onNext(final TaskMessage message) {
final int result = CODEC_INT.decode(message.get());
final String msg = "Task message " + message.getId() + ": " + result;;
* Receive notification that an Evaluator had been allocated,
* and submitTask a new Task in that Evaluator.
final class AllocatedEvaluatorHandler implements EventHandler<AllocatedEvaluator> {
public void onNext(final AllocatedEvaluator eval) {
try {
LOG.log(Level.INFO, "Allocated Evaluator: {0}", eval.getId());
final Configuration thisContextConfiguration = ContextConfiguration.CONF.set(
ContextConfiguration.IDENTIFIER, eval.getId() + "_context").build();
.newConfigurationBuilder(thisContextConfiguration, contextConfig).build());
} catch (final BindException ex) {
throw new RuntimeException(ex);
* Receive notification that a new Context is available.
* Submit a new Task to that Context.
final class ActiveContextHandler implements EventHandler<ActiveContext> {
public synchronized void onNext(final ActiveContext context) {
LOG.log(Level.INFO, "Active Context: {0}", context.getId());
try {
.set(TaskConfiguration.IDENTIFIER, context.getId() + "_task")
.set(TaskConfiguration.TASK, SuspendTestTask.class)
.set(TaskConfiguration.ON_SUSPEND, SuspendTestTask.SuspendHandler.class)
} catch (final BindException ex) {
LOG.log(Level.SEVERE, "Bad Task configuration for context: " + context.getId(), ex);
throw new RuntimeException(ex);
* Handle notifications from the client.
final class ClientMessageHandler implements EventHandler<byte[]> {
public void onNext(final byte[] message) {
final String commandStr = CODEC_STR.decode(message);
LOG.log(Level.INFO, "Client message: {0}", commandStr);
final String[] split = commandStr.split("\\s+", 2);
if (split.length != 2) {
throw new IllegalArgumentException("Bad command: " + commandStr);
} else {
final String command = split[0].toLowerCase().intern();
final String taskId = split[1];
switch (command) {
case "suspend": {
final RunningTask task = runningTasks.get(taskId);
if (task != null) {
} else {
throw new IllegalArgumentException("Suspend: Task not found: " + taskId);
case "resume": {
final SuspendedTask suspendedTask;
synchronized (suspendedTasks) {
suspendedTask = suspendedTasks.remove(taskId);
if (suspendedTask != null) {
try {
.set(TaskConfiguration.IDENTIFIER, taskId)
} catch (final BindException e) {
throw new RuntimeException(e);
} else {
throw new IllegalArgumentException("Resume: Task not found: " + taskId);
throw new IllegalArgumentException("Bad command: " + command);
* Job Driver is ready and the clock is set up: request the evaluators.
final class StartHandler implements EventHandler<StartTime> {
public void onNext(final StartTime time) {
LOG.log(Level.INFO, "StartTime: {0}", time);
* Shutting down the job driver: close the evaluators.
final class StopHandler implements EventHandler<StopTime> {
public void onNext(final StopTime time) {
LOG.log(Level.INFO, "StopTime: {0}", time);
jobMessageObserver.sendMessageToClient(CODEC_STR.encode("got StopTime"));