SAMZA-2591: Async Commit [2/3]: Task Commit api changes and async commit (#1490)

Introduce new state backend APIs for blobstore and kafka changelog
Change the task commit lifecycle to separate snapshot, upload and cleanup phases
Make the TaskInstance commit upload and cleanup phases nonblocking
diff --git a/gradle/dependency-versions.gradle b/gradle/dependency-versions.gradle
index e289af0..ebd2d38 100644
--- a/gradle/dependency-versions.gradle
+++ b/gradle/dependency-versions.gradle
@@ -49,7 +49,7 @@
   yarnVersion = "2.7.1"
   zkClientVersion = "0.11"
   zookeeperVersion = "3.4.13"
-  failsafeVersion = "1.1.0"
+  failsafeVersion = "2.4.0"
   jlineVersion = "3.8.2"
   jnaVersion = "4.5.1"
   couchbaseClientVersion = "2.7.2"
diff --git a/samza-api/src/main/java/org/apache/samza/storage/StateBackendFactory.java b/samza-api/src/main/java/org/apache/samza/storage/StateBackendFactory.java
new file mode 100644
index 0000000..9946d2a
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/storage/StateBackendFactory.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage;
+
+import java.io.File;
+import java.util.concurrent.ExecutorService;
+import org.apache.samza.config.Config;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.JobContext;
+import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.util.Clock;
+
+
+/**
+ * Factory to build the Samza {@link TaskBackupManager}, {@link TaskRestoreManager} and {@link TaskStorageAdmin}
+ * for a particular state storage backend, which are used to durably backup the Samza task state.
+ */
+  public interface StateBackendFactory {
+  TaskBackupManager getBackupManager(JobContext jobContext,
+      ContainerContext containerContext,
+      TaskModel taskModel,
+      ExecutorService backupExecutor,
+      MetricsRegistry taskInstanceMetricsRegistry,
+      Config config,
+      Clock clock,
+      File loggedStoreBaseDir,
+      File nonLoggedStoreBaseDir);
+
+  TaskRestoreManager getRestoreManager(JobContext jobContext,
+      ContainerContext containerContext,
+      TaskModel taskModel,
+      ExecutorService restoreExecutor,
+      MetricsRegistry metricsRegistry,
+      Config config,
+      Clock clock,
+      File loggedStoreBaseDir,
+      File nonLoggedStoreBaseDir,
+      KafkaChangelogRestoreParams kafkaChangelogRestoreParams);
+
+  TaskStorageAdmin getAdmin();
+}
diff --git a/samza-api/src/main/java/org/apache/samza/storage/TaskBackupManager.java b/samza-api/src/main/java/org/apache/samza/storage/TaskBackupManager.java
new file mode 100644
index 0000000..a00d715
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/storage/TaskBackupManager.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage;
+
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+
+import javax.annotation.Nullable;
+import org.apache.samza.checkpoint.Checkpoint;
+import org.apache.samza.checkpoint.CheckpointId;
+
+
+/**
+ * <p>
+ * TaskBackupManager is the interface that must be implemented for any remote system that Samza persists its state to
+ * during the task commit operation.
+ * {@link #snapshot(CheckpointId)} will be evoked synchronous to task processing and get a snapshot of the stores
+ * state to be persisted for the commit. {@link #upload(CheckpointId, Map)} will then use the snapshotted state
+ * to persist to the underlying backup system and will be asynchronous to task processing.
+ * </p>
+ * The interface will be evoked in the following way:
+ * <ul>
+ *   <li>Snapshot will be called before Upload.</li>
+ *   <li>persistToFilesystem will be called after Upload is completed</li>
+ *   <li>Cleanup is only called after Upload and persistToFilesystem has successfully completed</li>
+ * </ul>
+ */
+public interface TaskBackupManager {
+
+  /**
+   * Initializes the TaskBackupManager instance.
+   *
+   * @param checkpoint last recorded checkpoint from the CheckpointManager or null if no last checkpoint was found
+   */
+  void init(@Nullable Checkpoint checkpoint);
+
+  /**
+   *  Snapshot is used to capture the current state of the stores in order to persist it to the backup manager in the
+   *  {@link #upload(CheckpointId, Map)} (CheckpointId, Map)} phase. Performs the commit operation that is
+   *  synchronous to processing. Returns the per store name state checkpoint markers to be used in upload.
+   *
+   * @param checkpointId {@link CheckpointId} of the current commit
+   * @return a map of store name to state checkpoint markers for stores managed by this state backend
+   */
+  Map<String, String> snapshot(CheckpointId checkpointId);
+
+  /**
+   * Upload is used to persist the state provided by the {@link #snapshot(CheckpointId)} to the
+   * underlying backup system. Commit operation that is asynchronous to message processing and returns a
+   * {@link CompletableFuture} containing the successfully uploaded state checkpoint markers .
+   *
+   * @param checkpointId {@link CheckpointId} of the current commit
+   * @param stateCheckpointMarkers the map of storename to state checkpoint markers returned by
+   *                               {@link #snapshot(CheckpointId)}
+   * @return a {@link CompletableFuture} containing a map of store name to state checkpoint markers
+   *         after the upload is complete
+   */
+  CompletableFuture<Map<String, String>> upload(CheckpointId checkpointId, Map<String, String> stateCheckpointMarkers);
+
+  /**
+   * Cleanup any local or remote state for checkpoint information that is older than the provided checkpointId
+   * This operation is required to be idempotent.
+   *
+   * @param checkpointId the {@link CheckpointId} of the last successfully committed checkpoint
+   * @param stateCheckpointMarkers a map of store name to state checkpoint markers returned by
+   *                               {@link #upload(CheckpointId, Map)} (CheckpointId, Map)} upload}
+   */
+  CompletableFuture<Void> cleanUp(CheckpointId checkpointId, Map<String, String> stateCheckpointMarkers);
+
+  /**
+   * Shutdown hook the backup manager to cleanup any allocated resources
+   */
+  void close();
+
+}
\ No newline at end of file
diff --git a/samza-api/src/main/java/org/apache/samza/storage/TaskRestoreManager.java b/samza-api/src/main/java/org/apache/samza/storage/TaskRestoreManager.java
new file mode 100644
index 0000000..999325e
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/storage/TaskRestoreManager.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage;
+
+import org.apache.samza.checkpoint.Checkpoint;
+
+
+/**
+ * The helper interface restores task state.
+ */
+public interface TaskRestoreManager {
+
+  /**
+   * Initialize state resources such as store directories.
+   */
+  void init(Checkpoint checkpoint);
+
+  /**
+   * Restore state from checkpoints, state snapshots and changelogs.
+   * Currently, store restoration happens on a separate thread pool within {@code ContainerStorageManager}. In case of
+   * interrupt/shutdown signals from {@code SamzaContainer}, {@code ContainerStorageManager} may interrupt the restore
+   * thread.
+   *
+   * Note: Typically, interrupt signals don't bubble up as {@link InterruptedException} unless the restore thread is
+   * waiting on IO/network. In case of busy looping, implementors are expected to check the interrupt status of the
+   * thread periodically and shutdown gracefully before throwing {@link InterruptedException} upstream.
+   * {@code SamzaContainer} will not wait for clean up and the interrupt signal is the best effort by the container
+   * to notify that its shutting down.
+   */
+  void restore() throws InterruptedException;
+
+  /**
+   * Closes all initiated ressources include storage engines
+   */
+  void close();
+}
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala b/samza-api/src/main/java/org/apache/samza/storage/TaskStorageAdmin.java
similarity index 62%
rename from samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala
rename to samza-api/src/main/java/org/apache/samza/storage/TaskStorageAdmin.java
index 50d6418..205077b 100644
--- a/samza-core/src/main/scala/org/apache/samza/storage/TaskStorageManager.scala
+++ b/samza-api/src/main/java/org/apache/samza/storage/TaskStorageAdmin.java
@@ -17,21 +17,14 @@
  * under the License.
  */
 
-package org.apache.samza.storage
+package org.apache.samza.storage;
 
-import org.apache.samza.checkpoint.CheckpointId
-import org.apache.samza.system.SystemStreamPartition
+/**
+ * Creates and validate resources for the StateBackendFactory
+ */
+public interface TaskStorageAdmin {
 
-trait TaskStorageManager {
+  void createResources();
 
-  def getStore(storeName: String): Option[StorageEngine]
-
-  def flush(): Map[SystemStreamPartition, Option[String]]
-
-  def checkpoint(checkpointId: CheckpointId, newestChangelogOffsets: Map[SystemStreamPartition, Option[String]]): Unit
-
-  def removeOldCheckpoints(checkpointId: CheckpointId): Unit
-
-  def stop(): Unit
-
-}
\ No newline at end of file
+  void validateResources();
+}
diff --git a/samza-api/src/main/java/org/apache/samza/system/SystemFactory.java b/samza-api/src/main/java/org/apache/samza/system/SystemFactory.java
index 08c1b49..2841bb1 100644
--- a/samza-api/src/main/java/org/apache/samza/system/SystemFactory.java
+++ b/samza-api/src/main/java/org/apache/samza/system/SystemFactory.java
@@ -73,7 +73,7 @@
    *
    * @param systemName The name of the system to create admin for.
    * @param config The config to create admin with.
-   * @param adminLabel a string to provide info the admin instance.
+   * @param adminLabel a string to provide info for the admin instance.
    * @return A SystemAdmin
    */
   default SystemAdmin getAdmin(String systemName, Config config, String adminLabel) {
diff --git a/samza-core/src/main/java/org/apache/samza/storage/KafkaChangelogStateBackendFactory.java b/samza-core/src/main/java/org/apache/samza/storage/KafkaChangelogStateBackendFactory.java
new file mode 100644
index 0000000..e3230f6
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/storage/KafkaChangelogStateBackendFactory.java
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage;
+
+import com.google.common.annotations.VisibleForTesting;
+import java.io.File;
+import java.time.Duration;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ExecutorService;
+import java.util.stream.Collectors;
+import org.apache.commons.collections4.MapUtils;
+import org.apache.samza.SamzaException;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.StorageConfig;
+import org.apache.samza.config.TaskConfig;
+import org.apache.samza.context.ContainerContext;
+import org.apache.samza.context.JobContext;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.TaskMode;
+import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.system.SSPMetadataCache;
+import org.apache.samza.system.StreamMetadataCache;
+import org.apache.samza.system.SystemAdmins;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.util.Clock;
+
+
+/**
+ * Class used the provide the {@link TaskRestoreManager} and the {@link TaskBackupManager} for the Kafka changelog
+ * state backend.
+ */
+public class KafkaChangelogStateBackendFactory implements StateBackendFactory {
+  private StreamMetadataCache streamCache;
+  /*
+   * This keeps track of the changelog SSPs that are associated with the whole container. This is used so that we can
+   * prefetch the metadata about the all of the changelog SSPs associated with the container whenever we need the
+   * metadata about some of the changelog SSPs.
+   * An example use case is when Samza writes offset files for stores ({@link TaskStorageManager}). Each task is
+   * responsible for its own offset file, but if we can do prefetching, then most tasks will already have cached
+   * metadata by the time they need the offset metadata.
+   * Note: By using all changelog streams to build the sspsToPrefetch, any fetches done for persisted stores will
+   * include the ssps for non-persisted stores, so this is slightly suboptimal. However, this does not increase the
+   * actual number of calls to the {@link SystemAdmin}, and we can decouple this logic from the per-task objects (e.g.
+   * {@link TaskStorageManager}).
+   */
+  private SSPMetadataCache sspCache;
+
+  @Override
+  public TaskBackupManager getBackupManager(JobContext jobContext,
+      ContainerContext containerContext,
+      TaskModel taskModel,
+      ExecutorService backupExecutor,
+      MetricsRegistry metricsRegistry,
+      Config config,
+      Clock clock,
+      File loggedStoreBaseDir,
+      File nonLoggedStoreBaseDir) {
+    SystemAdmins systemAdmins = new SystemAdmins(config);
+    StorageConfig storageConfig = new StorageConfig(config);
+    Map<String, SystemStream> storeChangelogs = storageConfig.getStoreChangelogs();
+
+    if (new TaskConfig(config).getTransactionalStateCheckpointEnabled()) {
+      return new KafkaTransactionalStateTaskBackupManager(taskModel.getTaskName(), storeChangelogs,
+          systemAdmins, taskModel.getChangelogPartition());
+    } else {
+      return new KafkaNonTransactionalStateTaskBackupManager(taskModel.getTaskName(), storeChangelogs,
+          systemAdmins, taskModel.getChangelogPartition());
+    }
+  }
+
+  @Override
+  public TaskRestoreManager getRestoreManager(JobContext jobContext,
+      ContainerContext containerContext,
+      TaskModel taskModel,
+      ExecutorService restoreExecutor,
+      MetricsRegistry metricsRegistry,
+      Config config,
+      Clock clock,
+      File loggedStoreBaseDir,
+      File nonLoggedStoreBaseDir,
+      KafkaChangelogRestoreParams kafkaChangelogRestoreParams) {
+    Map<String, SystemStream> storeChangelogs = new StorageConfig(config).getStoreChangelogs();
+    Set<SystemStreamPartition> changelogSSPs = getChangelogSSPForContainer(storeChangelogs, containerContext);
+    // filter out standby store-ssp pairs
+    Map<String, SystemStream> filteredStoreChangelogs =
+        filterStandbySystemStreams(storeChangelogs, containerContext.getContainerModel());
+    SystemAdmins systemAdmins = new SystemAdmins(kafkaChangelogRestoreParams.getSystemAdmins());
+
+    if (new TaskConfig(config).getTransactionalStateRestoreEnabled()) {
+      return new TransactionalStateTaskRestoreManager(
+          kafkaChangelogRestoreParams.getStoreNames(),
+          jobContext,
+          containerContext,
+          taskModel,
+          filteredStoreChangelogs,
+          kafkaChangelogRestoreParams.getInMemoryStores(),
+          kafkaChangelogRestoreParams.getStorageEngineFactories(),
+          kafkaChangelogRestoreParams.getSerdes(),
+          systemAdmins,
+          kafkaChangelogRestoreParams.getStoreConsumers(),
+          metricsRegistry,
+          kafkaChangelogRestoreParams.getCollector(),
+          getSspCache(systemAdmins, clock, changelogSSPs),
+          loggedStoreBaseDir,
+          nonLoggedStoreBaseDir,
+          config,
+          clock
+      );
+    } else {
+      return new NonTransactionalStateTaskRestoreManager(
+          kafkaChangelogRestoreParams.getStoreNames(),
+          jobContext,
+          containerContext,
+          taskModel,
+          filteredStoreChangelogs,
+          kafkaChangelogRestoreParams.getInMemoryStores(),
+          kafkaChangelogRestoreParams.getStorageEngineFactories(),
+          kafkaChangelogRestoreParams.getSerdes(),
+          systemAdmins,
+          getStreamCache(systemAdmins, clock),
+          kafkaChangelogRestoreParams.getStoreConsumers(),
+          metricsRegistry,
+          kafkaChangelogRestoreParams.getCollector(),
+          jobContext.getJobModel().getMaxChangeLogStreamPartitions(),
+          loggedStoreBaseDir,
+          nonLoggedStoreBaseDir,
+          config,
+          clock
+      );
+    }
+  }
+
+  @Override
+  public TaskStorageAdmin getAdmin() {
+    throw new SamzaException("getAdmin() method not supported for KafkaStateBackendFactory");
+  }
+
+  @VisibleForTesting
+  Set<SystemStreamPartition> getChangelogSSPForContainer(Map<String, SystemStream> storeChangelogs,
+      ContainerContext containerContext) {
+    return storeChangelogs.values().stream()
+        .flatMap(ss -> containerContext.getContainerModel().getTasks().values().stream()
+            .map(tm -> new SystemStreamPartition(ss, tm.getChangelogPartition())))
+        .collect(Collectors.toSet());
+  }
+
+  /**
+   * Shared cache across all KafkaRestoreManagers for the Kafka topic
+   *
+   * @param admins system admins used the fetch the stream metadata
+   * @param clock for cache invalidation
+   * @return StreamMetadataCache containing the stream metadata
+   */
+  @VisibleForTesting
+  StreamMetadataCache getStreamCache(SystemAdmins admins, Clock clock) {
+    if (streamCache == null) {
+      streamCache = new StreamMetadataCache(admins, 5000, clock);
+    }
+    return streamCache;
+  }
+
+  /**
+   * Shared cache across KafkaRestoreManagers for the Kafka partition
+   *
+   * @param admins system admins used the fetch the stream metadata
+   * @param clock for cache invalidation
+   * @param ssps SSPs to prefetch
+   * @return SSPMetadataCache containing the partition metadata
+   */
+  @VisibleForTesting
+  SSPMetadataCache getSspCache(SystemAdmins admins, Clock clock, Set<SystemStreamPartition> ssps) {
+    if (sspCache == null) {
+      sspCache = new SSPMetadataCache(admins, Duration.ofSeconds(5), clock, ssps);
+    }
+    return sspCache;
+  }
+
+  @VisibleForTesting
+  Map<String, SystemStream> filterStandbySystemStreams(Map<String, SystemStream> changelogSystemStreams,
+      ContainerModel containerModel) {
+    Map<SystemStreamPartition, String> changelogSSPToStore = new HashMap<>();
+    changelogSystemStreams.forEach((storeName, systemStream) ->
+        containerModel.getTasks().forEach((taskName, taskModel) -> {
+          if (TaskMode.Standby.equals(taskModel.getTaskMode())) {
+            changelogSSPToStore.put(new SystemStreamPartition(systemStream, taskModel.getChangelogPartition()),
+                storeName);
+          }
+        })
+    );
+    // changelogSystemStreams correspond only to active tasks (since those of standby-tasks moved to sideInputs above)
+    return MapUtils.invertMap(changelogSSPToStore).entrySet().stream()
+        .collect(Collectors.toMap(Map.Entry::getKey, x -> x.getValue().getSystemStream()));
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/storage/StorageManagerUtil.java b/samza-core/src/main/java/org/apache/samza/storage/StorageManagerUtil.java
index cfd24d9..91c0c7c 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/StorageManagerUtil.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/StorageManagerUtil.java
@@ -38,12 +38,15 @@
 import java.util.Set;
 import java.util.stream.Collectors;
 import org.apache.samza.SamzaException;
+import org.apache.samza.checkpoint.CheckpointId;
+import org.apache.samza.checkpoint.CheckpointV2;
 import org.apache.samza.clustermanager.StandbyTaskUtil;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.StorageConfig;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.job.model.TaskMode;
 import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.serializers.CheckpointV2Serde;
 import org.apache.samza.serializers.model.SamzaObjectMapper;
 import org.apache.samza.system.SystemAdmin;
 import org.apache.samza.system.SystemStream;
@@ -56,14 +59,27 @@
 
 public class StorageManagerUtil {
   private static final Logger LOG = LoggerFactory.getLogger(StorageManagerUtil.class);
+  public static final String CHECKPOINT_FILE_NAME = "CHECKPOINT-V2";
   public static final String OFFSET_FILE_NAME_NEW = "OFFSET-v2";
   public static final String OFFSET_FILE_NAME_LEGACY = "OFFSET";
   public static final String SIDE_INPUT_OFFSET_FILE_NAME_LEGACY = "SIDE-INPUT-OFFSETS";
   private static final ObjectMapper OBJECT_MAPPER = SamzaObjectMapper.getObjectMapper();
   private static final TypeReference<Map<SystemStreamPartition, String>> OFFSETS_TYPE_REFERENCE =
             new TypeReference<Map<SystemStreamPartition, String>>() { };
-  private static final ObjectWriter OBJECT_WRITER = OBJECT_MAPPER.writerWithType(OFFSETS_TYPE_REFERENCE);
+  private static final ObjectWriter SSP_OFFSET_OBJECT_WRITER = OBJECT_MAPPER.writerFor(OFFSETS_TYPE_REFERENCE);
   private static final String SST_FILE_SUFFIX = ".sst";
+  private static final CheckpointV2Serde CHECKPOINT_V2_SERDE = new CheckpointV2Serde();
+
+  /**
+   * Returns the path for a storage engine to create its checkpoint based on the current checkpoint id.
+   *
+   * @param taskStoreDir directory of the store as returned by {@link #getTaskStoreDir}
+   * @param checkpointId current checkpoint id
+   * @return String denoting the file path of the store with the given checkpoint id
+   */
+  public static String getCheckpointDirPath(File taskStoreDir, CheckpointId checkpointId) {
+    return taskStoreDir.getPath() + "-" + checkpointId.serialize();
+  }
 
   /**
    * Fetch the starting offset for the input {@link SystemStreamPartition}
@@ -109,6 +125,7 @@
    * @param isSideInput true if store is a side-input store, false if it is a regular store
    * @return true if the store is stale, false otherwise
    */
+  // TODO BLOCKER dchen do these methods need to be updated to also read the new checkpoint file?
   public boolean isStaleStore(File storeDir, long storeDeleteRetentionInMs, long currentTimeMs, boolean isSideInput) {
     long offsetFileLastModifiedTime;
     boolean isStaleStore = false;
@@ -118,7 +135,7 @@
 
       // We check if the new offset-file exists, if so we use its last-modified time, if it doesn't we use the legacy file
       // depending on if it is a side-input or not,
-      // if neither exists, we use 0L (the defauilt return value of lastModified() when file does not exist
+      // if neither exists, we use 0L (the default return value of lastModified() when file does not exist
       File offsetFileRefNew = new File(storeDir, OFFSET_FILE_NAME_NEW);
       File offsetFileRefLegacy = new File(storeDir, OFFSET_FILE_NAME_LEGACY);
       File sideInputOffsetFileRefLegacy = new File(storeDir, SIDE_INPUT_OFFSET_FILE_NAME_LEGACY);
@@ -179,6 +196,7 @@
    * @param isSideInput true if store is a side-input store, false if it is a regular store
    * @return true if the offset file is valid. false otherwise.
    */
+  // TODO BLOCKER dchen do these methods need to be updated to also read the new checkpoint file?
   public boolean isOffsetFileValid(File storeDir, Set<SystemStreamPartition> storeSSPs, boolean isSideInput) {
     boolean hasValidOffsetFile = false;
     if (storeDir.exists()) {
@@ -210,14 +228,14 @@
 
     // First, we write the new-format offset file
     File offsetFile = new File(storeDir, OFFSET_FILE_NAME_NEW);
-    String fileContents = OBJECT_WRITER.writeValueAsString(offsets);
+    String fileContents = SSP_OFFSET_OBJECT_WRITER.writeValueAsString(offsets);
     FileUtil fileUtil = new FileUtil();
     fileUtil.writeWithChecksum(offsetFile, fileContents);
 
     // Now we write the old format offset file, which are different for store-offset and side-inputs
     if (isSideInput) {
       offsetFile = new File(storeDir, SIDE_INPUT_OFFSET_FILE_NAME_LEGACY);
-      fileContents = OBJECT_WRITER.writeValueAsString(offsets);
+      fileContents = SSP_OFFSET_OBJECT_WRITER.writeValueAsString(offsets);
       fileUtil.writeWithChecksum(offsetFile, fileContents);
     } else {
       offsetFile = new File(storeDir, OFFSET_FILE_NAME_LEGACY);
@@ -226,6 +244,19 @@
   }
 
   /**
+   * Writes the checkpoint to the store checkpoint directory based on the checkpointId.
+   *
+   * @param storeDir store or store checkpoint directory to write the checkpoint to
+   * @param checkpoint checkpoint v2 containing the checkpoint Id
+   */
+  public void writeCheckpointV2File(File storeDir, CheckpointV2 checkpoint) {
+    File offsetFile = new File(storeDir, CHECKPOINT_FILE_NAME);
+    byte[] fileContents = CHECKPOINT_V2_SERDE.toBytes(checkpoint);
+    FileUtil fileUtil = new FileUtil();
+    fileUtil.writeWithChecksum(offsetFile, new String(fileContents));
+  }
+
+  /**
    * Delete the offset file for this store, if one exists.
    * @param storeDir the directory of the store
    */
@@ -284,6 +315,24 @@
   }
 
   /**
+   * Read and return the {@link CheckpointV2} from the directory's {@link #CHECKPOINT_FILE_NAME} file.
+   * If the file does not exist, returns null.
+   * // TODO HIGH dchen add tests at all call sites for handling null value.
+   *
+   * @param storagePartitionDir store directory to read the checkpoint file from
+   * @return the {@link CheckpointV2} object retrieved from the checkpoint file if found, otherwise return null
+   */
+  public CheckpointV2 readCheckpointV2File(File storagePartitionDir) {
+    File checkpointFile = new File(storagePartitionDir, CHECKPOINT_FILE_NAME);
+    if (checkpointFile.exists()) {
+      String serializedCheckpointV2 = new FileUtil().readWithChecksum(checkpointFile);
+      return new CheckpointV2Serde().fromBytes(serializedCheckpointV2.getBytes());
+    } else {
+      return null;
+    }
+  }
+
+  /**
    * Read and return the contents of the offset file.
    *
    * @param storagePartitionDir the base directory of the store
diff --git a/samza-core/src/main/java/org/apache/samza/storage/TaskStorageCommitManager.java b/samza-core/src/main/java/org/apache/samza/storage/TaskStorageCommitManager.java
new file mode 100644
index 0000000..cc80a48
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/storage/TaskStorageCommitManager.java
@@ -0,0 +1,365 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage;
+
+import com.google.common.annotations.VisibleForTesting;
+import java.io.File;
+import java.io.FileFilter;
+import java.io.IOException;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.io.filefilter.WildcardFileFilter;
+import org.apache.samza.Partition;
+import org.apache.samza.SamzaException;
+import org.apache.samza.checkpoint.Checkpoint;
+import org.apache.samza.checkpoint.CheckpointId;
+import org.apache.samza.checkpoint.CheckpointManager;
+import org.apache.samza.checkpoint.CheckpointV1;
+import org.apache.samza.checkpoint.CheckpointV2;
+import org.apache.samza.checkpoint.kafka.KafkaChangelogSSPOffset;
+import org.apache.samza.config.Config;
+import org.apache.samza.container.TaskInstanceMetrics;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.job.model.TaskMode;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamPartition;
+import org.apache.samza.util.FutureUtil;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Handles the commit of the state stores of the task.
+ */
+public class TaskStorageCommitManager {
+
+  private static final Logger LOG = LoggerFactory.getLogger(TaskStorageCommitManager.class);
+
+  private final TaskName taskName;
+  private final CheckpointManager checkpointManager;
+  private final ContainerStorageManager containerStorageManager;
+  private final Map<String, TaskBackupManager> stateBackendToBackupManager;
+  private final Partition taskChangelogPartition;
+  private final StorageManagerUtil storageManagerUtil;
+  private final ExecutorService backupExecutor;
+  private final File durableStoreBaseDir;
+  private final Map<String, SystemStream> storeChangelogs;
+  private final TaskInstanceMetrics metrics;
+
+  // Available after init(), since stores are created by ContainerStorageManager#start()
+  private Map<String, StorageEngine> storageEngines;
+
+  public TaskStorageCommitManager(TaskName taskName, Map<String, TaskBackupManager> stateBackendToBackupManager,
+      ContainerStorageManager containerStorageManager, Map<String, SystemStream> storeChangelogs, Partition changelogPartition,
+      CheckpointManager checkpointManager, Config config, ExecutorService backupExecutor,
+      StorageManagerUtil storageManagerUtil, File durableStoreBaseDir, TaskInstanceMetrics metrics) {
+    this.taskName = taskName;
+    this.containerStorageManager = containerStorageManager;
+    this.stateBackendToBackupManager = stateBackendToBackupManager;
+    this.taskChangelogPartition = changelogPartition;
+    this.checkpointManager = checkpointManager;
+    this.backupExecutor = backupExecutor;
+    this.durableStoreBaseDir = durableStoreBaseDir;
+    this.storeChangelogs = storeChangelogs;
+    this.storageManagerUtil = storageManagerUtil;
+    this.metrics = metrics;
+  }
+
+  public void init() {
+    // Assuming that container storage manager has already started and created to stores
+    storageEngines = containerStorageManager.getAllStores(taskName);
+    if (checkpointManager != null) {
+      Checkpoint checkpoint = checkpointManager.readLastCheckpoint(taskName);
+      LOG.debug("Last checkpoint on start for task: {} is: {}", taskName, checkpoint);
+      stateBackendToBackupManager.values()
+          .forEach(storageBackupManager -> storageBackupManager.init(checkpoint));
+    } else {
+      stateBackendToBackupManager.values()
+          .forEach(storageBackupManager -> storageBackupManager.init(null));
+    }
+  }
+
+  /**
+   * Synchronously captures the current state of the stores in order to persist it to the backup manager
+   * in the async {@link #upload(CheckpointId, Map)} phase. Returns a map of state backend factory name to
+   * a map of store name to state checkpoint markers for all configured state backends and stores.
+   *
+   * @param checkpointId {@link CheckpointId} of the current commit
+   * @return a map of state backend factory name to a map of store name to state checkpoint markers
+   */
+  public Map<String, Map<String, String>> snapshot(CheckpointId checkpointId) {
+    // Flush all stores
+    storageEngines.values().forEach(StorageEngine::flush);
+    LOG.debug("Flushed all storage engines for taskName: {}, checkpoint id: {}",
+        taskName, checkpointId);
+
+    long checkpointStartNs = System.nanoTime();
+    // Checkpoint all persisted and durable stores
+    storageEngines.forEach((storeName, storageEngine) -> {
+      if (storageEngine.getStoreProperties().isPersistedToDisk() &&
+          storageEngine.getStoreProperties().isDurableStore()) {
+        storageEngine.checkpoint(checkpointId);
+      }
+    });
+    long checkpointNs = System.nanoTime() - checkpointStartNs;
+    metrics.storeCheckpointNs().update(checkpointNs);
+    LOG.debug("Checkpointed all storage engines for taskName: {}, checkpoint id: {} in {} ns",
+        taskName, checkpointId, checkpointNs);
+
+    // state backend factory -> store Name -> state checkpoint marker
+    Map<String, Map<String, String>> stateBackendToStoreSCMs = new HashMap<>();
+
+    // for each configured state backend factory, backup the state for all stores in this task.
+    stateBackendToBackupManager.forEach((stateBackendFactoryName, backupManager) -> {
+      Map<String, String> snapshotSCMs = backupManager.snapshot(checkpointId);
+      LOG.debug("Created snapshot for taskName: {}, checkpoint id: {}, state backend: {}. Snapshot SCMs: {}",
+          taskName, checkpointId, stateBackendFactoryName, snapshotSCMs);
+      stateBackendToStoreSCMs.put(stateBackendFactoryName, snapshotSCMs);
+    });
+
+    return stateBackendToStoreSCMs;
+  }
+
+  /**
+   * Asynchronously backs up the local state to the remote storage and returns a future containing the committed
+   * map of state backend factory name to the map of store name to state checkpoint marker.
+   *
+   * @param checkpointId the {@link CheckpointId} associated with this commit
+   * @return a future containing  the Map of FactoryName to (Map of StoreName to StateCheckpointMarker).
+   */
+  public CompletableFuture<Map<String, Map<String, String>>> upload(
+      CheckpointId checkpointId, Map<String, Map<String, String>> snapshotSCMs) {
+    // state backend factory -> store Name -> state checkpoint marker
+    Map<String, CompletableFuture<Map<String, String>>> stateBackendToStoreSCMs = new HashMap<>();
+
+    // for each configured state backend factory, backup the state for all stores in this task.
+    stateBackendToBackupManager.forEach((stateBackendFactoryName, backupManager) -> {
+      try {
+        Map<String, String> factorySnapshotSCMs =
+            snapshotSCMs.getOrDefault(stateBackendFactoryName, Collections.emptyMap());
+        LOG.debug("Starting upload for taskName: {}, checkpoint id: {}, state backend snapshot SCM: {}",
+            taskName, checkpointId, factorySnapshotSCMs);
+
+        CompletableFuture<Map<String, String>> uploadFuture =
+            backupManager.upload(checkpointId, factorySnapshotSCMs);
+        uploadFuture.thenAccept(uploadSCMs ->
+            LOG.debug("Finished upload for taskName: {}, checkpoint id: {}, state backend: {}. Upload SCMs: {}",
+                taskName, checkpointId, stateBackendFactoryName, uploadSCMs));
+
+        stateBackendToStoreSCMs.put(stateBackendFactoryName, uploadFuture);
+      } catch (Exception e) {
+        throw new SamzaException(
+            String.format("Error backing up local state for taskName: %s, checkpoint id: %s, state backend: %s",
+                taskName, checkpointId, stateBackendFactoryName), e);
+      }
+    });
+
+    return FutureUtil.toFutureOfMap(stateBackendToStoreSCMs);
+  }
+
+  /**
+   * Writes the {@link Checkpoint} information returned by {@link #upload(CheckpointId, Map)}
+   * in each store directory and store checkpoint directory. Written content depends on the type of {@code checkpoint}.
+   * For {@link CheckpointV2}, writes the entire task {@link CheckpointV2}.
+   * For {@link CheckpointV1}, only writes the changelog ssp offsets in the OFFSET* files.
+   *
+   * Note: The assumption is that this method will be invoked once for each {@link Checkpoint} version that the
+   * task needs to write as determined by {@link org.apache.samza.config.TaskConfig#getCheckpointWriteVersions()}.
+   * This is required for upgrade and rollback compatibility.
+   *
+   * @param checkpoint the latest checkpoint to be persisted to local file system
+   */
+  public void writeCheckpointToStoreDirectories(Checkpoint checkpoint) {
+    if (checkpoint instanceof CheckpointV1) {
+      LOG.debug("Writing CheckpointV1 to store and checkpoint directories for taskName: {} with checkpoint: {}",
+          taskName, checkpoint);
+      // Write CheckpointV1 changelog offsets to store and checkpoint directories
+      writeChangelogOffsetFiles(checkpoint.getOffsets());
+    } else if (checkpoint instanceof CheckpointV2) {
+      LOG.debug("Writing CheckpointV2 to store and checkpoint directories for taskName: {} with checkpoint: {}",
+          taskName, checkpoint);
+      storageEngines.forEach((storeName, storageEngine) -> {
+        // Only write the checkpoint file if the store is durable and persisted to disk
+        if (storageEngine.getStoreProperties().isDurableStore() &&
+            storageEngine.getStoreProperties().isPersistedToDisk()) {
+          CheckpointV2 checkpointV2 = (CheckpointV2) checkpoint;
+
+          try {
+            File storeDir = storageManagerUtil.getTaskStoreDir(durableStoreBaseDir, storeName, taskName, TaskMode.Active);
+            storageManagerUtil.writeCheckpointV2File(storeDir, checkpointV2);
+
+            CheckpointId checkpointId = checkpointV2.getCheckpointId();
+            File checkpointDir = Paths.get(StorageManagerUtil.getCheckpointDirPath(storeDir, checkpointId)).toFile();
+            storageManagerUtil.writeCheckpointV2File(checkpointDir, checkpointV2);
+          } catch (Exception e) {
+            throw new SamzaException(
+                String.format("Write checkpoint file failed for task: %s, storeName: %s, checkpointId: %s",
+                    taskName, storeName, ((CheckpointV2) checkpoint).getCheckpointId()), e);
+          }
+        }
+      });
+    } else {
+      throw new SamzaException("Unsupported checkpoint version: " + checkpoint.getVersion());
+    }
+  }
+
+  /**
+   * Performs any post-commit and cleanup actions after the {@link Checkpoint} is successfully written to the
+   * checkpoint topic. Invokes {@link TaskBackupManager#cleanUp(CheckpointId, Map)} on each of the configured task
+   * backup managers. Deletes all local store checkpoint directories older than the {@code latestCheckpointId}.
+   *
+   * @param latestCheckpointId CheckpointId of the most recent successful commit
+   * @param stateCheckpointMarkers map of map(stateBackendFactoryName to map(storeName to state checkpoint markers) from
+   *                               the latest commit
+   */
+  public CompletableFuture<Void> cleanUp(CheckpointId latestCheckpointId,
+      Map<String, Map<String, String>> stateCheckpointMarkers) {
+    List<CompletableFuture<Void>> cleanUpFutures = new ArrayList<>();
+
+    // Call cleanup on each backup manager
+    stateCheckpointMarkers.forEach((factoryName, storeSCMs) -> {
+      if (stateBackendToBackupManager.containsKey(factoryName)) {
+        LOG.debug("Cleaning up commit for factory: {} for task: {}", factoryName, taskName);
+        TaskBackupManager backupManager = stateBackendToBackupManager.get(factoryName);
+        cleanUpFutures.add(backupManager.cleanUp(latestCheckpointId, storeSCMs));
+      } else {
+        // This may happen during migration from one state backend to another, where the latest commit contains
+        // a state backend that is no longer supported for the current commit manager
+        LOG.warn("Ignored cleanup for scm: {} due to unknown factory: {} ", storeSCMs, factoryName);
+      }
+    });
+
+    return FutureUtil.allOf(cleanUpFutures)
+        .thenAcceptAsync(aVoid -> deleteOldCheckpointDirs(latestCheckpointId), backupExecutor);
+  }
+
+  private void deleteOldCheckpointDirs(CheckpointId latestCheckpointId) {
+    // Delete directories for checkpoints older than latestCheckpointId
+    if (latestCheckpointId != null) {
+      LOG.debug("Deleting checkpoints older than checkpoint id: {}", latestCheckpointId);
+      File[] files = durableStoreBaseDir.listFiles();
+      if (files != null) {
+        for (File storeDir : files) {
+          String storeName = storeDir.getName();
+          String taskStoreName = storageManagerUtil
+              .getTaskStoreDir(durableStoreBaseDir, storeName, taskName, TaskMode.Active).getName();
+          FileFilter fileFilter = new WildcardFileFilter(taskStoreName + "-*");
+          File[] checkpointDirs = storeDir.listFiles(fileFilter);
+          if (checkpointDirs != null) {
+            for (File checkpointDir : checkpointDirs) {
+              if (!checkpointDir.getName().contains(latestCheckpointId.serialize())) {
+                try {
+                  FileUtils.deleteDirectory(checkpointDir);
+                } catch (IOException e) {
+                  throw new SamzaException(
+                      String.format("Unable to delete checkpoint directory: %s", checkpointDir.getName()), e);
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * Close all the state backup managers
+   */
+  public void close() {
+    LOG.debug("Stopping backup managers for task {}.", taskName);
+    stateBackendToBackupManager.values().forEach(storageBackupManager -> {
+      if (storageBackupManager != null) {
+        storageBackupManager.close();
+      }
+    });
+  }
+
+  /**
+   * Writes the newest changelog ssp offset for each logged and persistent store to the OFFSET file in the current
+   * store directory (for allowing rollbacks). If the Kafka transactional backup manager is enabled, also writes to
+   * the store checkpoint directory.
+   *
+   * These files are used during container startup to ensure transactional state, and to determine whether the
+   * there is any new information in the changelog that is not reflected in the on-disk copy of the store.
+   * If there is any delta, it is replayed from the changelog. E.g. this can happen if the job was run on this host,
+   * then another host, and then back to this host.
+   */
+  @VisibleForTesting
+  void writeChangelogOffsetFiles(Map<SystemStreamPartition, String> checkpointOffsets) {
+    if (storageEngines == null) {
+      throw new SamzaException(String.format(
+          "Storage engines are not initialized and writeChangelogOffsetFiles not be written for task %s", taskName));
+    }
+    storeChangelogs.forEach((storeName, systemStream) -> {
+      SystemStreamPartition changelogSSP = new SystemStreamPartition(
+          systemStream.getSystem(), systemStream.getStream(), taskChangelogPartition);
+
+      // Only write if the store is durable and persisted to disk
+      if (checkpointOffsets.containsKey(changelogSSP) &&
+          storageEngines.containsKey(storeName) &&
+          storageEngines.get(storeName).getStoreProperties().isDurableStore() &&
+          storageEngines.get(storeName).getStoreProperties().isPersistedToDisk()) {
+        LOG.debug("Writing changelog offset for taskName {} store {} changelog {}.", taskName, storeName, systemStream);
+        File currentStoreDir = storageManagerUtil.getTaskStoreDir(durableStoreBaseDir, storeName, taskName, TaskMode.Active);
+        try {
+          KafkaChangelogSSPOffset kafkaChangelogSSPOffset = KafkaChangelogSSPOffset
+              .fromString(checkpointOffsets.get(changelogSSP));
+          // Write offsets to file system if it is non-null
+          String newestOffset = kafkaChangelogSSPOffset.getChangelogOffset();
+          if (newestOffset != null) {
+            // Write changelog SSP offset to the OFFSET files in the task store directory
+            writeChangelogOffsetFile(storeName, changelogSSP, newestOffset, currentStoreDir);
+
+            // Write changelog SSP offset to the OFFSET files in the store checkpoint directory
+            File checkpointDir = Paths.get(StorageManagerUtil.getCheckpointDirPath(
+                currentStoreDir, kafkaChangelogSSPOffset.getCheckpointId())).toFile();
+            writeChangelogOffsetFile(storeName, changelogSSP, newestOffset, checkpointDir);
+          } else {
+            // If newestOffset is null, then it means the changelog ssp is (or has become) empty. This could be
+            // either because the changelog topic was newly added, repartitioned, or manually deleted and recreated.
+            // No need to persist the offset file.
+            LOG.debug("Deleting OFFSET file for taskName {} store {} changelog ssp {} since the newestOffset is null.",
+                taskName, storeName, changelogSSP);
+            storageManagerUtil.deleteOffsetFile(currentStoreDir);
+          }
+        } catch (IOException e) {
+          throw new SamzaException(
+              String.format("Error storing offset for taskName %s store %s changelog %s.", taskName, storeName,
+                  systemStream), e);
+        }
+      }
+    });
+    LOG.debug("Done writing OFFSET files for logged persistent key value stores for task {}", taskName);
+  }
+
+  @VisibleForTesting
+  void writeChangelogOffsetFile(String storeName, SystemStreamPartition ssp, String newestOffset,
+      File writeDirectory) throws IOException {
+    LOG.debug("Storing newest offset {} for taskName {} store {} changelog ssp {} in OFFSET file at path: {}.",
+        newestOffset, taskName, storeName, ssp, writeDirectory);
+    storageManagerUtil.writeOffsetFile(writeDirectory, Collections.singletonMap(ssp, newestOffset), false);
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/system/inmemory/InMemoryManager.java b/samza-core/src/main/java/org/apache/samza/system/inmemory/InMemoryManager.java
index 13ebf6e..bc4b227 100644
--- a/samza-core/src/main/java/org/apache/samza/system/inmemory/InMemoryManager.java
+++ b/samza-core/src/main/java/org/apache/samza/system/inmemory/InMemoryManager.java
@@ -54,7 +54,7 @@
   }
 
   private List<IncomingMessageEnvelope> newSynchronizedLinkedList() {
-    return  Collections.synchronizedList(new LinkedList<IncomingMessageEnvelope>());
+    return Collections.synchronizedList(new LinkedList<IncomingMessageEnvelope>());
   }
 
   /**
diff --git a/samza-core/src/main/java/org/apache/samza/table/retry/AsyncRetriableTable.java b/samza-core/src/main/java/org/apache/samza/table/retry/AsyncRetriableTable.java
index 589fb14..938a1d3 100644
--- a/samza-core/src/main/java/org/apache/samza/table/retry/AsyncRetriableTable.java
+++ b/samza-core/src/main/java/org/apache/samza/table/retry/AsyncRetriableTable.java
@@ -20,15 +20,13 @@
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
+import net.jodah.failsafe.RetryPolicy;
 
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ScheduledExecutorService;
-
 import java.util.function.Predicate;
-import net.jodah.failsafe.RetryPolicy;
-
 import org.apache.samza.context.Context;
 import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.table.AsyncReadWriteTable;
@@ -36,9 +34,8 @@
 import org.apache.samza.table.remote.TableWriteFunction;
 import org.apache.samza.table.utils.TableMetricsUtil;
 
-import static org.apache.samza.table.retry.FailsafeAdapter.failsafe;
-
 import static org.apache.samza.table.BaseReadWriteTable.Func1;
+import static org.apache.samza.table.retry.FailsafeAdapter.failsafe;
 
 
 /**
@@ -156,13 +153,13 @@
 
   private <T> CompletableFuture<T> doRead(Func1<T> func) {
     return readRetryPolicy != null
-        ? failsafe(readRetryPolicy, readRetryMetrics, retryExecutor).future(() -> func.apply())
+        ? failsafe(readRetryPolicy, readRetryMetrics, retryExecutor).getStageAsync(() ->  func.apply())
         : func.apply();
   }
 
   private <T> CompletableFuture<T> doWrite(Func1<T> func) {
     return writeRetryPolicy != null
-        ? failsafe(writeRetryPolicy, writeRetryMetrics, retryExecutor).future(() -> func.apply())
+        ? failsafe(writeRetryPolicy, writeRetryMetrics, retryExecutor).getStageAsync(() -> func.apply())
         : func.apply();
   }
 }
diff --git a/samza-core/src/main/java/org/apache/samza/table/retry/FailsafeAdapter.java b/samza-core/src/main/java/org/apache/samza/table/retry/FailsafeAdapter.java
index 650d03a..a9466b0 100644
--- a/samza-core/src/main/java/org/apache/samza/table/retry/FailsafeAdapter.java
+++ b/samza-core/src/main/java/org/apache/samza/table/retry/FailsafeAdapter.java
@@ -19,15 +19,14 @@
 
 package org.apache.samza.table.retry;
 
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.TimeUnit;
-
-import org.apache.samza.SamzaException;
-
-import net.jodah.failsafe.AsyncFailsafe;
 import net.jodah.failsafe.Failsafe;
+import net.jodah.failsafe.FailsafeExecutor;
 import net.jodah.failsafe.RetryPolicy;
 
+import java.time.temporal.ChronoUnit;
+import java.util.concurrent.ScheduledExecutorService;
+import org.apache.samza.SamzaException;
+
 
 /**
  * Helper class adapting the generic {@link TableRetryPolicy} to a failsafe {@link RetryPolicy} and
@@ -39,23 +38,24 @@
    * @return this policy instance
    */
   static RetryPolicy valueOf(TableRetryPolicy policy) {
-    RetryPolicy failSafePolicy = new RetryPolicy();
+    // max retries default changed to 2 in v2.0. switching back to infinite retries by default for back compat.
+    RetryPolicy failSafePolicy = new RetryPolicy().withMaxRetries(-1);
 
     switch (policy.getBackoffType()) {
       case NONE:
         break;
 
       case FIXED:
-        failSafePolicy.withDelay(policy.getSleepTime().toMillis(), TimeUnit.MILLISECONDS);
+        failSafePolicy.withDelay(policy.getSleepTime());
         break;
 
       case RANDOM:
-        failSafePolicy.withDelay(policy.getRandomMin().toMillis(), policy.getRandomMax().toMillis(), TimeUnit.MILLISECONDS);
+        failSafePolicy.withDelay(policy.getRandomMin().toMillis(), policy.getRandomMax().toMillis(), ChronoUnit.MILLIS);
         break;
 
       case EXPONENTIAL:
-        failSafePolicy.withBackoff(policy.getSleepTime().toMillis(), policy.getExponentialMaxSleep().toMillis(), TimeUnit.MILLISECONDS,
-            policy.getExponentialFactor());
+        failSafePolicy.withBackoff(policy.getSleepTime().toMillis(), policy.getExponentialMaxSleep().toMillis(),
+            ChronoUnit.MILLIS, policy.getExponentialFactor());
         break;
 
       default:
@@ -63,17 +63,16 @@
     }
 
     if (policy.getMaxDuration() != null) {
-      failSafePolicy.withMaxDuration(policy.getMaxDuration().toMillis(), TimeUnit.MILLISECONDS);
+      failSafePolicy.withMaxDuration(policy.getMaxDuration());
     }
     if (policy.getMaxAttempts() != null) {
       failSafePolicy.withMaxRetries(policy.getMaxAttempts());
     }
     if (policy.getJitter() != null && policy.getBackoffType() != TableRetryPolicy.BackoffType.RANDOM) {
-      failSafePolicy.withJitter(policy.getJitter().toMillis(), TimeUnit.MILLISECONDS);
+      failSafePolicy.withJitter(policy.getJitter());
     }
 
-    failSafePolicy.retryOn(e -> policy.getRetryPredicate().test(e));
-
+    failSafePolicy.abortOn(policy.getRetryPredicate().negate());
     return failSafePolicy;
   }
 
@@ -82,22 +81,24 @@
    * @param retryPolicy retry policy
    * @param metrics retry metrics
    * @param retryExec executor service for scheduling async retries
-   * @return {@link net.jodah.failsafe.AsyncFailsafe} instance
+   * @return {@link net.jodah.failsafe.FailsafeExecutor} instance
    */
-  static AsyncFailsafe<?> failsafe(RetryPolicy retryPolicy, RetryMetrics metrics, ScheduledExecutorService retryExec) {
+  static <T> FailsafeExecutor<T> failsafe(RetryPolicy<T> retryPolicy, RetryMetrics metrics, ScheduledExecutorService retryExec) {
     long startMs = System.currentTimeMillis();
-    return Failsafe.with(retryPolicy).with(retryExec)
+
+    RetryPolicy<T> retryPolicyWithMetrics = retryPolicy
         .onRetry(e -> metrics.retryCount.inc())
         .onRetriesExceeded(e -> {
           metrics.retryTimer.update(System.currentTimeMillis() - startMs);
           metrics.permFailureCount.inc();
-        })
-        .onSuccess((e, ctx) -> {
-          if (ctx.getExecutions() > 1) {
+        }).onSuccess((e) -> {
+          if (e.getAttemptCount() > 1) {
             metrics.retryTimer.update(System.currentTimeMillis() - startMs);
           } else {
             metrics.successCount.inc();
           }
         });
+
+    return Failsafe.with(retryPolicyWithMetrics).with(retryExec);
   }
 }
diff --git a/samza-core/src/main/java/org/apache/samza/util/FutureUtil.java b/samza-core/src/main/java/org/apache/samza/util/FutureUtil.java
new file mode 100644
index 0000000..dc527e9
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/util/FutureUtil.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.util;
+
+import net.jodah.failsafe.Failsafe;
+import net.jodah.failsafe.RetryPolicy;
+
+import java.time.Duration;
+import java.time.temporal.ChronoUnit;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionStage;
+import java.util.concurrent.ExecutorService;
+import java.util.function.Predicate;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import org.apache.commons.lang3.tuple.Pair;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+public class FutureUtil {
+  private static final Logger LOG = LoggerFactory.getLogger(FutureUtil.class);
+
+  /**
+   * Returns a future that completes when all the futures in the provided collections of futures are complete.
+   * @param futureCollections collections of futures to complete before the returned future is complete
+   */
+  @SafeVarargs
+  public static CompletableFuture<Void> allOf(Collection<? extends CompletionStage<?>>... futureCollections) {
+    List<CompletableFuture<Void>> fvs = new ArrayList<>();
+    for (Collection<? extends CompletionStage<?>> futureCollection : futureCollections) {
+      if (!futureCollection.isEmpty()) {
+        fvs.add(CompletableFuture.allOf(futureCollection.toArray(new CompletableFuture[0])));
+      }
+    }
+
+    return CompletableFuture.allOf(fvs.toArray(new CompletableFuture[0]));
+  }
+
+  /**
+   * Returns a future that completes when all the futures futures are complete.
+   * Returned future completes exceptionally if any future complete with a non-ignored error.
+   */
+  public static CompletableFuture<Void> allOf(Predicate<Throwable> ignoreError, CompletableFuture<?>... futures) {
+    CompletableFuture<Void> allFuture = CompletableFuture.allOf(futures);
+    return allFuture.handle((aVoid, t) -> {
+      for (CompletableFuture<?> future : futures) {
+        try {
+          future.join();
+        } catch (Throwable th) {
+          if (ignoreError.test(th)) {
+            // continue
+          } else {
+            throw th;
+          }
+        }
+      }
+      return null;
+    });
+  }
+
+  /**
+   * Helper method to convert: {@code Pair<CompletableFuture<L>, CompletableFuture<R>>}
+   * to:                       {@code CompletableFuture<Pair<L, R>>}
+   *
+   * Returns a future that completes when both futures complete.
+   * Returned future completes exceptionally if either of the futures complete exceptionally.
+   */
+  public static <L, R> CompletableFuture<Pair<L, R>> toFutureOfPair(
+      Pair<CompletableFuture<L>, CompletableFuture<R>> pairOfFutures) {
+    return CompletableFuture
+        .allOf(pairOfFutures.getLeft(), pairOfFutures.getRight())
+        .thenApply(v -> Pair.of(pairOfFutures.getLeft().join(), pairOfFutures.getRight().join()));
+  }
+
+  /**
+   * Helper method to convert: {@code Map<K, CompletableFuture<V>>}
+   * to:                       {@code CompletableFuture<Map<K, V>>}
+   *
+   * Returns a future that completes when all value futures complete.
+   * Returned future completes exceptionally if any of the value futures complete exceptionally.
+   */
+  public static <K, V> CompletableFuture<Map<K, V>> toFutureOfMap(Map<K, CompletableFuture<V>> keyToValueFutures) {
+    return CompletableFuture
+        .allOf(keyToValueFutures.values().toArray(new CompletableFuture[0]))
+        .thenApply(v -> keyToValueFutures.entrySet().stream()
+            .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().join())));
+  }
+
+  /**
+   * Helper method to convert: {@code Map<K, CompletableFuture<V>>}
+   * to:                       {@code CompletableFuture<Map<K, V>>}
+   *
+   * Returns a future that completes with successful map entries, skipping any entries with ignored errors,
+   * when all value futures complete.
+   * Returned future completes exceptionally if any of the futures complete with a non-ignored error.
+   */
+  public static <K, V> CompletableFuture<Map<K, V>> toFutureOfMap(
+      Predicate<Throwable> ignoreError, Map<K, CompletableFuture<V>> keyToValueFutures) {
+    CompletableFuture<Void> allEntriesFuture =
+        CompletableFuture.allOf(keyToValueFutures.values().toArray(new CompletableFuture[]{}));
+
+    return allEntriesFuture.handle((aVoid, t) -> {
+      Map<K, V> successfulResults = new HashMap<>();
+      for (Map.Entry<K, CompletableFuture<V>> entry : keyToValueFutures.entrySet()) {
+        K key = entry.getKey();
+        try {
+          V value = entry.getValue().join();
+          successfulResults.put(key, value);
+        } catch (Throwable th) {
+          if (ignoreError.test(th)) {
+            // else ignore and continue
+            LOG.warn("Ignoring value future completion error for key: {}", key, th);
+          } else {
+            throw th;
+          }
+        }
+      }
+      return successfulResults;
+    });
+  }
+
+  public static <T> CompletableFuture<T> executeAsyncWithRetries(String opName,
+      Supplier<? extends CompletionStage<T>> action,
+      Predicate<? extends Throwable> abortRetries,
+      ExecutorService executor) {
+    Duration maxDuration = Duration.ofMinutes(1);
+
+    RetryPolicy<Object> retryPolicy = new RetryPolicy<>()
+        .withBackoff(100, 10000, ChronoUnit.MILLIS)
+        .withMaxDuration(maxDuration)
+        .abortOn(abortRetries) // stop retrying if predicate returns true
+        .onRetry(e -> LOG.warn("Action: {} attempt: {} completed with error {} after start. Retrying up to {}.",
+            opName, e.getAttemptCount(), e.getElapsedTime(), maxDuration, e.getLastFailure()));
+
+    return Failsafe.with(retryPolicy).with(executor).getStageAsync(action::get);
+  }
+
+  public static <T> CompletableFuture<T> failedFuture(Throwable t) {
+    final CompletableFuture<T> cf = new CompletableFuture<>();
+    cf.completeExceptionally(t);
+    return cf;
+  }
+
+  /**
+   * Removes wrapper exceptions of the provided type from the provided throwable and returns the first cause
+   * that does not match the wrapper type. Useful for unwrapping CompletionException / SamzaException
+   * in stack traces and getting to the underlying cause.
+   *
+   * Returns null if provided Throwable is null or if there is no cause of non-wrapper type in the stack.
+   */
+  public static <T extends Throwable> Throwable unwrapExceptions(Class<? extends Throwable> wrapperClassToUnwrap, T t) {
+    if (t == null) return null;
+    if (wrapperClassToUnwrap == null) return t;
+
+    Throwable originalException = t;
+    while (wrapperClassToUnwrap.isAssignableFrom(originalException.getClass()) &&
+        originalException.getCause() != null) {
+      originalException = originalException.getCause();
+    }
+
+    // can still be the wrapper class if no other cause was found.
+    if (wrapperClassToUnwrap.isAssignableFrom(originalException.getClass())) {
+      return null;
+    } else {
+      return originalException;
+    }
+  }
+}
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index 83f5c9d..dceb27b 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -27,7 +27,7 @@
 import java.util
 import java.util.{Base64, Optional}
 import java.util.concurrent.{CountDownLatch, ExecutorService, Executors, ScheduledExecutorService, ThreadPoolExecutor, TimeUnit}
-
+import java.util.function.Consumer
 import com.google.common.annotations.VisibleForTesting
 import com.google.common.util.concurrent.ThreadFactoryBuilder
 import org.apache.samza.checkpoint.{CheckpointListener, OffsetManager, OffsetManagerMetrics}
@@ -38,7 +38,7 @@
 import org.apache.samza.context._
 import org.apache.samza.diagnostics.DiagnosticsManager
 import org.apache.samza.job.model.{ContainerModel, JobModel, TaskMode}
-import org.apache.samza.metrics.{JmxServer, JvmMetrics, MetricsRegistryMap, MetricsReporter}
+import org.apache.samza.metrics.{JmxServer, JvmMetrics, MetricsRegistry, MetricsRegistryMap, MetricsReporter}
 import org.apache.samza.serializers._
 import org.apache.samza.serializers.model.SamzaObjectMapper
 import org.apache.samza.startpoint.StartpointManager
@@ -52,7 +52,6 @@
 import org.apache.samza.SamzaException
 import org.apache.samza.clustermanager.StandbyTaskUtil
 
-import scala.collection.JavaConversions
 import scala.collection.JavaConverters._
 
 object SamzaContainer extends Logging {
@@ -146,7 +145,6 @@
     val systemConfig = new SystemConfig(config)
     val containerModel = jobModel.getContainers.get(containerId)
     val containerName = "samza-container-%s" format containerId
-    val maxChangeLogStreamPartitions = jobModel.maxChangeLogStreamPartitions
 
     val containerPID = ManagementFactory.getRuntimeMXBean().getName()
 
@@ -343,31 +341,10 @@
 
     debug("Got system stream message serdes: %s" format systemStreamMessageSerdes)
 
-    val storeChangelogs = storageConfig
-      .getStoreNames.asScala
-      .filter(storageConfig.getChangelogStream(_).isPresent)
-      .map(name => (name, storageConfig.getChangelogStream(name).get)).toMap
-      .mapValues(StreamUtil.getSystemStreamFromNames(_))
+    val storeChangelogs = storageConfig.getStoreChangelogs
 
     info("Got change log system streams: %s" format storeChangelogs)
 
-    /*
-     * This keeps track of the changelog SSPs that are associated with the whole container. This is used so that we can
-     * prefetch the metadata about the all of the changelog SSPs associated with the container whenever we need the
-     * metadata about some of the changelog SSPs.
-     * An example use case is when Samza writes offset files for stores ({@link TaskStorageManager}). Each task is
-     * responsible for its own offset file, but if we can do prefetching, then most tasks will already have cached
-     * metadata by the time they need the offset metadata.
-     * Note: By using all changelog streams to build the sspsToPrefetch, any fetches done for persisted stores will
-     * include the ssps for non-persisted stores, so this is slightly suboptimal. However, this does not increase the
-     * actual number of calls to the {@link SystemAdmin}, and we can decouple this logic from the per-task objects (e.g.
-     * {@link TaskStorageManager}).
-     */
-    val changelogSSPMetadataCache = new SSPMetadataCache(systemAdmins,
-      Duration.ofSeconds(5),
-      SystemClock.instance,
-      getChangelogSSPsForContainer(containerModel, storeChangelogs).asJava)
-
     val intermediateStreams = streamConfig
       .getStreamIds()
       .asScala
@@ -398,7 +375,7 @@
       systemMessageSerdes = systemMessageSerdes,
       systemStreamKeySerdes = systemStreamKeySerdes,
       systemStreamMessageSerdes = systemStreamMessageSerdes,
-      changeLogSystemStreams = storeChangelogs.values.toSet,
+      changeLogSystemStreams = storeChangelogs.asScala.values.toSet,
       controlMessageKeySerdes = controlMessageKeySerdes,
       intermediateMessageSerdes = intermediateStreamMessageSerdes)
 
@@ -484,11 +461,19 @@
       null
     }
 
-
     val finalTaskFactory = TaskFactoryUtil.finalizeTaskFactory(
       taskFactory,
       taskThreadPool)
 
+    // executor for performing async commit operations for a task.
+    val commitThreadPoolSize =
+      Math.min(
+        Math.max(containerModel.getTasks.size() * 2, jobConfig.getCommitThreadPoolSize),
+        jobConfig.getCommitThreadPoolMaxSize
+      )
+    val commitThreadPool = Executors.newFixedThreadPool(commitThreadPoolSize,
+      new ThreadFactoryBuilder().setNameFormat("Samza Task Commit Thread-%d").setDaemon(true).build())
+
     val taskModels = containerModel.getTasks.values.asScala
     val containerContext = new ContainerContextImpl(containerModel, samzaContainerMetrics.registry)
     val applicationContainerContextOption = applicationContainerContextFactoryOption
@@ -498,8 +483,6 @@
 
     val timerExecutor = Executors.newSingleThreadScheduledExecutor
 
-    var taskStorageManagers : Map[TaskName, TaskStorageManager] = Map()
-
     val taskInstanceMetrics: Map[TaskName, TaskInstanceMetrics] = taskModels.map(taskModel => {
       (taskModel.getTaskName, new TaskInstanceMetrics("TaskName-%s" format taskModel.getTaskName))
     }).toMap
@@ -517,13 +500,15 @@
     val loggedStorageBaseDir = getLoggedStorageBaseDir(jobConfig, defaultStoreBaseDir)
     info("Got base directory for logged data stores: %s" format loggedStorageBaseDir)
 
+    val stateStorageBackendRestoreFactory = ReflectionUtil
+      .getObj(storageConfig.getStateBackendRestoreFactory(), classOf[StateBackendFactory])
+
     val containerStorageManager = new ContainerStorageManager(
       checkpointManager,
       containerModel,
       streamMetadataCache,
-      changelogSSPMetadataCache,
       systemAdmins,
-      storeChangelogs.asJava,
+      storeChangelogs,
       sideInputStoresToSystemStreams.mapValues(systemStreamSet => systemStreamSet.toSet.asJava).asJava,
       storageEngineFactories.asJava,
       systemFactories.asJava,
@@ -533,15 +518,19 @@
       samzaContainerMetrics,
       jobContext,
       containerContext,
+      stateStorageBackendRestoreFactory,
       taskCollectors.asJava,
       loggedStorageBaseDir,
       nonLoggedStorageBaseDir,
-      maxChangeLogStreamPartitions,
       serdeManager,
       new SystemClock)
 
     storeWatchPaths.addAll(containerStorageManager.getStoreDirectoryPaths)
 
+    val stateStorageBackendBackupFactories = storageConfig.getStateBackendBackupFactories.asScala.map(
+      ReflectionUtil.getObj(_, classOf[StateBackendFactory])
+    )
+
     // Create taskInstances
     val taskInstances: Map[TaskName, TaskInstance] = taskModels
       .filter(taskModel => taskModel.getTaskMode.eq(TaskMode.Active)).map(taskModel => {
@@ -563,15 +552,23 @@
       val taskSideInputSSPs = sideInputStoresToSSPs.values.flatMap(_.asScala).toSet
       info ("Got task side input SSPs: %s" format taskSideInputSSPs)
 
-      val storageManager = TaskStorageManagerFactory.create(
-        taskName,
-        containerStorageManager,
-        storeChangelogs,
-        systemAdmins,
-        loggedStorageBaseDir,
-        taskModel.getChangelogPartition,
-        config,
-        taskModel.getTaskMode)
+      val taskBackupManagerMap = new util.HashMap[String, TaskBackupManager]()
+      stateStorageBackendBackupFactories.asJava.forEach(new Consumer[StateBackendFactory] {
+        override def accept(factory: StateBackendFactory): Unit = {
+          val taskMetricsRegistry =
+            if (taskInstanceMetrics.contains(taskName) &&
+              taskInstanceMetrics.get(taskName).isDefined) taskInstanceMetrics.get(taskName).get.registry
+            else new MetricsRegistryMap
+          val taskBackupManager = factory.getBackupManager(jobContext, containerContext,
+            taskModel, commitThreadPool, taskMetricsRegistry, config, new SystemClock,
+            loggedStorageBaseDir, nonLoggedStorageBaseDir)
+          taskBackupManagerMap.put(factory.getClass.getName, taskBackupManager)
+        }
+      })
+
+      val commitManager = new TaskStorageCommitManager(taskName, taskBackupManagerMap,
+        containerStorageManager, storeChangelogs, taskModel.getChangelogPartition, checkpointManager, config,
+        commitThreadPool, new StorageManagerUtil, loggedStorageBaseDir, taskInstanceMetrics.get(taskName).get)
 
       val tableManager = new TableManager(config)
 
@@ -585,14 +582,16 @@
           consumerMultiplexer = consumerMultiplexer,
           collector = taskCollectors.get(taskName).get,
           offsetManager = offsetManager,
-          storageManager = storageManager,
+          commitManager = commitManager,
+          containerStorageManager = containerStorageManager,
           tableManager = tableManager,
-          systemStreamPartitions = JavaConversions.setAsJavaSet(taskSSPs -- taskSideInputSSPs),
+          systemStreamPartitions = (taskSSPs -- taskSideInputSSPs).asJava,
           exceptionHandler = TaskInstanceExceptionHandler(taskInstanceMetrics.get(taskName).get, taskConfig),
           jobModel = jobModel,
           streamMetadataCache = streamMetadataCache,
           inputStreamMetadata = inputStreamMetadata,
           timerExecutor = timerExecutor,
+          commitThreadPool = commitThreadPool,
           jobContext = jobContext,
           containerContext = containerContext,
           applicationContainerContextOption = applicationContainerContextOption,
@@ -601,7 +600,6 @@
 
       val taskInstance = createTaskInstance(task)
 
-      taskStorageManagers += taskInstance.taskName -> storageManager
       (taskName, taskInstance)
     }).toMap
 
@@ -684,6 +682,7 @@
       diskSpaceMonitor = diskSpaceMonitor,
       hostStatisticsMonitor = hostStatisticsMonitor,
       taskThreadPool = taskThreadPool,
+      commitThreadPool = commitThreadPool,
       timerExecutor = timerExecutor,
       containerContext = containerContext,
       applicationContainerContextOption = applicationContainerContextOption,
@@ -691,19 +690,6 @@
       containerStorageManager = containerStorageManager,
       diagnosticsManager = diagnosticsManager)
   }
-
-  /**
-    * Builds the set of SSPs for all changelogs on this container.
-    */
-  @VisibleForTesting
-  private[container] def getChangelogSSPsForContainer(containerModel: ContainerModel,
-    changeLogSystemStreams: Map[String, SystemStream]): Set[SystemStreamPartition] = {
-    containerModel.getTasks.values().asScala
-      .map(taskModel => taskModel.getChangelogPartition)
-      .flatMap(changelogPartition => changeLogSystemStreams.map { case (_, systemStream) =>
-        new SystemStreamPartition(systemStream, changelogPartition) })
-      .toSet
-  }
 }
 
 class SamzaContainer(
@@ -723,6 +709,7 @@
   reporters: Map[String, MetricsReporter] = Map(),
   jvm: JvmMetrics = null,
   taskThreadPool: ExecutorService = null,
+  commitThreadPool: ExecutorService = null,
   timerExecutor: ScheduledExecutorService = Executors.newSingleThreadScheduledExecutor,
   containerContext: ContainerContext,
   applicationContainerContextOption: Option[ApplicationContainerContext],
@@ -1053,7 +1040,7 @@
       info("Shutting down task thread pool")
       try {
         taskThreadPool.shutdown()
-        if(taskThreadPool.awaitTermination(shutdownMs, TimeUnit.MILLISECONDS)) {
+        if (!taskThreadPool.awaitTermination(shutdownMs, TimeUnit.MILLISECONDS)) {
           taskThreadPool.shutdownNow()
         }
       } catch {
@@ -1061,11 +1048,23 @@
       }
     }
 
+    if (commitThreadPool != null) {
+      info("Shutting down task commit thread pool")
+      try {
+        commitThreadPool.shutdown()
+        if(!commitThreadPool.awaitTermination(shutdownMs, TimeUnit.MILLISECONDS)) {
+          commitThreadPool.shutdownNow()
+        }
+      } catch {
+        case e: Exception => error(e.getMessage, e)
+      }
+    }
+
     if (timerExecutor != null) {
       info("Shutting down timer executor")
       try {
         timerExecutor.shutdown()
-        if (timerExecutor.awaitTermination(shutdownMs, TimeUnit.MILLISECONDS)) {
+        if (!timerExecutor.awaitTermination(shutdownMs, TimeUnit.MILLISECONDS)) {
           timerExecutor.shutdownNow()
         }
       } catch {
diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
index 2ebe465..8a872de 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
@@ -21,21 +21,26 @@
 
 
 import java.util.{Collections, Objects, Optional}
-import java.util.concurrent.ScheduledExecutorService
-
+import java.util.concurrent.{CompletableFuture, ExecutorService, ScheduledExecutorService, Semaphore, TimeUnit}
 import org.apache.samza.SamzaException
-import org.apache.samza.checkpoint.{Checkpoint, CheckpointId, CheckpointedChangelogOffset, OffsetManager}
+import org.apache.samza.checkpoint.kafka.{KafkaChangelogSSPOffset, KafkaStateCheckpointMarker}
+import org.apache.samza.checkpoint.{CheckpointId, CheckpointV1, CheckpointV2, OffsetManager}
 import org.apache.samza.config.{Config, StreamConfig, TaskConfig}
 import org.apache.samza.context._
 import org.apache.samza.job.model.{JobModel, TaskModel}
 import org.apache.samza.scheduler.{CallbackSchedulerImpl, EpochTimeScheduler, ScheduledCallback}
 import org.apache.samza.storage.kv.KeyValueStore
-import org.apache.samza.storage.TaskStorageManager
+import org.apache.samza.storage.{ContainerStorageManager, TaskStorageCommitManager}
 import org.apache.samza.system._
 import org.apache.samza.table.TableManager
 import org.apache.samza.task._
+import org.apache.samza.util.ScalaJavaUtil.JavaOptionals.toRichOptional
 import org.apache.samza.util.{Logging, ScalaJavaUtil}
 
+import java.util
+import java.util.concurrent.atomic.AtomicReference
+import java.util.function.BiConsumer
+import java.util.function.Function
 import scala.collection.JavaConversions._
 import scala.collection.JavaConverters._
 import scala.collection.{JavaConverters, Map}
@@ -48,14 +53,16 @@
   consumerMultiplexer: SystemConsumers,
   collector: TaskInstanceCollector,
   override val offsetManager: OffsetManager = new OffsetManager,
-  storageManager: TaskStorageManager = null,
+  commitManager: TaskStorageCommitManager = null,
+  containerStorageManager: ContainerStorageManager = null,
   tableManager: TableManager = null,
   val systemStreamPartitions: java.util.Set[SystemStreamPartition] = Collections.emptySet(),
   val exceptionHandler: TaskInstanceExceptionHandler = new TaskInstanceExceptionHandler,
   jobModel: JobModel = null,
   streamMetadataCache: StreamMetadataCache = null,
   inputStreamMetadata: Map[SystemStream, SystemStreamMetadata] = Map(),
-  timerExecutor : ScheduledExecutorService = null,
+  timerExecutor: ScheduledExecutorService = null,
+  commitThreadPool: ExecutorService = null,
   jobContext: JobContext,
   containerContext: ContainerContext,
   applicationContainerContextOption: Option[ApplicationContainerContext],
@@ -73,8 +80,9 @@
 
   private val kvStoreSupplier = ScalaJavaUtil.toJavaFunction(
     (storeName: String) => {
-      if (storageManager != null && storageManager.getStore(storeName).isDefined) {
-        storageManager.getStore(storeName).get.asInstanceOf[KeyValueStore[_, _]]
+      if (containerStorageManager != null) {
+        val storeOption = containerStorageManager.getStore(taskName, storeName).toOption
+        if (storeOption.isDefined) storeOption.get.asInstanceOf[KeyValueStore[_, _]] else null
       } else {
         null
       }
@@ -97,12 +105,21 @@
   systemStreamPartitions.foreach(ssp2CaughtupMapping += _ -> false)
 
   private val config: Config = jobContext.getConfig
+  val taskConfig = new TaskConfig(config)
 
   val streamConfig: StreamConfig = new StreamConfig(config)
   override val intermediateStreams: java.util.Set[String] = JavaConverters.setAsJavaSetConverter(streamConfig.getStreamIds.filter(streamConfig.getIsIntermediateStream)).asJava
 
   val streamsToDeleteCommittedMessages: Set[String] = streamConfig.getStreamIds.filter(streamConfig.getDeleteCommittedMessages).map(streamConfig.getPhysicalName).toSet
 
+  val checkpointWriteVersions = new TaskConfig(config).getCheckpointWriteVersions
+
+  @volatile var lastCommitStartTimeMs = System.currentTimeMillis()
+  val commitMaxDelayMs = taskConfig.getCommitMaxDelayMs
+  val commitTimeoutMs = taskConfig.getCommitTimeoutMs
+  val commitInProgress = new Semaphore(1)
+  val commitException = new AtomicReference[Exception]()
+
   def registerOffsets {
     debug("Registering offsets for taskName: %s" format taskName)
     offsetManager.register(taskName, systemStreamPartitions)
@@ -121,10 +138,33 @@
   def initTask {
     initCaughtUpMapping()
 
-    val taskConfig = new TaskConfig(config)
+    if (commitManager != null) {
+      debug("Starting commit manager for taskName: %s" format taskName)
+
+      commitManager.init()
+    } else {
+      debug("Skipping commit manager initialization for taskName: %s" format taskName)
+    }
+
+    if (offsetManager != null) {
+      val checkpoint = offsetManager.getLastTaskCheckpoint(taskName)
+      // Only required for checkpointV2
+      if (checkpoint != null && checkpoint.getVersion == 2) {
+        val checkpointV2 = checkpoint.asInstanceOf[CheckpointV2]
+        // call cleanUp on backup managers in case the container previously failed during commit
+        // before completing this step
+
+        // WARNING: cleanUp is NOT optional with blob stores since this is where we reset the TTL for
+        // tracked blobs. if this TTL reset is skipped, some of the blobs retained by future commits may
+        // be deleted in the background by the blob store, leading to data loss.
+        debug("Cleaning up stale state from previous run for taskName: %s" format taskName)
+        commitManager.cleanUp(checkpointV2.getCheckpointId, checkpointV2.getStateCheckpointMarkers)
+      }
+    }
+
     if (taskConfig.getTransactionalStateRestoreEnabled() && taskConfig.getCommitMs > 0) {
-      // Commit immediately so the trimmed changelog messages
-      // will be sealed in a checkpoint
+      debug("Committing immediately on startup for taskName: %s so that the trimmed changelog " +
+        "messages will be sealed in a checkpoint" format taskName)
       commit
     }
 
@@ -178,8 +218,9 @@
     if (ssp2CaughtupMapping(incomingMessageSsp)) {
       metrics.messagesActuallyProcessed.inc
 
-      trace("Processing incoming message envelope for taskName and SSP: %s, %s"
-        format (taskName, incomingMessageSsp))
+      // TODO BLOCKER pmaheshw reenable after demo
+//      trace("Processing incoming message envelope for taskName: %s SSP: %s offset: %s"
+//        format (taskName, incomingMessageSsp, envelope.getOffset))
 
       exceptionHandler.maybeHandle {
         val callback = callbackFactory.createCallback()
@@ -219,68 +260,245 @@
   }
 
   def commit {
-    metrics.commits.inc
+    // ensure that only one commit (including sync and async phases) is ever in progress for a task.
 
-    val allCheckpointOffsets = new java.util.HashMap[SystemStreamPartition, String]()
-    val inputCheckpoint = offsetManager.buildCheckpoint(taskName)
-    if (inputCheckpoint != null) {
-      trace("Got input offsets for taskName: %s as: %s" format(taskName, inputCheckpoint.getOffsets))
-      allCheckpointOffsets.putAll(inputCheckpoint.getOffsets)
+    val commitStartNs = System.nanoTime()
+    // first check if there were any unrecoverable errors during the async stage of the pending commit
+    // and if so, shut down the container.
+    if (commitException.get() != null) {
+      throw new SamzaException("Unrecoverable error during pending commit for taskName: %s." format taskName,
+        commitException.get())
     }
 
-    trace("Flushing producers for taskName: %s" format taskName)
+    // if no commit is in progress for this task, continue with this commit.
+    // if a previous commit is in progress but less than {@code task.commit.max.delay.ms}
+    // have elapsed since it started, skip this commit request.
+    // if more time has elapsed than that, block this commit until the previous commit
+    // is complete, then continue with this commit.
+    if (!commitInProgress.tryAcquire()) {
+      val timeSinceLastCommit = System.currentTimeMillis() - lastCommitStartTimeMs
+      if (timeSinceLastCommit < commitMaxDelayMs) {
+        info("Skipping commit for taskName: %s since another commit is in progress. " +
+          "%s ms have elapsed since the pending commit started." format (taskName, timeSinceLastCommit))
+        metrics.commitsSkipped.set(metrics.commitsSkipped.getValue + 1)
+        return
+      } else {
+        warn("Blocking processing for taskName: %s until in-flight commit is complete. " +
+          "%s ms have elapsed since the pending commit started, " +
+          "which is greater than the max allowed commit delay: %s."
+          format (taskName, timeSinceLastCommit, commitMaxDelayMs))
+
+        if (!commitInProgress.tryAcquire(commitTimeoutMs, TimeUnit.MILLISECONDS)) {
+          val timeSinceLastCommit = System.currentTimeMillis() - lastCommitStartTimeMs
+          throw new SamzaException("Timeout waiting for pending commit for taskName: %s to finish. " +
+            "%s ms have elapsed since the pending commit started. Max allowed commit delay is %s ms " +
+            "and commit timeout beyond that is %s ms" format (taskName, timeSinceLastCommit,
+            commitMaxDelayMs, commitTimeoutMs))
+        }
+      }
+    }
+    // at this point the permit for semaphore has been acquired, proceed with commit.
+    // the first part of the commit needs to be exclusive with processing, so do it on the caller thread.
+    lastCommitStartTimeMs = System.currentTimeMillis()
+
+    metrics.commits.inc
+    val checkpointId = CheckpointId.create()
+
+    debug("Starting sync stage of commit for taskName: %s checkpointId: %s" format (taskName, checkpointId))
+
+    val inputOffsets = offsetManager.getLastProcessedOffsets(taskName)
+    trace("Got last processed input offsets for taskName: %s checkpointId: %s as: %s"
+      format(taskName, checkpointId, inputOffsets))
+
+    trace("Flushing producers for taskName: %s checkpointId: %s" format (taskName, checkpointId))
+    // Flushes output, checkpoint and changelog producers
     collector.flush
 
     if (tableManager != null) {
-      trace("Flushing tables for taskName: %s" format taskName)
+      trace("Flushing tables for taskName: %s checkpointId: %s" format (taskName, checkpointId))
       tableManager.flush()
     }
 
-    var newestChangelogOffsets: Map[SystemStreamPartition, Option[String]] = null
-    if (storageManager != null) {
-      trace("Flushing state stores for taskName: %s" format taskName)
-      newestChangelogOffsets = storageManager.flush()
-      trace("Got newest changelog offsets for taskName: %s as: %s " format(taskName, newestChangelogOffsets))
-    }
+    // create a synchronous snapshot of stores for commit
+    debug("Creating synchronous state store snapshots for taskName: %s checkpointId: %s"
+      format (taskName, checkpointId))
+    val snapshotStartTimeNs = System.nanoTime()
+    val snapshotSCMs = commitManager.snapshot(checkpointId)
+    metrics.snapshotNs.update(System.nanoTime() - snapshotStartTimeNs)
+    trace("Got synchronous snapshot SCMs for taskName: %s checkpointId: %s as: %s "
+      format(taskName, checkpointId, snapshotSCMs))
 
-    val checkpointId = CheckpointId.create()
-    if (storageManager != null && newestChangelogOffsets != null) {
-      trace("Checkpointing stores for taskName: %s with checkpoint id: %s" format (taskName, checkpointId))
-      storageManager.checkpoint(checkpointId, newestChangelogOffsets.toMap)
-    }
+    debug("Submitting async stage of commit for taskName: %s checkpointId: %s for execution"
+      format (taskName, checkpointId))
+    val asyncStageStartNs = System.nanoTime()
+    // rest of the commit can happen asynchronously and concurrently with processing.
+    // schedule it on the commit executor and return. submitted runnable releases the
+    // commit semaphore permit when this commit is complete.
+    commitThreadPool.submit(new Runnable {
+      override def run(): Unit = {
+        debug("Starting async stage of commit for taskName: %s checkpointId: %s" format (taskName, checkpointId))
 
-    if (newestChangelogOffsets != null) {
-      newestChangelogOffsets.foreach {case (ssp, newestOffsetOption) =>
-        val offset = new CheckpointedChangelogOffset(checkpointId, newestOffsetOption.orNull).toString
-        allCheckpointOffsets.put(ssp, offset)
-      }
-    }
-    val checkpoint = new Checkpoint(allCheckpointOffsets)
-    trace("Got combined checkpoint offsets for taskName: %s as: %s" format (taskName, allCheckpointOffsets))
+        try {
+          val uploadStartTimeNs = System.nanoTime()
+          val uploadSCMsFuture = commitManager.upload(checkpointId, snapshotSCMs)
+          uploadSCMsFuture.whenComplete(new BiConsumer[util.Map[String, util.Map[String, String]], Throwable] {
+            override def accept(t: util.Map[String, util.Map[String, String]], throwable: Throwable): Unit = {
+              if (throwable == null) {
+                metrics.asyncUploadNs.update(System.nanoTime() - uploadStartTimeNs)
+              } else {
+                warn("Commit upload did not complete successfully for taskName: %s checkpointId: %s with error msg: %s"
+                  format (taskName, checkpointId, throwable.getMessage))
+              }
+            }
+          })
 
-    offsetManager.writeCheckpoint(taskName, checkpoint)
+          // explicit types required to make scala compiler happy
+          val checkpointWriteFuture: CompletableFuture[util.Map[String, util.Map[String, String]]] =
+            uploadSCMsFuture.thenApplyAsync(writeCheckpoint(checkpointId, inputOffsets), commitThreadPool)
 
-    if (storageManager != null) {
-      trace("Remove old checkpoint stores for taskName: %s" format taskName)
-      try {
-        storageManager.removeOldCheckpoints(checkpointId)
-      } catch {
-        case e: Exception => error("Failed to remove old checkpoints for task: %s. Current checkpointId: %s" format (taskName, checkpointId), e)
-      }
-    }
+          val cleanupStartTimeNs = System.nanoTime()
+          val cleanUpFuture: CompletableFuture[Void] =
+            checkpointWriteFuture.thenComposeAsync(cleanUp(checkpointId), commitThreadPool)
+          cleanUpFuture.whenComplete(new BiConsumer[Void, Throwable] {
+            override def accept(v: Void, throwable: Throwable): Unit = {
+              if (throwable == null) {
+                metrics.asyncCleanupNs.update(System.nanoTime() - cleanupStartTimeNs)
+              } else {
+                warn("Commit cleanup did not complete successfully for taskName: %s checkpointId: %s with error msg: %s"
+                  format (taskName, checkpointId, throwable.getMessage))
+              }
+            }
+          })
 
-    if (inputCheckpoint != null) {
-      trace("Deleting committed input offsets for taskName: %s" format taskName)
-      inputCheckpoint.getOffsets.asScala
-        .filter { case (ssp, _) => streamsToDeleteCommittedMessages.contains(ssp.getStream) } // Only delete data of intermediate streams
-        .groupBy { case (ssp, _) => ssp.getSystem }
-        .foreach { case (systemName: String, offsets: Map[SystemStreamPartition, String]) =>
-          systemAdmins.getSystemAdmin(systemName).deleteMessages(offsets.asJava)
+          val trimFuture = cleanUpFuture.thenRunAsync(
+            trim(checkpointId, inputOffsets), commitThreadPool)
+
+          trimFuture.whenCompleteAsync(handleCompletion(checkpointId, commitStartNs, asyncStageStartNs), commitThreadPool)
+        } catch {
+          case t: Throwable => handleCompletion(checkpointId, commitStartNs, asyncStageStartNs).accept(null, t)
         }
+      }
+    })
+
+    metrics.commitSyncNs.update(System.nanoTime() - commitStartNs)
+    debug("Finishing sync stage of commit for taskName: %s checkpointId: %s" format (taskName, checkpointId))
+  }
+
+  private def writeCheckpoint(checkpointId: CheckpointId, inputOffsets: util.Map[SystemStreamPartition, String]) = {
+    new Function[util.Map[String, util.Map[String, String]], util.Map[String, util.Map[String, String]]]() {
+      override def apply(uploadSCMs: util.Map[String, util.Map[String, String]]) = {
+        trace("Got asynchronous upload SCMs for taskName: %s checkpointId: %s as: %s "
+          format(taskName, checkpointId, uploadSCMs))
+
+        debug("Creating and writing checkpoints for taskName: %s checkpointId: %s" format (taskName, checkpointId))
+        checkpointWriteVersions.foreach(checkpointWriteVersion => {
+          val checkpoint = if (checkpointWriteVersion == 1) {
+            // build CheckpointV1 with KafkaChangelogSSPOffset for backwards compatibility
+            val allCheckpointOffsets = new util.HashMap[SystemStreamPartition, String]()
+            allCheckpointOffsets.putAll(inputOffsets)
+            val newestChangelogOffsets = KafkaStateCheckpointMarker.scmsToSSPOffsetMap(uploadSCMs)
+            newestChangelogOffsets.foreach { case (ssp, newestOffsetOption) =>
+              val offset = new KafkaChangelogSSPOffset(checkpointId, newestOffsetOption.orNull).toString
+              allCheckpointOffsets.put(ssp, offset)
+            }
+            new CheckpointV1(allCheckpointOffsets)
+          } else if (checkpointWriteVersion == 2) {
+            new CheckpointV2(checkpointId, inputOffsets, uploadSCMs)
+          } else {
+            throw new SamzaException("Unsupported checkpoint write version: " + checkpointWriteVersion)
+          }
+
+          trace("Writing checkpoint for taskName: %s checkpointId: %s as: %s"
+            format(taskName, checkpointId, checkpoint))
+
+          // Write input offsets and state checkpoint markers to task store and checkpoint directories
+          commitManager.writeCheckpointToStoreDirectories(checkpoint)
+
+          // Write input offsets and state checkpoint markers to the checkpoint topic atomically
+          offsetManager.writeCheckpoint(taskName, checkpoint)
+        })
+
+        uploadSCMs
+      }
+    }
+  }
+
+  private def cleanUp(checkpointId: CheckpointId) = {
+    new Function[util.Map[String, util.Map[String, String]], CompletableFuture[Void]] {
+      override def apply(uploadSCMs: util.Map[String, util.Map[String, String]]): CompletableFuture[Void] = {
+        // Perform cleanup on unused checkpoints
+        debug("Cleaning up old checkpoint state for taskName: %s checkpointId: %s" format(taskName, checkpointId))
+        try {
+          commitManager.cleanUp(checkpointId, uploadSCMs)
+        } catch {
+          case e: Exception =>
+            // WARNING: cleanUp is NOT optional with blob stores since this is where we reset the TTL for
+            // tracked blobs. if this TTL reset is skipped, some of the blobs retained by future commits may
+            // be deleted in the background by the blob store, leading to data loss.
+            throw new SamzaException(
+              "Failed to remove old checkpoint state for taskName: %s checkpointId: %s."
+                format(taskName, checkpointId), e)
+        }
+      }
+    }
+  }
+
+  private def trim(checkpointId: CheckpointId, inputOffsets: util.Map[SystemStreamPartition, String]) = {
+    new Runnable {
+      override def run(): Unit = {
+        trace("Deleting committed input offsets from intermediate topics for taskName: %s checkpointId: %s"
+          format (taskName, checkpointId))
+        inputOffsets.asScala
+          .filter { case (ssp, _) => streamsToDeleteCommittedMessages.contains(ssp.getStream) } // Only delete data of intermediate streams
+          .groupBy { case (ssp, _) => ssp.getSystem }
+          .foreach { case (systemName: String, offsets: Map[SystemStreamPartition, String]) =>
+            systemAdmins.getSystemAdmin(systemName).deleteMessages(offsets.asJava)
+          }
+      }
+    }
+  }
+
+  private def handleCompletion(checkpointId: CheckpointId, commitStartNs: Long, asyncStageStartNs: Long) = {
+    new BiConsumer[Void, Throwable] {
+      override def accept(v: Void, e: Throwable): Unit = {
+        try {
+          debug("%s finishing async stage of commit for taskName: %s checkpointId: %s."
+            format (if (e == null) "Successfully" else "Unsuccessfully", taskName, checkpointId))
+          if (e != null) {
+            val exception = new SamzaException("Unrecoverable error during async stage of commit " +
+              "for taskName: %s checkpointId: %s" format(taskName, checkpointId), e)
+            val exceptionSet = commitException.compareAndSet(null, exception)
+            if (!exceptionSet) {
+              // should never happen because there should be at most one async stage of commit in progress
+              // for a task and another one shouldn't be schedule if the previous one failed. throw a new
+              // exception on the caller thread for logging and debugging if this happens.
+              error("Should not have encountered a non-null saved exception during async stage of " +
+                "commit for taskName: %s checkpointId: %s" format(taskName, checkpointId), commitException.get())
+              error("New exception during async stage of commit for taskName: %s checkpointId: %s"
+                format(taskName, checkpointId), exception)
+              throw new SamzaException("Should not have encountered a non-null saved exception " +
+                "during async stage of commit for taskName: %s checkpointId: %s. New exception logged above. " +
+                "Saved exception under Caused By.", commitException.get())
+            }
+          } else {
+            metrics.commitAsyncNs.update(System.nanoTime() - asyncStageStartNs)
+            metrics.commitNs.update(System.nanoTime() - commitStartNs)
+          }
+        } finally {
+          // release the permit indicating that previous commit is complete.
+          commitInProgress.release()
+        }
+      }
     }
   }
 
   def shutdownTask {
+    if (commitManager != null) {
+      debug("Shutting down commit manager for taskName: %s" format taskName)
+      commitManager.close()
+    } else {
+      debug("Skipping commit manager shutdown for taskName: %s" format taskName)
+    }
     applicationTaskContextOption.foreach(applicationTaskContext => {
       debug("Stopping application-defined task context for taskName: %s" format taskName)
       applicationTaskContext.stop()
diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
index bdd773c..f13e37a 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstanceMetrics.scala
@@ -37,7 +37,16 @@
   val flushes = newCounter("flush-calls")
   val pendingMessages = newGauge("pending-messages", 0)
   val messagesInFlight = newGauge("messages-in-flight", 0)
-  val asyncCallbackCompleted = newCounter("async-callback-complete-calls");
+  val asyncCallbackCompleted = newCounter("async-callback-complete-calls")
+
+  val commitsSkipped = newGauge("commits-skipped", 0)
+  val commitNs = newTimer("commit-ns")
+  val commitSyncNs = newTimer("commit-sync-ns")
+  val commitAsyncNs = newTimer("commit-async-ns")
+  val snapshotNs = newTimer("snapshot-ns")
+  val storeCheckpointNs = newTimer("store-checkpoint-ns")
+  val asyncUploadNs = newTimer("async-upload-ns")
+  val asyncCleanupNs = newTimer("async-cleanup-ns")
 
   def addOffsetGauge(systemStreamPartition: SystemStreamPartition, getValue: () => String) {
     newGauge("%s-%s-%d-offset" format (systemStreamPartition.getSystem, systemStreamPartition.getStream, systemStreamPartition.getPartition.getPartitionId), getValue)
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/KafkaNonTransactionalStateTaskBackupManager.scala b/samza-core/src/main/scala/org/apache/samza/storage/KafkaNonTransactionalStateTaskBackupManager.scala
new file mode 100644
index 0000000..633191b
--- /dev/null
+++ b/samza-core/src/main/scala/org/apache/samza/storage/KafkaNonTransactionalStateTaskBackupManager.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage
+
+import java.util
+import java.util.concurrent.CompletableFuture
+
+import com.google.common.collect.ImmutableSet
+import org.apache.samza.checkpoint.kafka.KafkaStateCheckpointMarker
+import org.apache.samza.checkpoint.{Checkpoint, CheckpointId}
+import org.apache.samza.container.TaskName
+import org.apache.samza.system._
+import org.apache.samza.util.Logging
+import org.apache.samza.{Partition, SamzaException}
+
+import scala.collection.JavaConverters._
+
+/**
+ * Manage all the storage engines for a given task
+ */
+class KafkaNonTransactionalStateTaskBackupManager(
+  taskName: TaskName,
+  storeChangelogs: util.Map[String, SystemStream] = new util.HashMap[String, SystemStream](),
+  systemAdmins: SystemAdmins,
+  partition: Partition) extends Logging with TaskBackupManager {
+
+  override def init(checkpoint: Checkpoint): Unit = {}
+
+  override def snapshot(checkpointId: CheckpointId): util.Map[String, String] = {
+    debug("Getting newest offsets for kafka changelog SSPs.")
+    getNewestChangelogSSPOffsets()
+  }
+
+  override def upload(checkpointId: CheckpointId,
+    stateCheckpointMarkers: util.Map[String, String]): CompletableFuture[util.Map[String, String]] = {
+     CompletableFuture.completedFuture(stateCheckpointMarkers)
+  }
+
+  override def cleanUp(checkpointId: CheckpointId,
+    stateCheckpointMarker: util.Map[String, String]): CompletableFuture[Void] = {
+    CompletableFuture.completedFuture(null)
+  }
+
+  override def close() {}
+
+  /**
+   * Returns the newest offset for each store changelog SSP for this task.
+   * @return A map of changelog SSPs for this task to their newest offset (or None if ssp is empty)
+   * @throws SamzaException if there was an error fetching newest offset for any SSP
+   */
+  private def getNewestChangelogSSPOffsets(): util.Map[String, String] = {
+    storeChangelogs.asScala
+      .map { case (storeName, systemStream) => {
+        debug("Fetching newest offset for taskName %s store %s changelog %s" format (taskName, storeName, systemStream))
+        val ssp = new SystemStreamPartition(systemStream.getSystem, systemStream.getStream, partition)
+        val systemAdmin = systemAdmins.getSystemAdmin(systemStream.getSystem)
+
+        try {
+          val sspMetadataOption = Option(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp)).get(ssp))
+
+          // newest offset == null implies topic is empty
+          val newestOffsetOption = sspMetadataOption.flatMap(sspMetadata => Option(sspMetadata.getNewestOffset))
+          newestOffsetOption.foreach(newestOffset =>
+            debug("Got newest offset %s for taskName %s store %s changelog %s" format(newestOffset, taskName, storeName, systemStream)))
+
+          (storeName, KafkaStateCheckpointMarker.serialize(new KafkaStateCheckpointMarker(ssp, newestOffsetOption.orNull)))
+        } catch {
+          case e: Exception =>
+            throw new SamzaException("Error getting newest changelog offset for taskName %s store %s changelog %s."
+              format(taskName, storeName, systemStream), e)
+        }
+      }}.asJava
+  }
+}
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/KafkaTransactionalStateTaskBackupManager.scala b/samza-core/src/main/scala/org/apache/samza/storage/KafkaTransactionalStateTaskBackupManager.scala
new file mode 100644
index 0000000..dc28fe7
--- /dev/null
+++ b/samza-core/src/main/scala/org/apache/samza/storage/KafkaTransactionalStateTaskBackupManager.scala
@@ -0,0 +1,97 @@
+package org.apache.samza.storage
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+import java.util
+import java.util.concurrent.CompletableFuture
+
+import com.google.common.annotations.VisibleForTesting
+import com.google.common.collect.ImmutableSet
+import org.apache.samza.checkpoint.kafka.KafkaStateCheckpointMarker
+import org.apache.samza.checkpoint.{Checkpoint, CheckpointId}
+import org.apache.samza.container.TaskName
+import org.apache.samza.system._
+import org.apache.samza.util.Logging
+import org.apache.samza.{Partition, SamzaException}
+
+import scala.collection.JavaConverters._
+
+/**
+ * Manage all the storage engines for a given task
+ */
+class KafkaTransactionalStateTaskBackupManager(
+  taskName: TaskName,
+  storeChangelogs: util.Map[String, SystemStream] = new util.HashMap[String, SystemStream](),
+  systemAdmins: SystemAdmins,
+  partition: Partition) extends Logging with TaskBackupManager {
+
+  override def init(checkpoint: Checkpoint): Unit = {}
+
+  override def snapshot(checkpointId: CheckpointId): util.Map[String, String] = {
+    debug("Getting newest offsets for kafka changelog SSPs.")
+    getNewestChangelogSSPOffsets(taskName, storeChangelogs, partition, systemAdmins)
+  }
+
+  override def upload(checkpointId: CheckpointId, snapshotCheckpointsMap: util.Map[String, String]):
+  CompletableFuture[util.Map[String, String]] = {
+    CompletableFuture.completedFuture(snapshotCheckpointsMap)
+  }
+
+  override def cleanUp(checkpointId: CheckpointId,
+    stateCheckpointMarker: util.Map[String, String]): CompletableFuture[Void] = {
+    CompletableFuture.completedFuture(null)
+  }
+
+  override def close() {}
+
+  /**
+   * Returns the newest offset for each store changelog SSP for this task. Returned map will
+   * always contain an entry for every changelog SSP.
+   * @return A map of storenames for this task to their ssp and newest offset (null if empty) wrapped in KafkaStateCheckpointMarker
+   * @throws SamzaException if there was an error fetching newest offset for any SSP
+   */
+  @VisibleForTesting
+  def getNewestChangelogSSPOffsets(taskName: TaskName, storeChangelogs: util.Map[String, SystemStream],
+      partition: Partition, systemAdmins: SystemAdmins): util.Map[String, String] = {
+    storeChangelogs.asScala
+      .map { case (storeName, systemStream) => {
+        try {
+          debug("Fetching newest offset for taskName %s store %s changelog %s" format (taskName, storeName, systemStream))
+          val ssp = new SystemStreamPartition(systemStream.getSystem, systemStream.getStream, partition)
+          val systemAdmin = systemAdmins.getSystemAdmin(systemStream.getSystem)
+
+          val sspMetadata = Option(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp)).get(ssp))
+            .getOrElse(throw new SamzaException("Received null metadata for ssp: %s" format ssp))
+
+          // newest offset == null implies topic is empty
+          val newestOffsetOption = Option(sspMetadata.getNewestOffset)
+          newestOffsetOption.foreach(newestOffset =>
+            debug("Got newest offset %s for taskName %s store %s changelog %s" format(newestOffset, taskName, storeName, systemStream)))
+
+          (storeName, KafkaStateCheckpointMarker.serialize(new KafkaStateCheckpointMarker(ssp, newestOffsetOption.orNull)))
+        } catch {
+          case e: Exception =>
+            throw new SamzaException("Error getting newest changelog offset for taskName %s store %s changelog %s."
+              format(taskName, storeName, systemStream), e)
+        }
+      }}
+      .toMap.asJava
+  }
+}
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/NonTransactionalStateTaskStorageManager.scala b/samza-core/src/main/scala/org/apache/samza/storage/NonTransactionalStateTaskStorageManager.scala
deleted file mode 100644
index 7b38749..0000000
--- a/samza-core/src/main/scala/org/apache/samza/storage/NonTransactionalStateTaskStorageManager.scala
+++ /dev/null
@@ -1,144 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.storage
-
-import java.io._
-
-import com.google.common.annotations.VisibleForTesting
-import com.google.common.collect.ImmutableSet
-import org.apache.samza.checkpoint.CheckpointId
-import org.apache.samza.container.TaskName
-import org.apache.samza.job.model.TaskMode
-import org.apache.samza.system._
-import org.apache.samza.util.Logging
-import org.apache.samza.util.ScalaJavaUtil.JavaOptionals
-import org.apache.samza.{Partition, SamzaException}
-
-import scala.collection.JavaConverters._
-
-/**
- * Manage all the storage engines for a given task
- */
-class NonTransactionalStateTaskStorageManager(
-  taskName: TaskName,
-  containerStorageManager: ContainerStorageManager,
-  storeChangelogs: Map[String, SystemStream] = Map(),
-  systemAdmins: SystemAdmins,
-  loggedStoreBaseDir: File = new File(System.getProperty("user.dir"), "state"),
-  partition: Partition) extends Logging with TaskStorageManager {
-
-  private val storageManagerUtil = new StorageManagerUtil
-  private val persistedStores = containerStorageManager.getAllStores(taskName).asScala
-    .filter { case (storeName, storageEngine) => storageEngine.getStoreProperties.isPersistedToDisk }
-
-  def getStore(storeName: String): Option[StorageEngine] =  JavaOptionals.toRichOptional(containerStorageManager.getStore(taskName, storeName)).toOption
-
-  def flush(): Map[SystemStreamPartition, Option[String]] = {
-    debug("Flushing stores.")
-    containerStorageManager.getAllStores(taskName).asScala.values.foreach(_.flush)
-    val newestChangelogSSPOffsets = getNewestChangelogSSPOffsets()
-    writeChangelogOffsetFiles(newestChangelogSSPOffsets)
-    newestChangelogSSPOffsets
-  }
-
-  override def checkpoint(checkpointId: CheckpointId,
-    newestChangelogOffsets: Map[SystemStreamPartition, Option[String]]): Unit = {}
-
-  override def removeOldCheckpoints(checkpointId: CheckpointId): Unit = {}
-
-  @VisibleForTesting
-  def stop() {
-    debug("Stopping stores.")
-    containerStorageManager.stopStores()
-  }
-
-  /**
-   * Returns the newest offset for each store changelog SSP for this task.
-   * @return A map of changelog SSPs for this task to their newest offset (or None if ssp is empty)
-   * @throws SamzaException if there was an error fetching newest offset for any SSP
-   */
-  private def getNewestChangelogSSPOffsets(): Map[SystemStreamPartition, Option[String]] = {
-    storeChangelogs
-      .map { case (storeName, systemStream) => {
-        debug("Fetching newest offset for taskName %s store %s changelog %s" format (taskName, storeName, systemStream))
-        val ssp = new SystemStreamPartition(systemStream.getSystem, systemStream.getStream, partition)
-        val systemAdmin = systemAdmins.getSystemAdmin(systemStream.getSystem)
-
-        try {
-          val sspMetadataOption = Option(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp)).get(ssp))
-
-          // newest offset == null implies topic is empty
-          val newestOffsetOption = sspMetadataOption.flatMap(sspMetadata => Option(sspMetadata.getNewestOffset))
-          newestOffsetOption.foreach(newestOffset =>
-            debug("Got newest offset %s for taskName %s store %s changelog %s" format(newestOffset, taskName, storeName, systemStream)))
-
-          (ssp, newestOffsetOption)
-        } catch {
-          case e: Exception =>
-            throw new SamzaException("Error getting newest changelog offset for taskName %s store %s changelog %s."
-              format(taskName, storeName, systemStream), e)
-        }
-      }}
-  }
-
-  /**
-   * Writes the newest changelog ssp offset for each persistent store to the OFFSET file on disk.
-   * These files are used during container startup to determine whether there is any new information in the
-   * changelog that is not reflected in the on-disk copy of the store. If there is any delta, it is replayed
-   * from the changelog e.g. This can happen if the job was run on this host, then another
-   * host and back to this host.
-   */
-  private def writeChangelogOffsetFiles(newestChangelogOffsets: Map[SystemStreamPartition, Option[String]]) {
-    debug("Writing OFFSET files for logged persistent key value stores for task %s." format(taskName))
-
-    storeChangelogs
-      .filterKeys(storeName => persistedStores.contains(storeName))
-      .foreach { case (storeName, systemStream) => {
-        debug("Writing changelog offset for taskName %s store %s changelog %s." format(taskName, storeName, systemStream))
-        val currentStoreDir = storageManagerUtil.getTaskStoreDir(loggedStoreBaseDir, storeName, taskName, TaskMode.Active)
-        try {
-          val ssp = new SystemStreamPartition(systemStream.getSystem, systemStream.getStream, partition)
-          newestChangelogOffsets(ssp) match {
-            case Some(newestOffset) => {
-              debug("Storing newest offset %s for taskName %s store %s changelog %s in OFFSET file."
-                format(newestOffset, taskName, storeName, systemStream))
-              // TaskStorageManagers are only created for active tasks
-              storageManagerUtil.writeOffsetFile(currentStoreDir, Map(ssp -> newestOffset).asJava, false)
-              debug("Successfully stored offset %s for taskName %s store %s changelog %s in OFFSET file."
-                format(newestOffset, taskName, storeName, systemStream))
-            }
-            case None => {
-              // if newestOffset is null, then it means the changelog ssp is (or has become) empty. This could be
-              // either because the changelog topic was newly added, repartitioned, or manually deleted and recreated.
-              // No need to persist the offset file.
-              storageManagerUtil.deleteOffsetFile(currentStoreDir)
-              debug("Deleting OFFSET file for taskName %s store %s changelog ssp %s since the newestOffset is null."
-                format (taskName, storeName, ssp))
-            }
-          }
-        } catch {
-          case e: Exception =>
-            throw new SamzaException("Error storing offset for taskName %s store %s changelog %s."
-              format(taskName, storeName, systemStream), e)
-        }
-      }}
-    debug("Done writing OFFSET files for logged persistent key value stores for task %s" format(taskName))
-  }
-}
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/TransactionalStateTaskStorageManager.scala b/samza-core/src/main/scala/org/apache/samza/storage/TransactionalStateTaskStorageManager.scala
deleted file mode 100644
index 0335710..0000000
--- a/samza-core/src/main/scala/org/apache/samza/storage/TransactionalStateTaskStorageManager.scala
+++ /dev/null
@@ -1,201 +0,0 @@
-package org.apache.samza.storage
-
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-import java.io._
-import java.nio.file.Path
-
-import com.google.common.annotations.VisibleForTesting
-import com.google.common.collect.ImmutableSet
-import org.apache.commons.io.FileUtils
-import org.apache.commons.io.filefilter.WildcardFileFilter
-import org.apache.samza.checkpoint.CheckpointId
-import org.apache.samza.{Partition, SamzaException}
-import org.apache.samza.container.TaskName
-import org.apache.samza.job.model.TaskMode
-import org.apache.samza.system._
-import org.apache.samza.util.ScalaJavaUtil.JavaOptionals
-import org.apache.samza.util.Logging
-
-import scala.collection.JavaConverters._
-
-/**
- * Manage all the storage engines for a given task
- */
-class TransactionalStateTaskStorageManager(
-  taskName: TaskName,
-  containerStorageManager: ContainerStorageManager,
-  storeChangelogs: Map[String, SystemStream] = Map(),
-  systemAdmins: SystemAdmins,
-  loggedStoreBaseDir: File = new File(System.getProperty("user.dir"), "state"),
-  partition: Partition,
-  taskMode: TaskMode,
-  storageManagerUtil: StorageManagerUtil) extends Logging with TaskStorageManager {
-
-  def getStore(storeName: String): Option[StorageEngine] =  JavaOptionals.toRichOptional(containerStorageManager.getStore(taskName, storeName)).toOption
-
-  def flush(): Map[SystemStreamPartition, Option[String]] = {
-    debug("Flushing stores.")
-    containerStorageManager.getAllStores(taskName).asScala.values.foreach(_.flush)
-    getNewestChangelogSSPOffsets(taskName, storeChangelogs, partition, systemAdmins)
-  }
-
-  def checkpoint(checkpointId: CheckpointId, newestChangelogOffsets: Map[SystemStreamPartition, Option[String]]): Unit = {
-    debug("Checkpointing stores.")
-
-    val checkpointPaths = containerStorageManager.getAllStores(taskName).asScala
-      .filter { case (storeName, storeEngine) =>
-        storeEngine.getStoreProperties.isLoggedStore && storeEngine.getStoreProperties.isPersistedToDisk}
-      .flatMap { case (storeName, storeEngine) => {
-        val pathOptional = storeEngine.checkpoint(checkpointId)
-        if (pathOptional.isPresent) {
-          Some(storeName, pathOptional.get())
-        } else {
-          None
-        }
-      }}
-      .toMap
-
-    writeChangelogOffsetFiles(checkpointPaths, storeChangelogs, newestChangelogOffsets)
-  }
-
-  def removeOldCheckpoints(latestCheckpointId: CheckpointId): Unit = {
-    if (latestCheckpointId != null) {
-      debug("Removing older checkpoints before " + latestCheckpointId)
-
-      val files = loggedStoreBaseDir.listFiles()
-      if (files != null) {
-        files
-          .foreach(storeDir => {
-            val storeName = storeDir.getName
-            val taskStoreName = storageManagerUtil.getTaskStoreDir(
-              loggedStoreBaseDir, storeName, taskName, taskMode).getName
-            val fileFilter: FileFilter = new WildcardFileFilter(taskStoreName + "-*")
-            val checkpointDirs = storeDir.listFiles(fileFilter)
-
-            if (checkpointDirs != null) {
-              checkpointDirs
-                .filter(!_.getName.contains(latestCheckpointId.toString))
-                .foreach(checkpointDir => {
-                  FileUtils.deleteDirectory(checkpointDir)
-                })
-            }
-          })
-      }
-    }
-  }
-
-  @VisibleForTesting
-  def stop() {
-    debug("Stopping stores.")
-    containerStorageManager.stopStores()
-  }
-
-  /**
-   * Returns the newest offset for each store changelog SSP for this task. Returned map will
-   * always contain an entry for every changelog SSP.
-   * @return A map of changelog SSPs for this task to their newest offset (or None if ssp is empty)
-   * @throws SamzaException if there was an error fetching newest offset for any SSP
-   */
-  @VisibleForTesting
-  def getNewestChangelogSSPOffsets(taskName: TaskName, storeChangelogs: Map[String, SystemStream],
-      partition: Partition, systemAdmins: SystemAdmins): Map[SystemStreamPartition, Option[String]] = {
-    storeChangelogs
-      .map { case (storeName, systemStream) => {
-        try {
-          debug("Fetching newest offset for taskName %s store %s changelog %s" format (taskName, storeName, systemStream))
-          val ssp = new SystemStreamPartition(systemStream.getSystem, systemStream.getStream, partition)
-          val systemAdmin = systemAdmins.getSystemAdmin(systemStream.getSystem)
-
-          val sspMetadata = Option(systemAdmin.getSSPMetadata(ImmutableSet.of(ssp)).get(ssp))
-            .getOrElse(throw new SamzaException("Received null metadata for ssp: %s" format ssp))
-
-          // newest offset == null implies topic is empty
-          val newestOffsetOption = Option(sspMetadata.getNewestOffset)
-          newestOffsetOption.foreach(newestOffset =>
-            debug("Got newest offset %s for taskName %s store %s changelog %s" format(newestOffset, taskName, storeName, systemStream)))
-
-          (ssp, newestOffsetOption)
-        } catch {
-          case e: Exception =>
-            throw new SamzaException("Error getting newest changelog offset for taskName %s store %s changelog %s."
-              format(taskName, storeName, systemStream), e)
-        }
-      }}
-      .toMap
-  }
-
-  /**
-   * Writes the newest changelog ssp offset for each persistent store the OFFSET file in both the checkpoint
-   * and the current store directory (the latter for allowing rollbacks).
-   *
-   * These files are used during container startup to ensure transactional state, and to determine whether the
-   * there is any new information in the changelog that is not reflected in the on-disk copy of the store.
-   * If there is any delta, it is replayed from the changelog e.g. This can happen if the job was run on this host,
-   * then another host, and then back to this host.
-   */
-  @VisibleForTesting
-  def writeChangelogOffsetFiles(checkpointPaths: Map[String, Path], storeChangelogs: Map[String, SystemStream],
-      newestChangelogOffsets: Map[SystemStreamPartition, Option[String]]): Unit = {
-    debug("Writing OFFSET files for logged persistent key value stores for task %s." format(checkpointPaths))
-
-    storeChangelogs
-      .filterKeys(storeName => checkpointPaths.contains(storeName))
-      .foreach { case (storeName, systemStream) => {
-        try {
-          val ssp = new SystemStreamPartition(systemStream.getSystem, systemStream.getStream, partition)
-          val currentStoreDir = storageManagerUtil.getTaskStoreDir(loggedStoreBaseDir, storeName, taskName, TaskMode.Active)
-          newestChangelogOffsets(ssp) match {
-            case Some(newestOffset) => {
-              // write the offset file for the checkpoint directory
-              val checkpointPath = checkpointPaths(storeName)
-              writeChangelogOffsetFile(storeName, ssp, newestOffset, checkpointPath.toFile)
-              // write the OFFSET file for the current store (for backwards compatibility / allowing rollbacks)
-              writeChangelogOffsetFile(storeName, ssp, newestOffset, currentStoreDir)
-            }
-            case None => {
-              // retain existing behavior for current store directory for backwards compatibility / allowing rollbacks
-
-              // if newestOffset is null, then it means the changelog ssp is (or has become) empty. This could be
-              // either because the changelog topic was newly added, repartitioned, or manually deleted and recreated.
-              // No need to persist the offset file.
-              storageManagerUtil.deleteOffsetFile(currentStoreDir)
-              debug("Deleting OFFSET file for taskName %s current store %s changelog ssp %s since the newestOffset is null."
-                format (taskName, storeName, ssp))
-            }
-          }
-        } catch {
-          case e: Exception =>
-            throw new SamzaException("Error storing offset for taskName %s store %s changelog %s."
-              format(taskName, storeName, systemStream), e)
-        }
-      }}
-    debug("Done writing OFFSET files for logged persistent key value stores for task %s" format(taskName))
-  }
-
-  private def writeChangelogOffsetFile(storeName: String, ssp: SystemStreamPartition,
-      newestOffset: String, dir: File): Unit = {
-    debug("Storing newest offset: %s for taskName: %s store: %s changelog: %s in OFFSET file at path: %s."
-      format(newestOffset, taskName, storeName, ssp, dir))
-    storageManagerUtil.writeOffsetFile(dir, Map(ssp -> newestOffset).asJava, false)
-    debug("Successfully stored offset: %s for taskName: %s store: %s changelog: %s in OFFSET file at path: %s."
-      format(newestOffset, taskName, storeName, ssp, dir))
-  }
-}
diff --git a/samza-core/src/test/java/org/apache/samza/config/TestStorageConfig.java b/samza-core/src/test/java/org/apache/samza/config/TestStorageConfig.java
index 88fbbe0..e634940 100644
--- a/samza-core/src/test/java/org/apache/samza/config/TestStorageConfig.java
+++ b/samza-core/src/test/java/org/apache/samza/config/TestStorageConfig.java
@@ -41,6 +41,8 @@
 public class TestStorageConfig {
   private static final String STORE_NAME0 = "store0";
   private static final String STORE_NAME1 = "store1";
+  private static final String STORE_NAME2 = "store2";
+  private static final String STORE_NAME3 = "store3";
 
   @Test
   public void testGetStoreNames() {
@@ -138,6 +140,68 @@
   }
 
   @Test
+  public void testGetBackupManagerFactories() {
+    String factory1 = "factory1";
+    String factory2 = "factory2";
+    String factory3 = "factory3";
+    StorageConfig storageConfig = new StorageConfig(new MapConfig(
+        ImmutableMap.of(
+            String.format(STORE_BACKEND_BACKUP_FACTORIES, STORE_NAME0), factory1 + "," + factory2,
+            String.format(STORE_BACKEND_BACKUP_FACTORIES, STORE_NAME1), factory1,
+            String.format(STORE_BACKEND_BACKUP_FACTORIES, STORE_NAME2), factory3,
+            // store_name3 should use DEFAULT_STATE_BACKEND_FACTORY due to changelog presence
+            String.format(CHANGELOG_STREAM, STORE_NAME3), "nondefault-changelog-system.streamName"),
+        ImmutableMap.of(
+            String.format(FACTORY, STORE_NAME0), "store0.factory.class",
+            String.format(FACTORY, STORE_NAME1), "store1.factory.class",
+            String.format(FACTORY, STORE_NAME2), "store2.factory.class",
+            String.format(FACTORY, STORE_NAME3), "store3.factory.class",
+            // this store should have no backend factory configured
+            String.format(FACTORY, "noFactoryStore"), "noFactory.factory.class"
+            )
+        ));
+    Set<String> factories = storageConfig.getStateBackendBackupFactories();
+    assertTrue(factories.contains(factory1));
+    assertTrue(factories.contains(factory2));
+    assertTrue(factories.contains(factory3));
+    assertTrue(factories.contains(DEFAULT_STATE_BACKEND_FACTORY));
+    assertEquals(4, factories.size());
+    assertEquals(ImmutableList.of(factory1, factory2), storageConfig.getStoreBackupManagerClassName(STORE_NAME0));
+    assertEquals(ImmutableList.of(factory1), storageConfig.getStoreBackupManagerClassName(STORE_NAME1));
+    assertEquals(ImmutableList.of(factory3), storageConfig.getStoreBackupManagerClassName(STORE_NAME2));
+    assertEquals(DEFAULT_STATE_BACKEND_BACKUP_FACTORIES, storageConfig.getStoreBackupManagerClassName(STORE_NAME3));
+    assertTrue(storageConfig.getStoreBackupManagerClassName("emptyStore").isEmpty());
+    assertTrue(storageConfig.getStoreBackupManagerClassName("noFactoryStore").isEmpty());
+  }
+
+  @Test
+  public void testGetStoreToBackup() {
+    String targetFactory = "target.class";
+    StorageConfig config = new StorageConfig(new MapConfig(
+        ImmutableMap.of(
+            String.format(StorageConfig.STORE_BACKEND_BACKUP_FACTORIES, STORE_NAME0), targetFactory,
+            String.format(StorageConfig.STORE_BACKEND_BACKUP_FACTORIES, STORE_NAME1), targetFactory + "," +
+                DEFAULT_STATE_BACKEND_FACTORY,
+            String.format(StorageConfig.STORE_BACKEND_BACKUP_FACTORIES, STORE_NAME2), DEFAULT_STATE_BACKEND_FACTORY),
+        ImmutableMap.of(
+            String.format(FACTORY, STORE_NAME0), "store0.factory.class",
+            String.format(FACTORY, STORE_NAME1), "store1.factory.class",
+            String.format(FACTORY, STORE_NAME2), "store2.factory.class",
+            String.format(FACTORY, STORE_NAME3), "store3.factory.class",
+            String.format(CHANGELOG_STREAM, STORE_NAME3), "nondefault-changelog-system.streamName"
+        )
+    ));
+
+    List<String> targetStoreNames = config.getBackupStoreNamesForStateBackupFactory(targetFactory);
+    List<String> defaultStoreNames = config.getBackupStoreNamesForStateBackupFactory(
+        DEFAULT_STATE_BACKEND_FACTORY);
+    assertTrue(targetStoreNames.containsAll(ImmutableList.of(STORE_NAME0, STORE_NAME1)));
+    assertEquals(2, targetStoreNames.size());
+    assertTrue(defaultStoreNames.containsAll(ImmutableList.of(STORE_NAME2, STORE_NAME1, STORE_NAME3)));
+    assertEquals(3, defaultStoreNames.size());
+  }
+
+  @Test
   public void testGetAccessLogEnabled() {
     // empty config, access log disabled
     assertFalse(new StorageConfig(new MapConfig()).getAccessLogEnabled(STORE_NAME0));
diff --git a/samza-core/src/test/java/org/apache/samza/storage/TestKafkaChangelogStateBackendFactory.java b/samza-core/src/test/java/org/apache/samza/storage/TestKafkaChangelogStateBackendFactory.java
new file mode 100644
index 0000000..5782a75
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/storage/TestKafkaChangelogStateBackendFactory.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.Partition;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.context.ContainerContextImpl;
+import org.apache.samza.job.model.ContainerModel;
+import org.apache.samza.job.model.TaskModel;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamPartition;
+import org.junit.Assert;
+import org.junit.Test;
+
+
+public class TestKafkaChangelogStateBackendFactory {
+
+  @Test
+  public void testGetChangelogSSP() {
+    KafkaChangelogStateBackendFactory factory = new KafkaChangelogStateBackendFactory();
+    TaskName taskName0 = new TaskName("task0");
+    TaskName taskName1 = new TaskName("task1");
+    TaskModel taskModel0 = new TaskModel(taskName0,
+        ImmutableSet.of(new SystemStreamPartition("input", "stream", new Partition(0))),
+        new Partition(10));
+    TaskModel taskModel1 = new TaskModel(taskName1,
+        ImmutableSet.of(new SystemStreamPartition("input", "stream", new Partition(1))), new Partition(11));
+    ContainerModel containerModel = new ContainerModel("processorId",
+        ImmutableMap.of(taskName0, taskModel0, taskName1, taskModel1));
+    Map<String, SystemStream> changeLogSystemStreams = ImmutableMap.of(
+        "store0", new SystemStream("changelogSystem0", "store0-changelog"),
+        "store1", new SystemStream("changelogSystem1", "store1-changelog"));
+    Set<SystemStreamPartition> expected = ImmutableSet.of(
+        new SystemStreamPartition("changelogSystem0", "store0-changelog", new Partition(10)),
+        new SystemStreamPartition("changelogSystem1", "store1-changelog", new Partition(10)),
+        new SystemStreamPartition("changelogSystem0", "store0-changelog", new Partition(11)),
+        new SystemStreamPartition("changelogSystem1", "store1-changelog", new Partition(11)));
+    Assert.assertEquals(expected, factory.getChangelogSSPForContainer(changeLogSystemStreams,
+        new ContainerContextImpl(containerModel, null)));
+  }
+
+  @Test
+  public void testGetChangelogSSPsForContainerNoChangelogs() {
+    KafkaChangelogStateBackendFactory factory = new KafkaChangelogStateBackendFactory();
+    TaskName taskName0 = new TaskName("task0");
+    TaskName taskName1 = new TaskName("task1");
+    TaskModel taskModel0 = new TaskModel(taskName0,
+        ImmutableSet.of(new SystemStreamPartition("input", "stream", new Partition(0))),
+        new Partition(10));
+    TaskModel taskModel1 = new TaskModel(taskName1,
+        ImmutableSet.of(new SystemStreamPartition("input", "stream", new Partition(1))),
+        new Partition(11));
+    ContainerModel containerModel = new ContainerModel("processorId",
+        ImmutableMap.of(taskName0, taskModel0, taskName1, taskModel1));
+    Assert.assertEquals(Collections.emptySet(), factory.getChangelogSSPForContainer(Collections.emptyMap(),
+        new ContainerContextImpl(containerModel, null)));
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/storage/TestTaskStorageCommitManager.java b/samza-core/src/test/java/org/apache/samza/storage/TestTaskStorageCommitManager.java
new file mode 100644
index 0000000..e682aef
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/storage/TestTaskStorageCommitManager.java
@@ -0,0 +1,878 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage;
+
+import com.google.common.collect.ImmutableMap;
+import java.io.File;
+import java.io.FileFilter;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ForkJoinPool;
+import org.apache.samza.Partition;
+import org.apache.samza.SamzaException;
+import org.apache.samza.checkpoint.Checkpoint;
+import org.apache.samza.checkpoint.CheckpointId;
+import org.apache.samza.checkpoint.CheckpointManager;
+import org.apache.samza.checkpoint.CheckpointV1;
+import org.apache.samza.checkpoint.CheckpointV2;
+import org.apache.samza.checkpoint.kafka.KafkaChangelogSSPOffset;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.container.TaskInstanceMetrics;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.job.model.TaskMode;
+import org.apache.samza.metrics.Timer;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamPartition;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.*;
+
+
+public class TestTaskStorageCommitManager {
+  @Test
+  public void testCommitManagerStart() {
+    CheckpointManager checkpointManager = mock(CheckpointManager.class);
+    TaskBackupManager taskBackupManager1 = mock(TaskBackupManager.class);
+    TaskBackupManager taskBackupManager2 = mock(TaskBackupManager.class);
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+    Checkpoint checkpoint = mock(Checkpoint.class);
+
+    TaskName taskName = new TaskName("task1");
+    Map<String, TaskBackupManager> backupManagers = ImmutableMap.of(
+        "factory1", taskBackupManager1,
+        "factory2", taskBackupManager2
+    );
+    TaskStorageCommitManager cm = new TaskStorageCommitManager(taskName, backupManagers, containerStorageManager,
+        Collections.emptyMap(), new Partition(1), checkpointManager, new MapConfig(),
+        ForkJoinPool.commonPool(), new StorageManagerUtil(), null, null);
+
+    when(checkpointManager.readLastCheckpoint(taskName)).thenReturn(checkpoint);
+    cm.init();
+    verify(taskBackupManager1).init(eq(checkpoint));
+    verify(taskBackupManager2).init(eq(checkpoint));
+  }
+
+  @Test
+  public void testCommitManagerStartNullCheckpointManager() {
+    TaskBackupManager taskBackupManager1 = mock(TaskBackupManager.class);
+    TaskBackupManager taskBackupManager2 = mock(TaskBackupManager.class);
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+
+    TaskName task = new TaskName("task1");
+    Map<String, TaskBackupManager> backupManagers = ImmutableMap.of(
+        "factory1", taskBackupManager1,
+        "factory2", taskBackupManager2
+    );
+    TaskStorageCommitManager cm = new TaskStorageCommitManager(task, backupManagers, containerStorageManager,
+        Collections.emptyMap(), new Partition(1), null, new MapConfig(),
+        ForkJoinPool.commonPool(), new StorageManagerUtil(), null, null);
+    cm.init();
+    verify(taskBackupManager1).init(eq(null));
+    verify(taskBackupManager2).init(eq(null));
+  }
+
+  @Test
+  public void testSnapshotAndCommitAllFactories() {
+    CheckpointManager checkpointManager = mock(CheckpointManager.class);
+    TaskBackupManager taskBackupManager1 = mock(TaskBackupManager.class);
+    TaskBackupManager taskBackupManager2 = mock(TaskBackupManager.class);
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+    Checkpoint checkpoint = mock(Checkpoint.class);
+    TaskInstanceMetrics metrics = mock(TaskInstanceMetrics.class);
+    Timer checkpointTimer = mock(Timer.class);
+    when(metrics.storeCheckpointNs()).thenReturn(checkpointTimer);
+
+    TaskName taskName = new TaskName("task1");
+    Map<String, TaskBackupManager> backupManagers = ImmutableMap.of(
+        "factory1", taskBackupManager1,
+        "factory2", taskBackupManager2
+    );
+    TaskStorageCommitManager cm = new TaskStorageCommitManager(taskName, backupManagers, containerStorageManager,
+        Collections.emptyMap(), new Partition(1), checkpointManager, new MapConfig(),
+        ForkJoinPool.commonPool(), new StorageManagerUtil(), null, metrics);
+    when(checkpointManager.readLastCheckpoint(taskName)).thenReturn(checkpoint);
+    cm.init();
+    verify(taskBackupManager1).init(eq(checkpoint));
+    verify(taskBackupManager2).init(eq(checkpoint));
+
+    CheckpointId newCheckpointId = CheckpointId.create();
+    Map<String, String> factory1Checkpoints = ImmutableMap.of(
+        "store1", "system;stream;1",
+        "store2", "system;stream;2"
+    );
+    Map<String, String> factory2Checkpoints = ImmutableMap.of(
+        "store1", "blobId1",
+        "store2", "blobId2"
+    );
+
+    when(containerStorageManager.getAllStores(taskName)).thenReturn(Collections.emptyMap());
+    when(taskBackupManager1.snapshot(newCheckpointId)).thenReturn(factory1Checkpoints);
+    when(taskBackupManager2.snapshot(newCheckpointId)).thenReturn(factory2Checkpoints);
+
+    when(taskBackupManager1.upload(newCheckpointId, factory1Checkpoints))
+        .thenReturn(CompletableFuture.completedFuture(factory1Checkpoints));
+    when(taskBackupManager2.upload(newCheckpointId, factory2Checkpoints))
+        .thenReturn(CompletableFuture.completedFuture(factory2Checkpoints));
+
+    Map<String, Map<String, String>> snapshotSCMs = cm.snapshot(newCheckpointId);
+    cm.upload(newCheckpointId, snapshotSCMs);
+
+    // Test flow for snapshot
+    verify(taskBackupManager1).snapshot(newCheckpointId);
+    verify(taskBackupManager2).snapshot(newCheckpointId);
+
+    // Test flow for upload
+    verify(taskBackupManager1).upload(newCheckpointId, factory1Checkpoints);
+    verify(taskBackupManager2).upload(newCheckpointId, factory2Checkpoints);
+    verify(checkpointTimer).update(anyLong());
+  }
+
+  @Test
+  public void testFlushAndCheckpointOnSnapshot() {
+    CheckpointManager checkpointManager = mock(CheckpointManager.class);
+    TaskBackupManager taskBackupManager1 = mock(TaskBackupManager.class);
+    TaskBackupManager taskBackupManager2 = mock(TaskBackupManager.class);
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+    Checkpoint checkpoint = mock(Checkpoint.class);
+
+    StorageEngine mockLPStore = mock(StorageEngine.class);
+    StoreProperties lpStoreProps = mock(StoreProperties.class);
+    when(mockLPStore.getStoreProperties()).thenReturn(lpStoreProps);
+    when(lpStoreProps.isPersistedToDisk()).thenReturn(true);
+    when(lpStoreProps.isDurableStore()).thenReturn(true);
+    Path mockPath = mock(Path.class);
+    when(mockLPStore.checkpoint(any())).thenReturn(Optional.of(mockPath));
+
+    StorageEngine mockPStore = mock(StorageEngine.class);
+    StoreProperties pStoreProps = mock(StoreProperties.class);
+    when(mockPStore.getStoreProperties()).thenReturn(pStoreProps);
+    when(pStoreProps.isPersistedToDisk()).thenReturn(true);
+    when(pStoreProps.isDurableStore()).thenReturn(false);
+
+    StorageEngine mockLIStore = mock(StorageEngine.class);
+    StoreProperties liStoreProps = mock(StoreProperties.class);
+    when(mockLIStore.getStoreProperties()).thenReturn(liStoreProps);
+    when(liStoreProps.isPersistedToDisk()).thenReturn(false);
+    when(liStoreProps.isDurableStore()).thenReturn(true);
+
+    StorageEngine mockIStore = mock(StorageEngine.class);
+    StoreProperties iStoreProps = mock(StoreProperties.class);
+    when(mockIStore.getStoreProperties()).thenReturn(iStoreProps);
+    when(iStoreProps.isPersistedToDisk()).thenReturn(false);
+    when(iStoreProps.isDurableStore()).thenReturn(false);
+
+    TaskName taskName = new TaskName("task1");
+    Map<String, TaskBackupManager> backupManagers = ImmutableMap.of(
+        "factory1", taskBackupManager1,
+        "factory2", taskBackupManager2
+    );
+    Map<String, StorageEngine> storageEngines = ImmutableMap.of(
+        "storeLP", mockLPStore,
+        "storeP", mockPStore,
+        "storeLI", mockLIStore,
+        "storeI", mockIStore
+    );
+
+    TaskInstanceMetrics metrics = mock(TaskInstanceMetrics.class);
+    Timer checkpointTimer = mock(Timer.class);
+    when(metrics.storeCheckpointNs()).thenReturn(checkpointTimer);
+
+    TaskStorageCommitManager cm = new TaskStorageCommitManager(taskName, backupManagers, containerStorageManager,
+        Collections.emptyMap(), new Partition(1), checkpointManager, new MapConfig(),
+        ForkJoinPool.commonPool(), new StorageManagerUtil(), null, metrics);
+    when(checkpointManager.readLastCheckpoint(taskName)).thenReturn(checkpoint);
+    cm.init();
+    verify(taskBackupManager1).init(eq(checkpoint));
+    verify(taskBackupManager2).init(eq(checkpoint));
+
+    CheckpointId newCheckpointId = CheckpointId.create();
+    Map<String, String> factory1Checkpoints = ImmutableMap.of(
+        "store1", "system;stream;1",
+        "store2", "system;stream;2"
+    );
+    Map<String, String> factory2Checkpoints = ImmutableMap.of(
+        "store1", "blobId1",
+        "store2", "blobId2"
+    );
+
+    when(containerStorageManager.getAllStores(taskName)).thenReturn(storageEngines);
+    when(taskBackupManager1.snapshot(newCheckpointId)).thenReturn(factory1Checkpoints);
+    when(taskBackupManager1.upload(newCheckpointId, factory1Checkpoints))
+        .thenReturn(CompletableFuture.completedFuture(factory1Checkpoints));
+    when(taskBackupManager2.snapshot(newCheckpointId)).thenReturn(factory2Checkpoints);
+    when(taskBackupManager2.upload(newCheckpointId, factory2Checkpoints))
+        .thenReturn(CompletableFuture.completedFuture(factory2Checkpoints));
+    when(mockLIStore.checkpoint(newCheckpointId)).thenReturn(Optional.empty());
+
+    cm.init();
+    cm.snapshot(newCheckpointId);
+
+    // Assert stores where flushed
+    verify(mockIStore).flush();
+    verify(mockPStore).flush();
+    verify(mockLIStore).flush();
+    verify(mockLPStore).flush();
+    // only logged and persisted stores are checkpointed
+    verify(mockLPStore).checkpoint(newCheckpointId);
+    // ensure that checkpoint is never called for non-logged persistent stores since they're
+    // always cleared on restart.
+    verify(mockPStore, never()).checkpoint(any());
+    // ensure that checkpoint is never called for non-persistent stores
+    verify(mockIStore, never()).checkpoint(any());
+    verify(mockLIStore, never()).checkpoint(any());
+    verify(checkpointTimer).update(anyLong());
+  }
+
+  @Test(expected = IllegalStateException.class)
+  public void testSnapshotFailsIfErrorCreatingCheckpoint() {
+    CheckpointManager checkpointManager = mock(CheckpointManager.class);
+    TaskBackupManager taskBackupManager1 = mock(TaskBackupManager.class);
+    TaskBackupManager taskBackupManager2 = mock(TaskBackupManager.class);
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+
+    StorageEngine mockLPStore = mock(StorageEngine.class);
+    StoreProperties lpStoreProps = mock(StoreProperties.class);
+    when(mockLPStore.getStoreProperties()).thenReturn(lpStoreProps);
+    when(lpStoreProps.isPersistedToDisk()).thenReturn(true);
+    when(lpStoreProps.isDurableStore()).thenReturn(true);
+    when(mockLPStore.checkpoint(any())).thenThrow(new IllegalStateException());
+
+    TaskInstanceMetrics metrics = mock(TaskInstanceMetrics.class);
+    Timer checkpointTimer = mock(Timer.class);
+    when(metrics.storeCheckpointNs()).thenReturn(checkpointTimer);
+
+    TaskName taskName = new TaskName("task1");
+    Map<String, TaskBackupManager> backupManagers = ImmutableMap.of(
+        "factory1", taskBackupManager1,
+        "factory2", taskBackupManager2
+    );
+    Map<String, StorageEngine> storageEngines = ImmutableMap.of(
+        "storeLP", mockLPStore
+    );
+
+    TaskStorageCommitManager cm = new TaskStorageCommitManager(taskName, backupManagers, containerStorageManager,
+        Collections.emptyMap(), new Partition(1), checkpointManager, new MapConfig(),
+        ForkJoinPool.commonPool(), new StorageManagerUtil(), null, metrics);
+
+    when(containerStorageManager.getAllStores(taskName)).thenReturn(storageEngines);
+    CheckpointId newCheckpointId = CheckpointId.create();
+    cm.init();
+    cm.snapshot(newCheckpointId);
+
+    // Assert stores where flushed
+    verify(mockLPStore).flush();
+    // only logged and persisted stores are checkpointed
+    verify(mockLPStore).checkpoint(newCheckpointId);
+    verify(taskBackupManager1, never()).snapshot(any());
+    verify(taskBackupManager2, never()).snapshot(any());
+    verify(taskBackupManager1, never()).upload(any(), any());
+    verify(taskBackupManager2, never()).upload(any(), any());
+    fail("Should have thrown an exception when the storageEngine#checkpoint did not succeed");
+  }
+
+  @Test
+  public void testCleanupAllBackupManagers() {
+    CheckpointManager checkpointManager = mock(CheckpointManager.class);
+    TaskBackupManager taskBackupManager1 = mock(TaskBackupManager.class);
+    TaskBackupManager taskBackupManager2 = mock(TaskBackupManager.class);
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+    Checkpoint checkpoint = mock(Checkpoint.class);
+    File durableStoreDir = mock(File.class);
+    when(durableStoreDir.listFiles()).thenReturn(new File[0]);
+    TaskInstanceMetrics metrics = mock(TaskInstanceMetrics.class);
+    Timer checkpointTimer = mock(Timer.class);
+    when(metrics.storeCheckpointNs()).thenReturn(checkpointTimer);
+
+    TaskName taskName = new TaskName("task1");
+    Map<String, TaskBackupManager> backupManagers = ImmutableMap.of(
+        "factory1", taskBackupManager1,
+        "factory2", taskBackupManager2
+    );
+    TaskStorageCommitManager cm = new TaskStorageCommitManager(taskName, backupManagers, containerStorageManager,
+        Collections.emptyMap(),  new Partition(1), checkpointManager, new MapConfig(),
+        ForkJoinPool.commonPool(), new StorageManagerUtil(), durableStoreDir, metrics);
+    when(checkpointManager.readLastCheckpoint(taskName)).thenReturn(checkpoint);
+    when(containerStorageManager.getAllStores(taskName)).thenReturn(Collections.emptyMap());
+    when(taskBackupManager1.cleanUp(any(), any())).thenReturn(CompletableFuture.<Void>completedFuture(null));
+    when(taskBackupManager2.cleanUp(any(), any())).thenReturn(CompletableFuture.<Void>completedFuture(null));
+    Map<String, String> factory1Checkpoints = ImmutableMap.of(
+        "store1", "system;stream;1",
+        "store2", "system;stream;2"
+    );
+    Map<String, String> factory2Checkpoints = ImmutableMap.of(
+        "store1", "blobId1",
+        "store2", "blobId2"
+    );
+    Map<String, Map<String, String>> factoryCheckpointsMap = ImmutableMap.of(
+        "factory1", factory1Checkpoints,
+        "factory2", factory2Checkpoints
+    );
+
+    when(taskBackupManager1.cleanUp(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
+    when(taskBackupManager2.cleanUp(any(), any())).thenReturn(CompletableFuture.completedFuture(null));
+
+    CheckpointId newCheckpointId = CheckpointId.create();
+    cm.cleanUp(newCheckpointId, factoryCheckpointsMap).join();
+
+    verify(taskBackupManager1).cleanUp(newCheckpointId, factory1Checkpoints);
+    verify(taskBackupManager2).cleanUp(newCheckpointId, factory2Checkpoints);
+  }
+
+  @Test
+  public void testCleanupFailsIfBackupManagerNotInitiated() {
+    CheckpointManager checkpointManager = mock(CheckpointManager.class);
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+    Checkpoint checkpoint = mock(Checkpoint.class);
+    File durableStoreDir = mock(File.class);
+    when(durableStoreDir.listFiles()).thenReturn(new File[0]);
+    TaskInstanceMetrics metrics = mock(TaskInstanceMetrics.class);
+    Timer checkpointTimer = mock(Timer.class);
+    when(metrics.storeCheckpointNs()).thenReturn(checkpointTimer);
+
+    TaskName taskName = new TaskName("task1");
+    when(containerStorageManager.getAllStores(taskName)).thenReturn(Collections.emptyMap());
+    TaskStorageCommitManager cm = new TaskStorageCommitManager(taskName, Collections.emptyMap(), containerStorageManager,
+        Collections.emptyMap(),  new Partition(1), checkpointManager, new MapConfig(),
+        ForkJoinPool.commonPool(), new StorageManagerUtil(), durableStoreDir, metrics);
+    when(checkpointManager.readLastCheckpoint(taskName)).thenReturn(checkpoint);
+
+    Map<String, Map<String, String>> factoryCheckpointsMap = ImmutableMap.of(
+        "factory3", Collections.emptyMap() // factory 3 should be ignored
+    );
+
+    CheckpointId newCheckpointId = CheckpointId.create();
+    cm.cleanUp(newCheckpointId, factoryCheckpointsMap);
+    // should not fail the commit because the job should ignore any factories checkpoints not initialized
+    // in case the user is in a migration phase from on state backend to another
+  }
+
+  @Test
+  public void testPersistToFileSystemCheckpointV1AndV2Checkpoint() throws IOException {
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+    StorageEngine mockLPStore = mock(StorageEngine.class);
+    StoreProperties lpStoreProps = mock(StoreProperties.class);
+    when(mockLPStore.getStoreProperties()).thenReturn(lpStoreProps);
+    when(lpStoreProps.isPersistedToDisk()).thenReturn(true);
+    when(lpStoreProps.isDurableStore()).thenReturn(true);
+    Path mockPath = mock(Path.class);
+    when(mockLPStore.checkpoint(any())).thenReturn(Optional.of(mockPath));
+
+    StorageEngine mockPStore = mock(StorageEngine.class);
+    StoreProperties pStoreProps = mock(StoreProperties.class);
+    when(mockPStore.getStoreProperties()).thenReturn(pStoreProps);
+    when(pStoreProps.isPersistedToDisk()).thenReturn(true);
+    when(pStoreProps.isDurableStore()).thenReturn(false);
+
+    StorageEngine mockLIStore = mock(StorageEngine.class);
+    StoreProperties liStoreProps = mock(StoreProperties.class);
+    when(mockLIStore.getStoreProperties()).thenReturn(liStoreProps);
+    when(liStoreProps.isPersistedToDisk()).thenReturn(false);
+    when(liStoreProps.isDurableStore()).thenReturn(true);
+
+    StorageEngine mockIStore = mock(StorageEngine.class);
+    StoreProperties iStoreProps = mock(StoreProperties.class);
+    when(mockIStore.getStoreProperties()).thenReturn(iStoreProps);
+    when(iStoreProps.isPersistedToDisk()).thenReturn(false);
+    when(iStoreProps.isDurableStore()).thenReturn(false);
+
+    java.util.Map<String, StorageEngine> taskStores = ImmutableMap.of(
+        "loggedPersistentStore", mockLPStore,
+        "persistentStore", mockPStore,
+        "loggedInMemStore", mockLIStore,
+        "inMemStore", mockIStore
+    );
+
+    Partition changelogPartition = new Partition(0);
+    SystemStream changelogSystemStream = new SystemStream("changelogSystem", "changelogStream");
+    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
+    java.util.Map<String, SystemStream> storeChangelogsStreams = ImmutableMap.of(
+        "loggedPersistentStore", changelogSystemStream,
+        "loggedInMemStore", new SystemStream("system", "stream")
+    );
+
+    StorageManagerUtil storageManagerUtil = mock(StorageManagerUtil.class);
+    File durableStoreDir = new File("durableStorePath");
+    when(storageManagerUtil.getTaskStoreDir(eq(durableStoreDir), any(), any(), any())).thenReturn(durableStoreDir);
+    TaskName taskName = new TaskName("task");
+
+    TaskInstanceMetrics metrics = mock(TaskInstanceMetrics.class);
+    Timer checkpointTimer = mock(Timer.class);
+    when(metrics.storeCheckpointNs()).thenReturn(checkpointTimer);
+
+    when(containerStorageManager.getAllStores(taskName)).thenReturn(taskStores);
+    TaskStorageCommitManager commitManager = spy(new TaskStorageCommitManager(taskName,
+        Collections.emptyMap(), containerStorageManager, storeChangelogsStreams, changelogPartition,
+        null, null, ForkJoinPool.commonPool(), storageManagerUtil, durableStoreDir, metrics));
+    doNothing().when(commitManager).writeChangelogOffsetFile(any(), any(), any(), any());
+
+    CheckpointId newCheckpointId = CheckpointId.create();
+
+    String newestOffset = "1";
+    KafkaChangelogSSPOffset kafkaChangelogSSPOffset = new KafkaChangelogSSPOffset(newCheckpointId, newestOffset);
+    java.util.Map<SystemStreamPartition, String> offsetsJava = ImmutableMap.of(
+        changelogSSP, kafkaChangelogSSPOffset.toString()
+    );
+
+    commitManager.init();
+    // invoke persist to file system for v2 checkpoint
+    commitManager.writeCheckpointToStoreDirectories(new CheckpointV1(offsetsJava));
+
+    verify(commitManager).writeChangelogOffsetFiles(offsetsJava);
+    // evoked twice, for OFFSET-V1 and OFFSET-V2
+    verify(commitManager).writeChangelogOffsetFile(
+        eq("loggedPersistentStore"), eq(changelogSSP), eq(newestOffset), eq(durableStoreDir));
+    File checkpointFile = Paths.get(StorageManagerUtil
+        .getCheckpointDirPath(durableStoreDir, kafkaChangelogSSPOffset.getCheckpointId())).toFile();
+    verify(commitManager).writeChangelogOffsetFile(
+        eq("loggedPersistentStore"), eq(changelogSSP), eq(newestOffset), eq(checkpointFile));
+
+    java.util.Map<String, String> storeSCM = ImmutableMap.of(
+        "loggedPersistentStore", "system;loggedPersistentStoreStream;1",
+        "persistentStore", "system;persistentStoreStream;1",
+        "loggedInMemStore", "system;loggedInMemStoreStream;1",
+        "inMemStore", "system;inMemStoreStream;1"
+    );
+    CheckpointV2 checkpoint = new CheckpointV2(newCheckpointId, Collections.emptyMap(), Collections.singletonMap("factory", storeSCM));
+
+    // invoke persist to file system for v2 checkpoint
+    commitManager.writeCheckpointToStoreDirectories(checkpoint);
+    // Validate only durable and persisted stores are persisted
+    // This should be evoked twice, for checkpointV1 and checkpointV2
+    verify(storageManagerUtil, times(2)).getTaskStoreDir(eq(durableStoreDir), eq("loggedPersistentStore"), eq(taskName), any());
+    File checkpointPath = Paths.get(StorageManagerUtil.getCheckpointDirPath(durableStoreDir, newCheckpointId)).toFile();
+    verify(storageManagerUtil).writeCheckpointV2File(eq(checkpointPath), eq(checkpoint));
+  }
+
+  @Test
+  public void testPersistToFileSystemCheckpointV2Only() throws IOException {
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+    StorageEngine mockLPStore = mock(StorageEngine.class);
+    StoreProperties lpStoreProps = mock(StoreProperties.class);
+    when(mockLPStore.getStoreProperties()).thenReturn(lpStoreProps);
+    when(lpStoreProps.isPersistedToDisk()).thenReturn(true);
+    when(lpStoreProps.isDurableStore()).thenReturn(true);
+    Path mockPath = mock(Path.class);
+    when(mockLPStore.checkpoint(any())).thenReturn(Optional.of(mockPath));
+
+    StorageEngine mockPStore = mock(StorageEngine.class);
+    StoreProperties pStoreProps = mock(StoreProperties.class);
+    when(mockPStore.getStoreProperties()).thenReturn(pStoreProps);
+    when(pStoreProps.isPersistedToDisk()).thenReturn(true);
+    when(pStoreProps.isDurableStore()).thenReturn(false);
+
+    StorageEngine mockLIStore = mock(StorageEngine.class);
+    StoreProperties liStoreProps = mock(StoreProperties.class);
+    when(mockLIStore.getStoreProperties()).thenReturn(liStoreProps);
+    when(liStoreProps.isPersistedToDisk()).thenReturn(false);
+    when(liStoreProps.isDurableStore()).thenReturn(true);
+
+    StorageEngine mockIStore = mock(StorageEngine.class);
+    StoreProperties iStoreProps = mock(StoreProperties.class);
+    when(mockIStore.getStoreProperties()).thenReturn(iStoreProps);
+    when(iStoreProps.isPersistedToDisk()).thenReturn(false);
+    when(iStoreProps.isDurableStore()).thenReturn(false);
+
+    java.util.Map<String, StorageEngine> taskStores = ImmutableMap.of(
+        "loggedPersistentStore", mockLPStore,
+        "persistentStore", mockPStore,
+        "loggedInMemStore", mockLIStore,
+        "inMemStore", mockIStore
+    );
+
+    Partition changelogPartition = new Partition(0);
+    SystemStream changelogSystemStream = new SystemStream("changelogSystem", "changelogStream");
+    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
+    java.util.Map<String, SystemStream> storeChangelogsStreams = ImmutableMap.of(
+        "loggedPersistentStore", changelogSystemStream,
+        "loggedInMemStore", new SystemStream("system", "stream")
+    );
+
+    StorageManagerUtil storageManagerUtil = mock(StorageManagerUtil.class);
+    File durableStoreDir = new File("durableStorePath");
+    when(storageManagerUtil.getTaskStoreDir(eq(durableStoreDir), eq("loggedPersistentStore"), any(), any()))
+        .thenReturn(durableStoreDir);
+    TaskName taskName = new TaskName("task");
+
+    TaskInstanceMetrics metrics = mock(TaskInstanceMetrics.class);
+    Timer checkpointTimer = mock(Timer.class);
+    when(metrics.storeCheckpointNs()).thenReturn(checkpointTimer);
+
+    when(containerStorageManager.getAllStores(taskName)).thenReturn(taskStores);
+    TaskStorageCommitManager commitManager = spy(new TaskStorageCommitManager(taskName,
+        Collections.emptyMap(), containerStorageManager, storeChangelogsStreams, changelogPartition,
+        null, null, ForkJoinPool.commonPool(), storageManagerUtil, durableStoreDir, metrics));
+    doNothing().when(commitManager).writeChangelogOffsetFile(any(), any(), any(), any());
+
+    CheckpointId newCheckpointId = CheckpointId.create();
+
+    java.util.Map<String, String> storeSCM = ImmutableMap.of(
+        "loggedPersistentStore", "system;loggedPersistentStoreStream;1",
+        "persistentStore", "system;persistentStoreStream;1",
+        "loggedInMemStore", "system;loggedInMemStoreStream;1",
+        "inMemStore", "system;inMemStoreStream;1"
+    );
+    CheckpointV2 checkpoint = new CheckpointV2(newCheckpointId, Collections.emptyMap(), Collections.singletonMap("factory", storeSCM));
+
+    commitManager.init();
+    // invoke persist to file system
+    commitManager.writeCheckpointToStoreDirectories(checkpoint);
+    // Validate only durable and persisted stores are persisted
+    verify(storageManagerUtil).getTaskStoreDir(eq(durableStoreDir), eq("loggedPersistentStore"), eq(taskName), any());
+    File checkpointPath = Paths.get(StorageManagerUtil.getCheckpointDirPath(durableStoreDir, newCheckpointId)).toFile();
+    verify(storageManagerUtil).writeCheckpointV2File(eq(checkpointPath), eq(checkpoint));
+  }
+
+  @Test
+  public void testWriteChangelogOffsetFilesV1() throws IOException {
+    Map<String, Map<SystemStreamPartition, String>> mockFileSystem = new HashMap<>();
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+    StorageEngine mockLPStore = mock(StorageEngine.class);
+    StoreProperties lpStoreProps = mock(StoreProperties.class);
+    when(mockLPStore.getStoreProperties()).thenReturn(lpStoreProps);
+    when(lpStoreProps.isPersistedToDisk()).thenReturn(true);
+    when(lpStoreProps.isDurableStore()).thenReturn(true);
+    Path mockPath = mock(Path.class);
+    when(mockLPStore.checkpoint(any())).thenReturn(Optional.of(mockPath));
+    TaskInstanceMetrics metrics = mock(TaskInstanceMetrics.class);
+    Timer checkpointTimer = mock(Timer.class);
+    when(metrics.storeCheckpointNs()).thenReturn(checkpointTimer);
+
+    java.util.Map<String, StorageEngine> taskStores = ImmutableMap.of("loggedPersistentStore", mockLPStore);
+
+    Partition changelogPartition = new Partition(0);
+    SystemStream changelogSystemStream = new SystemStream("changelogSystem", "changelogStream");
+    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
+    java.util.Map<String, SystemStream> storeChangelogsStreams = ImmutableMap.of("loggedPersistentStore", changelogSystemStream);
+
+    StorageManagerUtil storageManagerUtil = mock(StorageManagerUtil.class);
+    File tmpTestPath = new File("store-checkpoint-test");
+    when(storageManagerUtil.getTaskStoreDir(eq(tmpTestPath), eq("loggedPersistentStore"), any(), any())).thenReturn(tmpTestPath);
+    TaskName taskName = new TaskName("task");
+
+    when(containerStorageManager.getAllStores(taskName)).thenReturn(taskStores);
+    TaskStorageCommitManager commitManager = spy(new TaskStorageCommitManager(taskName,
+        Collections.emptyMap(), containerStorageManager, storeChangelogsStreams, changelogPartition,
+        null, null, ForkJoinPool.commonPool(), storageManagerUtil, tmpTestPath, metrics));
+
+    doAnswer(invocation -> {
+      String fileDir = invocation.getArgumentAt(3, File.class).getName();
+      SystemStreamPartition ssp = invocation.getArgumentAt(1, SystemStreamPartition.class);
+      String offset = invocation.getArgumentAt(2, String.class);
+      if (mockFileSystem.containsKey(fileDir)) {
+        mockFileSystem.get(fileDir).put(ssp, offset);
+      } else {
+        Map<SystemStreamPartition, String> sspOffsets = new HashMap<>();
+        sspOffsets.put(ssp, offset);
+        mockFileSystem.put(fileDir, sspOffsets);
+      }
+      return null;
+    }).when(commitManager).writeChangelogOffsetFile(any(), any(), any(), any());
+
+    CheckpointId newCheckpointId = CheckpointId.create();
+
+    String newestOffset = "1";
+    KafkaChangelogSSPOffset kafkaChangelogSSPOffset = new KafkaChangelogSSPOffset(newCheckpointId, newestOffset);
+    java.util.Map<SystemStreamPartition, String> offsetsJava = ImmutableMap.of(
+        changelogSSP, kafkaChangelogSSPOffset.toString()
+    );
+
+    commitManager.init();
+    // invoke persist to file system for v2 checkpoint
+    commitManager.writeCheckpointToStoreDirectories(new CheckpointV1(offsetsJava));
+
+    assertEquals(2, mockFileSystem.size());
+    // check if v2 offsets are written correctly
+    String v2FilePath = StorageManagerUtil
+        .getCheckpointDirPath(tmpTestPath, newCheckpointId);
+    assertTrue(mockFileSystem.containsKey(v2FilePath));
+    assertTrue(mockFileSystem.get(v2FilePath).containsKey(changelogSSP));
+    assertEquals(1, mockFileSystem.get(v2FilePath).size());
+    assertEquals(newestOffset, mockFileSystem.get(v2FilePath).get(changelogSSP));
+    // check if v1 offsets are written correctly
+    String v1FilePath = tmpTestPath.getPath();
+    assertTrue(mockFileSystem.containsKey(v1FilePath));
+    assertTrue(mockFileSystem.get(v1FilePath).containsKey(changelogSSP));
+    assertEquals(1, mockFileSystem.get(v1FilePath).size());
+    assertEquals(newestOffset, mockFileSystem.get(v1FilePath).get(changelogSSP));
+  }
+
+  @Test
+  public void testWriteChangelogOffsetFilesV2andV1() throws IOException {
+    Map<String, Map<SystemStreamPartition, String>> mockFileSystem = new HashMap<>();
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+    Map<String, CheckpointV2> mockCheckpointFileSystem = new HashMap<>();
+    StorageEngine mockLPStore = mock(StorageEngine.class);
+    StoreProperties lpStoreProps = mock(StoreProperties.class);
+    when(mockLPStore.getStoreProperties()).thenReturn(lpStoreProps);
+    when(lpStoreProps.isPersistedToDisk()).thenReturn(true);
+    when(lpStoreProps.isDurableStore()).thenReturn(true);
+    Path mockPath = mock(Path.class);
+    when(mockLPStore.checkpoint(any())).thenReturn(Optional.of(mockPath));
+    TaskInstanceMetrics metrics = mock(TaskInstanceMetrics.class);
+    Timer checkpointTimer = mock(Timer.class);
+    when(metrics.storeCheckpointNs()).thenReturn(checkpointTimer);
+
+    java.util.Map<String, StorageEngine> taskStores = ImmutableMap.of("loggedPersistentStore", mockLPStore);
+
+    Partition changelogPartition = new Partition(0);
+    SystemStream changelogSystemStream = new SystemStream("changelogSystem", "changelogStream");
+    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
+    java.util.Map<String, SystemStream> storeChangelogsStreams = ImmutableMap.of("loggedPersistentStore", changelogSystemStream);
+
+    StorageManagerUtil storageManagerUtil = mock(StorageManagerUtil.class);
+    File tmpTestPath = new File("store-checkpoint-test");
+    when(storageManagerUtil.getTaskStoreDir(eq(tmpTestPath), eq("loggedPersistentStore"), any(), any())).thenReturn(tmpTestPath);
+    TaskName taskName = new TaskName("task");
+
+    when(containerStorageManager.getAllStores(taskName)).thenReturn(taskStores);
+    TaskStorageCommitManager commitManager = spy(new TaskStorageCommitManager(taskName,
+        Collections.emptyMap(), containerStorageManager, storeChangelogsStreams, changelogPartition,
+        null, null, ForkJoinPool.commonPool(), storageManagerUtil, tmpTestPath, metrics));
+
+    doAnswer(invocation -> {
+      String fileDir = invocation.getArgumentAt(3, File.class).getName();
+      SystemStreamPartition ssp = invocation.getArgumentAt(1, SystemStreamPartition.class);
+      String offset = invocation.getArgumentAt(2, String.class);
+      if (mockFileSystem.containsKey(fileDir)) {
+        mockFileSystem.get(fileDir).put(ssp, offset);
+      } else {
+        Map<SystemStreamPartition, String> sspOffsets = new HashMap<>();
+        sspOffsets.put(ssp, offset);
+        mockFileSystem.put(fileDir, sspOffsets);
+      }
+      return null;
+    }).when(commitManager).writeChangelogOffsetFile(any(), any(), any(), any());
+
+    doAnswer(invocation -> {
+      String storeDir = invocation.getArgumentAt(0, File.class).getName();
+      CheckpointV2 checkpointV2 = invocation.getArgumentAt(1, CheckpointV2.class);
+      mockCheckpointFileSystem.put(storeDir, checkpointV2);
+      return null;
+    }).when(storageManagerUtil).writeCheckpointV2File(any(), any());
+
+    CheckpointId newCheckpointId = CheckpointId.create();
+
+    String newestOffset = "1";
+    KafkaChangelogSSPOffset kafkaChangelogSSPOffset = new KafkaChangelogSSPOffset(newCheckpointId, newestOffset);
+    java.util.Map<SystemStreamPartition, String> offsetsJava = ImmutableMap.of(
+        changelogSSP, kafkaChangelogSSPOffset.toString()
+    );
+
+    commitManager.init();
+    // invoke persist to file system for v1 checkpoint
+    commitManager.writeCheckpointToStoreDirectories(new CheckpointV1(offsetsJava));
+
+    assertEquals(2, mockFileSystem.size());
+    // check if v2 offsets are written correctly
+    String v2FilePath = StorageManagerUtil
+        .getCheckpointDirPath(tmpTestPath, newCheckpointId);
+    assertTrue(mockFileSystem.containsKey(v2FilePath));
+    assertTrue(mockFileSystem.get(v2FilePath).containsKey(changelogSSP));
+    assertEquals(1, mockFileSystem.get(v2FilePath).size());
+    assertEquals(newestOffset, mockFileSystem.get(v2FilePath).get(changelogSSP));
+    // check if v1 offsets are written correctly
+    String v1FilePath = tmpTestPath.getPath();
+    assertTrue(mockFileSystem.containsKey(v1FilePath));
+    assertTrue(mockFileSystem.get(v1FilePath).containsKey(changelogSSP));
+    assertEquals(1, mockFileSystem.get(v1FilePath).size());
+    assertEquals(newestOffset, mockFileSystem.get(v1FilePath).get(changelogSSP));
+
+    java.util.Map<String, String> storeSCM = ImmutableMap.of(
+        "loggedPersistentStore", "system;loggedPersistentStoreStream;1",
+        "persistentStore", "system;persistentStoreStream;1",
+        "loggedInMemStore", "system;loggedInMemStoreStream;1",
+        "inMemStore", "system;inMemStoreStream;1"
+    );
+    CheckpointV2 checkpoint = new CheckpointV2(newCheckpointId, Collections.emptyMap(), Collections.singletonMap("factory", storeSCM));
+
+    // invoke persist to file system with checkpoint v2
+    commitManager.writeCheckpointToStoreDirectories(checkpoint);
+
+    assertTrue(mockCheckpointFileSystem.containsKey(v2FilePath));
+    assertEquals(checkpoint, mockCheckpointFileSystem.get(v2FilePath));
+    assertTrue(mockCheckpointFileSystem.containsKey(v1FilePath));
+    assertEquals(checkpoint, mockCheckpointFileSystem.get(v1FilePath));
+    assertEquals(2, mockCheckpointFileSystem.size());
+
+    CheckpointV2 updatedCheckpoint = new CheckpointV2(
+        newCheckpointId, ImmutableMap.of(
+        new SystemStreamPartition("inputSystem", "inputStream", changelogPartition), "5"),
+        Collections.singletonMap("factory", storeSCM));
+    commitManager.writeCheckpointToStoreDirectories(updatedCheckpoint);
+
+    assertEquals(updatedCheckpoint, mockCheckpointFileSystem.get(v2FilePath));
+    assertEquals(updatedCheckpoint, mockCheckpointFileSystem.get(v1FilePath));
+    assertEquals(2, mockCheckpointFileSystem.size());
+  }
+
+  @Test
+  public void testWriteChangelogOffsetFilesWithEmptyChangelogTopic() throws IOException {
+    Map<String, Map<SystemStreamPartition, String>> mockFileSystem = new HashMap<>();
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+    StorageEngine mockLPStore = mock(StorageEngine.class);
+    StoreProperties lpStoreProps = mock(StoreProperties.class);
+    when(mockLPStore.getStoreProperties()).thenReturn(lpStoreProps);
+    when(lpStoreProps.isPersistedToDisk()).thenReturn(true);
+    when(lpStoreProps.isDurableStore()).thenReturn(true);
+    Path mockPath = mock(Path.class);
+    when(mockLPStore.checkpoint(any())).thenReturn(Optional.of(mockPath));
+    TaskInstanceMetrics metrics = mock(TaskInstanceMetrics.class);
+    Timer checkpointTimer = mock(Timer.class);
+    when(metrics.storeCheckpointNs()).thenReturn(checkpointTimer);
+
+    java.util.Map<String, StorageEngine> taskStores = ImmutableMap.of("loggedPersistentStore", mockLPStore);
+
+    Partition changelogPartition = new Partition(0);
+    SystemStream changelogSystemStream = new SystemStream("changelogSystem", "changelogStream");
+    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
+    java.util.Map<String, SystemStream> storeChangelogsStreams = ImmutableMap.of("loggedPersistentStore", changelogSystemStream);
+
+    StorageManagerUtil storageManagerUtil = mock(StorageManagerUtil.class);
+    File tmpTestPath = new File("store-checkpoint-test");
+    when(storageManagerUtil.getTaskStoreDir(eq(tmpTestPath), any(), any(), any())).thenReturn(tmpTestPath);
+    TaskName taskName = new TaskName("task");
+
+    when(containerStorageManager.getAllStores(taskName)).thenReturn(taskStores);
+    TaskStorageCommitManager commitManager = spy(new TaskStorageCommitManager(taskName,
+        Collections.emptyMap(), containerStorageManager, storeChangelogsStreams, changelogPartition,
+        null, null, ForkJoinPool.commonPool(), storageManagerUtil, tmpTestPath, metrics));
+
+    doAnswer(invocation -> {
+      String storeName = invocation.getArgumentAt(0, String.class);
+      String fileDir = invocation.getArgumentAt(3, File.class).getName();
+      String mockKey = storeName + fileDir;
+      SystemStreamPartition ssp = invocation.getArgumentAt(1, SystemStreamPartition.class);
+      String offset = invocation.getArgumentAt(2, String.class);
+      if (mockFileSystem.containsKey(mockKey)) {
+        mockFileSystem.get(mockKey).put(ssp, offset);
+      } else {
+        Map<SystemStreamPartition, String> sspOffsets = new HashMap<>();
+        sspOffsets.put(ssp, offset);
+        mockFileSystem.put(mockKey, sspOffsets);
+      }
+      return null;
+    }).when(commitManager).writeChangelogOffsetFile(any(), any(), any(), any());
+
+    CheckpointId newCheckpointId = CheckpointId.create();
+
+    String newestOffset = null;
+    KafkaChangelogSSPOffset kafkaChangelogSSPOffset = new KafkaChangelogSSPOffset(newCheckpointId, newestOffset);
+    java.util.Map<SystemStreamPartition, String> offsetsJava = ImmutableMap.of(
+        changelogSSP, kafkaChangelogSSPOffset.toString()
+    );
+
+    commitManager.init();
+    // invoke persist to file system for v2 checkpoint
+    commitManager.writeCheckpointToStoreDirectories(new CheckpointV1(offsetsJava));
+    assertTrue(mockFileSystem.isEmpty());
+    // verify that delete was called on current store dir offset file
+    verify(storageManagerUtil, times(1)).deleteOffsetFile(eq(tmpTestPath));
+  }
+
+  @Test(expected = SamzaException.class)
+  public void testThrowOnWriteCheckpointDirIfUnsuccessful() {
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+    StorageEngine mockLPStore = mock(StorageEngine.class);
+    StoreProperties lpStoreProps = mock(StoreProperties.class);
+    when(mockLPStore.getStoreProperties()).thenReturn(lpStoreProps);
+    when(lpStoreProps.isPersistedToDisk()).thenReturn(true);
+    when(lpStoreProps.isDurableStore()).thenReturn(true);
+    Path mockPath = mock(Path.class);
+    when(mockLPStore.checkpoint(any())).thenReturn(Optional.of(mockPath));
+    TaskInstanceMetrics metrics = mock(TaskInstanceMetrics.class);
+    Timer checkpointTimer = mock(Timer.class);
+    when(metrics.storeCheckpointNs()).thenReturn(checkpointTimer);
+
+    java.util.Map<String, StorageEngine> taskStores = ImmutableMap.of("loggedPersistentStore", mockLPStore);
+
+    java.util.Map<String, SystemStream> storeChangelogsStreams = ImmutableMap.of("loggedPersistentStore", mock(SystemStream.class));
+
+    StorageManagerUtil storageManagerUtil = mock(StorageManagerUtil.class);
+    File tmpTestPath = new File("store-checkpoint-test");
+    when(storageManagerUtil.getTaskStoreDir(eq(tmpTestPath), eq("loggedPersistentStore"), any(), any())).thenReturn(tmpTestPath);
+
+    TaskName taskName = new TaskName("task");
+
+    TaskStorageCommitManager commitManager = spy(new TaskStorageCommitManager(taskName,
+        Collections.emptyMap(), containerStorageManager, storeChangelogsStreams, mock(Partition.class),
+        null, null, ForkJoinPool.commonPool(), storageManagerUtil, tmpTestPath, metrics));
+
+    java.util.Map<String, String> storeSCM = ImmutableMap.of(
+        "loggedPersistentStore", "system;loggedPersistentStoreStream;1",
+        "persistentStore", "system;persistentStoreStream;1",
+        "loggedInMemStore", "system;loggedInMemStoreStream;1",
+        "inMemStore", "system;inMemStoreStream;1"
+    );
+    when(containerStorageManager.getAllStores(taskName)).thenReturn(taskStores);
+    CheckpointV2 checkpoint = new CheckpointV2(CheckpointId.create(), Collections.emptyMap(), Collections.singletonMap("factory", storeSCM));
+    doThrow(IOException.class).when(storageManagerUtil).writeCheckpointV2File(eq(tmpTestPath), eq(checkpoint));
+
+    commitManager.init();
+    // Should throw samza exception since writeCheckpointV2 failed
+    commitManager.writeCheckpointToStoreDirectories(checkpoint);
+  }
+
+  @Test
+  public void testRemoveOldCheckpointsWhenBaseDirContainsRegularFiles() {
+    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
+    CheckpointManager checkpointManager = mock(CheckpointManager.class);
+    TaskBackupManager taskBackupManager1 = mock(TaskBackupManager.class);
+    TaskBackupManager taskBackupManager2 = mock(TaskBackupManager.class);
+    File durableStoreDir = mock(File.class);
+    TaskInstanceMetrics metrics = mock(TaskInstanceMetrics.class);
+    Timer checkpointTimer = mock(Timer.class);
+    when(metrics.storeCheckpointNs()).thenReturn(checkpointTimer);
+
+    StorageManagerUtil storageManagerUtil = mock(StorageManagerUtil.class);
+
+    TaskName taskName = new TaskName("task1");
+    Map<String, TaskBackupManager> backupManagers = ImmutableMap.of(
+        "factory1", taskBackupManager1,
+        "factory2", taskBackupManager2
+    );
+
+    when(containerStorageManager.getAllStores(taskName)).thenReturn(Collections.emptyMap());
+    TaskStorageCommitManager cm = new TaskStorageCommitManager(taskName, backupManagers, containerStorageManager,
+        Collections.emptyMap(), new Partition(1), checkpointManager, new MapConfig(),
+        ForkJoinPool.commonPool(), storageManagerUtil, durableStoreDir, metrics);
+
+
+    File mockStoreDir = mock(File.class);
+    String mockStoreDirName = "notDirectory";
+    when(durableStoreDir.listFiles()).thenReturn(new File[] {mockStoreDir});
+    when(mockStoreDir.getName()).thenReturn(mockStoreDirName);
+    when(storageManagerUtil.getTaskStoreDir(eq(durableStoreDir), eq(mockStoreDirName), eq(taskName), eq(TaskMode.Active))).thenReturn(mockStoreDir);
+    // null here can happen if listFiles is called on a non-directory
+    when(mockStoreDir.listFiles(any(FileFilter.class))).thenReturn(null);
+
+    cm.cleanUp(CheckpointId.create(), new HashMap<>()).join();
+    verify(durableStoreDir).listFiles();
+    verify(mockStoreDir).listFiles(any(FileFilter.class));
+    verify(storageManagerUtil).getTaskStoreDir(eq(durableStoreDir), eq(mockStoreDirName), eq(taskName), eq(TaskMode.Active));
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/util/TestFutureUtil.java b/samza-core/src/test/java/org/apache/samza/util/TestFutureUtil.java
new file mode 100644
index 0000000..815eb34
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/util/TestFutureUtil.java
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.util;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
+import java.util.concurrent.CompletionStage;
+import java.util.function.Predicate;
+import org.apache.samza.SamzaException;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+public class TestFutureUtil {
+
+  /**
+   * Test all futures in all collections complete before allOf completes.
+   * Test completes exceptionally if any complete exceptionally.
+   * Test works with heterogeneous value types.
+   * Test works with heterogeneous collection types.
+   * Test works with completion stages as well as completable futures.
+   */
+  @Test
+  public void testAllOf() {
+    // verify that there is no short circuiting
+    CompletableFuture<String> future1 = new CompletableFuture<>();
+    CompletableFuture<String> future2 = new CompletableFuture<>();
+    CompletableFuture<String> future3 = new CompletableFuture<>();
+    CompletableFuture<Integer> future4 = new CompletableFuture<>();
+    ImmutableList<CompletableFuture<?>> collection1 =
+        ImmutableList.of(future1, future2);
+    ImmutableSet<CompletionStage<?>> collection2 =
+        ImmutableSet.of(future3, future4);
+
+    CompletableFuture<Void> allFuture = FutureUtil.allOf(collection1, collection2);
+    future1.complete("1");
+    assertFalse(allFuture.isDone());
+    RuntimeException ex2 = new RuntimeException("2");
+    future2.completeExceptionally(ex2);
+    assertFalse(allFuture.isDone());
+    assertFalse(allFuture.isCompletedExceptionally());
+    future3.complete("3");
+    assertFalse(allFuture.isDone());
+    assertFalse(allFuture.isCompletedExceptionally());
+    future4.complete(4);
+    assertTrue(allFuture.isDone());
+    assertTrue(allFuture.isCompletedExceptionally());
+
+    try {
+      allFuture.join();
+    } catch (Exception e) {
+      assertEquals(ex2, FutureUtil.unwrapExceptions(CompletionException.class, e));
+    }
+  }
+
+  @Test
+  public void testAllOfIgnoringErrorsCompletesSuccessfullyIfNoErrors() {
+    CompletableFuture<String> future1 = new CompletableFuture<>();
+    CompletableFuture<String> future2 = new CompletableFuture<>();
+
+    CompletableFuture<Void> allFuture = FutureUtil.allOf(t -> false, future1, future2);
+    future1.complete("1");
+    assertFalse(allFuture.isDone());
+    future2.complete("2");
+    assertTrue(allFuture.isDone());
+    assertFalse(allFuture.isCompletedExceptionally());
+  }
+
+  @Test
+  public void testAllOfIgnoringErrorsCompletesSuccessfullyIfOnlyIgnoredErrors() {
+    CompletableFuture<String> future1 = new CompletableFuture<>();
+    CompletableFuture<String> future2 = new CompletableFuture<>();
+
+    CompletableFuture<Void> allFuture = FutureUtil.allOf(t -> true, future1, future2);
+    future1.complete("1");
+    assertFalse(allFuture.isDone());
+    RuntimeException ex2 = new RuntimeException("2");
+    future2.completeExceptionally(ex2);
+    assertTrue(allFuture.isDone());
+    assertFalse(allFuture.isCompletedExceptionally());
+  }
+
+  @Test
+  public void testAllOfIgnoringErrorsCompletesExceptionallyIfNonIgnoredErrors() {
+    // also test that each future is checked individually
+    CompletableFuture<String> future1 = new CompletableFuture<>();
+    CompletableFuture<String> future2 = new CompletableFuture<>();
+
+    Predicate<Throwable> mockPredicate = mock(Predicate.class);
+    when(mockPredicate.test(any()))
+        .thenReturn(true)
+        .thenReturn(false);
+    CompletableFuture<Void> allFuture = FutureUtil.allOf(mockPredicate, future1, future2);
+    future1.completeExceptionally(new SamzaException());
+    assertFalse(allFuture.isDone());
+    RuntimeException ex2 = new RuntimeException("2");
+    future2.completeExceptionally(ex2);
+    assertTrue(allFuture.isDone());
+    assertTrue(allFuture.isCompletedExceptionally());
+    verify(mockPredicate, times(2)).test(any());
+  }
+
+  @Test
+  public void testFutureOfMapCompletesExceptionallyIfAValueFutureCompletesExceptionally() {
+    Map<String, CompletableFuture<String>> map = new HashMap<>();
+    map.put("1", CompletableFuture.completedFuture("1"));
+    map.put("2", FutureUtil.failedFuture(new SamzaException()));
+
+    assertTrue(FutureUtil.toFutureOfMap(map).isCompletedExceptionally());
+  }
+
+  @Test
+  public void testFutureOfMapCompletesSuccessfullyIfNoErrors() {
+    Map<String, CompletableFuture<String>> map = new HashMap<>();
+    map.put("1", CompletableFuture.completedFuture("1"));
+    map.put("2", CompletableFuture.completedFuture("2"));
+
+    CompletableFuture<Map<String, String>> result = FutureUtil.toFutureOfMap(t -> true, map);
+    assertTrue(result.isDone());
+    assertFalse(result.isCompletedExceptionally());
+  }
+
+  @Test
+  public void testFutureOfMapCompletesSuccessfullyIfOnlyIgnoredErrors() {
+    Map<String, CompletableFuture<String>> map = new HashMap<>();
+    map.put("1", CompletableFuture.completedFuture("1"));
+    map.put("2", FutureUtil.failedFuture(new SamzaException()));
+
+    CompletableFuture<Map<String, String>> result = FutureUtil
+        .toFutureOfMap(t -> FutureUtil.unwrapExceptions(CompletionException.class, t) instanceof SamzaException, map);
+    assertTrue(result.isDone());
+    result.join();
+    assertFalse(result.isCompletedExceptionally());
+    assertEquals("1", result.join().get("1"));
+    assertFalse(result.join().containsKey("2"));
+  }
+
+  @Test
+  public void testFutureOfMapCompletesExceptionallyIfAnyNonIgnoredErrors() {
+    Map<String, CompletableFuture<String>> map = new HashMap<>();
+    map.put("1", FutureUtil.failedFuture(new RuntimeException()));
+    SamzaException samzaException = new SamzaException();
+    map.put("2", FutureUtil.failedFuture(samzaException));
+
+    Predicate<Throwable> mockPredicate = mock(Predicate.class);
+    when(mockPredicate.test(any()))
+        .thenReturn(true)
+        .thenReturn(false);
+
+    CompletableFuture<Map<String, String>> result = FutureUtil.toFutureOfMap(mockPredicate, map);
+    assertTrue(result.isDone());
+    assertTrue(result.isCompletedExceptionally());
+    verify(mockPredicate, times(2)).test(any()); // verify that each failed value future is tested
+
+    try {
+      result.join();
+      fail("Should have thrown an exception.");
+    } catch (Exception e) {
+      assertEquals(samzaException, FutureUtil.unwrapExceptions(CompletionException.class, e));
+    }
+  }
+
+  @Test
+  public void testUnwrapExceptionUnwrapsMultipleExceptions() {
+    IllegalArgumentException cause = new IllegalArgumentException();
+    Throwable t = new SamzaException(new SamzaException(cause));
+    Throwable unwrappedThrowable = FutureUtil.unwrapExceptions(SamzaException.class, t);
+    assertEquals(cause, unwrappedThrowable);
+  }
+
+  @Test
+  public void testUnwrapExceptionReturnsOriginalExceptionIfNoWrapper() {
+    IllegalArgumentException cause = new IllegalArgumentException();
+    Throwable unwrappedThrowable = FutureUtil.unwrapExceptions(SamzaException.class, cause);
+    assertEquals(cause, unwrappedThrowable);
+  }
+
+  @Test
+  public void testUnwrapExceptionReturnsNullIfNoNonWrapperCause() {
+    Throwable t = new SamzaException(new SamzaException());
+    Throwable unwrappedThrowable = FutureUtil.unwrapExceptions(SamzaException.class, t);
+    assertNull(unwrappedThrowable);
+  }
+
+  @Test
+  public void testUnwrapExceptionReturnsNullIfOriginalExceptionIsNull() {
+    Throwable unwrappedThrowable = FutureUtil.unwrapExceptions(SamzaException.class, null);
+    assertNull(unwrappedThrowable);
+  }
+}
+
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
index 7d18033..a86c49f 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
@@ -309,40 +309,6 @@
   }
 
   @Test
-  def testGetChangelogSSPsForContainer() {
-    val taskName0 = new TaskName("task0")
-    val taskName1 = new TaskName("task1")
-    val taskModel0 = new TaskModel(taskName0,
-      Set(new SystemStreamPartition("input", "stream", new Partition(0))),
-      new Partition(10))
-    val taskModel1 = new TaskModel(taskName1,
-      Set(new SystemStreamPartition("input", "stream", new Partition(1))),
-      new Partition(11))
-    val containerModel = new ContainerModel("processorId", Map(taskName0 -> taskModel0, taskName1 -> taskModel1))
-    val changeLogSystemStreams = Map("store0" -> new SystemStream("changelogSystem0", "store0-changelog"),
-      "store1" -> new SystemStream("changelogSystem1", "store1-changelog"))
-    val expected = Set(new SystemStreamPartition("changelogSystem0", "store0-changelog", new Partition(10)),
-      new SystemStreamPartition("changelogSystem1", "store1-changelog", new Partition(10)),
-      new SystemStreamPartition("changelogSystem0", "store0-changelog", new Partition(11)),
-      new SystemStreamPartition("changelogSystem1", "store1-changelog", new Partition(11)))
-    assertEquals(expected, SamzaContainer.getChangelogSSPsForContainer(containerModel, changeLogSystemStreams))
-  }
-
-  @Test
-  def testGetChangelogSSPsForContainerNoChangelogs() {
-    val taskName0 = new TaskName("task0")
-    val taskName1 = new TaskName("task1")
-    val taskModel0 = new TaskModel(taskName0,
-      Set(new SystemStreamPartition("input", "stream", new Partition(0))),
-      new Partition(10))
-    val taskModel1 = new TaskModel(taskName1,
-      Set(new SystemStreamPartition("input", "stream", new Partition(1))),
-      new Partition(11))
-    val containerModel = new ContainerModel("processorId", Map(taskName0 -> taskModel0, taskName1 -> taskModel1))
-    assertEquals(Set(), SamzaContainer.getChangelogSSPsForContainer(containerModel, Map()))
-  }
-
-  @Test
   def testStoreContainerLocality():Unit = {
     this.config = new MapConfig(Map(ClusterManagerConfig.JOB_HOST_AFFINITY_ENABLED -> "true"))
     setupSamzaContainer(None) // re-init with an actual config
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
index 4cab185..606f86d 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestTaskInstance.scala
@@ -19,19 +19,20 @@
 
 package org.apache.samza.container
 
-import java.util.Collections
-
-import com.google.common.collect.ImmutableSet
-import org.apache.samza.{Partition, SamzaException}
-import org.apache.samza.checkpoint.{Checkpoint, CheckpointedChangelogOffset, OffsetManager}
+import com.google.common.collect.{ImmutableMap, ImmutableSet}
+import com.google.common.util.concurrent.MoreExecutors
+import org.apache.samza.checkpoint._
+import org.apache.samza.checkpoint.kafka.{KafkaChangelogSSPOffset, KafkaStateCheckpointMarker}
 import org.apache.samza.config.MapConfig
 import org.apache.samza.context.{TaskContext => _, _}
 import org.apache.samza.job.model.TaskModel
-import org.apache.samza.metrics.Counter
-import org.apache.samza.storage.NonTransactionalStateTaskStorageManager
+import org.apache.samza.metrics.{Counter, Gauge, Timer}
+import org.apache.samza.storage.TaskStorageCommitManager
 import org.apache.samza.system.{IncomingMessageEnvelope, StreamMetadataCache, SystemAdmin, SystemConsumers, SystemStream, SystemStreamMetadata, _}
 import org.apache.samza.table.TableManager
 import org.apache.samza.task._
+import org.apache.samza.util.FutureUtil
+import org.apache.samza.{Partition, SamzaException}
 import org.junit.Assert._
 import org.junit.{Before, Test}
 import org.mockito.Matchers._
@@ -42,6 +43,10 @@
 import org.scalatest.junit.AssertionsForJUnit
 import org.scalatest.mockito.MockitoSugar
 
+import java.util
+import java.util.Collections
+import java.util.concurrent.{CompletableFuture, ExecutorService, Executors, ForkJoinPool}
+import java.util.function.Consumer
 import scala.collection.JavaConverters._
 
 class TestTaskInstance extends AssertionsForJUnit with MockitoSugar {
@@ -68,7 +73,9 @@
   @Mock
   private var offsetManager: OffsetManager = null
   @Mock
-  private var taskStorageManager: NonTransactionalStateTaskStorageManager = null
+  private var taskCommitManager: TaskStorageCommitManager = null
+  @Mock
+  private var checkpointManager: CheckpointManager = null
   @Mock
   private var taskTableManager: TableManager = null
   // not a mock; using MockTaskInstanceExceptionHandler
@@ -88,6 +95,8 @@
 
   private var taskInstance: TaskInstance = null
 
+  private val numCheckpointVersions = 2 // checkpoint versions count
+
   @Before
   def setup(): Unit = {
     MockitoAnnotations.initMocks(this)
@@ -98,7 +107,10 @@
       Matchers.eq(this.containerContext), any(), Matchers.eq(this.applicationContainerContext)))
       .thenReturn(this.applicationTaskContext)
     when(this.systemAdmins.getSystemAdmin(SYSTEM_NAME)).thenReturn(this.systemAdmin)
-    when(this.jobContext.getConfig).thenReturn(new MapConfig(Collections.singletonMap("task.commit.ms", "-1")))
+    val taskConfigsMap = new util.HashMap[String, String]()
+    taskConfigsMap.put("task.commit.ms", "-1")
+    taskConfigsMap.put("task.commit.max.delay.ms", "100000")
+    when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap))
     setupTaskInstance(Some(this.applicationTaskContextFactory))
   }
 
@@ -172,7 +184,6 @@
     taskInstance.initTask
 
     verify(this.offsetManager).setStartingOffset(TASK_NAME, SYSTEM_STREAM_PARTITION, "10")
-    verifyNoMoreInteractions(this.offsetManager)
   }
 
   @Test
@@ -212,18 +223,42 @@
   def testCommitOrder() {
     val commitsCounter = mock[Counter]
     when(this.metrics.commits).thenReturn(commitsCounter)
-    val inputOffsets = new Checkpoint(Map(SYSTEM_STREAM_PARTITION -> "4").asJava)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val cleanUpTimer = mock[Timer]
+    when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+    val skippedCounter = mock[Gauge[Int]]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
     val changelogSSP = new SystemStreamPartition(new SystemStream(SYSTEM_NAME, "test-changelog-stream"), new Partition(0))
-    val changelogOffsets = Map(changelogSSP -> Some("5"))
-    when(this.offsetManager.buildCheckpoint(TASK_NAME)).thenReturn(inputOffsets)
-    when(this.taskStorageManager.flush()).thenReturn(changelogOffsets)
-    doNothing().when(this.taskStorageManager).checkpoint(any(), any[Map[SystemStreamPartition, Option[String]]])
+    val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
+    val stateCheckpointMarker = KafkaStateCheckpointMarker.serialize(new KafkaStateCheckpointMarker(changelogSSP, "5"))
+    stateCheckpointMarkers.put("storeName", stateCheckpointMarker)
+    when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+
+    val snapshotSCMs = ImmutableMap.of(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)
+    when(this.taskCommitManager.snapshot(any())).thenReturn(snapshotSCMs)
+    val snapshotSCMFuture: CompletableFuture[util.Map[String, util.Map[String, String]]] =
+      CompletableFuture.completedFuture(snapshotSCMs)
+    when(this.taskCommitManager.upload(any(), Matchers.eq(snapshotSCMs))).thenReturn(snapshotSCMFuture) // kafka is no-op
+    when(this.taskCommitManager.cleanUp(any(), any())).thenReturn(CompletableFuture.completedFuture[Void](null))
+
+
     taskInstance.commit
 
-    val mockOrder = inOrder(this.offsetManager, this.collector, this.taskTableManager, this.taskStorageManager)
+    val mockOrder = inOrder(this.offsetManager, this.collector, this.taskTableManager, this.taskCommitManager)
 
     // We must first get a snapshot of the input offsets so it doesn't change while we flush. SAMZA-1384
-    mockOrder.verify(this.offsetManager).buildCheckpoint(TASK_NAME)
+    mockOrder.verify(this.offsetManager).getLastProcessedOffsets(TASK_NAME)
 
     // Producers must be flushed next and ideally the output would be flushed before the changelog
     // s.t. the changelog and checkpoints (state and inputs) are captured last
@@ -233,128 +268,671 @@
     mockOrder.verify(this.taskTableManager).flush()
 
     // Local state should be flushed next next
-    mockOrder.verify(this.taskStorageManager).flush()
+    mockOrder.verify(this.taskCommitManager).snapshot(any())
+
+    // Upload should be called next with the snapshot SCMs.
+    mockOrder.verify(this.taskCommitManager).upload(any(), Matchers.eq(snapshotSCMs))
 
     // Stores checkpoints should be created next with the newest changelog offsets
-    mockOrder.verify(this.taskStorageManager).checkpoint(any(), Matchers.eq(changelogOffsets))
+    mockOrder.verify(this.taskCommitManager).writeCheckpointToStoreDirectories(any())
 
     // Input checkpoint should be written with the snapshot captured at the beginning of commit and the
     // newest changelog offset captured during storage manager flush
     val captor = ArgumentCaptor.forClass(classOf[Checkpoint])
-    mockOrder.verify(offsetManager).writeCheckpoint(any(), captor.capture)
-    val cp = captor.getValue
-    assertEquals("4", cp.getOffsets.get(SYSTEM_STREAM_PARTITION))
-    assertEquals("5", CheckpointedChangelogOffset.fromString(cp.getOffsets.get(changelogSSP)).getOffset)
+    mockOrder.verify(offsetManager, times(numCheckpointVersions)).writeCheckpoint(any(), captor.capture)
+    val cp = captor.getAllValues
+    assertEquals(numCheckpointVersions, cp.size())
+    cp.forEach(new Consumer[Checkpoint] {
+      override def accept(c: Checkpoint): Unit = {
+        assertEquals("4", c.getOffsets.get(SYSTEM_STREAM_PARTITION))
+        if (c.getVersion == 2) {
+          assertEquals(1, c.getOffsets.size())
+          assertTrue(c.isInstanceOf[CheckpointV2])
+          val checkpointedStateCheckpointMarkers = c.asInstanceOf[CheckpointV2]
+            .getStateCheckpointMarkers.get(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME)
+          assertTrue(checkpointedStateCheckpointMarkers.size() == 1)
+          val checkpointedStateCheckpointMarker = checkpointedStateCheckpointMarkers.get("storeName")
+          assertTrue(checkpointedStateCheckpointMarker.equals(stateCheckpointMarker))
+          val kafkaMarker = KafkaStateCheckpointMarker.deserialize(checkpointedStateCheckpointMarker)
+          assertEquals(kafkaMarker.getChangelogOffset, "5")
+          assertEquals(kafkaMarker.getChangelogSSP, changelogSSP)
+        } else { // c.getVersion == 1
+          assertEquals(2, c.getOffsets.size())
+          assertTrue(c.isInstanceOf[CheckpointV1])
+          assertEquals("5", KafkaChangelogSSPOffset.fromString(c.getOffsets.get(changelogSSP)).getChangelogOffset)
+        }
+      }
+    })
 
     // Old checkpointed stores should be cleared
-    mockOrder.verify(this.taskStorageManager).removeOldCheckpoints(any())
+    mockOrder.verify(this.taskCommitManager).cleanUp(any(), any())
     verify(commitsCounter).inc()
+    verify(snapshotTimer).update(anyLong())
+    verify(uploadTimer).update(anyLong())
+    verify(commitTimer).update(anyLong())
   }
 
   @Test
   def testEmptyChangelogSSPOffsetInCommit() { // e.g. if changelog topic is empty
     val commitsCounter = mock[Counter]
     when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val skippedCounter = mock[Gauge[Int]]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
 
-    val inputOffsets = new Checkpoint(Map(SYSTEM_STREAM_PARTITION -> "4").asJava)
+    val inputOffsets = Map(SYSTEM_STREAM_PARTITION -> "4").asJava
     val changelogSSP = new SystemStreamPartition(new SystemStream(SYSTEM_NAME, "test-changelog-stream"), new Partition(0))
-    val changelogOffsets = Map(changelogSSP -> None)
-    when(this.offsetManager.buildCheckpoint(TASK_NAME)).thenReturn(inputOffsets)
-    when(this.taskStorageManager.flush()).thenReturn(changelogOffsets)
+    val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
+    val nullStateCheckpointMarker = KafkaStateCheckpointMarker.serialize(new KafkaStateCheckpointMarker(changelogSSP, null))
+    stateCheckpointMarkers.put("storeName", nullStateCheckpointMarker)
+    when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+    when(this.taskCommitManager.upload(any(), any()))
+      .thenReturn(CompletableFuture.completedFuture(
+        Collections.singletonMap(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)))
     taskInstance.commit
 
     val captor = ArgumentCaptor.forClass(classOf[Checkpoint])
-    verify(offsetManager).writeCheckpoint(any(), captor.capture)
-    val cp = captor.getValue
-    assertEquals("4", cp.getOffsets.get(SYSTEM_STREAM_PARTITION))
-    val message = cp.getOffsets.get(changelogSSP)
-    val checkpointedOffset = CheckpointedChangelogOffset.fromString(message)
-    assertNull(checkpointedOffset.getOffset)
-    assertNotNull(checkpointedOffset.getCheckpointId)
+    verify(offsetManager, times(numCheckpointVersions)).writeCheckpoint(any(), captor.capture)
+    val cp = captor.getAllValues
+    assertEquals(numCheckpointVersions, cp.size())
+    cp.forEach(new Consumer[Checkpoint] {
+      override def accept(checkpoint: Checkpoint): Unit = {
+        assertEquals("4", checkpoint.getOffsets.get(SYSTEM_STREAM_PARTITION))
+        if (checkpoint.getVersion == 2) {
+          assertEquals(1, checkpoint.getOffsets.size())
+          assertTrue(checkpoint.isInstanceOf[CheckpointV2])
+          val checkpointedStateCheckpointMarkers = checkpoint.asInstanceOf[CheckpointV2]
+            .getStateCheckpointMarkers.get(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME)
+          assertTrue(checkpointedStateCheckpointMarkers.size() == 1)
+          val checkpointedStateCheckpointMarker = checkpointedStateCheckpointMarkers.get("storeName")
+          assertTrue(checkpointedStateCheckpointMarker.equals(nullStateCheckpointMarker))
+          val kafkaMarker = KafkaStateCheckpointMarker.deserialize(checkpointedStateCheckpointMarker)
+          assertNull(kafkaMarker.getChangelogOffset)
+          assertEquals(kafkaMarker.getChangelogSSP, changelogSSP)
+        } else { // c.getVersion == 1
+          assertEquals(2, checkpoint.getOffsets.size())
+          assertTrue(checkpoint.isInstanceOf[CheckpointV1])
+          val message = checkpoint.getOffsets.get(changelogSSP)
+          val checkpointedOffset = KafkaChangelogSSPOffset.fromString(message)
+          assertNull(checkpointedOffset.getChangelogOffset)
+          assertNotNull(checkpointedOffset.getCheckpointId)
+        }
+      }
+    })
     verify(commitsCounter).inc()
+    verify(snapshotTimer).update(anyLong())
+    verify(uploadTimer).update(anyLong())
   }
 
   @Test
   def testEmptyChangelogOffsetsInCommit() { // e.g. if stores have no changelogs
     val commitsCounter = mock[Counter]
     when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val skippedCounter = mock[Gauge[Int]]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
 
-    val inputOffsets = new Checkpoint(Map(SYSTEM_STREAM_PARTITION -> "4").asJava)
-    val changelogOffsets = Map[SystemStreamPartition, Option[String]]()
-    when(this.offsetManager.buildCheckpoint(TASK_NAME)).thenReturn(inputOffsets)
-    when(this.taskStorageManager.flush()).thenReturn(changelogOffsets)
+    val inputOffsets = Map(SYSTEM_STREAM_PARTITION -> "4").asJava
+    val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
+    when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+    when(this.taskCommitManager.upload(any(), any()))
+      .thenReturn(CompletableFuture.completedFuture(
+        Collections.singletonMap(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)))
     taskInstance.commit
 
     val captor = ArgumentCaptor.forClass(classOf[Checkpoint])
-    verify(offsetManager).writeCheckpoint(any(), captor.capture)
-    val cp = captor.getValue
-    assertEquals("4", cp.getOffsets.get(SYSTEM_STREAM_PARTITION))
-    assertEquals(1, cp.getOffsets.size())
+    // verify the write checkpoint is evoked twice, once per checkpoint version
+    verify(offsetManager, times(numCheckpointVersions)).writeCheckpoint(any(), captor.capture)
+    val cp = captor.getAllValues
+    assertEquals(numCheckpointVersions, cp.size())
+    cp.forEach(new Consumer[Checkpoint] {
+      override def accept(c: Checkpoint): Unit = {
+        assertEquals("4", c.getOffsets.get(SYSTEM_STREAM_PARTITION))
+        assertEquals(1, c.getOffsets.size())
+      }
+    })
     verify(commitsCounter).inc()
+    verify(snapshotTimer).update(anyLong())
+    verify(uploadTimer).update(anyLong())
   }
 
   @Test
   def testCommitFailsIfErrorGettingChangelogOffset() { // required for transactional state
     val commitsCounter = mock[Counter]
     when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
 
-    val inputOffsets = new Checkpoint(Map(SYSTEM_STREAM_PARTITION -> "4").asJava)
-    when(this.offsetManager.buildCheckpoint(TASK_NAME)).thenReturn(inputOffsets)
-    when(this.taskStorageManager.flush()).thenThrow(new SamzaException("Error getting changelog offsets"))
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
+    when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+    when(this.taskCommitManager.snapshot(any())).thenThrow(new SamzaException("Error getting changelog offsets"))
+
+    try {
+      // sync stage exception should be caught and rethrown immediately
+      taskInstance.commit
+
+      verify(commitsCounter).inc()
+      verifyZeroInteractions(snapshotTimer)
+    } catch {
+      case e: SamzaException =>
+        val msg = e.getMessage
+        // exception is expected, container should fail if could not get changelog offsets.
+        return
+    }
+
+    fail("Should have failed commit if error getting newest changelog offsets")
+  }
+
+  @Test
+  def testCommitFailsIfPreviousAsyncUploadFails() {
+    val commitsCounter = mock[Counter]
+    when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val cleanUpTimer = mock[Timer]
+    when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+    val skippedCounter = mock[Gauge[Int]]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
+    val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
+    when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+    when(this.taskCommitManager.upload(any(), any()))
+      .thenReturn(CompletableFuture.completedFuture(
+        Collections.singletonMap(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)))
+    when(this.taskCommitManager.upload(any(), any()))
+      .thenReturn(FutureUtil.failedFuture[util.Map[String, util.Map[String, String]]](new RuntimeException))
 
     try {
       taskInstance.commit
+
+      verify(commitsCounter).inc()
+      verify(snapshotTimer).update(anyLong())
+      verifyZeroInteractions(uploadTimer)
+      verifyZeroInteractions(commitTimer)
+      verifyZeroInteractions(skippedCounter)
+
+      // async stage exception in first commit should be caught and rethrown by the subsequent commit
+      taskInstance.commit
+      verifyNoMoreInteractions(commitsCounter)
+    } catch {
+      case e: SamzaException =>
+        // exception is expected, container should fail if could not upload previous snapshot.
+        return
+    }
+
+    fail("Should have failed commit if error uploading store contents")
+  }
+
+  @Test
+  def testCommitFailsIfAsyncStoreDirCheckpointWriteFails() { // required for transactional state
+    val commitsCounter = mock[Counter]
+    when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val cleanUpTimer = mock[Timer]
+    when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+    val skippedCounter = mock[Gauge[Int]]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
+    val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
+    when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+    when(this.taskCommitManager.upload(any(), any()))
+      .thenReturn(CompletableFuture.completedFuture(
+        Collections.singletonMap(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)))
+    when(this.taskCommitManager.writeCheckpointToStoreDirectories(any()))
+      .thenThrow(new SamzaException("Error creating store checkpoint"))
+
+    try {
+      taskInstance.commit
+
+      verify(commitsCounter).inc()
+      verify(snapshotTimer).update(anyLong())
+      verify(uploadTimer).update(anyLong())
+      verifyZeroInteractions(commitTimer)
+      verifyZeroInteractions(skippedCounter)
+
+      // async stage exception in first commit should be caught and rethrown by the subsequent commit
+      taskInstance.commit
+      verifyNoMoreInteractions(commitsCounter)
     } catch {
       case e: SamzaException =>
         // exception is expected, container should fail if could not get changelog offsets.
         return
     }
 
-    fail("Should have failed commit if error getting newest changelog offests")
+    fail("Should have failed commit if error writing checkpoint to store dirs")
   }
 
   @Test
-  def testCommitFailsIfErrorCreatingStoreCheckpoints() { // required for transactional state
+  def testCommitFailsIfPreviousAsyncCheckpointTopicWriteFails() {
     val commitsCounter = mock[Counter]
     when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val cleanUpTimer = mock[Timer]
+    when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+    val skippedCounter = mock[Gauge[Int]]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
 
-    val inputOffsets = new Checkpoint(Map(SYSTEM_STREAM_PARTITION -> "4").asJava)
-    when(this.offsetManager.buildCheckpoint(TASK_NAME)).thenReturn(inputOffsets)
-    when(this.taskStorageManager.flush()).thenReturn(Map[SystemStreamPartition, Option[String]]())
-    when(this.taskStorageManager.checkpoint(any(), any())).thenThrow(new SamzaException("Error creating store checkpoint"))
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
+    val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
+    when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+    when(this.taskCommitManager.upload(any(), any()))
+      .thenReturn(CompletableFuture.completedFuture(
+        Collections.singletonMap(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)))
+    doNothing().when(this.taskCommitManager).writeCheckpointToStoreDirectories(any())
+    when(this.offsetManager.writeCheckpoint(any(), any()))
+      .thenThrow(new SamzaException("Error writing checkpoint"))
 
     try {
       taskInstance.commit
+
+      verify(commitsCounter).inc()
+      verify(snapshotTimer).update(anyLong())
+      verify(uploadTimer).update(anyLong())
+      verifyZeroInteractions(commitTimer)
+      verifyZeroInteractions(skippedCounter)
+
+      // async stage exception in first commit should be caught and rethrown by the subsequent commit
+      taskInstance.commit
+      verifyNoMoreInteractions(commitsCounter)
     } catch {
       case e: SamzaException =>
-        // exception is expected, container should fail if could not get changelog offsets.
+        // exception is expected, container should fail if could not write previous checkpoint.
         return
     }
 
-    fail("Should have failed commit if error getting newest changelog offests")
+    fail("Should have failed commit if error writing checkpoints to checkpoint topic")
   }
 
   @Test
-  def testCommitContinuesIfErrorClearingOldCheckpoints() { // required for transactional state
+  def testCommitFailsIfPreviousAsyncCleanUpFails() { // required for blob store backend
     val commitsCounter = mock[Counter]
     when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val cleanUpTimer = mock[Timer]
+    when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+    val skippedCounter = mock[Gauge[Int]]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
 
-    val inputOffsets = new Checkpoint(Map(SYSTEM_STREAM_PARTITION -> "4").asJava)
-    when(this.offsetManager.buildCheckpoint(TASK_NAME)).thenReturn(inputOffsets)
-    when(this.taskStorageManager.flush()).thenReturn(Map[SystemStreamPartition, Option[String]]())
-    doNothing().when(this.taskStorageManager).checkpoint(any(), any())
-    when(this.taskStorageManager.removeOldCheckpoints(any()))
-      .thenThrow(new SamzaException("Error clearing old checkpoints"))
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
+    val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
+    when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+    when(this.taskCommitManager.upload(any(), any()))
+      .thenReturn(CompletableFuture.completedFuture(
+        Collections.singletonMap(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)))
+    doNothing().when(this.taskCommitManager).writeCheckpointToStoreDirectories(any())
+    when(this.taskCommitManager.cleanUp(any(), any()))
+      .thenReturn(FutureUtil.failedFuture[Void](new SamzaException("Error during cleanup")))
 
     try {
       taskInstance.commit
+
+      verify(commitsCounter).inc()
+      verify(snapshotTimer).update(anyLong())
+      verify(uploadTimer).update(anyLong())
+      verifyZeroInteractions(commitTimer)
+      verifyZeroInteractions(skippedCounter)
+
+      // async stage exception in first commit should be caught and rethrown by the subsequent commit
+      taskInstance.commit
+      verifyNoMoreInteractions(commitsCounter)
     } catch {
       case e: SamzaException =>
-        // exception is expected, container should fail if could not get changelog offsets.
-        fail("Exception from removeOldCheckpoints should have been caught")
+        // exception is expected, container should fail if could not clean up old checkpoint.
+        return
     }
+
+    fail("Should have failed commit if error cleaning up previous commit")
   }
 
+  @Test
+  def testCommitFailsIfPreviousAsyncUploadFailsSynchronously() {
+    val commitsCounter = mock[Counter]
+    when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val cleanUpTimer = mock[Timer]
+    when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+    val skippedCounter = mock[Gauge[Int]]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
+    val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
+    when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+    when(this.taskCommitManager.upload(any(), any()))
+      .thenReturn(CompletableFuture.completedFuture(
+        Collections.singletonMap(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)))
+
+    // Fail synchronously instead of returning a failed future.
+    when(this.taskCommitManager.upload(any(), any()))
+      .thenThrow(new RuntimeException)
+
+    try {
+      taskInstance.commit
+
+      verify(commitsCounter).inc()
+      verify(snapshotTimer).update(anyLong())
+      verifyZeroInteractions(uploadTimer)
+      verifyZeroInteractions(commitTimer)
+      verifyZeroInteractions(skippedCounter)
+
+      // async stage exception in first commit should be caught and rethrown by the subsequent commit
+      taskInstance.commit
+      verifyNoMoreInteractions(commitsCounter)
+    } catch {
+      case e: SamzaException =>
+        // exception is expected, container should fail if could not upload previous snapshot.
+        return
+    }
+
+    fail("Should have failed commit if synchronous error during upload in async stage of previous commit")
+  }
+
+  @Test
+  def testCommitSucceedsIfPreviousAsyncStageSucceeds() {
+    val commitsCounter = mock[Counter]
+    when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val cleanUpTimer = mock[Timer]
+    when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+    val skippedCounter = mock[Gauge[Int]]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
+    val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
+    when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+    when(this.taskCommitManager.upload(any(), any()))
+      .thenReturn(CompletableFuture.completedFuture(
+        Collections.singletonMap(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)))
+    doNothing().when(this.taskCommitManager).writeCheckpointToStoreDirectories(any())
+    when(this.taskCommitManager.cleanUp(any(), any()))
+      .thenReturn(CompletableFuture.completedFuture[Void](null))
+
+    taskInstance.commit // async stage will be run by caller due to direct executor
+
+    verify(commitsCounter).inc()
+    verify(snapshotTimer).update(anyLong())
+    verify(uploadTimer).update(anyLong())
+    verify(commitTimer).update(anyLong())
+
+    taskInstance.commit
+
+    // verify that all commit operations ran twice
+    verify(taskCommitManager, times(2)).snapshot(any())
+    verify(taskCommitManager, times(2)).upload(any(), any())
+    // called 2x per commit, once for each checkpoint version
+    verify(taskCommitManager, times(4)).writeCheckpointToStoreDirectories(any())
+    verify(offsetManager, times(4)).writeCheckpoint(any(), any())
+    verify(taskCommitManager, times(2)).cleanUp(any(), any())
+    verify(commitsCounter, times(2)).inc()
+  }
+
+  @Test
+  def testCommitSkipsIfPreviousAsyncCommitInProgressWithinMaxCommitDelay() {
+    val commitsCounter = mock[Counter]
+    when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val cleanUpTimer = mock[Timer]
+    when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+    val skippedCounter = mock[Gauge[Int]]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
+    val changelogSSP = new SystemStreamPartition(new SystemStream(SYSTEM_NAME, "test-changelog-stream"), new Partition(0))
+
+    val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
+    val stateCheckpointMarker = KafkaStateCheckpointMarker.serialize(new KafkaStateCheckpointMarker(changelogSSP, "5"))
+    stateCheckpointMarkers.put("storeName", stateCheckpointMarker)
+    when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+
+    val snapshotSCMs = ImmutableMap.of(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)
+    when(this.taskCommitManager.snapshot(any())).thenReturn(snapshotSCMs)
+    val snapshotSCMFuture: CompletableFuture[util.Map[String, util.Map[String, String]]] =
+      CompletableFuture.completedFuture(snapshotSCMs)
+
+    when(this.taskCommitManager.upload(any(), Matchers.eq(snapshotSCMs))).thenReturn(snapshotSCMFuture) // kafka is no-op
+
+    val cleanUpFuture = new CompletableFuture[Void]() // not completed until subsequent commit starts
+    when(this.taskCommitManager.cleanUp(any(), any())).thenReturn(cleanUpFuture)
+
+    // use a separate executor to perform async operations on to test caller thread blocking behavior
+    setupTaskInstance(None, ForkJoinPool.commonPool())
+
+    taskInstance.commit // async stage will not complete until cleanUpFuture is completed
+
+    taskInstance.commit
+
+    verify(skippedCounter).set(1)
+
+    verify(commitsCounter, times(1)).inc() // should only have been incremented once on the initial commit
+    verify(snapshotTimer).update(anyLong())
+    verify(uploadTimer).update(anyLong())
+    verifyZeroInteractions(commitTimer)
+
+    cleanUpFuture.complete(null) // just to unblock shared executor
+  }
+
+  @Test
+  def testCommitThrowsIfPreviousAsyncCommitInProgressAfterMaxCommitDelayAndBlockTime() {
+    val commitsCounter = mock[Counter]
+    when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    val uploadTimer = mock[Timer]
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val cleanUpTimer = mock[Timer]
+    when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+    val skippedCounter = mock[Gauge[Int]]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
+    val changelogSSP = new SystemStreamPartition(new SystemStream(SYSTEM_NAME, "test-changelog-stream"), new Partition(0))
+
+    val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
+    val stateCheckpointMarker = KafkaStateCheckpointMarker.serialize(new KafkaStateCheckpointMarker(changelogSSP, "5"))
+    stateCheckpointMarkers.put("storeName", stateCheckpointMarker)
+    when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+
+    val snapshotSCMs = ImmutableMap.of(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)
+    when(this.taskCommitManager.snapshot(any())).thenReturn(snapshotSCMs)
+    val snapshotSCMFuture: CompletableFuture[util.Map[String, util.Map[String, String]]] =
+      CompletableFuture.completedFuture(snapshotSCMs)
+
+    when(this.taskCommitManager.upload(any(), Matchers.eq(snapshotSCMs))).thenReturn(snapshotSCMFuture) // kafka is no-op
+
+    val cleanUpFuture = new CompletableFuture[Void]()
+    when(this.taskCommitManager.cleanUp(any(), any())).thenReturn(cleanUpFuture)
+
+    // use a separate executor to perform async operations on to test caller thread blocking behavior
+    val taskConfigsMap = new util.HashMap[String, String]()
+    taskConfigsMap.put("task.commit.ms", "-1")
+    // "block" immediately if previous commit async stage not complete
+    taskConfigsMap.put("task.commit.max.delay.ms", "-1")
+    taskConfigsMap.put("task.commit.timeout.ms", "0") // throw exception immediately if blocked
+    when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap)) // override default behavior
+
+    setupTaskInstance(None, ForkJoinPool.commonPool())
+
+    taskInstance.commit // async stage will not complete until cleanUpFuture is completed
+
+    try {
+      taskInstance.commit // should throw exception
+      fail("Should have thrown an exception if blocked for previous commit async stage.")
+    } catch {
+      case e: Exception =>
+        verify(commitsCounter, times(1)).inc() // should only have been incremented once on the initial commit
+    }
+
+    cleanUpFuture.complete(null) // just to unblock shared executor
+  }
+
+  @Test
+  def testCommitBlocksIfPreviousAsyncCommitInProgressAfterMaxCommitDelayButWithinBlockTime() {
+    val commitsCounter = mock[Counter]
+    when(this.metrics.commits).thenReturn(commitsCounter)
+    val snapshotTimer = mock[Timer]
+    when(this.metrics.snapshotNs).thenReturn(snapshotTimer)
+    val commitTimer = mock[Timer]
+    when(this.metrics.commitNs).thenReturn(commitTimer)
+    val uploadTimer = mock[Timer]
+    val commitSyncTimer = mock[Timer]
+    when(this.metrics.commitSyncNs).thenReturn(commitSyncTimer)
+    val commitAsyncTimer = mock[Timer]
+    when(this.metrics.commitAsyncNs).thenReturn(commitAsyncTimer)
+    when(this.metrics.asyncUploadNs).thenReturn(uploadTimer)
+    val cleanUpTimer = mock[Timer]
+    when(this.metrics.asyncCleanupNs).thenReturn(cleanUpTimer)
+    val skippedCounter = mock[Gauge[Int]]
+    when(this.metrics.commitsSkipped).thenReturn(skippedCounter)
+
+    val inputOffsets = new util.HashMap[SystemStreamPartition, String]()
+    inputOffsets.put(SYSTEM_STREAM_PARTITION,"4")
+    val changelogSSP = new SystemStreamPartition(new SystemStream(SYSTEM_NAME, "test-changelog-stream"), new Partition(0))
+
+    val stateCheckpointMarkers: util.Map[String, String] = new util.HashMap[String, String]()
+    val stateCheckpointMarker = KafkaStateCheckpointMarker.serialize(new KafkaStateCheckpointMarker(changelogSSP, "5"))
+    stateCheckpointMarkers.put("storeName", stateCheckpointMarker)
+    when(this.offsetManager.getLastProcessedOffsets(TASK_NAME)).thenReturn(inputOffsets)
+
+    val snapshotSCMs = ImmutableMap.of(KafkaStateCheckpointMarker.KAFKA_STATE_BACKEND_FACTORY_NAME, stateCheckpointMarkers)
+    when(this.taskCommitManager.snapshot(any())).thenReturn(snapshotSCMs)
+    val snapshotSCMFuture: CompletableFuture[util.Map[String, util.Map[String, String]]] =
+      CompletableFuture.completedFuture(snapshotSCMs)
+
+    when(this.taskCommitManager.upload(any(), Matchers.eq(snapshotSCMs))).thenReturn(snapshotSCMFuture) // kafka is no-op
+
+    val cleanUpFuture = new CompletableFuture[Void]()
+    when(this.taskCommitManager.cleanUp(any(), any())).thenReturn(cleanUpFuture)
+
+    // use a separate executor to perform async operations on to test caller thread blocking behavior
+    val taskConfigsMap = new util.HashMap[String, String]()
+    taskConfigsMap.put("task.commit.ms", "-1")
+    // "block" immediately if previous commit async stage not complete
+    taskConfigsMap.put("task.commit.max.delay.ms", "-1")
+    taskConfigsMap.put("task.commit.timeout.ms", "1000000") // block until previous stage is complete
+    when(this.jobContext.getConfig).thenReturn(new MapConfig(taskConfigsMap)) // override default behavior
+
+    setupTaskInstance(None, ForkJoinPool.commonPool())
+
+    taskInstance.commit // async stage will not complete until cleanUpFuture is completed
+
+    val executorService = Executors.newSingleThreadExecutor()
+    val secondCommitFuture = CompletableFuture.runAsync(new Runnable {
+      override def run(): Unit = taskInstance.commit // will block on executor
+    }, executorService)
+
+    var retries = 0 // wait no more than ~100 millis
+    while (!taskInstance.commitInProgress.hasQueuedThreads && retries < 10) {
+      retries += 1
+      Thread.sleep(10) // wait until commit in other thread blocks on the semaphore.
+    }
+    if (!taskInstance.commitInProgress.hasQueuedThreads) {
+      fail("Other thread should have blocked on semaphore acquisition. " +
+        "May need to increase retries if transient failure.")
+    }
+
+    cleanUpFuture.complete(null) // will eventually unblock the 2nd commit in other thread.
+    secondCommitFuture.join() // will complete when the sync phase of 2nd commit is complete.
+    verify(commitsCounter, times(2)).inc() // should only have been incremented twice - once for each commit
+    verify(snapshotTimer, times(2)).update(anyLong())
+  }
+
+
   /**
     * Given that no application task context factory is provided, then no lifecycle calls should be made.
     */
@@ -400,7 +978,7 @@
       this.consumerMultiplexer,
       this.collector,
       offsetManager = offsetManagerMock,
-      storageManager = this.taskStorageManager,
+      commitManager = this.taskCommitManager,
       tableManager = this.taskTableManager,
       systemStreamPartitions = ImmutableSet.of(ssp),
       exceptionHandler = this.taskInstanceExceptionHandler,
@@ -418,7 +996,8 @@
   }
 
   private def setupTaskInstance(
-    applicationTaskContextFactory: Option[ApplicationTaskContextFactory[ApplicationTaskContext]]): Unit = {
+    applicationTaskContextFactory: Option[ApplicationTaskContextFactory[ApplicationTaskContext]],
+    commitThreadPool: ExecutorService = MoreExecutors.newDirectExecutorService()): Unit = {
     this.taskInstance = new TaskInstance(this.task,
       this.taskModel,
       this.metrics,
@@ -426,10 +1005,11 @@
       this.consumerMultiplexer,
       this.collector,
       offsetManager = this.offsetManager,
-      storageManager = this.taskStorageManager,
+      commitManager = this.taskCommitManager,
       tableManager = this.taskTableManager,
       systemStreamPartitions = SYSTEM_STREAM_PARTITIONS,
       exceptionHandler = this.taskInstanceExceptionHandler,
+      commitThreadPool = commitThreadPool,
       jobContext = this.jobContext,
       containerContext = this.containerContext,
       applicationContainerContextOption = Some(this.applicationContainerContext),
diff --git a/samza-core/src/test/scala/org/apache/samza/storage/TestTransactionalStateTaskStorageManager.java b/samza-core/src/test/scala/org/apache/samza/storage/TestTransactionalStateTaskStorageManager.java
deleted file mode 100644
index 244a35b..0000000
--- a/samza-core/src/test/scala/org/apache/samza/storage/TestTransactionalStateTaskStorageManager.java
+++ /dev/null
@@ -1,534 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.storage;
-
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSet;
-import java.io.FileFilter;
-import scala.Option;
-import scala.collection.immutable.Map;
-
-import java.io.File;
-import java.io.IOException;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.util.HashMap;
-import java.util.Optional;
-import org.apache.samza.Partition;
-import org.apache.samza.SamzaException;
-import org.apache.samza.checkpoint.CheckpointId;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.job.model.TaskMode;
-import org.apache.samza.system.SystemAdmin;
-import org.apache.samza.system.SystemAdmins;
-import org.apache.samza.system.SystemStream;
-import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata;
-import org.apache.samza.system.SystemStreamPartition;
-import org.apache.samza.util.ScalaJavaUtil;
-import org.junit.Test;
-import org.mockito.ArgumentCaptor;
-import org.mockito.InOrder;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.fail;
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyBoolean;
-import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.doNothing;
-import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.doThrow;
-import static org.mockito.Mockito.inOrder;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.never;
-import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-public class TestTransactionalStateTaskStorageManager {
-  @Test
-  public void testFlushOrder() {
-    ContainerStorageManager csm = mock(ContainerStorageManager.class);
-    StorageEngine mockStore = mock(StorageEngine.class);
-    java.util.Map<String, StorageEngine> taskStores = ImmutableMap.of("mockStore", mockStore);
-    when(csm.getAllStores(any())).thenReturn(taskStores);
-
-    TransactionalStateTaskStorageManager tsm = spy(buildTSM(csm, mock(Partition.class), new StorageManagerUtil()));
-    // stub actual method call
-    doReturn(mock(Map.class)).when(tsm).getNewestChangelogSSPOffsets(any(), any(), any(), any());
-
-    // invoke flush
-    tsm.flush();
-
-    // ensure that stores are flushed before we get newest changelog offsets
-    InOrder inOrder = inOrder(mockStore, tsm);
-    inOrder.verify(mockStore).flush();
-    inOrder.verify(tsm).getNewestChangelogSSPOffsets(any(), any(), any(), any());
-  }
-
-  @Test
-  public void testGetNewestOffsetsReturnsCorrectOffset() {
-    ContainerStorageManager csm = mock(ContainerStorageManager.class);
-    TransactionalStateTaskStorageManager tsm = buildTSM(csm, mock(Partition.class), new StorageManagerUtil());
-
-    TaskName taskName = mock(TaskName.class);
-    String changelogSystemName = "systemName";
-    String storeName = "storeName";
-    String changelogStreamName = "changelogName";
-    String newestChangelogSSPOffset = "1";
-    SystemStream changelogSystemStream = new SystemStream(changelogSystemName, changelogStreamName);
-    Partition changelogPartition = new Partition(0);
-    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
-
-    Map<String, SystemStream> storeChangelogs =
-        ScalaJavaUtil.toScalaMap(ImmutableMap.of(storeName, changelogSystemStream));
-
-    SystemAdmins systemAdmins = mock(SystemAdmins.class);
-    SystemAdmin systemAdmin = mock(SystemAdmin.class);
-    SystemStreamPartitionMetadata metadata = mock(SystemStreamPartitionMetadata.class);
-
-    when(metadata.getNewestOffset()).thenReturn(newestChangelogSSPOffset);
-    when(systemAdmins.getSystemAdmin(changelogSystemName)).thenReturn(systemAdmin);
-    when(systemAdmin.getSSPMetadata(eq(ImmutableSet.of(changelogSSP)))).thenReturn(ImmutableMap.of(changelogSSP, metadata));
-
-    // invoke the method
-    Map<SystemStreamPartition, Option<String>> offsets =
-        tsm.getNewestChangelogSSPOffsets(
-            taskName, storeChangelogs, changelogPartition, systemAdmins);
-
-    // verify results
-    assertEquals(1, offsets.size());
-    assertEquals(Option.apply(newestChangelogSSPOffset), offsets.apply(changelogSSP));
-  }
-
-  @Test
-  public void testGetNewestOffsetsReturnsNoneForEmptyTopic() {
-    // empty topic == null newest offset
-    ContainerStorageManager csm = mock(ContainerStorageManager.class);
-    TransactionalStateTaskStorageManager tsm = buildTSM(csm, mock(Partition.class), new StorageManagerUtil());
-
-    TaskName taskName = mock(TaskName.class);
-    String changelogSystemName = "systemName";
-    String storeName = "storeName";
-    String changelogStreamName = "changelogName";
-    String newestChangelogSSPOffset = null;
-    SystemStream changelogSystemStream = new SystemStream(changelogSystemName, changelogStreamName);
-    Partition changelogPartition = new Partition(0);
-    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
-
-    Map<String, SystemStream> storeChangelogs =
-        ScalaJavaUtil.toScalaMap(ImmutableMap.of(storeName, changelogSystemStream));
-
-    SystemAdmins systemAdmins = mock(SystemAdmins.class);
-    SystemAdmin systemAdmin = mock(SystemAdmin.class);
-    SystemStreamPartitionMetadata metadata = mock(SystemStreamPartitionMetadata.class);
-
-    when(metadata.getNewestOffset()).thenReturn(newestChangelogSSPOffset);
-    when(systemAdmins.getSystemAdmin(changelogSystemName)).thenReturn(systemAdmin);
-    when(systemAdmin.getSSPMetadata(eq(ImmutableSet.of(changelogSSP)))).thenReturn(ImmutableMap.of(changelogSSP, metadata));
-
-    // invoke the method
-    Map<SystemStreamPartition, Option<String>> offsets =
-        tsm.getNewestChangelogSSPOffsets(
-            taskName, storeChangelogs, changelogPartition, systemAdmins);
-
-    // verify results
-    assertEquals(1, offsets.size());
-    assertEquals(Option.empty(), offsets.apply(changelogSSP));
-  }
-
-  @Test(expected = SamzaException.class)
-  public void testGetNewestOffsetsThrowsIfNullMetadata() {
-    // empty topic == null newest offset
-    ContainerStorageManager csm = mock(ContainerStorageManager.class);
-    TransactionalStateTaskStorageManager tsm = buildTSM(csm, mock(Partition.class), new StorageManagerUtil());
-
-    TaskName taskName = mock(TaskName.class);
-    String changelogSystemName = "systemName";
-    String storeName = "storeName";
-    String changelogStreamName = "changelogName";
-    String newestChangelogSSPOffset = null;
-    SystemStream changelogSystemStream = new SystemStream(changelogSystemName, changelogStreamName);
-    Partition changelogPartition = new Partition(0);
-    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
-
-    Map<String, SystemStream> storeChangelogs =
-        ScalaJavaUtil.toScalaMap(ImmutableMap.of(storeName, changelogSystemStream));
-
-    SystemAdmins systemAdmins = mock(SystemAdmins.class);
-    SystemAdmin systemAdmin = mock(SystemAdmin.class);
-    SystemStreamPartitionMetadata metadata = mock(SystemStreamPartitionMetadata.class);
-
-    when(metadata.getNewestOffset()).thenReturn(newestChangelogSSPOffset);
-    when(systemAdmins.getSystemAdmin(changelogSystemName)).thenReturn(systemAdmin);
-    when(systemAdmin.getSSPMetadata(eq(ImmutableSet.of(changelogSSP)))).thenReturn(null);
-
-    // invoke the method
-    Map<SystemStreamPartition, Option<String>> offsets =
-        tsm.getNewestChangelogSSPOffsets(
-            taskName, storeChangelogs, changelogPartition, systemAdmins);
-
-    // verify results
-    fail("Should have thrown an exception if admin didn't return any metadata");
-  }
-
-  @Test(expected = SamzaException.class)
-  public void testGetNewestOffsetsThrowsIfNullSSPMetadata() {
-    // empty topic == null newest offset
-    ContainerStorageManager csm = mock(ContainerStorageManager.class);
-    TransactionalStateTaskStorageManager tsm = buildTSM(csm, mock(Partition.class), new StorageManagerUtil());
-
-    TaskName taskName = mock(TaskName.class);
-    String changelogSystemName = "systemName";
-    String storeName = "storeName";
-    String changelogStreamName = "changelogName";
-    String newestChangelogSSPOffset = null;
-    SystemStream changelogSystemStream = new SystemStream(changelogSystemName, changelogStreamName);
-    Partition changelogPartition = new Partition(0);
-    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
-
-    Map<String, SystemStream> storeChangelogs =
-        ScalaJavaUtil.toScalaMap(ImmutableMap.of(storeName, changelogSystemStream));
-
-    SystemAdmins systemAdmins = mock(SystemAdmins.class);
-    SystemAdmin systemAdmin = mock(SystemAdmin.class);
-    SystemStreamPartitionMetadata metadata = mock(SystemStreamPartitionMetadata.class);
-
-    when(metadata.getNewestOffset()).thenReturn(newestChangelogSSPOffset);
-    when(systemAdmins.getSystemAdmin(changelogSystemName)).thenReturn(systemAdmin);
-    java.util.Map metadataMap = new HashMap() { {
-        put(changelogSSP, null);
-      } };
-    when(systemAdmin.getSSPMetadata(eq(ImmutableSet.of(changelogSSP)))).thenReturn(metadataMap);
-
-    // invoke the method
-    Map<SystemStreamPartition, Option<String>> offsets =
-        tsm.getNewestChangelogSSPOffsets(
-            taskName, storeChangelogs, changelogPartition, systemAdmins);
-
-    // verify results
-    fail("Should have thrown an exception if admin returned null metadata for changelog SSP");
-  }
-
-  @Test(expected = SamzaException.class)
-  public void testGetNewestOffsetsThrowsIfErrorGettingMetadata() {
-    // empty topic == null newest offset
-    ContainerStorageManager csm = mock(ContainerStorageManager.class);
-    TransactionalStateTaskStorageManager tsm = buildTSM(csm, mock(Partition.class), new StorageManagerUtil());
-
-    TaskName taskName = mock(TaskName.class);
-    String changelogSystemName = "systemName";
-    String storeName = "storeName";
-    String changelogStreamName = "changelogName";
-    String newestChangelogSSPOffset = null;
-    SystemStream changelogSystemStream = new SystemStream(changelogSystemName, changelogStreamName);
-    Partition changelogPartition = new Partition(0);
-    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
-
-    Map<String, SystemStream> storeChangelogs =
-        ScalaJavaUtil.toScalaMap(ImmutableMap.of(storeName, changelogSystemStream));
-
-    SystemAdmins systemAdmins = mock(SystemAdmins.class);
-    SystemAdmin systemAdmin = mock(SystemAdmin.class);
-    SystemStreamPartitionMetadata metadata = mock(SystemStreamPartitionMetadata.class);
-
-    when(metadata.getNewestOffset()).thenReturn(newestChangelogSSPOffset);
-    when(systemAdmins.getSystemAdmin(changelogSystemName)).thenThrow(new SamzaException("Error getting metadata"));
-    when(systemAdmin.getSSPMetadata(eq(ImmutableSet.of(changelogSSP)))).thenReturn(null);
-
-    // invoke the method
-    Map<SystemStreamPartition, Option<String>> offsets =
-        tsm.getNewestChangelogSSPOffsets(
-            taskName, storeChangelogs, changelogPartition, systemAdmins);
-
-    // verify results
-    fail("Should have thrown an exception if admin had an error getting metadata");
-  }
-
-  @Test
-  public void testCheckpoint() {
-    ContainerStorageManager csm = mock(ContainerStorageManager.class);
-
-    StorageEngine mockLPStore = mock(StorageEngine.class);
-    StoreProperties lpStoreProps = mock(StoreProperties.class);
-    when(mockLPStore.getStoreProperties()).thenReturn(lpStoreProps);
-    when(lpStoreProps.isPersistedToDisk()).thenReturn(true);
-    when(lpStoreProps.isLoggedStore()).thenReturn(true);
-    Path mockPath = mock(Path.class);
-    when(mockLPStore.checkpoint(any())).thenReturn(Optional.of(mockPath));
-
-    StorageEngine mockPStore = mock(StorageEngine.class);
-    StoreProperties pStoreProps = mock(StoreProperties.class);
-    when(mockPStore.getStoreProperties()).thenReturn(pStoreProps);
-    when(pStoreProps.isPersistedToDisk()).thenReturn(true);
-    when(pStoreProps.isLoggedStore()).thenReturn(false);
-
-    StorageEngine mockLIStore = mock(StorageEngine.class);
-    StoreProperties liStoreProps = mock(StoreProperties.class);
-    when(mockLIStore.getStoreProperties()).thenReturn(liStoreProps);
-    when(liStoreProps.isPersistedToDisk()).thenReturn(false);
-    when(liStoreProps.isLoggedStore()).thenReturn(true);
-
-    StorageEngine mockIStore = mock(StorageEngine.class);
-    StoreProperties iStoreProps = mock(StoreProperties.class);
-    when(mockIStore.getStoreProperties()).thenReturn(iStoreProps);
-    when(iStoreProps.isPersistedToDisk()).thenReturn(false);
-    when(iStoreProps.isLoggedStore()).thenReturn(false);
-
-    java.util.Map<String, StorageEngine> taskStores = ImmutableMap.of(
-        "loggedPersistentStore", mockLPStore,
-        "persistentStore", mockPStore,
-        "loggedInMemStore", mockLIStore,
-        "inMemStore", mockIStore
-    );
-    when(csm.getAllStores(any())).thenReturn(taskStores);
-
-    TransactionalStateTaskStorageManager tsm = spy(buildTSM(csm, mock(Partition.class), new StorageManagerUtil()));
-    // stub actual method call
-    ArgumentCaptor<Map> checkpointPathsCaptor = ArgumentCaptor.forClass(Map.class);
-    doNothing().when(tsm).writeChangelogOffsetFiles(any(), any(), any());
-
-    Map<SystemStreamPartition, Option<String>> offsets = ScalaJavaUtil.toScalaMap(
-        ImmutableMap.of(mock(SystemStreamPartition.class), Option.apply("1")));
-
-    // invoke checkpoint
-    tsm.checkpoint(CheckpointId.create(), offsets);
-
-    // ensure that checkpoint is never called for non-logged persistent stores since they're
-    // always cleared on restart.
-    verify(mockPStore, never()).checkpoint(any());
-    // ensure that checkpoint is never called for in-memory stores since they're not persistent.
-    verify(mockIStore, never()).checkpoint(any());
-    verify(mockLIStore, never()).checkpoint(any());
-    verify(tsm).writeChangelogOffsetFiles(checkpointPathsCaptor.capture(), any(), eq(offsets));
-    Map<String, Path> checkpointPaths = checkpointPathsCaptor.getValue();
-    assertEquals(1, checkpointPaths.size());
-    assertEquals(mockPath, checkpointPaths.apply("loggedPersistentStore"));
-  }
-
-  @Test(expected = IllegalStateException.class)
-  public void testCheckpointFailsIfErrorCreatingCheckpoint() {
-    ContainerStorageManager csm = mock(ContainerStorageManager.class);
-
-    StorageEngine mockLPStore = mock(StorageEngine.class);
-    StoreProperties lpStoreProps = mock(StoreProperties.class);
-    when(mockLPStore.getStoreProperties()).thenReturn(lpStoreProps);
-    when(lpStoreProps.isPersistedToDisk()).thenReturn(true);
-    when(lpStoreProps.isLoggedStore()).thenReturn(true);
-    when(mockLPStore.checkpoint(any())).thenThrow(new IllegalStateException());
-    java.util.Map<String, StorageEngine> taskStores =
-        ImmutableMap.of("loggedPersistentStore", mockLPStore);
-    when(csm.getAllStores(any())).thenReturn(taskStores);
-
-    TransactionalStateTaskStorageManager tsm = spy(buildTSM(csm, mock(Partition.class), new StorageManagerUtil()));
-
-    Map<SystemStreamPartition, Option<String>> offsets = ScalaJavaUtil.toScalaMap(
-        ImmutableMap.of(mock(SystemStreamPartition.class), Option.apply("1")));
-
-    // invoke checkpoint
-    tsm.checkpoint(CheckpointId.create(), offsets);
-    verify(tsm, never()).writeChangelogOffsetFiles(any(), any(), any());
-    fail("Should have thrown an exception if error creating store checkpoint");
-  }
-
-  @Test(expected = SamzaException.class)
-  public void testCheckpointFailsIfErrorWritingOffsetFiles() {
-    ContainerStorageManager csm = mock(ContainerStorageManager.class);
-
-    StorageEngine mockLPStore = mock(StorageEngine.class);
-    StoreProperties lpStoreProps = mock(StoreProperties.class);
-    when(mockLPStore.getStoreProperties()).thenReturn(lpStoreProps);
-    when(lpStoreProps.isPersistedToDisk()).thenReturn(true);
-    when(lpStoreProps.isLoggedStore()).thenReturn(true);
-    Path mockPath = mock(Path.class);
-    when(mockLPStore.checkpoint(any())).thenReturn(Optional.of(mockPath));
-    java.util.Map<String, StorageEngine> taskStores =
-        ImmutableMap.of("loggedPersistentStore", mockLPStore);
-    when(csm.getAllStores(any())).thenReturn(taskStores);
-
-    TransactionalStateTaskStorageManager tsm = spy(buildTSM(csm, mock(Partition.class), new StorageManagerUtil()));
-    doThrow(new SamzaException("Error writing offset file"))
-        .when(tsm).writeChangelogOffsetFiles(any(), any(), any());
-
-    Map<SystemStreamPartition, Option<String>> offsets = ScalaJavaUtil.toScalaMap(
-        ImmutableMap.of(mock(SystemStreamPartition.class), Option.apply("1")));
-
-    // invoke checkpoint
-    tsm.checkpoint(CheckpointId.create(), offsets);
-
-    fail("Should have thrown an exception if error writing offset file.");
-  }
-
-  @Test
-  public void testWriteChangelogOffsetFiles() throws IOException {
-    String storeName = "mockStore";
-    ContainerStorageManager csm = mock(ContainerStorageManager.class);
-    StorageEngine mockStore = mock(StorageEngine.class);
-    java.util.Map<String, StorageEngine> taskStores = ImmutableMap.of(storeName, mockStore);
-    when(csm.getAllStores(any())).thenReturn(taskStores);
-
-    Partition changelogPartition = new Partition(0);
-    SystemStream changelogSS = new SystemStream("system", "changelog");
-    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSS, changelogPartition);
-    StorageManagerUtil smu = spy(new StorageManagerUtil());
-    File mockCurrentStoreDir = mock(File.class);
-    doReturn(mockCurrentStoreDir).when(smu).getTaskStoreDir(any(), eq(storeName), any(), any());
-    doNothing().when(smu).writeOffsetFile(eq(mockCurrentStoreDir), any(), anyBoolean());
-    TransactionalStateTaskStorageManager tsm = spy(buildTSM(csm, changelogPartition, smu));
-
-    String changelogNewestOffset = "1";
-    Map<SystemStreamPartition, Option<String>> offsets = ScalaJavaUtil.toScalaMap(
-        ImmutableMap.of(changelogSSP, Option.apply(changelogNewestOffset)));
-
-    Path checkpointPath = Files.createTempDirectory("store-checkpoint-test").toAbsolutePath();
-
-    Map<String, Path> checkpointPaths = ScalaJavaUtil.toScalaMap(
-        ImmutableMap.of(storeName, checkpointPath));
-    Map<String, SystemStream> storeChangelogs = ScalaJavaUtil.toScalaMap(
-        ImmutableMap.of(storeName, changelogSS));
-
-    // invoke method
-    tsm.writeChangelogOffsetFiles(checkpointPaths, storeChangelogs, offsets);
-
-    // verify that offset file was written to the checkpoint dir
-    java.util.Map<SystemStreamPartition, String> fileOffsets = new StorageManagerUtil()
-        .readOffsetFile(checkpointPath.toFile(), ImmutableSet.of(changelogSSP), false);
-    assertEquals(1, fileOffsets.size());
-    assertEquals(changelogNewestOffset, fileOffsets.get(changelogSSP));
-
-    // verify that offset file write was called on the current dir
-    verify(smu, times(1)).writeOffsetFile(eq(mockCurrentStoreDir), any(), anyBoolean());
-  }
-
-  @Test
-  public void testWriteChangelogOffsetFilesWithEmptyChangelogTopic() throws IOException {
-    String storeName = "mockStore";
-    ContainerStorageManager csm = mock(ContainerStorageManager.class);
-    StorageEngine mockStore = mock(StorageEngine.class);
-    java.util.Map<String, StorageEngine> taskStores = ImmutableMap.of(storeName, mockStore);
-    when(csm.getAllStores(any())).thenReturn(taskStores);
-
-    Partition changelogPartition = new Partition(0);
-    SystemStream changelogSS = new SystemStream("system", "changelog");
-    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSS, changelogPartition);
-    StorageManagerUtil mockSMU = mock(StorageManagerUtil.class);
-    File mockCurrentStoreDir = mock(File.class);
-    when(mockSMU.getTaskStoreDir(any(), eq(storeName), any(), any())).thenReturn(mockCurrentStoreDir);
-    TransactionalStateTaskStorageManager tsm = spy(buildTSM(csm, changelogPartition, mockSMU));
-
-    String changelogNewestOffset = null;
-    Map<SystemStreamPartition, Option<String>> offsets = ScalaJavaUtil.toScalaMap(
-        ImmutableMap.of(changelogSSP, Option.apply(changelogNewestOffset)));
-
-    Path checkpointPath = Files.createTempDirectory("store-checkpoint-test").toAbsolutePath();
-
-    Map<String, Path> checkpointPaths = ScalaJavaUtil.toScalaMap(
-        ImmutableMap.of(storeName, checkpointPath));
-    Map<String, SystemStream> storeChangelogs = ScalaJavaUtil.toScalaMap(
-        ImmutableMap.of(storeName, changelogSS));
-
-    // invoke method
-    tsm.writeChangelogOffsetFiles(checkpointPaths, storeChangelogs, offsets);
-
-    // verify that the offset files were not written to the checkpoint dir
-    assertFalse(Files.exists(new File(checkpointPath.toFile(), StorageManagerUtil.OFFSET_FILE_NAME_LEGACY).toPath()));
-    assertFalse(Files.exists(new File(checkpointPath.toFile(), StorageManagerUtil.OFFSET_FILE_NAME_NEW).toPath()));
-    java.util.Map<SystemStreamPartition, String> fileOffsets = new StorageManagerUtil()
-        .readOffsetFile(checkpointPath.toFile(), ImmutableSet.of(changelogSSP), false);
-    assertEquals(0, fileOffsets.size());
-
-    // verify that delete was called on current store dir offset file
-    verify(mockSMU, times(1)).deleteOffsetFile(eq(mockCurrentStoreDir));
-  }
-
-  /**
-   * This should never happen with CheckpointingTaskStorageManager. #getNewestChangelogSSPOffset must
-   * return a key for every changelog SSP. If the SSP is empty, the value should be none. If it could
-   * not fetch metadata, it should throw an exception instead of skipping the SSP.
-   * If this contract is accidentally broken, ensure that we fail the commit
-   */
-  @Test(expected = SamzaException.class)
-  public void testWriteChangelogOffsetFilesWithNoChangelogOffset() throws IOException {
-    String storeName = "mockStore";
-    ContainerStorageManager csm = mock(ContainerStorageManager.class);
-    StorageEngine mockStore = mock(StorageEngine.class);
-    java.util.Map<String, StorageEngine> taskStores = ImmutableMap.of(storeName, mockStore);
-    when(csm.getAllStores(any())).thenReturn(taskStores);
-
-    Partition changelogPartition = new Partition(0);
-    SystemStream changelogSS = new SystemStream("system", "changelog");
-    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSS, changelogPartition);
-    TransactionalStateTaskStorageManager tsm = spy(buildTSM(csm, changelogPartition, new StorageManagerUtil()));
-
-    // no mapping present for changelog newest offset
-    Map<SystemStreamPartition, Option<String>> offsets = ScalaJavaUtil.toScalaMap(ImmutableMap.of());
-
-    Path checkpointPath = Files.createTempDirectory("store-checkpoint-test").toAbsolutePath();
-    Map<String, Path> checkpointPaths = ScalaJavaUtil.toScalaMap(
-        ImmutableMap.of(storeName, checkpointPath));
-    Map<String, SystemStream> storeChangelogs = ScalaJavaUtil.toScalaMap(
-        ImmutableMap.of(storeName, changelogSS));
-
-    // invoke method
-    tsm.writeChangelogOffsetFiles(checkpointPaths, storeChangelogs, offsets);
-
-    fail("Should have thrown an exception if no changelog offset found for checkpointed store");
-  }
-
-  @Test
-  public void testRemoveOldCheckpointsWhenBaseDirContainsRegularFiles() {
-    TaskName taskName = new TaskName("Partition 0");
-    ContainerStorageManager containerStorageManager = mock(ContainerStorageManager.class);
-    Map<String, SystemStream> changelogSystemStreams = mock(Map.class);
-    SystemAdmins systemAdmins = mock(SystemAdmins.class);
-    File loggedStoreBaseDir = mock(File.class);
-    Partition changelogPartition = new Partition(0);
-    TaskMode taskMode = TaskMode.Active;
-    StorageManagerUtil storageManagerUtil = mock(StorageManagerUtil.class);
-
-    File mockStoreDir = mock(File.class);
-    String mockStoreDirName = "notDirectory";
-
-    when(loggedStoreBaseDir.listFiles()).thenReturn(new File[] {mockStoreDir});
-    when(mockStoreDir.getName()).thenReturn(mockStoreDirName);
-    when(storageManagerUtil.getTaskStoreDir(eq(loggedStoreBaseDir), eq(mockStoreDirName), eq(taskName), eq(taskMode))).thenReturn(mockStoreDir);
-    // null here can happen if listFiles is called on a non-directory
-    when(mockStoreDir.listFiles(any(FileFilter.class))).thenReturn(null);
-
-    TransactionalStateTaskStorageManager tsm = new TransactionalStateTaskStorageManager(taskName, containerStorageManager,
-        changelogSystemStreams, systemAdmins, loggedStoreBaseDir, changelogPartition, taskMode, storageManagerUtil);
-
-    tsm.removeOldCheckpoints(CheckpointId.create());
-  }
-
-  private TransactionalStateTaskStorageManager buildTSM(ContainerStorageManager csm, Partition changelogPartition,
-      StorageManagerUtil smu) {
-    TaskName taskName = new TaskName("Partition 0");
-    Map<String, SystemStream> changelogSystemStreams = mock(Map.class);
-    SystemAdmins systemAdmins = mock(SystemAdmins.class);
-    File loggedStoreBaseDir = mock(File.class);
-    TaskMode taskMode = TaskMode.Active;
-
-    return new TransactionalStateTaskStorageManager(
-        taskName, csm, changelogSystemStreams, systemAdmins,
-        loggedStoreBaseDir, changelogPartition, taskMode, smu);
-  }
-}
\ No newline at end of file
diff --git a/samza-core/src/test/scala/org/apache/samza/storage/TestTaskStorageManager.scala b/samza-kafka/src/test/java/org/apache/samza/storage/TestTaskStorageManager.scala
similarity index 82%
rename from samza-core/src/test/scala/org/apache/samza/storage/TestTaskStorageManager.scala
rename to samza-kafka/src/test/java/org/apache/samza/storage/TestTaskStorageManager.scala
index 957c00c..167bc78 100644
--- a/samza-core/src/test/scala/org/apache/samza/storage/TestTaskStorageManager.scala
+++ b/samza-kafka/src/test/java/org/apache/samza/storage/TestTaskStorageManager.scala
@@ -22,24 +22,27 @@
 import java.io.{File, FileOutputStream, ObjectOutputStream}
 import java.util
 
+import com.google.common.collect.{ImmutableMap, ImmutableSet}
 import org.apache.samza.Partition
+import org.apache.samza.checkpoint.kafka.KafkaStateCheckpointMarker
+import org.apache.samza.checkpoint.{CheckpointId, CheckpointManager, CheckpointV1}
 import org.apache.samza.config._
 import org.apache.samza.container.{SamzaContainerMetrics, TaskInstanceMetrics, TaskName}
 import org.apache.samza.context.{ContainerContext, JobContext}
-import org.apache.samza.job.model.{ContainerModel, TaskMode, TaskModel}
+import org.apache.samza.job.model.{ContainerModel, JobModel, TaskMode, TaskModel}
 import org.apache.samza.serializers.{Serde, StringSerdeFactory}
 import org.apache.samza.storage.StoreProperties.StorePropertiesBuilder
 import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata
 import org.apache.samza.system._
 import org.apache.samza.task.TaskInstanceCollector
-import org.apache.samza.util.{FileUtil, SystemClock}
+import org.apache.samza.util.{Clock, FileUtil, SystemClock}
 import org.junit.Assert._
 import org.junit.runner.RunWith
 import org.junit.runners.Parameterized
 import org.junit.runners.Parameterized.Parameters
 import org.junit.{After, Before, Test}
 import org.mockito.Matchers._
-import org.mockito.{Mockito}
+import org.mockito.Mockito
 import org.mockito.Mockito._
 import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
@@ -48,8 +51,6 @@
 import scala.collection.JavaConverters._
 import scala.collection.immutable.HashMap
 import scala.collection.mutable
-import com.google.common.collect.{ImmutableMap, ImmutableSet}
-import org.apache.samza.checkpoint.{Checkpoint, CheckpointManager}
 
 /**
   * This test is parameterized on the offsetFileName and is run for both
@@ -58,7 +59,7 @@
   * @param offsetFileName the name of the offset file.
   */
 @RunWith(value = classOf[Parameterized])
-class TestNonTransactionalStateTaskStorageManager(offsetFileName: String) extends MockitoSugar {
+class TestKafkaNonTransactionalStateTaskBackupManager(offsetFileName: String) extends MockitoSugar {
 
   val store = "store1"
   val loggedStore = "loggedStore1"
@@ -131,44 +132,6 @@
     assertTrue(storeFile.exists())
     assertFalse(offsetFile.exists())
     verify(mockSystemConsumer).register(ssp, "0")
-
-    // Test 2: flush should update the offset file
-    taskManager.flush()
-    assertTrue(offsetFile.exists())
-    validateOffsetFileContents(offsetFile, "kafka.testStream-loggedStore1.0", "50")
-
-    // Test 3: Update sspMetadata before shutdown and verify that offset file is not updated
-    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp)))
-      .thenReturn(ImmutableMap.of(ssp, new SystemStreamPartitionMetadata("0", "100", "101")))
-    taskManager.stop()
-    verify(mockStorageEngine, times(1)).flush() // only called once during Test 2.
-    assertTrue(storeFile.exists())
-    assertTrue(offsetFile.exists())
-    validateOffsetFileContents(offsetFile, "kafka.testStream-loggedStore1.0", "50")
-
-    // Test 4: Initialize again with an updated sspMetadata; Verify that it restores from the correct offset
-    sspMetadata = new SystemStreamPartitionMetadata("0", "150", "151")
-    metadata = new SystemStreamMetadata(getStreamName(loggedStore), new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
-      {
-        put(partition, sspMetadata)
-      }
-    })
-    when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(ss -> metadata))
-    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp)))
-      .thenReturn(ImmutableMap.of(ssp, sspMetadata))
-    when(mockSystemAdmin.getOffsetsAfter(Map(ssp -> "50").asJava)).thenReturn(Map(ssp -> "51").asJava)
-    Mockito.reset(mockSystemConsumer)
-
-    taskManager = new TaskStorageManagerBuilder()
-      .addStore(loggedStore, mockStorageEngine, mockSystemConsumer)
-      .setStreamMetadataCache(mockStreamMetadataCache)
-      .setSystemAdmin("kafka", mockSystemAdmin)
-      .initializeContainerStorageManager()
-      .build
-
-    assertTrue(storeFile.exists())
-    assertTrue(offsetFile.exists())
-    verify(mockSystemConsumer).register(ssp, "51")
   }
 
   /**
@@ -217,7 +180,9 @@
     verify(mockSystemConsumer).register(ssp, "0")
 
     // Test 2: flush should NOT create/update the offset file. Store directory has no files
-    taskManager.flush()
+    val checkpointId = CheckpointId.create()
+    val snapshot = taskManager.snapshot(checkpointId)
+    val stateCheckpointMarkers = taskManager.upload(checkpointId, snapshot)
     assertTrue(storeDirectory.list().isEmpty)
 
     // Test 3: Update sspMetadata before shutdown and verify that offset file is NOT created
@@ -228,7 +193,7 @@
     })
     when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(ss -> metadata))
     when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of(ssp, sspMetadata))
-    taskManager.stop()
+    taskManager.close()
     assertTrue(storeDirectory.list().isEmpty)
 
     // Test 4: Initialize again with an updated sspMetadata; Verify that it restores from the earliest offset
@@ -368,60 +333,19 @@
       .build
 
     //Invoke test method
-    taskStorageManager.stop()
+    taskStorageManager.close()
 
     //Check conditions
     assertFalse("Offset file doesn't exist!", offsetFile.exists())
   }
 
   /**
-    * Given that the SSPMetadataCache returns metadata, flush should create the offset files.
-    */
-  @Test
-  def testFlushCreatesOffsetFileForLoggedStore() {
-    val partition = new Partition(0)
-
-    val offsetFilePath = new File(storageManagerUtil.getTaskStoreDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName, TaskMode.Active) + File.separator + offsetFileName)
-    val anotherOffsetPath = new File(
-      storageManagerUtil.getTaskStoreDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, store, taskName, TaskMode.Active) + File.separator + offsetFileName)
-
-    val ssp1 = new SystemStreamPartition("kafka", getStreamName(loggedStore), partition)
-    val ssp2 = new SystemStreamPartition("kafka", getStreamName(store), partition)
-    val sspMetadata = new SystemStreamPartitionMetadata("20", "100", "101")
-
-    val mockSystemAdmin = mock[SystemAdmin]
-    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp1))).thenReturn(ImmutableMap.of(ssp1, sspMetadata))
-    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp2))).thenReturn(ImmutableMap.of(ssp2, sspMetadata))
-
-    //Build TaskStorageManager
-    val taskStorageManager = new TaskStorageManagerBuilder()
-      .addLoggedStore(loggedStore, true)
-      .addStore(store, false)
-      .setSystemAdmin("kafka", mockSystemAdmin)
-      .setStreamMetadataCache(createMockStreamMetadataCache("20", "100", "101"))
-      .setPartition(partition)
-      .initializeContainerStorageManager()
-      .build
-
-    //Invoke test method
-    taskStorageManager.flush()
-
-    //Check conditions
-    assertTrue("Offset file doesn't exist!", offsetFilePath.exists())
-    validateOffsetFileContents(offsetFilePath, "kafka.testStream-loggedStore1.0", "100")
-
-    assertTrue("Offset file got created for a store that is not persisted to the disk!!", !anotherOffsetPath.exists())
-  }
-
-  /**
     * Flush should delete the existing OFFSET file if the changelog partition (for some reason) becomes empty
     */
   @Test
   def testFlushDeletesOffsetFileForLoggedStoreForEmptyPartition() {
     val partition = new Partition(0)
 
-    val offsetFilePath = new File(storageManagerUtil.getTaskStoreDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName, TaskMode.Active) + File.separator + offsetFileName)
-
     val ssp = new SystemStreamPartition("kafka", getStreamName(loggedStore), partition)
     val sspMetadata = new SystemStreamPartitionMetadata("0", "100", "101")
     val nullSspMetadata = new SystemStreamPartitionMetadata(null, null, null)
@@ -430,7 +354,7 @@
       .thenReturn(ImmutableMap.of(ssp, sspMetadata))
       .thenReturn(ImmutableMap.of(ssp, nullSspMetadata))
 
-    var metadata = new SystemStreamMetadata(getStreamName(loggedStore), new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
+    val metadata = new SystemStreamMetadata(getStreamName(loggedStore), new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
       {
         put(partition, sspMetadata)
       }
@@ -449,67 +373,15 @@
       .build
 
     //Invoke test method
-    taskStorageManager.flush()
-
-    //Check conditions
-    assertTrue("Offset file doesn't exist!", offsetFilePath.exists())
-    validateOffsetFileContents(offsetFilePath, "kafka.testStream-loggedStore1.0", "100")
+    val checkpointId = CheckpointId.create()
+    var snapshot = taskStorageManager.snapshot(checkpointId)
+    taskStorageManager.upload(checkpointId, snapshot)
 
     //Invoke test method again
-    taskStorageManager.flush()
+    snapshot = taskStorageManager.snapshot(checkpointId)
+    val stateCheckpointMarkers2 = taskStorageManager.upload(checkpointId, snapshot)
 
-    //Check conditions
-    assertFalse("Offset file for null offset exists!", offsetFilePath.exists())
-  }
-
-  @Test
-  def testFlushOverwritesOffsetFileForLoggedStore() {
-    val partition = new Partition(0)
-    val ssp = new SystemStreamPartition("kafka", getStreamName(loggedStore), partition)
-
-    val offsetFilePath = new File(storageManagerUtil.getTaskStoreDir(TaskStorageManagerBuilder.defaultLoggedStoreBaseDir, loggedStore, taskName, TaskMode.Active) + File.separator + offsetFileName)
-    fileUtil.writeWithChecksum(offsetFilePath, "100")
-
-    val sspMetadata = new SystemStreamPartitionMetadata("20", "139", "140")
-    val mockSystemAdmin = mock[SystemAdmin]
-    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp))).thenReturn(ImmutableMap.of(ssp, sspMetadata))
-
-
-    var metadata = new SystemStreamMetadata(getStreamName(loggedStore), new java.util.HashMap[Partition, SystemStreamPartitionMetadata]() {
-      {
-        put(partition, sspMetadata)
-      }
-    })
-
-    val mockStreamMetadataCache = mock[StreamMetadataCache]
-    when(mockStreamMetadataCache.getStreamMetadata(any(), any())).thenReturn(Map(new SystemStream("kafka", getStreamName(loggedStore)) -> metadata))
-
-    //Build TaskStorageManager
-    val taskStorageManager = new TaskStorageManagerBuilder()
-      .addLoggedStore(loggedStore, true)
-      .setSystemAdmin("kafka", mockSystemAdmin)
-      .setPartition(partition)
-      .setStreamMetadataCache(mockStreamMetadataCache)
-      .initializeContainerStorageManager()
-      .build
-
-    //Invoke test method
-    taskStorageManager.flush()
-
-    //Check conditions
-    assertTrue("Offset file doesn't exist!", offsetFilePath.exists())
-    validateOffsetFileContents(offsetFilePath, "kafka.testStream-loggedStore1.0", "139")
-
-    // Flush again
-    when(mockSystemAdmin.getSSPMetadata(ImmutableSet.of(ssp)))
-      .thenReturn(ImmutableMap.of(ssp, new SystemStreamPartitionMetadata("20", "193", "194")))
-
-    //Invoke test method
-    taskStorageManager.flush()
-
-    //Check conditions
-    assertTrue("Offset file doesn't exist!", offsetFilePath.exists())
-    validateOffsetFileContents(offsetFilePath, "kafka.testStream-loggedStore1.0", "193")
+    assertNull(KafkaStateCheckpointMarker.deserialize(stateCheckpointMarkers2.get.get(loggedStore)).getChangelogOffset)
   }
 
   /**
@@ -517,7 +389,7 @@
     * The legacy offset file only contains the offset as a string, while the new offset file contains a map of
     * ssp to offset in json format.
     * The name of the two offset files are given in {@link StorageManagerUtil.OFFSET_FILE_NAME_NEW} and
-    * {@link StorageManagerUtil.OFFSET_FILE_LEGACY}.
+    * {@link StorageManagerUtil.OFFSET_FILE_NAME_LEGACY}.
     */
   private def validateOffsetFileContents(offsetFile: File, ssp: String, offset: String): Unit = {
 
@@ -549,7 +421,7 @@
       .build
 
     //Invoke test method
-    taskStorageManager.stop()
+    taskStorageManager.close()
 
     //Check conditions
     assertTrue("Offset file should not exist!", !offsetFilePath.exists())
@@ -777,7 +649,7 @@
   }
 }
 
-object TestNonTransactionalStateTaskStorageManager {
+object TestKafkaNonTransactionalStateTaskBackupManager {
 
   @Parameters def parameters: util.Collection[Array[String]] = {
     val offsetFileNames = new util.ArrayList[Array[String]]()
@@ -864,7 +736,10 @@
     var containerModel = new ContainerModel("container", tasks.asJava)
 
     val mockSystemAdmins = Mockito.mock(classOf[SystemAdmins])
-    Mockito.when(mockSystemAdmins.getSystemAdmin(org.mockito.Matchers.eq("kafka"))).thenReturn(systemAdminsMap.get("kafka").get)
+    Mockito.when(mockSystemAdmins.getSystemAdmin(org.mockito.Matchers.eq("kafka")))
+      .thenReturn(systemAdminsMap.get("kafka").get)
+    Mockito.when(mockSystemAdmins.getSystemAdmins)
+      .thenReturn(systemAdminsMap.asJava)
 
     var mockStorageEngineFactory : StorageEngineFactory[AnyRef, AnyRef] = Mockito.mock(classOf[StorageEngineFactory[AnyRef, AnyRef]])
 
@@ -900,15 +775,20 @@
 
     val mockCheckpointManager = Mockito.mock(classOf[CheckpointManager])
     when(mockCheckpointManager.readLastCheckpoint(any(classOf[TaskName])))
-      .thenReturn(new Checkpoint(new util.HashMap[SystemStreamPartition, String]()))
+      .thenReturn(new CheckpointV1(new util.HashMap[SystemStreamPartition, String]()))
 
-    val mockSSPMetadataCache = Mockito.mock(classOf[SSPMetadataCache])
+    val mockContainerContext = Mockito.mock(classOf[ContainerContext])
+    when(mockContainerContext.getContainerModel).thenReturn(containerModel);
+
+    val mockJobContext = Mockito.mock(classOf[JobContext])
+    val mockJobModel =  Mockito.mock(classOf[JobModel])
+    when(mockJobContext.getJobModel).thenReturn(mockJobModel)
+    when(mockJobModel.getMaxChangeLogStreamPartitions).thenReturn(1)
 
     containerStorageManager = new ContainerStorageManager(
       mockCheckpointManager,
       containerModel,
       streamMetadataCache,
-      mockSSPMetadataCache,
       mockSystemAdmins,
       changeLogSystemStreams.asJava,
       Map[String, util.Set[SystemStream]]().asJava,
@@ -918,12 +798,12 @@
       config,
       new HashMap[TaskName, TaskInstanceMetrics]().asJava,
       Mockito.mock(classOf[SamzaContainerMetrics]),
-      Mockito.mock(classOf[JobContext]),
-      Mockito.mock(classOf[ContainerContext]),
+      mockJobContext,
+      mockContainerContext,
+      new mockKafkaChangelogBackendManager(changeLogSystemStreams),
       new HashMap[TaskName, TaskInstanceCollector].asJava,
       loggedStoreBaseDir,
       TaskStorageManagerBuilder.defaultStoreBaseDir,
-      1,
       null,
       new SystemClock)
     this
@@ -931,18 +811,16 @@
 
 
 
-  def build: NonTransactionalStateTaskStorageManager = {
+  def build: KafkaNonTransactionalStateTaskBackupManager = {
 
     if (containerStorageManager != null) {
       containerStorageManager.start()
     }
 
-    new NonTransactionalStateTaskStorageManager(
+    new KafkaNonTransactionalStateTaskBackupManager(
       taskName = taskName,
-      containerStorageManager = containerStorageManager,
-      storeChangelogs = changeLogSystemStreams,
+      storeChangelogs = changeLogSystemStreams.asJava,
       systemAdmins = buildSystemAdmins(systemAdminsMap),
-      loggedStoreBaseDir = loggedStoreBaseDir,
       partition = partition
     )
   }
@@ -954,4 +832,12 @@
     }
     systemAdmins
   }
+
+  private class mockKafkaChangelogBackendManager(storeSystemStream: Map[String, SystemStream])
+    extends KafkaChangelogStateBackendFactory {
+    override def filterStandbySystemStreams(changelogSystemStreams: util.Map[String, SystemStream], containerModel: ContainerModel):
+    util.Map[String, SystemStream] = storeSystemStream.asJava
+
+    override def getStreamCache(admins: SystemAdmins, clock: Clock): StreamMetadataCache = streamMetadataCache
+  }
 }
diff --git a/samza-kafka/src/test/java/org/apache/samza/storage/TestTransactionalStateTaskBackupManager.java b/samza-kafka/src/test/java/org/apache/samza/storage/TestTransactionalStateTaskBackupManager.java
new file mode 100644
index 0000000..8bb93f7
--- /dev/null
+++ b/samza-kafka/src/test/java/org/apache/samza/storage/TestTransactionalStateTaskBackupManager.java
@@ -0,0 +1,276 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage;
+
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import org.apache.samza.checkpoint.kafka.KafkaStateCheckpointMarker;
+
+import java.util.HashMap;
+import java.util.concurrent.ForkJoinPool;
+import org.apache.samza.Partition;
+import org.apache.samza.SamzaException;
+import org.apache.samza.checkpoint.CheckpointId;
+import org.apache.samza.container.TaskInstanceMetrics;
+import org.apache.samza.container.TaskName;
+import org.apache.samza.metrics.Timer;
+import org.apache.samza.system.SystemAdmin;
+import org.apache.samza.system.SystemAdmins;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamMetadata.SystemStreamPartitionMetadata;
+import org.apache.samza.system.SystemStreamPartition;
+import org.junit.Test;
+import org.mockito.InOrder;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.inOrder;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+
+public class TestTransactionalStateTaskBackupManager {
+  @Test
+  public void testFlushOrder() {
+    ContainerStorageManager csm = mock(ContainerStorageManager.class);
+    StorageEngine mockStore = mock(StorageEngine.class);
+    java.util.Map<String, StorageEngine> taskStores = ImmutableMap.of("mockStore", mockStore);
+    when(csm.getAllStores(any())).thenReturn(taskStores);
+    when(mockStore.getStoreProperties()).thenReturn(new StoreProperties
+        .StorePropertiesBuilder().setPersistedToDisk(true).setLoggedStore(true).build());
+    TaskInstanceMetrics metrics = mock(TaskInstanceMetrics.class);
+    Timer checkpointTimer = mock(Timer.class);
+    when(metrics.storeCheckpointNs()).thenReturn(checkpointTimer);
+
+    KafkaTransactionalStateTaskBackupManager tsm = spy(buildTSM(csm, mock(Partition.class), new StorageManagerUtil()));
+    TaskStorageCommitManager commitManager = new TaskStorageCommitManager(new TaskName("task"),
+        ImmutableMap.of("kafka", tsm), csm, null, null, null, null,
+        ForkJoinPool.commonPool(), new StorageManagerUtil(), null, metrics);
+    // stub actual method call
+    doReturn(mock(java.util.Map.class)).when(tsm).getNewestChangelogSSPOffsets(any(), any(), any(), any());
+
+    // invoke Kafka flush
+    commitManager.init();
+    commitManager.snapshot(CheckpointId.create());
+
+    // ensure that stores are flushed before we get newest changelog offsets
+    InOrder inOrder = inOrder(mockStore, tsm);
+    inOrder.verify(mockStore).flush();
+    inOrder.verify(tsm).getNewestChangelogSSPOffsets(any(), any(), any(), any());
+  }
+
+  @Test
+  public void testGetNewestOffsetsReturnsCorrectOffset() {
+    ContainerStorageManager csm = mock(ContainerStorageManager.class);
+    KafkaTransactionalStateTaskBackupManager tsm = buildTSM(csm, mock(Partition.class), new StorageManagerUtil());
+
+    TaskName taskName = mock(TaskName.class);
+    String changelogSystemName = "systemName";
+    String storeName = "storeName";
+    String changelogStreamName = "changelogName";
+    String newestChangelogSSPOffset = "1";
+    SystemStream changelogSystemStream = new SystemStream(changelogSystemName, changelogStreamName);
+    Partition changelogPartition = new Partition(0);
+    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
+
+    java.util.Map<String, SystemStream> storeChangelogs = new HashMap<>();
+    storeChangelogs.put(storeName, changelogSystemStream);
+
+    SystemAdmins systemAdmins = mock(SystemAdmins.class);
+    SystemAdmin systemAdmin = mock(SystemAdmin.class);
+    SystemStreamPartitionMetadata metadata = mock(SystemStreamPartitionMetadata.class);
+
+    when(metadata.getNewestOffset()).thenReturn(newestChangelogSSPOffset);
+    when(systemAdmins.getSystemAdmin(changelogSystemName)).thenReturn(systemAdmin);
+    when(systemAdmin.getSSPMetadata(eq(ImmutableSet.of(changelogSSP)))).thenReturn(ImmutableMap.of(changelogSSP, metadata));
+
+    // invoke the method
+    java.util.Map<String, String> stateCheckpointMarkerMap =
+        tsm.getNewestChangelogSSPOffsets(
+            taskName, storeChangelogs, changelogPartition, systemAdmins);
+
+    // verify results
+    assertEquals(1, stateCheckpointMarkerMap.size());
+    KafkaStateCheckpointMarker kscm = KafkaStateCheckpointMarker.deserialize(stateCheckpointMarkerMap.get(storeName));
+    assertEquals(newestChangelogSSPOffset, kscm.getChangelogOffset());
+    assertEquals(changelogSSP, kscm.getChangelogSSP());
+  }
+
+  @Test
+  public void testGetNewestOffsetsReturnsNoneForEmptyTopic() {
+    // empty topic == null newest offset
+    ContainerStorageManager csm = mock(ContainerStorageManager.class);
+    KafkaTransactionalStateTaskBackupManager tsm = buildTSM(csm, mock(Partition.class), new StorageManagerUtil());
+
+    TaskName taskName = mock(TaskName.class);
+    String changelogSystemName = "systemName";
+    String storeName = "storeName";
+    String changelogStreamName = "changelogName";
+    String newestChangelogSSPOffset = null;
+    SystemStream changelogSystemStream = new SystemStream(changelogSystemName, changelogStreamName);
+    Partition changelogPartition = new Partition(0);
+    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
+
+    java.util.Map<String, SystemStream> storeChangelogs = new HashMap<String, SystemStream>();
+    storeChangelogs.put(storeName, changelogSystemStream);
+
+    SystemAdmins systemAdmins = mock(SystemAdmins.class);
+    SystemAdmin systemAdmin = mock(SystemAdmin.class);
+    SystemStreamPartitionMetadata metadata = mock(SystemStreamPartitionMetadata.class);
+
+    when(metadata.getNewestOffset()).thenReturn(newestChangelogSSPOffset);
+    when(systemAdmins.getSystemAdmin(changelogSystemName)).thenReturn(systemAdmin);
+    when(systemAdmin.getSSPMetadata(eq(ImmutableSet.of(changelogSSP)))).thenReturn(ImmutableMap.of(changelogSSP, metadata));
+
+    // invoke the method
+    java.util.Map<String, String> stateCheckpointMarkerMap =
+        tsm.getNewestChangelogSSPOffsets(
+            taskName, storeChangelogs, changelogPartition, systemAdmins);
+
+    // verify results
+    assertEquals(1, stateCheckpointMarkerMap.size());
+    KafkaStateCheckpointMarker kscm = KafkaStateCheckpointMarker.deserialize(stateCheckpointMarkerMap.get(storeName));
+    assertEquals(changelogSSP, kscm.getChangelogSSP());
+    assertNull(kscm.getChangelogOffset());
+  }
+
+  @Test(expected = SamzaException.class)
+  public void testGetNewestOffsetsThrowsIfNullMetadata() {
+    // empty topic == null newest offset
+    ContainerStorageManager csm = mock(ContainerStorageManager.class);
+    KafkaTransactionalStateTaskBackupManager tsm = buildTSM(csm, mock(Partition.class), new StorageManagerUtil());
+
+    TaskName taskName = mock(TaskName.class);
+    String changelogSystemName = "systemName";
+    String storeName = "storeName";
+    String changelogStreamName = "changelogName";
+    String newestChangelogSSPOffset = null;
+    SystemStream changelogSystemStream = new SystemStream(changelogSystemName, changelogStreamName);
+    Partition changelogPartition = new Partition(0);
+    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
+
+    java.util.Map<String, SystemStream> storeChangelogs = new HashMap<>();
+    storeChangelogs.put(storeName, changelogSystemStream);
+
+    SystemAdmins systemAdmins = mock(SystemAdmins.class);
+    SystemAdmin systemAdmin = mock(SystemAdmin.class);
+    SystemStreamPartitionMetadata metadata = mock(SystemStreamPartitionMetadata.class);
+
+    when(metadata.getNewestOffset()).thenReturn(newestChangelogSSPOffset);
+    when(systemAdmins.getSystemAdmin(changelogSystemName)).thenReturn(systemAdmin);
+    when(systemAdmin.getSSPMetadata(eq(ImmutableSet.of(changelogSSP)))).thenReturn(null);
+
+    // invoke the method
+    java.util.Map<String, String> offsets =
+        tsm.getNewestChangelogSSPOffsets(
+            taskName, storeChangelogs, changelogPartition, systemAdmins);
+
+    // verify results
+    fail("Should have thrown an exception if admin didn't return any metadata");
+  }
+
+  @Test(expected = SamzaException.class)
+  public void testGetNewestOffsetsThrowsIfNullSSPMetadata() {
+    // empty topic == null newest offset
+    ContainerStorageManager csm = mock(ContainerStorageManager.class);
+    KafkaTransactionalStateTaskBackupManager tsm = buildTSM(csm, mock(Partition.class), new StorageManagerUtil());
+
+    TaskName taskName = mock(TaskName.class);
+    String changelogSystemName = "systemName";
+    String storeName = "storeName";
+    String changelogStreamName = "changelogName";
+    String newestChangelogSSPOffset = null;
+    SystemStream changelogSystemStream = new SystemStream(changelogSystemName, changelogStreamName);
+    Partition changelogPartition = new Partition(0);
+    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
+
+    java.util.Map<String, SystemStream> storeChangelogs = new HashMap<>();
+    storeChangelogs.put(storeName, changelogSystemStream);
+
+    SystemAdmins systemAdmins = mock(SystemAdmins.class);
+    SystemAdmin systemAdmin = mock(SystemAdmin.class);
+    SystemStreamPartitionMetadata metadata = mock(SystemStreamPartitionMetadata.class);
+
+    when(metadata.getNewestOffset()).thenReturn(newestChangelogSSPOffset);
+    when(systemAdmins.getSystemAdmin(changelogSystemName)).thenReturn(systemAdmin);
+    java.util.Map metadataMap = new HashMap() { {
+        put(changelogSSP, null);
+      } };
+    when(systemAdmin.getSSPMetadata(eq(ImmutableSet.of(changelogSSP)))).thenReturn(metadataMap);
+
+    // invoke the method
+    java.util.Map<String, String> offsets =
+        tsm.getNewestChangelogSSPOffsets(
+            taskName, storeChangelogs, changelogPartition, systemAdmins);
+
+    // verify results
+    fail("Should have thrown an exception if admin returned null metadata for changelog SSP");
+  }
+
+  @Test(expected = SamzaException.class)
+  public void testGetNewestOffsetsThrowsIfErrorGettingMetadata() {
+    // empty topic == null newest offset
+    ContainerStorageManager csm = mock(ContainerStorageManager.class);
+    KafkaTransactionalStateTaskBackupManager tsm = buildTSM(csm, mock(Partition.class), new StorageManagerUtil());
+
+    TaskName taskName = mock(TaskName.class);
+    String changelogSystemName = "systemName";
+    String storeName = "storeName";
+    String changelogStreamName = "changelogName";
+    String newestChangelogSSPOffset = null;
+    SystemStream changelogSystemStream = new SystemStream(changelogSystemName, changelogStreamName);
+    Partition changelogPartition = new Partition(0);
+    SystemStreamPartition changelogSSP = new SystemStreamPartition(changelogSystemStream, changelogPartition);
+
+    java.util.Map<String, SystemStream> storeChangelogs = new HashMap<>();
+    storeChangelogs.put(storeName, changelogSystemStream);
+
+    SystemAdmins systemAdmins = mock(SystemAdmins.class);
+    SystemAdmin systemAdmin = mock(SystemAdmin.class);
+    SystemStreamPartitionMetadata metadata = mock(SystemStreamPartitionMetadata.class);
+
+    when(metadata.getNewestOffset()).thenReturn(newestChangelogSSPOffset);
+    when(systemAdmins.getSystemAdmin(changelogSystemName)).thenThrow(new SamzaException("Error getting metadata"));
+    when(systemAdmin.getSSPMetadata(eq(ImmutableSet.of(changelogSSP)))).thenReturn(null);
+
+    // invoke the method
+    java.util.Map<String, String> offsets =
+        tsm.getNewestChangelogSSPOffsets(
+            taskName, storeChangelogs, changelogPartition, systemAdmins);
+
+    // verify results
+    fail("Should have thrown an exception if admin had an error getting metadata");
+  }
+
+  private KafkaTransactionalStateTaskBackupManager buildTSM(ContainerStorageManager csm, Partition changelogPartition,
+      StorageManagerUtil smu) {
+    TaskName taskName = new TaskName("Partition 0");
+    java.util.Map<String, SystemStream> changelogSystemStreams = mock(java.util.Map.class);
+    SystemAdmins systemAdmins = mock(SystemAdmins.class);
+
+    return new KafkaTransactionalStateTaskBackupManager(
+        taskName, changelogSystemStreams, systemAdmins, changelogPartition);
+  }
+}
\ No newline at end of file
diff --git a/samza-test/src/test/java/org/apache/samza/storage/MyStatefulApplication.java b/samza-test/src/test/java/org/apache/samza/storage/MyStatefulApplication.java
new file mode 100644
index 0000000..99c54b0
--- /dev/null
+++ b/samza-test/src/test/java/org/apache/samza/storage/MyStatefulApplication.java
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.storage;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.apache.samza.application.TaskApplication;
+import org.apache.samza.application.descriptors.TaskApplicationDescriptor;
+import org.apache.samza.context.Context;
+import org.apache.samza.operators.KV;
+import org.apache.samza.serializers.KVSerde;
+import org.apache.samza.serializers.StringSerde;
+import org.apache.samza.storage.kv.KeyValueIterator;
+import org.apache.samza.storage.kv.KeyValueStore;
+import org.apache.samza.storage.kv.descriptors.RocksDbTableDescriptor;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.kafka.descriptors.KafkaInputDescriptor;
+import org.apache.samza.system.kafka.descriptors.KafkaSystemDescriptor;
+import org.apache.samza.task.InitableTask;
+import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.StreamTask;
+import org.apache.samza.task.StreamTaskFactory;
+import org.apache.samza.task.TaskCoordinator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Stateful TaskApplication used for testing task store backup and restore behaviour.
+ * {@link #resetTestState()} should be invoked in @Before class of the test using this class.
+ *
+ * Input Message format:
+ * "num" -> put (key = num, value = num) and flush
+ * "-num" -> delete (key = num) and flush
+ * ":msg" -> act on msg (including flush) but no commit (may be num, shutdown or crash_once)
+ * "shutdown" -> always shutdown the job
+ * "crash_once" -> shut down the job the first time but ignore on subsequent run
+ */
+public class MyStatefulApplication implements TaskApplication {
+
+  public static final Logger LOG = LoggerFactory.getLogger(MyStatefulApplication.class);
+
+  private static Map<String, List<String>> initialStoreContents = new HashMap<>();
+  private static boolean crashedOnce = false;
+  private final String inputSystem;
+  private final String inputTopic;
+  private final Map<String, String> storeToChangelog;
+
+  public MyStatefulApplication(String inputSystem, String inputTopic, Map<String, String> storeToChangelog) {
+    this.inputSystem = inputSystem;
+    this.inputTopic = inputTopic;
+    this.storeToChangelog = storeToChangelog;
+  }
+
+  @Override
+  public void describe(TaskApplicationDescriptor appDescriptor) {
+    KafkaSystemDescriptor ksd = new KafkaSystemDescriptor(inputSystem);
+    KVSerde<String, String> serde = KVSerde.of(new StringSerde(), new StringSerde());
+
+    KafkaInputDescriptor<KV<String, String>> isd = ksd.getInputDescriptor(inputTopic, serde);
+
+    TaskApplicationDescriptor desc = appDescriptor
+        .withInputStream(isd)
+        .withTaskFactory((StreamTaskFactory) () -> new MyTask(storeToChangelog.keySet()));
+
+    storeToChangelog.forEach((storeName, changelogTopic) -> {
+      RocksDbTableDescriptor<String, String> td = new RocksDbTableDescriptor<>(storeName, serde)
+          .withChangelogStream(changelogTopic)
+          .withChangelogReplicationFactor(1);
+      desc.withTable(td);
+    });
+  }
+
+  public static void resetTestState() {
+    initialStoreContents = new HashMap<>();
+    crashedOnce = false;
+  }
+
+  public static Map<String, List<String>> getInitialStoreContents() {
+    return initialStoreContents;
+  }
+
+  static class MyTask implements StreamTask, InitableTask {
+    private final Set<KeyValueStore<String, String>> stores = new HashSet<>();
+    private final Set<String> storeNames;
+
+    MyTask(Set<String> storeNames) {
+      this.storeNames = storeNames;
+    }
+
+    @Override
+    public void init(Context context) {
+      storeNames.forEach(storeName -> {
+        KeyValueStore<String, String> store = (KeyValueStore<String, String>) context.getTaskContext().getStore(storeName);
+        stores.add(store);
+        KeyValueIterator<String, String> storeEntries = store.all();
+        List<String> storeInitialChangelog = new ArrayList<>();
+        while (storeEntries.hasNext()) {
+          storeInitialChangelog.add(storeEntries.next().getValue());
+        }
+        initialStoreContents.put(storeName, storeInitialChangelog);
+        storeEntries.close();
+      });
+    }
+
+    @Override
+    public void process(IncomingMessageEnvelope envelope,
+        MessageCollector collector, TaskCoordinator coordinator) {
+      String key = (String) envelope.getKey();
+      LOG.info("Received key: {}", key);
+
+      if (key.endsWith("crash_once")) {  // endsWith allows :crash_once and crash_once
+        if (!crashedOnce) {
+          crashedOnce = true;
+          coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+        } else {
+          return;
+        }
+      } else if (key.endsWith("shutdown")) {
+        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+      } else if (key.startsWith("-")) {
+        stores.forEach(store -> store.delete(key.substring(1)));
+      } else if (key.startsWith(":")) {
+        // write the message and flush, but don't invoke commit later
+        String msg = key.substring(1);
+        stores.forEach(store -> store.put(msg, msg));
+      } else {
+        stores.forEach(store -> store.put(key, key));
+      }
+      stores.forEach(KeyValueStore::flush);
+
+      if (!key.startsWith(":")) {
+        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
+      }
+    }
+  }
+}
diff --git a/samza-test/src/test/java/org/apache/samza/storage/kv/TransactionalStateIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/storage/kv/TransactionalStateIntegrationTest.java
index e9b12a1..6f9427c 100644
--- a/samza-test/src/test/java/org/apache/samza/storage/kv/TransactionalStateIntegrationTest.java
+++ b/samza-test/src/test/java/org/apache/samza/storage/kv/TransactionalStateIntegrationTest.java
@@ -22,7 +22,6 @@
 import com.google.common.collect.ImmutableList;
 
 import java.io.File;
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
@@ -31,25 +30,11 @@
 import java.util.Map;
 import java.util.stream.Collectors;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
-import org.apache.samza.application.TaskApplication;
-import org.apache.samza.application.descriptors.TaskApplicationDescriptor;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.JobCoordinatorConfig;
 import org.apache.samza.config.KafkaConfig;
 import org.apache.samza.config.TaskConfig;
-import org.apache.samza.context.Context;
-import org.apache.samza.operators.KV;
-import org.apache.samza.serializers.KVSerde;
-import org.apache.samza.serializers.StringSerde;
-import org.apache.samza.storage.kv.descriptors.RocksDbTableDescriptor;
-import org.apache.samza.system.IncomingMessageEnvelope;
-import org.apache.samza.system.kafka.descriptors.KafkaInputDescriptor;
-import org.apache.samza.system.kafka.descriptors.KafkaSystemDescriptor;
-import org.apache.samza.task.InitableTask;
-import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.StreamTask;
-import org.apache.samza.task.StreamTaskFactory;
-import org.apache.samza.task.TaskCoordinator;
+import org.apache.samza.storage.MyStatefulApplication;
 import org.apache.samza.test.framework.StreamApplicationIntegrationTestHarness;
 import org.apache.samza.util.FileUtil;
 import org.junit.Assert;
@@ -60,14 +45,7 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-/**
- * Message format:
- * "num" -> put (key = num, value = num) and flush
- * "-num" -> delete (key = num) and flush
- * ":msg" -> act on msg (including flush) but no commit (may be num, shutdown or crash_once)
- * "shutdown" -> always shutdown the job
- * "crash_once" -> shut down the job the first time but ignore on subsequent run
- */
+
 @RunWith(value = Parameterized.class)
 public class TransactionalStateIntegrationTest extends StreamApplicationIntegrationTestHarness {
   @Parameterized.Parameters(name = "hostAffinity={0}")
@@ -95,9 +73,6 @@
       put(TaskConfig.COMMIT_MAX_DELAY_MS, "0"); // Ensure no commits are skipped due to in progress commits
     } };
 
-  private static List<String> actualInitialStoreContents = new ArrayList<>();
-  private static boolean crashedOnce = false;
-
   private final boolean hostAffinity;
 
   public TransactionalStateIntegrationTest(boolean hostAffinity) {
@@ -109,43 +84,43 @@
   public void setUp() {
     super.setUp();
     // reset static state shared with task between each parameterized iteration
-    crashedOnce = false;
-    actualInitialStoreContents = new ArrayList<>();
+    MyStatefulApplication.resetTestState();
     new FileUtil().rm(new File(LOGGED_STORE_BASE_DIR)); // always clear local store on startup
   }
 
   @Test
   public void testStopAndRestart() {
     List<String> inputMessagesOnInitialRun = Arrays.asList("1", "2", "3", "2", "97", "-97", ":98", ":99", ":crash_once");
+    // double check collectors.flush
     List<String> expectedChangelogMessagesOnInitialRun = Arrays.asList("1", "2", "3", "2", "97", null, "98", "99");
     initialRun(inputMessagesOnInitialRun, expectedChangelogMessagesOnInitialRun);
 
-    // first two are reverts for uncommitted messages from last run
-    List<String> expectedChangelogMessagesOnSecondRun =
+    // first two are reverts for uncommitted messages from last run for keys 98 and 99
+    List<String> expectedChangelogMessagesAfterSecondRun =
         Arrays.asList(null, null, "98", "99", "4", "5", "5");
     List<String> expectedInitialStoreContentsOnSecondRun = Arrays.asList("1", "2", "3");
     secondRun(CHANGELOG_TOPIC,
-        expectedChangelogMessagesOnSecondRun, expectedInitialStoreContentsOnSecondRun);
+        expectedChangelogMessagesAfterSecondRun, expectedInitialStoreContentsOnSecondRun, CONFIGS);
   }
 
   @Test
   public void testWithEmptyChangelogFromInitialRun() {
     // expected changelog messages will always match since we'll read 0 messages
     initialRun(ImmutableList.of("crash_once"), Collections.emptyList());
-    secondRun(CHANGELOG_TOPIC, ImmutableList.of("4", "5", "5"), Collections.emptyList());
+    secondRun(CHANGELOG_TOPIC, ImmutableList.of("4", "5", "5"), Collections.emptyList(), CONFIGS);
   }
 
   @Test
   public void testWithNewChangelogAfterInitialRun() {
     List<String> inputMessagesOnInitialRun = Arrays.asList("1", "2", "3", "2", "97", "-97", ":98", ":99", ":crash_once");
-    List<String> expectedChangelogMessagesOnInitialRun = Arrays.asList("1", "2", "3", "2", "97", null, "98", "99");
-    initialRun(inputMessagesOnInitialRun, expectedChangelogMessagesOnInitialRun);
+    List<String> expectedChangelogMessagesAfterInitialRun = Arrays.asList("1", "2", "3", "2", "97", null, "98", "99");
+    initialRun(inputMessagesOnInitialRun, expectedChangelogMessagesAfterInitialRun);
 
     // admin client delete topic doesn't seem to work, times out up to 60 seconds.
     // simulate delete topic by changing the changelog topic instead.
     String newChangelogTopic = "changelog2";
     LOG.info("Changing changelog topic to: {}", newChangelogTopic);
-    secondRun(newChangelogTopic, ImmutableList.of("98", "99", "4", "5", "5"), Collections.emptyList());
+    secondRun(newChangelogTopic, ImmutableList.of("98", "99", "4", "5", "5"), Collections.emptyList(), CONFIGS);
   }
 
   private void initialRun(List<String> inputMessages, List<String> expectedChangelogMessages) {
@@ -162,7 +137,12 @@
     }
 
     // run the application
-    RunApplicationContext context = runApplication(new MyApplication(CHANGELOG_TOPIC), "myApp", CONFIGS);
+    RunApplicationContext context = runApplication(
+        new MyStatefulApplication(INPUT_SYSTEM, INPUT_TOPIC, Collections.singletonMap(STORE_NAME, CHANGELOG_TOPIC)),
+        "myApp", CONFIGS);
+
+    // wait for the application to finish
+    context.getRunner().waitForFinish();
 
     // consume and verify the changelog messages
     if (expectedChangelogMessages.size() > 0) {
@@ -172,13 +152,11 @@
       Assert.assertEquals(expectedChangelogMessages, changelogMessages);
     }
 
-    // wait for the application to finish
-    context.getRunner().waitForFinish();
     LOG.info("Finished initial run");
   }
 
   private void secondRun(String changelogTopic, List<String> expectedChangelogMessages,
-      List<String> expectedInitialStoreContents) {
+      List<String> expectedInitialStoreContents, Map<String, String> overriddenConfigs) {
     // clear the local store directory
     if (!hostAffinity) {
       new FileUtil().rm(new File(LOGGED_STORE_BASE_DIR));
@@ -190,7 +168,9 @@
     inputMessages.forEach(m -> produceMessage(INPUT_TOPIC, 0, m, m));
 
     // run the application
-    RunApplicationContext context = runApplication(new MyApplication(changelogTopic), "myApp", CONFIGS);
+    RunApplicationContext context = runApplication(
+        new MyStatefulApplication(INPUT_SYSTEM, INPUT_TOPIC, Collections.singletonMap(STORE_NAME, changelogTopic)),
+        "myApp", overriddenConfigs);
 
     // wait for the application to finish
     context.getRunner().waitForFinish();
@@ -202,76 +182,6 @@
     Assert.assertEquals(expectedChangelogMessages, changelogMessages);
 
     // verify the store contents during startup (this is after changelog verification to ensure init has completed)
-    Assert.assertEquals(expectedInitialStoreContents, actualInitialStoreContents);
-  }
-
-  static class MyApplication implements TaskApplication {
-    private final String changelogTopic;
-
-    public MyApplication(String changelogTopic) {
-      this.changelogTopic = changelogTopic;
-    }
-
-    @Override
-    public void describe(TaskApplicationDescriptor appDescriptor) {
-      KafkaSystemDescriptor ksd = new KafkaSystemDescriptor(INPUT_SYSTEM);
-      KVSerde<String, String> serde = KVSerde.of(new StringSerde(), new StringSerde());
-
-      KafkaInputDescriptor<KV<String, String>> isd = ksd.getInputDescriptor(INPUT_TOPIC, serde);
-
-      RocksDbTableDescriptor<String, String> td = new RocksDbTableDescriptor<>(STORE_NAME, serde)
-          .withChangelogStream(changelogTopic)
-          .withChangelogReplicationFactor(1);
-
-      appDescriptor
-          .withInputStream(isd)
-          .withTaskFactory((StreamTaskFactory) () -> new MyTask())
-          .withTable(td);
-    }
-  }
-
-  static class MyTask implements StreamTask, InitableTask {
-    private KeyValueStore<String, String> store;
-
-    @Override
-    public void init(Context context) {
-      this.store = (KeyValueStore<String, String>) context.getTaskContext().getStore(STORE_NAME);
-      KeyValueIterator<String, String> storeEntries = store.all();
-      while (storeEntries.hasNext()) {
-        actualInitialStoreContents.add(storeEntries.next().getValue());
-      }
-      storeEntries.close();
-    }
-
-    @Override
-    public void process(IncomingMessageEnvelope envelope,
-        MessageCollector collector, TaskCoordinator coordinator) {
-      String key = (String) envelope.getKey();
-      LOG.info("Received key: {}", key);
-
-      if (key.endsWith("crash_once")) {  // endsWith allows :crash_once and crash_once
-        if (!crashedOnce) {
-          crashedOnce = true;
-          coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
-        } else {
-          return;
-        }
-      } else if (key.endsWith("shutdown")) {
-        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
-      } else if (key.startsWith("-")) {
-        store.delete(key.substring(1));
-      } else if (key.startsWith(":")) {
-        // write the message and flush, but don't invoke commit later
-        String msg = key.substring(1);
-        store.put(msg, msg);
-      } else {
-        store.put(key, key);
-      }
-      store.flush();
-
-      if (!key.startsWith(":")) {
-        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-      }
-    }
+    Assert.assertEquals(expectedInitialStoreContents, MyStatefulApplication.getInitialStoreContents().get(STORE_NAME));
   }
 }
\ No newline at end of file
diff --git a/samza-test/src/test/java/org/apache/samza/storage/kv/TransactionalStateMultiStoreIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/storage/kv/TransactionalStateMultiStoreIntegrationTest.java
index d50d6bf..e0fbaca 100644
--- a/samza-test/src/test/java/org/apache/samza/storage/kv/TransactionalStateMultiStoreIntegrationTest.java
+++ b/samza-test/src/test/java/org/apache/samza/storage/kv/TransactionalStateMultiStoreIntegrationTest.java
@@ -21,8 +21,8 @@
 
 import com.google.common.collect.ImmutableList;
 
+import com.google.common.collect.ImmutableMap;
 import java.io.File;
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
@@ -31,25 +31,12 @@
 import java.util.Map;
 import java.util.stream.Collectors;
 import org.apache.kafka.clients.consumer.ConsumerRecord;
-import org.apache.samza.application.TaskApplication;
-import org.apache.samza.application.descriptors.TaskApplicationDescriptor;
+import org.apache.samza.application.SamzaApplication;
 import org.apache.samza.config.JobConfig;
 import org.apache.samza.config.JobCoordinatorConfig;
 import org.apache.samza.config.KafkaConfig;
 import org.apache.samza.config.TaskConfig;
-import org.apache.samza.context.Context;
-import org.apache.samza.operators.KV;
-import org.apache.samza.serializers.KVSerde;
-import org.apache.samza.serializers.StringSerde;
-import org.apache.samza.storage.kv.descriptors.RocksDbTableDescriptor;
-import org.apache.samza.system.IncomingMessageEnvelope;
-import org.apache.samza.system.kafka.descriptors.KafkaInputDescriptor;
-import org.apache.samza.system.kafka.descriptors.KafkaSystemDescriptor;
-import org.apache.samza.task.InitableTask;
-import org.apache.samza.task.MessageCollector;
-import org.apache.samza.task.StreamTask;
-import org.apache.samza.task.StreamTaskFactory;
-import org.apache.samza.task.TaskCoordinator;
+import org.apache.samza.storage.MyStatefulApplication;
 import org.apache.samza.test.framework.StreamApplicationIntegrationTestHarness;
 import org.apache.samza.util.FileUtil;
 import org.junit.Assert;
@@ -80,6 +67,7 @@
   private static final String STORE_2_NAME = "store2";
   private static final String STORE_1_CHANGELOG = "changelog1";
   private static final String STORE_2_CHANGELOG = "changelog2";
+  private static final String APP_NAME = "myApp";
   private static final String LOGGED_STORE_BASE_DIR = new File(System.getProperty("java.io.tmpdir"), "logged-store").getAbsolutePath();
   private static final Map<String, String> CONFIGS = new HashMap<String, String>() { {
       put(JobCoordinatorConfig.JOB_COORDINATOR_FACTORY, "org.apache.samza.standalone.PassthroughJobCoordinatorFactory");
@@ -94,9 +82,6 @@
       put(TaskConfig.COMMIT_MAX_DELAY_MS, "0"); // Ensure no commits are skipped due to in progress commits
     } };
 
-  private static List<String> actualInitialStoreContents = new ArrayList<>();
-  private static boolean crashedOnce = false;
-
   private final boolean hostAffinity;
 
   public TransactionalStateMultiStoreIntegrationTest(boolean hostAffinity) {
@@ -108,8 +93,7 @@
   public void setUp() {
     super.setUp();
     // reset static state shared with task between each parameterized iteration
-    crashedOnce = false;
-    actualInitialStoreContents = new ArrayList<>();
+    MyStatefulApplication.resetTestState();
     new FileUtil().rm(new File(LOGGED_STORE_BASE_DIR)); // always clear local store on startup
   }
 
@@ -119,7 +103,7 @@
     List<String> expectedChangelogMessagesOnInitialRun = Arrays.asList("1", "2", "3", "2", "97", null, "98", "99");
     initialRun(inputMessagesOnInitialRun, expectedChangelogMessagesOnInitialRun);
 
-    // first two are reverts for uncommitted messages from last run
+    // first two are reverts for uncommitted messages from last run for keys 98 and 99
     List<String> expectedChangelogMessagesOnSecondRun =
         Arrays.asList(null, null, "98", "99", "4", "5", "5");
     List<String> expectedInitialStoreContentsOnSecondRun = Arrays.asList("1", "2", "3");
@@ -160,8 +144,14 @@
       Assert.assertEquals(inputMessages, readInputMessages);
     }
 
+    SamzaApplication app =  new MyStatefulApplication(INPUT_SYSTEM, INPUT_TOPIC, ImmutableMap.of(
+        STORE_1_NAME, STORE_1_CHANGELOG,
+        STORE_2_NAME, STORE_2_CHANGELOG
+    ));
+
     // run the application
-    RunApplicationContext context = runApplication(new MyApplication(STORE_1_CHANGELOG), "myApp", CONFIGS);
+    RunApplicationContext context = runApplication(app, APP_NAME, CONFIGS);
+
 
     // consume and verify the changelog messages
     if (expectedChangelogMessages.size() > 0) {
@@ -188,8 +178,12 @@
     List<String> inputMessages = Arrays.asList("4", "5", "5", ":shutdown");
     inputMessages.forEach(m -> produceMessage(INPUT_TOPIC, 0, m, m));
 
+    SamzaApplication app =  new MyStatefulApplication(INPUT_SYSTEM, INPUT_TOPIC, ImmutableMap.of(
+        STORE_1_NAME, changelogTopic,
+        STORE_2_NAME, STORE_2_CHANGELOG
+    ));
     // run the application
-    RunApplicationContext context = runApplication(new MyApplication(changelogTopic), "myApp", CONFIGS);
+    RunApplicationContext context = runApplication(app, APP_NAME, CONFIGS);
 
     // wait for the application to finish
     context.getRunner().waitForFinish();
@@ -201,81 +195,6 @@
     Assert.assertEquals(expectedChangelogMessages, changelogMessages);
 
     // verify the store contents during startup (this is after changelog verification to ensure init has completed)
-    Assert.assertEquals(expectedInitialStoreContents, actualInitialStoreContents);
-  }
-
-  static class MyApplication implements TaskApplication {
-    private final String changelogTopic;
-
-    public MyApplication(String changelogTopic) {
-      this.changelogTopic = changelogTopic;
-    }
-
-    @Override
-    public void describe(TaskApplicationDescriptor appDescriptor) {
-      KafkaSystemDescriptor ksd = new KafkaSystemDescriptor(INPUT_SYSTEM);
-      KVSerde<String, String> serde = KVSerde.of(new StringSerde(), new StringSerde());
-
-      KafkaInputDescriptor<KV<String, String>> isd = ksd.getInputDescriptor(INPUT_TOPIC, serde);
-
-      RocksDbTableDescriptor<String, String> td1 = new RocksDbTableDescriptor<>(STORE_1_NAME, serde)
-          .withChangelogStream(changelogTopic)
-          .withChangelogReplicationFactor(1);
-
-      RocksDbTableDescriptor<String, String> td2 = new RocksDbTableDescriptor<>(STORE_2_NAME, serde)
-          .withChangelogStream(STORE_2_CHANGELOG)
-          .withChangelogReplicationFactor(1);
-
-      appDescriptor
-          .withInputStream(isd)
-          .withTaskFactory((StreamTaskFactory) () -> new MyTask())
-          .withTable(td1)
-          .withTable(td2);
-    }
-  }
-
-  static class MyTask implements StreamTask, InitableTask {
-    private KeyValueStore<String, String> store;
-
-    @Override
-    public void init(Context context) {
-      this.store = (KeyValueStore<String, String>) context.getTaskContext().getStore(STORE_1_NAME);
-      KeyValueIterator<String, String> storeEntries = store.all();
-      while (storeEntries.hasNext()) {
-        actualInitialStoreContents.add(storeEntries.next().getValue());
-      }
-      storeEntries.close();
-    }
-
-    @Override
-    public void process(IncomingMessageEnvelope envelope,
-        MessageCollector collector, TaskCoordinator coordinator) {
-      String key = (String) envelope.getKey();
-      LOG.info("Received key: {}", key);
-
-      if (key.endsWith("crash_once")) {  // endsWith allows :crash_once and crash_once
-        if (!crashedOnce) {
-          crashedOnce = true;
-          coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
-        } else {
-          return;
-        }
-      } else if (key.endsWith("shutdown")) {
-        coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
-      } else if (key.startsWith("-")) {
-        store.delete(key.substring(1));
-      } else if (key.startsWith(":")) {
-        // write the message and flush, but don't invoke commit later
-        String msg = key.substring(1);
-        store.put(msg, msg);
-      } else {
-        store.put(key, key);
-      }
-      store.flush();
-
-      if (!key.startsWith(":")) {
-        coordinator.commit(TaskCoordinator.RequestScope.CURRENT_TASK);
-      }
-    }
+    Assert.assertEquals(expectedInitialStoreContents, MyStatefulApplication.getInitialStoreContents().get(STORE_1_NAME));
   }
 }
\ No newline at end of file