Merge branch 'master' into 1.2.0
diff --git a/KEYS b/KEYS
index edcb935..8e23b24 100644
--- a/KEYS
+++ b/KEYS
@@ -690,3 +690,40 @@
 g2aH
 =ukLA
 -----END PGP PUBLIC KEY BLOCK-----
+pub   2048R/5B5EB041 2019-05-17
+uid                  Boris Shkolnik (Key for release signing) <boryas@apache.org>
+sig 3        5B5EB041 2019-05-17  Boris Shkolnik (Key for release signing) <boryas@apache.org>
+sub   2048R/F77F01B7 2019-05-17
+sig          5B5EB041 2019-05-17  Boris Shkolnik (Key for release signing) <boryas@apache.org>
+
+-----BEGIN PGP PUBLIC KEY BLOCK-----
+Version: GnuPG v2.0.22 (GNU/Linux)
+
+mQENBFze7MEBCADJGw76F2EkwUUjEzxBgBLpg5zbNNTnlwgYJ5C/FVPFILK14H+6
+vjN2ID66ZQJCg30w3FltFILgfJhLHjrAabcFbptghqzT0kFCIcrVwiHm7eu7z5C1
+faTkBhKc2xiQjjVRp1Jkr/iNcKct9p6wuhABAiiEOI5iKxzN00Nkf2SaPydP8mUa
+ftEmpHaVLmVwgW5Jh/tCJIoJn0T06uNMZ7C25DI9TJP9UwXkzzw++b7Drd4hp310
+bjrhGb61VU2SbZum4sVi6gNZlXFE0gzfQRLLiTpESPsEaouRmymOG81LaWNbH9+w
+xD5Mc+dRSmuLpYpCTUDcbm18bquN846grToFABEBAAG0PEJvcmlzIFNoa29sbmlr
+IChLZXkgZm9yIHJlbGVhc2Ugc2lnbmluZykgPGJvcnlhc0BhcGFjaGUub3JnPokB
+OQQTAQIAIwUCXN7swQIbAwcLCQgHAwIBBhUIAgkKCwQWAgMBAh4BAheAAAoJEH10
+0M1bXrBBgScIAINYtGoMSPl3UgNaa1bv124milnaa/E9Mu1dkD3nXCcsn2bRtdM7
+oSOxRJ35wHwEEQMWzmEVG+wyxSNdUBaoIAzMV/Ok+yCmKVaPlKINmoUQ2k5RjAle
+kyqz0lcaYluSL5GTqlm2Hw+v5uZ78c2O36PwuhydNedfO1aI7msS89zoeyyVtJQz
+3M+i/fpZ9mbkbcQWU4Bn1HgcsKYCiyDJ6RCoEvQJheIUNKTswj78IgcppvYHtkDF
+NOtUQgw6GozQ0zIVRw7+82qNGkgbPPYjJjfgNuFkgD1GY3Fp4UDs9yX2Qlr0swoP
+f0WOqiDbHPln5EHVtVXwK0994NAPSXzS1tu5AQ0EXN7swQEIANZFy99a86nlddQg
+sldEmg2fhFJDIk8+SXDm2NF5TeTMQ96KZpaFcyOEy1DytA5EzryIynGgV9rt9jhH
+a8HtfH2uiHf3QTIqgnrUimxs2uhxc+Kfx9yfEKaEUBKrnSdrStqxu0DQefbr84KW
+VmhEvKSmzejHoqeV4H7Q+/tmCrMbeVDyOKEQLKYm9Dm3TNJxNRqxNrKzdsjYgqt7
+B1heSXiKzzREsZSWoXLkFX2H6qJpKn6/Mbdu5qD7RIenYZ1BJDPkAkTTCA1rpr1o
+iW2qY7KFzJkFGBnej+G58GcOFwAzIPzzjQsip5XjMedbYpz2V7sdij6K9AO1avDu
+1Fb8iEEAEQEAAYkBHwQYAQIACQUCXN7swQIbDAAKCRB9dNDNW16wQY/0CAC7t4c3
+cSgsWbEIkdY4A/cNAeQQxWmi08x0UMl1xYGfcIOtAwePdQOJJ2TuhBEJKMLmKO64
+IpGzaKwcRaEPBB9lFwJBAJeCizzY76h64LFnus35aA3UJeG3TlyfcggVI5uJG//M
+90TVkxv84Z406Wf2B1RLca3YmsxGdchk6JLD9e2bPXGAFgr+z5sTsHAe6XP8CcoH
+PlhUDyZhNuP6OnvS0r0qkIy4eWZzsuUk3OuJQAkhvqNb/r5eahXtgpasJyIYofeR
+bN7+AsTZ8FGMrkfwrs9a1RFtypldStL70bAl61/yTqngdwFGb0zJi3IxqNzq8Fun
+AE4y9/JMEbF0vpC4
+=lijy
+-----END PGP PUBLIC KEY BLOCK-----
diff --git a/docs/learn/documentation/versioned/api/table-api.md b/docs/learn/documentation/versioned/api/table-api.md
index be87e43..d8695ba 100644
--- a/docs/learn/documentation/versioned/api/table-api.md
+++ b/docs/learn/documentation/versioned/api/table-api.md
@@ -96,24 +96,25 @@
 the **`get`** operation.
 
 {% highlight java %}
- /**
+  /**
    * Gets the value associated with the specified {@code key}.
    *
    * @param key the key with which the associated value is to be fetched.
-   * @return if found, the value associated with the specified {@code key}; 
-   * otherwise, {@code null}.
+   * @param args additional arguments
+   * @return if found, the value associated with the specified {@code key}; otherwise, {@code null}.
    * @throws NullPointerException if the specified {@code key} is {@code null}.
    */
-  V get(K key);
+  V get(K key, Object ... args);
 
- /**
+  /**
    * Asynchronously gets the value associated with the specified {@code key}.
    *
    * @param key the key with which the associated value is to be fetched.
+   * @param args additional arguments
    * @return completableFuture for the requested value
    * @throws NullPointerException if the specified {@code key} is {@code null}.
    */
-  CompletableFuture<V> getAsync(K key);
+  CompletableFuture<V> getAsync(K key, Object ... args);
 {% endhighlight %}
 
 
@@ -249,6 +250,8 @@
 |`num-gets`|`ReadableTable`|Count of `get/getAsync()` operations
 |`num-getAlls`|`ReadableTable`|Count of `getAll/getAllAsync()` operations
 |`num-missed-lookups`|`ReadableTable`|Count of missed get/getAll() operations
+|`read-ns`|`ReadableTable`|Average latency of `readAsync()` operations|
+|`num-reads`|`ReadableTable`|Count of `readAsync()` operations
 |`put-ns`|`ReadWriteTable`|Average latency of `put/putAsync()` operations
 |`putAll-ns`|`ReadWriteTable`|Average latency of `putAll/putAllAsync()` operations
 |`num-puts`|`ReadWriteTable`|Count of `put/putAsync()` operations
@@ -257,6 +260,8 @@
 |`deleteAll-ns`|`ReadWriteTable`|Average latency of `deleteAll/deleteAllAsync()` operations
 |`delete-num`|`ReadWriteTable`|Count of `delete/deleteAsync()` operations
 |`deleteAll-num`|`ReadWriteTable`|Count of `deleteAll/deleteAllAsync()` operations
+|`num-writes`|`ReadWriteTable`|Count of `writeAsync()` operations
+|`write-ns`|`ReadWriteTable`|Average latency of `writeAsync()` operations
 |`flush-ns`|`ReadWriteTable`|Average latency of flush operations
 |`flush-num`|`ReadWriteTable`|Count of flush operations
 |`hit-rate`|`CachingTable`|Cache hit rate (%)
@@ -343,6 +348,66 @@
 They can be found in 
 [`RetryMetrics`] (https://github.com/apache/samza/blob/master/samza-core/src/main/java/org/apache/samza/table/retry/RetryMetrics.java).
 
+
+### Supporting Additional Operations
+
+Remote Table allows invoking additional operations on remote store that are not directly
+supported through the Get/Put/Delete methods. Two categories of operations are supported
+
+* Get/Put/Delete operations with additional arguments
+* Arbitrary operations through readAsync() and writeAsync()  
+
+We only mandate implementers of table functions to provide implementation for Get/Put/Delete 
+without additional arguments. End users can subclass a table function, and invoke operations
+on remote store directly, if they are not supported by a table function.
+
+{% highlight java %}
+ 1  public class MyCouchbaseTableWriteFunction<V> extends CouchbaseTableWriteFunction<V> {
+ 2
+ 3    public static final int OP_COUNTER = 1;
+ 4
+ 5    @Override
+ 6    public <T> CompletableFuture<T> writeAsync(int opId, Object... args) {
+ 7      if (OP_COUNTER == opId) {
+ 8        String id = (String) args[0];
+ 9        Long delta = Long.valueOf(args[1].toString());
+10        return convertToFuture(bucket.async().counter(id, delta));
+11      }
+12      throw new SamzaException("Unknown opId" + opId);
+13    }
+14
+15    public CompletableFuture<Long> counterAsync(String id, long delta) {
+16      return table.writeAsync(OP_COUNTER, id, delta);
+17    }
+18  }
+19
+20  public class MyMapFunc implements MapFunction { 
+21
+22    AsyncReadWriteTable table;
+23    MyCouchbaseTableWriteFunction writeFunc;
+24
+25    @Override
+26    public void init(Context context) {
+27      table = context.getTaskContext().getTable(...);
+28      writeFunc = (MyCouchbaseTableWriteFunction) ((RemoteTable) table).getWriteFunction();
+29    }
+30
+31    @Override
+32    public Object apply(Object message) {
+33      return writeFunc.counterAsync(“id”, 100);
+34    }
+35  }
+{% endhighlight %}
+  
+The code above illustrates an example of invoking counter() operation on Couchbase. 
+
+1. Line  5-13: method writeAsync() is implemented to invoke counter(). 
+2. Line 15-16: it is then wrapped by a convenience method. Notice here we invoke writeAsync()
+               on the table, so that other value-added features such as rate limiting,
+               retry and batching can participate in this call.
+3. Line 27-28: references to the table and read function are obtained
+4. Line    33: the actual invocation.   
+
 ## Local Table
 
 A table is considered local when its data physically co-exists on the same host 
diff --git a/samza-api/src/main/java/org/apache/samza/startpoint/Startpoint.java b/samza-api/src/main/java/org/apache/samza/startpoint/Startpoint.java
index fcca792..a48962b 100644
--- a/samza-api/src/main/java/org/apache/samza/startpoint/Startpoint.java
+++ b/samza-api/src/main/java/org/apache/samza/startpoint/Startpoint.java
@@ -22,12 +22,17 @@
 import com.google.common.base.Objects;
 import java.time.Instant;
 import org.apache.samza.annotation.InterfaceStability;
+import org.codehaus.jackson.annotate.JsonTypeInfo;
 
 /**
  * Startpoint represents a position in a stream partition.
  */
 @InterfaceStability.Evolving
+@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "@type")
 public abstract class Startpoint {
+  // TODO: Remove the @JsonTypeInfo annotation and use the ObjectMapper#enableDefaultTyping method in
+  //  StartpointObjectMapper after upgrading jackson version. That method does not add the appropriate type info to the
+  //  serialized json with the current version (1.9.13) of jackson.
 
   private final long creationTimestamp;
 
diff --git a/samza-api/src/main/java/org/apache/samza/table/AsyncReadWriteTable.java b/samza-api/src/main/java/org/apache/samza/table/AsyncReadWriteTable.java
index dc976b5..bf692e0 100644
--- a/samza-api/src/main/java/org/apache/samza/table/AsyncReadWriteTable.java
+++ b/samza-api/src/main/java/org/apache/samza/table/AsyncReadWriteTable.java
@@ -21,6 +21,8 @@
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
+
+import org.apache.samza.SamzaException;
 import org.apache.samza.context.Context;
 import org.apache.samza.storage.kv.Entry;
 
@@ -36,19 +38,33 @@
    * Asynchronously gets the value associated with the specified {@code key}.
    *
    * @param key the key with which the associated value is to be fetched.
+   * @param args additional arguments
    * @return completableFuture for the requested value
    * @throws NullPointerException if the specified {@code key} is {@code null}.
    */
-  CompletableFuture<V> getAsync(K key);
+  CompletableFuture<V> getAsync(K key, Object ... args);
 
   /**
    * Asynchronously gets the values with which the specified {@code keys} are associated.
    *
    * @param keys the keys with which the associated values are to be fetched.
+   * @param args additional arguments
    * @return completableFuture for the requested entries
    * @throws NullPointerException if the specified {@code keys} list, or any of the keys, is {@code null}.
    */
-  CompletableFuture<Map<K, V>> getAllAsync(List<K> keys);
+  CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args);
+
+  /**
+   * Asynchronously executes a read operation. opId is used to allow tracking of different
+   * types of operation.
+   * @param opId operation identifier
+   * @param args additional arguments
+   * @param <T> return type
+   * @return completableFuture for read result
+   */
+  default <T> CompletableFuture<T> readAsync(int opId, Object ... args) {
+    throw new SamzaException("Not supported");
+  }
 
   /**
    * Asynchronously updates the mapping of the specified key-value pair;
@@ -57,36 +73,52 @@
    *
    * @param key the key with which the specified {@code value} is to be associated.
    * @param value the value with which the specified {@code key} is to be associated.
+   * @param args additional arguments
    * @throws NullPointerException if the specified {@code key} is {@code null}.
    * @return CompletableFuture for the operation
    */
-  CompletableFuture<Void> putAsync(K key, V value);
+  CompletableFuture<Void> putAsync(K key, V value, Object ... args);
 
   /**
    * Asynchronously updates the mappings of the specified key-value {@code entries}.
    * A key is deleted from the table if its corresponding value is {@code null}.
    *
    * @param entries the updated mappings to put into this table.
+   * @param args additional arguments
    * @throws NullPointerException if any of the specified {@code entries} has {@code null} as key.
    * @return CompletableFuture for the operation
    */
-  CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries);
+  CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries, Object ... args);
 
   /**
    * Asynchronously deletes the mapping for the specified {@code key} from this table (if such mapping exists).
    * @param key the key for which the mapping is to be deleted.
+   * @param args additional arguments
    * @throws NullPointerException if the specified {@code key} is {@code null}.
    * @return CompletableFuture for the operation
    */
-  CompletableFuture<Void> deleteAsync(K key);
+  CompletableFuture<Void> deleteAsync(K key, Object ... args);
 
   /**
    * Asynchronously deletes the mappings for the specified {@code keys} from this table.
    * @param keys the keys for which the mappings are to be deleted.
+   * @param args additional arguments
    * @throws NullPointerException if the specified {@code keys} list, or any of the keys, is {@code null}.
    * @return CompletableFuture for the operation
    */
-  CompletableFuture<Void> deleteAllAsync(List<K> keys);
+  CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args);
+
+  /**
+   * Asynchronously executes a write operation. opId is used to allow tracking of different
+   * types of operation.
+   * @param opId operation identifier
+   * @param args additional arguments
+   * @param <T> return type
+   * @return completableFuture for write result
+   */
+  default <T> CompletableFuture<T> writeAsync(int opId, Object ... args) {
+    throw new SamzaException("Not supported");
+  }
 
   /**
    * Initializes the table during container initialization.
diff --git a/samza-api/src/main/java/org/apache/samza/table/ReadWriteTable.java b/samza-api/src/main/java/org/apache/samza/table/ReadWriteTable.java
index a7dad8f..72ee34d 100644
--- a/samza-api/src/main/java/org/apache/samza/table/ReadWriteTable.java
+++ b/samza-api/src/main/java/org/apache/samza/table/ReadWriteTable.java
@@ -21,6 +21,7 @@
 import java.util.List;
 import java.util.Map;
 
+import org.apache.samza.SamzaException;
 import org.apache.samza.annotation.InterfaceStability;
 import org.apache.samza.storage.kv.Entry;
 
@@ -37,19 +38,34 @@
    * Gets the value associated with the specified {@code key}.
    *
    * @param key the key with which the associated value is to be fetched.
+   * @param args additional arguments
    * @return if found, the value associated with the specified {@code key}; otherwise, {@code null}.
    * @throws NullPointerException if the specified {@code key} is {@code null}.
    */
-  V get(K key);
+  V get(K key, Object ... args);
 
   /**
    * Gets the values with which the specified {@code keys} are associated.
    *
    * @param keys the keys with which the associated values are to be fetched.
+   * @param args additional arguments
    * @return a map of the keys that were found and their respective values.
    * @throws NullPointerException if the specified {@code keys} list, or any of the keys, is {@code null}.
    */
-  Map<K, V> getAll(List<K> keys);
+  Map<K, V> getAll(List<K> keys, Object ... args);
+
+  /**
+   * Executes a read operation. opId is used to allow tracking of different
+   * types of operation.
+   * @param opId operation identifier
+   * @param args additional arguments
+   * @param <T> return type
+   * @return read result
+   */
+
+  default <T> T read(int opId, Object ... args) {
+    throw new SamzaException("Not supported");
+  }
 
   /**
    * Updates the mapping of the specified key-value pair;
@@ -59,9 +75,10 @@
    *
    * @param key the key with which the specified {@code value} is to be associated.
    * @param value the value with which the specified {@code key} is to be associated.
+   * @param args additional arguments
    * @throws NullPointerException if the specified {@code key} is {@code null}.
    */
-  void put(K key, V value);
+  void put(K key, V value, Object ... args);
 
   /**
    * Updates the mappings of the specified key-value {@code entries}.
@@ -69,23 +86,38 @@
    * A key is deleted from the table if its corresponding value is {@code null}.
    *
    * @param entries the updated mappings to put into this table.
+   * @param args additional arguments
    * @throws NullPointerException if any of the specified {@code entries} has {@code null} as key.
    */
-  void putAll(List<Entry<K, V>> entries);
+  void putAll(List<Entry<K, V>> entries, Object ... args);
 
   /**
    * Deletes the mapping for the specified {@code key} from this table (if such mapping exists).
    *
    * @param key the key for which the mapping is to be deleted.
+   * @param args additional arguments
    * @throws NullPointerException if the specified {@code key} is {@code null}.
    */
-  void delete(K key);
+  void delete(K key, Object ... args);
 
   /**
    * Deletes the mappings for the specified {@code keys} from this table.
    *
    * @param keys the keys for which the mappings are to be deleted.
+   * @param args additional arguments
    * @throws NullPointerException if the specified {@code keys} list, or any of the keys, is {@code null}.
    */
-  void deleteAll(List<K> keys);
+  void deleteAll(List<K> keys, Object ... args);
+
+  /**
+   * Executes a write operation. opId is used to allow tracking of different
+   * types of operation.
+   * @param opId operation identifier
+   * @param args additional arguments
+   * @param <T> return type
+   * @return write result
+   */
+  default <T> T write(int opId, Object ... args) {
+    throw new SamzaException("Not supported");
+  }
 }
diff --git a/samza-api/src/main/java/org/apache/samza/table/remote/BaseTableFunction.java b/samza-api/src/main/java/org/apache/samza/table/remote/BaseTableFunction.java
new file mode 100644
index 0000000..f099ff9
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/table/remote/BaseTableFunction.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.remote;
+
+import com.google.common.base.Preconditions;
+
+import org.apache.samza.context.Context;
+import org.apache.samza.table.AsyncReadWriteTable;
+
+
+/**
+ * Base class for all table read and write functions.
+ */
+abstract public class BaseTableFunction implements TableFunction {
+
+  protected AsyncReadWriteTable table;
+
+  @Override
+  public void init(Context context, AsyncReadWriteTable table) {
+    Preconditions.checkNotNull(context, "null context");
+    Preconditions.checkNotNull(table, "null table");
+    this.table = table;
+  }
+}
diff --git a/samza-api/src/main/java/org/apache/samza/table/remote/TableFunction.java b/samza-api/src/main/java/org/apache/samza/table/remote/TableFunction.java
new file mode 100644
index 0000000..517a927
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/table/remote/TableFunction.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.table.remote;
+
+import java.io.Serializable;
+
+import org.apache.samza.annotation.InterfaceStability;
+import org.apache.samza.context.Context;
+import org.apache.samza.operators.functions.ClosableFunction;
+import org.apache.samza.table.AsyncReadWriteTable;
+
+
+/**
+ * The root interface for table read and write function.
+ */
+@InterfaceStability.Unstable
+public interface TableFunction extends TablePart, ClosableFunction, Serializable {
+
+  /**
+   * Initializes the function before any operation.
+   *
+   * @param context the {@link Context} for this task
+   * @param table the {@link TableFunction} which this table function belongs to
+   */
+  void init(Context context, AsyncReadWriteTable table);
+
+  /**
+   * Determine whether the current operation can be retried with the last thrown exception.
+   * @param exception exception thrown by a table operation
+   * @return whether the operation can be retried
+   */
+  boolean isRetriable(Throwable exception);
+}
diff --git a/samza-api/src/main/java/org/apache/samza/table/remote/TableRateLimiter.java b/samza-api/src/main/java/org/apache/samza/table/remote/TableRateLimiter.java
index f52bcbe..81cdecc 100644
--- a/samza-api/src/main/java/org/apache/samza/table/remote/TableRateLimiter.java
+++ b/samza-api/src/main/java/org/apache/samza/table/remote/TableRateLimiter.java
@@ -63,9 +63,20 @@
      * Get the number of credits required for the {@code key} and {@code value} pair.
      * @param key table key
      * @param value table record
+     * @param args additional arguments
      * @return number of credits
      */
-    int getCredits(K key, V value);
+    int getCredits(K key, V value, Object ... args);
+
+    /**
+     * Get the number of credits required for the {@code opId} and associated {@code args}.
+     * @param opId operation Id
+     * @param args additional arguments
+     * @return number of credits
+     */
+    default int getCredits(int opId, Object ... args) {
+      return 1;
+    }
   }
 
   /**
@@ -92,26 +103,30 @@
     this.waitTimeMetric = timer;
   }
 
-  int getCredits(K key, V value) {
-    return (creditFn == null) ? 1 : creditFn.getCredits(key, value);
+  int getCredits(K key, V value, Object ... args) {
+    return (creditFn == null) ? 1 : creditFn.getCredits(key, value, args);
   }
 
-  int getCredits(Collection<K> keys) {
+  int getCredits(Collection<K> keys, Object ... args) {
     if (creditFn == null) {
       return keys.size();
     } else {
-      return keys.stream().mapToInt(k -> creditFn.getCredits(k, null)).sum();
+      return keys.stream().mapToInt(k -> creditFn.getCredits(k, null, args)).sum();
     }
   }
 
-  int getEntryCredits(Collection<Entry<K, V>> entries) {
+  int getEntryCredits(Collection<Entry<K, V>> entries, Object ... args) {
     if (creditFn == null) {
       return entries.size();
     } else {
-      return entries.stream().mapToInt(e -> creditFn.getCredits(e.getKey(), e.getValue())).sum();
+      return entries.stream().mapToInt(e -> creditFn.getCredits(e.getKey(), e.getValue(), args)).sum();
     }
   }
 
+  int getCredits(int opId, Object ... args) {
+    return (creditFn == null) ? 1 : creditFn.getCredits(opId, args);
+  }
+
   private void throttle(int credits) {
     long startNs = System.nanoTime();
     rateLimiter.acquire(Collections.singletonMap(tag, credits));
@@ -123,33 +138,46 @@
   /**
    * Throttle a request with a key argument if necessary.
    * @param key key used for the table request
+   * @param args additional arguments
    */
