| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.giraph; |
| |
| import org.apache.giraph.aggregators.LongSumAggregator; |
| import org.apache.giraph.bsp.BspService; |
| import org.apache.giraph.conf.GiraphConfiguration; |
| import org.apache.giraph.conf.GiraphConstants; |
| import org.apache.giraph.conf.ImmutableClassesGiraphConfiguration; |
| import org.apache.giraph.edge.Edge; |
| import org.apache.giraph.edge.EdgeFactory; |
| import org.apache.giraph.examples.SimpleSuperstepComputation; |
| import org.apache.giraph.graph.BasicComputation; |
| import org.apache.giraph.graph.Vertex; |
| import org.apache.giraph.job.GiraphJob; |
| import org.apache.giraph.master.DefaultMasterCompute; |
| import org.apache.giraph.worker.DefaultWorkerContext; |
| import org.apache.giraph.zk.ZooKeeperExt; |
| import org.apache.giraph.zk.ZooKeeperManager; |
| import org.apache.hadoop.fs.FileStatus; |
| import org.apache.hadoop.fs.Path; |
| import org.apache.hadoop.io.FloatWritable; |
| import org.apache.hadoop.io.IntWritable; |
| import org.apache.hadoop.io.LongWritable; |
| import org.apache.hadoop.io.Writable; |
| import org.apache.log4j.Logger; |
| import org.apache.zookeeper.CreateMode; |
| import org.apache.zookeeper.KeeperException; |
| import org.apache.zookeeper.ZooDefs; |
| import org.junit.Assert; |
| import org.junit.Test; |
| |
| import java.io.DataInput; |
| import java.io.DataOutput; |
| import java.io.IOException; |
| import java.util.List; |
| |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertTrue; |
| import static org.junit.Assert.fail; |
| |
| /** |
| * Tests that worker context and master computation |
| * are properly saved and loaded back at checkpoint. |
| */ |
| public class TestCheckpointing extends BspCase { |
| |
| /** Class logger */ |
| private static final Logger LOG = |
| Logger.getLogger(TestCheckpointing.class); |
| /** ID to be used with test job */ |
| public static final String TEST_JOB_ID = "test_job"; |
| |
| private static SuperstepCallback SUPERSTEP_CALLBACK; |
| |
| /** |
| * Create the test case |
| */ |
| public TestCheckpointing() { |
| super(TestCheckpointing.class.getName()); |
| } |
| |
| @Test |
| public void testBspCheckpoint() throws InterruptedException, IOException, ClassNotFoundException { |
| testBspCheckpoint(false); |
| } |
| |
| @Test |
| public void testAsyncMessageStoreCheckpoint() throws InterruptedException, IOException, ClassNotFoundException { |
| testBspCheckpoint(true); |
| } |
| |
| public void testBspCheckpoint(boolean useAsyncMessageStore) |
| throws IOException, InterruptedException, ClassNotFoundException { |
| Path checkpointsDir = getTempPath("checkpointing"); |
| GiraphConfiguration conf = new GiraphConfiguration(); |
| if (useAsyncMessageStore) { |
| GiraphConstants.ASYNC_MESSAGE_STORE_THREADS_COUNT.set(conf, 2); |
| } |
| |
| SUPERSTEP_CALLBACK = null; |
| |
| GiraphConstants.CLEANUP_CHECKPOINTS_AFTER_SUCCESS.set(conf, false); |
| conf.setCheckpointFrequency(2); |
| |
| long idSum = runOriginalJob(checkpointsDir, conf); |
| assertEquals(10, idSum); |
| |
| SUPERSTEP_CALLBACK = new SuperstepCallback() { |
| @Override |
| public void superstep(long superstep, |
| ImmutableClassesGiraphConfiguration<LongWritable, IntWritable, FloatWritable> conf) { |
| if (superstep < 2) { |
| Assert.fail("Restarted JOB should not be executed on superstep " + superstep); |
| } |
| } |
| }; |
| |
| runRestartedJob(checkpointsDir, conf, idSum, 2); |
| |
| |
| } |
| |
| private void runRestartedJob(Path checkpointsDir, GiraphConfiguration conf, long idSum, long restartFrom) throws IOException, InterruptedException, ClassNotFoundException { |
| Path outputPath; |
| LOG.info("testBspCheckpoint: Restarting from the latest superstep " + |
| "with checkpoint path = " + checkpointsDir); |
| outputPath = getTempPath("checkpointing_restarted"); |
| |
| GiraphConstants.RESTART_JOB_ID.set(conf, TEST_JOB_ID); |
| conf.set("mapred.job.id", "restarted_test_job"); |
| if (restartFrom >= 0) { |
| conf.set(GiraphConstants.RESTART_SUPERSTEP, Long.toString(restartFrom)); |
| } |
| |
| GiraphJob restartedJob = prepareJob(getCallingMethodName() + "Restarted", |
| conf, outputPath); |
| |
| GiraphConstants.CHECKPOINT_DIRECTORY.set(restartedJob.getConfiguration(), |
| checkpointsDir.toString()); |
| |
| assertTrue(restartedJob.run(true)); |
| |
| |
| if (!runningInDistributedMode()) { |
| long idSumRestarted = |
| CheckpointVertexWorkerContext |
| .getFinalSum(); |
| LOG.info("testBspCheckpoint: idSumRestarted = " + |
| idSumRestarted); |
| assertEquals(idSum, idSumRestarted); |
| } |
| } |
| |
| private long runOriginalJob(Path checkpointsDir, GiraphConfiguration conf) throws IOException, InterruptedException, ClassNotFoundException { |
| Path outputPath = getTempPath("checkpointing_original"); |
| conf.setComputationClass( |
| CheckpointComputation.class); |
| conf.setWorkerContextClass( |
| CheckpointVertexWorkerContext.class); |
| conf.setMasterComputeClass( |
| CheckpointVertexMasterCompute.class); |
| conf.setVertexInputFormatClass(SimpleSuperstepComputation.SimpleSuperstepVertexInputFormat.class); |
| conf.setVertexOutputFormatClass(SimpleSuperstepComputation.SimpleSuperstepVertexOutputFormat.class); |
| conf.set("mapred.job.id", TEST_JOB_ID); |
| GiraphJob job = prepareJob(getCallingMethodName(), conf, outputPath); |
| |
| GiraphConfiguration configuration = job.getConfiguration(); |
| GiraphConstants.CHECKPOINT_DIRECTORY.set(configuration, checkpointsDir.toString()); |
| |
| assertTrue(job.run(true)); |
| |
| long idSum = 0; |
| if (!runningInDistributedMode()) { |
| FileStatus fileStatus = getSinglePartFileStatus(job.getConfiguration(), |
| outputPath); |
| idSum = CheckpointVertexWorkerContext |
| .getFinalSum(); |
| LOG.info("testBspCheckpoint: idSum = " + idSum + |
| " fileLen = " + fileStatus.getLen()); |
| } |
| return idSum; |
| } |
| |
| |
| /** |
| * Actual computation. |
| */ |
| public static class CheckpointComputation extends |
| BasicComputation<LongWritable, IntWritable, FloatWritable, |
| FloatWritable> { |
| @Override |
| public void compute( |
| Vertex<LongWritable, IntWritable, FloatWritable> vertex, |
| Iterable<FloatWritable> messages) throws IOException { |
| CheckpointVertexWorkerContext workerContext = getWorkerContext(); |
| assertEquals(getSuperstep() + 1, workerContext.testValue); |
| |
| if (getSuperstep() > 4) { |
| vertex.voteToHalt(); |
| return; |
| } |
| |
| aggregate(LongSumAggregator.class.getName(), |
| new LongWritable(vertex.getId().get())); |
| |
| float msgValue = 0.0f; |
| for (FloatWritable message : messages) { |
| float curMsgValue = message.get(); |
| msgValue += curMsgValue; |
| } |
| |
| int vertexValue = vertex.getValue().get(); |
| vertex.setValue(new IntWritable(vertexValue + (int) msgValue)); |
| for (Edge<LongWritable, FloatWritable> edge : vertex.getEdges()) { |
| FloatWritable newEdgeValue = new FloatWritable(edge.getValue().get() + |
| (float) vertexValue); |
| Edge<LongWritable, FloatWritable> newEdge = |
| EdgeFactory.create(edge.getTargetVertexId(), newEdgeValue); |
| vertex.addEdge(newEdge); |
| sendMessage(edge.getTargetVertexId(), newEdgeValue); |
| |
| } |
| } |
| } |
| |
| @Test |
| public void testManualCheckpointAtTheBeginning() |
| throws InterruptedException, IOException, ClassNotFoundException { |
| testManualCheckpoint(0); |
| } |
| |
| @Test |
| public void testManualCheckpoint() |
| throws InterruptedException, IOException, ClassNotFoundException { |
| testManualCheckpoint(2); |
| } |
| |
| |
| private void testManualCheckpoint(final int checkpointSuperstep) |
| throws IOException, InterruptedException, ClassNotFoundException { |
| Path checkpointsDir = getTempPath("checkpointing"); |
| GiraphConfiguration conf = new GiraphConfiguration(); |
| |
| SUPERSTEP_CALLBACK = new SuperstepCallback() { |
| |
| @Override |
| public void superstep(long superstep, ImmutableClassesGiraphConfiguration<LongWritable, IntWritable, FloatWritable> conf) { |
| if (superstep == checkpointSuperstep) { |
| try { |
| ZooKeeperExt zooKeeperExt = new ZooKeeperExt(conf.getZookeeperList(), |
| conf.getZooKeeperSessionTimeout(), |
| conf.getZookeeperOpsMaxAttempts(), |
| conf.getZookeeperOpsRetryWaitMsecs(), |
| TestCheckpointing.this); |
| String basePath = ZooKeeperManager.getBasePath(conf) + BspService.BASE_DIR + "/" + conf.get("mapred.job.id"); |
| zooKeeperExt.createExt( |
| basePath + BspService.FORCE_CHECKPOINT_USER_FLAG, |
| null, |
| ZooDefs.Ids.OPEN_ACL_UNSAFE, |
| CreateMode.PERSISTENT, |
| true); |
| } catch (IOException | InterruptedException | KeeperException e) { |
| throw new RuntimeException(e); |
| } |
| } else if (superstep > checkpointSuperstep) { |
| Assert.fail("Job should be stopped by now " + superstep); |
| } |
| } |
| }; |
| |
| try { |
| runOriginalJob(checkpointsDir, conf); |
| fail("Original job should fail after checkpointing"); |
| } catch (Exception e) { |
| LOG.info("Original job failed, that's OK " + e); |
| } |
| |
| SUPERSTEP_CALLBACK = new SuperstepCallback() { |
| @Override |
| public void superstep(long superstep, |
| ImmutableClassesGiraphConfiguration<LongWritable, IntWritable, FloatWritable> conf) { |
| if (superstep < checkpointSuperstep) { |
| Assert.fail("Restarted JOB should not be executed on superstep " + superstep); |
| } |
| } |
| }; |
| |
| runRestartedJob(checkpointsDir, conf, 10, -1); |
| } |
| |
| /** |
| * Worker context associated. |
| */ |
| public static class CheckpointVertexWorkerContext |
| extends DefaultWorkerContext { |
| /** User can access this after the application finishes if local */ |
| private static long FINAL_SUM; |
| |
| private int testValue; |
| |
| public static long getFinalSum() { |
| return FINAL_SUM; |
| } |
| |
| @Override |
| public void postSuperstep() { |
| super.postSuperstep(); |
| sendMessageToMyself(new LongWritable(getSuperstep())); |
| } |
| |
| /** |
| * Send message to all workers (except this worker) |
| * |
| * @param message Message to send |
| */ |
| private void sendMessageToMyself(Writable message) { |
| sendMessageToWorker(message, getMyWorkerIndex()); |
| } |
| |
| @Override |
| public void postApplication() { |
| setFinalSum(this.<LongWritable>getAggregatedValue( |
| LongSumAggregator.class.getName()).get()); |
| LOG.info("FINAL_SUM=" + FINAL_SUM); |
| } |
| |
| /** |
| * Set the final sum |
| * |
| * @param value sum |
| */ |
| private static void setFinalSum(long value) { |
| FINAL_SUM = value; |
| } |
| |
| @Override |
| public void preSuperstep() { |
| assertEquals(getSuperstep(), testValue++); |
| if (getSuperstep() > 0) { |
| List<Writable> messages = getAndClearMessagesFromOtherWorkers(); |
| assertEquals(1, messages.size()); |
| assertEquals(getSuperstep() - 1, ((LongWritable)(messages.get(0))).get()); |
| } |
| } |
| |
| @Override |
| public void readFields(DataInput dataInput) throws IOException { |
| super.readFields(dataInput); |
| testValue = dataInput.readInt(); |
| } |
| |
| @Override |
| public void write(DataOutput dataOutput) throws IOException { |
| super.write(dataOutput); |
| dataOutput.writeInt(testValue); |
| } |
| } |
| |
| /** |
| * Master compute |
| */ |
| public static class CheckpointVertexMasterCompute extends |
| DefaultMasterCompute { |
| |
| private int testValue = 0; |
| |
| @Override |
| public void compute() { |
| long superstep = getSuperstep(); |
| if (SUPERSTEP_CALLBACK != null) { |
| SUPERSTEP_CALLBACK.superstep(getSuperstep(), getConf()); |
| } |
| assertEquals(superstep, testValue++); |
| } |
| |
| @Override |
| public void initialize() throws InstantiationException, |
| IllegalAccessException { |
| registerAggregator(LongSumAggregator.class.getName(), |
| LongSumAggregator.class); |
| } |
| |
| @Override |
| public void readFields(DataInput in) throws IOException { |
| super.readFields(in); |
| testValue = in.readInt(); |
| } |
| |
| @Override |
| public void write(DataOutput out) throws IOException { |
| super.write(out); |
| out.writeInt(testValue); |
| } |
| } |
| |
| private static interface SuperstepCallback { |
| |
| public void superstep(long superstep, |
| ImmutableClassesGiraphConfiguration<LongWritable, |
| IntWritable, FloatWritable> conf); |
| |
| } |
| |
| } |