-  public void throttle(K key) {
-    throttle(getCredits(key, null));
+  public void throttle(K key, Object ... args) {
+    throttle(getCredits(key, null, args));
   }
 
   /**
    * Throttle a request with both the key and value arguments if necessary.
    * @param key key used for the table request
    * @param value value used for the table request
+   * @param args additional arguments
    */
-  public void throttle(K key, V value) {
-    throttle(getCredits(key, value));
+  public void throttle(K key, V value, Object ... args) {
+    throttle(getCredits(key, value, args));
+  }
+
+  /**
+   * Throttle a request with opId and associated arguments
+   * @param opId operation Id
+   * @param args associated arguments
+   */
+  public void throttle(int opId, Object ... args) {
+    throttle(getCredits(opId, args));
   }
 
   /**
    * Throttle a request with a collection of keys as the argument if necessary.
    * @param keys collection of keys used for the table request
+   * @param args additional arguments
    */
-  public void throttle(Collection<K> keys) {
-    throttle(getCredits(keys));
+  public void throttle(Collection<K> keys, Object ... args) {
+    throttle(getCredits(keys, args));
   }
 
   /**
    * Throttle a request with a collection of table records as the argument if necessary.
    * @param records collection of records used for the table request
+   * @param args additional arguments
    */
-  public void throttleRecords(Collection<Entry<K, V>> records) {
-    throttle(getEntryCredits(records));
+  public void throttleRecords(Collection<Entry<K, V>> records, Object ... args) {
+    throttle(getEntryCredits(records, args));
   }
 }
diff --git a/samza-api/src/main/java/org/apache/samza/table/remote/TableReadFunction.java b/samza-api/src/main/java/org/apache/samza/table/remote/TableReadFunction.java
index 04fc918..03a4f24 100644
--- a/samza-api/src/main/java/org/apache/samza/table/remote/TableReadFunction.java
+++ b/samza-api/src/main/java/org/apache/samza/table/remote/TableReadFunction.java
@@ -19,7 +19,6 @@
 
 package org.apache.samza.table.remote;
 
-import java.io.Serializable;
 import java.util.Collection;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
@@ -28,8 +27,6 @@
 
 import org.apache.samza.SamzaException;
 import org.apache.samza.annotation.InterfaceStability;
-import org.apache.samza.operators.functions.ClosableFunction;
-import org.apache.samza.operators.functions.InitableFunction;
 
 import com.google.common.collect.Iterables;
 
@@ -46,7 +43,7 @@
  * @param <V> the type of the value in this table
  */
 @InterfaceStability.Unstable
-public interface TableReadFunction<K, V> extends TablePart, InitableFunction, ClosableFunction, Serializable {
+public interface TableReadFunction<K, V> extends TableFunction {
   /**
    * Fetch single table record for a specified {@code key}. This method must be thread-safe.
    * The default implementation calls getAsync and blocks on the completion afterwards.
@@ -69,6 +66,17 @@
   CompletableFuture<V> getAsync(K key);
 
   /**
+   * Asynchronously fetch single table record for a specified {@code key} with additional arguments.
+   * This method must be thread-safe.
+   * @param key key for the table record
+   * @param args additional arguments
+   * @return CompletableFuture for the get request
+   */
+  default CompletableFuture<V> getAsync(K key, Object ... args) {
+    throw new SamzaException("Not supported");
+  }
+
+  /**
    * Fetch the table {@code records} for specified {@code keys}. This method must be thread-safe.
    * The default implementation calls getAllAsync and blocks on the completion afterwards.
    * @param keys keys for the table records
@@ -101,11 +109,27 @@
   }
 
   /**
-   * Determine whether the current operation can be retried with the last thrown exception.
-   * @param exception exception thrown by a table operation
-   * @return whether the operation can be retried
+   * Asynchronously fetch the table {@code records} for specified {@code keys} and additional arguments.
+   * This method must be thread-safe.
+   * @param keys keys for the table records
+   * @param args additional arguments
+   * @return CompletableFuture for the get request
    */
-  boolean isRetriable(Throwable exception);
+  default CompletableFuture<Map<K, V>> getAllAsync(Collection<K> keys, Object ... args) {
+    throw new SamzaException("Not supported");
+  }
+
+  /**
+   * Asynchronously read data from table for specified {@code opId} and additional arguments.
+   * This method must be thread-safe.
+   * @param opId operation identifier
+   * @param args additional arguments
+   * @param <T> return type
+   * @return CompletableFuture for the read request
+   */
+  default <T> CompletableFuture<T> readAsync(int opId, Object ... args) {
+    throw new SamzaException("Not supported");
+  }
 
   // optionally implement readObject() to initialize transient states
 }
diff --git a/samza-api/src/main/java/org/apache/samza/table/remote/TableWriteFunction.java b/samza-api/src/main/java/org/apache/samza/table/remote/TableWriteFunction.java
index 3b06664..9e274c1 100644
--- a/samza-api/src/main/java/org/apache/samza/table/remote/TableWriteFunction.java
+++ b/samza-api/src/main/java/org/apache/samza/table/remote/TableWriteFunction.java
@@ -19,7 +19,6 @@
 
 package org.apache.samza.table.remote;
 
-import java.io.Serializable;
 import java.util.Collection;
 import java.util.List;
 import java.util.concurrent.CompletableFuture;
@@ -28,8 +27,6 @@
 
 import org.apache.samza.SamzaException;
 import org.apache.samza.annotation.InterfaceStability;
-import org.apache.samza.operators.functions.ClosableFunction;
-import org.apache.samza.operators.functions.InitableFunction;
 import org.apache.samza.storage.kv.Entry;
 
 import com.google.common.collect.Iterables;
@@ -47,7 +44,7 @@
  * @param <V> the type of the value in this table
  */
 @InterfaceStability.Unstable
-public interface TableWriteFunction<K, V> extends TablePart, InitableFunction, ClosableFunction, Serializable {
+public interface TableWriteFunction<K, V> extends TableFunction {
   /**
    * Store single table {@code record} with specified {@code key}. This method must be thread-safe.
    * The default implementation calls putAsync and blocks on the completion afterwards.
@@ -72,6 +69,18 @@
   CompletableFuture<Void> putAsync(K key, V record);
 
   /**
+   * Asynchronously store single table {@code record} with specified {@code key} and additional arguments.
+   * This method must be thread-safe.
+   * @param key key for the table record
+   * @param record table record to be written
+   * @param args additional arguments
+   * @return CompletableFuture for the put request
+   */
+  default CompletableFuture<Void> putAsync(K key, V record, Object ... args) {
+    throw new SamzaException("Not supported");
+  }
+
+  /**
    * Store the table {@code records} with specified {@code keys}. This method must be thread-safe.
    * The default implementation calls putAllAsync and blocks on the completion afterwards.
    * @param records table records to be written
@@ -97,6 +106,17 @@
   }
 
   /**
+   * Asynchronously store the table {@code records} with specified {@code keys} and additional arguments.
+   * This method must be thread-safe.
+   * @param records table records to be written
+   * @param args additional arguments
+   * @return CompletableFuture for the put request
+   */
+  default CompletableFuture<Void> putAllAsync(Collection<Entry<K, V>> records, Object ... args) {
+    throw new SamzaException("Not supported");
+  }
+
+  /**
    * Delete the {@code record} with specified {@code key} from the remote store.
    * The default implementation calls deleteAsync and blocks on the completion afterwards.
    * @param key key to the table record to be deleted
@@ -117,6 +137,16 @@
   CompletableFuture<Void> deleteAsync(K key);
 
   /**
+   * Asynchronously delete the {@code record} with specified {@code key} and additional arguments from the remote store
+   * @param key key to the table record to be deleted
+   * @param args additional arguments
+   * @return CompletableFuture for the delete request
+   */
+  default CompletableFuture<Void> deleteAsync(K key, Object ... args) {
+    throw new SamzaException("Not supported");
+  }
+
+  /**
    * Delete all {@code records} with the specified {@code keys} from the remote store
    * The default implementation calls deleteAllAsync and blocks on the completion afterwards.
    * @param keys keys for the table records to be written
@@ -143,11 +173,28 @@
   }
 
   /**
-   * Determine whether the current operation can be retried with the last thrown exception.
-   * @param exception exception thrown by a table operation
-   * @return whether the operation can be retried
+   * Asynchronously delete all {@code records} with the specified {@code keys} and additional arguments from
+   * the remote store.
+   *
+   * @param keys keys for the table records to be written
+   * @param args additional arguments
+   * @return CompletableFuture for the deleteAll request
    */
-  boolean isRetriable(Throwable exception);
+  default CompletableFuture<Void> deleteAllAsync(Collection<K> keys, Object ... args) {
+    throw new SamzaException("Not supported");
+  }
+
+  /**
+   * Asynchronously write data to table for specified {@code opId} and additional arguments.
+   * This method must be thread-safe.
+   * @param opId operation identifier
+   * @param args additional arguments
+   * @param <T> return type
+   * @return CompletableFuture for the write request
+   */
+  default <T> CompletableFuture<T> writeAsync(int opId, Object ... args) {
+    throw new SamzaException("Not supported");
+  }
 
   /**
    * Flush the remote store (optional)
diff --git a/samza-api/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java b/samza-api/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java
index 262f669..5510dde 100644
--- a/samza-api/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java
+++ b/samza-api/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java
@@ -44,7 +44,7 @@
 
   public TableRateLimiter<String, String> getThrottler(String tag) {
     TableRateLimiter.CreditFunction<String, String> credFn =
-        (TableRateLimiter.CreditFunction<String, String>) (key, value) -> {
+        (TableRateLimiter.CreditFunction<String, String>) (key, value, args) -> {
       int credits = key == null ? 0 : 3;
       credits += value == null ? 0 : 3;
       return credits;
@@ -83,13 +83,29 @@
   }
 
   @Test
+  public void testCreditOpId() {
+    TableRateLimiter<String, String> rateLimitHelper = getThrottler();
+    Assert.assertEquals(1, rateLimitHelper.getCredits(1, 2));
+  }
+
+  @Test
   public void testThrottle() {
     TableRateLimiter<String, String> rateLimitHelper = getThrottler();
     Timer timer = mock(Timer.class);
     rateLimitHelper.setTimerMetric(timer);
+    int times = 0;
     rateLimitHelper.throttle("foo");
-    verify(rateLimitHelper.rateLimiter, times(1)).acquire(anyMapOf(String.class, Integer.class));
-    verify(timer, times(1)).update(anyLong());
+    verify(rateLimitHelper.rateLimiter, times(++times)).acquire(anyMap());
+    verify(timer, times(times)).update(anyLong());
+    rateLimitHelper.throttle("foo", "bar");
+    verify(rateLimitHelper.rateLimiter, times(++times)).acquire(anyMap());
+    verify(timer, times(times)).update(anyLong());
+    rateLimitHelper.throttle(Arrays.asList("foo", "bar"));
+    verify(rateLimitHelper.rateLimiter, times(++times)).acquire(anyMap());
+    verify(timer, times(times)).update(anyLong());
+    rateLimitHelper.throttle(1, 2);
+    verify(rateLimitHelper.rateLimiter, times(++times)).acquire(anyMap());
+    verify(timer, times(times)).update(anyLong());
   }
 
   @Test
diff --git a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointFanOutPerTask.java b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointFanOutPerTask.java
new file mode 100644
index 0000000..63efabc
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointFanOutPerTask.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.startpoint;
+
+import com.google.common.base.MoreObjects;
+import com.google.common.base.Objects;
+import java.time.Instant;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.samza.serializers.model.SamzaObjectMapper;
+import org.apache.samza.system.SystemStreamPartition;
+import org.codehaus.jackson.map.annotate.JsonDeserialize;
+import org.codehaus.jackson.map.annotate.JsonSerialize;
+
+
+/**
+ * Holds the {@link Startpoint} fan outs for each {@link SystemStreamPartition}. Each StartpointFanOutPerTask maps to a
+ * {@link org.apache.samza.container.TaskName}
+ */
+class StartpointFanOutPerTask {
+  // TODO: Remove the @JsonSerialize and @JsonDeserialize annotations and use the SimpleModule#addKeySerializer and
+  //  SimpleModule#addKeyDeserializer methods in StartpointObjectMapper after upgrading jackson version.
+  //  Those methods do not work on nested maps with the current version (1.9.13) of jackson.
+
+  @JsonSerialize
+  @JsonDeserialize
+  private final Instant timestamp;
+
+  @JsonDeserialize(keyUsing = SamzaObjectMapper.SystemStreamPartitionKeyDeserializer.class)
+  @JsonSerialize(keyUsing = SamzaObjectMapper.SystemStreamPartitionKeySerializer.class)
+  private final Map<SystemStreamPartition, Startpoint> fanOuts;
+
+  // required for Jackson deserialization
+  StartpointFanOutPerTask() {
+    this(Instant.now());
+  }
+
+  StartpointFanOutPerTask(Instant timestamp) {
+    this.timestamp = timestamp;
+    this.fanOuts = new HashMap<>();
+  }
+
+  // Unused in code, but useful for auditing when the fan out is serialized into the store
+  Instant getTimestamp() {
+    return timestamp;
+  }
+
+  Map<SystemStreamPartition, Startpoint> getFanOuts() {
+    return fanOuts;
+  }
+
+  @Override
+  public String toString() {
+    return MoreObjects.toStringHelper(this).toString();
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+    StartpointFanOutPerTask that = (StartpointFanOutPerTask) o;
+    return Objects.equal(timestamp, that.timestamp) && Objects.equal(fanOuts, that.fanOuts);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hashCode(timestamp, fanOuts);
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointKey.java b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointKey.java
deleted file mode 100644
index 7d4753d..0000000
--- a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointKey.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.startpoint;
-
-import com.google.common.base.MoreObjects;
-import com.google.common.base.Objects;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.system.SystemStreamPartition;
-import org.codehaus.jackson.map.annotate.JsonSerialize;
-
-
-@JsonSerialize(using = StartpointKeySerializer.class)
-class StartpointKey {
-  private final SystemStreamPartition systemStreamPartition;
-  private final TaskName taskName;
-
-  // Constructs a startpoint key with SSP. This means the key will apply to all tasks that are mapped to this SSP
-  StartpointKey(SystemStreamPartition systemStreamPartition) {
-    this(systemStreamPartition, null);
-  }
-
-  // Constructs a startpoint key with SSP and a task.
-  StartpointKey(SystemStreamPartition systemStreamPartition, TaskName taskName) {
-    this.systemStreamPartition = systemStreamPartition;
-    this.taskName = taskName;
-  }
-
-  SystemStreamPartition getSystemStreamPartition() {
-    return systemStreamPartition;
-  }
-
-  TaskName getTaskName() {
-    return taskName;
-  }
-
-  @Override
-  public String toString() {
-    return MoreObjects.toStringHelper(this).toString();
-  }
-
-  @Override
-  public boolean equals(Object o) {
-    if (this == o) {
-      return true;
-    }
-    if (o == null || getClass() != o.getClass()) {
-      return false;
-    }
-    StartpointKey that = (StartpointKey) o;
-    return Objects.equal(systemStreamPartition, that.systemStreamPartition) && Objects.equal(taskName, that.taskName);
-  }
-
-  @Override
-  public int hashCode() {
-    return Objects.hashCode(systemStreamPartition, taskName);
-  }
-}
diff --git a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointKeySerializer.java b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointKeySerializer.java
deleted file mode 100644
index d347fe7..0000000
--- a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointKeySerializer.java
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- *//*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- *//*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.startpoint;
-
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.system.SystemStreamPartition;
-import org.codehaus.jackson.JsonGenerator;
-import org.codehaus.jackson.map.JsonSerializer;
-import org.codehaus.jackson.map.SerializerProvider;
-
-
-final class StartpointKeySerializer extends JsonSerializer<StartpointKey> {
-  @Override
-  public void serialize(StartpointKey startpointKey, JsonGenerator jsonGenerator, SerializerProvider provider) throws
-                                                                                                               IOException {
-    Map<String, Object> systemStreamPartitionMap = new HashMap<>();
-    SystemStreamPartition systemStreamPartition = startpointKey.getSystemStreamPartition();
-    TaskName taskName = startpointKey.getTaskName();
-    systemStreamPartitionMap.put("system", systemStreamPartition.getSystem());
-    systemStreamPartitionMap.put("stream", systemStreamPartition.getStream());
-    systemStreamPartitionMap.put("partition", systemStreamPartition.getPartition().getPartitionId());
-    if (taskName != null) {
-      systemStreamPartitionMap.put("taskName", taskName.getTaskName());
-    }
-    jsonGenerator.writeObject(systemStreamPartitionMap);
-  }
-}
\ No newline at end of file
diff --git a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointManager.java b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointManager.java
index da7f990..0c070cb 100644
--- a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointManager.java
+++ b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointManager.java
@@ -20,87 +20,114 @@
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.ImmutableMap;
+import java.io.IOException;
 import java.time.Duration;
 import java.time.Instant;
-import java.util.HashSet;
-import java.util.Objects;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.collections4.MapUtils;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.StringUtils;
 import org.apache.samza.SamzaException;
-import org.apache.samza.config.Config;
 import org.apache.samza.container.TaskName;
 import org.apache.samza.coordinator.metadatastore.NamespaceAwareCoordinatorStreamStore;
-import org.apache.samza.job.model.ContainerModel;
-import org.apache.samza.job.model.JobModel;
-import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.metadatastore.MetadataStore;
-import org.apache.samza.metadatastore.MetadataStoreFactory;
-import org.apache.samza.metrics.MetricsRegistry;
-import org.apache.samza.serializers.JsonSerdeV2;
 import org.apache.samza.system.SystemStreamPartition;
+import org.codehaus.jackson.map.ObjectMapper;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 
 /**
- * The StartpointManager reads and writes {@link Startpoint} to the {@link MetadataStore} defined by
- * the configuration task.startpoint.metadata.store.factory.
- *
- * Startpoints are keyed in the MetadataStore by two different formats:
- * 1) Only by {@link SystemStreamPartition}
- * 2) A combination of {@link SystemStreamPartition} and {@link TaskName}
+ * The StartpointManager reads and writes {@link Startpoint} to the provided {@link MetadataStore}
  *
  * The intention for the StartpointManager is to maintain a strong contract between the caller
  * and how Startpoints are stored in the underlying MetadataStore.
+ *
+ * Startpoints are written in the MetadataStore using keys of two different formats:
+ * 1) {@link SystemStreamPartition} only
+ * 2) A combination of {@link SystemStreamPartition} and {@link TaskName}
+ *
+ * Startpoints are then fanned out to a fan out namespace in the MetadataStore by the
+ * {@link org.apache.samza.clustermanager.ClusterBasedJobCoordinator} or the standalone
+ * {@link org.apache.samza.coordinator.JobCoordinator} upon startup and the
+ * {@link org.apache.samza.checkpoint.OffsetManager} gets the fan outs to set the starting offsets per task and per
+ * {@link SystemStreamPartition}. The fan outs are deleted once the offsets are committed to the checkpoint.
+ *
+ * The read, write and delete methods are intended for external callers.
+ * The fan out methods are intended to be used within a job coordinator.
  */
 public class StartpointManager {
-  private static final Logger LOG = LoggerFactory.getLogger(StartpointManager.class);
-  public static final String NAMESPACE = "samza-startpoint-v1";
+  private static final Integer VERSION = 1;
+  public static final String NAMESPACE = "samza-startpoint-v" + VERSION;
 
   static final Duration DEFAULT_EXPIRATION_DURATION = Duration.ofHours(12);
 
-  private final MetadataStore metadataStore;
-  private final StartpointSerde startpointSerde = new StartpointSerde();
+  private static final Logger LOG = LoggerFactory.getLogger(StartpointManager.class);
+  private static final String NAMESPACE_FAN_OUT = NAMESPACE + "-fan-out";
 
-  private boolean stopped = false;
+  private final NamespaceAwareCoordinatorStreamStore fanOutStore;
+  private final NamespaceAwareCoordinatorStreamStore readWriteStore;
+  private final ObjectMapper objectMapper = StartpointObjectMapper.getObjectMapper();
 
-  /**
-   *  Constructs a {@link StartpointManager} instance by instantiating a new metadata store connection.
-   *  This is primarily used for testing.
-   */
-  @VisibleForTesting
-  StartpointManager(MetadataStoreFactory metadataStoreFactory, Config config, MetricsRegistry metricsRegistry) {
-    Preconditions.checkNotNull(metadataStoreFactory, "MetadataStoreFactory cannot be null");
-    Preconditions.checkNotNull(config, "Config cannot be null");
-    Preconditions.checkNotNull(metricsRegistry, "MetricsRegistry cannot be null");
-
-    this.metadataStore = metadataStoreFactory.getMetadataStore(NAMESPACE, config, metricsRegistry);
-    LOG.info("StartpointManager created with metadata store: {}", metadataStore.getClass().getCanonicalName());
-    this.metadataStore.init();
-  }
+  private boolean stopped = true;
 
   /**
    *  Builds the StartpointManager based upon the provided {@link MetadataStore} that is instantiated.
    *  Setting up a metadata store instance is expensive which requires opening multiple connections
-   *  and reading tons of information. Fully instantiated metadata store is taken as a constructor argument
+   *  and reading tons of information. Fully instantiated metadata store is passed in as a constructor argument
    *  to reuse it across different utility classes.
    *
    * @param metadataStore an instance of {@link MetadataStore} used to read/write the start-points.
    */
   public StartpointManager(MetadataStore metadataStore) {
     Preconditions.checkNotNull(metadataStore, "MetadataStore cannot be null");
-    this.metadataStore = new NamespaceAwareCoordinatorStreamStore(metadataStore, NAMESPACE);
-  }
 
-  public void start() {
-    // Metadata store lifecycle is managed outside of the StartpointManager, so not starting it.
+    this.readWriteStore = new NamespaceAwareCoordinatorStreamStore(metadataStore, NAMESPACE);
+    this.fanOutStore = new NamespaceAwareCoordinatorStreamStore(metadataStore, NAMESPACE_FAN_OUT);
+    LOG.info("Startpoints are written to namespace: {} and fanned out to namespace: {} in the metadata store of type: {}",
+        NAMESPACE, NAMESPACE_FAN_OUT, metadataStore.getClass().getCanonicalName());
   }
 
   /**
-   * Writes a {@link Startpoint} that defines the start position for a {@link SystemStreamPartition}.
-   * @param ssp The {@link SystemStreamPartition} to map the {@link Startpoint} against.
-   * @param startpoint Reference to a Startpoint object.
+   * Perform startup operations. Method is idempotent.
    */
+  public void start() {
+    if (stopped) {
+      LOG.info("starting");
+      readWriteStore.init();
+      fanOutStore.init();
+      stopped = false;
+    } else {
+      LOG.warn("already started");
+    }
+  }
+
+  /**
+   * Perform teardown operations. Method is idempotent.
+   */
+  public void stop() {
+    if (!stopped) {
+      LOG.info("stopping");
+      readWriteStore.close();
+      fanOutStore.close();
+      stopped = true;
+    } else {
+      LOG.warn("already stopped");
+    }
+  }
+
+    /**
+     * Writes a {@link Startpoint} that defines the start position for a {@link SystemStreamPartition}.
+     * @param ssp The {@link SystemStreamPartition} to map the {@link Startpoint} against.
+     * @param startpoint Reference to a Startpoint object.
+     */
   public void writeStartpoint(SystemStreamPartition ssp, Startpoint startpoint) {
     writeStartpoint(ssp, null, startpoint);
   }
@@ -117,7 +144,7 @@
     Preconditions.checkNotNull(startpoint, "Startpoint cannot be null");
 
     try {
-      metadataStore.put(toStoreKey(ssp, taskName), startpointSerde.toBytes(startpoint));
+      readWriteStore.put(toReadWriteStoreKey(ssp, taskName), objectMapper.writeValueAsBytes(startpoint));
     } catch (Exception ex) {
       throw new SamzaException(String.format(
           "Startpoint for SSP: %s and task: %s may not have been written to the metadata store.", ssp, taskName), ex);
@@ -127,9 +154,10 @@
   /**
    * Returns the last {@link Startpoint} that defines the start position for a {@link SystemStreamPartition}.
    * @param ssp The {@link SystemStreamPartition} to fetch the {@link Startpoint} for.
-   * @return {@link Startpoint} for the {@link SystemStreamPartition}, or null if it does not exist or if it is too stale
+   * @return {@link Optional} of {@link Startpoint} for the {@link SystemStreamPartition}.
+   *         It is empty if it does not exist or if it is too stale.
    */
-  public Startpoint readStartpoint(SystemStreamPartition ssp) {
+  public Optional<Startpoint> readStartpoint(SystemStreamPartition ssp) {
     return readStartpoint(ssp, null);
   }
 
@@ -137,23 +165,29 @@
    * Returns the {@link Startpoint} for a {@link SystemStreamPartition} and {@link TaskName}.
    * @param ssp The {@link SystemStreamPartition} to fetch the {@link Startpoint} for.
    * @param taskName The {@link TaskName} to fetch the {@link Startpoint} for.
-   * @return {@link Startpoint} for the {@link SystemStreamPartition}, or null if it does not exist or if it is too stale.
+   * @return {@link Optional} of {@link Startpoint} for the {@link SystemStreamPartition} and {@link TaskName}.
+   *         It is empty if it does not exist or if it is too stale.
    */
-  public Startpoint readStartpoint(SystemStreamPartition ssp, TaskName taskName) {
+  public Optional<Startpoint> readStartpoint(SystemStreamPartition ssp, TaskName taskName) {
     Preconditions.checkState(!stopped, "Underlying metadata store not available");
     Preconditions.checkNotNull(ssp, "SystemStreamPartition cannot be null");
 
-    byte[] startpointBytes = metadataStore.get(toStoreKey(ssp, taskName));
+    byte[] startpointBytes = readWriteStore.get(toReadWriteStoreKey(ssp, taskName));
 
-    if (Objects.nonNull(startpointBytes)) {
-      Startpoint startpoint = startpointSerde.fromBytes(startpointBytes);
-      if (Instant.now().minus(DEFAULT_EXPIRATION_DURATION).isBefore(Instant.ofEpochMilli(startpoint.getCreationTimestamp()))) {
-        return startpoint; // return if deserializable and if not stale
+    if (ArrayUtils.isNotEmpty(startpointBytes)) {
+      try {
+        Startpoint startpoint = objectMapper.readValue(startpointBytes, Startpoint.class);
+        if (Instant.now().minus(DEFAULT_EXPIRATION_DURATION).isBefore(Instant.ofEpochMilli(startpoint.getCreationTimestamp()))) {
+          return Optional.of(startpoint); // return if deserializable and if not stale
+        }
+        LOG.warn("Creation timestamp: {} of startpoint: {} has crossed the expiration duration: {}. Ignoring it",
+            startpoint.getCreationTimestamp(), startpoint, DEFAULT_EXPIRATION_DURATION);
+      } catch (IOException ex) {
+        throw new SamzaException(ex);
       }
-      LOG.warn("Stale Startpoint: {} was read. Ignoring.", startpoint);
     }
 
-    return null;
+    return Optional.empty();
   }
 
   /**
@@ -173,71 +207,148 @@
     Preconditions.checkState(!stopped, "Underlying metadata store not available");
     Preconditions.checkNotNull(ssp, "SystemStreamPartition cannot be null");
 
-    metadataStore.delete(toStoreKey(ssp, taskName));
+    readWriteStore.delete(toReadWriteStoreKey(ssp, taskName));
   }
 
   /**
-   * For {@link Startpoint}s keyed only by {@link SystemStreamPartition}, this method re-maps the Startpoints from
-   * SystemStreamPartition to SystemStreamPartition+{@link TaskName} for all tasks provided by the {@link JobModel}
+   * The Startpoints that are written to with {@link #writeStartpoint(SystemStreamPartition, Startpoint)} and with
+   * {@link #writeStartpoint(SystemStreamPartition, TaskName, Startpoint)} are moved from a "read-write" namespace
+   * to a "fan out" namespace.
    * This method is not atomic or thread-safe. The intent is for the Samza Processor's coordinator to use this
    * method to assign the Startpoints to the appropriate tasks.
-   * @param jobModel The {@link JobModel} is used to determine which {@link TaskName} each {@link SystemStreamPartition} maps to.
-   * @return The list of {@link SystemStreamPartition}s that were fanned out to SystemStreamPartition+TaskName.
+   * @param taskToSSPs Determines which {@link TaskName} each {@link SystemStreamPartition} maps to.
+   * @return The set of active {@link TaskName}s that were fanned out to.
    */
-  public Set<SystemStreamPartition> fanOutStartpointsToTasks(JobModel jobModel) {
+  public Map<TaskName, Map<SystemStreamPartition, Startpoint>> fanOut(Map<TaskName, Set<SystemStreamPartition>> taskToSSPs) throws IOException {
     Preconditions.checkState(!stopped, "Underlying metadata store not available");
-    Preconditions.checkNotNull(jobModel, "JobModel cannot be null");
+    Preconditions.checkArgument(MapUtils.isNotEmpty(taskToSSPs), "taskToSSPs cannot be null or empty");
 
-    HashSet<SystemStreamPartition> sspsToDelete = new HashSet<>();
+    // construct fan out with the existing readWriteStore entries and mark the entries for deletion after fan out
+    Instant now = Instant.now();
+    HashMultimap<SystemStreamPartition, TaskName> deleteKeys = HashMultimap.create();
+    HashMap<TaskName, StartpointFanOutPerTask> fanOuts = new HashMap<>();
+    for (TaskName taskName : taskToSSPs.keySet()) {
+      Set<SystemStreamPartition> ssps = taskToSSPs.get(taskName);
+      if (CollectionUtils.isEmpty(ssps)) {
+        LOG.warn("No SSPs are mapped to taskName: {}", taskName.getTaskName());
+        continue;
+      }
+      for (SystemStreamPartition ssp : ssps) {
+        Optional<Startpoint> startpoint = readStartpoint(ssp); // Read SSP-only key
+        startpoint.ifPresent(sp -> deleteKeys.put(ssp, null));
 
-    // Inspect the job model for TaskName-to-SSPs mapping and re-map startpoints from SSP-only keys to SSP+TaskName keys.
-    for (ContainerModel containerModel: jobModel.getContainers().values()) {
-      for (TaskModel taskModel : containerModel.getTasks().values()) {
-        TaskName taskName = taskModel.getTaskName();
-        for (SystemStreamPartition ssp : taskModel.getSystemStreamPartitions()) {
-          Startpoint startpoint = readStartpoint(ssp); // Read SSP-only key
-          if (startpoint == null) {
-            LOG.debug("No Startpoint for SSP: {} in task: {}", ssp, taskName);
-            continue;
-          }
+        Optional<Startpoint> startpointForTask = readStartpoint(ssp, taskName); // Read SSP+taskName key
+        startpointForTask.ifPresent(sp -> deleteKeys.put(ssp, taskName));
 
-          LOG.info("Grouping Startpoint keyed on SSP: {} to tasks determined by the job model.", ssp);
-          Startpoint startpointForTask = readStartpoint(ssp, taskName);
-          if (startpointForTask == null || startpointForTask.getCreationTimestamp() < startpoint.getCreationTimestamp()) {
-            writeStartpoint(ssp, taskName, startpoint);
-            sspsToDelete.add(ssp); // Mark for deletion
-            LOG.info("Startpoint for SSP: {} remapped with task: {}.", ssp, taskName);
-          } else {
-            LOG.info("Startpoint for SSP: {} and task: {} already exists and will not be overwritten.", ssp, taskName);
-          }
+        Optional<Startpoint> startpointWithPrecedence = resolveStartpointPrecendence(startpoint, startpointForTask);
+        if (!startpointWithPrecedence.isPresent()) {
+          continue;
+        }
 
+        fanOuts.putIfAbsent(taskName, new StartpointFanOutPerTask(now));
+        fanOuts.get(taskName).getFanOuts().put(ssp, startpointWithPrecedence.get());
+      }
+    }
+
+    if (fanOuts.isEmpty()) {
+      LOG.debug("No fan outs created.");
+      return ImmutableMap.of();
+    }
+
+    LOG.info("Fanning out to {} tasks", fanOuts.size());
+
+    // Fan out to store
+    for (TaskName taskName : fanOuts.keySet()) {
+      String fanOutKey = toFanOutStoreKey(taskName);
+      StartpointFanOutPerTask newFanOut = fanOuts.get(taskName);
+      fanOutStore.put(fanOutKey, objectMapper.writeValueAsBytes(newFanOut));
+    }
+
+    for (SystemStreamPartition ssp : deleteKeys.keySet()) {
+      for (TaskName taskName : deleteKeys.get(ssp)) {
+        if (taskName != null) {
+          deleteStartpoint(ssp, taskName);
+        } else {
+          deleteStartpoint(ssp);
         }
       }
     }
 
-    // Delete SSP-only keys
-    sspsToDelete.forEach(ssp -> {
-        deleteStartpoint(ssp);
-        LOG.info("All Startpoints for SSP: {} have been grouped to the appropriate tasks and the SSP was deleted.");
-      });
-
-    return ImmutableSet.copyOf(sspsToDelete);
+    return ImmutableMap.copyOf(fanOuts.entrySet().stream()
+        .collect(Collectors.toMap(fo -> fo.getKey(), fo -> fo.getValue().getFanOuts())));
   }
 
   /**
-   * Relinquish resources held by the underlying {@link MetadataStore}
+   * Read the fanned out {@link Startpoint}s for the given {@link TaskName}
+   * @param taskName to read the fan out Startpoints for
+   * @return fanned out Startpoints
    */
-  public void stop() {
-    stopped = true;
-    // Metadata store lifecycle is managed outside of the StartpointManager, so not closing it.
+  public Map<SystemStreamPartition, Startpoint> getFanOutForTask(TaskName taskName) throws IOException {
+    Preconditions.checkState(!stopped, "Underlying metadata store not available");
+    Preconditions.checkNotNull(taskName, "TaskName cannot be null");
+
+    byte[] fanOutBytes = fanOutStore.get(toFanOutStoreKey(taskName));
+    if (ArrayUtils.isEmpty(fanOutBytes)) {
+      return ImmutableMap.of();
+    }
+    StartpointFanOutPerTask startpointFanOutPerTask = objectMapper.readValue(fanOutBytes, StartpointFanOutPerTask.class);
+    return ImmutableMap.copyOf(startpointFanOutPerTask.getFanOuts());
+  }
+
+  /**
+   * Deletes the fanned out {@link Startpoint} for the given {@link TaskName}
+   * @param taskName to delete the fan out Startpoints for
+   */
+  public void removeFanOutForTask(TaskName taskName) {
+    Preconditions.checkState(!stopped, "Underlying metadata store not available");
+    Preconditions.checkNotNull(taskName, "TaskName cannot be null");
+
+    fanOutStore.delete(toFanOutStoreKey(taskName));
   }
 
   @VisibleForTesting
-  MetadataStore getMetadataStore() {
-    return metadataStore;
+  MetadataStore getReadWriteStore() {
+    return readWriteStore;
   }
 
-  private static String toStoreKey(SystemStreamPartition ssp, TaskName taskName) {
-    return new String(new JsonSerdeV2<>().toBytes(new StartpointKey(ssp, taskName)));
+  @VisibleForTesting
+  MetadataStore getFanOutStore() {
+    return fanOutStore;
+  }
+
+  @VisibleForTesting
+  ObjectMapper getObjectMapper() {
+    return objectMapper;
+  }
+
+  private static Optional<Startpoint> resolveStartpointPrecendence(Optional<Startpoint> startpoint1, Optional<Startpoint> startpoint2) {
+    if (startpoint1.isPresent() && startpoint2.isPresent()) {
+      // if SSP-only and SSP+taskName startpoints both exist, resolve to the one with the latest timestamp
+      if (startpoint1.get().getCreationTimestamp() > startpoint2.get().getCreationTimestamp()) {
+        return startpoint1;
+      }
+      return startpoint2;
+    }
+    return startpoint1.isPresent() ? startpoint1 : startpoint2;
+  }
+
+  private static String toReadWriteStoreKey(SystemStreamPartition ssp, TaskName taskName) {
+    Preconditions.checkArgument(ssp != null, "SystemStreamPartition should be defined");
+    Preconditions.checkArgument(StringUtils.isNotBlank(ssp.getSystem()), "System should be defined");
+    Preconditions.checkArgument(StringUtils.isNotBlank(ssp.getStream()), "Stream should be defined");
+    Preconditions.checkArgument(ssp.getPartition() != null, "Partition should be defined");
+
+    String storeKey = ssp.getSystem() + "." + ssp.getStream() + "." + String.valueOf(ssp.getPartition().getPartitionId());
+    if (taskName != null) {
+      storeKey += "." + taskName.getTaskName();
+    }
+    return storeKey;
+  }
+
+  private static String toFanOutStoreKey(TaskName taskName) {
+    Preconditions.checkArgument(taskName != null, "TaskName should be defined");
+    Preconditions.checkArgument(StringUtils.isNotBlank(taskName.getTaskName()), "TaskName should not be blank");
+
+    return taskName.getTaskName();
   }
 }
diff --git a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointObjectMapper.java b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointObjectMapper.java
new file mode 100644
index 0000000..c5c0434
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointObjectMapper.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.startpoint;
+
+import java.io.IOException;
+import java.time.Instant;
+import org.codehaus.jackson.JsonGenerator;
+import org.codehaus.jackson.JsonParser;
+import org.codehaus.jackson.JsonProcessingException;
+import org.codehaus.jackson.Version;
+import org.codehaus.jackson.map.DeserializationContext;
+import org.codehaus.jackson.map.JsonDeserializer;
+import org.codehaus.jackson.map.JsonSerializer;
+import org.codehaus.jackson.map.ObjectMapper;
+import org.codehaus.jackson.map.SerializerProvider;
+import org.codehaus.jackson.map.jsontype.NamedType;
+import org.codehaus.jackson.map.module.SimpleModule;
+
+
+class StartpointObjectMapper {
+
+  static ObjectMapper getObjectMapper() {
+    ObjectMapper objectMapper = new ObjectMapper();
+    SimpleModule module = new SimpleModule("StartpointModule", new Version(1, 0, 0, ""));
+    module.addSerializer(Instant.class, new CustomInstantSerializer());
+    module.addDeserializer(Instant.class, new CustomInstantDeserializer());
+    objectMapper.registerModule(module);
+
+    // 1. To support polymorphism for serialization, the Startpoint subtypes must be registered here.
+    // 2. The NamedType container class provides a logical name as an external identifier so that the full canonical
+    //    class name is not serialized into the json type property.
+    objectMapper.registerSubtypes(new NamedType(StartpointSpecific.class));
+    objectMapper.registerSubtypes(new NamedType(StartpointTimestamp.class));
+    objectMapper.registerSubtypes(new NamedType(StartpointUpcoming.class));
+    objectMapper.registerSubtypes(new NamedType(StartpointOldest.class));
+
+    return objectMapper;
+  }
+
+  private StartpointObjectMapper() { }
+
+  static class CustomInstantSerializer extends JsonSerializer<Instant> {
+    @Override
+    public void serialize(Instant value, JsonGenerator jsonGenerator, SerializerProvider provider)
+        throws IOException, JsonProcessingException {
+      jsonGenerator.writeObject(Long.valueOf(value.toEpochMilli()));
+    }
+  }
+
+  static class CustomInstantDeserializer extends JsonDeserializer<Instant> {
+    @Override
+    public Instant deserialize(JsonParser jsonParser, DeserializationContext context) throws IOException, JsonProcessingException {
+      return Instant.ofEpochMilli(jsonParser.getLongValue());
+    }
+  }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointSerde.java b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointSerde.java
deleted file mode 100644
index 361aef3..0000000
--- a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointSerde.java
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.startpoint;
-
-import com.google.common.collect.ImmutableMap;
-import java.util.Arrays;
-import java.util.LinkedHashMap;
-import org.apache.samza.SamzaException;
-import org.apache.samza.serializers.Serde;
-import org.codehaus.jackson.map.ObjectMapper;
-
-
-class StartpointSerde implements Serde<Startpoint> {
-  private static final String STARTPOINT_CLASS = "startpointClass";
-  private static final String STARTPOINT_OBJ = "startpointObj";
-
-  private final ObjectMapper mapper = new ObjectMapper();
-
-  @Override
-  public Startpoint fromBytes(byte[] bytes) {
-    try {
-      LinkedHashMap<String, String> deserialized = mapper.readValue(bytes, LinkedHashMap.class);
-      Class<? extends Startpoint> startpointClass =
-          (Class<? extends Startpoint>) Class.forName(deserialized.get(STARTPOINT_CLASS));
-      return mapper.readValue(deserialized.get(STARTPOINT_OBJ), startpointClass);
-    } catch (Exception e) {
-      throw new SamzaException(String.format("Exception in de-serializing startpoint bytes: %s",
-          Arrays.toString(bytes)), e);
-    }
-  }
-
-  @Override
-  public byte[] toBytes(Startpoint startpoint) {
-    try {
-      ImmutableMap.Builder<String, String> mapBuilder = ImmutableMap.builder();
-      mapBuilder.put(STARTPOINT_CLASS, startpoint.getClass().getCanonicalName());
-      mapBuilder.put(STARTPOINT_OBJ, mapper.writeValueAsString(startpoint));
-      return mapper.writeValueAsBytes(mapBuilder.build());
-    } catch (Exception e) {
-      throw new SamzaException(String.format("Exception in serializing: %s", startpoint), e);
-    }
-  }
-}
diff --git a/samza-core/src/main/java/org/apache/samza/table/caching/CachingTable.java b/samza-core/src/main/java/org/apache/samza/table/caching/CachingTable.java
index 2fde79a..dee0767 100644
--- a/samza-core/src/main/java/org/apache/samza/table/caching/CachingTable.java
+++ b/samza-core/src/main/java/org/apache/samza/table/caching/CachingTable.java
@@ -95,9 +95,9 @@
    * @param records result map
    * @return list of keys missed in the cache
    */
-  private List<K> lookupCache(List<K> keys, Map<K, V> records) {
+  private List<K> lookupCache(List<K> keys, Map<K, V> records, Object ... args) {
     List<K> missKeys = new ArrayList<>();
-    records.putAll(cache.getAll(keys));
+    records.putAll(cache.getAll(keys, args));
     keys.forEach(k -> {
         if (!records.containsKey(k)) {
           missKeys.add(k);
@@ -107,18 +107,18 @@
   }
 
   @Override
-  public V get(K key) {
+  public V get(K key, Object ... args) {
     try {
-      return getAsync(key).get();
+      return getAsync(key, args).get();
     } catch (Exception e) {
       throw new SamzaException(e);
     }
   }
 
   @Override
-  public CompletableFuture<V> getAsync(K key) {
+  public CompletableFuture<V> getAsync(K key, Object ... args) {
     incCounter(metrics.numGets);
-    V value = cache.get(key);
+    V value = cache.get(key, args);
     if (value != null) {
       hitCount.incrementAndGet();
       return CompletableFuture.completedFuture(value);
@@ -127,12 +127,12 @@
     long startNs = clock.nanoTime();
     missCount.incrementAndGet();
 
-    return table.getAsync(key).handle((result, e) -> {
+    return table.getAsync(key, args).handle((result, e) -> {
         if (e != null) {
           throw new SamzaException("Failed to get the record for " + key, e);
         } else {
           if (result != null) {
-            cache.put(key, result);
+            cache.put(key, result, args);
           }
           updateTimer(metrics.getNs, clock.nanoTime() - startNs);
           return result;
@@ -141,16 +141,16 @@
   }
 
   @Override
-  public Map<K, V> getAll(List<K> keys) {
+  public Map<K, V> getAll(List<K> keys, Object ... args) {
     try {
-      return getAllAsync(keys).get();
+      return getAllAsync(keys, args).get();
     } catch (Exception e) {
       throw new SamzaException(e);
     }
   }
 
   @Override
-  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
+  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
     incCounter(metrics.numGetAlls);
     // Make a copy of entries which might be immutable
     Map<K, V> getAllResult = new HashMap<>();
@@ -161,14 +161,14 @@
     }
 
     long startNs = clock.nanoTime();
-    return table.getAllAsync(missingKeys).handle((records, e) -> {
+    return table.getAllAsync(missingKeys, args).handle((records, e) -> {
         if (e != null) {
           throw new SamzaException("Failed to get records for " + keys, e);
         } else {
           if (records != null) {
             cache.putAll(records.entrySet().stream()
                 .map(r -> new Entry<>(r.getKey(), r.getValue()))
-                .collect(Collectors.toList()));
+                .collect(Collectors.toList()), args);
             getAllResult.putAll(records);
           }
           updateTimer(metrics.getAllNs, clock.nanoTime() - startNs);
@@ -178,28 +178,28 @@
   }
 
   @Override
-  public void put(K key, V value) {
+  public void put(K key, V value, Object ... args) {
     try {
-      putAsync(key, value).get();
+      putAsync(key, value, args).get();
     } catch (Exception e) {
       throw new SamzaException(e);
     }
   }
 
   @Override
-  public CompletableFuture<Void> putAsync(K key, V value) {
+  public CompletableFuture<Void> putAsync(K key, V value, Object ... args) {
     incCounter(metrics.numPuts);
     Preconditions.checkNotNull(table, "Cannot write to a read-only table: " + table);
 
     long startNs = clock.nanoTime();
-    return table.putAsync(key, value).handle((result, e) -> {
+    return table.putAsync(key, value, args).handle((result, e) -> {
         if (e != null) {
           throw new SamzaException(String.format("Failed to put a record, key=%s, value=%s", key, value), e);
         } else if (!isWriteAround) {
           if (value == null) {
-            cache.delete(key);
+            cache.delete(key, args);
           } else {
-            cache.put(key, value);
+            cache.put(key, value, args);
           }
         }
         updateTimer(metrics.putNs, clock.nanoTime() - startNs);
@@ -208,24 +208,24 @@
   }
 
   @Override
-  public void putAll(List<Entry<K, V>> records) {
+  public void putAll(List<Entry<K, V>> records, Object ... args) {
     try {
-      putAllAsync(records).get();
+      putAllAsync(records, args).get();
     } catch (Exception e) {
       throw new SamzaException(e);
     }
   }
 
   @Override
-  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> records) {
+  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> records, Object ... args) {
     incCounter(metrics.numPutAlls);
     long startNs = clock.nanoTime();
     Preconditions.checkNotNull(table, "Cannot write to a read-only table: " + table);
-    return table.putAllAsync(records).handle((result, e) -> {
+    return table.putAllAsync(records, args).handle((result, e) -> {
         if (e != null) {
           throw new SamzaException("Failed to put records " + records, e);
         } else if (!isWriteAround) {
-          cache.putAll(records);
+          cache.putAll(records, args);
         }
 
         updateTimer(metrics.putAllNs, clock.nanoTime() - startNs);
@@ -234,24 +234,24 @@
   }
 
   @Override
-  public void delete(K key) {
+  public void delete(K key, Object ... args) {
     try {
-      deleteAsync(key).get();
+      deleteAsync(key, args).get();
     } catch (Exception e) {
       throw new SamzaException(e);
     }
   }
 
   @Override
-  public CompletableFuture<Void> deleteAsync(K key) {
+  public CompletableFuture<Void> deleteAsync(K key, Object ... args) {
     incCounter(metrics.numDeletes);
     long startNs = clock.nanoTime();
     Preconditions.checkNotNull(table, "Cannot delete from a read-only table: " + table);
-    return table.deleteAsync(key).handle((result, e) -> {
+    return table.deleteAsync(key, args).handle((result, e) -> {
         if (e != null) {
           throw new SamzaException("Failed to delete the record for " + key, e);
         } else if (!isWriteAround) {
-          cache.delete(key);
+          cache.delete(key, args);
         }
         updateTimer(metrics.deleteNs, clock.nanoTime() - startNs);
         return result;
@@ -259,24 +259,24 @@
   }
 
   @Override
-  public void deleteAll(List<K> keys) {
+  public void deleteAll(List<K> keys, Object ... args) {
     try {
-      deleteAllAsync(keys).get();
+      deleteAllAsync(keys, args).get();
     } catch (Exception e) {
       throw new SamzaException(e);
     }
   }
 
   @Override
-  public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
+  public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
     incCounter(metrics.numDeleteAlls);
     long startNs = clock.nanoTime();
     Preconditions.checkNotNull(table, "Cannot delete from a read-only table: " + table);
-    return table.deleteAllAsync(keys).handle((result, e) -> {
+    return table.deleteAllAsync(keys, args).handle((result, e) -> {
         if (e != null) {
           throw new SamzaException("Failed to delete the record for " + keys, e);
         } else if (!isWriteAround) {
-          cache.deleteAll(keys);
+          cache.deleteAll(keys, args);
         }
         updateTimer(metrics.deleteAllNs, clock.nanoTime() - startNs);
         return result;
@@ -284,6 +284,32 @@
   }
 
   @Override
+  public <T> CompletableFuture<T> readAsync(int opId, Object... args) {
+    incCounter(metrics.numReads);
+    long startNs = clock.nanoTime();
+    return table.readAsync(opId, args).handle((result, e) -> {
+        if (e != null) {
+          throw new SamzaException("Failed to read, opId=" + opId, e);
+        }
+        updateTimer(metrics.readNs, clock.nanoTime() - startNs);
+        return (T) result;
+      });
+  }
+
+  @Override
+  public <T> CompletableFuture<T> writeAsync(int opId, Object... args) {
+    incCounter(metrics.numWrites);
+    long startNs = clock.nanoTime();
+    return table.writeAsync(opId, args).handle((result, e) -> {
+        if (e != null) {
+          throw new SamzaException("Failed to write, opId=" + opId, e);
+        }
+        updateTimer(metrics.writeNs, clock.nanoTime() - startNs);
+        return (T) result;
+      });
+  }
+
+  @Override
   public synchronized void flush() {
     incCounter(metrics.numFlushes);
     long startNs = clock.nanoTime();
diff --git a/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTable.java b/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTable.java
index d8a5d9c..02083f3 100644
--- a/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTable.java
+++ b/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTable.java
@@ -59,7 +59,7 @@
   }
 
   @Override
-  public V get(K key) {
+  public V get(K key, Object ... args) {
     try {
       return getAsync(key).get();
     } catch (Exception e) {
@@ -68,7 +68,7 @@
   }
 
   @Override
-  public CompletableFuture<V> getAsync(K key) {
+  public CompletableFuture<V> getAsync(K key, Object ... args) {
     CompletableFuture<V> future = new CompletableFuture<>();
     try {
       future.complete(cache.getIfPresent(key));
@@ -79,7 +79,7 @@
   }
 
   @Override
-  public Map<K, V> getAll(List<K> keys) {
+  public Map<K, V> getAll(List<K> keys, Object ... args) {
     try {
       return getAllAsync(keys).get();
     } catch (Exception e) {
@@ -88,7 +88,7 @@
   }
 
   @Override
-  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
+  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
     CompletableFuture<Map<K, V>> future = new CompletableFuture<>();
     try {
       future.complete(cache.getAllPresent(keys));
@@ -99,7 +99,7 @@
   }
 
   @Override
-  public void put(K key, V value) {
+  public void put(K key, V value, Object ... args) {
     try {
       putAsync(key, value).get();
     } catch (Exception e) {
@@ -108,7 +108,7 @@
   }
 
   @Override
-  public CompletableFuture<Void> putAsync(K key, V value) {
+  public CompletableFuture<Void> putAsync(K key, V value, Object ... args) {
     if (key == null) {
       return deleteAsync(key);
     }
@@ -124,7 +124,7 @@
   }
 
   @Override
-  public void putAll(List<Entry<K, V>> entries) {
+  public void putAll(List<Entry<K, V>> entries, Object ... args) {
     try {
       putAllAsync(entries).get();
     } catch (Exception e) {
@@ -133,7 +133,7 @@
   }
 
   @Override
-  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries) {
+  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries, Object ... args) {
     CompletableFuture<Void> future = new CompletableFuture<>();
     try {
       // Separate out put vs delete records
@@ -157,7 +157,7 @@
   }
 
   @Override
-  public void delete(K key) {
+  public void delete(K key, Object ... args) {
     try {
       deleteAsync(key).get();
     } catch (Exception e) {
@@ -166,7 +166,7 @@
   }
 
   @Override
-  public CompletableFuture<Void> deleteAsync(K key) {
+  public CompletableFuture<Void> deleteAsync(K key, Object ... args) {
     CompletableFuture<Void> future = new CompletableFuture<>();
     try {
       cache.invalidate(key);
@@ -178,7 +178,7 @@
   }
 
   @Override
-  public void deleteAll(List<K> keys) {
+  public void deleteAll(List<K> keys, Object ... args) {
     try {
       deleteAllAsync(keys).get();
     } catch (Exception e) {
@@ -187,7 +187,7 @@
   }
 
   @Override
-  public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
+  public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
     CompletableFuture<Void> future = new CompletableFuture<>();
     try {
       cache.invalidateAll(keys);
diff --git a/samza-core/src/main/java/org/apache/samza/table/ratelimit/AsyncRateLimitedTable.java b/samza-core/src/main/java/org/apache/samza/table/ratelimit/AsyncRateLimitedTable.java
index 69f3dd3..75fed12 100644
--- a/samza-core/src/main/java/org/apache/samza/table/ratelimit/AsyncRateLimitedTable.java
+++ b/samza-core/src/main/java/org/apache/samza/table/ratelimit/AsyncRateLimitedTable.java
@@ -19,10 +19,12 @@
 package org.apache.samza.table.ratelimit;
 
 import com.google.common.base.Preconditions;
+
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutorService;
+
 import org.apache.samza.config.MetricsConfig;
 import org.apache.samza.context.Context;
 import org.apache.samza.storage.kv.Entry;
@@ -30,6 +32,8 @@
 import org.apache.samza.table.remote.TableRateLimiter;
 import org.apache.samza.table.utils.TableMetricsUtil;
 
+import static org.apache.samza.table.BaseReadWriteTable.Func0;
+import static org.apache.samza.table.BaseReadWriteTable.Func1;
 
 /**
  * A composable read and/or write rate limited asynchronous table implementation
@@ -60,57 +64,59 @@
   }
 
   @Override
-  public CompletableFuture<V> getAsync(K key) {
-    return isReadRateLimited()
-      ? CompletableFuture
-          .runAsync(() -> readRateLimiter.throttle(key), rateLimitingExecutor)
-          .thenCompose((r) -> table.getAsync(key))
-      : table.getAsync(key);
+  public CompletableFuture<V> getAsync(K key, Object ... args) {
+    return doRead(
+        () -> readRateLimiter.throttle(key, args),
+        () -> table.getAsync(key, args));
   }
 
   @Override
-  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
-    return isReadRateLimited()
-      ? CompletableFuture
-          .runAsync(() -> readRateLimiter.throttle(keys), rateLimitingExecutor)
-          .thenCompose((r) -> table.getAllAsync(keys))
-      : table.getAllAsync(keys);
+  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
+    return doRead(
+        () -> readRateLimiter.throttle(keys, args),
+        () -> table.getAllAsync(keys, args));
   }
 
   @Override
-  public CompletableFuture<Void> putAsync(K key, V value) {
-    return isWriteRateLimited()
-        ? CompletableFuture
-            .runAsync(() -> writeRateLimiter.throttle(key, value), rateLimitingExecutor)
-            .thenCompose((r) -> table.putAsync(key, value))
-        : table.putAsync(key, value);
+  public <T> CompletableFuture<T> readAsync(int opId, Object ... args) {
+    return doRead(
+        () -> readRateLimiter.throttle(opId, args),
+        () -> table.readAsync(opId, args));
   }
 
   @Override
-  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries) {
-    return isWriteRateLimited()
-      ? CompletableFuture
-          .runAsync(() -> writeRateLimiter.throttleRecords(entries), rateLimitingExecutor)
-          .thenCompose((r) -> table.putAllAsync(entries))
-      : table.putAllAsync(entries);
+  public CompletableFuture<Void> putAsync(K key, V value, Object ... args) {
+    return doWrite(
+        () -> writeRateLimiter.throttle(key, value, args),
+        () -> table.putAsync(key, value, args));
   }
 
   @Override
-  public CompletableFuture<Void> deleteAsync(K key) {
-    return isWriteRateLimited()
-      ? CompletableFuture
-          .runAsync(() -> writeRateLimiter.throttle(key), rateLimitingExecutor)
-          .thenCompose((r) -> table.deleteAsync(key))
-      : table.deleteAsync(key);
+  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries, Object ... args) {
+    return doWrite(
+        () -> writeRateLimiter.throttleRecords(entries),
+        () -> table.putAllAsync(entries, args));
   }
 
   @Override
-  public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
-    return isWriteRateLimited()
-      ? CompletableFuture
-          .runAsync(() -> writeRateLimiter.throttle(keys), rateLimitingExecutor)
-          .thenCompose((r) -> table.deleteAllAsync(keys))
-      : table.deleteAllAsync(keys);
+  public CompletableFuture<Void> deleteAsync(K key, Object ... args) {
+    return doWrite(
+        () -> writeRateLimiter.throttle(key, args),
+        () -> table.deleteAsync(key, args));
+  }
+
+  @Override
+  public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
+    return doWrite(
+        () -> writeRateLimiter.throttle(keys, args),
+        () -> table.deleteAllAsync(keys, args));
+  }
+
+  @Override
+  public <T> CompletableFuture<T> writeAsync(int opId, Object ... args) {
+    return doWrite(
+        () -> writeRateLimiter.throttle(opId, args),
+        () -> table.writeAsync(opId, args));
   }
 
   @Override
@@ -145,4 +151,21 @@
   private boolean isWriteRateLimited() {
     return writeRateLimiter != null;
   }
+
+  private <T> CompletableFuture<T> doRead(Func0 throttleFunc, Func1<T> func) {
+    return isReadRateLimited()
+        ? CompletableFuture
+            .runAsync(() -> throttleFunc.apply(), rateLimitingExecutor)
+            .thenCompose((r) -> func.apply())
+        : func.apply();
+  }
+
+  private <T> CompletableFuture<T> doWrite(Func0 throttleFunc, Func1<T> func) {
+    return isWriteRateLimited()
+        ? CompletableFuture
+            .runAsync(() -> throttleFunc.apply(), rateLimitingExecutor)
+            .thenCompose((r) -> func.apply())
+        : func.apply();
+  }
+
 }
diff --git a/samza-core/src/main/java/org/apache/samza/table/remote/AsyncRemoteTable.java b/samza-core/src/main/java/org/apache/samza/table/remote/AsyncRemoteTable.java
index d4dbc03..4b1851b 100644
--- a/samza-core/src/main/java/org/apache/samza/table/remote/AsyncRemoteTable.java
+++ b/samza-core/src/main/java/org/apache/samza/table/remote/AsyncRemoteTable.java
@@ -19,9 +19,11 @@
 package org.apache.samza.table.remote;
 
 import com.google.common.base.Preconditions;
+
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
+
 import org.apache.samza.context.Context;
 import org.apache.samza.storage.kv.Entry;
 import org.apache.samza.table.AsyncReadWriteTable;
@@ -46,45 +48,66 @@
   }
 
   @Override
-  public CompletableFuture<V> getAsync(K key) {
-    return readFn.getAsync(key);
+  public CompletableFuture<V> getAsync(K key, Object ... args) {
+    return args.length > 0
+        ? readFn.getAsync(key, args)
+        : readFn.getAsync(key);
   }
 
   @Override
-  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
-    return readFn.getAllAsync(keys);
+  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
+    return args.length > 0
+        ? readFn.getAllAsync(keys, args)
+        : readFn.getAllAsync(keys);
   }
 
   @Override
-  public CompletableFuture<Void> putAsync(K key, V value) {
+  public <T> CompletableFuture<T> readAsync(int opId, Object... args) {
+    return readFn.readAsync(opId, args);
+  }
+
+  @Override
+  public CompletableFuture<Void> putAsync(K key, V record, Object... args) {
     Preconditions.checkNotNull(writeFn, "null writeFn");
-    return writeFn.putAsync(key, value);
+    return args.length > 0
+        ? writeFn.putAsync(key, record, args)
+        : writeFn.putAsync(key, record);
   }
 
   @Override
-  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries) {
+  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries, Object ... args) {
     Preconditions.checkNotNull(writeFn, "null writeFn");
-    return writeFn.putAllAsync(entries);
+    return args.length > 0
+        ? writeFn.putAllAsync(entries, args)
+        : writeFn.putAllAsync(entries);
   }
 
   @Override
-  public CompletableFuture<Void> deleteAsync(K key) {
+  public CompletableFuture<Void> deleteAsync(K key, Object... args) {
     Preconditions.checkNotNull(writeFn, "null writeFn");
-    return writeFn.deleteAsync(key);
+    return args.length > 0
+        ? writeFn.deleteAsync(key, args)
+        : writeFn.deleteAsync(key);
   }
 
   @Override
-  public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
+  public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
     Preconditions.checkNotNull(writeFn, "null writeFn");
-    return writeFn.deleteAllAsync(keys);
+    return args.length > 0
+        ? writeFn.deleteAllAsync(keys, args)
+        : writeFn.deleteAllAsync(keys);
+  }
+
+  @Override
+  public <T> CompletableFuture<T> writeAsync(int opId, Object... args) {
+    Preconditions.checkNotNull(writeFn, "null writeFn");
+    return writeFn.writeAsync(opId, args);
   }
 
   @Override
   public void init(Context context) {
-    readFn.init(context);
-    if (writeFn != null) {
-      writeFn.init(context);
-    }
+    // Note: Initialization of table functions is done in {@link RemoteTable#init(Context)},
+    //       as we need to pass in the reference to the top level table
   }
 
   @Override
diff --git a/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTable.java b/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTable.java
index 8301661..23f4c01 100644
--- a/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTable.java
+++ b/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTable.java
@@ -74,8 +74,8 @@
  * @param <K> the type of the key in this table
  * @param <V> the type of the value in this table
  */
-public class RemoteTable<K, V> extends BaseReadWriteTable<K, V>
-    implements ReadWriteTable<K, V> {
+public final class RemoteTable<K, V> extends BaseReadWriteTable<K, V>
+    implements ReadWriteTable<K, V>, AsyncReadWriteTable<K, V> {
 
   // Read/write functions
   protected final TableReadFunction<K, V> readFn;
@@ -144,18 +144,18 @@
   }
 
   @Override
-  public V get(K key) {
+  public V get(K key, Object ... args) {
     try {
-      return getAsync(key).get();
+      return getAsync(key, args).get();
     } catch (Exception e) {
       throw new SamzaException(e);
     }
   }
 
   @Override
-  public CompletableFuture<V> getAsync(K key) {
+  public CompletableFuture<V> getAsync(K key, Object ... args) {
     Preconditions.checkNotNull(key, "null key");
-    return instrument(() -> asyncTable.getAsync(key), metrics.numGets, metrics.getNs)
+    return instrument(() -> asyncTable.getAsync(key, args), metrics.numGets, metrics.getNs)
         .handle((result, e) -> {
             if (e != null) {
               throw new SamzaException("Failed to get the records for " + key, e);
@@ -168,21 +168,21 @@
   }
 
   @Override
-  public Map<K, V> getAll(List<K> keys) {
+  public Map<K, V> getAll(List<K> keys, Object ... args) {
     try {
-      return getAllAsync(keys).get();
+      return getAllAsync(keys, args).get();
     } catch (Exception e) {
       throw new SamzaException(e);
     }
   }
 
   @Override
-  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
+  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
     Preconditions.checkNotNull(keys, "null keys");
     if (keys.isEmpty()) {
       return CompletableFuture.completedFuture(Collections.EMPTY_MAP);
     }
-    return instrument(() -> asyncTable.getAllAsync(keys), metrics.numGetAlls, metrics.getAllNs)
+    return instrument(() -> asyncTable.getAllAsync(keys, args), metrics.numGetAlls, metrics.getAllNs)
         .handle((result, e) -> {
             if (e != null) {
               throw new SamzaException("Failed to get the records for " + keys, e);
@@ -193,39 +193,56 @@
   }
 
   @Override
-  public void put(K key, V value) {
+  public <T> T read(int opId, Object ... args) {
     try {
-      putAsync(key, value).get();
+      return (T) readAsync(opId, args).get();
     } catch (Exception e) {
       throw new SamzaException(e);
     }
   }
 
   @Override
-  public CompletableFuture<Void> putAsync(K key, V value) {
+  public <T> CompletableFuture<T> readAsync(int opId, Object ... args) {
+    return (CompletableFuture<T>) instrument(() -> asyncTable.readAsync(opId, args), metrics.numReads, metrics.readNs)
+        .exceptionally(e -> {
+            throw new SamzaException(String.format("Failed to read, opId=%d", opId), e);
+          });
+  }
+
+  @Override
+  public void put(K key, V value, Object ... args) {
+    try {
+      putAsync(key, value, args).get();
+    } catch (Exception e) {
+      throw new SamzaException(e);
+    }
+  }
+
+  @Override
+  public CompletableFuture<Void> putAsync(K key, V value, Object ... args) {
     Preconditions.checkNotNull(writeFn, "null write function");
     Preconditions.checkNotNull(key, "null key");
     if (value == null) {
-      return deleteAsync(key);
+      return deleteAsync(key, args);
     }
 
-    return instrument(() -> asyncTable.putAsync(key, value), metrics.numPuts, metrics.putNs)
+    return instrument(() -> asyncTable.putAsync(key, value, args), metrics.numPuts, metrics.putNs)
         .exceptionally(e -> {
             throw new SamzaException("Failed to put a record with key=" + key, (Throwable) e);
           });
   }
 
   @Override
-  public void putAll(List<Entry<K, V>> entries) {
+  public void putAll(List<Entry<K, V>> entries, Object ... args) {
     try {
-      putAllAsync(entries).get();
+      putAllAsync(entries, args).get();
     } catch (Exception e) {
       throw new SamzaException(e);
     }
   }
 
   @Override
-  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> records) {
+  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> records, Object ... args) {
 
     Preconditions.checkNotNull(writeFn, "null write function");
     Preconditions.checkNotNull(records, "null records");
@@ -241,12 +258,12 @@
 
     CompletableFuture<Void> deleteFuture = deleteKeys.isEmpty()
         ? CompletableFuture.completedFuture(null)
-        : deleteAllAsync(deleteKeys);
+        : deleteAllAsync(deleteKeys, args);
 
     // Return the combined future
     return CompletableFuture.allOf(
         deleteFuture,
-        instrument(() -> asyncTable.putAllAsync(putRecords), metrics.numPutAlls, metrics.putAllNs))
+        instrument(() -> asyncTable.putAllAsync(putRecords, args), metrics.numPutAlls, metrics.putAllNs))
         .exceptionally(e -> {
             String strKeys = records.stream().map(r -> r.getKey().toString()).collect(Collectors.joining(","));
             throw new SamzaException(String.format("Failed to put records with keys=" + strKeys), e);
@@ -254,44 +271,61 @@
   }
 
   @Override
-  public void delete(K key) {
+  public void delete(K key, Object ... args) {
     try {
-      deleteAsync(key).get();
+      deleteAsync(key, args).get();
     } catch (Exception e) {
       throw new SamzaException(e);
     }
   }
 
   @Override
-  public CompletableFuture<Void> deleteAsync(K key) {
+  public CompletableFuture<Void> deleteAsync(K key, Object ... args) {
     Preconditions.checkNotNull(writeFn, "null write function");
     Preconditions.checkNotNull(key, "null key");
-    return instrument(() -> asyncTable.deleteAsync(key), metrics.numDeletes, metrics.deleteNs)
+    return instrument(() -> asyncTable.deleteAsync(key, args), metrics.numDeletes, metrics.deleteNs)
         .exceptionally(e -> {
             throw new SamzaException(String.format("Failed to delete the record for " + key), (Throwable) e);
           });
   }
 
   @Override
-  public void deleteAll(List<K> keys) {
+  public void deleteAll(List<K> keys, Object ... args) {
     try {
-      deleteAllAsync(keys).get();
+      deleteAllAsync(keys, args).get();
     } catch (Exception e) {
       throw new SamzaException(e);
     }
   }
 
   @Override
-  public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
+  public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
     Preconditions.checkNotNull(writeFn, "null write function");
     Preconditions.checkNotNull(keys, "null keys");
     if (keys.isEmpty()) {
       return CompletableFuture.completedFuture(null);
     }
 
-    return instrument(() -> asyncTable.deleteAllAsync(keys), metrics.numDeleteAlls, metrics.deleteAllNs)
+    return instrument(() -> asyncTable.deleteAllAsync(keys, args), metrics.numDeleteAlls, metrics.deleteAllNs)
         .exceptionally(e -> {
-            throw new SamzaException(String.format("Failed to delete records for " + keys), (Throwable) e);
+            throw new SamzaException(String.format("Failed to delete records for " + keys), e);
+          });
+  }
+
+  @Override
+  public <T> T write(int opId, Object ... args) {
+    try {
+      return (T) writeAsync(opId, args).get();
+    } catch (Exception e) {
+      throw new SamzaException(e);
+    }
+  }
+
+  @Override
+  public <T> CompletableFuture<T> writeAsync(int opId, Object... args) {
+    return (CompletableFuture<T>) instrument(() -> asyncTable.writeAsync(opId, args), metrics.numWrites, metrics.writeNs)
+        .exceptionally(e -> {
+            throw new SamzaException(String.format("Failed to write, opId=%d", opId), e);
           });
   }
 
@@ -299,6 +333,10 @@
   public void init(Context context) {
     super.init(context);
     asyncTable.init(context);
+    readFn.init(context, this);
+    if (writeFn != null) {
+      writeFn.init(context, this);
+    }
   }
 
   @Override
@@ -320,6 +358,14 @@
     asyncTable.close();
   }
 
+  public TableReadFunction<K, V> getReadFunction() {
+    return readFn;
+  }
+
+  public TableWriteFunction<K, V> getWriteFunction() {
+    return writeFn;
+  }
+
   protected <T> CompletableFuture<T> instrument(Func1<T> func, Counter counter, Timer timer) {
     incCounter(counter);
     final long startNs = clock.nanoTime();
diff --git a/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTableProvider.java b/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTableProvider.java
index 1716244..aca0a4b 100644
--- a/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTableProvider.java
+++ b/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTableProvider.java
@@ -20,6 +20,7 @@
 package org.apache.samza.table.remote;
 
 import com.google.common.base.Preconditions;
+
 import org.apache.samza.config.JavaTableConfig;
 import org.apache.samza.table.ReadWriteTable;
 import org.apache.samza.table.descriptors.RemoteTableDescriptor;
@@ -64,7 +65,7 @@
     JavaTableConfig tableConfig = new JavaTableConfig(context.getJobContext().getConfig());
 
     // Read part
-    TableReadFunction readFn = getReadFn(tableConfig);
+    TableReadFunction readFn = deserializeObject(tableConfig, RemoteTableDescriptor.READ_FN);
     RateLimiter rateLimiter = deserializeObject(tableConfig, RemoteTableDescriptor.RATE_LIMITER);
     if (rateLimiter != null) {
       rateLimiter.init(this.context);
@@ -77,9 +78,9 @@
     TableRetryPolicy readRetryPolicy = deserializeObject(tableConfig, RemoteTableDescriptor.READ_RETRY_POLICY);
 
     // Write part
-    TableWriteFunction writeFn = getWriteFn(tableConfig);
     TableRateLimiter writeRateLimiter = null;
     TableRetryPolicy writeRetryPolicy = null;
+    TableWriteFunction writeFn = deserializeObject(tableConfig, RemoteTableDescriptor.WRITE_FN);
     if (writeFn != null) {
       TableRateLimiter.CreditFunction<?, ?> writeCreditFn = deserializeObject(tableConfig, RemoteTableDescriptor.WRITE_CREDIT_FN);
       writeRateLimiter = rateLimiter != null && rateLimiter.getSupportedTags().contains(RemoteTableDescriptor.RL_WRITE_TAG)
@@ -130,7 +131,9 @@
     super.close();
     tables.forEach(t -> t.close());
     rateLimitingExecutors.values().forEach(e -> e.shutdown());
+    rateLimitingExecutors.clear();
     callbackExecutors.values().forEach(e -> e.shutdown());
+    callbackExecutors.clear();
   }
 
   private <T> T deserializeObject(JavaTableConfig tableConfig, String key) {
@@ -141,22 +144,6 @@
     return SerdeUtils.deserialize(key, entry);
   }
 
-  private TableReadFunction<?, ?> getReadFn(JavaTableConfig tableConfig) {
-    TableReadFunction<?, ?> readFn = deserializeObject(tableConfig, RemoteTableDescriptor.READ_FN);
-    if (readFn != null) {
-      readFn.init(this.context);
-    }
-    return readFn;
-  }
-
-  private TableWriteFunction<?, ?> getWriteFn(JavaTableConfig tableConfig) {
-    TableWriteFunction<?, ?> writeFn = deserializeObject(tableConfig, RemoteTableDescriptor.WRITE_FN);
-    if (writeFn != null) {
-      writeFn.init(this.context);
-    }
-    return writeFn;
-  }
-
   private ScheduledExecutorService createRetryExecutor() {
     return Executors.newSingleThreadScheduledExecutor(runnable -> {
         Thread thread = new Thread(runnable);
diff --git a/samza-core/src/main/java/org/apache/samza/table/retry/AsyncRetriableTable.java b/samza-core/src/main/java/org/apache/samza/table/retry/AsyncRetriableTable.java
index ba39517..589fb14 100644
--- a/samza-core/src/main/java/org/apache/samza/table/retry/AsyncRetriableTable.java
+++ b/samza-core/src/main/java/org/apache/samza/table/retry/AsyncRetriableTable.java
@@ -93,33 +93,43 @@
   }
 
   @Override
-  public CompletableFuture<V> getAsync(K key) {
-    return doRead(() -> table.getAsync(key));
+  public CompletableFuture<V> getAsync(K key, Object... args) {
+    return doRead(() -> table.getAsync(key, args));
   }
 
   @Override
-  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
-    return doRead(() -> table.getAllAsync(keys));
+  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
+    return doRead(() -> table.getAllAsync(keys, args));
   }
 
   @Override
-  public CompletableFuture<Void> putAsync(K key, V value) {
-    return doWrite(() -> table.putAsync(key, value));
+  public <T> CompletableFuture<T> readAsync(int opId, Object... args) {
+    return doRead(() -> table.readAsync(opId, args));
   }
 
   @Override
-  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries) {
-    return doWrite(() -> table.putAllAsync(entries));
+  public CompletableFuture<Void> putAsync(K key, V value, Object... args) {
+    return doWrite(() -> table.putAsync(key, value, args));
   }
 
   @Override
-  public CompletableFuture<Void> deleteAsync(K key) {
-    return doWrite(() -> table.deleteAsync(key));
+  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries, Object ... args) {
+    return doWrite(() -> table.putAllAsync(entries, args));
   }
 
   @Override
-  public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
-    return doWrite(() -> table.deleteAllAsync(keys));
+  public CompletableFuture<Void> deleteAsync(K key, Object... args) {
+    return doWrite(() -> table.deleteAsync(key, args));
+  }
+
+  @Override
+  public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
+    return doWrite(() -> table.deleteAllAsync(keys, args));
+  }
+
+  @Override
+  public <T> CompletableFuture<T> writeAsync(int opId, Object... args) {
+    return doWrite(() -> table.writeAsync(opId, args));
   }
 
   @Override
diff --git a/samza-core/src/main/java/org/apache/samza/table/retry/RetriableReadFunction.java b/samza-core/src/main/java/org/apache/samza/table/retry/RetriableReadFunction.java
deleted file mode 100644
index 1adddc0..0000000
--- a/samza-core/src/main/java/org/apache/samza/table/retry/RetriableReadFunction.java
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.table.retry;
-
-import java.util.Collection;
-import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.function.Predicate;
-
-import org.apache.samza.SamzaException;
-import org.apache.samza.table.remote.TableReadFunction;
-import org.apache.samza.table.utils.TableMetricsUtil;
-
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
-import net.jodah.failsafe.RetryPolicy;
-
-import static org.apache.samza.table.retry.FailsafeAdapter.failsafe;
-
-
-/**
- * Wrapper for a {@link TableReadFunction} instance to add common retry
- * support with a {@link TableRetryPolicy}. This wrapper is created by
- * {@link org.apache.samza.table.remote.RemoteTableProvider} when a retry
- * policy is specified together with the {@link TableReadFunction}.
- *
- * Actual retry mechanism is provided by the failsafe library. Retry is
- * attempted in an async way with a {@link ScheduledExecutorService}.
- *
- * @param <K> the type of the key in this table
- * @param <V> the type of the value in this table
- */
-public class RetriableReadFunction<K, V> implements TableReadFunction<K, V> {
-  private final RetryPolicy retryPolicy;
-  private final TableReadFunction<K, V> readFn;
-  private final ScheduledExecutorService retryExecutor;
-
-  @VisibleForTesting
-  RetryMetrics retryMetrics;
-
-  public RetriableReadFunction(TableRetryPolicy policy, TableReadFunction<K, V> readFn,
-      ScheduledExecutorService retryExecutor) {
-    Preconditions.checkNotNull(policy);
-    Preconditions.checkNotNull(readFn);
-    Preconditions.checkNotNull(retryExecutor);
-
-    this.readFn = readFn;
-    this.retryExecutor = retryExecutor;
-    Predicate<Throwable> retryPredicate = policy.getRetryPredicate();
-    policy.withRetryPredicate((ex) -> readFn.isRetriable(ex) || retryPredicate.test(ex));
-    this.retryPolicy = FailsafeAdapter.valueOf(policy);
-  }
-
-  @Override
-  public CompletableFuture<V> getAsync(K key) {
-    return failsafe(retryPolicy, retryMetrics, retryExecutor)
-        .future(() -> readFn.getAsync(key))
-        .exceptionally(e -> {
-            throw new SamzaException("Failed to get the record for " + key + " after retries.", e);
-          });
-  }
-
-  @Override
-  public CompletableFuture<Map<K, V>> getAllAsync(Collection<K> keys) {
-    return failsafe(retryPolicy, retryMetrics, retryExecutor)
-        .future(() -> readFn.getAllAsync(keys))
-        .exceptionally(e -> {
-            throw new SamzaException("Failed to get the records for " + keys + " after retries.", e);
-          });
-  }
-
-  @Override
-  public boolean isRetriable(Throwable exception) {
-    return readFn.isRetriable(exception);
-  }
-
-  /**
-   * Initialize retry-related metrics
-   * @param metricsUtil metrics util
-   */
-  public void setMetrics(TableMetricsUtil metricsUtil) {
-    this.retryMetrics = new RetryMetrics("reader", metricsUtil);
-  }
-}
diff --git a/samza-core/src/main/java/org/apache/samza/table/retry/RetriableWriteFunction.java b/samza-core/src/main/java/org/apache/samza/table/retry/RetriableWriteFunction.java
deleted file mode 100644
index 2f3f062..0000000
--- a/samza-core/src/main/java/org/apache/samza/table/retry/RetriableWriteFunction.java
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.table.retry;
-
-import java.util.Collection;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.function.Predicate;
-
-import org.apache.samza.SamzaException;
-import org.apache.samza.storage.kv.Entry;
-import org.apache.samza.table.remote.TableWriteFunction;
-import org.apache.samza.table.utils.TableMetricsUtil;
-
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
-import net.jodah.failsafe.RetryPolicy;
-
-import static org.apache.samza.table.retry.FailsafeAdapter.failsafe;
-
-
-/**
- * Wrapper for a {@link TableWriteFunction} instance to add common retry
- * support with a {@link TableRetryPolicy}. This wrapper is created by
- * {@link org.apache.samza.table.remote.RemoteTableProvider} when a retry
- * policy is specified together with the {@link TableWriteFunction}.
- *
- * Actual retry mechanism is provided by the failsafe library. Retry is
- * attempted in an async way with a {@link ScheduledExecutorService}.
- *
- * @param <K> the type of the key in this table
- * @param <V> the type of the value in this table
- */
-public class RetriableWriteFunction<K, V> implements TableWriteFunction<K, V> {
-  private final RetryPolicy retryPolicy;
-  private final TableWriteFunction<K, V> writeFn;
-  private final ScheduledExecutorService retryExecutor;
-
-  @VisibleForTesting
-  RetryMetrics retryMetrics;
-
-  public RetriableWriteFunction(TableRetryPolicy policy, TableWriteFunction<K, V> writeFn,
-      ScheduledExecutorService retryExecutor)  {
-    Preconditions.checkNotNull(policy);
-    Preconditions.checkNotNull(writeFn);
-    Preconditions.checkNotNull(retryExecutor);
-
-    this.writeFn = writeFn;
-    this.retryExecutor = retryExecutor;
-    Predicate<Throwable> retryPredicate = policy.getRetryPredicate();
-    policy.withRetryPredicate((ex) -> writeFn.isRetriable(ex) || retryPredicate.test(ex));
-    this.retryPolicy = FailsafeAdapter.valueOf(policy);
-  }
-
-  @Override
-  public CompletableFuture<Void> putAsync(K key, V record) {
-    return failsafe(retryPolicy, retryMetrics, retryExecutor)
-        .future(() -> writeFn.putAsync(key, record))
-        .exceptionally(e -> {
-            throw new SamzaException("Failed to get the record for " + key + " after retries.", e);
-          });
-  }
-
-  @Override
-  public CompletableFuture<Void> putAllAsync(Collection<Entry<K, V>> records) {
-    return failsafe(retryPolicy, retryMetrics, retryExecutor)
-        .future(() -> writeFn.putAllAsync(records))
-        .exceptionally(e -> {
-            throw new SamzaException("Failed to put records after retries.", e);
-          });
-  }
-
-  @Override
-  public CompletableFuture<Void> deleteAsync(K key) {
-    return failsafe(retryPolicy, retryMetrics, retryExecutor)
-        .future(() -> writeFn.deleteAsync(key))
-        .exceptionally(e -> {
-            throw new SamzaException("Failed to delete the record for " + key + " after retries.", e);
-          });
-  }
-
-  @Override
-  public CompletableFuture<Void> deleteAllAsync(Collection<K> keys) {
-    return failsafe(retryPolicy, retryMetrics, retryExecutor)
-        .future(() -> writeFn.deleteAllAsync(keys))
-        .exceptionally(e -> {
-            throw new SamzaException("Failed to delete the records for " + keys + " after retries.", e);
-          });
-  }
-
-  @Override
-  public boolean isRetriable(Throwable exception) {
-    return writeFn.isRetriable(exception);
-  }
-
-  /**
-   * Initialize retry-related metrics.
-   * @param metricsUtil metrics util
-   */
-  public void setMetrics(TableMetricsUtil metricsUtil) {
-    this.retryMetrics = new RetryMetrics("writer", metricsUtil);
-  }
-}
diff --git a/samza-core/src/main/java/org/apache/samza/table/retry/RetryMetrics.java b/samza-core/src/main/java/org/apache/samza/table/retry/RetryMetrics.java
index fbc511c..f717462 100644
--- a/samza-core/src/main/java/org/apache/samza/table/retry/RetryMetrics.java
+++ b/samza-core/src/main/java/org/apache/samza/table/retry/RetryMetrics.java
@@ -25,8 +25,7 @@
 
 
 /**
- * Wrapper of retry-related metrics common to both {@link RetriableReadFunction} and
- * {@link RetriableWriteFunction}.
+ * Retry-related metrics
  */
 class RetryMetrics {
   /**
diff --git a/samza-core/src/main/java/org/apache/samza/table/utils/TableMetrics.java b/samza-core/src/main/java/org/apache/samza/table/utils/TableMetrics.java
index df6833e..5906bed 100644
--- a/samza-core/src/main/java/org/apache/samza/table/utils/TableMetrics.java
+++ b/samza-core/src/main/java/org/apache/samza/table/utils/TableMetrics.java
@@ -32,8 +32,10 @@
   // Read metrics
   public final Timer getNs;
   public final Timer getAllNs;
+  public final Timer readNs;
   public final Counter numGets;
   public final Counter numGetAlls;
+  public final Counter numReads;
   public final Counter numMissedLookups;
   // Write metrics
   public final Counter numPuts;
@@ -44,6 +46,8 @@
   public final Timer deleteNs;
   public final Counter numDeleteAlls;
   public final Timer deleteAllNs;
+  public final Counter numWrites;
+  public final Timer writeNs;
   public final Counter numFlushes;
   public final Timer flushNs;
 
@@ -61,6 +65,8 @@
     getNs = tableMetricsUtil.newTimer("get-ns");
     numGetAlls = tableMetricsUtil.newCounter("num-getAlls");
     getAllNs = tableMetricsUtil.newTimer("getAll-ns");
+    numReads = tableMetricsUtil.newCounter("num-reads");
+    readNs = tableMetricsUtil.newTimer("read-ns");
     numMissedLookups = tableMetricsUtil.newCounter("num-missed-lookups");
     // Write metrics
     numPuts = tableMetricsUtil.newCounter("num-puts");
@@ -71,6 +77,8 @@
     deleteNs = tableMetricsUtil.newTimer("delete-ns");
     numDeleteAlls = tableMetricsUtil.newCounter("num-deleteAlls");
     deleteAllNs = tableMetricsUtil.newTimer("deleteAll-ns");
+    numWrites = tableMetricsUtil.newCounter("num-writes");
+    writeNs = tableMetricsUtil.newTimer("write-ns");
     numFlushes = tableMetricsUtil.newCounter("num-flushes");
     flushNs = tableMetricsUtil.newTimer("flush-ns");
   }
diff --git a/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala b/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
index 09f778a..90f9668 100644
--- a/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
+++ b/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
@@ -356,20 +356,13 @@
 
     // delete corresponding startpoints after checkpoint is supposed to be committed
     if (startpointManager != null && startpoints.contains(taskName)) {
-      val sspStartpoints = checkpoint.getOffsets.keySet.asScala
-        .intersect(startpoints.getOrElse(taskName, Map.empty[SystemStreamPartition, Startpoint]).keySet)
-
-      // delete startpoints for this task and the intersection of SSPs between checkpoint and startpoint.
-      sspStartpoints.foreach(ssp => {
-        startpointManager.deleteStartpoint(ssp, taskName)
-        info("Deleted startpoint for SSP: %s and task: %s" format (ssp, taskName))
-      })
+      info("%d startpoint(s) for taskName: %s have been committed to the checkpoint." format (startpoints.get(taskName).size, taskName.getTaskName))
+      startpointManager.removeFanOutForTask(taskName)
       startpoints -= taskName
 
       if (startpoints.isEmpty) {
-        // Stop startpoint manager after last startpoint is deleted
-        startpointManager.stop()
-        info("No more startpoints left to consume. Stopped the startpoint manager.")
+        info("All outstanding startpoints have been committed to the checkpoint.")
+        startpointManager.stop
       }
     }
   }
@@ -384,11 +377,11 @@
     }
 
     if (startpointManager != null) {
-      debug("Ensuring startpoint manager has shut down.")
+      debug("Shutting down startpoint manager.")
 
       startpointManager.stop
     } else {
-      debug("Skipping startpoint manager shutdown because no checkpoint manager is defined.")
+      debug("Skipping startpoint manager shutdown because no startpoint manager is defined.")
     }
   }
 
@@ -517,23 +510,26 @@
     if (startpointManager != null) {
       info("Starting startpoint manager.")
       startpointManager.start
-      val taskNameToSSPs: Map[TaskName, Set[SystemStreamPartition]] = systemStreamPartitions
 
-      taskNameToSSPs.foreach {
+      systemStreamPartitions.foreach {
         case (taskName, systemStreamPartitionSet) => {
-          val sspToStartpoint = systemStreamPartitionSet
-            .map(ssp => (ssp, startpointManager.readStartpoint(ssp, taskName)))
-            .filter(_._2 != null)
-            .toMap
-
-          if (!sspToStartpoint.isEmpty) {
-            startpoints += taskName -> sspToStartpoint
+          Option(startpointManager.getFanOutForTask(taskName)) match {
+            case Some(fanOut) => {
+              val filteredFanOut = fanOut.asScala
+                .filter(f => systemStreamPartitionSet.contains(f._1))
+                .toMap
+              if (!filteredFanOut.isEmpty) {
+                startpoints += taskName -> filteredFanOut
+                info("Startpoint fan out for task: %s - %s" format (taskName, filteredFanOut))
+              }
+            }
+            case None => debug("No startpoints fanned out on taskName: %s" format taskName.getTaskName)
           }
         }
       }
 
       if (startpoints.isEmpty) {
-        info("No startpoints to consume. Stopping startpoint manager.")
+        info("No startpoints to consume.")
         startpointManager.stop
       } else {
         startpoints
diff --git a/samza-core/src/test/java/org/apache/samza/startpoint/StartpointManagerTestUtil.java b/samza-core/src/test/java/org/apache/samza/startpoint/StartpointManagerTestUtil.java
index f8eb855..1071904 100644
--- a/samza-core/src/test/java/org/apache/samza/startpoint/StartpointManagerTestUtil.java
+++ b/samza-core/src/test/java/org/apache/samza/startpoint/StartpointManagerTestUtil.java
@@ -18,17 +18,36 @@
  */
 package org.apache.samza.startpoint;
 
+import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.metadatastore.InMemoryMetadataStoreFactory;
+import org.apache.samza.metadatastore.MetadataStore;
+import org.apache.samza.metadatastore.MetadataStoreFactory;
+import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.util.NoOpMetricsRegistry;
 
 
 public class StartpointManagerTestUtil {
+  private final MetadataStore metadataStore;
+  private final StartpointManager startpointManager;
 
-  private StartpointManagerTestUtil() {
+  public StartpointManagerTestUtil() {
+    this(new InMemoryMetadataStoreFactory(), new MapConfig(), new NoOpMetricsRegistry());
   }
 
-  public static StartpointManager getStartpointManager() {
-    return new StartpointManager(new InMemoryMetadataStoreFactory(), new MapConfig(), new NoOpMetricsRegistry());
+  public StartpointManagerTestUtil(MetadataStoreFactory metadataStoreFactory, Config config, MetricsRegistry metricsRegistry) {
+    this.metadataStore = metadataStoreFactory.getMetadataStore(StartpointManager.NAMESPACE, config, metricsRegistry);
+    this.metadataStore.init();
+    this.startpointManager = new StartpointManager(metadataStore);
+    this.startpointManager.start();
+  }
+
+  public StartpointManager getStartpointManager() {
+    return startpointManager;
+  }
+
+  public void stop() {
+    startpointManager.stop();
+    metadataStore.close();
   }
 }
diff --git a/samza-core/src/test/java/org/apache/samza/startpoint/StartpointMock.java b/samza-core/src/test/java/org/apache/samza/startpoint/StartpointMock.java
new file mode 100644
index 0000000..7c8b50d
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/startpoint/StartpointMock.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.startpoint;
+
+import java.time.Instant;
+
+
+// Can't mock concrete Startpoint classes because they are final.
+public class StartpointMock extends Startpoint {
+  public StartpointMock() {
+    super(Instant.now().toEpochMilli());
+  }
+
+  public StartpointMock(long creationTimestamp) {
+    super(creationTimestamp);
+  }
+
+  @Override
+  public <IN, OUT> OUT apply(IN input, StartpointVisitor<IN, OUT> startpointVisitor) {
+    // mocked
+    return startpointVisitor.visit(input, new StartpointSpecific("Mocked"));
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointKey.java b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointKey.java
deleted file mode 100644
index 72f922f..0000000
--- a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointKey.java
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.startpoint;
-
-import java.io.IOException;
-import java.util.LinkedHashMap;
-import org.apache.samza.Partition;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.serializers.JsonSerdeV2;
-import org.apache.samza.system.SystemStreamPartition;
-import org.codehaus.jackson.map.ObjectMapper;
-import org.junit.Assert;
-import org.junit.Test;
-
-
-public class TestStartpointKey {
-
-  @Test
-  public void testStartpointKey() {
-    SystemStreamPartition ssp1 = new SystemStreamPartition("system", "stream", new Partition(2));
-    SystemStreamPartition ssp2 = new SystemStreamPartition("system", "stream", new Partition(3));
-
-    StartpointKey startpointKey1 = new StartpointKey(ssp1);
-    StartpointKey startpointKey2 = new StartpointKey(ssp1);
-    StartpointKey startpointKeyWithDifferentSSP = new StartpointKey(ssp2);
-    StartpointKey startpointKeyWithTask1 = new StartpointKey(ssp1, new TaskName("t1"));
-    StartpointKey startpointKeyWithTask2 = new StartpointKey(ssp1, new TaskName("t1"));
-    StartpointKey startpointKeyWithDifferentTask = new StartpointKey(ssp1, new TaskName("t2"));
-
-    Assert.assertEquals(startpointKey1, startpointKey2);
-    Assert.assertEquals(new String(new JsonSerdeV2<>().toBytes(startpointKey1)),
-        new String(new JsonSerdeV2<>().toBytes(startpointKey2)));
-    Assert.assertEquals(startpointKeyWithTask1, startpointKeyWithTask2);
-    Assert.assertEquals(new String(new JsonSerdeV2<>().toBytes(startpointKeyWithTask1)),
-        new String(new JsonSerdeV2<>().toBytes(startpointKeyWithTask2)));
-
-    Assert.assertNotEquals(startpointKey1, startpointKeyWithTask1);
-    Assert.assertNotEquals(new String(new JsonSerdeV2<>().toBytes(startpointKey1)),
-        new String(new JsonSerdeV2<>().toBytes(startpointKeyWithTask1)));
-
-    Assert.assertNotEquals(startpointKey1, startpointKeyWithDifferentSSP);
-    Assert.assertNotEquals(new String(new JsonSerdeV2<>().toBytes(startpointKey1)),
-        new String(new JsonSerdeV2<>().toBytes(startpointKeyWithDifferentSSP)));
-    Assert.assertNotEquals(startpointKeyWithTask1, startpointKeyWithDifferentTask);
-    Assert.assertNotEquals(new String(new JsonSerdeV2<>().toBytes(startpointKeyWithTask1)),
-        new String(new JsonSerdeV2<>().toBytes(startpointKeyWithDifferentTask)));
-
-    Assert.assertNotEquals(startpointKeyWithTask1, startpointKeyWithDifferentTask);
-    Assert.assertNotEquals(new String(new JsonSerdeV2<>().toBytes(startpointKeyWithTask1)),
-        new String(new JsonSerdeV2<>().toBytes(startpointKeyWithDifferentTask)));
-  }
-
-  @Test
-  public void testStartpointKeyFormat() throws IOException {
-    SystemStreamPartition ssp = new SystemStreamPartition("system1", "stream1", new Partition(2));
-    StartpointKey startpointKeyWithTask = new StartpointKey(ssp, new TaskName("t1"));
-    ObjectMapper objectMapper = new ObjectMapper();
-    byte[] jsonBytes = new JsonSerdeV2<>().toBytes(startpointKeyWithTask);
-    LinkedHashMap<String, String> deserialized = objectMapper.readValue(jsonBytes, LinkedHashMap.class);
-
-    Assert.assertEquals(4, deserialized.size());
-    Assert.assertEquals("system1", deserialized.get("system"));
-    Assert.assertEquals("stream1", deserialized.get("stream"));
-    Assert.assertEquals(2, deserialized.get("partition"));
-    Assert.assertEquals("t1", deserialized.get("taskName"));
-  }
-}
diff --git a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointManager.java b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointManager.java
index b4c33c3..d615c54 100644
--- a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointManager.java
+++ b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointManager.java
@@ -21,6 +21,7 @@
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
+import java.io.IOException;
 import java.time.Instant;
 import java.util.HashMap;
 import java.util.List;
@@ -34,14 +35,13 @@
 import org.apache.samza.coordinator.metadatastore.CoordinatorStreamStore;
 import org.apache.samza.coordinator.metadatastore.CoordinatorStreamStoreTestUtil;
 import org.apache.samza.coordinator.metadatastore.NamespaceAwareCoordinatorStreamStore;
-import org.apache.samza.job.model.ContainerModel;
-import org.apache.samza.job.model.JobModel;
-import org.apache.samza.job.model.TaskModel;
 import org.apache.samza.system.SystemStreamPartition;
+import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
 
+
 public class TestStartpointManager {
 
   private static final Config CONFIG = new MapConfig(ImmutableMap.of("job.name", "test-job", "job.coordinator.system", "test-kafka"));
@@ -53,14 +53,23 @@
   public void setup() {
     CoordinatorStreamStoreTestUtil coordinatorStreamStoreTestUtil = new CoordinatorStreamStoreTestUtil(CONFIG);
     coordinatorStreamStore = coordinatorStreamStoreTestUtil.getCoordinatorStreamStore();
+    coordinatorStreamStore.init();
     startpointManager = new StartpointManager(coordinatorStreamStore);
+    startpointManager.start();
+  }
+
+  @After
+  public void teardown() {
+    startpointManager.stop();
+    coordinatorStreamStore.close();
   }
 
   @Test
   public void testDefaultMetadataStore() {
     StartpointManager startpointManager = new StartpointManager(coordinatorStreamStore);
     Assert.assertNotNull(startpointManager);
-    Assert.assertEquals(NamespaceAwareCoordinatorStreamStore.class, startpointManager.getMetadataStore().getClass());
+    Assert.assertEquals(NamespaceAwareCoordinatorStreamStore.class, startpointManager.getReadWriteStore().getClass());
+    Assert.assertEquals(NamespaceAwareCoordinatorStreamStore.class, startpointManager.getFanOutStore().getClass());
   }
 
   @Test
@@ -68,19 +77,18 @@
     SystemStreamPartition ssp = new SystemStreamPartition("mockSystem", "mockStream", new Partition(2));
     TaskName taskName = new TaskName("MockTask");
 
-    startpointManager.start();
     long staleTimestamp = Instant.now().toEpochMilli() - StartpointManager.DEFAULT_EXPIRATION_DURATION.toMillis() - 2;
     StartpointTimestamp startpoint = new StartpointTimestamp(staleTimestamp, staleTimestamp);
 
     startpointManager.writeStartpoint(ssp, startpoint);
-    Assert.assertNull(startpointManager.readStartpoint(ssp));
+    Assert.assertFalse(startpointManager.readStartpoint(ssp).isPresent());
 
     startpointManager.writeStartpoint(ssp, taskName, startpoint);
-    Assert.assertNull(startpointManager.readStartpoint(ssp, taskName));
+    Assert.assertFalse(startpointManager.readStartpoint(ssp, taskName).isPresent());
   }
 
   @Test
-  public void testNoLongerUsableAfterStop() {
+  public void testNoLongerUsableAfterStop() throws IOException {
     StartpointManager startpointManager = new StartpointManager(coordinatorStreamStore);
     startpointManager.start();
     SystemStreamPartition ssp =
@@ -121,14 +129,23 @@
     } catch (IllegalStateException ex) { }
 
     try {
-      startpointManager.fanOutStartpointsToTasks(new JobModel(new MapConfig(), new HashMap<>()));
+      startpointManager.fanOut(new HashMap<>());
+      Assert.fail("Expected precondition exception.");
+    } catch (IllegalStateException ex) { }
+
+    try {
+      startpointManager.getFanOutForTask(new TaskName("t0"));
+      Assert.fail("Expected precondition exception.");
+    } catch (IllegalStateException ex) { }
+
+    try {
+      startpointManager.removeFanOutForTask(new TaskName("t0"));
       Assert.fail("Expected precondition exception.");
     } catch (IllegalStateException ex) { }
   }
 
   @Test
   public void testBasics() {
-    startpointManager.start();
     SystemStreamPartition ssp =
         new SystemStreamPartition("mockSystem", "mockStream", new Partition(2));
     TaskName taskName = new TaskName("MockTask");
@@ -144,18 +161,18 @@
     Assert.assertNotNull(startpoint4.getCreationTimestamp());
 
     // Test reads on non-existent keys
-    Assert.assertNull(startpointManager.readStartpoint(ssp));
-    Assert.assertNull(startpointManager.readStartpoint(ssp, taskName));
+    Assert.assertFalse(startpointManager.readStartpoint(ssp).isPresent());
+    Assert.assertFalse(startpointManager.readStartpoint(ssp, taskName).isPresent());
 
     // Test writes
     Startpoint startpointFromStore;
     startpointManager.writeStartpoint(ssp, startpoint1);
     startpointManager.writeStartpoint(ssp, taskName, startpoint2);
-    startpointFromStore = startpointManager.readStartpoint(ssp);
+    startpointFromStore = startpointManager.readStartpoint(ssp).get();
     Assert.assertEquals(StartpointTimestamp.class, startpointFromStore.getClass());
     Assert.assertEquals(startpoint1.getTimestampOffset(), ((StartpointTimestamp) startpointFromStore).getTimestampOffset());
     Assert.assertTrue(startpointFromStore.getCreationTimestamp() <= Instant.now().toEpochMilli());
-    startpointFromStore = startpointManager.readStartpoint(ssp, taskName);
+    startpointFromStore = startpointManager.readStartpoint(ssp, taskName).get();
     Assert.assertEquals(StartpointTimestamp.class, startpointFromStore.getClass());
     Assert.assertEquals(startpoint2.getTimestampOffset(), ((StartpointTimestamp) startpointFromStore).getTimestampOffset());
     Assert.assertTrue(startpointFromStore.getCreationTimestamp() <= Instant.now().toEpochMilli());
@@ -163,48 +180,40 @@
     // Test overwrites
     startpointManager.writeStartpoint(ssp, startpoint3);
     startpointManager.writeStartpoint(ssp, taskName, startpoint4);
-    startpointFromStore = startpointManager.readStartpoint(ssp);
+    startpointFromStore = startpointManager.readStartpoint(ssp).get();
     Assert.assertEquals(StartpointSpecific.class, startpointFromStore.getClass());
     Assert.assertEquals(startpoint3.getSpecificOffset(), ((StartpointSpecific) startpointFromStore).getSpecificOffset());
     Assert.assertTrue(startpointFromStore.getCreationTimestamp() <= Instant.now().toEpochMilli());
-    startpointFromStore = startpointManager.readStartpoint(ssp, taskName);
+    startpointFromStore = startpointManager.readStartpoint(ssp, taskName).get();
     Assert.assertEquals(StartpointSpecific.class, startpointFromStore.getClass());
     Assert.assertEquals(startpoint4.getSpecificOffset(), ((StartpointSpecific) startpointFromStore).getSpecificOffset());
     Assert.assertTrue(startpointFromStore.getCreationTimestamp() <= Instant.now().toEpochMilli());
 
     // Test deletes on SSP keys does not affect SSP+TaskName keys
     startpointManager.deleteStartpoint(ssp);
-    Assert.assertNull(startpointManager.readStartpoint(ssp));
-    Assert.assertNotNull(startpointManager.readStartpoint(ssp, taskName));
+    Assert.assertFalse(startpointManager.readStartpoint(ssp).isPresent());
+    Assert.assertTrue(startpointManager.readStartpoint(ssp, taskName).isPresent());
 
     // Test deletes on SSP+TaskName keys does not affect SSP keys
     startpointManager.writeStartpoint(ssp, startpoint3);
     startpointManager.deleteStartpoint(ssp, taskName);
-    Assert.assertNull(startpointManager.readStartpoint(ssp, taskName));
-    Assert.assertNotNull(startpointManager.readStartpoint(ssp));
-
-    startpointManager.stop();
+    Assert.assertFalse(startpointManager.readStartpoint(ssp, taskName).isPresent());
+    Assert.assertTrue(startpointManager.readStartpoint(ssp).isPresent());
   }
 
   @Test
-  public void testGroupStartpointsPerTask() {
-    MapConfig config = new MapConfig();
-    startpointManager.start();
-    SystemStreamPartition sspBroadcast =
-        new SystemStreamPartition("mockSystem1", "mockStream1", new Partition(2));
-    SystemStreamPartition sspBroadcast2 =
-        new SystemStreamPartition("mockSystem3", "mockStream3", new Partition(4));
-    SystemStreamPartition sspSingle =
-        new SystemStreamPartition("mockSystem2", "mockStream2", new Partition(3));
+  public void testFanOutBasic() throws IOException {
+    SystemStreamPartition sspBroadcast = new SystemStreamPartition("mockSystem1", "mockStream1", new Partition(2));
+    SystemStreamPartition sspSingle = new SystemStreamPartition("mockSystem2", "mockStream2", new Partition(3));
+
+    TaskName taskWithNonBroadcast = new TaskName("t1");
 
     List<TaskName> tasks =
-        ImmutableList.of(new TaskName("t0"), new TaskName("t1"), new TaskName("t2"), new TaskName("t3"), new TaskName("t4"), new TaskName("t5"));
+        ImmutableList.of(new TaskName("t0"), taskWithNonBroadcast, new TaskName("t2"), new TaskName("t3"), new TaskName("t4"), new TaskName("t5"));
 
-    Map<TaskName, TaskModel> taskModelMap = tasks.stream()
-        .map(task -> new TaskModel(task, task.getTaskName().equals("t1") ? ImmutableSet.of(sspBroadcast, sspBroadcast2, sspSingle) : ImmutableSet.of(sspBroadcast, sspBroadcast2), new Partition(1)))
-        .collect(Collectors.toMap(taskModel -> taskModel.getTaskName(), taskModel -> taskModel));
-    ContainerModel containerModel = new ContainerModel("container 0", taskModelMap);
-    JobModel jobModel = new JobModel(config, ImmutableMap.of(containerModel.getId(), containerModel));
+    Map<TaskName, Set<SystemStreamPartition>> taskToSSPs = tasks.stream()
+        .collect(Collectors
+            .toMap(task -> task, task -> task.equals(taskWithNonBroadcast) ? ImmutableSet.of(sspBroadcast, sspSingle) : ImmutableSet.of(sspBroadcast)));
 
     StartpointSpecific startpoint42 = new StartpointSpecific("42");
 
@@ -212,54 +221,94 @@
     startpointManager.writeStartpoint(sspSingle, startpoint42);
 
     // startpoint42 should remap with key sspBroadcast to all tasks + sspBroadcast
-    Set<SystemStreamPartition> systemStreamPartitions = startpointManager.fanOutStartpointsToTasks(jobModel);
-    Assert.assertEquals(2, systemStreamPartitions.size());
-    Assert.assertTrue(systemStreamPartitions.containsAll(ImmutableSet.of(sspBroadcast, sspSingle)));
+    Map<TaskName, Map<SystemStreamPartition, Startpoint>> tasksFannedOutTo = startpointManager.fanOut(taskToSSPs);
+    Assert.assertEquals(tasks.size(), tasksFannedOutTo.size());
+    Assert.assertTrue(tasksFannedOutTo.keySet().containsAll(tasks));
+    Assert.assertFalse("Should be deleted after fan out", startpointManager.readStartpoint(sspBroadcast).isPresent());
+    Assert.assertFalse("Should be deleted after fan out", startpointManager.readStartpoint(sspSingle).isPresent());
 
     for (TaskName taskName : tasks) {
-      // startpoint42 should be mapped to all tasks for sspBroadcast
-      Startpoint startpointFromStore = startpointManager.readStartpoint(sspBroadcast, taskName);
+      Map<SystemStreamPartition, Startpoint> fanOutForTask = startpointManager.getFanOutForTask(taskName);
+      if (taskName.equals(taskWithNonBroadcast)) {
+        // Non-broadcast startpoint should be fanned out to only one task
+        Assert.assertEquals("Should have broadcast and non-broadcast SSP", 2, fanOutForTask.size());
+      } else {
+        Assert.assertEquals("Should only have broadcast SSP", 1, fanOutForTask.size());
+      }
+
+      // Broadcast SSP should be on every task
+      Startpoint startpointFromStore = fanOutForTask.get(sspBroadcast);
       Assert.assertEquals(StartpointSpecific.class, startpointFromStore.getClass());
       Assert.assertEquals(startpoint42.getSpecificOffset(), ((StartpointSpecific) startpointFromStore).getSpecificOffset());
 
-      // startpoint 42 should be mapped only to task "t1" for sspSingle
-      startpointFromStore = startpointManager.readStartpoint(sspSingle, taskName);
-      if (taskName.getTaskName().equals("t1")) {
+      // startpoint mapped only to task "t1" for Non-broadcast SSP
+      startpointFromStore = fanOutForTask.get(sspSingle);
+      if (taskName.equals(taskWithNonBroadcast)) {
         Assert.assertEquals(StartpointSpecific.class, startpointFromStore.getClass());
         Assert.assertEquals(startpoint42.getSpecificOffset(), ((StartpointSpecific) startpointFromStore).getSpecificOffset());
       } else {
-        Assert.assertNull(startpointFromStore);
+        Assert.assertNull("Should not have non-broadcast SSP", startpointFromStore);
       }
+
+      startpointManager.removeFanOutForTask(taskName);
+      Assert.assertTrue(startpointManager.getFanOutForTask(taskName).isEmpty());
     }
-    Assert.assertNull(startpointManager.readStartpoint(sspBroadcast));
-    Assert.assertNull(startpointManager.readStartpoint(sspSingle));
+  }
 
-    // Test startpoints that were explicit assigned to an sspBroadcast2+TaskName will not be overwritten from fanOutStartpointsToTasks
+  @Test
+  public void testFanOutWithStartpointResolutions() throws IOException {
+    SystemStreamPartition sspBroadcast = new SystemStreamPartition("mockSystem1", "mockStream1", new Partition(2));
+    SystemStreamPartition sspSingle = new SystemStreamPartition("mockSystem2", "mockStream2", new Partition(3));
 
-    StartpointSpecific startpoint1024 = new StartpointSpecific("1024");
+    List<TaskName> tasks =
+        ImmutableList.of(new TaskName("t0"), new TaskName("t1"), new TaskName("t2"), new TaskName("t3"), new TaskName("t4"));
 
-    startpointManager.writeStartpoint(sspBroadcast2, startpoint42);
-    startpointManager.writeStartpoint(sspBroadcast2, tasks.get(1), startpoint1024);
-    startpointManager.writeStartpoint(sspBroadcast2, tasks.get(3), startpoint1024);
+    TaskName taskWithNonBroadcast = tasks.get(1);
+    TaskName taskBroadcastInPast = tasks.get(2);
+    TaskName taskBroadcastInFuture = tasks.get(3);
 
-    Set<SystemStreamPartition> sspsDeleted = startpointManager.fanOutStartpointsToTasks(jobModel);
-    Assert.assertEquals(1, sspsDeleted.size());
-    Assert.assertTrue(sspsDeleted.contains(sspBroadcast2));
+    Map<TaskName, Set<SystemStreamPartition>> taskToSSPs = tasks.stream()
+        .collect(Collectors
+            .toMap(task -> task, task -> task.equals(taskWithNonBroadcast) ? ImmutableSet.of(sspBroadcast, sspSingle) : ImmutableSet.of(sspBroadcast)));
 
-    StartpointSpecific startpointFromStore = (StartpointSpecific) startpointManager.readStartpoint(sspBroadcast2, tasks.get(0));
-    Assert.assertEquals(startpoint42.getSpecificOffset(), startpointFromStore.getSpecificOffset());
-    startpointFromStore = (StartpointSpecific) startpointManager.readStartpoint(sspBroadcast2, tasks.get(1));
-    Assert.assertEquals(startpoint1024.getSpecificOffset(), startpointFromStore.getSpecificOffset());
-    startpointFromStore = (StartpointSpecific) startpointManager.readStartpoint(sspBroadcast2, tasks.get(2));
-    Assert.assertEquals(startpoint42.getSpecificOffset(), startpointFromStore.getSpecificOffset());
-    startpointFromStore = (StartpointSpecific) startpointManager.readStartpoint(sspBroadcast2, tasks.get(3));
-    Assert.assertEquals(startpoint1024.getSpecificOffset(), startpointFromStore.getSpecificOffset());
-    startpointFromStore = (StartpointSpecific) startpointManager.readStartpoint(sspBroadcast2, tasks.get(4));
-    Assert.assertEquals(startpoint42.getSpecificOffset(), startpointFromStore.getSpecificOffset());
-    startpointFromStore = (StartpointSpecific) startpointManager.readStartpoint(sspBroadcast2, tasks.get(5));
-    Assert.assertEquals(startpoint42.getSpecificOffset(), startpointFromStore.getSpecificOffset());
-    Assert.assertNull(startpointManager.readStartpoint(sspBroadcast2));
+    Instant now = Instant.now();
+    StartpointMock startpointPast = new StartpointMock(now.minusMillis(10000L).toEpochMilli());
+    StartpointMock startpointPresent = new StartpointMock(now.toEpochMilli());
+    StartpointMock startpointFuture = new StartpointMock(now.plusMillis(10000L).toEpochMilli());
 
-    startpointManager.stop();
+    startpointManager.getObjectMapper().registerSubtypes(StartpointMock.class);
+    startpointManager.writeStartpoint(sspSingle, startpointPast);
+    startpointManager.writeStartpoint(sspSingle, startpointPresent);
+    startpointManager.writeStartpoint(sspBroadcast, startpointPresent);
+    startpointManager.writeStartpoint(sspBroadcast, taskBroadcastInPast, startpointPast);
+    startpointManager.writeStartpoint(sspBroadcast, taskBroadcastInFuture, startpointFuture);
+
+    Map<TaskName, Map<SystemStreamPartition, Startpoint>> fannedOut = startpointManager.fanOut(taskToSSPs);
+    Assert.assertEquals(tasks.size(), fannedOut.size());
+    Assert.assertTrue(fannedOut.keySet().containsAll(tasks));
+    Assert.assertFalse("Should be deleted after fan out", startpointManager.readStartpoint(sspBroadcast).isPresent());
+    for (TaskName taskName : fannedOut.keySet()) {
+      Assert.assertFalse("Should be deleted after fan out for task: " + taskName.getTaskName(), startpointManager.readStartpoint(sspBroadcast, taskName).isPresent());
+    }
+    Assert.assertFalse("Should be deleted after fan out", startpointManager.readStartpoint(sspSingle).isPresent());
+
+    for (TaskName taskName : tasks) {
+      Map<SystemStreamPartition, Startpoint> fanOutForTask = startpointManager.getFanOutForTask(taskName);
+      if (taskName.equals(taskWithNonBroadcast)) {
+        Assert.assertEquals(startpointPresent, fanOutForTask.get(sspSingle));
+        Assert.assertEquals(startpointPresent, fanOutForTask.get(sspBroadcast));
+      } else if (taskName.equals(taskBroadcastInPast)) {
+        Assert.assertNull(fanOutForTask.get(sspSingle));
+        Assert.assertEquals(startpointPresent, fanOutForTask.get(sspBroadcast));
+      } else if (taskName.equals(taskBroadcastInFuture)) {
+        Assert.assertNull(fanOutForTask.get(sspSingle));
+        Assert.assertEquals(startpointFuture, fanOutForTask.get(sspBroadcast));
+      } else {
+        Assert.assertNull(fanOutForTask.get(sspSingle));
+        Assert.assertEquals(startpointPresent, fanOutForTask.get(sspBroadcast));
+      }
+      startpointManager.removeFanOutForTask(taskName);
+      Assert.assertTrue(startpointManager.getFanOutForTask(taskName).isEmpty());
+    }
   }
 }
diff --git a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointObjectMapper.java b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointObjectMapper.java
new file mode 100644
index 0000000..70a3bf7
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointObjectMapper.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.startpoint;
+
+import java.io.IOException;
+import java.time.Instant;
+import org.apache.samza.Partition;
+import org.apache.samza.system.SystemStreamPartition;
+import org.codehaus.jackson.map.ObjectMapper;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestStartpointObjectMapper {
+  private static final ObjectMapper MAPPER = StartpointObjectMapper.getObjectMapper();
+
+  @Test
+  public void testStartpointSpecificSerde() throws IOException {
+    StartpointSpecific startpointSpecific = new StartpointSpecific("42");
+    Startpoint startpointFromSerde = MAPPER.readValue(MAPPER.writeValueAsBytes(startpointSpecific), Startpoint.class);
+
+    Assert.assertEquals(startpointSpecific.getClass(), startpointFromSerde.getClass());
+    Assert.assertEquals(startpointSpecific.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
+    Assert.assertEquals(startpointSpecific.getSpecificOffset(), ((StartpointSpecific) startpointFromSerde).getSpecificOffset());
+  }
+
+  @Test
+  public void testStartpointTimestampSerde() throws IOException {
+    StartpointTimestamp startpointTimestamp = new StartpointTimestamp(123456L);
+    Startpoint startpointFromSerde = MAPPER.readValue(MAPPER.writeValueAsBytes(startpointTimestamp), Startpoint.class);
+
+    Assert.assertEquals(startpointTimestamp.getClass(), startpointFromSerde.getClass());
+    Assert.assertEquals(startpointTimestamp.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
+    Assert.assertEquals(startpointTimestamp.getTimestampOffset(), ((StartpointTimestamp) startpointFromSerde).getTimestampOffset());
+  }
+
+  @Test
+  public void testStartpointEarliestSerde() throws IOException {
+    StartpointOldest startpointOldest = new StartpointOldest();
+    Startpoint startpointFromSerde = MAPPER.readValue(MAPPER.writeValueAsBytes(startpointOldest), Startpoint.class);
+
+    Assert.assertEquals(startpointOldest.getClass(), startpointFromSerde.getClass());
+    Assert.assertEquals(startpointOldest.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
+  }
+
+  @Test
+  public void testStartpointLatestSerde() throws IOException {
+    StartpointUpcoming startpointUpcoming = new StartpointUpcoming();
+    Startpoint startpointFromSerde = MAPPER.readValue(MAPPER.writeValueAsBytes(startpointUpcoming), Startpoint.class);
+
+    Assert.assertEquals(startpointUpcoming.getClass(), startpointFromSerde.getClass());
+    Assert.assertEquals(startpointUpcoming.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
+  }
+
+  @Test
+  public void testFanOutSerde() throws IOException {
+    StartpointFanOutPerTask startpointFanOutPerTask = new StartpointFanOutPerTask(Instant.now().minusSeconds(60));
+    startpointFanOutPerTask.getFanOuts()
+        .put(new SystemStreamPartition("system1", "stream1", new Partition(1)), new StartpointUpcoming());
+    startpointFanOutPerTask.getFanOuts()
+        .put(new SystemStreamPartition("system2", "stream2", new Partition(2)), new StartpointOldest());
+
+    String serialized = MAPPER.writeValueAsString(startpointFanOutPerTask);
+    StartpointFanOutPerTask startpointFanOutPerTaskFromSerde = MAPPER.readValue(serialized, StartpointFanOutPerTask.class);
+
+    Assert.assertEquals(startpointFanOutPerTask, startpointFanOutPerTaskFromSerde);
+  }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointSerde.java b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointSerde.java
deleted file mode 100644
index 9dfb784..0000000
--- a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointSerde.java
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.startpoint;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class TestStartpointSerde {
-  private final StartpointSerde startpointSerde = new StartpointSerde();
-
-  @Test
-  public void testStartpointSpecificSerde() {
-    StartpointSpecific startpointSpecific = new StartpointSpecific("42");
-    Startpoint startpointFromSerde = startpointSerde.fromBytes(startpointSerde.toBytes(startpointSpecific));
-
-    Assert.assertEquals(startpointSpecific.getClass(), startpointFromSerde.getClass());
-    Assert.assertEquals(startpointSpecific.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
-    Assert.assertEquals(startpointSpecific.getSpecificOffset(), ((StartpointSpecific) startpointFromSerde).getSpecificOffset());
-  }
-
-  @Test
-  public void testStartpointTimestampSerde() {
-    StartpointTimestamp startpointTimestamp = new StartpointTimestamp(123456L);
-    Startpoint startpointFromSerde = startpointSerde.fromBytes(startpointSerde.toBytes(startpointTimestamp));
-
-    Assert.assertEquals(startpointTimestamp.getClass(), startpointFromSerde.getClass());
-    Assert.assertEquals(startpointTimestamp.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
-    Assert.assertEquals(startpointTimestamp.getTimestampOffset(), ((StartpointTimestamp) startpointFromSerde).getTimestampOffset());
-  }
-
-  @Test
-  public void testStartpointEarliestSerde() {
-    StartpointOldest startpointOldest = new StartpointOldest();
-    Startpoint startpointFromSerde = startpointSerde.fromBytes(startpointSerde.toBytes(startpointOldest));
-
-    Assert.assertEquals(startpointOldest.getClass(), startpointFromSerde.getClass());
-    Assert.assertEquals(startpointOldest.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
-  }
-
-  @Test
-  public void testStartpointLatestSerde() {
-    StartpointUpcoming startpointUpcoming = new StartpointUpcoming();
-    Startpoint startpointFromSerde = startpointSerde.fromBytes(startpointSerde.toBytes(startpointUpcoming));
-
-    Assert.assertEquals(startpointUpcoming.getClass(), startpointFromSerde.getClass());
-    Assert.assertEquals(startpointUpcoming.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
-  }
-}
diff --git a/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java b/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java
index b5e35cf..44efc5b 100644
--- a/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java
+++ b/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java
@@ -293,15 +293,15 @@
 
     initTables(cachingTable, guavaTable, remoteTable);
 
-    // 3 per readable table (9)
-    // 5 per read/write table (15)
-    verify(metricsRegistry, times(24)).newCounter(any(), anyString());
+    // 4 per readable table (12)
+    // 6 per read/write table (18)
+    verify(metricsRegistry, times(30)).newCounter(any(), anyString());
 
-    // 2 per readable table (6)
-    // 5 per read/write table (15)
+    // 3 per readable table (9)
+    // 6 per read/write table (18)
     // 1 per remote readable table (1)
     // 1 per remote read/write table (1)
-    verify(metricsRegistry, times(23)).newTimer(any(), anyString());
+    verify(metricsRegistry, times(29)).newTimer(any(), anyString());
 
     // 1 per guava table (1)
     // 3 per caching table (2)
diff --git a/samza-core/src/test/java/org/apache/samza/table/ratelimit/TestAsyncRateLimitedTable.java b/samza-core/src/test/java/org/apache/samza/table/ratelimit/TestAsyncRateLimitedTable.java
index 7c646fb..d1493e2 100644
--- a/samza-core/src/test/java/org/apache/samza/table/ratelimit/TestAsyncRateLimitedTable.java
+++ b/samza-core/src/test/java/org/apache/samza/table/ratelimit/TestAsyncRateLimitedTable.java
@@ -32,7 +32,9 @@
 import org.apache.samza.table.remote.TableReadFunction;
 import org.apache.samza.table.remote.TableWriteFunction;
 import org.apache.samza.table.remote.TestRemoteTable;
+
 import org.junit.Assert;
+import org.junit.Before;
 import org.junit.Test;
 
 import static org.mockito.Mockito.*;
@@ -42,6 +44,46 @@
 
   private final ScheduledExecutorService schedExec = Executors.newSingleThreadScheduledExecutor();
 
+  private Map<String, String> readMap = new HashMap<>();
+  private AsyncReadWriteTable readTable;
+  private TableRateLimiter readRateLimiter;
+  private TableReadFunction<String, String> readFn;
+  private AsyncReadWriteTable<String, String> writeTable;
+  private TableRateLimiter<String, String> writeRateLimiter;
+  private TableWriteFunction<String, String> writeFn;
+
+  @Before
+  public void prepare() {
+    // Read part
+    readRateLimiter = mock(TableRateLimiter.class);
+    readFn = mock(TableReadFunction.class);
+    doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any());
+    doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any(), any());
+    readMap.put("foo", "bar");
+    doReturn(CompletableFuture.completedFuture(readMap)).when(readFn).getAllAsync(any());
+    doReturn(CompletableFuture.completedFuture(readMap)).when(readFn).getAllAsync(any(), any());
+    doReturn(CompletableFuture.completedFuture(5)).when(readFn).readAsync(anyInt(), any());
+    AsyncReadWriteTable delegate = new AsyncRemoteTable(readFn, null);
+    readTable = new AsyncRateLimitedTable("t1", delegate, readRateLimiter, null, schedExec);
+    readTable.init(TestRemoteTable.getMockContext());
+
+    // Write part
+    writeRateLimiter = mock(TableRateLimiter.class);
+    writeFn = mock(TableWriteFunction.class);
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any(), any());
+    doReturn(CompletableFuture.completedFuture(5)).when(writeFn).writeAsync(anyInt(), any());
+    delegate = new AsyncRemoteTable(readFn, writeFn);
+    writeTable = new AsyncRateLimitedTable("t1", delegate, readRateLimiter, writeRateLimiter, schedExec);
+    writeTable.init(TestRemoteTable.getMockContext());
+  }
+
   @Test(expected = NullPointerException.class)
   public void testNotNullTableId() {
     new AsyncRateLimitedTable(null, mock(AsyncReadWriteTable.class),
@@ -71,71 +113,183 @@
   }
 
   @Test
-  public void testGetThrottling() throws Exception {
-    TableRateLimiter readRateLimiter = mock(TableRateLimiter.class);
-    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
-    doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any());
-    Map<String, String> result = new HashMap<>();
-    result.put("foo", "bar");
-    doReturn(CompletableFuture.completedFuture(result)).when(readFn).getAllAsync(any());
-    AsyncReadWriteTable delegate = new AsyncRemoteTable(readFn, null);
-    AsyncRateLimitedTable table = new AsyncRateLimitedTable("t1", delegate,
-        readRateLimiter, null, schedExec);
-    table.init(TestRemoteTable.getMockContext());
-
-    Assert.assertEquals("bar", table.getAsync("foo").get());
+  public void testGetAsync() {
+    Assert.assertEquals("bar", readTable.getAsync("foo").join());
     verify(readFn, times(1)).getAsync(any());
+    verify(readFn, times(0)).getAsync(any(), any());
     verify(readRateLimiter, times(1)).throttle(anyString());
-    verify(readRateLimiter, times(0)).throttle(anyList());
-
-    Assert.assertEquals(result, table.getAllAsync(Arrays.asList("")).get());
-    verify(readFn, times(1)).getAllAsync(any());
-    verify(readRateLimiter, times(1)).throttle(anyList());
-    verify(readRateLimiter, times(1)).throttle(anyString());
+    verify(readRateLimiter, times(0)).throttle(anyCollection());
+    verify(readRateLimiter, times(0)).throttle(any(), any());
+    verify(readRateLimiter, times(0)).throttle(anyInt(), any());
+    verify(readRateLimiter, times(0)).throttleRecords(anyCollection());
+    verifyWritePartNotCalled();
   }
 
   @Test
-  public void testPutThrottling() throws Exception {
-    TableRateLimiter readRateLimiter = mock(TableRateLimiter.class);
-    TableRateLimiter writeRateLimiter = mock(TableRateLimiter.class);
-    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
-    TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
-    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any());
-    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any());
-    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any());
-    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any());
-    AsyncReadWriteTable delegate = new AsyncRemoteTable(readFn, writeFn);
-    AsyncRateLimitedTable table = new AsyncRateLimitedTable("t1", delegate,
-        readRateLimiter, writeRateLimiter, schedExec);
-    table.init(TestRemoteTable.getMockContext());
+  public void testGetAsyncWithArgs() {
+    Assert.assertEquals("bar", readTable.getAsync("foo", 1).join());
+    verify(readFn, times(0)).getAsync(any());
+    verify(readFn, times(1)).getAsync(any(), any());
+    verify(readRateLimiter, times(1)).throttle(anyString(), any());
+    verify(readRateLimiter, times(0)).throttle(anyCollection());
+    verify(readRateLimiter, times(0)).throttle(any(), any());
+    verify(readRateLimiter, times(0)).throttle(anyInt(), any());
+    verify(readRateLimiter, times(0)).throttleRecords(anyCollection());
+    verifyWritePartNotCalled();
+  }
 
-    table.putAsync("foo", "bar").get();
+  @Test
+  public void testGetAllAsync() {
+    Assert.assertEquals(readMap, readTable.getAllAsync(Arrays.asList("")).join());
+    verify(readFn, times(1)).getAllAsync(any());
+    verify(readFn, times(0)).getAllAsync(any(), any());
+    verify(readRateLimiter, times(0)).throttle(anyString());
+    verify(readRateLimiter, times(1)).throttle(anyCollection());
+    verify(readRateLimiter, times(0)).throttle(any(), any());
+    verify(readRateLimiter, times(0)).throttle(anyInt(), any());
+    verify(readRateLimiter, times(0)).throttleRecords(anyCollection());
+    verifyWritePartNotCalled();
+  }
+
+  @Test
+  public void testGetAllAsyncWithArgs() {
+    Assert.assertEquals(readMap, readTable.getAllAsync(Arrays.asList(""), "").join());
+    verify(readFn, times(0)).getAllAsync(any());
+    verify(readFn, times(1)).getAllAsync(any(), any());
+    verify(readRateLimiter, times(0)).throttle(anyString());
+    verify(readRateLimiter, times(1)).throttle(anyCollection(), any());
+    verify(readRateLimiter, times(0)).throttle(anyString(), any());
+    verify(readRateLimiter, times(0)).throttle(anyInt(), any());
+    verify(readRateLimiter, times(0)).throttleRecords(anyCollection());
+    verifyWritePartNotCalled();
+  }
+
+  @Test
+  public void testReadAsync() {
+    Assert.assertEquals(5, readTable.readAsync(1, 2).join());
+    verify(readFn, times(1)).readAsync(anyInt(), any());
+    verify(readRateLimiter, times(0)).throttle(anyString());
+    verify(readRateLimiter, times(0)).throttle(anyCollection());
+    verify(readRateLimiter, times(0)).throttle(any(), any());
+    verify(readRateLimiter, times(1)).throttle(anyInt(), any());
+    verify(readRateLimiter, times(0)).throttleRecords(anyCollection());
+    verifyWritePartNotCalled();
+  }
+
+  @Test
+  public void testPutAsync() {
+    writeTable.putAsync("foo", "bar").join();
     verify(writeFn, times(1)).putAsync(any(), any());
-    verify(writeRateLimiter, times(1)).throttle(anyString(), anyString());
-    verify(writeRateLimiter, times(0)).throttleRecords(anyList());
+    verify(writeFn, times(0)).putAsync(any(), any(), any());
     verify(writeRateLimiter, times(0)).throttle(anyString());
-    verify(writeRateLimiter, times(0)).throttle(anyList());
-
-    table.putAllAsync(Arrays.asList(new Entry("1", "2"))).get();
-    verify(writeFn, times(1)).putAllAsync(any());
     verify(writeRateLimiter, times(1)).throttle(anyString(), anyString());
-    verify(writeRateLimiter, times(1)).throttleRecords(anyList());
+    verify(writeRateLimiter, times(0)).throttle(anyCollection());
+    verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+    verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+    verifyReadPartNotCalled();
+  }
+
+  @Test
+  public void testPutAsyncWithArgs() {
+    writeTable.putAsync("foo", "bar", 1).join();
+    verify(writeFn, times(0)).putAsync(any(), any());
+    verify(writeFn, times(1)).putAsync(any(), any(), any());
     verify(writeRateLimiter, times(0)).throttle(anyString());
-    verify(writeRateLimiter, times(0)).throttle(anyList());
+    verify(writeRateLimiter, times(1)).throttle(anyString(), anyString(), any());
+    verify(writeRateLimiter, times(0)).throttle(anyCollection());
+    verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+    verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+    verifyReadPartNotCalled();
+  }
 
-    table.deleteAsync("foo").get();
-    verify(writeFn, times(1)).deleteAsync(anyString());
-    verify(writeRateLimiter, times(1)).throttle(anyString(), anyString());
-    verify(writeRateLimiter, times(1)).throttleRecords(anyList());
-    verify(writeRateLimiter, times(1)).throttle(anyString());
-    verify(writeRateLimiter, times(0)).throttle(anyList());
+  @Test
+  public void testPutAllAsync() {
+    writeTable.putAllAsync(Arrays.asList(new Entry("1", "2"))).join();
+    verify(writeFn, times(1)).putAllAsync(anyCollection());
+    verify(writeFn, times(0)).putAllAsync(anyCollection(), any());
+    verify(writeRateLimiter, times(0)).throttle(anyString());
+    verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+    verify(writeRateLimiter, times(0)).throttle(anyCollection());
+    verify(writeRateLimiter, times(1)).throttleRecords(anyCollection());
+    verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+    verifyReadPartNotCalled();
+  }
 
-    table.deleteAllAsync(Arrays.asList("1", "2")).get();
-    verify(writeFn, times(1)).deleteAllAsync(any());
-    verify(writeRateLimiter, times(1)).throttle(anyString(), anyString());
-    verify(writeRateLimiter, times(1)).throttleRecords(anyList());
+  @Test
+  public void testPutAllAsyncWithArgs() {
+    writeTable.putAllAsync(Arrays.asList(new Entry("1", "2")), Arrays.asList(1)).join();
+    verify(writeFn, times(0)).putAllAsync(anyCollection());
+    verify(writeFn, times(1)).putAllAsync(anyCollection(), any());
+    verify(writeRateLimiter, times(0)).throttle(anyString());
+    verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+    verify(writeRateLimiter, times(0)).throttle(anyCollection());
+    verify(writeRateLimiter, times(1)).throttleRecords(anyCollection());
+    verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+    verifyReadPartNotCalled();
+  }
+
+  @Test
+  public void testDeleteAsync() {
+    writeTable.deleteAsync("foo").join();
+    verify(writeFn, times(1)).deleteAsync(any());
+    verify(writeFn, times(0)).deleteAsync(any(), any());
     verify(writeRateLimiter, times(1)).throttle(anyString());
-    verify(writeRateLimiter, times(1)).throttle(anyList());
+    verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+    verify(writeRateLimiter, times(0)).throttle(anyCollection());
+    verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+    verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+    verifyReadPartNotCalled();
+  }
+
+  @Test
+  public void testDeleteAsyncWithArgs() {
+    writeTable.deleteAsync("foo", 1).join();
+    verify(writeFn, times(0)).deleteAsync(any());
+    verify(writeFn, times(1)).deleteAsync(any(), any());
+    verify(writeRateLimiter, times(1)).throttle(anyString(), any());
+    verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+    verify(writeRateLimiter, times(0)).throttle(anyCollection());
+    verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+    verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+    verifyReadPartNotCalled();
+  }
+
+  @Test
+  public void testDeleteAllAsync() {
+    writeTable.deleteAllAsync(Arrays.asList("1", "2")).join();
+    verify(writeFn, times(1)).deleteAllAsync(anyCollection());
+    verify(writeFn, times(0)).deleteAllAsync(anyCollection(), any());
+    verify(writeRateLimiter, times(0)).throttle(anyString());
+    verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+    verify(writeRateLimiter, times(1)).throttle(anyCollection());
+    verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+    verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+    verifyReadPartNotCalled();
+  }
+
+  @Test
+  public void testDeleteAllAsyncWithArgs() {
+    writeTable.deleteAllAsync(Arrays.asList("1", "2"), 1).join();
+    verify(writeFn, times(0)).deleteAllAsync(anyCollection());
+    verify(writeFn, times(1)).deleteAllAsync(anyCollection(), any());
+    verify(writeRateLimiter, times(0)).throttle(anyString());
+    verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+    verify(writeRateLimiter, times(1)).throttle(anyCollection(), any());
+    verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+    verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+    verifyReadPartNotCalled();
+  }
+
+  @Test
+  public void testWriteAsync() {
+    Assert.assertEquals(5, writeTable.writeAsync(1, 2).join());
+    verify(writeFn, times(1)).writeAsync(anyInt(), any());
+    verify(writeRateLimiter, times(0)).throttle(anyString());
+    verify(writeRateLimiter, times(0)).throttle(anyCollection());
+    verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+    verify(writeRateLimiter, times(1)).throttle(anyInt(), any());
+    verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+    verifyReadPartNotCalled();
   }
 
   @Test
@@ -157,4 +311,34 @@
     verify(writeFn, times(1)).close();
   }
 
+  private void verifyReadPartNotCalled() {
+    verify(readFn, times(0)).getAsync(any());
+    verify(readFn, times(0)).getAsync(any(), any());
+    verify(readFn, times(0)).getAllAsync(any(), any());
+    verify(readFn, times(0)).getAllAsync(any(), any(), any());
+    verify(readFn, times(0)).readAsync(anyInt(), any());
+    verify(readRateLimiter, times(0)).throttle(anyString());
+    verify(readRateLimiter, times(0)).throttle(anyCollection());
+    verify(readRateLimiter, times(0)).throttle(any(), any());
+    verify(readRateLimiter, times(0)).throttle(anyInt(), any());
+    verify(readRateLimiter, times(0)).throttleRecords(anyCollection());
+  }
+
+  private void verifyWritePartNotCalled() {
+    verify(writeFn, times(0)).putAsync(any(), any());
+    verify(writeFn, times(0)).putAsync(any(), any(), any());
+    verify(writeFn, times(0)).putAllAsync(any());
+    verify(writeFn, times(0)).putAllAsync(any(), any());
+    verify(writeFn, times(0)).deleteAsync(any());
+    verify(writeFn, times(0)).deleteAsync(any(), any());
+    verify(writeFn, times(0)).deleteAllAsync(any());
+    verify(writeFn, times(0)).deleteAllAsync(any(), any());
+    verify(writeFn, times(0)).writeAsync(anyInt(), any());
+    verify(writeRateLimiter, times(0)).throttle(anyString());
+    verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+    verify(writeRateLimiter, times(0)).throttle(anyCollection());
+    verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+    verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+  }
+
 }
diff --git a/samza-core/src/test/java/org/apache/samza/table/remote/TestAsyncRemoteTable.java b/samza-core/src/test/java/org/apache/samza/table/remote/TestAsyncRemoteTable.java
index 2a447e0..d557c31 100644
--- a/samza-core/src/test/java/org/apache/samza/table/remote/TestAsyncRemoteTable.java
+++ b/samza-core/src/test/java/org/apache/samza/table/remote/TestAsyncRemoteTable.java
@@ -20,7 +20,6 @@
 
 import java.util.Arrays;
 
-import org.apache.samza.context.Context;
 import org.apache.samza.storage.kv.Entry;
 
 import org.junit.Before;
@@ -30,6 +29,7 @@
 
 import static org.junit.Assert.assertTrue;
 
+import static org.mockito.Matchers.*;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
@@ -60,6 +60,15 @@
   }
 
   @Test
+  public void testGetAsyncWithArgs() {
+    int times = 0;
+    roTable.getAsync(1, 1);
+    verify(readFn, times(++times)).getAsync(any(), any());
+    rwTable.getAsync(1, 1);
+    verify(readFn, times(++times)).getAsync(any(), any());
+  }
+
+  @Test
   public void testGetAllAsync() {
     int times = 0;
     roTable.getAllAsync(Arrays.asList(1, 2));
@@ -69,6 +78,24 @@
   }
 
   @Test
+  public void testGetAllAsyncWithArgs() {
+    int times = 0;
+    roTable.getAllAsync(Arrays.asList(1, 2), Arrays.asList(0, 0));
+    verify(readFn, times(++times)).getAllAsync(any(), any());
+    rwTable.getAllAsync(Arrays.asList(1, 2), Arrays.asList(0, 0));
+    verify(readFn, times(++times)).getAllAsync(any(), any());
+  }
+
+  @Test
+  public void testReadAsync() {
+    int times = 0;
+    roTable.readAsync(1, 2, 3);
+    verify(readFn, times(++times)).readAsync(anyInt(), any(), any());
+    rwTable.readAsync(1, 2, 3);
+    verify(readFn, times(++times)).readAsync(anyInt(), any(), any());
+  }
+
+  @Test
   public void testPutAsync() {
     verifyFailure(() -> roTable.putAsync(1, 2));
     rwTable.putAsync(1, 2);
@@ -76,6 +103,13 @@
   }
 
   @Test
+  public void testPutAsyncWithArgs() {
+    verifyFailure(() -> roTable.putAsync(1, 2, 3));
+    rwTable.putAsync(1, 2, 3);
+    verify(writeFn, times(1)).putAsync(any(), any(), any());
+  }
+
+  @Test
   public void testPutAllAsync() {
     verifyFailure(() -> roTable.putAllAsync(Arrays.asList(new Entry(1, 2))));
     rwTable.putAllAsync(Arrays.asList(new Entry(1, 2)));
@@ -83,11 +117,26 @@
   }
 
   @Test
+  public void testPutAllAsyncWithArgs() {
+    verifyFailure(() -> roTable.putAllAsync(Arrays.asList(new Entry(1, 2)), Arrays.asList(0, 0)));
+    rwTable.putAllAsync(Arrays.asList(new Entry(1, 2)), Arrays.asList(0, 0));
+    verify(writeFn, times(1)).putAllAsync(any(), any());
+  }
+
+  @Test
   public void testDeleteAsync() {
     verifyFailure(() -> roTable.deleteAsync(1));
     rwTable.deleteAsync(1);
     verify(writeFn, times(1)).deleteAsync(any());
   }
+
+  @Test
+  public void testDeleteAsyncWithArgs() {
+    verifyFailure(() -> roTable.deleteAsync(1, 2));
+    rwTable.deleteAsync(1, 2);
+    verify(writeFn, times(1)).deleteAsync(any(), any());
+  }
+
   @Test
   public void testDeleteAllAsync() {
     verifyFailure(() -> roTable.deleteAllAsync(Arrays.asList(1)));
@@ -96,13 +145,17 @@
   }
 
   @Test
-  public void testInit() {
-    roTable.init(mock(Context.class));
-    verify(readFn, times(1)).init(any());
-    verify(writeFn, times(0)).init(any());
-    rwTable.init(mock(Context.class));
-    verify(readFn, times(2)).init(any());
-    verify(writeFn, times(1)).init(any());
+  public void testDeleteAllAsyncWithArgs() {
+    verifyFailure(() -> roTable.deleteAllAsync(Arrays.asList(1), Arrays.asList(2)));
+    rwTable.deleteAllAsync(Arrays.asList(1, 2), Arrays.asList(2));
+    verify(writeFn, times(1)).deleteAllAsync(any(), any());
+  }
+
+  @Test
+  public void testWriteAsync() {
+    verifyFailure(() -> roTable.writeAsync(1, 2, 3));
+    rwTable.writeAsync(1, 2, 3);
+    verify(writeFn, times(1)).writeAsync(anyInt(), any(), any());
   }
 
   @Test
diff --git a/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java b/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java
index 93b2dab..02625bb 100644
--- a/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java
+++ b/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java
@@ -26,8 +26,12 @@
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.metrics.Timer;
 import org.apache.samza.storage.kv.Entry;
+import org.apache.samza.table.AsyncReadWriteTable;
+import org.apache.samza.table.ratelimit.AsyncRateLimitedTable;
+import org.apache.samza.table.retry.AsyncRetriableTable;
 import org.apache.samza.table.retry.TableRetryPolicy;
 
+import org.apache.samza.testUtils.TestUtils;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -38,14 +42,12 @@
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyCollection;
-import static org.mockito.Matchers.anyString;
+import static org.junit.Assert.*;
+import static org.mockito.Matchers.*;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
@@ -55,8 +57,6 @@
 
 public class TestRemoteTable {
 
-  private final ScheduledExecutorService schedExec = Executors.newSingleThreadScheduledExecutor();
-
   public static Context getMockContext() {
     Context context = new MockContext();
     MetricsRegistry metricsRegistry = mock(MetricsRegistry.class);
@@ -88,10 +88,14 @@
         readRateLimiter, writeRateLimiter, rateLimitingExecutor,
         readPolicy, writePolicy, retryExecutor, cbExecutor);
     table.init(getMockContext());
+    verify(readFn, times(1)).init(any(), any());
+    if (writeFn != null) {
+      verify(writeFn, times(1)).init(any(), any());
+    }
     return (T) table;
   }
 
-  private void doTestGet(boolean sync, boolean error, boolean retry) throws Exception {
+  private void doTestGet(boolean sync, boolean error, boolean retry) {
     String tableId = "testGet-" + sync + error + retry;
     TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
     // Sync is backed by async so needs to mock the async method
@@ -114,27 +118,49 @@
       doReturn(true).when(readFn).isRetriable(any());
     }
     RemoteTable<String, String> table = getTable(tableId, readFn, null, retry);
-    Assert.assertEquals("bar", sync ? table.get("foo") : table.getAsync("foo").get());
+    Assert.assertEquals("bar", sync ? table.get("foo") : table.getAsync("foo").join());
     verify(table.readRateLimiter, times(error && retry ? 2 : 1)).throttle(anyString());
   }
 
   @Test
-  public void testGet() throws Exception {
+  public void testInit() {
+    String tableId = "testInit";
+    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
+    TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
+    RemoteTable<String, String> table = getTable(tableId, readFn, writeFn, true);
+    // AsyncRetriableTable
+    AsyncReadWriteTable innerTable = TestUtils.getFieldValue(table, "asyncTable");
+    Assert.assertTrue(innerTable instanceof AsyncRetriableTable);
+    Assert.assertNotNull(TestUtils.getFieldValue(innerTable, "readRetryMetrics"));
+    Assert.assertNotNull(TestUtils.getFieldValue(innerTable, "writeRetryMetrics"));
+    // AsyncRateLimitedTable
+    innerTable = TestUtils.getFieldValue(innerTable, "table");
+    Assert.assertTrue(innerTable instanceof AsyncRateLimitedTable);
+    // AsyncRemoteTable
+    innerTable = TestUtils.getFieldValue(innerTable, "table");
+    Assert.assertTrue(innerTable instanceof AsyncRemoteTable);
+    // Verify table functions are initialized
+    verify(readFn, times(1)).init(any(), any());
+    verify(writeFn, times(1)).init(any(), any());
+  }
+
+  @Test
+  public void testGet() {
     doTestGet(true, false, false);
   }
 
   @Test
-  public void testGetAsync() throws Exception {
+  public void testGetAsync() {
     doTestGet(false, false, false);
   }
 
-  @Test(expected = ExecutionException.class)
-  public void testGetAsyncError() throws Exception {
+  @Test(expected = RuntimeException.class)
+  public void testGetAsyncError() {
     doTestGet(false, true, false);
   }
 
   @Test
-  public void testGetAsyncErrorRetried() throws Exception {
+  public void testGetAsyncErrorRetried() {
     doTestGet(false, true, true);
   }
 
@@ -160,7 +186,44 @@
           });
   }
 
-  private void doTestPut(boolean sync, boolean error, boolean isDelete, boolean retry) throws Exception {
+  public void doTestRead(boolean sync, boolean error) {
+    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
+    RemoteTable<String, String> table = getTable("testRead-" + sync + error,
+        readFn, mock(TableWriteFunction.class), false);
+    CompletableFuture<?> future;
+    if (error) {
+      future = new CompletableFuture();
+      future.completeExceptionally(new RuntimeException("Test exception"));
+    } else {
+      future = CompletableFuture.completedFuture(5);
+    }
+    // Sync is backed by async so needs to mock the async method
+    doReturn(future).when(readFn).readAsync(anyInt(), any());
+
+    int readResult = sync
+        ? table.read(1, 2)
+        : (Integer) table.readAsync(1, 2).join();
+    verify(readFn, times(1)).readAsync(anyInt(), any());
+    Assert.assertEquals(5, readResult);
+    verify(table.readRateLimiter, times(1)).throttle(anyInt(), any());
+  }
+
+  @Test
+  public void testRead() {
+    doTestRead(true, false);
+  }
+
+  @Test
+  public void testReadAsync() {
+    doTestRead(false, false);
+  }
+
+  @Test(expected = RuntimeException.class)
+  public void testReadAsyncError() {
+    doTestRead(false, true);
+  }
+
+  private void doTestPut(boolean sync, boolean error, boolean isDelete, boolean retry) {
     String tableId = "testPut-" + sync + error + isDelete + retry;
     TableWriteFunction<String, String> mockWriteFn = mock(TableWriteFunction.class);
     TableWriteFunction<String, String> writeFn = mockWriteFn;
@@ -192,7 +255,7 @@
     if (sync) {
       table.put("foo", isDelete ? null : "bar");
     } else {
-      table.putAsync("foo", isDelete ? null : "bar").get();
+      table.putAsync("foo", isDelete ? null : "bar").join();
     }
     ArgumentCaptor<String> keyCaptor = ArgumentCaptor.forClass(String.class);
     ArgumentCaptor<String> valCaptor = ArgumentCaptor.forClass(String.class);
@@ -211,36 +274,36 @@
   }
 
   @Test
-  public void testPut() throws Exception {
+  public void testPut() {
     doTestPut(true, false, false, false);
   }
 
   @Test
-  public void testPutDelete() throws Exception {
+  public void testPutDelete() {
     doTestPut(true, false, true, false);
   }
 
   @Test
-  public void testPutAsync() throws Exception {
+  public void testPutAsync() {
     doTestPut(false, false, false, false);
   }
 
   @Test
-  public void testPutAsyncDelete() throws Exception {
+  public void testPutAsyncDelete() {
     doTestPut(false, false, true, false);
   }
 
-  @Test(expected = ExecutionException.class)
-  public void testPutAsyncError() throws Exception {
+  @Test(expected = RuntimeException.class)
+  public void testPutAsyncError() {
     doTestPut(false, true, false, false);
   }
 
   @Test
-  public void testPutAsyncErrorRetried() throws Exception {
+  public void testPutAsyncErrorRetried() {
     doTestPut(false, true, false, true);
   }
 
-  private void doTestDelete(boolean sync, boolean error) throws Exception {
+  private void doTestDelete(boolean sync, boolean error) {
     TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
     RemoteTable<String, String> table = getTable("testDelete-" + sync + error,
         mock(TableReadFunction.class), writeFn, false);
@@ -257,7 +320,7 @@
     if (sync) {
       table.delete("foo");
     } else {
-      table.deleteAsync("foo").get();
+      table.deleteAsync("foo").join();
     }
     verify(writeFn, times(1)).deleteAsync(argCaptor.capture());
     Assert.assertEquals("foo", argCaptor.getValue());
@@ -265,21 +328,21 @@
   }
 
   @Test
-  public void testDelete() throws Exception {
+  public void testDelete() {
     doTestDelete(true, false);
   }
 
   @Test
-  public void testDeleteAsync() throws Exception {
+  public void testDeleteAsync() {
     doTestDelete(false, false);
   }
 
-  @Test(expected = ExecutionException.class)
-  public void testDeleteAsyncError() throws Exception {
+  @Test(expected = RuntimeException.class)
+  public void testDeleteAsyncError() {
     doTestDelete(false, true);
   }
 
-  private void doTestGetAll(boolean sync, boolean error, boolean partial) throws Exception {
+  private void doTestGetAll(boolean sync, boolean error, boolean partial) {
     TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
     Map<String, String> res = new HashMap<>();
     res.put("foo1", "bar1");
@@ -297,32 +360,32 @@
     doReturn(future).when(readFn).getAllAsync(any());
     RemoteTable<String, String> table = getTable("testGetAll-" + sync + error + partial, readFn, null, false);
     Assert.assertEquals(res, sync ? table.getAll(Arrays.asList("foo1", "foo2"))
-        : table.getAllAsync(Arrays.asList("foo1", "foo2")).get());
+        : table.getAllAsync(Arrays.asList("foo1", "foo2")).join());
     verify(table.readRateLimiter, times(1)).throttle(anyCollection());
   }
 
   @Test
-  public void testGetAll() throws Exception {
+  public void testGetAll() {
     doTestGetAll(true, false, false);
   }
 
   @Test
-  public void testGetAllAsync() throws Exception {
+  public void testGetAllAsync() {
     doTestGetAll(false, false, false);
   }
 
-  @Test(expected = ExecutionException.class)
-  public void testGetAllAsyncError() throws Exception {
+  @Test(expected = RuntimeException.class)
+  public void testGetAllAsyncError() {
     doTestGetAll(false, true, false);
   }
 
   // Partial result is an acceptable scenario
   @Test
-  public void testGetAllPartialResult() throws Exception {
+  public void testGetAllPartialResult() {
     doTestGetAll(false, false, true);
   }
 
-  public void doTestPutAll(boolean sync, boolean error, boolean hasDelete) throws Exception {
+  public void doTestPutAll(boolean sync, boolean error, boolean hasDelete) {
     TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
     RemoteTable<String, String> table = getTable("testPutAll-" + sync + error + hasDelete,
         mock(TableReadFunction.class), writeFn, false);
@@ -344,7 +407,7 @@
     if (sync) {
       table.putAll(entries);
     } else {
-      table.putAllAsync(entries).get();
+      table.putAllAsync(entries).join();
     }
     verify(writeFn, times(1)).putAllAsync(argCaptor.capture());
     if (hasDelete) {
@@ -361,31 +424,31 @@
   }
 
   @Test
-  public void testPutAll() throws Exception {
+  public void testPutAll() {
     doTestPutAll(true, false, false);
   }
 
   @Test
-  public void testPutAllHasDelete() throws Exception {
+  public void testPutAllHasDelete() {
     doTestPutAll(true, false, true);
   }
 
   @Test
-  public void testPutAllAsync() throws Exception {
+  public void testPutAllAsync() {
     doTestPutAll(false, false, false);
   }
 
   @Test
-  public void testPutAllAsyncHasDelete() throws Exception {
+  public void testPutAllAsyncHasDelete() {
     doTestPutAll(false, false, true);
   }
 
-  @Test(expected = ExecutionException.class)
-  public void testPutAllAsyncError() throws Exception {
+  @Test(expected = RuntimeException.class)
+  public void testPutAllAsyncError() {
     doTestPutAll(false, true, false);
   }
 
-  public void doTestDeleteAll(boolean sync, boolean error) throws Exception {
+  public void doTestDeleteAll(boolean sync, boolean error) {
     TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
     RemoteTable<String, String> table = getTable("testDeleteAll-" + sync + error,
         mock(TableReadFunction.class), writeFn, false);
@@ -403,7 +466,7 @@
     if (sync) {
       table.deleteAll(keys);
     } else {
-      table.deleteAllAsync(keys).get();
+      table.deleteAllAsync(keys).join();
     }
     verify(writeFn, times(1)).deleteAllAsync(argCaptor.capture());
     Assert.assertEquals(keys, argCaptor.getValue());
@@ -411,20 +474,57 @@
   }
 
   @Test
-  public void testDeleteAll() throws Exception {
+  public void testDeleteAll() {
     doTestDeleteAll(true, false);
   }
 
   @Test
-  public void testDeleteAllAsync() throws Exception {
+  public void testDeleteAllAsync() {
     doTestDeleteAll(false, false);
   }
 
-  @Test(expected = ExecutionException.class)
-  public void testDeleteAllAsyncError() throws Exception {
+  @Test(expected = RuntimeException.class)
+  public void testDeleteAllAsyncError() {
     doTestDeleteAll(false, true);
   }
 
+  public void doTestWrite(boolean sync, boolean error) {
+    TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
+    RemoteTable<String, String> table = getTable("testWrite-" + sync + error,
+        mock(TableReadFunction.class), writeFn, false);
+    CompletableFuture<?> future;
+    if (error) {
+      future = new CompletableFuture();
+      future.completeExceptionally(new RuntimeException("Test exception"));
+    } else {
+      future = CompletableFuture.completedFuture(5);
+    }
+    // Sync is backed by async so needs to mock the async method
+    doReturn(future).when(writeFn).writeAsync(anyInt(), any());
+
+    int writeResult = sync
+        ? table.write(1, 2)
+        : (Integer) table.writeAsync(1, 2).join();
+    verify(writeFn, times(1)).writeAsync(anyInt(), any());
+    Assert.assertEquals(5, writeResult);
+    verify(table.writeRateLimiter, times(1)).throttle(anyInt(), any());
+  }
+
+  @Test
+  public void testWrite() {
+    doTestWrite(true, false);
+  }
+
+  @Test
+  public void testWriteAsync() {
+    doTestWrite(false, false);
+  }
+
+  @Test(expected = RuntimeException.class)
+  public void testWriteAsyncError() {
+    doTestWrite(false, true);
+  }
+
   @Test
   public void testFlush() {
     TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
@@ -434,7 +534,7 @@
   }
 
   @Test
-  public void testGetWithCallbackExecutor() throws Exception {
+  public void testGetWithCallbackExecutor() {
     TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
     // Sync is backed by async so needs to mock the async method
     doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(anyString());
@@ -448,4 +548,106 @@
         Assert.assertNotSame(testThread, Thread.currentThread());
       });
   }
+
+  @Test
+  public void testGetDelegation() {
+    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
+    doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any());
+    doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any(), any());
+    Map<String, String> result = new HashMap<>();
+    result.put("foo", "bar");
+    doReturn(CompletableFuture.completedFuture(result)).when(readFn).getAllAsync(any());
+    doReturn(CompletableFuture.completedFuture(result)).when(readFn).getAllAsync(any(), any());
+    doReturn(CompletableFuture.completedFuture(5)).when(readFn).readAsync(anyInt(), any());
+
+    RemoteTable<String, String> table = getTable("testGetDelegation", readFn, null,
+        Executors.newSingleThreadExecutor(), true);
+    verify(readFn, times(1)).init(any(), any());
+
+    // GetAsync
+    verify(readFn, times(0)).getAsync(any());
+    verify(readFn, times(0)).getAsync(any(), any());
+    assertEquals("bar", table.getAsync("foo").join());
+    verify(readFn, times(1)).getAsync(any());
+    verify(readFn, times(0)).getAsync(any(), any());
+    assertEquals("bar", table.getAsync("foo", 1).join());
+    verify(readFn, times(1)).getAsync(any());
+    verify(readFn, times(1)).getAsync(any(), any());
+    // GetAllAsync
+    verify(readFn, times(0)).getAllAsync(any());
+    verify(readFn, times(0)).getAllAsync(any(), any());
+    assertEquals(result, table.getAllAsync(Arrays.asList("foo")).join());
+    verify(readFn, times(1)).getAllAsync(any());
+    verify(readFn, times(0)).getAllAsync(any(), any());
+    assertEquals(result, table.getAllAsync(Arrays.asList("foo"), Arrays.asList(1)).join());
+    verify(readFn, times(1)).getAllAsync(any());
+    verify(readFn, times(1)).getAllAsync(any(), any());
+    // ReadAsync
+    verify(readFn, times(0)).readAsync(anyInt(), any());
+    assertEquals(5, table.readAsync(1, 2).join());
+    verify(readFn, times(1)).readAsync(anyInt(), any());
+
+    table.close();
+  }
+
+  @Test
+  public void testPutAndDeleteDelegation() {
+    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
+    TableWriteFunction writeFn = mock(TableWriteFunction.class);
+    doReturn(true).when(writeFn).isRetriable(any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(anyCollection());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(anyCollection(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(anyCollection());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(anyCollection(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).writeAsync(anyInt(), any());
+
+    RemoteTable<String, String> table = getTable("testGetDelegation", readFn, writeFn,
+        Executors.newSingleThreadExecutor(), true);
+    verify(readFn, times(1)).init(any(), any());
+
+    // PutAsync
+    verify(writeFn, times(0)).putAsync(any(), any());
+    verify(writeFn, times(0)).putAsync(any(), any(), any());
+    table.putAsync("roo", "bar").join();
+    verify(writeFn, times(1)).putAsync(any(), any());
+    verify(writeFn, times(0)).putAsync(any(), any(), any());
+    table.putAsync("foo", "bar", 3).join();
+    verify(writeFn, times(1)).putAsync(any(), any());
+    verify(writeFn, times(1)).putAsync(any(), any(), any());
+    // PutAllAsync
+    verify(writeFn, times(0)).putAllAsync(anyCollection());
+    verify(writeFn, times(0)).putAllAsync(anyCollection(), any());
+    table.putAllAsync(Arrays.asList(new Entry("foo", "bar"))).join();
+    verify(writeFn, times(1)).putAllAsync(anyCollection());
+    verify(writeFn, times(0)).putAllAsync(anyCollection(), any());
+    table.putAllAsync(Arrays.asList(new Entry("foo", "bar")), 2).join();
+    verify(writeFn, times(1)).putAllAsync(anyCollection());
+    verify(writeFn, times(1)).putAllAsync(anyCollection(), any());
+    // DeleteAsync
+    verify(writeFn, times(0)).deleteAsync(any());
+    verify(writeFn, times(0)).deleteAsync(any(), any());
+    table.deleteAsync("foo").join();
+    verify(writeFn, times(1)).deleteAsync(any());
+    verify(writeFn, times(0)).deleteAsync(any(), any());
+    table.deleteAsync("foo", 2).join();
+    verify(writeFn, times(1)).deleteAsync(any());
+    verify(writeFn, times(1)).deleteAsync(any(), any());
+    // DeleteAllAsync
+    verify(writeFn, times(0)).deleteAllAsync(anyCollection());
+    verify(writeFn, times(0)).deleteAllAsync(anyCollection(), any());
+    table.deleteAllAsync(Arrays.asList("foo")).join();
+    verify(writeFn, times(1)).deleteAllAsync(anyCollection());
+    verify(writeFn, times(0)).deleteAllAsync(anyCollection(), any());
+    table.deleteAllAsync(Arrays.asList("foo"), Arrays.asList(2)).join();
+    verify(writeFn, times(1)).deleteAllAsync(anyCollection());
+    verify(writeFn, times(1)).deleteAllAsync(anyCollection(), any());
+    // WriteAsync
+    verify(writeFn, times(0)).writeAsync(anyInt(), any());
+    table.writeAsync(1, 2).join();
+    verify(writeFn, times(1)).writeAsync(anyInt(), any());
+  }
 }
diff --git a/samza-core/src/test/java/org/apache/samza/table/remote/descriptors/TestRemoteTableDescriptor.java b/samza-core/src/test/java/org/apache/samza/table/remote/descriptors/TestRemoteTableDescriptor.java
index 5bc8339..ce89c5a 100644
--- a/samza-core/src/test/java/org/apache/samza/table/remote/descriptors/TestRemoteTableDescriptor.java
+++ b/samza-core/src/test/java/org/apache/samza/table/remote/descriptors/TestRemoteTableDescriptor.java
@@ -111,17 +111,17 @@
 
   @Test
   public void testSerializeWithLimiterAndReadCredFn() {
-    doTestSerialize(createMockRateLimiter(), (k, v) -> 1, null);
+    doTestSerialize(createMockRateLimiter(), (k, v, args) -> 1, null);
   }
 
   @Test
   public void testSerializeWithLimiterAndWriteCredFn() {
-    doTestSerialize(createMockRateLimiter(), null, (k, v) -> 1);
+    doTestSerialize(createMockRateLimiter(), null, (k, v, args) -> 1);
   }
 
   @Test
   public void testSerializeWithLimiterAndReadWriteCredFns() {
-    doTestSerialize(createMockRateLimiter(), (key, value) -> 1, (key, value) -> 1);
+    doTestSerialize(createMockRateLimiter(), (key, value, args) -> 1, (key, value, args) -> 1);
   }
 
   @Test
@@ -263,7 +263,7 @@
     int numCalls = 0;
 
     @Override
-    public int getCredits(K key, V value) {
+    public int getCredits(K key, V value, Object ... args) {
       numCalls++;
       return 1;
     }
diff --git a/samza-core/src/test/java/org/apache/samza/table/retry/TestAsyncRetriableTable.java b/samza-core/src/test/java/org/apache/samza/table/retry/TestAsyncRetriableTable.java
index 1f4b514..ec4307d 100644
--- a/samza-core/src/test/java/org/apache/samza/table/retry/TestAsyncRetriableTable.java
+++ b/samza-core/src/test/java/org/apache/samza/table/retry/TestAsyncRetriableTable.java
@@ -23,7 +23,6 @@
 import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -77,7 +76,51 @@
   }
 
   @Test
-  public void testGetWithoutRetry() throws Exception {
+  public void testGetDelegation() {
+    TableRetryPolicy policy = new TableRetryPolicy();
+    policy.withFixedBackoff(Duration.ofMillis(100));
+    TableReadFunction readFn = mock(TableReadFunction.class);
+    doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any());
+    doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any(), any());
+    Map<String, String> result = new HashMap<>();
+    result.put("foo", "bar");
+    doReturn(CompletableFuture.completedFuture(result)).when(readFn).getAllAsync(any());
+    doReturn(CompletableFuture.completedFuture(result)).when(readFn).getAllAsync(any(), any());
+    doReturn(CompletableFuture.completedFuture(5)).when(readFn).readAsync(anyInt(), any());
+    AsyncReadWriteTable delegate = new AsyncRemoteTable(readFn, null);
+    AsyncRetriableTable table = new AsyncRetriableTable("t1", delegate, policy, null, schedExec, readFn, null);
+
+    table.init(TestRemoteTable.getMockContext());
+    verify(readFn, times(0)).init(any(), any());
+
+    // GetAsync
+    verify(readFn, times(0)).getAsync(any());
+    verify(readFn, times(0)).getAsync(any(), any());
+    assertEquals("bar", table.getAsync("foo").join());
+    verify(readFn, times(1)).getAsync(any());
+    verify(readFn, times(0)).getAsync(any(), any());
+    assertEquals("bar", table.getAsync("foo", 1).join());
+    verify(readFn, times(1)).getAsync(any());
+    verify(readFn, times(1)).getAsync(any(), any());
+    // GetAllAsync
+    verify(readFn, times(0)).getAllAsync(any());
+    verify(readFn, times(0)).getAllAsync(any(), any());
+    assertEquals(result, table.getAllAsync(Arrays.asList("foo")).join());
+    verify(readFn, times(1)).getAllAsync(any());
+    verify(readFn, times(0)).getAllAsync(any(), any());
+    assertEquals(result, table.getAllAsync(Arrays.asList("foo"), Arrays.asList(1)).join());
+    verify(readFn, times(1)).getAllAsync(any());
+    verify(readFn, times(1)).getAllAsync(any(), any());
+    // ReadAsync
+    verify(readFn, times(0)).readAsync(anyInt(), any());
+    assertEquals(5, table.readAsync(1, 2).join());
+    verify(readFn, times(1)).readAsync(anyInt(), any());
+
+    table.close();
+  }
+
+  @Test
+  public void testGetWithoutRetry() {
     TableRetryPolicy policy = new TableRetryPolicy();
     policy.withFixedBackoff(Duration.ofMillis(100));
     TableReadFunction readFn = mock(TableReadFunction.class);
@@ -91,11 +134,11 @@
 
     int times = 0;
     table.init(TestRemoteTable.getMockContext());
-    verify(readFn, times(1)).init(any());
-    assertEquals("bar", table.getAsync("foo").get());
+    verify(readFn, times(0)).init(any(), any());
+    assertEquals("bar", table.getAsync("foo").join());
     verify(readFn, times(1)).getAsync(any());
     assertEquals(++times, table.readRetryMetrics.successCount.getCount());
-    assertEquals(result, table.getAllAsync(Arrays.asList("foo")).get());
+    assertEquals(result, table.getAllAsync(Arrays.asList("foo")).join());
     verify(readFn, times(1)).getAllAsync(any());
     assertEquals(++times, table.readRetryMetrics.successCount.getCount());
     assertEquals(0, table.readRetryMetrics.retryCount.getCount());
@@ -107,7 +150,7 @@
   }
 
   @Test
-  public void testGetWithRetryDisabled() throws Exception {
+  public void testGetWithRetryDisabled() {
     TableRetryPolicy policy = new TableRetryPolicy();
     policy.withFixedBackoff(Duration.ofMillis(10));
     policy.withStopAfterDelay(Duration.ofMillis(100));
@@ -121,9 +164,9 @@
     table.init(TestRemoteTable.getMockContext());
 
     try {
-      table.getAsync("foo").get();
+      table.getAsync("foo").join();
       fail();
-    } catch (ExecutionException e) {
+    } catch (Throwable t) {
     }
 
     verify(readFn, times(1)).getAsync(any());
@@ -134,7 +177,7 @@
   }
 
   @Test
-  public void testGetAllWithOneRetry() throws Exception {
+  public void testGetAllWithOneRetry() {
     TableRetryPolicy policy = new TableRetryPolicy();
     policy.withFixedBackoff(Duration.ofMillis(10));
     TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
@@ -153,13 +196,13 @@
           future.completeExceptionally(new RuntimeException("test exception"));
         }
         return future;
-      }).when(readFn).getAllAsync(any());
+      }).when(readFn).getAllAsync(anyCollection());
 
     AsyncReadWriteTable delegate = new AsyncRemoteTable(readFn, null);
     AsyncRetriableTable table = new AsyncRetriableTable("t1", delegate, policy, null, schedExec, readFn, null);
     table.init(TestRemoteTable.getMockContext());
 
-    assertEquals(map, table.getAllAsync(Arrays.asList("foo1", "foo2")).get());
+    assertEquals(map, table.getAllAsync(Arrays.asList("foo1", "foo2")).join());
     verify(readFn, times(2)).getAllAsync(any());
     assertEquals(1, table.readRetryMetrics.retryCount.getCount());
     assertEquals(0, table.readRetryMetrics.successCount.getCount());
@@ -168,7 +211,7 @@
   }
 
   @Test
-  public void testGetWithPermFailureOnTimeout() throws Exception {
+  public void testGetWithPermFailureOnTimeout() {
     TableRetryPolicy policy = new TableRetryPolicy();
     policy.withFixedBackoff(Duration.ofMillis(5));
     policy.withStopAfterDelay(Duration.ofMillis(100));
@@ -182,9 +225,9 @@
     table.init(TestRemoteTable.getMockContext());
 
     try {
-      table.getAsync("foo").get();
+      table.getAsync("foo").join();
       fail();
-    } catch (ExecutionException e) {
+    } catch (Throwable t) {
     }
 
     verify(readFn, atLeast(3)).getAsync(any());
@@ -195,7 +238,7 @@
   }
 
   @Test
-  public void testGetWithPermFailureOnMaxCount() throws Exception {
+  public void testGetWithPermFailureOnMaxCount() {
     TableRetryPolicy policy = new TableRetryPolicy();
     policy.withFixedBackoff(Duration.ofMillis(5));
     policy.withStopAfterAttempts(10);
@@ -209,9 +252,9 @@
     table.init(TestRemoteTable.getMockContext());
 
     try {
-      table.getAsync("foo").get();
+      table.getAsync("foo").join();
       fail();
-    } catch (ExecutionException e) {
+    } catch (Throwable t) {
     }
 
     verify(readFn, atLeast(11)).getAsync(any());
@@ -222,7 +265,68 @@
   }
 
   @Test
-  public void testPutWithoutRetry() throws Exception {
+  public void testPutAndDeleteDelegation() {
+    TableRetryPolicy policy = new TableRetryPolicy();
+    policy.withFixedBackoff(Duration.ofMillis(100));
+    TableReadFunction readFn = mock(TableReadFunction.class);
+    TableWriteFunction writeFn = mock(TableWriteFunction.class);
+    doReturn(true).when(writeFn).isRetriable(any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any(), any());
+    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).writeAsync(anyInt(), any());
+    AsyncReadWriteTable delegate = new AsyncRemoteTable(readFn, writeFn);
+    AsyncRetriableTable table = new AsyncRetriableTable("t1", delegate, null, policy, schedExec, readFn, writeFn);
+
+    // PutAsync
+    verify(writeFn, times(0)).putAsync(any(), any());
+    verify(writeFn, times(0)).putAsync(any(), any(), any());
+    table.putAsync(1, 2).join();
+    verify(writeFn, times(1)).putAsync(any(), any());
+    verify(writeFn, times(0)).putAsync(any(), any(), any());
+    table.putAsync(1, 2, 3).join();
+    verify(writeFn, times(1)).putAsync(any(), any());
+    verify(writeFn, times(1)).putAsync(any(), any(), any());
+    // PutAllAsync
+    verify(writeFn, times(0)).putAllAsync(anyCollection());
+    verify(writeFn, times(0)).putAllAsync(anyCollection(), any());
+    table.putAllAsync(Arrays.asList(1)).join();
+    verify(writeFn, times(1)).putAllAsync(anyCollection());
+    verify(writeFn, times(0)).putAllAsync(anyCollection(), any());
+    table.putAllAsync(Arrays.asList(1), Arrays.asList(1)).join();
+    verify(writeFn, times(1)).putAllAsync(anyCollection());
+    verify(writeFn, times(1)).putAllAsync(anyCollection(), any());
+    // DeleteAsync
+    verify(writeFn, times(0)).deleteAsync(any());
+    verify(writeFn, times(0)).deleteAsync(any(), any());
+    table.deleteAsync(1).join();
+    verify(writeFn, times(1)).deleteAsync(any());
+    verify(writeFn, times(0)).deleteAsync(any(), any());
+    table.deleteAsync(1, 2).join();
+    verify(writeFn, times(1)).deleteAsync(any());
+    verify(writeFn, times(1)).deleteAsync(any(), any());
+    // DeleteAllAsync
+    verify(writeFn, times(0)).deleteAllAsync(anyCollection());
+    verify(writeFn, times(0)).deleteAllAsync(anyCollection(), any());
+    table.deleteAllAsync(Arrays.asList(1)).join();
+    verify(writeFn, times(1)).deleteAllAsync(anyCollection());
+    verify(writeFn, times(0)).deleteAllAsync(anyCollection(), any());
+    table.deleteAllAsync(Arrays.asList(1), Arrays.asList(2)).join();
+    verify(writeFn, times(1)).deleteAllAsync(anyCollection());
+    verify(writeFn, times(1)).deleteAllAsync(anyCollection(), any());
+    // WriteAsync
+    verify(writeFn, times(0)).writeAsync(anyInt(), any());
+    table.writeAsync(1, 2).join();
+    verify(writeFn, times(1)).writeAsync(anyInt(), any());
+  }
+
+  @Test
+  public void testPutWithoutRetry() {
     TableRetryPolicy policy = new TableRetryPolicy();
     policy.withFixedBackoff(Duration.ofMillis(100));
     TableReadFunction readFn = mock(TableReadFunction.class);
@@ -237,18 +341,18 @@
 
     int times = 0;
     table.init(TestRemoteTable.getMockContext());
-    verify(readFn, times(1)).init(any());
-    verify(writeFn, times(1)).init(any());
-    table.putAsync("foo", "bar").get();
+    verify(readFn, times(0)).init(any(), any());
+    verify(writeFn, times(0)).init(any(), any());
+    table.putAsync("foo", "bar").join();
     verify(writeFn, times(1)).putAsync(any(), any());
     assertEquals(++times, table.writeRetryMetrics.successCount.getCount());
-    table.putAllAsync(Arrays.asList(new Entry("1", "2"))).get();
+    table.putAllAsync(Arrays.asList(new Entry("1", "2"))).join();
     verify(writeFn, times(1)).putAllAsync(any());
     assertEquals(++times, table.writeRetryMetrics.successCount.getCount());
-    table.deleteAsync("1").get();
+    table.deleteAsync("1").join();
     verify(writeFn, times(1)).deleteAsync(any());
     assertEquals(++times, table.writeRetryMetrics.successCount.getCount());
-    table.deleteAllAsync(Arrays.asList("1", "2")).get();
+    table.deleteAllAsync(Arrays.asList("1", "2")).join();
     verify(writeFn, times(1)).deleteAllAsync(any());
     assertEquals(++times, table.writeRetryMetrics.successCount.getCount());
     assertEquals(0, table.writeRetryMetrics.retryCount.getCount());
@@ -258,7 +362,7 @@
   }
 
   @Test
-  public void testPutWithRetryDisabled() throws Exception {
+  public void testPutWithRetryDisabled() {
     TableRetryPolicy policy = new TableRetryPolicy();
     policy.withFixedBackoff(Duration.ofMillis(10));
     policy.withStopAfterDelay(Duration.ofMillis(100));
@@ -273,9 +377,9 @@
     table.init(TestRemoteTable.getMockContext());
 
     try {
-      table.putAsync("foo", "bar").get();
+      table.putAsync("foo", "bar").join();
       fail();
-    } catch (ExecutionException e) {
+    } catch (Throwable t) {
     }
 
     verify(writeFn, times(1)).putAsync(any(), any());
@@ -286,7 +390,7 @@
   }
 
   @Test
-  public void testPutAllWithOneRetry() throws Exception {
+  public void testPutAllWithOneRetry() {
     TableRetryPolicy policy = new TableRetryPolicy();
     policy.withFixedBackoff(Duration.ofMillis(10));
     TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
@@ -309,7 +413,7 @@
     AsyncRetriableTable table = new AsyncRetriableTable("t1", delegate, null, policy, schedExec, readFn, writeFn);
     table.init(TestRemoteTable.getMockContext());
 
-    table.putAllAsync(Arrays.asList(new Entry(1, 2))).get();
+    table.putAllAsync(Arrays.asList(new Entry(1, 2))).join();
     verify(writeFn, times(2)).putAllAsync(any());
     assertEquals(1, table.writeRetryMetrics.retryCount.getCount());
     assertEquals(0, table.writeRetryMetrics.successCount.getCount());
@@ -318,7 +422,7 @@
   }
 
   @Test
-  public void testPutWithPermFailureOnTimeout() throws Exception {
+  public void testPutWithPermFailureOnTimeout() {
     TableRetryPolicy policy = new TableRetryPolicy();
     policy.withFixedBackoff(Duration.ofMillis(5));
     policy.withStopAfterDelay(Duration.ofMillis(100));
@@ -333,9 +437,9 @@
     table.init(TestRemoteTable.getMockContext());
 
     try {
-      table.putAsync("foo", "bar").get();
+      table.putAsync("foo", "bar").join();
       fail();
-    } catch (ExecutionException e) {
+    } catch (Throwable t) {
     }
 
     verify(writeFn, atLeast(3)).putAsync(any(), any());
@@ -346,7 +450,7 @@
   }
 
   @Test
-  public void testPutWithPermFailureOnMaxCount() throws Exception {
+  public void testPutWithPermFailureOnMaxCount() {
     TableRetryPolicy policy = new TableRetryPolicy();
     policy.withFixedBackoff(Duration.ofMillis(5));
     policy.withStopAfterAttempts(10);
@@ -361,9 +465,9 @@
     table.init(TestRemoteTable.getMockContext());
 
     try {
-      table.putAllAsync(Arrays.asList(new Entry(1, 2))).get();
+      table.putAllAsync(Arrays.asList(new Entry(1, 2))).join();
       fail();
-    } catch (ExecutionException e) {
+    } catch (Throwable t) {
     }
 
     verify(writeFn, atLeast(11)).putAllAsync(any());
diff --git a/samza-core/src/test/java/org/apache/samza/table/retry/TestRetriableTableFunctions.java b/samza-core/src/test/java/org/apache/samza/table/retry/TestRetriableTableFunctions.java
deleted file mode 100644
index 34aac9f..0000000
--- a/samza-core/src/test/java/org/apache/samza/table/retry/TestRetriableTableFunctions.java
+++ /dev/null
@@ -1,312 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.table.retry;
-
-import java.time.Duration;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Executors;
-import java.util.concurrent.ScheduledExecutorService;
-import junit.framework.Assert;
-import org.apache.samza.context.Context;
-import org.apache.samza.storage.kv.Entry;
-import org.apache.samza.table.Table;
-import org.apache.samza.table.remote.TableReadFunction;
-import org.apache.samza.table.remote.TableWriteFunction;
-import org.apache.samza.table.remote.TestRemoteTable;
-import org.apache.samza.table.utils.TableMetricsUtil;
-import org.junit.Test;
-
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyString;
-import static org.mockito.Mockito.atLeast;
-import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-
-
-public class TestRetriableTableFunctions {
-
-  private final ScheduledExecutorService schedExec = Executors.newSingleThreadScheduledExecutor();
-
-  public TableMetricsUtil getMetricsUtil(String tableId) {
-    Table table = mock(Table.class);
-    Context context = TestRemoteTable.getMockContext();
-    return new TableMetricsUtil(context, table, tableId);
-  }
-
-  @Test
-  public void testFirstTimeSuccessGet() throws Exception {
-    String tableId = "testFirstTimeSuccessGet";
-    TableRetryPolicy policy = new TableRetryPolicy();
-    policy.withFixedBackoff(Duration.ofMillis(100));
-    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
-    doReturn(true).when(readFn).isRetriable(any());
-    doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(anyString());
-    RetriableReadFunction<String, String> retryIO = new RetriableReadFunction<>(policy, readFn, schedExec);
-    retryIO.setMetrics(getMetricsUtil(tableId));
-    Assert.assertEquals("bar", retryIO.getAsync("foo").get());
-    verify(readFn, times(1)).getAsync(anyString());
-
-    Assert.assertEquals(0, retryIO.retryMetrics.retryCount.getCount());
-    Assert.assertEquals(1, retryIO.retryMetrics.successCount.getCount());
-    Assert.assertEquals(0, retryIO.retryMetrics.retryTimer.getSnapshot().getMax());
-  }
-
-  @Test
-  public void testRetryEngagedGet() throws Exception {
-    String tableId = "testRetryEngagedGet";
-    TableRetryPolicy policy = new TableRetryPolicy();
-    policy.withFixedBackoff(Duration.ofMillis(10));
-    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
-    doReturn(true).when(readFn).isRetriable(any());
-
-    int[] times = {0};
-    Map<String, String> map = new HashMap<>();
-    map.put("foo1", "bar1");
-    map.put("foo2", "bar2");
-    doAnswer(invocation -> {
-        CompletableFuture<Map<String, String>> future = new CompletableFuture();
-        if (times[0] > 0) {
-          future.complete(map);
-        } else {
-          times[0]++;
-          future.completeExceptionally(new RuntimeException("test exception"));
-        }
-        return future;
-      }).when(readFn).getAllAsync(any());
-
-    RetriableReadFunction<String, String> retryIO = new RetriableReadFunction<>(policy, readFn, schedExec);
-    retryIO.setMetrics(getMetricsUtil(tableId));
-    Assert.assertEquals(map, retryIO.getAllAsync(Arrays.asList("foo1", "foo2")).get());
-    verify(readFn, times(2)).getAllAsync(any());
-
-    Assert.assertEquals(1, retryIO.retryMetrics.retryCount.getCount());
-    Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
-    Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
-  }
-
-  @Test
-  public void testRetryExhaustedTimeGet() throws Exception {
-    String tableId = "testRetryExhaustedTime";
-    TableRetryPolicy policy = new TableRetryPolicy();
-    policy.withFixedBackoff(Duration.ofMillis(5));
-    policy.withStopAfterDelay(Duration.ofMillis(100));
-    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
-    doReturn(true).when(readFn).isRetriable(any());
-
-    CompletableFuture<String> future = new CompletableFuture();
-    future.completeExceptionally(new RuntimeException("test exception"));
-    doReturn(future).when(readFn).getAsync(anyString());
-
-    RetriableReadFunction<String, String> retryIO = new RetriableReadFunction<>(policy, readFn, schedExec);
-    retryIO.setMetrics(getMetricsUtil(tableId));
-    try {
-      retryIO.getAsync("foo").get();
-      Assert.fail();
-    } catch (ExecutionException e) {
-    }
-
-    // Conservatively: must be at least 3 attempts with 5ms backoff and 100ms maxDelay
-    verify(readFn, atLeast(3)).getAsync(anyString());
-    Assert.assertTrue(retryIO.retryMetrics.retryCount.getCount() >= 3);
-    Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
-    Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
-  }
-
-  @Test
-  public void testRetryExhaustedAttemptsGet() throws Exception {
-    String tableId = "testRetryExhaustedAttempts";
-    TableRetryPolicy policy = new TableRetryPolicy();
-    policy.withFixedBackoff(Duration.ofMillis(5));
-    policy.withStopAfterAttempts(10);
-    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
-    doReturn(true).when(readFn).isRetriable(any());
-
-    CompletableFuture<String> future = new CompletableFuture();
-    future.completeExceptionally(new RuntimeException("test exception"));
-    doReturn(future).when(readFn).getAllAsync(any());
-
-    RetriableReadFunction<String, String> retryIO = new RetriableReadFunction<>(policy, readFn, schedExec);
-    retryIO.setMetrics(getMetricsUtil(tableId));
-    try {
-      retryIO.getAllAsync(Arrays.asList("foo1", "foo2")).get();
-      Assert.fail();
-    } catch (ExecutionException e) {
-    }
-
-    // 1 initial try + 10 retries
-    verify(readFn, times(11)).getAllAsync(any());
-    Assert.assertEquals(10, retryIO.retryMetrics.retryCount.getCount());
-    Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
-    Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
-  }
-
-  @Test
-  public void testFirstTimeSuccessPut() throws Exception {
-    String tableId = "testFirstTimeSuccessPut";
-    TableRetryPolicy policy = new TableRetryPolicy();
-    policy.withFixedBackoff(Duration.ofMillis(100));
-    TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
-    doReturn(true).when(writeFn).isRetriable(any());
-    doReturn(CompletableFuture.completedFuture("bar")).when(writeFn).putAsync(anyString(), anyString());
-    RetriableWriteFunction<String, String> retryIO = new RetriableWriteFunction<>(policy, writeFn, schedExec);
-    retryIO.setMetrics(getMetricsUtil(tableId));
-    retryIO.putAsync("foo", "bar").get();
-    verify(writeFn, times(1)).putAsync(anyString(), anyString());
-
-    Assert.assertEquals(0, retryIO.retryMetrics.retryCount.getCount());
-    Assert.assertEquals(1, retryIO.retryMetrics.successCount.getCount());
-    Assert.assertEquals(0, retryIO.retryMetrics.retryTimer.getSnapshot().getMax());
-  }
-
-  @Test
-  public void testRetryEngagedPut() throws Exception {
-    String tableId = "testRetryEngagedPut";
-    TableRetryPolicy policy = new TableRetryPolicy();
-    policy.withFixedBackoff(Duration.ofMillis(10));
-    TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
-    doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any());
-    doReturn(true).when(writeFn).isRetriable(any());
-
-    int[] times = new int[] {0};
-    List<Entry<String, String>> records = new ArrayList<>();
-    records.add(new Entry<>("foo1", "bar1"));
-    records.add(new Entry<>("foo2", "bar2"));
-    doAnswer(invocation -> {
-        CompletableFuture<Map<String, String>> future = new CompletableFuture();
-        if (times[0] > 0) {
-          future.complete(null);
-        } else {
-          times[0]++;
-          future.completeExceptionally(new RuntimeException("test exception"));
-        }
-        return future;
-      }).when(writeFn).putAllAsync(any());
-
-    RetriableWriteFunction<String, String> retryIO = new RetriableWriteFunction<>(policy, writeFn, schedExec);
-    retryIO.setMetrics(getMetricsUtil(tableId));
-    retryIO.putAllAsync(records).get();
-    verify(writeFn, times(2)).putAllAsync(any());
-
-    Assert.assertEquals(1, retryIO.retryMetrics.retryCount.getCount());
-    Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
-    Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
-  }
-
-  @Test
-  public void testRetryExhaustedTimePut() throws Exception {
-    String tableId = "testRetryExhaustedTimePut";
-    TableRetryPolicy policy = new TableRetryPolicy();
-    policy.withFixedBackoff(Duration.ofMillis(5));
-    policy.withStopAfterDelay(Duration.ofMillis(100));
-    TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
-    doReturn(true).when(writeFn).isRetriable(any());
-
-    CompletableFuture<String> future = new CompletableFuture();
-    future.completeExceptionally(new RuntimeException("test exception"));
-    doReturn(future).when(writeFn).deleteAsync(anyString());
-
-    RetriableWriteFunction<String, String> retryIO = new RetriableWriteFunction<>(policy, writeFn, schedExec);
-    retryIO.setMetrics(getMetricsUtil(tableId));
-    try {
-      retryIO.deleteAsync("foo").get();
-      Assert.fail();
-    } catch (ExecutionException e) {
-    }
-
-    // Conservatively: must be at least 3 attempts with 5ms backoff and 100ms maxDelay
-    verify(writeFn, atLeast(3)).deleteAsync(anyString());
-    Assert.assertTrue(retryIO.retryMetrics.retryCount.getCount() >= 3);
-    Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
-    Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
-  }
-
-  @Test
-  public void testRetryExhaustedAttemptsPut() throws Exception {
-    String tableId = "testRetryExhaustedAttemptsPut";
-    TableRetryPolicy policy = new TableRetryPolicy();
-    policy.withFixedBackoff(Duration.ofMillis(5));
-    policy.withStopAfterAttempts(10);
-    TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
-    doReturn(true).when(writeFn).isRetriable(any());
-
-    CompletableFuture<String> future = new CompletableFuture();
-    future.completeExceptionally(new RuntimeException("test exception"));
-    doReturn(future).when(writeFn).deleteAllAsync(any());
-
-    RetriableWriteFunction<String, String> retryIO = new RetriableWriteFunction<>(policy, writeFn, schedExec);
-    retryIO.setMetrics(getMetricsUtil(tableId));
-    try {
-      retryIO.deleteAllAsync(Arrays.asList("foo1", "foo2")).get();
-      Assert.fail();
-    } catch (ExecutionException e) {
-    }
-
-    // 1 initial try + 10 retries
-    verify(writeFn, times(11)).deleteAllAsync(any());
-    Assert.assertEquals(10, retryIO.retryMetrics.retryCount.getCount());
-    Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
-    Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
-  }
-
-  @Test
-  public void testMixedIsRetriablePredicates() throws Exception {
-    String tableId = "testMixedIsRetriablePredicates";
-    TableRetryPolicy policy = new TableRetryPolicy();
-    policy.withFixedBackoff(Duration.ofMillis(100));
-
-    // Retry should be attempted based on the custom classification, ie. retry on NPE
-    policy.withRetryPredicate(ex -> ex instanceof NullPointerException);
-
-    TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
-
-    // Table fn classification only retries on IllegalArgumentException
-    doAnswer(arg -> arg.getArgumentAt(0, Throwable.class) instanceof IllegalArgumentException).when(readFn).isRetriable(any());
-
-    int[] times = new int[1];
-    doAnswer(arg -> {
-        if (times[0]++ == 0) {
-          CompletableFuture<String> future = new CompletableFuture();
-          future.completeExceptionally(new NullPointerException("test exception"));
-          return future;
-        } else {
-          return CompletableFuture.completedFuture("bar");
-        }
-      }).when(readFn).getAsync(any());
-
-    RetriableReadFunction<String, String> retryIO = new RetriableReadFunction<>(policy, readFn, schedExec);
-    retryIO.setMetrics(getMetricsUtil(tableId));
-
-    Assert.assertEquals("bar", retryIO.getAsync("foo").get());
-
-    verify(readFn, times(2)).getAsync(anyString());
-    Assert.assertEquals(1, retryIO.retryMetrics.retryCount.getCount());
-    Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
-    Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
-  }
-}
diff --git a/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala b/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
index 0c009a1..4c53ec2 100644
--- a/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
+++ b/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
@@ -22,12 +22,12 @@
 import java.util
 import java.util.function.BiConsumer
 
-import org.apache.samza.{Partition, SamzaException}
 import org.apache.samza.config.MapConfig
 import org.apache.samza.container.TaskName
 import org.apache.samza.startpoint.{StartpointManagerTestUtil, StartpointOldest, StartpointUpcoming}
 import org.apache.samza.system.SystemStreamMetadata.{OffsetType, SystemStreamPartitionMetadata}
 import org.apache.samza.system._
+import org.apache.samza.{Partition, SamzaException}
 import org.junit.Assert._
 import org.junit.Test
 import org.mockito.Matchers._
@@ -67,13 +67,14 @@
     val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
     val config = new MapConfig
     val checkpointManager = getCheckpointManager(systemStreamPartition, taskName)
-    val startpointManager = getStartpointManager()
+    val startpointManagerUtil = getStartpointManagerUtil
     val systemAdmins = mock(classOf[SystemAdmins])
     when(systemAdmins.getSystemAdmin("test-system")).thenReturn(getSystemAdmin)
-    val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, getStartpointManager(), systemAdmins, Map(), new OffsetManagerMetrics)
+    val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins, Map(), new OffsetManagerMetrics)
     offsetManager.register(taskName, Set(systemStreamPartition))
-    startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
+    startpointManagerUtil.getStartpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
+    assertTrue(startpointManagerUtil.getStartpointManager.readStartpoint(systemStreamPartition, taskName).isPresent)
+    assertTrue(startpointManagerUtil.getStartpointManager.fanOut(asTaskToSSPMap(taskName, systemStreamPartition)).keySet().contains(taskName))
     offsetManager.start
     assertTrue(checkpointManager.isStarted)
     assertEquals(1, checkpointManager.registered.size)
@@ -90,11 +91,19 @@
     assertEquals("46", offsetManager.getStartingOffset(taskName, systemStreamPartition).get)
     // Should not update null offset
     offsetManager.update(taskName, systemStreamPartition, null)
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition))
     checkpoint(offsetManager, taskName)
-    assertNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should delete after checkpoint commit
+    intercept[IllegalStateException] {
+      // StartpointManager should stop after last fan out is removed
+      startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
+    }
+    startpointManagerUtil.getStartpointManager.start
+    assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint should delete after checkpoint commit
     val expectedCheckpoint = new Checkpoint(Map(systemStreamPartition -> "47").asJava)
     assertEquals(expectedCheckpoint, checkpointManager.readLastCheckpoint(taskName))
+    startpointManagerUtil.stop
   }
+
   @Test
   def testGetAndSetStartpoint {
     val taskName1 = new TaskName("c")
@@ -106,25 +115,23 @@
     val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
     val config = new MapConfig
     val checkpointManager = getCheckpointManager(systemStreamPartition, taskName1)
-    val startpointManager = getStartpointManager()
+    val startpointManagerUtil = getStartpointManagerUtil()
     val systemAdmins = mock(classOf[SystemAdmins])
     when(systemAdmins.getSystemAdmin("test-system")).thenReturn(getSystemAdmin)
-    val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, getStartpointManager(), systemAdmins, Map(), new OffsetManagerMetrics)
+    val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins, Map(), new OffsetManagerMetrics)
 
     offsetManager.register(taskName1, Set(systemStreamPartition))
     val startpoint1 = new StartpointOldest
-    startpointManager.writeStartpoint(systemStreamPartition, taskName1, startpoint1)
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName1))
+    startpointManagerUtil.getStartpointManager.writeStartpoint(systemStreamPartition, taskName1, startpoint1)
+    assertTrue(startpointManagerUtil.getStartpointManager.readStartpoint(systemStreamPartition, taskName1).isPresent)
+    assertTrue(startpointManagerUtil.getStartpointManager.fanOut(asTaskToSSPMap(taskName1, systemStreamPartition)).keySet().contains(taskName1))
     offsetManager.start
     val startpoint2 = new StartpointUpcoming
     offsetManager.setStartpoint(taskName2, systemStreamPartition, startpoint2)
 
     assertEquals(Option(startpoint1), offsetManager.getStartpoint(taskName1, systemStreamPartition))
     assertEquals(Option(startpoint2), offsetManager.getStartpoint(taskName2, systemStreamPartition))
-
-    assertEquals(startpoint1, startpointManager.readStartpoint(systemStreamPartition, taskName1))
-    // Startpoint written to offset manager, but not directly to startpoint manager.
-    assertNull(startpointManager.readStartpoint(systemStreamPartition, taskName2))
+    startpointManagerUtil.stop
   }
 
   @Test
@@ -137,34 +144,36 @@
     val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
     val config = new MapConfig
     val checkpointManager = getCheckpointManager(systemStreamPartition, taskName)
-    val startpointManager = getStartpointManager()
+    val startpointManagerUtil = getStartpointManagerUtil()
     val systemAdmins = mock(classOf[SystemAdmins])
     when(systemAdmins.getSystemAdmin("test-system")).thenReturn(getSystemAdmin)
-    val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, getStartpointManager(), systemAdmins, Map(), new OffsetManagerMetrics)
+    val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins, Map(), new OffsetManagerMetrics)
     offsetManager.register(taskName, Set(systemStreamPartition))
 
     // Pre-populate startpoint
-    startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
+    startpointManagerUtil.getStartpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
+    assertTrue(startpointManagerUtil.getStartpointManager.fanOut(asTaskToSSPMap(taskName, systemStreamPartition)).keySet().contains(taskName))
     offsetManager.start
     // Should get offset 45 back from the checkpoint manager, which is last processed, and system admin should return 46 as starting offset.
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition))
     checkpoint(offsetManager, taskName)
-    assertNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should delete after checkpoint commit
+    intercept[IllegalStateException] {
+      // StartpointManager should stop after last fan out is removed
+      startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
+    }
+    startpointManagerUtil.getStartpointManager.start
+    assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint should delete after checkpoint commit
     assertEquals("45", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
     offsetManager.update(taskName, systemStreamPartition, "46")
 
     offsetManager.update(taskName, systemStreamPartition, "47")
-    startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
     checkpoint(offsetManager, taskName)
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should only be deleted at first checkpoint
     assertEquals("47", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
 
     offsetManager.update(taskName, systemStreamPartition, "48")
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
     checkpoint(offsetManager, taskName)
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should only be deleted at first checkpoint
     assertEquals("48", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
+    startpointManagerUtil.stop
   }
 
   @Test
@@ -207,22 +216,21 @@
     val checkpoint = new Checkpoint(Map(systemStreamPartition1 -> "45").asJava)
     // Checkpoint manager only has partition 1.
     val checkpointManager = getCheckpointManager(systemStreamPartition1, taskName1)
-    val startpointManager = getStartpointManager()
+    val startpointManagerUtil = getStartpointManagerUtil()
+
     val config = new MapConfig
     val systemAdmins = mock(classOf[SystemAdmins])
     when(systemAdmins.getSystemAdmin("test-system")).thenReturn(getSystemAdmin)
-    val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, getStartpointManager(), systemAdmins)
+    val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins)
     // Register both partitions. Partition 2 shouldn't have a checkpoint.
     offsetManager.register(taskName1, Set(systemStreamPartition1))
     offsetManager.register(taskName2, Set(systemStreamPartition2))
-    startpointManager.writeStartpoint(systemStreamPartition1, taskName1, new StartpointOldest)
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition1, taskName1))
     offsetManager.start
     assertTrue(checkpointManager.isStarted)
     assertEquals(2, checkpointManager.registered.size)
     assertEquals(checkpoint, checkpointManager.readLastCheckpoint(taskName1))
     assertNull(checkpointManager.readLastCheckpoint(taskName2))
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition1, taskName1)) // no checkpoint commit so this should still be there
+    startpointManagerUtil.stop
   }
 
   @Test
@@ -314,7 +322,7 @@
     val checkpointManager = getCheckpointManager1(systemStreamPartition,
                                                  new Checkpoint(Map(systemStreamPartition -> "45", systemStreamPartition2 -> "100").asJava),
                                                  taskName)
-    val startpointManager = getStartpointManager()
+    val startpointManagerUtil = getStartpointManagerUtil()
     val consumer = new SystemConsumerWithCheckpointCallback
     val systemAdmins = mock(classOf[SystemAdmins])
     when(systemAdmins.getSystemAdmin(systemName)).thenReturn(getSystemAdmin)
@@ -325,16 +333,26 @@
     else
       Map()
 
-    val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, getStartpointManager(), systemAdmins,
+    val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins,
       checkpointListeners, new OffsetManagerMetrics)
     offsetManager.register(taskName, Set(systemStreamPartition, systemStreamPartition2))
 
-    startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
+    startpointManagerUtil.getStartpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
+    assertTrue(startpointManagerUtil.getStartpointManager.readStartpoint(systemStreamPartition, taskName).isPresent)
+    assertTrue(startpointManagerUtil.getStartpointManager.fanOut(asTaskToSSPMap(taskName, systemStreamPartition)).keySet().contains(taskName))
+    assertFalse(startpointManagerUtil.getStartpointManager.readStartpoint(systemStreamPartition, taskName).isPresent)
     offsetManager.start
     // Should get offset 45 back from the checkpoint manager, which is last processed, and system admin should return 46 as starting offset.
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition))
     checkpoint(offsetManager, taskName)
-    assertNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint be deleted at first checkpoint
+
+    intercept[IllegalStateException] {
+      // StartpointManager should stop after last fan out is removed
+      startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
+    }
+    startpointManagerUtil.getStartpointManager.start
+    assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint be deleted at first checkpoint
+
     assertEquals("45", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
     assertEquals("100", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition2).getValue)
     assertEquals("45", consumer.recentCheckpoint.get(systemStreamPartition))
@@ -343,10 +361,7 @@
 
     offsetManager.update(taskName, systemStreamPartition, "46")
     offsetManager.update(taskName, systemStreamPartition, "47")
-    startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
     checkpoint(offsetManager, taskName)
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should only be deleted at first checkpoint
     assertEquals("47", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
     assertEquals("100", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition2).getValue)
     assertEquals("47", consumer.recentCheckpoint.get(systemStreamPartition))
@@ -354,15 +369,13 @@
 
     offsetManager.update(taskName, systemStreamPartition, "48")
     offsetManager.update(taskName, systemStreamPartition2, "101")
-    startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
     checkpoint(offsetManager, taskName)
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should only be deleted at first checkpoint
     assertEquals("48", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
     assertEquals("101", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition2).getValue)
     assertEquals("48", consumer.recentCheckpoint.get(systemStreamPartition))
     assertNull(consumer.recentCheckpoint.get(systemStreamPartition2))
     offsetManager.stop
+    startpointManagerUtil.stop
   }
 
   /**
@@ -380,35 +393,38 @@
     val testStreamMetadata = new SystemStreamMetadata(systemStream.getStream, Map(partition -> new SystemStreamPartitionMetadata("0", "1", "2")).asJava)
     val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
     val checkpointManager = getCheckpointManager(systemStreamPartition, taskName)
-    val startpointManager = getStartpointManager()
+    val startpointManagerUtil = getStartpointManagerUtil()
     val systemAdmins = mock(classOf[SystemAdmins])
     when(systemAdmins.getSystemAdmin("test-system")).thenReturn(getSystemAdmin)
-    val offsetManager = OffsetManager(systemStreamMetadata, new MapConfig, checkpointManager, getStartpointManager(), systemAdmins, Map(), new OffsetManagerMetrics)
+    val offsetManager = OffsetManager(systemStreamMetadata, new MapConfig, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins, Map(), new OffsetManagerMetrics)
     offsetManager.register(taskName, Set(systemStreamPartition))
-    startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
+    startpointManagerUtil.getStartpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
+    assertTrue(startpointManagerUtil.getStartpointManager.fanOut(asTaskToSSPMap(taskName, systemStreamPartition)).keySet().contains(taskName))
     offsetManager.start
 
     // Should get offset 45 back from the checkpoint manager, which is last processed, and system admin should return 46 as starting offset.
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
+    assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition))
     checkpoint(offsetManager, taskName)
-    assertNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint be deleted at first checkpoint
+    intercept[IllegalStateException] {
+      // StartpointManager should stop after last fan out is removed
+      startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
+    }
+    startpointManagerUtil.getStartpointManager.start
+    assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint be deleted at first checkpoint
     assertEquals("45", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
 
-    startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
-
     offsetManager.update(taskName, systemStreamPartition, "46")
     // Get checkpoint snapshot like we do at the beginning of TaskInstance.commit()
     val checkpoint46 = offsetManager.buildCheckpoint(taskName)
     offsetManager.update(taskName, systemStreamPartition, "47") // Offset updated before checkpoint
     offsetManager.writeCheckpoint(taskName, checkpoint46)
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should only be deleted at first checkpoint
     assertEquals(Some("47"), offsetManager.getLastProcessedOffset(taskName, systemStreamPartition))
     assertEquals("46", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
 
     // Now write the checkpoint for the latest offset
     val checkpoint47 = offsetManager.buildCheckpoint(taskName)
     offsetManager.writeCheckpoint(taskName, checkpoint47)
-    assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should only be deleted at first checkpoint
+    startpointManagerUtil.stop
     assertEquals(Some("47"), offsetManager.getLastProcessedOffset(taskName, systemStreamPartition))
     assertEquals("47", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
   }
@@ -523,10 +539,9 @@
     }
   }
 
-  private def getStartpointManager() = {
-    val startpointManager = StartpointManagerTestUtil.getStartpointManager()
-    startpointManager.start
-    startpointManager
+  private def getStartpointManagerUtil() = {
+    val startpointManagerUtil = new StartpointManagerTestUtil
+    startpointManagerUtil
   }
 
   private def getSystemAdmin: SystemAdmin = {
@@ -540,4 +555,8 @@
       override def offsetComparator(offset1: String, offset2: String) = null
     }
   }
+
+  private def asTaskToSSPMap(taskName: TaskName, ssps: SystemStreamPartition*) = {
+    Map(taskName -> ssps.toSet.asJava).asJava
+  }
 }
diff --git a/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/BaseCouchbaseTableFunction.java b/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/BaseCouchbaseTableFunction.java
index 813fb97..e805975 100644
--- a/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/BaseCouchbaseTableFunction.java
+++ b/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/BaseCouchbaseTableFunction.java
@@ -22,24 +22,26 @@
 import com.couchbase.client.java.Bucket;
 import com.couchbase.client.java.error.TemporaryFailureException;
 import com.couchbase.client.java.error.TemporaryLockFailureException;
+
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
-import java.io.Serializable;
+
 import java.time.Duration;
 import java.util.List;
+
 import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.samza.context.Context;
-import org.apache.samza.operators.functions.ClosableFunction;
-import org.apache.samza.operators.functions.InitableFunction;
 import org.apache.samza.serializers.Serde;
+import org.apache.samza.table.AsyncReadWriteTable;
+import org.apache.samza.table.remote.BaseTableFunction;
 
 
 /**
  * Base class for {@link CouchbaseTableReadFunction} and {@link CouchbaseTableWriteFunction}
  * @param <V> Type of values to read from / write to couchbase
  */
-public abstract class BaseCouchbaseTableFunction<V> implements InitableFunction, ClosableFunction, Serializable {
+public abstract class BaseCouchbaseTableFunction<V> extends BaseTableFunction {
 
   // Clients
   private final static CouchbaseBucketRegistry COUCHBASE_BUCKET_REGISTRY = new CouchbaseBucketRegistry();
@@ -75,11 +77,9 @@
     environmentConfigs = new CouchbaseEnvironmentConfigs();
   }
 
-  /**
-   * Helper method to initialize {@link Bucket}.
-   */
   @Override
-  public void init(Context context) {
+  public void init(Context context, AsyncReadWriteTable table) {
+    super.init(context, table);
     bucket = COUCHBASE_BUCKET_REGISTRY.getBucket(bucketName, clusterNodes, environmentConfigs);
   }
 
diff --git a/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableReadFunction.java b/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableReadFunction.java
index ae758a5..c05ec55 100644
--- a/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableReadFunction.java
+++ b/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableReadFunction.java
@@ -43,16 +43,23 @@
 import com.couchbase.client.java.document.Document;
 import com.couchbase.client.java.document.JsonDocument;
 import com.couchbase.client.java.document.json.JsonObject;
+
 import com.google.common.base.Preconditions;
+
 import java.util.NoSuchElementException;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.TimeUnit;
+
 import org.apache.commons.lang3.StringUtils;
+
 import org.apache.samza.SamzaException;
 import org.apache.samza.context.Context;
+import org.apache.samza.table.AsyncReadWriteTable;
 import org.apache.samza.table.remote.TableReadFunction;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+
 import rx.Single;
 import rx.SingleSubscriber;
 
@@ -66,6 +73,7 @@
  */
 public class CouchbaseTableReadFunction<V> extends BaseCouchbaseTableFunction<V>
     implements TableReadFunction<String, V> {
+
   private static final Logger LOGGER = LoggerFactory.getLogger(CouchbaseTableReadFunction.class);
 
   protected final Class<? extends Document<?>> documentType;
@@ -83,8 +91,8 @@
   }
 
   @Override
-  public void init(Context context) {
-    super.init(context);
+  public void init(Context context, AsyncReadWriteTable table) {
+    super.init(context, table);
     LOGGER.info("Read function for bucket {} initialized successfully", bucketName);
   }
 
@@ -123,10 +131,10 @@
     return future;
   }
 
-  /**
+  /*
    * Helper method to read bytes from binaryDocument and release the buffer.
    */
-  private void handleGetAsyncBinaryDocument(BinaryDocument binaryDocument, CompletableFuture<V> future, String key) {
+  protected void handleGetAsyncBinaryDocument(BinaryDocument binaryDocument, CompletableFuture<V> future, String key) {
     ByteBuf buffer = binaryDocument.content();
     try {
       byte[] bytes;
diff --git a/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableWriteFunction.java b/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableWriteFunction.java
index cbdcb00..96800ae 100644
--- a/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableWriteFunction.java
+++ b/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableWriteFunction.java
@@ -24,15 +24,21 @@
 import com.couchbase.client.java.document.Document;
 import com.couchbase.client.java.document.JsonDocument;
 import com.couchbase.client.java.document.json.JsonObject;
+
 import com.google.common.base.Preconditions;
+
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.TimeUnit;
+
 import org.apache.commons.lang3.StringUtils;
 import org.apache.samza.SamzaException;
 import org.apache.samza.context.Context;
+import org.apache.samza.table.AsyncReadWriteTable;
 import org.apache.samza.table.remote.TableWriteFunction;
+
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+
 import rx.Single;
 import rx.SingleSubscriber;
 
@@ -45,6 +51,7 @@
  */
 public class CouchbaseTableWriteFunction<V> extends BaseCouchbaseTableFunction<V>
     implements TableWriteFunction<String, V> {
+
   private static final Logger LOGGER = LoggerFactory.getLogger(CouchbaseTableWriteFunction.class);
 
   /**
@@ -59,8 +66,8 @@
   }
 
   @Override
-  public void init(Context context) {
-    super.init(context);
+  public void init(Context context, AsyncReadWriteTable table) {
+    super.init(context, table);
     LOGGER.info("Write function for bucket {} initialized successfully", bucketName);
   }
 
@@ -83,10 +90,10 @@
         String.format("Failed to delete key %s from bucket %s.", key, bucketName));
   }
 
-  /**
+  /*
    * Helper method for putAsync and deleteAsync to convert Single to CompletableFuture.
    */
-  private CompletableFuture<Void> asyncWriteHelper(Single<? extends Document<?>> single, String errorMessage) {
+  protected CompletableFuture<Void> asyncWriteHelper(Single<? extends Document<?>> single, String errorMessage) {
     CompletableFuture<Void> future = new CompletableFuture<>();
     single.subscribe(new SingleSubscriber<Document>() {
       @Override
diff --git a/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableReadFunction.java b/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableReadFunction.java
index 4532fd7..9e388f1 100644
--- a/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableReadFunction.java
+++ b/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableReadFunction.java
@@ -28,16 +28,23 @@
 import com.couchbase.client.java.document.json.JsonObject;
 import com.couchbase.client.java.error.TemporaryFailureException;
 import com.couchbase.client.java.error.TemporaryLockFailureException;
+
 import java.util.List;
 import java.util.concurrent.TimeUnit;
+
 import org.apache.samza.SamzaException;
+import org.apache.samza.context.Context;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.serializers.StringSerde;
+import org.apache.samza.table.AsyncReadWriteTable;
+
 import org.junit.Test;
 import org.junit.runner.RunWith;
+
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
+
 import rx.Observable;
 
 import static org.junit.Assert.*;
@@ -176,7 +183,7 @@
         CouchbaseEnvironmentConfigs.class)).toReturn(bucket);
     CouchbaseTableReadFunction<V> readFunction =
         new CouchbaseTableReadFunction<>(DEFAULT_BUCKET_NAME, valueClass, DEFAULT_CLUSTER_NODE).withSerde(serde);
-    readFunction.init(null);
+    readFunction.init(mock(Context.class), mock(AsyncReadWriteTable.class));
     return readFunction;
   }
 }
diff --git a/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableWriteFunction.java b/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableWriteFunction.java
index c63472e..e600de8 100644
--- a/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableWriteFunction.java
+++ b/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableWriteFunction.java
@@ -26,16 +26,23 @@
 import com.couchbase.client.java.document.json.JsonObject;
 import com.couchbase.client.java.error.TemporaryFailureException;
 import com.couchbase.client.java.error.TemporaryLockFailureException;
+
 import java.util.List;
 import java.util.concurrent.TimeUnit;
+
 import org.apache.samza.SamzaException;
+import org.apache.samza.context.Context;
 import org.apache.samza.serializers.Serde;
 import org.apache.samza.serializers.StringSerde;
+import org.apache.samza.table.AsyncReadWriteTable;
+
 import org.junit.Test;
 import org.junit.runner.RunWith;
+
 import org.powermock.api.mockito.PowerMockito;
 import org.powermock.core.classloader.annotations.PrepareForTest;
 import org.powermock.modules.junit4.PowerMockRunner;
+
 import rx.Observable;
 
 import static org.junit.Assert.*;
@@ -190,9 +197,9 @@
     when(bucket.async()).thenReturn(asyncBucket);
     PowerMockito.stub(PowerMockito.method(CouchbaseBucketRegistry.class, "getBucket", String.class, List.class,
         CouchbaseEnvironmentConfigs.class)).toReturn(bucket);
-    CouchbaseTableWriteFunction<V> readFunction =
+    CouchbaseTableWriteFunction<V> writeFunction =
         new CouchbaseTableWriteFunction<>(DEFAULT_BUCKET_NAME, valueClass, DEFAULT_CLUSTER_NODE).withSerde(serde);
-    readFunction.init(null);
-    return readFunction;
+    writeFunction.init(mock(Context.class), mock(AsyncReadWriteTable.class));
+    return writeFunction;
   }
 }
diff --git a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalTable.java b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalTable.java
index a269452..2a27d18 100644
--- a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalTable.java
+++ b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalTable.java
@@ -39,7 +39,7 @@
  * @param <K> the type of the key in this table
  * @param <V> the type of the value in this table
  */
-public class LocalTable<K, V> extends BaseReadWriteTable<K, V> {
+public final class LocalTable<K, V> extends BaseReadWriteTable<K, V> {
 
   protected final KeyValueStore<K, V> kvStore;
 
@@ -55,7 +55,7 @@
   }
 
   @Override
-  public V get(K key) {
+  public V get(K key, Object ... args) {
     V result = instrument(metrics.numGets, metrics.getNs, () -> kvStore.get(key));
     if (result == null) {
       incCounter(metrics.numMissedLookups);
@@ -64,10 +64,10 @@
   }
 
   @Override
-  public CompletableFuture<V> getAsync(K key) {
+  public CompletableFuture<V> getAsync(K key, Object ... args) {
     CompletableFuture<V> future = new CompletableFuture();
     try {
-      future.complete(get(key));
+      future.complete(get(key, args));
     } catch (Exception e) {
       future.completeExceptionally(e);
     }
@@ -75,17 +75,17 @@
   }
 
   @Override
-  public Map<K, V> getAll(List<K> keys) {
+  public Map<K, V> getAll(List<K> keys, Object ... args) {
     Map<K, V> result = instrument(metrics.numGetAlls, metrics.getAllNs, () -> kvStore.getAll(keys));
     result.values().stream().filter(Objects::isNull).forEach(v -> incCounter(metrics.numMissedLookups));
     return result;
   }
 
   @Override
-  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
+  public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
     CompletableFuture<Map<K, V>> future = new CompletableFuture();
     try {
-      future.complete(getAll(keys));
+      future.complete(getAll(keys, args));
     } catch (Exception e) {
       future.completeExceptionally(e);
     }
@@ -93,7 +93,7 @@
   }
 
   @Override
-  public void put(K key, V value) {
+  public void put(K key, V value, Object ... args) {
     if (value != null) {
       instrument(metrics.numPuts, metrics.putNs, () -> kvStore.put(key, value));
     } else {
@@ -102,10 +102,10 @@
   }
 
   @Override
-  public CompletableFuture<Void> putAsync(K key, V value) {
+  public CompletableFuture<Void> putAsync(K key, V value, Object ... args) {
     CompletableFuture<Void> future = new CompletableFuture();
     try {
-      put(key, value);
+      put(key, value, args);
       future.complete(null);
     } catch (Exception e) {
       future.completeExceptionally(e);
@@ -114,7 +114,7 @@
   }
 
   @Override
-  public void putAll(List<Entry<K, V>> entries) {
+  public void putAll(List<Entry<K, V>> entries, Object ... args) {
     List<Entry<K, V>> toPut = new LinkedList<>();
     List<K> toDelete = new LinkedList<>();
     entries.forEach(e -> {
@@ -135,10 +135,10 @@
   }
 
   @Override
-  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries) {
+  public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries, Object ... args) {
     CompletableFuture<Void> future = new CompletableFuture();
     try {
-      putAll(entries);
+      putAll(entries, args);
       future.complete(null);
     } catch (Exception e) {
       future.completeExceptionally(e);
@@ -147,15 +147,15 @@
   }
 
   @Override
-  public void delete(K key) {
+  public void delete(K key, Object ... args) {
     instrument(metrics.numDeletes, metrics.deleteNs, () -> kvStore.delete(key));
   }
 
   @Override
-  public CompletableFuture<Void> deleteAsync(K key) {
+  public CompletableFuture<Void> deleteAsync(K key, Object ... args) {
     CompletableFuture<Void> future = new CompletableFuture();
     try {
-      delete(key);
+      delete(key, args);
       future.complete(null);
     } catch (Exception e) {
       future.completeExceptionally(e);
@@ -164,15 +164,15 @@
   }
 
   @Override
-  public void deleteAll(List<K> keys) {
+  public void deleteAll(List<K> keys, Object ... args) {
     instrument(metrics.numDeleteAlls, metrics.deleteAllNs, () -> kvStore.deleteAll(keys));
   }
 
   @Override
-  public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
+  public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
     CompletableFuture<Void> future = new CompletableFuture();
     try {
-      deleteAll(keys);
+      deleteAll(keys, args);
       future.complete(null);
     } catch (Exception e) {
       future.completeExceptionally(e);
diff --git a/samza-sql/src/test/java/org/apache/samza/sql/util/RemoteStoreIOResolverTestFactory.java b/samza-sql/src/test/java/org/apache/samza/sql/util/RemoteStoreIOResolverTestFactory.java
index 7ba50fd..4a1d299 100644
--- a/samza-sql/src/test/java/org/apache/samza/sql/util/RemoteStoreIOResolverTestFactory.java
+++ b/samza-sql/src/test/java/org/apache/samza/sql/util/RemoteStoreIOResolverTestFactory.java
@@ -33,6 +33,7 @@
 import org.apache.samza.sql.interfaces.SqlIOResolver;
 import org.apache.samza.sql.interfaces.SqlIOResolverFactory;
 import org.apache.samza.storage.kv.descriptors.RocksDbTableDescriptor;
+import org.apache.samza.table.remote.BaseTableFunction;
 import org.apache.samza.table.remote.TableReadFunction;
 import org.apache.samza.table.remote.TableWriteFunction;
 
@@ -51,7 +52,8 @@
     return new TestRemoteStoreIOResolver(config);
   }
 
-  public static class InMemoryWriteFunction implements TableWriteFunction<Object, Object> {
+  public static class InMemoryWriteFunction extends BaseTableFunction
+      implements TableWriteFunction<Object, Object> {
 
     @Override
     public CompletableFuture<Void> putAsync(Object key, Object record) {
@@ -71,7 +73,8 @@
     }
   }
 
-  static class InMemoryReadFunction implements TableReadFunction<Object, Object> {
+  static class InMemoryReadFunction extends BaseTableFunction
+      implements TableReadFunction<Object, Object> {
 
     @Override
     public CompletableFuture<Object> getAsync(Object key) {
diff --git a/samza-test/src/test/java/org/apache/samza/test/table/TestCouchbaseRemoteTableEndToEnd.java b/samza-test/src/test/java/org/apache/samza/test/table/TestCouchbaseRemoteTableEndToEnd.java
index b05adcd..46f91d4 100644
--- a/samza-test/src/test/java/org/apache/samza/test/table/TestCouchbaseRemoteTableEndToEnd.java
+++ b/samza-test/src/test/java/org/apache/samza/test/table/TestCouchbaseRemoteTableEndToEnd.java
@@ -28,10 +28,12 @@
 import com.couchbase.client.java.env.DefaultCouchbaseEnvironment;
 import com.couchbase.mock.BucketConfiguration;
 import com.couchbase.mock.CouchbaseMock;
+
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
+
 import org.apache.samza.application.StreamApplication;
 import org.apache.samza.config.Config;
 import org.apache.samza.config.MapConfig;
@@ -44,11 +46,13 @@
 import org.apache.samza.system.descriptors.GenericInputDescriptor;
 import org.apache.samza.table.Table;
 import org.apache.samza.table.descriptors.RemoteTableDescriptor;
+import org.apache.samza.table.remote.BaseTableFunction;
 import org.apache.samza.table.remote.TableReadFunction;
 import org.apache.samza.table.remote.couchbase.CouchbaseTableReadFunction;
 import org.apache.samza.table.remote.couchbase.CouchbaseTableWriteFunction;
 import org.apache.samza.test.harness.IntegrationTestHarness;
 import org.apache.samza.test.util.Base64Serializer;
+
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
@@ -132,23 +136,24 @@
       GenericInputDescriptor<String> inputDescriptor =
           inputSystemDescriptor.getInputDescriptor("User", new NoOpSerde<>());
 
-      CouchbaseTableReadFunction<String> readFunction = new CouchbaseTableReadFunction<>(inputBucketName, String.class,
-          "couchbase://127.0.0.1").withBootstrapCarrierDirectPort(couchbaseMock.getCarrierPort(inputBucketName))
+      CouchbaseTableReadFunction<String> readFunction = new CouchbaseTableReadFunction<>(inputBucketName,
+              String.class, "couchbase://127.0.0.1")
+          .withBootstrapCarrierDirectPort(couchbaseMock.getCarrierPort(inputBucketName))
           .withBootstrapHttpDirectPort(couchbaseMock.getHttpPort())
           .withSerde(new StringSerde());
 
-      CouchbaseTableWriteFunction<JsonObject> writeFunction =
-          new CouchbaseTableWriteFunction<>(outputBucketName, JsonObject.class,
-              "couchbase://127.0.0.1").withBootstrapCarrierDirectPort(couchbaseMock.getCarrierPort(outputBucketName))
-              .withBootstrapHttpDirectPort(couchbaseMock.getHttpPort());
+      CouchbaseTableWriteFunction<JsonObject> writeFunction = new CouchbaseTableWriteFunction<>(outputBucketName,
+              JsonObject.class, "couchbase://127.0.0.1")
+          .withBootstrapCarrierDirectPort(couchbaseMock.getCarrierPort(outputBucketName))
+          .withBootstrapHttpDirectPort(couchbaseMock.getHttpPort());
 
-      RemoteTableDescriptor<String, String> inputTableDesc = new RemoteTableDescriptor<>("input-table");
-      inputTableDesc.withReadFunction(readFunction).withRateLimiterDisabled();
+      RemoteTableDescriptor inputTableDesc = new RemoteTableDescriptor<String, String>("input-table")
+          .withReadFunction(readFunction)
+          .withRateLimiterDisabled();
       Table<KV<String, String>> inputTable = appDesc.getTable(inputTableDesc);
 
-      RemoteTableDescriptor<String, JsonObject> outputTableDesc = new RemoteTableDescriptor<>("output-table");
-      outputTableDesc.withReadFunction(new DummyReadFunction<>())
-          .withWriteFunction(writeFunction)
+      RemoteTableDescriptor outputTableDesc = new RemoteTableDescriptor<String, JsonObject>("output-table")
+          .withReadFunction(new DummyReadFunction<>()).withWriteFunction(writeFunction)
           .withRateLimiterDisabled();
       Table<KV<String, JsonObject>> outputTable = appDesc.getTable(outputTableDesc);
 
@@ -189,7 +194,7 @@
     }
   }
 
-  static class DummyReadFunction<K, V> implements TableReadFunction<K, V> {
+  static class DummyReadFunction<K, V> extends BaseTableFunction implements TableReadFunction<K, V> {
     @Override
     public CompletableFuture<V> getAsync(K key) {
       return null;
diff --git a/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTableEndToEnd.java b/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTableEndToEnd.java
index 96f99ca..c48f46c 100644
--- a/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTableEndToEnd.java
+++ b/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTableEndToEnd.java
@@ -32,6 +32,7 @@
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
@@ -42,11 +43,14 @@
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.context.Context;
 import org.apache.samza.context.MockContext;
+import org.apache.samza.operators.TableImpl;
+import org.apache.samza.operators.functions.MapFunction;
 import org.apache.samza.system.descriptors.GenericInputDescriptor;
 import org.apache.samza.metrics.Counter;
 import org.apache.samza.metrics.MetricsRegistry;
 import org.apache.samza.metrics.Timer;
 import org.apache.samza.operators.KV;
+import org.apache.samza.table.ReadWriteTable;
 import org.apache.samza.table.descriptors.TableDescriptor;
 import org.apache.samza.system.descriptors.DelegatingSystemDescriptor;
 import org.apache.samza.runtime.LocalApplicationRunner;
@@ -54,6 +58,7 @@
 import org.apache.samza.table.Table;
 import org.apache.samza.table.descriptors.CachingTableDescriptor;
 import org.apache.samza.table.descriptors.GuavaCacheTableDescriptor;
+import org.apache.samza.table.remote.BaseTableFunction;
 import org.apache.samza.table.remote.RemoteTable;
 import org.apache.samza.table.descriptors.RemoteTableDescriptor;
 import org.apache.samza.table.remote.TableRateLimiter;
@@ -74,20 +79,21 @@
 
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyString;
-import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.withSettings;
+import static org.mockito.Mockito.*;
 
 
 public class TestRemoteTableEndToEnd extends IntegrationTestHarness {
 
+  static Map<String, AtomicInteger> counters = new HashMap<>();
   static Map<String, List<EnrichedPageView>> writtenRecords = new HashMap<>();
 
-  static class InMemoryReadFunction implements TableReadFunction<Integer, Profile> {
+  static class InMemoryProfileReadFunction extends BaseTableFunction
+      implements TableReadFunction<Integer, Profile> {
+
     private final String serializedProfiles;
     private transient Map<Integer, Profile> profileMap;
 
-    private InMemoryReadFunction(String profiles) {
+    private InMemoryProfileReadFunction(String profiles) {
       this.serializedProfiles = profiles;
     }
 
@@ -103,20 +109,32 @@
     }
 
     @Override
+    public CompletableFuture<Profile> getAsync(Integer key, Object ... args) {
+      Profile profile = profileMap.get(key);
+      boolean append = (boolean) args[0];
+      if (append) {
+        profile = new Profile(profile.memberId, profile.company + "-r");
+      }
+      return CompletableFuture.completedFuture(profile);
+    }
+
+    @Override
     public boolean isRetriable(Throwable exception) {
       return false;
     }
 
-    static InMemoryReadFunction getInMemoryReadFunction(String serializedProfiles) {
-      return new InMemoryReadFunction(serializedProfiles);
+    static InMemoryProfileReadFunction getInMemoryReadFunction(String testName, String serializedProfiles) {
+      return new InMemoryProfileReadFunction(serializedProfiles);
     }
   }
 
-  static class InMemoryWriteFunction implements TableWriteFunction<Integer, EnrichedPageView> {
-    private transient List<EnrichedPageView> records;
-    private String testName;
+  static class InMemoryEnrichedPageViewWriteFunction extends BaseTableFunction
+      implements TableWriteFunction<Integer, EnrichedPageView> {
 
-    public InMemoryWriteFunction(String testName) {
+    private String testName;
+    private transient List<EnrichedPageView> records;
+
+    public InMemoryEnrichedPageViewWriteFunction(String testName) {
       this.testName = testName;
     }
 
@@ -130,6 +148,17 @@
 
     @Override
     public CompletableFuture<Void> putAsync(Integer key, EnrichedPageView record) {
+      System.out.println("==> " + testName + " writing " + record.getPageKey());
+      records.add(record);
+      return CompletableFuture.completedFuture(null);
+    }
+
+    @Override
+    public CompletableFuture<Void> putAsync(Integer key, EnrichedPageView record, Object ... args) {
+      boolean append = (boolean) args[0];
+      if (append) {
+        record = new EnrichedPageView(record.pageKey, record.memberId, record.company + "-w");
+      }
       records.add(record);
       return CompletableFuture.completedFuture(null);
     }
@@ -146,7 +175,98 @@
     }
   }
 
-  private <K, V> Table<KV<K, V>> getCachingTable(TableDescriptor<K, V, ?> actualTableDesc, boolean defaultCache, String id, StreamApplicationDescriptor appDesc) {
+  static class InMemoryCounterReadFunction extends BaseTableFunction
+      implements TableReadFunction {
+
+    private final String testName;
+
+    private InMemoryCounterReadFunction(String testName) {
+      this.testName = testName;
+    }
+
+    @Override
+    public CompletableFuture readAsync(int opId, Object... args) {
+      if (1 == opId) {
+        boolean shouldReturnValue = (boolean) args[0];
+        AtomicInteger counter = counters.get(testName);
+        Integer result = shouldReturnValue ? counter.get() : null;
+        return CompletableFuture.completedFuture(result);
+      } else {
+        throw new SamzaException("Invalid opId: " + opId);
+      }
+    }
+
+    @Override
+    public CompletableFuture getAsync(Object key) {
+      throw new SamzaException("Not supported");
+    }
+
+    @Override
+    public boolean isRetriable(Throwable exception) {
+      return false;
+    }
+  }
+
+  static class InMemoryCounterWriteFunction extends BaseTableFunction
+      implements TableWriteFunction {
+
+    private final String testName;
+
+    private InMemoryCounterWriteFunction(String testName) {
+      this.testName = testName;
+    }
+
+    @Override
+    public CompletableFuture<Void> putAsync(Object key, Object record) {
+      throw new SamzaException("Not supported");
+    }
+
+    @Override
+    public CompletableFuture<Void> deleteAsync(Object key) {
+      throw new SamzaException("Not supported");
+    }
+
+    @Override
+    public CompletableFuture writeAsync(int opId, Object... args) {
+      Integer result;
+      AtomicInteger counter = counters.get(testName);
+      boolean shouldModify = (boolean) args[0];
+      switch (opId) {
+        case 1:
+          result = shouldModify ? counter.incrementAndGet() : counter.get();
+          break;
+        case 2:
+          result = shouldModify ? counter.decrementAndGet() : counter.get();
+          break;
+        default:
+          throw new SamzaException("Invalid opId: " + opId);
+      }
+      return CompletableFuture.completedFuture(result);
+    }
+
+    @Override
+    public boolean isRetriable(Throwable exception) {
+      return false;
+    }
+  }
+
+  static class DummyReadFunction extends BaseTableFunction
+      implements TableReadFunction {
+
+    @Override
+    public CompletableFuture getAsync(Object key) {
+      throw new SamzaException("Not supported");
+    }
+
+    @Override
+    public boolean isRetriable(Throwable exception) {
+      return false;
+    }
+  }
+
+  private <K, V> Table<KV<K, V>> getCachingTable(TableDescriptor<K, V, ?> actualTableDesc, boolean defaultCache,
+      StreamApplicationDescriptor appDesc) {
+    String id = actualTableDesc.getTableId();
     CachingTableDescriptor<K, V> cachingDesc;
     if (defaultCache) {
       cachingDesc = new CachingTableDescriptor<>("caching-table-" + id, actualTableDesc);
@@ -161,75 +281,60 @@
     return appDesc.getTable(cachingDesc);
   }
 
-  static class MyReadFunction implements TableReadFunction {
-    @Override
-    public CompletableFuture getAsync(Object key) {
-      return null;
-    }
-
-    @Override
-    public boolean isRetriable(Throwable exception) {
-      return false;
-    }
-  }
-
   private void doTestStreamTableJoinRemoteTable(boolean withCache, boolean defaultCache, String testName) throws Exception {
-    final InMemoryWriteFunction writer = new InMemoryWriteFunction(testName);
 
     writtenRecords.put(testName, new ArrayList<>());
 
     int count = 10;
-    PageView[] pageViews = generatePageViews(count);
-    String profiles = Base64Serializer.serialize(generateProfiles(count));
+    final PageView[] pageViews = generatePageViews(count);
+    final String profiles = Base64Serializer.serialize(generateProfiles(count));
 
-    int partitionCount = 4;
-    Map<String, String> configs = TestLocalTableEndToEnd.getBaseJobConfig(bootstrapUrl(), zkConnect());
-
+    final int partitionCount = 4;
+    final Map<String, String> configs = TestLocalTableEndToEnd.getBaseJobConfig(bootstrapUrl(), zkConnect());
     configs.put("streams.PageView.samza.system", "test");
     configs.put("streams.PageView.source", Base64Serializer.serialize(pageViews));
     configs.put("streams.PageView.partitionCount", String.valueOf(partitionCount));
 
     final RateLimiter readRateLimiter = mock(RateLimiter.class, withSettings().serializable());
-    final RateLimiter writeRateLimiter = mock(RateLimiter.class, withSettings().serializable());
-    final TableRateLimiter.CreditFunction creditFunction = (k, v)->1;
+    final TableRateLimiter.CreditFunction creditFunction = (k, v, args) -> 1;
     final StreamApplication app = appDesc -> {
-      RemoteTableDescriptor<Integer, TestTableData.Profile> inputTableDesc = new RemoteTableDescriptor<>("profile-table-1");
-      inputTableDesc
-          .withReadFunction(InMemoryReadFunction.getInMemoryReadFunction(profiles))
+
+      final RemoteTableDescriptor joinTableDesc =
+              new RemoteTableDescriptor<Integer, TestTableData.Profile>("profile-table-1")
+          .withReadFunction(InMemoryProfileReadFunction.getInMemoryReadFunction(testName, profiles))
           .withRateLimiter(readRateLimiter, creditFunction, null);
 
-      // dummy reader
-      TableReadFunction readFn = new MyReadFunction();
+      final RemoteTableDescriptor outputTableDesc =
+              new RemoteTableDescriptor<Integer, EnrichedPageView>("enriched-page-view-table-1")
+          .withReadFunction(new DummyReadFunction())
+          .withReadRateLimiterDisabled()
+          .withWriteFunction(new InMemoryEnrichedPageViewWriteFunction(testName))
+          .withWriteRateLimit(1000);
 
-      RemoteTableDescriptor<Integer, EnrichedPageView> outputTableDesc = new RemoteTableDescriptor<>("enriched-page-view-table-1");
-      outputTableDesc
-          .withReadFunction(readFn)
-          .withWriteFunction(writer)
-          .withRateLimiter(writeRateLimiter, creditFunction, creditFunction);
-
-      Table<KV<Integer, EnrichedPageView>> outputTable = withCache
-          ? getCachingTable(outputTableDesc, defaultCache, "output", appDesc)
+      final Table<KV<Integer, EnrichedPageView>> outputTable = withCache
+          ? getCachingTable(outputTableDesc, defaultCache, appDesc)
           : appDesc.getTable(outputTableDesc);
 
-      Table<KV<Integer, Profile>> inputTable = withCache
-          ? getCachingTable(inputTableDesc, defaultCache, "input", appDesc)
-          : appDesc.getTable(inputTableDesc);
+      final Table<KV<Integer, Profile>> joinTable = withCache
+          ? getCachingTable(joinTableDesc, defaultCache, appDesc)
+          : appDesc.getTable(joinTableDesc);
 
-      DelegatingSystemDescriptor ksd = new DelegatingSystemDescriptor("test");
-      GenericInputDescriptor<PageView> isd = ksd.getInputDescriptor("PageView", new NoOpSerde<>());
+      final DelegatingSystemDescriptor ksd = new DelegatingSystemDescriptor("test");
+      final GenericInputDescriptor<PageView> isd = ksd.getInputDescriptor("PageView", new NoOpSerde<>());
+
       appDesc.getInputStream(isd)
           .map(pv -> new KV<>(pv.getMemberId(), pv))
-          .join(inputTable, new PageViewToProfileJoinFunction())
+          .join(joinTable, new PageViewToProfileJoinFunction())
           .map(m -> new KV(m.getMemberId(), m))
           .sendTo(outputTable);
     };
 
-    Config config = new MapConfig(configs);
+    final Config config = new MapConfig(configs);
     final LocalApplicationRunner runner = new LocalApplicationRunner(app, config);
     executeRun(runner, config);
     runner.waitForFinish();
 
-    int numExpected = count * partitionCount;
+    final int numExpected = count * partitionCount;
     Assert.assertEquals(numExpected, writtenRecords.get(testName).size());
     Assert.assertTrue(writtenRecords.get(testName).get(0) instanceof EnrichedPageView);
   }
@@ -312,6 +417,162 @@
     }
     table.flush();
     table.close();
-    Assert.assertEquals(2, failureCount);
+  }
+
+  private void doTestReadWriteWithArgs(boolean withCache, boolean defaultCache, String testName) throws Exception {
+
+    writtenRecords.put(testName, new ArrayList<>());
+    counters.put(testName, new AtomicInteger());
+
+    final int count = 10;
+    final PageView[] pageViews = generatePageViews(count);
+    final String profiles = Base64Serializer.serialize(generateProfiles(count));
+
+    final int partitionCount = 4;
+    final Map<String, String> configs = TestLocalTableEndToEnd.getBaseJobConfig(bootstrapUrl(), zkConnect());
+    configs.put("streams.PageView.samza.system", "test");
+    configs.put("streams.PageView.source", Base64Serializer.serialize(pageViews));
+    configs.put("streams.PageView.partitionCount", String.valueOf(partitionCount));
+
+    final RateLimiter readRateLimiter = mock(RateLimiter.class, withSettings().serializable());
+    final TableRateLimiter.CreditFunction creditFunction = (k, v, args) -> 1;
+    final StreamApplication app = appDesc -> {
+
+      final RemoteTableDescriptor joinTableDesc =
+          new RemoteTableDescriptor<Integer, TestTableData.Profile>("profile-table-1")
+              .withReadFunction(InMemoryProfileReadFunction.getInMemoryReadFunction(testName, profiles))
+              .withRateLimiter(readRateLimiter, creditFunction, null);
+
+      final RemoteTableDescriptor outputTableDesc =
+          new RemoteTableDescriptor<Integer, EnrichedPageView>("enriched-page-view-table-1")
+              .withReadFunction(new DummyReadFunction())
+              .withReadRateLimiterDisabled()
+              .withWriteFunction(new InMemoryEnrichedPageViewWriteFunction(testName))
+              .withWriteRateLimit(1000);
+
+      final RemoteTableDescriptor counterTableDesc =
+          new RemoteTableDescriptor("counter-table-1")
+              .withReadFunction(new InMemoryCounterReadFunction(testName))
+              .withWriteFunction(new InMemoryCounterWriteFunction(testName))
+              .withRateLimiterDisabled();
+
+      final Table joinTable = withCache
+          ? getCachingTable(joinTableDesc, defaultCache, appDesc)
+          : appDesc.getTable(joinTableDesc);
+
+      final Table outputTable = withCache
+          ? getCachingTable(outputTableDesc, defaultCache, appDesc)
+          : appDesc.getTable(outputTableDesc);
+
+      final Table counterTable = withCache
+          ? getCachingTable(counterTableDesc, defaultCache, appDesc)
+          : appDesc.getTable(counterTableDesc);
+
+      final String joinTableName = ((TableImpl) joinTable).getTableId();
+      final String outputTableName = ((TableImpl) outputTable).getTableId();
+      final String counterTableName = ((TableImpl) counterTable).getTableId();
+
+      final DelegatingSystemDescriptor ksd = new DelegatingSystemDescriptor("test");
+      final GenericInputDescriptor<PageView> isd = ksd.getInputDescriptor("PageView", new NoOpSerde<>());
+
+      appDesc.getInputStream(isd)
+          .map(new JoinMapFunction(joinTableName, outputTableName, counterTableName));
+    };
+
+    final Config config = new MapConfig(configs);
+    final LocalApplicationRunner runner = new LocalApplicationRunner(app, config);
+    executeRun(runner, config);
+    runner.waitForFinish();
+
+    final int numExpected = count * partitionCount;
+    Assert.assertEquals(numExpected, writtenRecords.get(testName).size());
+    Assert.assertTrue(writtenRecords.get(testName).get(0) instanceof EnrichedPageView);
+    writtenRecords.get(testName).forEach(epv -> {
+        Assert.assertTrue(epv.company.endsWith("-r-w"));
+      });
+    Assert.assertEquals(numExpected, counters.get(testName).get());
+  }
+
+  static private class JoinMapFunction implements MapFunction<PageView, EnrichedPageView> {
+
+    private final String joinTableName;
+    private final String outputTableName;
+    private final String counterTableName;
+    private ReadWriteTable<Integer, Profile> inputTable;
+    private ReadWriteTable<Integer, EnrichedPageView> outputTable;
+    private ReadWriteTable counterTable;
+
+    private JoinMapFunction(String joinTableName, String outputTableName, String counterTableName) {
+      this.joinTableName = joinTableName;
+      this.outputTableName = outputTableName;
+      this.counterTableName = counterTableName;
+    }
+
+    @Override
+    public void init(Context context) {
+      inputTable = context.getTaskContext().getTable(joinTableName);
+      outputTable = context.getTaskContext().getTable(outputTableName);
+      counterTable = context.getTaskContext().getTable(counterTableName);
+    }
+
+    @Override
+    public EnrichedPageView apply(PageView pageView) {
+      try {
+        // Counter manipulation
+        badOpId();
+        Assert.assertNull(getCounterValue(false));
+        Integer beforeValue = getCounterValue(true);
+        Assert.assertEquals(beforeValue, incCounterValue(false));
+        Assert.assertEquals(beforeValue, decCounterValue(false));
+        Assert.assertEquals(Integer.valueOf(beforeValue + 1), incCounterValue(true));
+        Assert.assertEquals(beforeValue, decCounterValue(true));
+        Assert.assertEquals(beforeValue, getCounterValue(true));
+        incCounterValue(true);
+
+        // Generate EnrichedPageView
+        Profile profile = inputTable.getAsync(pageView.memberId, true).join();
+        EnrichedPageView epv = new EnrichedPageView(pageView.getPageKey(), profile.memberId, profile.company);
+        outputTable.putAsync(epv.memberId, epv, true).join();
+        return epv;
+      } catch (Exception ex) {
+        throw new SamzaException(ex);
+      }
+    }
+
+    private Integer getCounterValue(boolean shouldReturn) {
+      return (Integer) counterTable.readAsync(1, shouldReturn).join();
+    }
+
+    private Integer incCounterValue(boolean shouldModifdy) {
+      return (Integer) counterTable.writeAsync(1, shouldModifdy).join();
+    }
+
+    private Integer decCounterValue(boolean shouldModifdy) {
+      return (Integer) counterTable.writeAsync(2, shouldModifdy).join();
+    }
+
+    private void badOpId() {
+      try {
+        outputTable.readAsync(0).join();
+        Assert.fail("Shouldn't reach here");
+      } catch (SamzaException ex) {
+        // Expected exception
+      }
+    }
+  }
+
+  @Test
+  public void testReadWriteWithArgs() throws Exception {
+    doTestReadWriteWithArgs(false, false, "testReadWriteWithArgs");
+  }
+
+  @Test
+  public void testReadWriteWithArgsWithCache() throws Exception {
+    doTestReadWriteWithArgs(true, false, "testReadWriteWithArgsWithCache");
+  }
+
+  @Test
+  public void testReadWriteWithArgsWithDefaultCache() throws Exception {
+    doTestReadWriteWithArgs(true, true, "testReadWriteWithArgsWithDefaultCache");
   }
 }