Merge branch 'master' into 1.2.0
diff --git a/KEYS b/KEYS
index edcb935..8e23b24 100644
--- a/KEYS
+++ b/KEYS
@@ -690,3 +690,40 @@
g2aH
=ukLA
-----END PGP PUBLIC KEY BLOCK-----
+pub 2048R/5B5EB041 2019-05-17
+uid Boris Shkolnik (Key for release signing) <boryas@apache.org>
+sig 3 5B5EB041 2019-05-17 Boris Shkolnik (Key for release signing) <boryas@apache.org>
+sub 2048R/F77F01B7 2019-05-17
+sig 5B5EB041 2019-05-17 Boris Shkolnik (Key for release signing) <boryas@apache.org>
+
+-----BEGIN PGP PUBLIC KEY BLOCK-----
+Version: GnuPG v2.0.22 (GNU/Linux)
+
+mQENBFze7MEBCADJGw76F2EkwUUjEzxBgBLpg5zbNNTnlwgYJ5C/FVPFILK14H+6
+vjN2ID66ZQJCg30w3FltFILgfJhLHjrAabcFbptghqzT0kFCIcrVwiHm7eu7z5C1
+faTkBhKc2xiQjjVRp1Jkr/iNcKct9p6wuhABAiiEOI5iKxzN00Nkf2SaPydP8mUa
+ftEmpHaVLmVwgW5Jh/tCJIoJn0T06uNMZ7C25DI9TJP9UwXkzzw++b7Drd4hp310
+bjrhGb61VU2SbZum4sVi6gNZlXFE0gzfQRLLiTpESPsEaouRmymOG81LaWNbH9+w
+xD5Mc+dRSmuLpYpCTUDcbm18bquN846grToFABEBAAG0PEJvcmlzIFNoa29sbmlr
+IChLZXkgZm9yIHJlbGVhc2Ugc2lnbmluZykgPGJvcnlhc0BhcGFjaGUub3JnPokB
+OQQTAQIAIwUCXN7swQIbAwcLCQgHAwIBBhUIAgkKCwQWAgMBAh4BAheAAAoJEH10
+0M1bXrBBgScIAINYtGoMSPl3UgNaa1bv124milnaa/E9Mu1dkD3nXCcsn2bRtdM7
+oSOxRJ35wHwEEQMWzmEVG+wyxSNdUBaoIAzMV/Ok+yCmKVaPlKINmoUQ2k5RjAle
+kyqz0lcaYluSL5GTqlm2Hw+v5uZ78c2O36PwuhydNedfO1aI7msS89zoeyyVtJQz
+3M+i/fpZ9mbkbcQWU4Bn1HgcsKYCiyDJ6RCoEvQJheIUNKTswj78IgcppvYHtkDF
+NOtUQgw6GozQ0zIVRw7+82qNGkgbPPYjJjfgNuFkgD1GY3Fp4UDs9yX2Qlr0swoP
+f0WOqiDbHPln5EHVtVXwK0994NAPSXzS1tu5AQ0EXN7swQEIANZFy99a86nlddQg
+sldEmg2fhFJDIk8+SXDm2NF5TeTMQ96KZpaFcyOEy1DytA5EzryIynGgV9rt9jhH
+a8HtfH2uiHf3QTIqgnrUimxs2uhxc+Kfx9yfEKaEUBKrnSdrStqxu0DQefbr84KW
+VmhEvKSmzejHoqeV4H7Q+/tmCrMbeVDyOKEQLKYm9Dm3TNJxNRqxNrKzdsjYgqt7
+B1heSXiKzzREsZSWoXLkFX2H6qJpKn6/Mbdu5qD7RIenYZ1BJDPkAkTTCA1rpr1o
+iW2qY7KFzJkFGBnej+G58GcOFwAzIPzzjQsip5XjMedbYpz2V7sdij6K9AO1avDu
+1Fb8iEEAEQEAAYkBHwQYAQIACQUCXN7swQIbDAAKCRB9dNDNW16wQY/0CAC7t4c3
+cSgsWbEIkdY4A/cNAeQQxWmi08x0UMl1xYGfcIOtAwePdQOJJ2TuhBEJKMLmKO64
+IpGzaKwcRaEPBB9lFwJBAJeCizzY76h64LFnus35aA3UJeG3TlyfcggVI5uJG//M
+90TVkxv84Z406Wf2B1RLca3YmsxGdchk6JLD9e2bPXGAFgr+z5sTsHAe6XP8CcoH
+PlhUDyZhNuP6OnvS0r0qkIy4eWZzsuUk3OuJQAkhvqNb/r5eahXtgpasJyIYofeR
+bN7+AsTZ8FGMrkfwrs9a1RFtypldStL70bAl61/yTqngdwFGb0zJi3IxqNzq8Fun
+AE4y9/JMEbF0vpC4
+=lijy
+-----END PGP PUBLIC KEY BLOCK-----
diff --git a/docs/learn/documentation/versioned/api/table-api.md b/docs/learn/documentation/versioned/api/table-api.md
index be87e43..d8695ba 100644
--- a/docs/learn/documentation/versioned/api/table-api.md
+++ b/docs/learn/documentation/versioned/api/table-api.md
@@ -96,24 +96,25 @@
the **`get`** operation.
{% highlight java %}
- /**
+ /**
* Gets the value associated with the specified {@code key}.
*
* @param key the key with which the associated value is to be fetched.
- * @return if found, the value associated with the specified {@code key};
- * otherwise, {@code null}.
+ * @param args additional arguments
+ * @return if found, the value associated with the specified {@code key}; otherwise, {@code null}.
* @throws NullPointerException if the specified {@code key} is {@code null}.
*/
- V get(K key);
+ V get(K key, Object ... args);
- /**
+ /**
* Asynchronously gets the value associated with the specified {@code key}.
*
* @param key the key with which the associated value is to be fetched.
+ * @param args additional arguments
* @return completableFuture for the requested value
* @throws NullPointerException if the specified {@code key} is {@code null}.
*/
- CompletableFuture<V> getAsync(K key);
+ CompletableFuture<V> getAsync(K key, Object ... args);
{% endhighlight %}
@@ -249,6 +250,8 @@
|`num-gets`|`ReadableTable`|Count of `get/getAsync()` operations
|`num-getAlls`|`ReadableTable`|Count of `getAll/getAllAsync()` operations
|`num-missed-lookups`|`ReadableTable`|Count of missed get/getAll() operations
+|`read-ns`|`ReadableTable`|Average latency of `readAsync()` operations|
+|`num-reads`|`ReadableTable`|Count of `readAsync()` operations
|`put-ns`|`ReadWriteTable`|Average latency of `put/putAsync()` operations
|`putAll-ns`|`ReadWriteTable`|Average latency of `putAll/putAllAsync()` operations
|`num-puts`|`ReadWriteTable`|Count of `put/putAsync()` operations
@@ -257,6 +260,8 @@
|`deleteAll-ns`|`ReadWriteTable`|Average latency of `deleteAll/deleteAllAsync()` operations
|`delete-num`|`ReadWriteTable`|Count of `delete/deleteAsync()` operations
|`deleteAll-num`|`ReadWriteTable`|Count of `deleteAll/deleteAllAsync()` operations
+|`num-writes`|`ReadWriteTable`|Count of `writeAsync()` operations
+|`write-ns`|`ReadWriteTable`|Average latency of `writeAsync()` operations
|`flush-ns`|`ReadWriteTable`|Average latency of flush operations
|`flush-num`|`ReadWriteTable`|Count of flush operations
|`hit-rate`|`CachingTable`|Cache hit rate (%)
@@ -343,6 +348,66 @@
They can be found in
[`RetryMetrics`] (https://github.com/apache/samza/blob/master/samza-core/src/main/java/org/apache/samza/table/retry/RetryMetrics.java).
+
+### Supporting Additional Operations
+
+Remote Table allows invoking additional operations on remote store that are not directly
+supported through the Get/Put/Delete methods. Two categories of operations are supported
+
+* Get/Put/Delete operations with additional arguments
+* Arbitrary operations through readAsync() and writeAsync()
+
+We only mandate implementers of table functions to provide implementation for Get/Put/Delete
+without additional arguments. End users can subclass a table function, and invoke operations
+on remote store directly, if they are not supported by a table function.
+
+{% highlight java %}
+ 1 public class MyCouchbaseTableWriteFunction<V> extends CouchbaseTableWriteFunction<V> {
+ 2
+ 3 public static final int OP_COUNTER = 1;
+ 4
+ 5 @Override
+ 6 public <T> CompletableFuture<T> writeAsync(int opId, Object... args) {
+ 7 if (OP_COUNTER == opId) {
+ 8 String id = (String) args[0];
+ 9 Long delta = Long.valueOf(args[1].toString());
+10 return convertToFuture(bucket.async().counter(id, delta));
+11 }
+12 throw new SamzaException("Unknown opId" + opId);
+13 }
+14
+15 public CompletableFuture<Long> counterAsync(String id, long delta) {
+16 return table.writeAsync(OP_COUNTER, id, delta);
+17 }
+18 }
+19
+20 public class MyMapFunc implements MapFunction {
+21
+22 AsyncReadWriteTable table;
+23 MyCouchbaseTableWriteFunction writeFunc;
+24
+25 @Override
+26 public void init(Context context) {
+27 table = context.getTaskContext().getTable(...);
+28 writeFunc = (MyCouchbaseTableWriteFunction) ((RemoteTable) table).getWriteFunction();
+29 }
+30
+31 @Override
+32 public Object apply(Object message) {
+33 return writeFunc.counterAsync(“id”, 100);
+34 }
+35 }
+{% endhighlight %}
+
+The code above illustrates an example of invoking counter() operation on Couchbase.
+
+1. Line 5-13: method writeAsync() is implemented to invoke counter().
+2. Line 15-16: it is then wrapped by a convenience method. Notice here we invoke writeAsync()
+ on the table, so that other value-added features such as rate limiting,
+ retry and batching can participate in this call.
+3. Line 27-28: references to the table and read function are obtained
+4. Line 33: the actual invocation.
+
## Local Table
A table is considered local when its data physically co-exists on the same host
diff --git a/samza-api/src/main/java/org/apache/samza/startpoint/Startpoint.java b/samza-api/src/main/java/org/apache/samza/startpoint/Startpoint.java
index fcca792..a48962b 100644
--- a/samza-api/src/main/java/org/apache/samza/startpoint/Startpoint.java
+++ b/samza-api/src/main/java/org/apache/samza/startpoint/Startpoint.java
@@ -22,12 +22,17 @@
import com.google.common.base.Objects;
import java.time.Instant;
import org.apache.samza.annotation.InterfaceStability;
+import org.codehaus.jackson.annotate.JsonTypeInfo;
/**
* Startpoint represents a position in a stream partition.
*/
@InterfaceStability.Evolving
+@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "@type")
public abstract class Startpoint {
+ // TODO: Remove the @JsonTypeInfo annotation and use the ObjectMapper#enableDefaultTyping method in
+ // StartpointObjectMapper after upgrading jackson version. That method does not add the appropriate type info to the
+ // serialized json with the current version (1.9.13) of jackson.
private final long creationTimestamp;
diff --git a/samza-api/src/main/java/org/apache/samza/table/AsyncReadWriteTable.java b/samza-api/src/main/java/org/apache/samza/table/AsyncReadWriteTable.java
index dc976b5..bf692e0 100644
--- a/samza-api/src/main/java/org/apache/samza/table/AsyncReadWriteTable.java
+++ b/samza-api/src/main/java/org/apache/samza/table/AsyncReadWriteTable.java
@@ -21,6 +21,8 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
+
+import org.apache.samza.SamzaException;
import org.apache.samza.context.Context;
import org.apache.samza.storage.kv.Entry;
@@ -36,19 +38,33 @@
* Asynchronously gets the value associated with the specified {@code key}.
*
* @param key the key with which the associated value is to be fetched.
+ * @param args additional arguments
* @return completableFuture for the requested value
* @throws NullPointerException if the specified {@code key} is {@code null}.
*/
- CompletableFuture<V> getAsync(K key);
+ CompletableFuture<V> getAsync(K key, Object ... args);
/**
* Asynchronously gets the values with which the specified {@code keys} are associated.
*
* @param keys the keys with which the associated values are to be fetched.
+ * @param args additional arguments
* @return completableFuture for the requested entries
* @throws NullPointerException if the specified {@code keys} list, or any of the keys, is {@code null}.
*/
- CompletableFuture<Map<K, V>> getAllAsync(List<K> keys);
+ CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args);
+
+ /**
+ * Asynchronously executes a read operation. opId is used to allow tracking of different
+ * types of operation.
+ * @param opId operation identifier
+ * @param args additional arguments
+ * @param <T> return type
+ * @return completableFuture for read result
+ */
+ default <T> CompletableFuture<T> readAsync(int opId, Object ... args) {
+ throw new SamzaException("Not supported");
+ }
/**
* Asynchronously updates the mapping of the specified key-value pair;
@@ -57,36 +73,52 @@
*
* @param key the key with which the specified {@code value} is to be associated.
* @param value the value with which the specified {@code key} is to be associated.
+ * @param args additional arguments
* @throws NullPointerException if the specified {@code key} is {@code null}.
* @return CompletableFuture for the operation
*/
- CompletableFuture<Void> putAsync(K key, V value);
+ CompletableFuture<Void> putAsync(K key, V value, Object ... args);
/**
* Asynchronously updates the mappings of the specified key-value {@code entries}.
* A key is deleted from the table if its corresponding value is {@code null}.
*
* @param entries the updated mappings to put into this table.
+ * @param args additional arguments
* @throws NullPointerException if any of the specified {@code entries} has {@code null} as key.
* @return CompletableFuture for the operation
*/
- CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries);
+ CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries, Object ... args);
/**
* Asynchronously deletes the mapping for the specified {@code key} from this table (if such mapping exists).
* @param key the key for which the mapping is to be deleted.
+ * @param args additional arguments
* @throws NullPointerException if the specified {@code key} is {@code null}.
* @return CompletableFuture for the operation
*/
- CompletableFuture<Void> deleteAsync(K key);
+ CompletableFuture<Void> deleteAsync(K key, Object ... args);
/**
* Asynchronously deletes the mappings for the specified {@code keys} from this table.
* @param keys the keys for which the mappings are to be deleted.
+ * @param args additional arguments
* @throws NullPointerException if the specified {@code keys} list, or any of the keys, is {@code null}.
* @return CompletableFuture for the operation
*/
- CompletableFuture<Void> deleteAllAsync(List<K> keys);
+ CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args);
+
+ /**
+ * Asynchronously executes a write operation. opId is used to allow tracking of different
+ * types of operation.
+ * @param opId operation identifier
+ * @param args additional arguments
+ * @param <T> return type
+ * @return completableFuture for write result
+ */
+ default <T> CompletableFuture<T> writeAsync(int opId, Object ... args) {
+ throw new SamzaException("Not supported");
+ }
/**
* Initializes the table during container initialization.
diff --git a/samza-api/src/main/java/org/apache/samza/table/ReadWriteTable.java b/samza-api/src/main/java/org/apache/samza/table/ReadWriteTable.java
index a7dad8f..72ee34d 100644
--- a/samza-api/src/main/java/org/apache/samza/table/ReadWriteTable.java
+++ b/samza-api/src/main/java/org/apache/samza/table/ReadWriteTable.java
@@ -21,6 +21,7 @@
import java.util.List;
import java.util.Map;
+import org.apache.samza.SamzaException;
import org.apache.samza.annotation.InterfaceStability;
import org.apache.samza.storage.kv.Entry;
@@ -37,19 +38,34 @@
* Gets the value associated with the specified {@code key}.
*
* @param key the key with which the associated value is to be fetched.
+ * @param args additional arguments
* @return if found, the value associated with the specified {@code key}; otherwise, {@code null}.
* @throws NullPointerException if the specified {@code key} is {@code null}.
*/
- V get(K key);
+ V get(K key, Object ... args);
/**
* Gets the values with which the specified {@code keys} are associated.
*
* @param keys the keys with which the associated values are to be fetched.
+ * @param args additional arguments
* @return a map of the keys that were found and their respective values.
* @throws NullPointerException if the specified {@code keys} list, or any of the keys, is {@code null}.
*/
- Map<K, V> getAll(List<K> keys);
+ Map<K, V> getAll(List<K> keys, Object ... args);
+
+ /**
+ * Executes a read operation. opId is used to allow tracking of different
+ * types of operation.
+ * @param opId operation identifier
+ * @param args additional arguments
+ * @param <T> return type
+ * @return read result
+ */
+
+ default <T> T read(int opId, Object ... args) {
+ throw new SamzaException("Not supported");
+ }
/**
* Updates the mapping of the specified key-value pair;
@@ -59,9 +75,10 @@
*
* @param key the key with which the specified {@code value} is to be associated.
* @param value the value with which the specified {@code key} is to be associated.
+ * @param args additional arguments
* @throws NullPointerException if the specified {@code key} is {@code null}.
*/
- void put(K key, V value);
+ void put(K key, V value, Object ... args);
/**
* Updates the mappings of the specified key-value {@code entries}.
@@ -69,23 +86,38 @@
* A key is deleted from the table if its corresponding value is {@code null}.
*
* @param entries the updated mappings to put into this table.
+ * @param args additional arguments
* @throws NullPointerException if any of the specified {@code entries} has {@code null} as key.
*/
- void putAll(List<Entry<K, V>> entries);
+ void putAll(List<Entry<K, V>> entries, Object ... args);
/**
* Deletes the mapping for the specified {@code key} from this table (if such mapping exists).
*
* @param key the key for which the mapping is to be deleted.
+ * @param args additional arguments
* @throws NullPointerException if the specified {@code key} is {@code null}.
*/
- void delete(K key);
+ void delete(K key, Object ... args);
/**
* Deletes the mappings for the specified {@code keys} from this table.
*
* @param keys the keys for which the mappings are to be deleted.
+ * @param args additional arguments
* @throws NullPointerException if the specified {@code keys} list, or any of the keys, is {@code null}.
*/
- void deleteAll(List<K> keys);
+ void deleteAll(List<K> keys, Object ... args);
+
+ /**
+ * Executes a write operation. opId is used to allow tracking of different
+ * types of operation.
+ * @param opId operation identifier
+ * @param args additional arguments
+ * @param <T> return type
+ * @return write result
+ */
+ default <T> T write(int opId, Object ... args) {
+ throw new SamzaException("Not supported");
+ }
}
diff --git a/samza-api/src/main/java/org/apache/samza/table/remote/BaseTableFunction.java b/samza-api/src/main/java/org/apache/samza/table/remote/BaseTableFunction.java
new file mode 100644
index 0000000..f099ff9
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/table/remote/BaseTableFunction.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.table.remote;
+
+import com.google.common.base.Preconditions;
+
+import org.apache.samza.context.Context;
+import org.apache.samza.table.AsyncReadWriteTable;
+
+
+/**
+ * Base class for all table read and write functions.
+ */
+abstract public class BaseTableFunction implements TableFunction {
+
+ protected AsyncReadWriteTable table;
+
+ @Override
+ public void init(Context context, AsyncReadWriteTable table) {
+ Preconditions.checkNotNull(context, "null context");
+ Preconditions.checkNotNull(table, "null table");
+ this.table = table;
+ }
+}
diff --git a/samza-api/src/main/java/org/apache/samza/table/remote/TableFunction.java b/samza-api/src/main/java/org/apache/samza/table/remote/TableFunction.java
new file mode 100644
index 0000000..517a927
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/table/remote/TableFunction.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.table.remote;
+
+import java.io.Serializable;
+
+import org.apache.samza.annotation.InterfaceStability;
+import org.apache.samza.context.Context;
+import org.apache.samza.operators.functions.ClosableFunction;
+import org.apache.samza.table.AsyncReadWriteTable;
+
+
+/**
+ * The root interface for table read and write function.
+ */
+@InterfaceStability.Unstable
+public interface TableFunction extends TablePart, ClosableFunction, Serializable {
+
+ /**
+ * Initializes the function before any operation.
+ *
+ * @param context the {@link Context} for this task
+ * @param table the {@link TableFunction} which this table function belongs to
+ */
+ void init(Context context, AsyncReadWriteTable table);
+
+ /**
+ * Determine whether the current operation can be retried with the last thrown exception.
+ * @param exception exception thrown by a table operation
+ * @return whether the operation can be retried
+ */
+ boolean isRetriable(Throwable exception);
+}
diff --git a/samza-api/src/main/java/org/apache/samza/table/remote/TableRateLimiter.java b/samza-api/src/main/java/org/apache/samza/table/remote/TableRateLimiter.java
index f52bcbe..81cdecc 100644
--- a/samza-api/src/main/java/org/apache/samza/table/remote/TableRateLimiter.java
+++ b/samza-api/src/main/java/org/apache/samza/table/remote/TableRateLimiter.java
@@ -63,9 +63,20 @@
* Get the number of credits required for the {@code key} and {@code value} pair.
* @param key table key
* @param value table record
+ * @param args additional arguments
* @return number of credits
*/
- int getCredits(K key, V value);
+ int getCredits(K key, V value, Object ... args);
+
+ /**
+ * Get the number of credits required for the {@code opId} and associated {@code args}.
+ * @param opId operation Id
+ * @param args additional arguments
+ * @return number of credits
+ */
+ default int getCredits(int opId, Object ... args) {
+ return 1;
+ }
}
/**
@@ -92,26 +103,30 @@
this.waitTimeMetric = timer;
}
- int getCredits(K key, V value) {
- return (creditFn == null) ? 1 : creditFn.getCredits(key, value);
+ int getCredits(K key, V value, Object ... args) {
+ return (creditFn == null) ? 1 : creditFn.getCredits(key, value, args);
}
- int getCredits(Collection<K> keys) {
+ int getCredits(Collection<K> keys, Object ... args) {
if (creditFn == null) {
return keys.size();
} else {
- return keys.stream().mapToInt(k -> creditFn.getCredits(k, null)).sum();
+ return keys.stream().mapToInt(k -> creditFn.getCredits(k, null, args)).sum();
}
}
- int getEntryCredits(Collection<Entry<K, V>> entries) {
+ int getEntryCredits(Collection<Entry<K, V>> entries, Object ... args) {
if (creditFn == null) {
return entries.size();
} else {
- return entries.stream().mapToInt(e -> creditFn.getCredits(e.getKey(), e.getValue())).sum();
+ return entries.stream().mapToInt(e -> creditFn.getCredits(e.getKey(), e.getValue(), args)).sum();
}
}
+ int getCredits(int opId, Object ... args) {
+ return (creditFn == null) ? 1 : creditFn.getCredits(opId, args);
+ }
+
private void throttle(int credits) {
long startNs = System.nanoTime();
rateLimiter.acquire(Collections.singletonMap(tag, credits));
@@ -123,33 +138,46 @@
/**
* Throttle a request with a key argument if necessary.
* @param key key used for the table request
+ * @param args additional arguments
*/
- public void throttle(K key) {
- throttle(getCredits(key, null));
+ public void throttle(K key, Object ... args) {
+ throttle(getCredits(key, null, args));
}
/**
* Throttle a request with both the key and value arguments if necessary.
* @param key key used for the table request
* @param value value used for the table request
+ * @param args additional arguments
*/
- public void throttle(K key, V value) {
- throttle(getCredits(key, value));
+ public void throttle(K key, V value, Object ... args) {
+ throttle(getCredits(key, value, args));
+ }
+
+ /**
+ * Throttle a request with opId and associated arguments
+ * @param opId operation Id
+ * @param args associated arguments
+ */
+ public void throttle(int opId, Object ... args) {
+ throttle(getCredits(opId, args));
}
/**
* Throttle a request with a collection of keys as the argument if necessary.
* @param keys collection of keys used for the table request
+ * @param args additional arguments
*/
- public void throttle(Collection<K> keys) {
- throttle(getCredits(keys));
+ public void throttle(Collection<K> keys, Object ... args) {
+ throttle(getCredits(keys, args));
}
/**
* Throttle a request with a collection of table records as the argument if necessary.
* @param records collection of records used for the table request
+ * @param args additional arguments
*/
- public void throttleRecords(Collection<Entry<K, V>> records) {
- throttle(getEntryCredits(records));
+ public void throttleRecords(Collection<Entry<K, V>> records, Object ... args) {
+ throttle(getEntryCredits(records, args));
}
}
diff --git a/samza-api/src/main/java/org/apache/samza/table/remote/TableReadFunction.java b/samza-api/src/main/java/org/apache/samza/table/remote/TableReadFunction.java
index 04fc918..03a4f24 100644
--- a/samza-api/src/main/java/org/apache/samza/table/remote/TableReadFunction.java
+++ b/samza-api/src/main/java/org/apache/samza/table/remote/TableReadFunction.java
@@ -19,7 +19,6 @@
package org.apache.samza.table.remote;
-import java.io.Serializable;
import java.util.Collection;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
@@ -28,8 +27,6 @@
import org.apache.samza.SamzaException;
import org.apache.samza.annotation.InterfaceStability;
-import org.apache.samza.operators.functions.ClosableFunction;
-import org.apache.samza.operators.functions.InitableFunction;
import com.google.common.collect.Iterables;
@@ -46,7 +43,7 @@
* @param <V> the type of the value in this table
*/
@InterfaceStability.Unstable
-public interface TableReadFunction<K, V> extends TablePart, InitableFunction, ClosableFunction, Serializable {
+public interface TableReadFunction<K, V> extends TableFunction {
/**
* Fetch single table record for a specified {@code key}. This method must be thread-safe.
* The default implementation calls getAsync and blocks on the completion afterwards.
@@ -69,6 +66,17 @@
CompletableFuture<V> getAsync(K key);
/**
+ * Asynchronously fetch single table record for a specified {@code key} with additional arguments.
+ * This method must be thread-safe.
+ * @param key key for the table record
+ * @param args additional arguments
+ * @return CompletableFuture for the get request
+ */
+ default CompletableFuture<V> getAsync(K key, Object ... args) {
+ throw new SamzaException("Not supported");
+ }
+
+ /**
* Fetch the table {@code records} for specified {@code keys}. This method must be thread-safe.
* The default implementation calls getAllAsync and blocks on the completion afterwards.
* @param keys keys for the table records
@@ -101,11 +109,27 @@
}
/**
- * Determine whether the current operation can be retried with the last thrown exception.
- * @param exception exception thrown by a table operation
- * @return whether the operation can be retried
+ * Asynchronously fetch the table {@code records} for specified {@code keys} and additional arguments.
+ * This method must be thread-safe.
+ * @param keys keys for the table records
+ * @param args additional arguments
+ * @return CompletableFuture for the get request
*/
- boolean isRetriable(Throwable exception);
+ default CompletableFuture<Map<K, V>> getAllAsync(Collection<K> keys, Object ... args) {
+ throw new SamzaException("Not supported");
+ }
+
+ /**
+ * Asynchronously read data from table for specified {@code opId} and additional arguments.
+ * This method must be thread-safe.
+ * @param opId operation identifier
+ * @param args additional arguments
+ * @param <T> return type
+ * @return CompletableFuture for the read request
+ */
+ default <T> CompletableFuture<T> readAsync(int opId, Object ... args) {
+ throw new SamzaException("Not supported");
+ }
// optionally implement readObject() to initialize transient states
}
diff --git a/samza-api/src/main/java/org/apache/samza/table/remote/TableWriteFunction.java b/samza-api/src/main/java/org/apache/samza/table/remote/TableWriteFunction.java
index 3b06664..9e274c1 100644
--- a/samza-api/src/main/java/org/apache/samza/table/remote/TableWriteFunction.java
+++ b/samza-api/src/main/java/org/apache/samza/table/remote/TableWriteFunction.java
@@ -19,7 +19,6 @@
package org.apache.samza.table.remote;
-import java.io.Serializable;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CompletableFuture;
@@ -28,8 +27,6 @@
import org.apache.samza.SamzaException;
import org.apache.samza.annotation.InterfaceStability;
-import org.apache.samza.operators.functions.ClosableFunction;
-import org.apache.samza.operators.functions.InitableFunction;
import org.apache.samza.storage.kv.Entry;
import com.google.common.collect.Iterables;
@@ -47,7 +44,7 @@
* @param <V> the type of the value in this table
*/
@InterfaceStability.Unstable
-public interface TableWriteFunction<K, V> extends TablePart, InitableFunction, ClosableFunction, Serializable {
+public interface TableWriteFunction<K, V> extends TableFunction {
/**
* Store single table {@code record} with specified {@code key}. This method must be thread-safe.
* The default implementation calls putAsync and blocks on the completion afterwards.
@@ -72,6 +69,18 @@
CompletableFuture<Void> putAsync(K key, V record);
/**
+ * Asynchronously store single table {@code record} with specified {@code key} and additional arguments.
+ * This method must be thread-safe.
+ * @param key key for the table record
+ * @param record table record to be written
+ * @param args additional arguments
+ * @return CompletableFuture for the put request
+ */
+ default CompletableFuture<Void> putAsync(K key, V record, Object ... args) {
+ throw new SamzaException("Not supported");
+ }
+
+ /**
* Store the table {@code records} with specified {@code keys}. This method must be thread-safe.
* The default implementation calls putAllAsync and blocks on the completion afterwards.
* @param records table records to be written
@@ -97,6 +106,17 @@
}
/**
+ * Asynchronously store the table {@code records} with specified {@code keys} and additional arguments.
+ * This method must be thread-safe.
+ * @param records table records to be written
+ * @param args additional arguments
+ * @return CompletableFuture for the put request
+ */
+ default CompletableFuture<Void> putAllAsync(Collection<Entry<K, V>> records, Object ... args) {
+ throw new SamzaException("Not supported");
+ }
+
+ /**
* Delete the {@code record} with specified {@code key} from the remote store.
* The default implementation calls deleteAsync and blocks on the completion afterwards.
* @param key key to the table record to be deleted
@@ -117,6 +137,16 @@
CompletableFuture<Void> deleteAsync(K key);
/**
+ * Asynchronously delete the {@code record} with specified {@code key} and additional arguments from the remote store
+ * @param key key to the table record to be deleted
+ * @param args additional arguments
+ * @return CompletableFuture for the delete request
+ */
+ default CompletableFuture<Void> deleteAsync(K key, Object ... args) {
+ throw new SamzaException("Not supported");
+ }
+
+ /**
* Delete all {@code records} with the specified {@code keys} from the remote store
* The default implementation calls deleteAllAsync and blocks on the completion afterwards.
* @param keys keys for the table records to be written
@@ -143,11 +173,28 @@
}
/**
- * Determine whether the current operation can be retried with the last thrown exception.
- * @param exception exception thrown by a table operation
- * @return whether the operation can be retried
+ * Asynchronously delete all {@code records} with the specified {@code keys} and additional arguments from
+ * the remote store.
+ *
+ * @param keys keys for the table records to be written
+ * @param args additional arguments
+ * @return CompletableFuture for the deleteAll request
*/
- boolean isRetriable(Throwable exception);
+ default CompletableFuture<Void> deleteAllAsync(Collection<K> keys, Object ... args) {
+ throw new SamzaException("Not supported");
+ }
+
+ /**
+ * Asynchronously write data to table for specified {@code opId} and additional arguments.
+ * This method must be thread-safe.
+ * @param opId operation identifier
+ * @param args additional arguments
+ * @param <T> return type
+ * @return CompletableFuture for the write request
+ */
+ default <T> CompletableFuture<T> writeAsync(int opId, Object ... args) {
+ throw new SamzaException("Not supported");
+ }
/**
* Flush the remote store (optional)
diff --git a/samza-api/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java b/samza-api/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java
index 262f669..5510dde 100644
--- a/samza-api/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java
+++ b/samza-api/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java
@@ -44,7 +44,7 @@
public TableRateLimiter<String, String> getThrottler(String tag) {
TableRateLimiter.CreditFunction<String, String> credFn =
- (TableRateLimiter.CreditFunction<String, String>) (key, value) -> {
+ (TableRateLimiter.CreditFunction<String, String>) (key, value, args) -> {
int credits = key == null ? 0 : 3;
credits += value == null ? 0 : 3;
return credits;
@@ -83,13 +83,29 @@
}
@Test
+ public void testCreditOpId() {
+ TableRateLimiter<String, String> rateLimitHelper = getThrottler();
+ Assert.assertEquals(1, rateLimitHelper.getCredits(1, 2));
+ }
+
+ @Test
public void testThrottle() {
TableRateLimiter<String, String> rateLimitHelper = getThrottler();
Timer timer = mock(Timer.class);
rateLimitHelper.setTimerMetric(timer);
+ int times = 0;
rateLimitHelper.throttle("foo");
- verify(rateLimitHelper.rateLimiter, times(1)).acquire(anyMapOf(String.class, Integer.class));
- verify(timer, times(1)).update(anyLong());
+ verify(rateLimitHelper.rateLimiter, times(++times)).acquire(anyMap());
+ verify(timer, times(times)).update(anyLong());
+ rateLimitHelper.throttle("foo", "bar");
+ verify(rateLimitHelper.rateLimiter, times(++times)).acquire(anyMap());
+ verify(timer, times(times)).update(anyLong());
+ rateLimitHelper.throttle(Arrays.asList("foo", "bar"));
+ verify(rateLimitHelper.rateLimiter, times(++times)).acquire(anyMap());
+ verify(timer, times(times)).update(anyLong());
+ rateLimitHelper.throttle(1, 2);
+ verify(rateLimitHelper.rateLimiter, times(++times)).acquire(anyMap());
+ verify(timer, times(times)).update(anyLong());
}
@Test
diff --git a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointFanOutPerTask.java b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointFanOutPerTask.java
new file mode 100644
index 0000000..63efabc
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointFanOutPerTask.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.startpoint;
+
+import com.google.common.base.MoreObjects;
+import com.google.common.base.Objects;
+import java.time.Instant;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.samza.serializers.model.SamzaObjectMapper;
+import org.apache.samza.system.SystemStreamPartition;
+import org.codehaus.jackson.map.annotate.JsonDeserialize;
+import org.codehaus.jackson.map.annotate.JsonSerialize;
+
+
+/**
+ * Holds the {@link Startpoint} fan outs for each {@link SystemStreamPartition}. Each StartpointFanOutPerTask maps to a
+ * {@link org.apache.samza.container.TaskName}
+ */
+class StartpointFanOutPerTask {
+ // TODO: Remove the @JsonSerialize and @JsonDeserialize annotations and use the SimpleModule#addKeySerializer and
+ // SimpleModule#addKeyDeserializer methods in StartpointObjectMapper after upgrading jackson version.
+ // Those methods do not work on nested maps with the current version (1.9.13) of jackson.
+
+ @JsonSerialize
+ @JsonDeserialize
+ private final Instant timestamp;
+
+ @JsonDeserialize(keyUsing = SamzaObjectMapper.SystemStreamPartitionKeyDeserializer.class)
+ @JsonSerialize(keyUsing = SamzaObjectMapper.SystemStreamPartitionKeySerializer.class)
+ private final Map<SystemStreamPartition, Startpoint> fanOuts;
+
+ // required for Jackson deserialization
+ StartpointFanOutPerTask() {
+ this(Instant.now());
+ }
+
+ StartpointFanOutPerTask(Instant timestamp) {
+ this.timestamp = timestamp;
+ this.fanOuts = new HashMap<>();
+ }
+
+ // Unused in code, but useful for auditing when the fan out is serialized into the store
+ Instant getTimestamp() {
+ return timestamp;
+ }
+
+ Map<SystemStreamPartition, Startpoint> getFanOuts() {
+ return fanOuts;
+ }
+
+ @Override
+ public String toString() {
+ return MoreObjects.toStringHelper(this).toString();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ StartpointFanOutPerTask that = (StartpointFanOutPerTask) o;
+ return Objects.equal(timestamp, that.timestamp) && Objects.equal(fanOuts, that.fanOuts);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(timestamp, fanOuts);
+ }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointKey.java b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointKey.java
deleted file mode 100644
index 7d4753d..0000000
--- a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointKey.java
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.startpoint;
-
-import com.google.common.base.MoreObjects;
-import com.google.common.base.Objects;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.system.SystemStreamPartition;
-import org.codehaus.jackson.map.annotate.JsonSerialize;
-
-
-@JsonSerialize(using = StartpointKeySerializer.class)
-class StartpointKey {
- private final SystemStreamPartition systemStreamPartition;
- private final TaskName taskName;
-
- // Constructs a startpoint key with SSP. This means the key will apply to all tasks that are mapped to this SSP
- StartpointKey(SystemStreamPartition systemStreamPartition) {
- this(systemStreamPartition, null);
- }
-
- // Constructs a startpoint key with SSP and a task.
- StartpointKey(SystemStreamPartition systemStreamPartition, TaskName taskName) {
- this.systemStreamPartition = systemStreamPartition;
- this.taskName = taskName;
- }
-
- SystemStreamPartition getSystemStreamPartition() {
- return systemStreamPartition;
- }
-
- TaskName getTaskName() {
- return taskName;
- }
-
- @Override
- public String toString() {
- return MoreObjects.toStringHelper(this).toString();
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- StartpointKey that = (StartpointKey) o;
- return Objects.equal(systemStreamPartition, that.systemStreamPartition) && Objects.equal(taskName, that.taskName);
- }
-
- @Override
- public int hashCode() {
- return Objects.hashCode(systemStreamPartition, taskName);
- }
-}
diff --git a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointKeySerializer.java b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointKeySerializer.java
deleted file mode 100644
index d347fe7..0000000
--- a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointKeySerializer.java
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- *//*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- *//*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.startpoint;
-
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.system.SystemStreamPartition;
-import org.codehaus.jackson.JsonGenerator;
-import org.codehaus.jackson.map.JsonSerializer;
-import org.codehaus.jackson.map.SerializerProvider;
-
-
-final class StartpointKeySerializer extends JsonSerializer<StartpointKey> {
- @Override
- public void serialize(StartpointKey startpointKey, JsonGenerator jsonGenerator, SerializerProvider provider) throws
- IOException {
- Map<String, Object> systemStreamPartitionMap = new HashMap<>();
- SystemStreamPartition systemStreamPartition = startpointKey.getSystemStreamPartition();
- TaskName taskName = startpointKey.getTaskName();
- systemStreamPartitionMap.put("system", systemStreamPartition.getSystem());
- systemStreamPartitionMap.put("stream", systemStreamPartition.getStream());
- systemStreamPartitionMap.put("partition", systemStreamPartition.getPartition().getPartitionId());
- if (taskName != null) {
- systemStreamPartitionMap.put("taskName", taskName.getTaskName());
- }
- jsonGenerator.writeObject(systemStreamPartitionMap);
- }
-}
\ No newline at end of file
diff --git a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointManager.java b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointManager.java
index da7f990..0c070cb 100644
--- a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointManager.java
+++ b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointManager.java
@@ -20,87 +20,114 @@
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.ImmutableMap;
+import java.io.IOException;
import java.time.Duration;
import java.time.Instant;
-import java.util.HashSet;
-import java.util.Objects;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.commons.collections4.CollectionUtils;
+import org.apache.commons.collections4.MapUtils;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.lang3.StringUtils;
import org.apache.samza.SamzaException;
-import org.apache.samza.config.Config;
import org.apache.samza.container.TaskName;
import org.apache.samza.coordinator.metadatastore.NamespaceAwareCoordinatorStreamStore;
-import org.apache.samza.job.model.ContainerModel;
-import org.apache.samza.job.model.JobModel;
-import org.apache.samza.job.model.TaskModel;
import org.apache.samza.metadatastore.MetadataStore;
-import org.apache.samza.metadatastore.MetadataStoreFactory;
-import org.apache.samza.metrics.MetricsRegistry;
-import org.apache.samza.serializers.JsonSerdeV2;
import org.apache.samza.system.SystemStreamPartition;
+import org.codehaus.jackson.map.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
- * The StartpointManager reads and writes {@link Startpoint} to the {@link MetadataStore} defined by
- * the configuration task.startpoint.metadata.store.factory.
- *
- * Startpoints are keyed in the MetadataStore by two different formats:
- * 1) Only by {@link SystemStreamPartition}
- * 2) A combination of {@link SystemStreamPartition} and {@link TaskName}
+ * The StartpointManager reads and writes {@link Startpoint} to the provided {@link MetadataStore}
*
* The intention for the StartpointManager is to maintain a strong contract between the caller
* and how Startpoints are stored in the underlying MetadataStore.
+ *
+ * Startpoints are written in the MetadataStore using keys of two different formats:
+ * 1) {@link SystemStreamPartition} only
+ * 2) A combination of {@link SystemStreamPartition} and {@link TaskName}
+ *
+ * Startpoints are then fanned out to a fan out namespace in the MetadataStore by the
+ * {@link org.apache.samza.clustermanager.ClusterBasedJobCoordinator} or the standalone
+ * {@link org.apache.samza.coordinator.JobCoordinator} upon startup and the
+ * {@link org.apache.samza.checkpoint.OffsetManager} gets the fan outs to set the starting offsets per task and per
+ * {@link SystemStreamPartition}. The fan outs are deleted once the offsets are committed to the checkpoint.
+ *
+ * The read, write and delete methods are intended for external callers.
+ * The fan out methods are intended to be used within a job coordinator.
*/
public class StartpointManager {
- private static final Logger LOG = LoggerFactory.getLogger(StartpointManager.class);
- public static final String NAMESPACE = "samza-startpoint-v1";
+ private static final Integer VERSION = 1;
+ public static final String NAMESPACE = "samza-startpoint-v" + VERSION;
static final Duration DEFAULT_EXPIRATION_DURATION = Duration.ofHours(12);
- private final MetadataStore metadataStore;
- private final StartpointSerde startpointSerde = new StartpointSerde();
+ private static final Logger LOG = LoggerFactory.getLogger(StartpointManager.class);
+ private static final String NAMESPACE_FAN_OUT = NAMESPACE + "-fan-out";
- private boolean stopped = false;
+ private final NamespaceAwareCoordinatorStreamStore fanOutStore;
+ private final NamespaceAwareCoordinatorStreamStore readWriteStore;
+ private final ObjectMapper objectMapper = StartpointObjectMapper.getObjectMapper();
- /**
- * Constructs a {@link StartpointManager} instance by instantiating a new metadata store connection.
- * This is primarily used for testing.
- */
- @VisibleForTesting
- StartpointManager(MetadataStoreFactory metadataStoreFactory, Config config, MetricsRegistry metricsRegistry) {
- Preconditions.checkNotNull(metadataStoreFactory, "MetadataStoreFactory cannot be null");
- Preconditions.checkNotNull(config, "Config cannot be null");
- Preconditions.checkNotNull(metricsRegistry, "MetricsRegistry cannot be null");
-
- this.metadataStore = metadataStoreFactory.getMetadataStore(NAMESPACE, config, metricsRegistry);
- LOG.info("StartpointManager created with metadata store: {}", metadataStore.getClass().getCanonicalName());
- this.metadataStore.init();
- }
+ private boolean stopped = true;
/**
* Builds the StartpointManager based upon the provided {@link MetadataStore} that is instantiated.
* Setting up a metadata store instance is expensive which requires opening multiple connections
- * and reading tons of information. Fully instantiated metadata store is taken as a constructor argument
+ * and reading tons of information. Fully instantiated metadata store is passed in as a constructor argument
* to reuse it across different utility classes.
*
* @param metadataStore an instance of {@link MetadataStore} used to read/write the start-points.
*/
public StartpointManager(MetadataStore metadataStore) {
Preconditions.checkNotNull(metadataStore, "MetadataStore cannot be null");
- this.metadataStore = new NamespaceAwareCoordinatorStreamStore(metadataStore, NAMESPACE);
- }
- public void start() {
- // Metadata store lifecycle is managed outside of the StartpointManager, so not starting it.
+ this.readWriteStore = new NamespaceAwareCoordinatorStreamStore(metadataStore, NAMESPACE);
+ this.fanOutStore = new NamespaceAwareCoordinatorStreamStore(metadataStore, NAMESPACE_FAN_OUT);
+ LOG.info("Startpoints are written to namespace: {} and fanned out to namespace: {} in the metadata store of type: {}",
+ NAMESPACE, NAMESPACE_FAN_OUT, metadataStore.getClass().getCanonicalName());
}
/**
- * Writes a {@link Startpoint} that defines the start position for a {@link SystemStreamPartition}.
- * @param ssp The {@link SystemStreamPartition} to map the {@link Startpoint} against.
- * @param startpoint Reference to a Startpoint object.
+ * Perform startup operations. Method is idempotent.
*/
+ public void start() {
+ if (stopped) {
+ LOG.info("starting");
+ readWriteStore.init();
+ fanOutStore.init();
+ stopped = false;
+ } else {
+ LOG.warn("already started");
+ }
+ }
+
+ /**
+ * Perform teardown operations. Method is idempotent.
+ */
+ public void stop() {
+ if (!stopped) {
+ LOG.info("stopping");
+ readWriteStore.close();
+ fanOutStore.close();
+ stopped = true;
+ } else {
+ LOG.warn("already stopped");
+ }
+ }
+
+ /**
+ * Writes a {@link Startpoint} that defines the start position for a {@link SystemStreamPartition}.
+ * @param ssp The {@link SystemStreamPartition} to map the {@link Startpoint} against.
+ * @param startpoint Reference to a Startpoint object.
+ */
public void writeStartpoint(SystemStreamPartition ssp, Startpoint startpoint) {
writeStartpoint(ssp, null, startpoint);
}
@@ -117,7 +144,7 @@
Preconditions.checkNotNull(startpoint, "Startpoint cannot be null");
try {
- metadataStore.put(toStoreKey(ssp, taskName), startpointSerde.toBytes(startpoint));
+ readWriteStore.put(toReadWriteStoreKey(ssp, taskName), objectMapper.writeValueAsBytes(startpoint));
} catch (Exception ex) {
throw new SamzaException(String.format(
"Startpoint for SSP: %s and task: %s may not have been written to the metadata store.", ssp, taskName), ex);
@@ -127,9 +154,10 @@
/**
* Returns the last {@link Startpoint} that defines the start position for a {@link SystemStreamPartition}.
* @param ssp The {@link SystemStreamPartition} to fetch the {@link Startpoint} for.
- * @return {@link Startpoint} for the {@link SystemStreamPartition}, or null if it does not exist or if it is too stale
+ * @return {@link Optional} of {@link Startpoint} for the {@link SystemStreamPartition}.
+ * It is empty if it does not exist or if it is too stale.
*/
- public Startpoint readStartpoint(SystemStreamPartition ssp) {
+ public Optional<Startpoint> readStartpoint(SystemStreamPartition ssp) {
return readStartpoint(ssp, null);
}
@@ -137,23 +165,29 @@
* Returns the {@link Startpoint} for a {@link SystemStreamPartition} and {@link TaskName}.
* @param ssp The {@link SystemStreamPartition} to fetch the {@link Startpoint} for.
* @param taskName The {@link TaskName} to fetch the {@link Startpoint} for.
- * @return {@link Startpoint} for the {@link SystemStreamPartition}, or null if it does not exist or if it is too stale.
+ * @return {@link Optional} of {@link Startpoint} for the {@link SystemStreamPartition} and {@link TaskName}.
+ * It is empty if it does not exist or if it is too stale.
*/
- public Startpoint readStartpoint(SystemStreamPartition ssp, TaskName taskName) {
+ public Optional<Startpoint> readStartpoint(SystemStreamPartition ssp, TaskName taskName) {
Preconditions.checkState(!stopped, "Underlying metadata store not available");
Preconditions.checkNotNull(ssp, "SystemStreamPartition cannot be null");
- byte[] startpointBytes = metadataStore.get(toStoreKey(ssp, taskName));
+ byte[] startpointBytes = readWriteStore.get(toReadWriteStoreKey(ssp, taskName));
- if (Objects.nonNull(startpointBytes)) {
- Startpoint startpoint = startpointSerde.fromBytes(startpointBytes);
- if (Instant.now().minus(DEFAULT_EXPIRATION_DURATION).isBefore(Instant.ofEpochMilli(startpoint.getCreationTimestamp()))) {
- return startpoint; // return if deserializable and if not stale
+ if (ArrayUtils.isNotEmpty(startpointBytes)) {
+ try {
+ Startpoint startpoint = objectMapper.readValue(startpointBytes, Startpoint.class);
+ if (Instant.now().minus(DEFAULT_EXPIRATION_DURATION).isBefore(Instant.ofEpochMilli(startpoint.getCreationTimestamp()))) {
+ return Optional.of(startpoint); // return if deserializable and if not stale
+ }
+ LOG.warn("Creation timestamp: {} of startpoint: {} has crossed the expiration duration: {}. Ignoring it",
+ startpoint.getCreationTimestamp(), startpoint, DEFAULT_EXPIRATION_DURATION);
+ } catch (IOException ex) {
+ throw new SamzaException(ex);
}
- LOG.warn("Stale Startpoint: {} was read. Ignoring.", startpoint);
}
- return null;
+ return Optional.empty();
}
/**
@@ -173,71 +207,148 @@
Preconditions.checkState(!stopped, "Underlying metadata store not available");
Preconditions.checkNotNull(ssp, "SystemStreamPartition cannot be null");
- metadataStore.delete(toStoreKey(ssp, taskName));
+ readWriteStore.delete(toReadWriteStoreKey(ssp, taskName));
}
/**
- * For {@link Startpoint}s keyed only by {@link SystemStreamPartition}, this method re-maps the Startpoints from
- * SystemStreamPartition to SystemStreamPartition+{@link TaskName} for all tasks provided by the {@link JobModel}
+ * The Startpoints that are written to with {@link #writeStartpoint(SystemStreamPartition, Startpoint)} and with
+ * {@link #writeStartpoint(SystemStreamPartition, TaskName, Startpoint)} are moved from a "read-write" namespace
+ * to a "fan out" namespace.
* This method is not atomic or thread-safe. The intent is for the Samza Processor's coordinator to use this
* method to assign the Startpoints to the appropriate tasks.
- * @param jobModel The {@link JobModel} is used to determine which {@link TaskName} each {@link SystemStreamPartition} maps to.
- * @return The list of {@link SystemStreamPartition}s that were fanned out to SystemStreamPartition+TaskName.
+ * @param taskToSSPs Determines which {@link TaskName} each {@link SystemStreamPartition} maps to.
+ * @return The set of active {@link TaskName}s that were fanned out to.
*/
- public Set<SystemStreamPartition> fanOutStartpointsToTasks(JobModel jobModel) {
+ public Map<TaskName, Map<SystemStreamPartition, Startpoint>> fanOut(Map<TaskName, Set<SystemStreamPartition>> taskToSSPs) throws IOException {
Preconditions.checkState(!stopped, "Underlying metadata store not available");
- Preconditions.checkNotNull(jobModel, "JobModel cannot be null");
+ Preconditions.checkArgument(MapUtils.isNotEmpty(taskToSSPs), "taskToSSPs cannot be null or empty");
- HashSet<SystemStreamPartition> sspsToDelete = new HashSet<>();
+ // construct fan out with the existing readWriteStore entries and mark the entries for deletion after fan out
+ Instant now = Instant.now();
+ HashMultimap<SystemStreamPartition, TaskName> deleteKeys = HashMultimap.create();
+ HashMap<TaskName, StartpointFanOutPerTask> fanOuts = new HashMap<>();
+ for (TaskName taskName : taskToSSPs.keySet()) {
+ Set<SystemStreamPartition> ssps = taskToSSPs.get(taskName);
+ if (CollectionUtils.isEmpty(ssps)) {
+ LOG.warn("No SSPs are mapped to taskName: {}", taskName.getTaskName());
+ continue;
+ }
+ for (SystemStreamPartition ssp : ssps) {
+ Optional<Startpoint> startpoint = readStartpoint(ssp); // Read SSP-only key
+ startpoint.ifPresent(sp -> deleteKeys.put(ssp, null));
- // Inspect the job model for TaskName-to-SSPs mapping and re-map startpoints from SSP-only keys to SSP+TaskName keys.
- for (ContainerModel containerModel: jobModel.getContainers().values()) {
- for (TaskModel taskModel : containerModel.getTasks().values()) {
- TaskName taskName = taskModel.getTaskName();
- for (SystemStreamPartition ssp : taskModel.getSystemStreamPartitions()) {
- Startpoint startpoint = readStartpoint(ssp); // Read SSP-only key
- if (startpoint == null) {
- LOG.debug("No Startpoint for SSP: {} in task: {}", ssp, taskName);
- continue;
- }
+ Optional<Startpoint> startpointForTask = readStartpoint(ssp, taskName); // Read SSP+taskName key
+ startpointForTask.ifPresent(sp -> deleteKeys.put(ssp, taskName));
- LOG.info("Grouping Startpoint keyed on SSP: {} to tasks determined by the job model.", ssp);
- Startpoint startpointForTask = readStartpoint(ssp, taskName);
- if (startpointForTask == null || startpointForTask.getCreationTimestamp() < startpoint.getCreationTimestamp()) {
- writeStartpoint(ssp, taskName, startpoint);
- sspsToDelete.add(ssp); // Mark for deletion
- LOG.info("Startpoint for SSP: {} remapped with task: {}.", ssp, taskName);
- } else {
- LOG.info("Startpoint for SSP: {} and task: {} already exists and will not be overwritten.", ssp, taskName);
- }
+ Optional<Startpoint> startpointWithPrecedence = resolveStartpointPrecendence(startpoint, startpointForTask);
+ if (!startpointWithPrecedence.isPresent()) {
+ continue;
+ }
+ fanOuts.putIfAbsent(taskName, new StartpointFanOutPerTask(now));
+ fanOuts.get(taskName).getFanOuts().put(ssp, startpointWithPrecedence.get());
+ }
+ }
+
+ if (fanOuts.isEmpty()) {
+ LOG.debug("No fan outs created.");
+ return ImmutableMap.of();
+ }
+
+ LOG.info("Fanning out to {} tasks", fanOuts.size());
+
+ // Fan out to store
+ for (TaskName taskName : fanOuts.keySet()) {
+ String fanOutKey = toFanOutStoreKey(taskName);
+ StartpointFanOutPerTask newFanOut = fanOuts.get(taskName);
+ fanOutStore.put(fanOutKey, objectMapper.writeValueAsBytes(newFanOut));
+ }
+
+ for (SystemStreamPartition ssp : deleteKeys.keySet()) {
+ for (TaskName taskName : deleteKeys.get(ssp)) {
+ if (taskName != null) {
+ deleteStartpoint(ssp, taskName);
+ } else {
+ deleteStartpoint(ssp);
}
}
}
- // Delete SSP-only keys
- sspsToDelete.forEach(ssp -> {
- deleteStartpoint(ssp);
- LOG.info("All Startpoints for SSP: {} have been grouped to the appropriate tasks and the SSP was deleted.");
- });
-
- return ImmutableSet.copyOf(sspsToDelete);
+ return ImmutableMap.copyOf(fanOuts.entrySet().stream()
+ .collect(Collectors.toMap(fo -> fo.getKey(), fo -> fo.getValue().getFanOuts())));
}
/**
- * Relinquish resources held by the underlying {@link MetadataStore}
+ * Read the fanned out {@link Startpoint}s for the given {@link TaskName}
+ * @param taskName to read the fan out Startpoints for
+ * @return fanned out Startpoints
*/
- public void stop() {
- stopped = true;
- // Metadata store lifecycle is managed outside of the StartpointManager, so not closing it.
+ public Map<SystemStreamPartition, Startpoint> getFanOutForTask(TaskName taskName) throws IOException {
+ Preconditions.checkState(!stopped, "Underlying metadata store not available");
+ Preconditions.checkNotNull(taskName, "TaskName cannot be null");
+
+ byte[] fanOutBytes = fanOutStore.get(toFanOutStoreKey(taskName));
+ if (ArrayUtils.isEmpty(fanOutBytes)) {
+ return ImmutableMap.of();
+ }
+ StartpointFanOutPerTask startpointFanOutPerTask = objectMapper.readValue(fanOutBytes, StartpointFanOutPerTask.class);
+ return ImmutableMap.copyOf(startpointFanOutPerTask.getFanOuts());
+ }
+
+ /**
+ * Deletes the fanned out {@link Startpoint} for the given {@link TaskName}
+ * @param taskName to delete the fan out Startpoints for
+ */
+ public void removeFanOutForTask(TaskName taskName) {
+ Preconditions.checkState(!stopped, "Underlying metadata store not available");
+ Preconditions.checkNotNull(taskName, "TaskName cannot be null");
+
+ fanOutStore.delete(toFanOutStoreKey(taskName));
}
@VisibleForTesting
- MetadataStore getMetadataStore() {
- return metadataStore;
+ MetadataStore getReadWriteStore() {
+ return readWriteStore;
}
- private static String toStoreKey(SystemStreamPartition ssp, TaskName taskName) {
- return new String(new JsonSerdeV2<>().toBytes(new StartpointKey(ssp, taskName)));
+ @VisibleForTesting
+ MetadataStore getFanOutStore() {
+ return fanOutStore;
+ }
+
+ @VisibleForTesting
+ ObjectMapper getObjectMapper() {
+ return objectMapper;
+ }
+
+ private static Optional<Startpoint> resolveStartpointPrecendence(Optional<Startpoint> startpoint1, Optional<Startpoint> startpoint2) {
+ if (startpoint1.isPresent() && startpoint2.isPresent()) {
+ // if SSP-only and SSP+taskName startpoints both exist, resolve to the one with the latest timestamp
+ if (startpoint1.get().getCreationTimestamp() > startpoint2.get().getCreationTimestamp()) {
+ return startpoint1;
+ }
+ return startpoint2;
+ }
+ return startpoint1.isPresent() ? startpoint1 : startpoint2;
+ }
+
+ private static String toReadWriteStoreKey(SystemStreamPartition ssp, TaskName taskName) {
+ Preconditions.checkArgument(ssp != null, "SystemStreamPartition should be defined");
+ Preconditions.checkArgument(StringUtils.isNotBlank(ssp.getSystem()), "System should be defined");
+ Preconditions.checkArgument(StringUtils.isNotBlank(ssp.getStream()), "Stream should be defined");
+ Preconditions.checkArgument(ssp.getPartition() != null, "Partition should be defined");
+
+ String storeKey = ssp.getSystem() + "." + ssp.getStream() + "." + String.valueOf(ssp.getPartition().getPartitionId());
+ if (taskName != null) {
+ storeKey += "." + taskName.getTaskName();
+ }
+ return storeKey;
+ }
+
+ private static String toFanOutStoreKey(TaskName taskName) {
+ Preconditions.checkArgument(taskName != null, "TaskName should be defined");
+ Preconditions.checkArgument(StringUtils.isNotBlank(taskName.getTaskName()), "TaskName should not be blank");
+
+ return taskName.getTaskName();
}
}
diff --git a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointObjectMapper.java b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointObjectMapper.java
new file mode 100644
index 0000000..c5c0434
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointObjectMapper.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.startpoint;
+
+import java.io.IOException;
+import java.time.Instant;
+import org.codehaus.jackson.JsonGenerator;
+import org.codehaus.jackson.JsonParser;
+import org.codehaus.jackson.JsonProcessingException;
+import org.codehaus.jackson.Version;
+import org.codehaus.jackson.map.DeserializationContext;
+import org.codehaus.jackson.map.JsonDeserializer;
+import org.codehaus.jackson.map.JsonSerializer;
+import org.codehaus.jackson.map.ObjectMapper;
+import org.codehaus.jackson.map.SerializerProvider;
+import org.codehaus.jackson.map.jsontype.NamedType;
+import org.codehaus.jackson.map.module.SimpleModule;
+
+
+class StartpointObjectMapper {
+
+ static ObjectMapper getObjectMapper() {
+ ObjectMapper objectMapper = new ObjectMapper();
+ SimpleModule module = new SimpleModule("StartpointModule", new Version(1, 0, 0, ""));
+ module.addSerializer(Instant.class, new CustomInstantSerializer());
+ module.addDeserializer(Instant.class, new CustomInstantDeserializer());
+ objectMapper.registerModule(module);
+
+ // 1. To support polymorphism for serialization, the Startpoint subtypes must be registered here.
+ // 2. The NamedType container class provides a logical name as an external identifier so that the full canonical
+ // class name is not serialized into the json type property.
+ objectMapper.registerSubtypes(new NamedType(StartpointSpecific.class));
+ objectMapper.registerSubtypes(new NamedType(StartpointTimestamp.class));
+ objectMapper.registerSubtypes(new NamedType(StartpointUpcoming.class));
+ objectMapper.registerSubtypes(new NamedType(StartpointOldest.class));
+
+ return objectMapper;
+ }
+
+ private StartpointObjectMapper() { }
+
+ static class CustomInstantSerializer extends JsonSerializer<Instant> {
+ @Override
+ public void serialize(Instant value, JsonGenerator jsonGenerator, SerializerProvider provider)
+ throws IOException, JsonProcessingException {
+ jsonGenerator.writeObject(Long.valueOf(value.toEpochMilli()));
+ }
+ }
+
+ static class CustomInstantDeserializer extends JsonDeserializer<Instant> {
+ @Override
+ public Instant deserialize(JsonParser jsonParser, DeserializationContext context) throws IOException, JsonProcessingException {
+ return Instant.ofEpochMilli(jsonParser.getLongValue());
+ }
+ }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointSerde.java b/samza-core/src/main/java/org/apache/samza/startpoint/StartpointSerde.java
deleted file mode 100644
index 361aef3..0000000
--- a/samza-core/src/main/java/org/apache/samza/startpoint/StartpointSerde.java
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.startpoint;
-
-import com.google.common.collect.ImmutableMap;
-import java.util.Arrays;
-import java.util.LinkedHashMap;
-import org.apache.samza.SamzaException;
-import org.apache.samza.serializers.Serde;
-import org.codehaus.jackson.map.ObjectMapper;
-
-
-class StartpointSerde implements Serde<Startpoint> {
- private static final String STARTPOINT_CLASS = "startpointClass";
- private static final String STARTPOINT_OBJ = "startpointObj";
-
- private final ObjectMapper mapper = new ObjectMapper();
-
- @Override
- public Startpoint fromBytes(byte[] bytes) {
- try {
- LinkedHashMap<String, String> deserialized = mapper.readValue(bytes, LinkedHashMap.class);
- Class<? extends Startpoint> startpointClass =
- (Class<? extends Startpoint>) Class.forName(deserialized.get(STARTPOINT_CLASS));
- return mapper.readValue(deserialized.get(STARTPOINT_OBJ), startpointClass);
- } catch (Exception e) {
- throw new SamzaException(String.format("Exception in de-serializing startpoint bytes: %s",
- Arrays.toString(bytes)), e);
- }
- }
-
- @Override
- public byte[] toBytes(Startpoint startpoint) {
- try {
- ImmutableMap.Builder<String, String> mapBuilder = ImmutableMap.builder();
- mapBuilder.put(STARTPOINT_CLASS, startpoint.getClass().getCanonicalName());
- mapBuilder.put(STARTPOINT_OBJ, mapper.writeValueAsString(startpoint));
- return mapper.writeValueAsBytes(mapBuilder.build());
- } catch (Exception e) {
- throw new SamzaException(String.format("Exception in serializing: %s", startpoint), e);
- }
- }
-}
diff --git a/samza-core/src/main/java/org/apache/samza/table/caching/CachingTable.java b/samza-core/src/main/java/org/apache/samza/table/caching/CachingTable.java
index 2fde79a..dee0767 100644
--- a/samza-core/src/main/java/org/apache/samza/table/caching/CachingTable.java
+++ b/samza-core/src/main/java/org/apache/samza/table/caching/CachingTable.java
@@ -95,9 +95,9 @@
* @param records result map
* @return list of keys missed in the cache
*/
- private List<K> lookupCache(List<K> keys, Map<K, V> records) {
+ private List<K> lookupCache(List<K> keys, Map<K, V> records, Object ... args) {
List<K> missKeys = new ArrayList<>();
- records.putAll(cache.getAll(keys));
+ records.putAll(cache.getAll(keys, args));
keys.forEach(k -> {
if (!records.containsKey(k)) {
missKeys.add(k);
@@ -107,18 +107,18 @@
}
@Override
- public V get(K key) {
+ public V get(K key, Object ... args) {
try {
- return getAsync(key).get();
+ return getAsync(key, args).get();
} catch (Exception e) {
throw new SamzaException(e);
}
}
@Override
- public CompletableFuture<V> getAsync(K key) {
+ public CompletableFuture<V> getAsync(K key, Object ... args) {
incCounter(metrics.numGets);
- V value = cache.get(key);
+ V value = cache.get(key, args);
if (value != null) {
hitCount.incrementAndGet();
return CompletableFuture.completedFuture(value);
@@ -127,12 +127,12 @@
long startNs = clock.nanoTime();
missCount.incrementAndGet();
- return table.getAsync(key).handle((result, e) -> {
+ return table.getAsync(key, args).handle((result, e) -> {
if (e != null) {
throw new SamzaException("Failed to get the record for " + key, e);
} else {
if (result != null) {
- cache.put(key, result);
+ cache.put(key, result, args);
}
updateTimer(metrics.getNs, clock.nanoTime() - startNs);
return result;
@@ -141,16 +141,16 @@
}
@Override
- public Map<K, V> getAll(List<K> keys) {
+ public Map<K, V> getAll(List<K> keys, Object ... args) {
try {
- return getAllAsync(keys).get();
+ return getAllAsync(keys, args).get();
} catch (Exception e) {
throw new SamzaException(e);
}
}
@Override
- public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
+ public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
incCounter(metrics.numGetAlls);
// Make a copy of entries which might be immutable
Map<K, V> getAllResult = new HashMap<>();
@@ -161,14 +161,14 @@
}
long startNs = clock.nanoTime();
- return table.getAllAsync(missingKeys).handle((records, e) -> {
+ return table.getAllAsync(missingKeys, args).handle((records, e) -> {
if (e != null) {
throw new SamzaException("Failed to get records for " + keys, e);
} else {
if (records != null) {
cache.putAll(records.entrySet().stream()
.map(r -> new Entry<>(r.getKey(), r.getValue()))
- .collect(Collectors.toList()));
+ .collect(Collectors.toList()), args);
getAllResult.putAll(records);
}
updateTimer(metrics.getAllNs, clock.nanoTime() - startNs);
@@ -178,28 +178,28 @@
}
@Override
- public void put(K key, V value) {
+ public void put(K key, V value, Object ... args) {
try {
- putAsync(key, value).get();
+ putAsync(key, value, args).get();
} catch (Exception e) {
throw new SamzaException(e);
}
}
@Override
- public CompletableFuture<Void> putAsync(K key, V value) {
+ public CompletableFuture<Void> putAsync(K key, V value, Object ... args) {
incCounter(metrics.numPuts);
Preconditions.checkNotNull(table, "Cannot write to a read-only table: " + table);
long startNs = clock.nanoTime();
- return table.putAsync(key, value).handle((result, e) -> {
+ return table.putAsync(key, value, args).handle((result, e) -> {
if (e != null) {
throw new SamzaException(String.format("Failed to put a record, key=%s, value=%s", key, value), e);
} else if (!isWriteAround) {
if (value == null) {
- cache.delete(key);
+ cache.delete(key, args);
} else {
- cache.put(key, value);
+ cache.put(key, value, args);
}
}
updateTimer(metrics.putNs, clock.nanoTime() - startNs);
@@ -208,24 +208,24 @@
}
@Override
- public void putAll(List<Entry<K, V>> records) {
+ public void putAll(List<Entry<K, V>> records, Object ... args) {
try {
- putAllAsync(records).get();
+ putAllAsync(records, args).get();
} catch (Exception e) {
throw new SamzaException(e);
}
}
@Override
- public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> records) {
+ public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> records, Object ... args) {
incCounter(metrics.numPutAlls);
long startNs = clock.nanoTime();
Preconditions.checkNotNull(table, "Cannot write to a read-only table: " + table);
- return table.putAllAsync(records).handle((result, e) -> {
+ return table.putAllAsync(records, args).handle((result, e) -> {
if (e != null) {
throw new SamzaException("Failed to put records " + records, e);
} else if (!isWriteAround) {
- cache.putAll(records);
+ cache.putAll(records, args);
}
updateTimer(metrics.putAllNs, clock.nanoTime() - startNs);
@@ -234,24 +234,24 @@
}
@Override
- public void delete(K key) {
+ public void delete(K key, Object ... args) {
try {
- deleteAsync(key).get();
+ deleteAsync(key, args).get();
} catch (Exception e) {
throw new SamzaException(e);
}
}
@Override
- public CompletableFuture<Void> deleteAsync(K key) {
+ public CompletableFuture<Void> deleteAsync(K key, Object ... args) {
incCounter(metrics.numDeletes);
long startNs = clock.nanoTime();
Preconditions.checkNotNull(table, "Cannot delete from a read-only table: " + table);
- return table.deleteAsync(key).handle((result, e) -> {
+ return table.deleteAsync(key, args).handle((result, e) -> {
if (e != null) {
throw new SamzaException("Failed to delete the record for " + key, e);
} else if (!isWriteAround) {
- cache.delete(key);
+ cache.delete(key, args);
}
updateTimer(metrics.deleteNs, clock.nanoTime() - startNs);
return result;
@@ -259,24 +259,24 @@
}
@Override
- public void deleteAll(List<K> keys) {
+ public void deleteAll(List<K> keys, Object ... args) {
try {
- deleteAllAsync(keys).get();
+ deleteAllAsync(keys, args).get();
} catch (Exception e) {
throw new SamzaException(e);
}
}
@Override
- public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
+ public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
incCounter(metrics.numDeleteAlls);
long startNs = clock.nanoTime();
Preconditions.checkNotNull(table, "Cannot delete from a read-only table: " + table);
- return table.deleteAllAsync(keys).handle((result, e) -> {
+ return table.deleteAllAsync(keys, args).handle((result, e) -> {
if (e != null) {
throw new SamzaException("Failed to delete the record for " + keys, e);
} else if (!isWriteAround) {
- cache.deleteAll(keys);
+ cache.deleteAll(keys, args);
}
updateTimer(metrics.deleteAllNs, clock.nanoTime() - startNs);
return result;
@@ -284,6 +284,32 @@
}
@Override
+ public <T> CompletableFuture<T> readAsync(int opId, Object... args) {
+ incCounter(metrics.numReads);
+ long startNs = clock.nanoTime();
+ return table.readAsync(opId, args).handle((result, e) -> {
+ if (e != null) {
+ throw new SamzaException("Failed to read, opId=" + opId, e);
+ }
+ updateTimer(metrics.readNs, clock.nanoTime() - startNs);
+ return (T) result;
+ });
+ }
+
+ @Override
+ public <T> CompletableFuture<T> writeAsync(int opId, Object... args) {
+ incCounter(metrics.numWrites);
+ long startNs = clock.nanoTime();
+ return table.writeAsync(opId, args).handle((result, e) -> {
+ if (e != null) {
+ throw new SamzaException("Failed to write, opId=" + opId, e);
+ }
+ updateTimer(metrics.writeNs, clock.nanoTime() - startNs);
+ return (T) result;
+ });
+ }
+
+ @Override
public synchronized void flush() {
incCounter(metrics.numFlushes);
long startNs = clock.nanoTime();
diff --git a/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTable.java b/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTable.java
index d8a5d9c..02083f3 100644
--- a/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTable.java
+++ b/samza-core/src/main/java/org/apache/samza/table/caching/guava/GuavaCacheTable.java
@@ -59,7 +59,7 @@
}
@Override
- public V get(K key) {
+ public V get(K key, Object ... args) {
try {
return getAsync(key).get();
} catch (Exception e) {
@@ -68,7 +68,7 @@
}
@Override
- public CompletableFuture<V> getAsync(K key) {
+ public CompletableFuture<V> getAsync(K key, Object ... args) {
CompletableFuture<V> future = new CompletableFuture<>();
try {
future.complete(cache.getIfPresent(key));
@@ -79,7 +79,7 @@
}
@Override
- public Map<K, V> getAll(List<K> keys) {
+ public Map<K, V> getAll(List<K> keys, Object ... args) {
try {
return getAllAsync(keys).get();
} catch (Exception e) {
@@ -88,7 +88,7 @@
}
@Override
- public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
+ public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
CompletableFuture<Map<K, V>> future = new CompletableFuture<>();
try {
future.complete(cache.getAllPresent(keys));
@@ -99,7 +99,7 @@
}
@Override
- public void put(K key, V value) {
+ public void put(K key, V value, Object ... args) {
try {
putAsync(key, value).get();
} catch (Exception e) {
@@ -108,7 +108,7 @@
}
@Override
- public CompletableFuture<Void> putAsync(K key, V value) {
+ public CompletableFuture<Void> putAsync(K key, V value, Object ... args) {
if (key == null) {
return deleteAsync(key);
}
@@ -124,7 +124,7 @@
}
@Override
- public void putAll(List<Entry<K, V>> entries) {
+ public void putAll(List<Entry<K, V>> entries, Object ... args) {
try {
putAllAsync(entries).get();
} catch (Exception e) {
@@ -133,7 +133,7 @@
}
@Override
- public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries) {
+ public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries, Object ... args) {
CompletableFuture<Void> future = new CompletableFuture<>();
try {
// Separate out put vs delete records
@@ -157,7 +157,7 @@
}
@Override
- public void delete(K key) {
+ public void delete(K key, Object ... args) {
try {
deleteAsync(key).get();
} catch (Exception e) {
@@ -166,7 +166,7 @@
}
@Override
- public CompletableFuture<Void> deleteAsync(K key) {
+ public CompletableFuture<Void> deleteAsync(K key, Object ... args) {
CompletableFuture<Void> future = new CompletableFuture<>();
try {
cache.invalidate(key);
@@ -178,7 +178,7 @@
}
@Override
- public void deleteAll(List<K> keys) {
+ public void deleteAll(List<K> keys, Object ... args) {
try {
deleteAllAsync(keys).get();
} catch (Exception e) {
@@ -187,7 +187,7 @@
}
@Override
- public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
+ public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
CompletableFuture<Void> future = new CompletableFuture<>();
try {
cache.invalidateAll(keys);
diff --git a/samza-core/src/main/java/org/apache/samza/table/ratelimit/AsyncRateLimitedTable.java b/samza-core/src/main/java/org/apache/samza/table/ratelimit/AsyncRateLimitedTable.java
index 69f3dd3..75fed12 100644
--- a/samza-core/src/main/java/org/apache/samza/table/ratelimit/AsyncRateLimitedTable.java
+++ b/samza-core/src/main/java/org/apache/samza/table/ratelimit/AsyncRateLimitedTable.java
@@ -19,10 +19,12 @@
package org.apache.samza.table.ratelimit;
import com.google.common.base.Preconditions;
+
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
+
import org.apache.samza.config.MetricsConfig;
import org.apache.samza.context.Context;
import org.apache.samza.storage.kv.Entry;
@@ -30,6 +32,8 @@
import org.apache.samza.table.remote.TableRateLimiter;
import org.apache.samza.table.utils.TableMetricsUtil;
+import static org.apache.samza.table.BaseReadWriteTable.Func0;
+import static org.apache.samza.table.BaseReadWriteTable.Func1;
/**
* A composable read and/or write rate limited asynchronous table implementation
@@ -60,57 +64,59 @@
}
@Override
- public CompletableFuture<V> getAsync(K key) {
- return isReadRateLimited()
- ? CompletableFuture
- .runAsync(() -> readRateLimiter.throttle(key), rateLimitingExecutor)
- .thenCompose((r) -> table.getAsync(key))
- : table.getAsync(key);
+ public CompletableFuture<V> getAsync(K key, Object ... args) {
+ return doRead(
+ () -> readRateLimiter.throttle(key, args),
+ () -> table.getAsync(key, args));
}
@Override
- public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
- return isReadRateLimited()
- ? CompletableFuture
- .runAsync(() -> readRateLimiter.throttle(keys), rateLimitingExecutor)
- .thenCompose((r) -> table.getAllAsync(keys))
- : table.getAllAsync(keys);
+ public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
+ return doRead(
+ () -> readRateLimiter.throttle(keys, args),
+ () -> table.getAllAsync(keys, args));
}
@Override
- public CompletableFuture<Void> putAsync(K key, V value) {
- return isWriteRateLimited()
- ? CompletableFuture
- .runAsync(() -> writeRateLimiter.throttle(key, value), rateLimitingExecutor)
- .thenCompose((r) -> table.putAsync(key, value))
- : table.putAsync(key, value);
+ public <T> CompletableFuture<T> readAsync(int opId, Object ... args) {
+ return doRead(
+ () -> readRateLimiter.throttle(opId, args),
+ () -> table.readAsync(opId, args));
}
@Override
- public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries) {
- return isWriteRateLimited()
- ? CompletableFuture
- .runAsync(() -> writeRateLimiter.throttleRecords(entries), rateLimitingExecutor)
- .thenCompose((r) -> table.putAllAsync(entries))
- : table.putAllAsync(entries);
+ public CompletableFuture<Void> putAsync(K key, V value, Object ... args) {
+ return doWrite(
+ () -> writeRateLimiter.throttle(key, value, args),
+ () -> table.putAsync(key, value, args));
}
@Override
- public CompletableFuture<Void> deleteAsync(K key) {
- return isWriteRateLimited()
- ? CompletableFuture
- .runAsync(() -> writeRateLimiter.throttle(key), rateLimitingExecutor)
- .thenCompose((r) -> table.deleteAsync(key))
- : table.deleteAsync(key);
+ public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries, Object ... args) {
+ return doWrite(
+ () -> writeRateLimiter.throttleRecords(entries),
+ () -> table.putAllAsync(entries, args));
}
@Override
- public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
- return isWriteRateLimited()
- ? CompletableFuture
- .runAsync(() -> writeRateLimiter.throttle(keys), rateLimitingExecutor)
- .thenCompose((r) -> table.deleteAllAsync(keys))
- : table.deleteAllAsync(keys);
+ public CompletableFuture<Void> deleteAsync(K key, Object ... args) {
+ return doWrite(
+ () -> writeRateLimiter.throttle(key, args),
+ () -> table.deleteAsync(key, args));
+ }
+
+ @Override
+ public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
+ return doWrite(
+ () -> writeRateLimiter.throttle(keys, args),
+ () -> table.deleteAllAsync(keys, args));
+ }
+
+ @Override
+ public <T> CompletableFuture<T> writeAsync(int opId, Object ... args) {
+ return doWrite(
+ () -> writeRateLimiter.throttle(opId, args),
+ () -> table.writeAsync(opId, args));
}
@Override
@@ -145,4 +151,21 @@
private boolean isWriteRateLimited() {
return writeRateLimiter != null;
}
+
+ private <T> CompletableFuture<T> doRead(Func0 throttleFunc, Func1<T> func) {
+ return isReadRateLimited()
+ ? CompletableFuture
+ .runAsync(() -> throttleFunc.apply(), rateLimitingExecutor)
+ .thenCompose((r) -> func.apply())
+ : func.apply();
+ }
+
+ private <T> CompletableFuture<T> doWrite(Func0 throttleFunc, Func1<T> func) {
+ return isWriteRateLimited()
+ ? CompletableFuture
+ .runAsync(() -> throttleFunc.apply(), rateLimitingExecutor)
+ .thenCompose((r) -> func.apply())
+ : func.apply();
+ }
+
}
diff --git a/samza-core/src/main/java/org/apache/samza/table/remote/AsyncRemoteTable.java b/samza-core/src/main/java/org/apache/samza/table/remote/AsyncRemoteTable.java
index d4dbc03..4b1851b 100644
--- a/samza-core/src/main/java/org/apache/samza/table/remote/AsyncRemoteTable.java
+++ b/samza-core/src/main/java/org/apache/samza/table/remote/AsyncRemoteTable.java
@@ -19,9 +19,11 @@
package org.apache.samza.table.remote;
import com.google.common.base.Preconditions;
+
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
+
import org.apache.samza.context.Context;
import org.apache.samza.storage.kv.Entry;
import org.apache.samza.table.AsyncReadWriteTable;
@@ -46,45 +48,66 @@
}
@Override
- public CompletableFuture<V> getAsync(K key) {
- return readFn.getAsync(key);
+ public CompletableFuture<V> getAsync(K key, Object ... args) {
+ return args.length > 0
+ ? readFn.getAsync(key, args)
+ : readFn.getAsync(key);
}
@Override
- public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
- return readFn.getAllAsync(keys);
+ public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
+ return args.length > 0
+ ? readFn.getAllAsync(keys, args)
+ : readFn.getAllAsync(keys);
}
@Override
- public CompletableFuture<Void> putAsync(K key, V value) {
+ public <T> CompletableFuture<T> readAsync(int opId, Object... args) {
+ return readFn.readAsync(opId, args);
+ }
+
+ @Override
+ public CompletableFuture<Void> putAsync(K key, V record, Object... args) {
Preconditions.checkNotNull(writeFn, "null writeFn");
- return writeFn.putAsync(key, value);
+ return args.length > 0
+ ? writeFn.putAsync(key, record, args)
+ : writeFn.putAsync(key, record);
}
@Override
- public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries) {
+ public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries, Object ... args) {
Preconditions.checkNotNull(writeFn, "null writeFn");
- return writeFn.putAllAsync(entries);
+ return args.length > 0
+ ? writeFn.putAllAsync(entries, args)
+ : writeFn.putAllAsync(entries);
}
@Override
- public CompletableFuture<Void> deleteAsync(K key) {
+ public CompletableFuture<Void> deleteAsync(K key, Object... args) {
Preconditions.checkNotNull(writeFn, "null writeFn");
- return writeFn.deleteAsync(key);
+ return args.length > 0
+ ? writeFn.deleteAsync(key, args)
+ : writeFn.deleteAsync(key);
}
@Override
- public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
+ public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
Preconditions.checkNotNull(writeFn, "null writeFn");
- return writeFn.deleteAllAsync(keys);
+ return args.length > 0
+ ? writeFn.deleteAllAsync(keys, args)
+ : writeFn.deleteAllAsync(keys);
+ }
+
+ @Override
+ public <T> CompletableFuture<T> writeAsync(int opId, Object... args) {
+ Preconditions.checkNotNull(writeFn, "null writeFn");
+ return writeFn.writeAsync(opId, args);
}
@Override
public void init(Context context) {
- readFn.init(context);
- if (writeFn != null) {
- writeFn.init(context);
- }
+ // Note: Initialization of table functions is done in {@link RemoteTable#init(Context)},
+ // as we need to pass in the reference to the top level table
}
@Override
diff --git a/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTable.java b/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTable.java
index 8301661..23f4c01 100644
--- a/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTable.java
+++ b/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTable.java
@@ -74,8 +74,8 @@
* @param <K> the type of the key in this table
* @param <V> the type of the value in this table
*/
-public class RemoteTable<K, V> extends BaseReadWriteTable<K, V>
- implements ReadWriteTable<K, V> {
+public final class RemoteTable<K, V> extends BaseReadWriteTable<K, V>
+ implements ReadWriteTable<K, V>, AsyncReadWriteTable<K, V> {
// Read/write functions
protected final TableReadFunction<K, V> readFn;
@@ -144,18 +144,18 @@
}
@Override
- public V get(K key) {
+ public V get(K key, Object ... args) {
try {
- return getAsync(key).get();
+ return getAsync(key, args).get();
} catch (Exception e) {
throw new SamzaException(e);
}
}
@Override
- public CompletableFuture<V> getAsync(K key) {
+ public CompletableFuture<V> getAsync(K key, Object ... args) {
Preconditions.checkNotNull(key, "null key");
- return instrument(() -> asyncTable.getAsync(key), metrics.numGets, metrics.getNs)
+ return instrument(() -> asyncTable.getAsync(key, args), metrics.numGets, metrics.getNs)
.handle((result, e) -> {
if (e != null) {
throw new SamzaException("Failed to get the records for " + key, e);
@@ -168,21 +168,21 @@
}
@Override
- public Map<K, V> getAll(List<K> keys) {
+ public Map<K, V> getAll(List<K> keys, Object ... args) {
try {
- return getAllAsync(keys).get();
+ return getAllAsync(keys, args).get();
} catch (Exception e) {
throw new SamzaException(e);
}
}
@Override
- public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
+ public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
Preconditions.checkNotNull(keys, "null keys");
if (keys.isEmpty()) {
return CompletableFuture.completedFuture(Collections.EMPTY_MAP);
}
- return instrument(() -> asyncTable.getAllAsync(keys), metrics.numGetAlls, metrics.getAllNs)
+ return instrument(() -> asyncTable.getAllAsync(keys, args), metrics.numGetAlls, metrics.getAllNs)
.handle((result, e) -> {
if (e != null) {
throw new SamzaException("Failed to get the records for " + keys, e);
@@ -193,39 +193,56 @@
}
@Override
- public void put(K key, V value) {
+ public <T> T read(int opId, Object ... args) {
try {
- putAsync(key, value).get();
+ return (T) readAsync(opId, args).get();
} catch (Exception e) {
throw new SamzaException(e);
}
}
@Override
- public CompletableFuture<Void> putAsync(K key, V value) {
+ public <T> CompletableFuture<T> readAsync(int opId, Object ... args) {
+ return (CompletableFuture<T>) instrument(() -> asyncTable.readAsync(opId, args), metrics.numReads, metrics.readNs)
+ .exceptionally(e -> {
+ throw new SamzaException(String.format("Failed to read, opId=%d", opId), e);
+ });
+ }
+
+ @Override
+ public void put(K key, V value, Object ... args) {
+ try {
+ putAsync(key, value, args).get();
+ } catch (Exception e) {
+ throw new SamzaException(e);
+ }
+ }
+
+ @Override
+ public CompletableFuture<Void> putAsync(K key, V value, Object ... args) {
Preconditions.checkNotNull(writeFn, "null write function");
Preconditions.checkNotNull(key, "null key");
if (value == null) {
- return deleteAsync(key);
+ return deleteAsync(key, args);
}
- return instrument(() -> asyncTable.putAsync(key, value), metrics.numPuts, metrics.putNs)
+ return instrument(() -> asyncTable.putAsync(key, value, args), metrics.numPuts, metrics.putNs)
.exceptionally(e -> {
throw new SamzaException("Failed to put a record with key=" + key, (Throwable) e);
});
}
@Override
- public void putAll(List<Entry<K, V>> entries) {
+ public void putAll(List<Entry<K, V>> entries, Object ... args) {
try {
- putAllAsync(entries).get();
+ putAllAsync(entries, args).get();
} catch (Exception e) {
throw new SamzaException(e);
}
}
@Override
- public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> records) {
+ public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> records, Object ... args) {
Preconditions.checkNotNull(writeFn, "null write function");
Preconditions.checkNotNull(records, "null records");
@@ -241,12 +258,12 @@
CompletableFuture<Void> deleteFuture = deleteKeys.isEmpty()
? CompletableFuture.completedFuture(null)
- : deleteAllAsync(deleteKeys);
+ : deleteAllAsync(deleteKeys, args);
// Return the combined future
return CompletableFuture.allOf(
deleteFuture,
- instrument(() -> asyncTable.putAllAsync(putRecords), metrics.numPutAlls, metrics.putAllNs))
+ instrument(() -> asyncTable.putAllAsync(putRecords, args), metrics.numPutAlls, metrics.putAllNs))
.exceptionally(e -> {
String strKeys = records.stream().map(r -> r.getKey().toString()).collect(Collectors.joining(","));
throw new SamzaException(String.format("Failed to put records with keys=" + strKeys), e);
@@ -254,44 +271,61 @@
}
@Override
- public void delete(K key) {
+ public void delete(K key, Object ... args) {
try {
- deleteAsync(key).get();
+ deleteAsync(key, args).get();
} catch (Exception e) {
throw new SamzaException(e);
}
}
@Override
- public CompletableFuture<Void> deleteAsync(K key) {
+ public CompletableFuture<Void> deleteAsync(K key, Object ... args) {
Preconditions.checkNotNull(writeFn, "null write function");
Preconditions.checkNotNull(key, "null key");
- return instrument(() -> asyncTable.deleteAsync(key), metrics.numDeletes, metrics.deleteNs)
+ return instrument(() -> asyncTable.deleteAsync(key, args), metrics.numDeletes, metrics.deleteNs)
.exceptionally(e -> {
throw new SamzaException(String.format("Failed to delete the record for " + key), (Throwable) e);
});
}
@Override
- public void deleteAll(List<K> keys) {
+ public void deleteAll(List<K> keys, Object ... args) {
try {
- deleteAllAsync(keys).get();
+ deleteAllAsync(keys, args).get();
} catch (Exception e) {
throw new SamzaException(e);
}
}
@Override
- public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
+ public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
Preconditions.checkNotNull(writeFn, "null write function");
Preconditions.checkNotNull(keys, "null keys");
if (keys.isEmpty()) {
return CompletableFuture.completedFuture(null);
}
- return instrument(() -> asyncTable.deleteAllAsync(keys), metrics.numDeleteAlls, metrics.deleteAllNs)
+ return instrument(() -> asyncTable.deleteAllAsync(keys, args), metrics.numDeleteAlls, metrics.deleteAllNs)
.exceptionally(e -> {
- throw new SamzaException(String.format("Failed to delete records for " + keys), (Throwable) e);
+ throw new SamzaException(String.format("Failed to delete records for " + keys), e);
+ });
+ }
+
+ @Override
+ public <T> T write(int opId, Object ... args) {
+ try {
+ return (T) writeAsync(opId, args).get();
+ } catch (Exception e) {
+ throw new SamzaException(e);
+ }
+ }
+
+ @Override
+ public <T> CompletableFuture<T> writeAsync(int opId, Object... args) {
+ return (CompletableFuture<T>) instrument(() -> asyncTable.writeAsync(opId, args), metrics.numWrites, metrics.writeNs)
+ .exceptionally(e -> {
+ throw new SamzaException(String.format("Failed to write, opId=%d", opId), e);
});
}
@@ -299,6 +333,10 @@
public void init(Context context) {
super.init(context);
asyncTable.init(context);
+ readFn.init(context, this);
+ if (writeFn != null) {
+ writeFn.init(context, this);
+ }
}
@Override
@@ -320,6 +358,14 @@
asyncTable.close();
}
+ public TableReadFunction<K, V> getReadFunction() {
+ return readFn;
+ }
+
+ public TableWriteFunction<K, V> getWriteFunction() {
+ return writeFn;
+ }
+
protected <T> CompletableFuture<T> instrument(Func1<T> func, Counter counter, Timer timer) {
incCounter(counter);
final long startNs = clock.nanoTime();
diff --git a/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTableProvider.java b/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTableProvider.java
index 1716244..aca0a4b 100644
--- a/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTableProvider.java
+++ b/samza-core/src/main/java/org/apache/samza/table/remote/RemoteTableProvider.java
@@ -20,6 +20,7 @@
package org.apache.samza.table.remote;
import com.google.common.base.Preconditions;
+
import org.apache.samza.config.JavaTableConfig;
import org.apache.samza.table.ReadWriteTable;
import org.apache.samza.table.descriptors.RemoteTableDescriptor;
@@ -64,7 +65,7 @@
JavaTableConfig tableConfig = new JavaTableConfig(context.getJobContext().getConfig());
// Read part
- TableReadFunction readFn = getReadFn(tableConfig);
+ TableReadFunction readFn = deserializeObject(tableConfig, RemoteTableDescriptor.READ_FN);
RateLimiter rateLimiter = deserializeObject(tableConfig, RemoteTableDescriptor.RATE_LIMITER);
if (rateLimiter != null) {
rateLimiter.init(this.context);
@@ -77,9 +78,9 @@
TableRetryPolicy readRetryPolicy = deserializeObject(tableConfig, RemoteTableDescriptor.READ_RETRY_POLICY);
// Write part
- TableWriteFunction writeFn = getWriteFn(tableConfig);
TableRateLimiter writeRateLimiter = null;
TableRetryPolicy writeRetryPolicy = null;
+ TableWriteFunction writeFn = deserializeObject(tableConfig, RemoteTableDescriptor.WRITE_FN);
if (writeFn != null) {
TableRateLimiter.CreditFunction<?, ?> writeCreditFn = deserializeObject(tableConfig, RemoteTableDescriptor.WRITE_CREDIT_FN);
writeRateLimiter = rateLimiter != null && rateLimiter.getSupportedTags().contains(RemoteTableDescriptor.RL_WRITE_TAG)
@@ -130,7 +131,9 @@
super.close();
tables.forEach(t -> t.close());
rateLimitingExecutors.values().forEach(e -> e.shutdown());
+ rateLimitingExecutors.clear();
callbackExecutors.values().forEach(e -> e.shutdown());
+ callbackExecutors.clear();
}
private <T> T deserializeObject(JavaTableConfig tableConfig, String key) {
@@ -141,22 +144,6 @@
return SerdeUtils.deserialize(key, entry);
}
- private TableReadFunction<?, ?> getReadFn(JavaTableConfig tableConfig) {
- TableReadFunction<?, ?> readFn = deserializeObject(tableConfig, RemoteTableDescriptor.READ_FN);
- if (readFn != null) {
- readFn.init(this.context);
- }
- return readFn;
- }
-
- private TableWriteFunction<?, ?> getWriteFn(JavaTableConfig tableConfig) {
- TableWriteFunction<?, ?> writeFn = deserializeObject(tableConfig, RemoteTableDescriptor.WRITE_FN);
- if (writeFn != null) {
- writeFn.init(this.context);
- }
- return writeFn;
- }
-
private ScheduledExecutorService createRetryExecutor() {
return Executors.newSingleThreadScheduledExecutor(runnable -> {
Thread thread = new Thread(runnable);
diff --git a/samza-core/src/main/java/org/apache/samza/table/retry/AsyncRetriableTable.java b/samza-core/src/main/java/org/apache/samza/table/retry/AsyncRetriableTable.java
index ba39517..589fb14 100644
--- a/samza-core/src/main/java/org/apache/samza/table/retry/AsyncRetriableTable.java
+++ b/samza-core/src/main/java/org/apache/samza/table/retry/AsyncRetriableTable.java
@@ -93,33 +93,43 @@
}
@Override
- public CompletableFuture<V> getAsync(K key) {
- return doRead(() -> table.getAsync(key));
+ public CompletableFuture<V> getAsync(K key, Object... args) {
+ return doRead(() -> table.getAsync(key, args));
}
@Override
- public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
- return doRead(() -> table.getAllAsync(keys));
+ public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
+ return doRead(() -> table.getAllAsync(keys, args));
}
@Override
- public CompletableFuture<Void> putAsync(K key, V value) {
- return doWrite(() -> table.putAsync(key, value));
+ public <T> CompletableFuture<T> readAsync(int opId, Object... args) {
+ return doRead(() -> table.readAsync(opId, args));
}
@Override
- public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries) {
- return doWrite(() -> table.putAllAsync(entries));
+ public CompletableFuture<Void> putAsync(K key, V value, Object... args) {
+ return doWrite(() -> table.putAsync(key, value, args));
}
@Override
- public CompletableFuture<Void> deleteAsync(K key) {
- return doWrite(() -> table.deleteAsync(key));
+ public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries, Object ... args) {
+ return doWrite(() -> table.putAllAsync(entries, args));
}
@Override
- public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
- return doWrite(() -> table.deleteAllAsync(keys));
+ public CompletableFuture<Void> deleteAsync(K key, Object... args) {
+ return doWrite(() -> table.deleteAsync(key, args));
+ }
+
+ @Override
+ public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
+ return doWrite(() -> table.deleteAllAsync(keys, args));
+ }
+
+ @Override
+ public <T> CompletableFuture<T> writeAsync(int opId, Object... args) {
+ return doWrite(() -> table.writeAsync(opId, args));
}
@Override
diff --git a/samza-core/src/main/java/org/apache/samza/table/retry/RetriableReadFunction.java b/samza-core/src/main/java/org/apache/samza/table/retry/RetriableReadFunction.java
deleted file mode 100644
index 1adddc0..0000000
--- a/samza-core/src/main/java/org/apache/samza/table/retry/RetriableReadFunction.java
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.table.retry;
-
-import java.util.Collection;
-import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.function.Predicate;
-
-import org.apache.samza.SamzaException;
-import org.apache.samza.table.remote.TableReadFunction;
-import org.apache.samza.table.utils.TableMetricsUtil;
-
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
-import net.jodah.failsafe.RetryPolicy;
-
-import static org.apache.samza.table.retry.FailsafeAdapter.failsafe;
-
-
-/**
- * Wrapper for a {@link TableReadFunction} instance to add common retry
- * support with a {@link TableRetryPolicy}. This wrapper is created by
- * {@link org.apache.samza.table.remote.RemoteTableProvider} when a retry
- * policy is specified together with the {@link TableReadFunction}.
- *
- * Actual retry mechanism is provided by the failsafe library. Retry is
- * attempted in an async way with a {@link ScheduledExecutorService}.
- *
- * @param <K> the type of the key in this table
- * @param <V> the type of the value in this table
- */
-public class RetriableReadFunction<K, V> implements TableReadFunction<K, V> {
- private final RetryPolicy retryPolicy;
- private final TableReadFunction<K, V> readFn;
- private final ScheduledExecutorService retryExecutor;
-
- @VisibleForTesting
- RetryMetrics retryMetrics;
-
- public RetriableReadFunction(TableRetryPolicy policy, TableReadFunction<K, V> readFn,
- ScheduledExecutorService retryExecutor) {
- Preconditions.checkNotNull(policy);
- Preconditions.checkNotNull(readFn);
- Preconditions.checkNotNull(retryExecutor);
-
- this.readFn = readFn;
- this.retryExecutor = retryExecutor;
- Predicate<Throwable> retryPredicate = policy.getRetryPredicate();
- policy.withRetryPredicate((ex) -> readFn.isRetriable(ex) || retryPredicate.test(ex));
- this.retryPolicy = FailsafeAdapter.valueOf(policy);
- }
-
- @Override
- public CompletableFuture<V> getAsync(K key) {
- return failsafe(retryPolicy, retryMetrics, retryExecutor)
- .future(() -> readFn.getAsync(key))
- .exceptionally(e -> {
- throw new SamzaException("Failed to get the record for " + key + " after retries.", e);
- });
- }
-
- @Override
- public CompletableFuture<Map<K, V>> getAllAsync(Collection<K> keys) {
- return failsafe(retryPolicy, retryMetrics, retryExecutor)
- .future(() -> readFn.getAllAsync(keys))
- .exceptionally(e -> {
- throw new SamzaException("Failed to get the records for " + keys + " after retries.", e);
- });
- }
-
- @Override
- public boolean isRetriable(Throwable exception) {
- return readFn.isRetriable(exception);
- }
-
- /**
- * Initialize retry-related metrics
- * @param metricsUtil metrics util
- */
- public void setMetrics(TableMetricsUtil metricsUtil) {
- this.retryMetrics = new RetryMetrics("reader", metricsUtil);
- }
-}
diff --git a/samza-core/src/main/java/org/apache/samza/table/retry/RetriableWriteFunction.java b/samza-core/src/main/java/org/apache/samza/table/retry/RetriableWriteFunction.java
deleted file mode 100644
index 2f3f062..0000000
--- a/samza-core/src/main/java/org/apache/samza/table/retry/RetriableWriteFunction.java
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.table.retry;
-
-import java.util.Collection;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.function.Predicate;
-
-import org.apache.samza.SamzaException;
-import org.apache.samza.storage.kv.Entry;
-import org.apache.samza.table.remote.TableWriteFunction;
-import org.apache.samza.table.utils.TableMetricsUtil;
-
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.base.Preconditions;
-import net.jodah.failsafe.RetryPolicy;
-
-import static org.apache.samza.table.retry.FailsafeAdapter.failsafe;
-
-
-/**
- * Wrapper for a {@link TableWriteFunction} instance to add common retry
- * support with a {@link TableRetryPolicy}. This wrapper is created by
- * {@link org.apache.samza.table.remote.RemoteTableProvider} when a retry
- * policy is specified together with the {@link TableWriteFunction}.
- *
- * Actual retry mechanism is provided by the failsafe library. Retry is
- * attempted in an async way with a {@link ScheduledExecutorService}.
- *
- * @param <K> the type of the key in this table
- * @param <V> the type of the value in this table
- */
-public class RetriableWriteFunction<K, V> implements TableWriteFunction<K, V> {
- private final RetryPolicy retryPolicy;
- private final TableWriteFunction<K, V> writeFn;
- private final ScheduledExecutorService retryExecutor;
-
- @VisibleForTesting
- RetryMetrics retryMetrics;
-
- public RetriableWriteFunction(TableRetryPolicy policy, TableWriteFunction<K, V> writeFn,
- ScheduledExecutorService retryExecutor) {
- Preconditions.checkNotNull(policy);
- Preconditions.checkNotNull(writeFn);
- Preconditions.checkNotNull(retryExecutor);
-
- this.writeFn = writeFn;
- this.retryExecutor = retryExecutor;
- Predicate<Throwable> retryPredicate = policy.getRetryPredicate();
- policy.withRetryPredicate((ex) -> writeFn.isRetriable(ex) || retryPredicate.test(ex));
- this.retryPolicy = FailsafeAdapter.valueOf(policy);
- }
-
- @Override
- public CompletableFuture<Void> putAsync(K key, V record) {
- return failsafe(retryPolicy, retryMetrics, retryExecutor)
- .future(() -> writeFn.putAsync(key, record))
- .exceptionally(e -> {
- throw new SamzaException("Failed to get the record for " + key + " after retries.", e);
- });
- }
-
- @Override
- public CompletableFuture<Void> putAllAsync(Collection<Entry<K, V>> records) {
- return failsafe(retryPolicy, retryMetrics, retryExecutor)
- .future(() -> writeFn.putAllAsync(records))
- .exceptionally(e -> {
- throw new SamzaException("Failed to put records after retries.", e);
- });
- }
-
- @Override
- public CompletableFuture<Void> deleteAsync(K key) {
- return failsafe(retryPolicy, retryMetrics, retryExecutor)
- .future(() -> writeFn.deleteAsync(key))
- .exceptionally(e -> {
- throw new SamzaException("Failed to delete the record for " + key + " after retries.", e);
- });
- }
-
- @Override
- public CompletableFuture<Void> deleteAllAsync(Collection<K> keys) {
- return failsafe(retryPolicy, retryMetrics, retryExecutor)
- .future(() -> writeFn.deleteAllAsync(keys))
- .exceptionally(e -> {
- throw new SamzaException("Failed to delete the records for " + keys + " after retries.", e);
- });
- }
-
- @Override
- public boolean isRetriable(Throwable exception) {
- return writeFn.isRetriable(exception);
- }
-
- /**
- * Initialize retry-related metrics.
- * @param metricsUtil metrics util
- */
- public void setMetrics(TableMetricsUtil metricsUtil) {
- this.retryMetrics = new RetryMetrics("writer", metricsUtil);
- }
-}
diff --git a/samza-core/src/main/java/org/apache/samza/table/retry/RetryMetrics.java b/samza-core/src/main/java/org/apache/samza/table/retry/RetryMetrics.java
index fbc511c..f717462 100644
--- a/samza-core/src/main/java/org/apache/samza/table/retry/RetryMetrics.java
+++ b/samza-core/src/main/java/org/apache/samza/table/retry/RetryMetrics.java
@@ -25,8 +25,7 @@
/**
- * Wrapper of retry-related metrics common to both {@link RetriableReadFunction} and
- * {@link RetriableWriteFunction}.
+ * Retry-related metrics
*/
class RetryMetrics {
/**
diff --git a/samza-core/src/main/java/org/apache/samza/table/utils/TableMetrics.java b/samza-core/src/main/java/org/apache/samza/table/utils/TableMetrics.java
index df6833e..5906bed 100644
--- a/samza-core/src/main/java/org/apache/samza/table/utils/TableMetrics.java
+++ b/samza-core/src/main/java/org/apache/samza/table/utils/TableMetrics.java
@@ -32,8 +32,10 @@
// Read metrics
public final Timer getNs;
public final Timer getAllNs;
+ public final Timer readNs;
public final Counter numGets;
public final Counter numGetAlls;
+ public final Counter numReads;
public final Counter numMissedLookups;
// Write metrics
public final Counter numPuts;
@@ -44,6 +46,8 @@
public final Timer deleteNs;
public final Counter numDeleteAlls;
public final Timer deleteAllNs;
+ public final Counter numWrites;
+ public final Timer writeNs;
public final Counter numFlushes;
public final Timer flushNs;
@@ -61,6 +65,8 @@
getNs = tableMetricsUtil.newTimer("get-ns");
numGetAlls = tableMetricsUtil.newCounter("num-getAlls");
getAllNs = tableMetricsUtil.newTimer("getAll-ns");
+ numReads = tableMetricsUtil.newCounter("num-reads");
+ readNs = tableMetricsUtil.newTimer("read-ns");
numMissedLookups = tableMetricsUtil.newCounter("num-missed-lookups");
// Write metrics
numPuts = tableMetricsUtil.newCounter("num-puts");
@@ -71,6 +77,8 @@
deleteNs = tableMetricsUtil.newTimer("delete-ns");
numDeleteAlls = tableMetricsUtil.newCounter("num-deleteAlls");
deleteAllNs = tableMetricsUtil.newTimer("deleteAll-ns");
+ numWrites = tableMetricsUtil.newCounter("num-writes");
+ writeNs = tableMetricsUtil.newTimer("write-ns");
numFlushes = tableMetricsUtil.newCounter("num-flushes");
flushNs = tableMetricsUtil.newTimer("flush-ns");
}
diff --git a/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala b/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
index 09f778a..90f9668 100644
--- a/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
+++ b/samza-core/src/main/scala/org/apache/samza/checkpoint/OffsetManager.scala
@@ -356,20 +356,13 @@
// delete corresponding startpoints after checkpoint is supposed to be committed
if (startpointManager != null && startpoints.contains(taskName)) {
- val sspStartpoints = checkpoint.getOffsets.keySet.asScala
- .intersect(startpoints.getOrElse(taskName, Map.empty[SystemStreamPartition, Startpoint]).keySet)
-
- // delete startpoints for this task and the intersection of SSPs between checkpoint and startpoint.
- sspStartpoints.foreach(ssp => {
- startpointManager.deleteStartpoint(ssp, taskName)
- info("Deleted startpoint for SSP: %s and task: %s" format (ssp, taskName))
- })
+ info("%d startpoint(s) for taskName: %s have been committed to the checkpoint." format (startpoints.get(taskName).size, taskName.getTaskName))
+ startpointManager.removeFanOutForTask(taskName)
startpoints -= taskName
if (startpoints.isEmpty) {
- // Stop startpoint manager after last startpoint is deleted
- startpointManager.stop()
- info("No more startpoints left to consume. Stopped the startpoint manager.")
+ info("All outstanding startpoints have been committed to the checkpoint.")
+ startpointManager.stop
}
}
}
@@ -384,11 +377,11 @@
}
if (startpointManager != null) {
- debug("Ensuring startpoint manager has shut down.")
+ debug("Shutting down startpoint manager.")
startpointManager.stop
} else {
- debug("Skipping startpoint manager shutdown because no checkpoint manager is defined.")
+ debug("Skipping startpoint manager shutdown because no startpoint manager is defined.")
}
}
@@ -517,23 +510,26 @@
if (startpointManager != null) {
info("Starting startpoint manager.")
startpointManager.start
- val taskNameToSSPs: Map[TaskName, Set[SystemStreamPartition]] = systemStreamPartitions
- taskNameToSSPs.foreach {
+ systemStreamPartitions.foreach {
case (taskName, systemStreamPartitionSet) => {
- val sspToStartpoint = systemStreamPartitionSet
- .map(ssp => (ssp, startpointManager.readStartpoint(ssp, taskName)))
- .filter(_._2 != null)
- .toMap
-
- if (!sspToStartpoint.isEmpty) {
- startpoints += taskName -> sspToStartpoint
+ Option(startpointManager.getFanOutForTask(taskName)) match {
+ case Some(fanOut) => {
+ val filteredFanOut = fanOut.asScala
+ .filter(f => systemStreamPartitionSet.contains(f._1))
+ .toMap
+ if (!filteredFanOut.isEmpty) {
+ startpoints += taskName -> filteredFanOut
+ info("Startpoint fan out for task: %s - %s" format (taskName, filteredFanOut))
+ }
+ }
+ case None => debug("No startpoints fanned out on taskName: %s" format taskName.getTaskName)
}
}
}
if (startpoints.isEmpty) {
- info("No startpoints to consume. Stopping startpoint manager.")
+ info("No startpoints to consume.")
startpointManager.stop
} else {
startpoints
diff --git a/samza-core/src/test/java/org/apache/samza/startpoint/StartpointManagerTestUtil.java b/samza-core/src/test/java/org/apache/samza/startpoint/StartpointManagerTestUtil.java
index f8eb855..1071904 100644
--- a/samza-core/src/test/java/org/apache/samza/startpoint/StartpointManagerTestUtil.java
+++ b/samza-core/src/test/java/org/apache/samza/startpoint/StartpointManagerTestUtil.java
@@ -18,17 +18,36 @@
*/
package org.apache.samza.startpoint;
+import org.apache.samza.config.Config;
import org.apache.samza.config.MapConfig;
import org.apache.samza.metadatastore.InMemoryMetadataStoreFactory;
+import org.apache.samza.metadatastore.MetadataStore;
+import org.apache.samza.metadatastore.MetadataStoreFactory;
+import org.apache.samza.metrics.MetricsRegistry;
import org.apache.samza.util.NoOpMetricsRegistry;
public class StartpointManagerTestUtil {
+ private final MetadataStore metadataStore;
+ private final StartpointManager startpointManager;
- private StartpointManagerTestUtil() {
+ public StartpointManagerTestUtil() {
+ this(new InMemoryMetadataStoreFactory(), new MapConfig(), new NoOpMetricsRegistry());
}
- public static StartpointManager getStartpointManager() {
- return new StartpointManager(new InMemoryMetadataStoreFactory(), new MapConfig(), new NoOpMetricsRegistry());
+ public StartpointManagerTestUtil(MetadataStoreFactory metadataStoreFactory, Config config, MetricsRegistry metricsRegistry) {
+ this.metadataStore = metadataStoreFactory.getMetadataStore(StartpointManager.NAMESPACE, config, metricsRegistry);
+ this.metadataStore.init();
+ this.startpointManager = new StartpointManager(metadataStore);
+ this.startpointManager.start();
+ }
+
+ public StartpointManager getStartpointManager() {
+ return startpointManager;
+ }
+
+ public void stop() {
+ startpointManager.stop();
+ metadataStore.close();
}
}
diff --git a/samza-core/src/test/java/org/apache/samza/startpoint/StartpointMock.java b/samza-core/src/test/java/org/apache/samza/startpoint/StartpointMock.java
new file mode 100644
index 0000000..7c8b50d
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/startpoint/StartpointMock.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.startpoint;
+
+import java.time.Instant;
+
+
+// Can't mock concrete Startpoint classes because they are final.
+public class StartpointMock extends Startpoint {
+ public StartpointMock() {
+ super(Instant.now().toEpochMilli());
+ }
+
+ public StartpointMock(long creationTimestamp) {
+ super(creationTimestamp);
+ }
+
+ @Override
+ public <IN, OUT> OUT apply(IN input, StartpointVisitor<IN, OUT> startpointVisitor) {
+ // mocked
+ return startpointVisitor.visit(input, new StartpointSpecific("Mocked"));
+ }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointKey.java b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointKey.java
deleted file mode 100644
index 72f922f..0000000
--- a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointKey.java
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.startpoint;
-
-import java.io.IOException;
-import java.util.LinkedHashMap;
-import org.apache.samza.Partition;
-import org.apache.samza.container.TaskName;
-import org.apache.samza.serializers.JsonSerdeV2;
-import org.apache.samza.system.SystemStreamPartition;
-import org.codehaus.jackson.map.ObjectMapper;
-import org.junit.Assert;
-import org.junit.Test;
-
-
-public class TestStartpointKey {
-
- @Test
- public void testStartpointKey() {
- SystemStreamPartition ssp1 = new SystemStreamPartition("system", "stream", new Partition(2));
- SystemStreamPartition ssp2 = new SystemStreamPartition("system", "stream", new Partition(3));
-
- StartpointKey startpointKey1 = new StartpointKey(ssp1);
- StartpointKey startpointKey2 = new StartpointKey(ssp1);
- StartpointKey startpointKeyWithDifferentSSP = new StartpointKey(ssp2);
- StartpointKey startpointKeyWithTask1 = new StartpointKey(ssp1, new TaskName("t1"));
- StartpointKey startpointKeyWithTask2 = new StartpointKey(ssp1, new TaskName("t1"));
- StartpointKey startpointKeyWithDifferentTask = new StartpointKey(ssp1, new TaskName("t2"));
-
- Assert.assertEquals(startpointKey1, startpointKey2);
- Assert.assertEquals(new String(new JsonSerdeV2<>().toBytes(startpointKey1)),
- new String(new JsonSerdeV2<>().toBytes(startpointKey2)));
- Assert.assertEquals(startpointKeyWithTask1, startpointKeyWithTask2);
- Assert.assertEquals(new String(new JsonSerdeV2<>().toBytes(startpointKeyWithTask1)),
- new String(new JsonSerdeV2<>().toBytes(startpointKeyWithTask2)));
-
- Assert.assertNotEquals(startpointKey1, startpointKeyWithTask1);
- Assert.assertNotEquals(new String(new JsonSerdeV2<>().toBytes(startpointKey1)),
- new String(new JsonSerdeV2<>().toBytes(startpointKeyWithTask1)));
-
- Assert.assertNotEquals(startpointKey1, startpointKeyWithDifferentSSP);
- Assert.assertNotEquals(new String(new JsonSerdeV2<>().toBytes(startpointKey1)),
- new String(new JsonSerdeV2<>().toBytes(startpointKeyWithDifferentSSP)));
- Assert.assertNotEquals(startpointKeyWithTask1, startpointKeyWithDifferentTask);
- Assert.assertNotEquals(new String(new JsonSerdeV2<>().toBytes(startpointKeyWithTask1)),
- new String(new JsonSerdeV2<>().toBytes(startpointKeyWithDifferentTask)));
-
- Assert.assertNotEquals(startpointKeyWithTask1, startpointKeyWithDifferentTask);
- Assert.assertNotEquals(new String(new JsonSerdeV2<>().toBytes(startpointKeyWithTask1)),
- new String(new JsonSerdeV2<>().toBytes(startpointKeyWithDifferentTask)));
- }
-
- @Test
- public void testStartpointKeyFormat() throws IOException {
- SystemStreamPartition ssp = new SystemStreamPartition("system1", "stream1", new Partition(2));
- StartpointKey startpointKeyWithTask = new StartpointKey(ssp, new TaskName("t1"));
- ObjectMapper objectMapper = new ObjectMapper();
- byte[] jsonBytes = new JsonSerdeV2<>().toBytes(startpointKeyWithTask);
- LinkedHashMap<String, String> deserialized = objectMapper.readValue(jsonBytes, LinkedHashMap.class);
-
- Assert.assertEquals(4, deserialized.size());
- Assert.assertEquals("system1", deserialized.get("system"));
- Assert.assertEquals("stream1", deserialized.get("stream"));
- Assert.assertEquals(2, deserialized.get("partition"));
- Assert.assertEquals("t1", deserialized.get("taskName"));
- }
-}
diff --git a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointManager.java b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointManager.java
index b4c33c3..d615c54 100644
--- a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointManager.java
+++ b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointManager.java
@@ -21,6 +21,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
+import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.List;
@@ -34,14 +35,13 @@
import org.apache.samza.coordinator.metadatastore.CoordinatorStreamStore;
import org.apache.samza.coordinator.metadatastore.CoordinatorStreamStoreTestUtil;
import org.apache.samza.coordinator.metadatastore.NamespaceAwareCoordinatorStreamStore;
-import org.apache.samza.job.model.ContainerModel;
-import org.apache.samza.job.model.JobModel;
-import org.apache.samza.job.model.TaskModel;
import org.apache.samza.system.SystemStreamPartition;
+import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+
public class TestStartpointManager {
private static final Config CONFIG = new MapConfig(ImmutableMap.of("job.name", "test-job", "job.coordinator.system", "test-kafka"));
@@ -53,14 +53,23 @@
public void setup() {
CoordinatorStreamStoreTestUtil coordinatorStreamStoreTestUtil = new CoordinatorStreamStoreTestUtil(CONFIG);
coordinatorStreamStore = coordinatorStreamStoreTestUtil.getCoordinatorStreamStore();
+ coordinatorStreamStore.init();
startpointManager = new StartpointManager(coordinatorStreamStore);
+ startpointManager.start();
+ }
+
+ @After
+ public void teardown() {
+ startpointManager.stop();
+ coordinatorStreamStore.close();
}
@Test
public void testDefaultMetadataStore() {
StartpointManager startpointManager = new StartpointManager(coordinatorStreamStore);
Assert.assertNotNull(startpointManager);
- Assert.assertEquals(NamespaceAwareCoordinatorStreamStore.class, startpointManager.getMetadataStore().getClass());
+ Assert.assertEquals(NamespaceAwareCoordinatorStreamStore.class, startpointManager.getReadWriteStore().getClass());
+ Assert.assertEquals(NamespaceAwareCoordinatorStreamStore.class, startpointManager.getFanOutStore().getClass());
}
@Test
@@ -68,19 +77,18 @@
SystemStreamPartition ssp = new SystemStreamPartition("mockSystem", "mockStream", new Partition(2));
TaskName taskName = new TaskName("MockTask");
- startpointManager.start();
long staleTimestamp = Instant.now().toEpochMilli() - StartpointManager.DEFAULT_EXPIRATION_DURATION.toMillis() - 2;
StartpointTimestamp startpoint = new StartpointTimestamp(staleTimestamp, staleTimestamp);
startpointManager.writeStartpoint(ssp, startpoint);
- Assert.assertNull(startpointManager.readStartpoint(ssp));
+ Assert.assertFalse(startpointManager.readStartpoint(ssp).isPresent());
startpointManager.writeStartpoint(ssp, taskName, startpoint);
- Assert.assertNull(startpointManager.readStartpoint(ssp, taskName));
+ Assert.assertFalse(startpointManager.readStartpoint(ssp, taskName).isPresent());
}
@Test
- public void testNoLongerUsableAfterStop() {
+ public void testNoLongerUsableAfterStop() throws IOException {
StartpointManager startpointManager = new StartpointManager(coordinatorStreamStore);
startpointManager.start();
SystemStreamPartition ssp =
@@ -121,14 +129,23 @@
} catch (IllegalStateException ex) { }
try {
- startpointManager.fanOutStartpointsToTasks(new JobModel(new MapConfig(), new HashMap<>()));
+ startpointManager.fanOut(new HashMap<>());
+ Assert.fail("Expected precondition exception.");
+ } catch (IllegalStateException ex) { }
+
+ try {
+ startpointManager.getFanOutForTask(new TaskName("t0"));
+ Assert.fail("Expected precondition exception.");
+ } catch (IllegalStateException ex) { }
+
+ try {
+ startpointManager.removeFanOutForTask(new TaskName("t0"));
Assert.fail("Expected precondition exception.");
} catch (IllegalStateException ex) { }
}
@Test
public void testBasics() {
- startpointManager.start();
SystemStreamPartition ssp =
new SystemStreamPartition("mockSystem", "mockStream", new Partition(2));
TaskName taskName = new TaskName("MockTask");
@@ -144,18 +161,18 @@
Assert.assertNotNull(startpoint4.getCreationTimestamp());
// Test reads on non-existent keys
- Assert.assertNull(startpointManager.readStartpoint(ssp));
- Assert.assertNull(startpointManager.readStartpoint(ssp, taskName));
+ Assert.assertFalse(startpointManager.readStartpoint(ssp).isPresent());
+ Assert.assertFalse(startpointManager.readStartpoint(ssp, taskName).isPresent());
// Test writes
Startpoint startpointFromStore;
startpointManager.writeStartpoint(ssp, startpoint1);
startpointManager.writeStartpoint(ssp, taskName, startpoint2);
- startpointFromStore = startpointManager.readStartpoint(ssp);
+ startpointFromStore = startpointManager.readStartpoint(ssp).get();
Assert.assertEquals(StartpointTimestamp.class, startpointFromStore.getClass());
Assert.assertEquals(startpoint1.getTimestampOffset(), ((StartpointTimestamp) startpointFromStore).getTimestampOffset());
Assert.assertTrue(startpointFromStore.getCreationTimestamp() <= Instant.now().toEpochMilli());
- startpointFromStore = startpointManager.readStartpoint(ssp, taskName);
+ startpointFromStore = startpointManager.readStartpoint(ssp, taskName).get();
Assert.assertEquals(StartpointTimestamp.class, startpointFromStore.getClass());
Assert.assertEquals(startpoint2.getTimestampOffset(), ((StartpointTimestamp) startpointFromStore).getTimestampOffset());
Assert.assertTrue(startpointFromStore.getCreationTimestamp() <= Instant.now().toEpochMilli());
@@ -163,48 +180,40 @@
// Test overwrites
startpointManager.writeStartpoint(ssp, startpoint3);
startpointManager.writeStartpoint(ssp, taskName, startpoint4);
- startpointFromStore = startpointManager.readStartpoint(ssp);
+ startpointFromStore = startpointManager.readStartpoint(ssp).get();
Assert.assertEquals(StartpointSpecific.class, startpointFromStore.getClass());
Assert.assertEquals(startpoint3.getSpecificOffset(), ((StartpointSpecific) startpointFromStore).getSpecificOffset());
Assert.assertTrue(startpointFromStore.getCreationTimestamp() <= Instant.now().toEpochMilli());
- startpointFromStore = startpointManager.readStartpoint(ssp, taskName);
+ startpointFromStore = startpointManager.readStartpoint(ssp, taskName).get();
Assert.assertEquals(StartpointSpecific.class, startpointFromStore.getClass());
Assert.assertEquals(startpoint4.getSpecificOffset(), ((StartpointSpecific) startpointFromStore).getSpecificOffset());
Assert.assertTrue(startpointFromStore.getCreationTimestamp() <= Instant.now().toEpochMilli());
// Test deletes on SSP keys does not affect SSP+TaskName keys
startpointManager.deleteStartpoint(ssp);
- Assert.assertNull(startpointManager.readStartpoint(ssp));
- Assert.assertNotNull(startpointManager.readStartpoint(ssp, taskName));
+ Assert.assertFalse(startpointManager.readStartpoint(ssp).isPresent());
+ Assert.assertTrue(startpointManager.readStartpoint(ssp, taskName).isPresent());
// Test deletes on SSP+TaskName keys does not affect SSP keys
startpointManager.writeStartpoint(ssp, startpoint3);
startpointManager.deleteStartpoint(ssp, taskName);
- Assert.assertNull(startpointManager.readStartpoint(ssp, taskName));
- Assert.assertNotNull(startpointManager.readStartpoint(ssp));
-
- startpointManager.stop();
+ Assert.assertFalse(startpointManager.readStartpoint(ssp, taskName).isPresent());
+ Assert.assertTrue(startpointManager.readStartpoint(ssp).isPresent());
}
@Test
- public void testGroupStartpointsPerTask() {
- MapConfig config = new MapConfig();
- startpointManager.start();
- SystemStreamPartition sspBroadcast =
- new SystemStreamPartition("mockSystem1", "mockStream1", new Partition(2));
- SystemStreamPartition sspBroadcast2 =
- new SystemStreamPartition("mockSystem3", "mockStream3", new Partition(4));
- SystemStreamPartition sspSingle =
- new SystemStreamPartition("mockSystem2", "mockStream2", new Partition(3));
+ public void testFanOutBasic() throws IOException {
+ SystemStreamPartition sspBroadcast = new SystemStreamPartition("mockSystem1", "mockStream1", new Partition(2));
+ SystemStreamPartition sspSingle = new SystemStreamPartition("mockSystem2", "mockStream2", new Partition(3));
+
+ TaskName taskWithNonBroadcast = new TaskName("t1");
List<TaskName> tasks =
- ImmutableList.of(new TaskName("t0"), new TaskName("t1"), new TaskName("t2"), new TaskName("t3"), new TaskName("t4"), new TaskName("t5"));
+ ImmutableList.of(new TaskName("t0"), taskWithNonBroadcast, new TaskName("t2"), new TaskName("t3"), new TaskName("t4"), new TaskName("t5"));
- Map<TaskName, TaskModel> taskModelMap = tasks.stream()
- .map(task -> new TaskModel(task, task.getTaskName().equals("t1") ? ImmutableSet.of(sspBroadcast, sspBroadcast2, sspSingle) : ImmutableSet.of(sspBroadcast, sspBroadcast2), new Partition(1)))
- .collect(Collectors.toMap(taskModel -> taskModel.getTaskName(), taskModel -> taskModel));
- ContainerModel containerModel = new ContainerModel("container 0", taskModelMap);
- JobModel jobModel = new JobModel(config, ImmutableMap.of(containerModel.getId(), containerModel));
+ Map<TaskName, Set<SystemStreamPartition>> taskToSSPs = tasks.stream()
+ .collect(Collectors
+ .toMap(task -> task, task -> task.equals(taskWithNonBroadcast) ? ImmutableSet.of(sspBroadcast, sspSingle) : ImmutableSet.of(sspBroadcast)));
StartpointSpecific startpoint42 = new StartpointSpecific("42");
@@ -212,54 +221,94 @@
startpointManager.writeStartpoint(sspSingle, startpoint42);
// startpoint42 should remap with key sspBroadcast to all tasks + sspBroadcast
- Set<SystemStreamPartition> systemStreamPartitions = startpointManager.fanOutStartpointsToTasks(jobModel);
- Assert.assertEquals(2, systemStreamPartitions.size());
- Assert.assertTrue(systemStreamPartitions.containsAll(ImmutableSet.of(sspBroadcast, sspSingle)));
+ Map<TaskName, Map<SystemStreamPartition, Startpoint>> tasksFannedOutTo = startpointManager.fanOut(taskToSSPs);
+ Assert.assertEquals(tasks.size(), tasksFannedOutTo.size());
+ Assert.assertTrue(tasksFannedOutTo.keySet().containsAll(tasks));
+ Assert.assertFalse("Should be deleted after fan out", startpointManager.readStartpoint(sspBroadcast).isPresent());
+ Assert.assertFalse("Should be deleted after fan out", startpointManager.readStartpoint(sspSingle).isPresent());
for (TaskName taskName : tasks) {
- // startpoint42 should be mapped to all tasks for sspBroadcast
- Startpoint startpointFromStore = startpointManager.readStartpoint(sspBroadcast, taskName);
+ Map<SystemStreamPartition, Startpoint> fanOutForTask = startpointManager.getFanOutForTask(taskName);
+ if (taskName.equals(taskWithNonBroadcast)) {
+ // Non-broadcast startpoint should be fanned out to only one task
+ Assert.assertEquals("Should have broadcast and non-broadcast SSP", 2, fanOutForTask.size());
+ } else {
+ Assert.assertEquals("Should only have broadcast SSP", 1, fanOutForTask.size());
+ }
+
+ // Broadcast SSP should be on every task
+ Startpoint startpointFromStore = fanOutForTask.get(sspBroadcast);
Assert.assertEquals(StartpointSpecific.class, startpointFromStore.getClass());
Assert.assertEquals(startpoint42.getSpecificOffset(), ((StartpointSpecific) startpointFromStore).getSpecificOffset());
- // startpoint 42 should be mapped only to task "t1" for sspSingle
- startpointFromStore = startpointManager.readStartpoint(sspSingle, taskName);
- if (taskName.getTaskName().equals("t1")) {
+ // startpoint mapped only to task "t1" for Non-broadcast SSP
+ startpointFromStore = fanOutForTask.get(sspSingle);
+ if (taskName.equals(taskWithNonBroadcast)) {
Assert.assertEquals(StartpointSpecific.class, startpointFromStore.getClass());
Assert.assertEquals(startpoint42.getSpecificOffset(), ((StartpointSpecific) startpointFromStore).getSpecificOffset());
} else {
- Assert.assertNull(startpointFromStore);
+ Assert.assertNull("Should not have non-broadcast SSP", startpointFromStore);
}
+
+ startpointManager.removeFanOutForTask(taskName);
+ Assert.assertTrue(startpointManager.getFanOutForTask(taskName).isEmpty());
}
- Assert.assertNull(startpointManager.readStartpoint(sspBroadcast));
- Assert.assertNull(startpointManager.readStartpoint(sspSingle));
+ }
- // Test startpoints that were explicit assigned to an sspBroadcast2+TaskName will not be overwritten from fanOutStartpointsToTasks
+ @Test
+ public void testFanOutWithStartpointResolutions() throws IOException {
+ SystemStreamPartition sspBroadcast = new SystemStreamPartition("mockSystem1", "mockStream1", new Partition(2));
+ SystemStreamPartition sspSingle = new SystemStreamPartition("mockSystem2", "mockStream2", new Partition(3));
- StartpointSpecific startpoint1024 = new StartpointSpecific("1024");
+ List<TaskName> tasks =
+ ImmutableList.of(new TaskName("t0"), new TaskName("t1"), new TaskName("t2"), new TaskName("t3"), new TaskName("t4"));
- startpointManager.writeStartpoint(sspBroadcast2, startpoint42);
- startpointManager.writeStartpoint(sspBroadcast2, tasks.get(1), startpoint1024);
- startpointManager.writeStartpoint(sspBroadcast2, tasks.get(3), startpoint1024);
+ TaskName taskWithNonBroadcast = tasks.get(1);
+ TaskName taskBroadcastInPast = tasks.get(2);
+ TaskName taskBroadcastInFuture = tasks.get(3);
- Set<SystemStreamPartition> sspsDeleted = startpointManager.fanOutStartpointsToTasks(jobModel);
- Assert.assertEquals(1, sspsDeleted.size());
- Assert.assertTrue(sspsDeleted.contains(sspBroadcast2));
+ Map<TaskName, Set<SystemStreamPartition>> taskToSSPs = tasks.stream()
+ .collect(Collectors
+ .toMap(task -> task, task -> task.equals(taskWithNonBroadcast) ? ImmutableSet.of(sspBroadcast, sspSingle) : ImmutableSet.of(sspBroadcast)));
- StartpointSpecific startpointFromStore = (StartpointSpecific) startpointManager.readStartpoint(sspBroadcast2, tasks.get(0));
- Assert.assertEquals(startpoint42.getSpecificOffset(), startpointFromStore.getSpecificOffset());
- startpointFromStore = (StartpointSpecific) startpointManager.readStartpoint(sspBroadcast2, tasks.get(1));
- Assert.assertEquals(startpoint1024.getSpecificOffset(), startpointFromStore.getSpecificOffset());
- startpointFromStore = (StartpointSpecific) startpointManager.readStartpoint(sspBroadcast2, tasks.get(2));
- Assert.assertEquals(startpoint42.getSpecificOffset(), startpointFromStore.getSpecificOffset());
- startpointFromStore = (StartpointSpecific) startpointManager.readStartpoint(sspBroadcast2, tasks.get(3));
- Assert.assertEquals(startpoint1024.getSpecificOffset(), startpointFromStore.getSpecificOffset());
- startpointFromStore = (StartpointSpecific) startpointManager.readStartpoint(sspBroadcast2, tasks.get(4));
- Assert.assertEquals(startpoint42.getSpecificOffset(), startpointFromStore.getSpecificOffset());
- startpointFromStore = (StartpointSpecific) startpointManager.readStartpoint(sspBroadcast2, tasks.get(5));
- Assert.assertEquals(startpoint42.getSpecificOffset(), startpointFromStore.getSpecificOffset());
- Assert.assertNull(startpointManager.readStartpoint(sspBroadcast2));
+ Instant now = Instant.now();
+ StartpointMock startpointPast = new StartpointMock(now.minusMillis(10000L).toEpochMilli());
+ StartpointMock startpointPresent = new StartpointMock(now.toEpochMilli());
+ StartpointMock startpointFuture = new StartpointMock(now.plusMillis(10000L).toEpochMilli());
- startpointManager.stop();
+ startpointManager.getObjectMapper().registerSubtypes(StartpointMock.class);
+ startpointManager.writeStartpoint(sspSingle, startpointPast);
+ startpointManager.writeStartpoint(sspSingle, startpointPresent);
+ startpointManager.writeStartpoint(sspBroadcast, startpointPresent);
+ startpointManager.writeStartpoint(sspBroadcast, taskBroadcastInPast, startpointPast);
+ startpointManager.writeStartpoint(sspBroadcast, taskBroadcastInFuture, startpointFuture);
+
+ Map<TaskName, Map<SystemStreamPartition, Startpoint>> fannedOut = startpointManager.fanOut(taskToSSPs);
+ Assert.assertEquals(tasks.size(), fannedOut.size());
+ Assert.assertTrue(fannedOut.keySet().containsAll(tasks));
+ Assert.assertFalse("Should be deleted after fan out", startpointManager.readStartpoint(sspBroadcast).isPresent());
+ for (TaskName taskName : fannedOut.keySet()) {
+ Assert.assertFalse("Should be deleted after fan out for task: " + taskName.getTaskName(), startpointManager.readStartpoint(sspBroadcast, taskName).isPresent());
+ }
+ Assert.assertFalse("Should be deleted after fan out", startpointManager.readStartpoint(sspSingle).isPresent());
+
+ for (TaskName taskName : tasks) {
+ Map<SystemStreamPartition, Startpoint> fanOutForTask = startpointManager.getFanOutForTask(taskName);
+ if (taskName.equals(taskWithNonBroadcast)) {
+ Assert.assertEquals(startpointPresent, fanOutForTask.get(sspSingle));
+ Assert.assertEquals(startpointPresent, fanOutForTask.get(sspBroadcast));
+ } else if (taskName.equals(taskBroadcastInPast)) {
+ Assert.assertNull(fanOutForTask.get(sspSingle));
+ Assert.assertEquals(startpointPresent, fanOutForTask.get(sspBroadcast));
+ } else if (taskName.equals(taskBroadcastInFuture)) {
+ Assert.assertNull(fanOutForTask.get(sspSingle));
+ Assert.assertEquals(startpointFuture, fanOutForTask.get(sspBroadcast));
+ } else {
+ Assert.assertNull(fanOutForTask.get(sspSingle));
+ Assert.assertEquals(startpointPresent, fanOutForTask.get(sspBroadcast));
+ }
+ startpointManager.removeFanOutForTask(taskName);
+ Assert.assertTrue(startpointManager.getFanOutForTask(taskName).isEmpty());
+ }
}
}
diff --git a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointObjectMapper.java b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointObjectMapper.java
new file mode 100644
index 0000000..70a3bf7
--- /dev/null
+++ b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointObjectMapper.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.startpoint;
+
+import java.io.IOException;
+import java.time.Instant;
+import org.apache.samza.Partition;
+import org.apache.samza.system.SystemStreamPartition;
+import org.codehaus.jackson.map.ObjectMapper;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestStartpointObjectMapper {
+ private static final ObjectMapper MAPPER = StartpointObjectMapper.getObjectMapper();
+
+ @Test
+ public void testStartpointSpecificSerde() throws IOException {
+ StartpointSpecific startpointSpecific = new StartpointSpecific("42");
+ Startpoint startpointFromSerde = MAPPER.readValue(MAPPER.writeValueAsBytes(startpointSpecific), Startpoint.class);
+
+ Assert.assertEquals(startpointSpecific.getClass(), startpointFromSerde.getClass());
+ Assert.assertEquals(startpointSpecific.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
+ Assert.assertEquals(startpointSpecific.getSpecificOffset(), ((StartpointSpecific) startpointFromSerde).getSpecificOffset());
+ }
+
+ @Test
+ public void testStartpointTimestampSerde() throws IOException {
+ StartpointTimestamp startpointTimestamp = new StartpointTimestamp(123456L);
+ Startpoint startpointFromSerde = MAPPER.readValue(MAPPER.writeValueAsBytes(startpointTimestamp), Startpoint.class);
+
+ Assert.assertEquals(startpointTimestamp.getClass(), startpointFromSerde.getClass());
+ Assert.assertEquals(startpointTimestamp.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
+ Assert.assertEquals(startpointTimestamp.getTimestampOffset(), ((StartpointTimestamp) startpointFromSerde).getTimestampOffset());
+ }
+
+ @Test
+ public void testStartpointEarliestSerde() throws IOException {
+ StartpointOldest startpointOldest = new StartpointOldest();
+ Startpoint startpointFromSerde = MAPPER.readValue(MAPPER.writeValueAsBytes(startpointOldest), Startpoint.class);
+
+ Assert.assertEquals(startpointOldest.getClass(), startpointFromSerde.getClass());
+ Assert.assertEquals(startpointOldest.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
+ }
+
+ @Test
+ public void testStartpointLatestSerde() throws IOException {
+ StartpointUpcoming startpointUpcoming = new StartpointUpcoming();
+ Startpoint startpointFromSerde = MAPPER.readValue(MAPPER.writeValueAsBytes(startpointUpcoming), Startpoint.class);
+
+ Assert.assertEquals(startpointUpcoming.getClass(), startpointFromSerde.getClass());
+ Assert.assertEquals(startpointUpcoming.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
+ }
+
+ @Test
+ public void testFanOutSerde() throws IOException {
+ StartpointFanOutPerTask startpointFanOutPerTask = new StartpointFanOutPerTask(Instant.now().minusSeconds(60));
+ startpointFanOutPerTask.getFanOuts()
+ .put(new SystemStreamPartition("system1", "stream1", new Partition(1)), new StartpointUpcoming());
+ startpointFanOutPerTask.getFanOuts()
+ .put(new SystemStreamPartition("system2", "stream2", new Partition(2)), new StartpointOldest());
+
+ String serialized = MAPPER.writeValueAsString(startpointFanOutPerTask);
+ StartpointFanOutPerTask startpointFanOutPerTaskFromSerde = MAPPER.readValue(serialized, StartpointFanOutPerTask.class);
+
+ Assert.assertEquals(startpointFanOutPerTask, startpointFanOutPerTaskFromSerde);
+ }
+}
diff --git a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointSerde.java b/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointSerde.java
deleted file mode 100644
index 9dfb784..0000000
--- a/samza-core/src/test/java/org/apache/samza/startpoint/TestStartpointSerde.java
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.samza.startpoint;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-public class TestStartpointSerde {
- private final StartpointSerde startpointSerde = new StartpointSerde();
-
- @Test
- public void testStartpointSpecificSerde() {
- StartpointSpecific startpointSpecific = new StartpointSpecific("42");
- Startpoint startpointFromSerde = startpointSerde.fromBytes(startpointSerde.toBytes(startpointSpecific));
-
- Assert.assertEquals(startpointSpecific.getClass(), startpointFromSerde.getClass());
- Assert.assertEquals(startpointSpecific.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
- Assert.assertEquals(startpointSpecific.getSpecificOffset(), ((StartpointSpecific) startpointFromSerde).getSpecificOffset());
- }
-
- @Test
- public void testStartpointTimestampSerde() {
- StartpointTimestamp startpointTimestamp = new StartpointTimestamp(123456L);
- Startpoint startpointFromSerde = startpointSerde.fromBytes(startpointSerde.toBytes(startpointTimestamp));
-
- Assert.assertEquals(startpointTimestamp.getClass(), startpointFromSerde.getClass());
- Assert.assertEquals(startpointTimestamp.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
- Assert.assertEquals(startpointTimestamp.getTimestampOffset(), ((StartpointTimestamp) startpointFromSerde).getTimestampOffset());
- }
-
- @Test
- public void testStartpointEarliestSerde() {
- StartpointOldest startpointOldest = new StartpointOldest();
- Startpoint startpointFromSerde = startpointSerde.fromBytes(startpointSerde.toBytes(startpointOldest));
-
- Assert.assertEquals(startpointOldest.getClass(), startpointFromSerde.getClass());
- Assert.assertEquals(startpointOldest.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
- }
-
- @Test
- public void testStartpointLatestSerde() {
- StartpointUpcoming startpointUpcoming = new StartpointUpcoming();
- Startpoint startpointFromSerde = startpointSerde.fromBytes(startpointSerde.toBytes(startpointUpcoming));
-
- Assert.assertEquals(startpointUpcoming.getClass(), startpointFromSerde.getClass());
- Assert.assertEquals(startpointUpcoming.getCreationTimestamp(), startpointFromSerde.getCreationTimestamp());
- }
-}
diff --git a/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java b/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java
index b5e35cf..44efc5b 100644
--- a/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java
+++ b/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java
@@ -293,15 +293,15 @@
initTables(cachingTable, guavaTable, remoteTable);
- // 3 per readable table (9)
- // 5 per read/write table (15)
- verify(metricsRegistry, times(24)).newCounter(any(), anyString());
+ // 4 per readable table (12)
+ // 6 per read/write table (18)
+ verify(metricsRegistry, times(30)).newCounter(any(), anyString());
- // 2 per readable table (6)
- // 5 per read/write table (15)
+ // 3 per readable table (9)
+ // 6 per read/write table (18)
// 1 per remote readable table (1)
// 1 per remote read/write table (1)
- verify(metricsRegistry, times(23)).newTimer(any(), anyString());
+ verify(metricsRegistry, times(29)).newTimer(any(), anyString());
// 1 per guava table (1)
// 3 per caching table (2)
diff --git a/samza-core/src/test/java/org/apache/samza/table/ratelimit/TestAsyncRateLimitedTable.java b/samza-core/src/test/java/org/apache/samza/table/ratelimit/TestAsyncRateLimitedTable.java
index 7c646fb..d1493e2 100644
--- a/samza-core/src/test/java/org/apache/samza/table/ratelimit/TestAsyncRateLimitedTable.java
+++ b/samza-core/src/test/java/org/apache/samza/table/ratelimit/TestAsyncRateLimitedTable.java
@@ -32,7 +32,9 @@
import org.apache.samza.table.remote.TableReadFunction;
import org.apache.samza.table.remote.TableWriteFunction;
import org.apache.samza.table.remote.TestRemoteTable;
+
import org.junit.Assert;
+import org.junit.Before;
import org.junit.Test;
import static org.mockito.Mockito.*;
@@ -42,6 +44,46 @@
private final ScheduledExecutorService schedExec = Executors.newSingleThreadScheduledExecutor();
+ private Map<String, String> readMap = new HashMap<>();
+ private AsyncReadWriteTable readTable;
+ private TableRateLimiter readRateLimiter;
+ private TableReadFunction<String, String> readFn;
+ private AsyncReadWriteTable<String, String> writeTable;
+ private TableRateLimiter<String, String> writeRateLimiter;
+ private TableWriteFunction<String, String> writeFn;
+
+ @Before
+ public void prepare() {
+ // Read part
+ readRateLimiter = mock(TableRateLimiter.class);
+ readFn = mock(TableReadFunction.class);
+ doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any());
+ doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any(), any());
+ readMap.put("foo", "bar");
+ doReturn(CompletableFuture.completedFuture(readMap)).when(readFn).getAllAsync(any());
+ doReturn(CompletableFuture.completedFuture(readMap)).when(readFn).getAllAsync(any(), any());
+ doReturn(CompletableFuture.completedFuture(5)).when(readFn).readAsync(anyInt(), any());
+ AsyncReadWriteTable delegate = new AsyncRemoteTable(readFn, null);
+ readTable = new AsyncRateLimitedTable("t1", delegate, readRateLimiter, null, schedExec);
+ readTable.init(TestRemoteTable.getMockContext());
+
+ // Write part
+ writeRateLimiter = mock(TableRateLimiter.class);
+ writeFn = mock(TableWriteFunction.class);
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any(), any());
+ doReturn(CompletableFuture.completedFuture(5)).when(writeFn).writeAsync(anyInt(), any());
+ delegate = new AsyncRemoteTable(readFn, writeFn);
+ writeTable = new AsyncRateLimitedTable("t1", delegate, readRateLimiter, writeRateLimiter, schedExec);
+ writeTable.init(TestRemoteTable.getMockContext());
+ }
+
@Test(expected = NullPointerException.class)
public void testNotNullTableId() {
new AsyncRateLimitedTable(null, mock(AsyncReadWriteTable.class),
@@ -71,71 +113,183 @@
}
@Test
- public void testGetThrottling() throws Exception {
- TableRateLimiter readRateLimiter = mock(TableRateLimiter.class);
- TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
- doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any());
- Map<String, String> result = new HashMap<>();
- result.put("foo", "bar");
- doReturn(CompletableFuture.completedFuture(result)).when(readFn).getAllAsync(any());
- AsyncReadWriteTable delegate = new AsyncRemoteTable(readFn, null);
- AsyncRateLimitedTable table = new AsyncRateLimitedTable("t1", delegate,
- readRateLimiter, null, schedExec);
- table.init(TestRemoteTable.getMockContext());
-
- Assert.assertEquals("bar", table.getAsync("foo").get());
+ public void testGetAsync() {
+ Assert.assertEquals("bar", readTable.getAsync("foo").join());
verify(readFn, times(1)).getAsync(any());
+ verify(readFn, times(0)).getAsync(any(), any());
verify(readRateLimiter, times(1)).throttle(anyString());
- verify(readRateLimiter, times(0)).throttle(anyList());
-
- Assert.assertEquals(result, table.getAllAsync(Arrays.asList("")).get());
- verify(readFn, times(1)).getAllAsync(any());
- verify(readRateLimiter, times(1)).throttle(anyList());
- verify(readRateLimiter, times(1)).throttle(anyString());
+ verify(readRateLimiter, times(0)).throttle(anyCollection());
+ verify(readRateLimiter, times(0)).throttle(any(), any());
+ verify(readRateLimiter, times(0)).throttle(anyInt(), any());
+ verify(readRateLimiter, times(0)).throttleRecords(anyCollection());
+ verifyWritePartNotCalled();
}
@Test
- public void testPutThrottling() throws Exception {
- TableRateLimiter readRateLimiter = mock(TableRateLimiter.class);
- TableRateLimiter writeRateLimiter = mock(TableRateLimiter.class);
- TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
- TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
- doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any());
- doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any());
- doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any());
- doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any());
- AsyncReadWriteTable delegate = new AsyncRemoteTable(readFn, writeFn);
- AsyncRateLimitedTable table = new AsyncRateLimitedTable("t1", delegate,
- readRateLimiter, writeRateLimiter, schedExec);
- table.init(TestRemoteTable.getMockContext());
+ public void testGetAsyncWithArgs() {
+ Assert.assertEquals("bar", readTable.getAsync("foo", 1).join());
+ verify(readFn, times(0)).getAsync(any());
+ verify(readFn, times(1)).getAsync(any(), any());
+ verify(readRateLimiter, times(1)).throttle(anyString(), any());
+ verify(readRateLimiter, times(0)).throttle(anyCollection());
+ verify(readRateLimiter, times(0)).throttle(any(), any());
+ verify(readRateLimiter, times(0)).throttle(anyInt(), any());
+ verify(readRateLimiter, times(0)).throttleRecords(anyCollection());
+ verifyWritePartNotCalled();
+ }
- table.putAsync("foo", "bar").get();
+ @Test
+ public void testGetAllAsync() {
+ Assert.assertEquals(readMap, readTable.getAllAsync(Arrays.asList("")).join());
+ verify(readFn, times(1)).getAllAsync(any());
+ verify(readFn, times(0)).getAllAsync(any(), any());
+ verify(readRateLimiter, times(0)).throttle(anyString());
+ verify(readRateLimiter, times(1)).throttle(anyCollection());
+ verify(readRateLimiter, times(0)).throttle(any(), any());
+ verify(readRateLimiter, times(0)).throttle(anyInt(), any());
+ verify(readRateLimiter, times(0)).throttleRecords(anyCollection());
+ verifyWritePartNotCalled();
+ }
+
+ @Test
+ public void testGetAllAsyncWithArgs() {
+ Assert.assertEquals(readMap, readTable.getAllAsync(Arrays.asList(""), "").join());
+ verify(readFn, times(0)).getAllAsync(any());
+ verify(readFn, times(1)).getAllAsync(any(), any());
+ verify(readRateLimiter, times(0)).throttle(anyString());
+ verify(readRateLimiter, times(1)).throttle(anyCollection(), any());
+ verify(readRateLimiter, times(0)).throttle(anyString(), any());
+ verify(readRateLimiter, times(0)).throttle(anyInt(), any());
+ verify(readRateLimiter, times(0)).throttleRecords(anyCollection());
+ verifyWritePartNotCalled();
+ }
+
+ @Test
+ public void testReadAsync() {
+ Assert.assertEquals(5, readTable.readAsync(1, 2).join());
+ verify(readFn, times(1)).readAsync(anyInt(), any());
+ verify(readRateLimiter, times(0)).throttle(anyString());
+ verify(readRateLimiter, times(0)).throttle(anyCollection());
+ verify(readRateLimiter, times(0)).throttle(any(), any());
+ verify(readRateLimiter, times(1)).throttle(anyInt(), any());
+ verify(readRateLimiter, times(0)).throttleRecords(anyCollection());
+ verifyWritePartNotCalled();
+ }
+
+ @Test
+ public void testPutAsync() {
+ writeTable.putAsync("foo", "bar").join();
verify(writeFn, times(1)).putAsync(any(), any());
- verify(writeRateLimiter, times(1)).throttle(anyString(), anyString());
- verify(writeRateLimiter, times(0)).throttleRecords(anyList());
+ verify(writeFn, times(0)).putAsync(any(), any(), any());
verify(writeRateLimiter, times(0)).throttle(anyString());
- verify(writeRateLimiter, times(0)).throttle(anyList());
-
- table.putAllAsync(Arrays.asList(new Entry("1", "2"))).get();
- verify(writeFn, times(1)).putAllAsync(any());
verify(writeRateLimiter, times(1)).throttle(anyString(), anyString());
- verify(writeRateLimiter, times(1)).throttleRecords(anyList());
+ verify(writeRateLimiter, times(0)).throttle(anyCollection());
+ verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+ verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+ verifyReadPartNotCalled();
+ }
+
+ @Test
+ public void testPutAsyncWithArgs() {
+ writeTable.putAsync("foo", "bar", 1).join();
+ verify(writeFn, times(0)).putAsync(any(), any());
+ verify(writeFn, times(1)).putAsync(any(), any(), any());
verify(writeRateLimiter, times(0)).throttle(anyString());
- verify(writeRateLimiter, times(0)).throttle(anyList());
+ verify(writeRateLimiter, times(1)).throttle(anyString(), anyString(), any());
+ verify(writeRateLimiter, times(0)).throttle(anyCollection());
+ verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+ verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+ verifyReadPartNotCalled();
+ }
- table.deleteAsync("foo").get();
- verify(writeFn, times(1)).deleteAsync(anyString());
- verify(writeRateLimiter, times(1)).throttle(anyString(), anyString());
- verify(writeRateLimiter, times(1)).throttleRecords(anyList());
- verify(writeRateLimiter, times(1)).throttle(anyString());
- verify(writeRateLimiter, times(0)).throttle(anyList());
+ @Test
+ public void testPutAllAsync() {
+ writeTable.putAllAsync(Arrays.asList(new Entry("1", "2"))).join();
+ verify(writeFn, times(1)).putAllAsync(anyCollection());
+ verify(writeFn, times(0)).putAllAsync(anyCollection(), any());
+ verify(writeRateLimiter, times(0)).throttle(anyString());
+ verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+ verify(writeRateLimiter, times(0)).throttle(anyCollection());
+ verify(writeRateLimiter, times(1)).throttleRecords(anyCollection());
+ verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+ verifyReadPartNotCalled();
+ }
- table.deleteAllAsync(Arrays.asList("1", "2")).get();
- verify(writeFn, times(1)).deleteAllAsync(any());
- verify(writeRateLimiter, times(1)).throttle(anyString(), anyString());
- verify(writeRateLimiter, times(1)).throttleRecords(anyList());
+ @Test
+ public void testPutAllAsyncWithArgs() {
+ writeTable.putAllAsync(Arrays.asList(new Entry("1", "2")), Arrays.asList(1)).join();
+ verify(writeFn, times(0)).putAllAsync(anyCollection());
+ verify(writeFn, times(1)).putAllAsync(anyCollection(), any());
+ verify(writeRateLimiter, times(0)).throttle(anyString());
+ verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+ verify(writeRateLimiter, times(0)).throttle(anyCollection());
+ verify(writeRateLimiter, times(1)).throttleRecords(anyCollection());
+ verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+ verifyReadPartNotCalled();
+ }
+
+ @Test
+ public void testDeleteAsync() {
+ writeTable.deleteAsync("foo").join();
+ verify(writeFn, times(1)).deleteAsync(any());
+ verify(writeFn, times(0)).deleteAsync(any(), any());
verify(writeRateLimiter, times(1)).throttle(anyString());
- verify(writeRateLimiter, times(1)).throttle(anyList());
+ verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+ verify(writeRateLimiter, times(0)).throttle(anyCollection());
+ verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+ verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+ verifyReadPartNotCalled();
+ }
+
+ @Test
+ public void testDeleteAsyncWithArgs() {
+ writeTable.deleteAsync("foo", 1).join();
+ verify(writeFn, times(0)).deleteAsync(any());
+ verify(writeFn, times(1)).deleteAsync(any(), any());
+ verify(writeRateLimiter, times(1)).throttle(anyString(), any());
+ verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+ verify(writeRateLimiter, times(0)).throttle(anyCollection());
+ verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+ verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+ verifyReadPartNotCalled();
+ }
+
+ @Test
+ public void testDeleteAllAsync() {
+ writeTable.deleteAllAsync(Arrays.asList("1", "2")).join();
+ verify(writeFn, times(1)).deleteAllAsync(anyCollection());
+ verify(writeFn, times(0)).deleteAllAsync(anyCollection(), any());
+ verify(writeRateLimiter, times(0)).throttle(anyString());
+ verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+ verify(writeRateLimiter, times(1)).throttle(anyCollection());
+ verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+ verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+ verifyReadPartNotCalled();
+ }
+
+ @Test
+ public void testDeleteAllAsyncWithArgs() {
+ writeTable.deleteAllAsync(Arrays.asList("1", "2"), 1).join();
+ verify(writeFn, times(0)).deleteAllAsync(anyCollection());
+ verify(writeFn, times(1)).deleteAllAsync(anyCollection(), any());
+ verify(writeRateLimiter, times(0)).throttle(anyString());
+ verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+ verify(writeRateLimiter, times(1)).throttle(anyCollection(), any());
+ verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+ verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+ verifyReadPartNotCalled();
+ }
+
+ @Test
+ public void testWriteAsync() {
+ Assert.assertEquals(5, writeTable.writeAsync(1, 2).join());
+ verify(writeFn, times(1)).writeAsync(anyInt(), any());
+ verify(writeRateLimiter, times(0)).throttle(anyString());
+ verify(writeRateLimiter, times(0)).throttle(anyCollection());
+ verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+ verify(writeRateLimiter, times(1)).throttle(anyInt(), any());
+ verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+ verifyReadPartNotCalled();
}
@Test
@@ -157,4 +311,34 @@
verify(writeFn, times(1)).close();
}
+ private void verifyReadPartNotCalled() {
+ verify(readFn, times(0)).getAsync(any());
+ verify(readFn, times(0)).getAsync(any(), any());
+ verify(readFn, times(0)).getAllAsync(any(), any());
+ verify(readFn, times(0)).getAllAsync(any(), any(), any());
+ verify(readFn, times(0)).readAsync(anyInt(), any());
+ verify(readRateLimiter, times(0)).throttle(anyString());
+ verify(readRateLimiter, times(0)).throttle(anyCollection());
+ verify(readRateLimiter, times(0)).throttle(any(), any());
+ verify(readRateLimiter, times(0)).throttle(anyInt(), any());
+ verify(readRateLimiter, times(0)).throttleRecords(anyCollection());
+ }
+
+ private void verifyWritePartNotCalled() {
+ verify(writeFn, times(0)).putAsync(any(), any());
+ verify(writeFn, times(0)).putAsync(any(), any(), any());
+ verify(writeFn, times(0)).putAllAsync(any());
+ verify(writeFn, times(0)).putAllAsync(any(), any());
+ verify(writeFn, times(0)).deleteAsync(any());
+ verify(writeFn, times(0)).deleteAsync(any(), any());
+ verify(writeFn, times(0)).deleteAllAsync(any());
+ verify(writeFn, times(0)).deleteAllAsync(any(), any());
+ verify(writeFn, times(0)).writeAsync(anyInt(), any());
+ verify(writeRateLimiter, times(0)).throttle(anyString());
+ verify(writeRateLimiter, times(0)).throttle(anyString(), anyString());
+ verify(writeRateLimiter, times(0)).throttle(anyCollection());
+ verify(writeRateLimiter, times(0)).throttleRecords(anyCollection());
+ verify(writeRateLimiter, times(0)).throttle(anyInt(), any());
+ }
+
}
diff --git a/samza-core/src/test/java/org/apache/samza/table/remote/TestAsyncRemoteTable.java b/samza-core/src/test/java/org/apache/samza/table/remote/TestAsyncRemoteTable.java
index 2a447e0..d557c31 100644
--- a/samza-core/src/test/java/org/apache/samza/table/remote/TestAsyncRemoteTable.java
+++ b/samza-core/src/test/java/org/apache/samza/table/remote/TestAsyncRemoteTable.java
@@ -20,7 +20,6 @@
import java.util.Arrays;
-import org.apache.samza.context.Context;
import org.apache.samza.storage.kv.Entry;
import org.junit.Before;
@@ -30,6 +29,7 @@
import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.*;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
@@ -60,6 +60,15 @@
}
@Test
+ public void testGetAsyncWithArgs() {
+ int times = 0;
+ roTable.getAsync(1, 1);
+ verify(readFn, times(++times)).getAsync(any(), any());
+ rwTable.getAsync(1, 1);
+ verify(readFn, times(++times)).getAsync(any(), any());
+ }
+
+ @Test
public void testGetAllAsync() {
int times = 0;
roTable.getAllAsync(Arrays.asList(1, 2));
@@ -69,6 +78,24 @@
}
@Test
+ public void testGetAllAsyncWithArgs() {
+ int times = 0;
+ roTable.getAllAsync(Arrays.asList(1, 2), Arrays.asList(0, 0));
+ verify(readFn, times(++times)).getAllAsync(any(), any());
+ rwTable.getAllAsync(Arrays.asList(1, 2), Arrays.asList(0, 0));
+ verify(readFn, times(++times)).getAllAsync(any(), any());
+ }
+
+ @Test
+ public void testReadAsync() {
+ int times = 0;
+ roTable.readAsync(1, 2, 3);
+ verify(readFn, times(++times)).readAsync(anyInt(), any(), any());
+ rwTable.readAsync(1, 2, 3);
+ verify(readFn, times(++times)).readAsync(anyInt(), any(), any());
+ }
+
+ @Test
public void testPutAsync() {
verifyFailure(() -> roTable.putAsync(1, 2));
rwTable.putAsync(1, 2);
@@ -76,6 +103,13 @@
}
@Test
+ public void testPutAsyncWithArgs() {
+ verifyFailure(() -> roTable.putAsync(1, 2, 3));
+ rwTable.putAsync(1, 2, 3);
+ verify(writeFn, times(1)).putAsync(any(), any(), any());
+ }
+
+ @Test
public void testPutAllAsync() {
verifyFailure(() -> roTable.putAllAsync(Arrays.asList(new Entry(1, 2))));
rwTable.putAllAsync(Arrays.asList(new Entry(1, 2)));
@@ -83,11 +117,26 @@
}
@Test
+ public void testPutAllAsyncWithArgs() {
+ verifyFailure(() -> roTable.putAllAsync(Arrays.asList(new Entry(1, 2)), Arrays.asList(0, 0)));
+ rwTable.putAllAsync(Arrays.asList(new Entry(1, 2)), Arrays.asList(0, 0));
+ verify(writeFn, times(1)).putAllAsync(any(), any());
+ }
+
+ @Test
public void testDeleteAsync() {
verifyFailure(() -> roTable.deleteAsync(1));
rwTable.deleteAsync(1);
verify(writeFn, times(1)).deleteAsync(any());
}
+
+ @Test
+ public void testDeleteAsyncWithArgs() {
+ verifyFailure(() -> roTable.deleteAsync(1, 2));
+ rwTable.deleteAsync(1, 2);
+ verify(writeFn, times(1)).deleteAsync(any(), any());
+ }
+
@Test
public void testDeleteAllAsync() {
verifyFailure(() -> roTable.deleteAllAsync(Arrays.asList(1)));
@@ -96,13 +145,17 @@
}
@Test
- public void testInit() {
- roTable.init(mock(Context.class));
- verify(readFn, times(1)).init(any());
- verify(writeFn, times(0)).init(any());
- rwTable.init(mock(Context.class));
- verify(readFn, times(2)).init(any());
- verify(writeFn, times(1)).init(any());
+ public void testDeleteAllAsyncWithArgs() {
+ verifyFailure(() -> roTable.deleteAllAsync(Arrays.asList(1), Arrays.asList(2)));
+ rwTable.deleteAllAsync(Arrays.asList(1, 2), Arrays.asList(2));
+ verify(writeFn, times(1)).deleteAllAsync(any(), any());
+ }
+
+ @Test
+ public void testWriteAsync() {
+ verifyFailure(() -> roTable.writeAsync(1, 2, 3));
+ rwTable.writeAsync(1, 2, 3);
+ verify(writeFn, times(1)).writeAsync(anyInt(), any(), any());
}
@Test
diff --git a/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java b/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java
index 93b2dab..02625bb 100644
--- a/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java
+++ b/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java
@@ -26,8 +26,12 @@
import org.apache.samza.metrics.MetricsRegistry;
import org.apache.samza.metrics.Timer;
import org.apache.samza.storage.kv.Entry;
+import org.apache.samza.table.AsyncReadWriteTable;
+import org.apache.samza.table.ratelimit.AsyncRateLimitedTable;
+import org.apache.samza.table.retry.AsyncRetriableTable;
import org.apache.samza.table.retry.TableRetryPolicy;
+import org.apache.samza.testUtils.TestUtils;
import org.junit.Assert;
import org.junit.Test;
@@ -38,14 +42,12 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyCollection;
-import static org.mockito.Matchers.anyString;
+import static org.junit.Assert.*;
+import static org.mockito.Matchers.*;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
@@ -55,8 +57,6 @@
public class TestRemoteTable {
- private final ScheduledExecutorService schedExec = Executors.newSingleThreadScheduledExecutor();
-
public static Context getMockContext() {
Context context = new MockContext();
MetricsRegistry metricsRegistry = mock(MetricsRegistry.class);
@@ -88,10 +88,14 @@
readRateLimiter, writeRateLimiter, rateLimitingExecutor,
readPolicy, writePolicy, retryExecutor, cbExecutor);
table.init(getMockContext());
+ verify(readFn, times(1)).init(any(), any());
+ if (writeFn != null) {
+ verify(writeFn, times(1)).init(any(), any());
+ }
return (T) table;
}
- private void doTestGet(boolean sync, boolean error, boolean retry) throws Exception {
+ private void doTestGet(boolean sync, boolean error, boolean retry) {
String tableId = "testGet-" + sync + error + retry;
TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
// Sync is backed by async so needs to mock the async method
@@ -114,27 +118,49 @@
doReturn(true).when(readFn).isRetriable(any());
}
RemoteTable<String, String> table = getTable(tableId, readFn, null, retry);
- Assert.assertEquals("bar", sync ? table.get("foo") : table.getAsync("foo").get());
+ Assert.assertEquals("bar", sync ? table.get("foo") : table.getAsync("foo").join());
verify(table.readRateLimiter, times(error && retry ? 2 : 1)).throttle(anyString());
}
@Test
- public void testGet() throws Exception {
+ public void testInit() {
+ String tableId = "testInit";
+ TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
+ TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
+ RemoteTable<String, String> table = getTable(tableId, readFn, writeFn, true);
+ // AsyncRetriableTable
+ AsyncReadWriteTable innerTable = TestUtils.getFieldValue(table, "asyncTable");
+ Assert.assertTrue(innerTable instanceof AsyncRetriableTable);
+ Assert.assertNotNull(TestUtils.getFieldValue(innerTable, "readRetryMetrics"));
+ Assert.assertNotNull(TestUtils.getFieldValue(innerTable, "writeRetryMetrics"));
+ // AsyncRateLimitedTable
+ innerTable = TestUtils.getFieldValue(innerTable, "table");
+ Assert.assertTrue(innerTable instanceof AsyncRateLimitedTable);
+ // AsyncRemoteTable
+ innerTable = TestUtils.getFieldValue(innerTable, "table");
+ Assert.assertTrue(innerTable instanceof AsyncRemoteTable);
+ // Verify table functions are initialized
+ verify(readFn, times(1)).init(any(), any());
+ verify(writeFn, times(1)).init(any(), any());
+ }
+
+ @Test
+ public void testGet() {
doTestGet(true, false, false);
}
@Test
- public void testGetAsync() throws Exception {
+ public void testGetAsync() {
doTestGet(false, false, false);
}
- @Test(expected = ExecutionException.class)
- public void testGetAsyncError() throws Exception {
+ @Test(expected = RuntimeException.class)
+ public void testGetAsyncError() {
doTestGet(false, true, false);
}
@Test
- public void testGetAsyncErrorRetried() throws Exception {
+ public void testGetAsyncErrorRetried() {
doTestGet(false, true, true);
}
@@ -160,7 +186,44 @@
});
}
- private void doTestPut(boolean sync, boolean error, boolean isDelete, boolean retry) throws Exception {
+ public void doTestRead(boolean sync, boolean error) {
+ TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
+ RemoteTable<String, String> table = getTable("testRead-" + sync + error,
+ readFn, mock(TableWriteFunction.class), false);
+ CompletableFuture<?> future;
+ if (error) {
+ future = new CompletableFuture();
+ future.completeExceptionally(new RuntimeException("Test exception"));
+ } else {
+ future = CompletableFuture.completedFuture(5);
+ }
+ // Sync is backed by async so needs to mock the async method
+ doReturn(future).when(readFn).readAsync(anyInt(), any());
+
+ int readResult = sync
+ ? table.read(1, 2)
+ : (Integer) table.readAsync(1, 2).join();
+ verify(readFn, times(1)).readAsync(anyInt(), any());
+ Assert.assertEquals(5, readResult);
+ verify(table.readRateLimiter, times(1)).throttle(anyInt(), any());
+ }
+
+ @Test
+ public void testRead() {
+ doTestRead(true, false);
+ }
+
+ @Test
+ public void testReadAsync() {
+ doTestRead(false, false);
+ }
+
+ @Test(expected = RuntimeException.class)
+ public void testReadAsyncError() {
+ doTestRead(false, true);
+ }
+
+ private void doTestPut(boolean sync, boolean error, boolean isDelete, boolean retry) {
String tableId = "testPut-" + sync + error + isDelete + retry;
TableWriteFunction<String, String> mockWriteFn = mock(TableWriteFunction.class);
TableWriteFunction<String, String> writeFn = mockWriteFn;
@@ -192,7 +255,7 @@
if (sync) {
table.put("foo", isDelete ? null : "bar");
} else {
- table.putAsync("foo", isDelete ? null : "bar").get();
+ table.putAsync("foo", isDelete ? null : "bar").join();
}
ArgumentCaptor<String> keyCaptor = ArgumentCaptor.forClass(String.class);
ArgumentCaptor<String> valCaptor = ArgumentCaptor.forClass(String.class);
@@ -211,36 +274,36 @@
}
@Test
- public void testPut() throws Exception {
+ public void testPut() {
doTestPut(true, false, false, false);
}
@Test
- public void testPutDelete() throws Exception {
+ public void testPutDelete() {
doTestPut(true, false, true, false);
}
@Test
- public void testPutAsync() throws Exception {
+ public void testPutAsync() {
doTestPut(false, false, false, false);
}
@Test
- public void testPutAsyncDelete() throws Exception {
+ public void testPutAsyncDelete() {
doTestPut(false, false, true, false);
}
- @Test(expected = ExecutionException.class)
- public void testPutAsyncError() throws Exception {
+ @Test(expected = RuntimeException.class)
+ public void testPutAsyncError() {
doTestPut(false, true, false, false);
}
@Test
- public void testPutAsyncErrorRetried() throws Exception {
+ public void testPutAsyncErrorRetried() {
doTestPut(false, true, false, true);
}
- private void doTestDelete(boolean sync, boolean error) throws Exception {
+ private void doTestDelete(boolean sync, boolean error) {
TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
RemoteTable<String, String> table = getTable("testDelete-" + sync + error,
mock(TableReadFunction.class), writeFn, false);
@@ -257,7 +320,7 @@
if (sync) {
table.delete("foo");
} else {
- table.deleteAsync("foo").get();
+ table.deleteAsync("foo").join();
}
verify(writeFn, times(1)).deleteAsync(argCaptor.capture());
Assert.assertEquals("foo", argCaptor.getValue());
@@ -265,21 +328,21 @@
}
@Test
- public void testDelete() throws Exception {
+ public void testDelete() {
doTestDelete(true, false);
}
@Test
- public void testDeleteAsync() throws Exception {
+ public void testDeleteAsync() {
doTestDelete(false, false);
}
- @Test(expected = ExecutionException.class)
- public void testDeleteAsyncError() throws Exception {
+ @Test(expected = RuntimeException.class)
+ public void testDeleteAsyncError() {
doTestDelete(false, true);
}
- private void doTestGetAll(boolean sync, boolean error, boolean partial) throws Exception {
+ private void doTestGetAll(boolean sync, boolean error, boolean partial) {
TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
Map<String, String> res = new HashMap<>();
res.put("foo1", "bar1");
@@ -297,32 +360,32 @@
doReturn(future).when(readFn).getAllAsync(any());
RemoteTable<String, String> table = getTable("testGetAll-" + sync + error + partial, readFn, null, false);
Assert.assertEquals(res, sync ? table.getAll(Arrays.asList("foo1", "foo2"))
- : table.getAllAsync(Arrays.asList("foo1", "foo2")).get());
+ : table.getAllAsync(Arrays.asList("foo1", "foo2")).join());
verify(table.readRateLimiter, times(1)).throttle(anyCollection());
}
@Test
- public void testGetAll() throws Exception {
+ public void testGetAll() {
doTestGetAll(true, false, false);
}
@Test
- public void testGetAllAsync() throws Exception {
+ public void testGetAllAsync() {
doTestGetAll(false, false, false);
}
- @Test(expected = ExecutionException.class)
- public void testGetAllAsyncError() throws Exception {
+ @Test(expected = RuntimeException.class)
+ public void testGetAllAsyncError() {
doTestGetAll(false, true, false);
}
// Partial result is an acceptable scenario
@Test
- public void testGetAllPartialResult() throws Exception {
+ public void testGetAllPartialResult() {
doTestGetAll(false, false, true);
}
- public void doTestPutAll(boolean sync, boolean error, boolean hasDelete) throws Exception {
+ public void doTestPutAll(boolean sync, boolean error, boolean hasDelete) {
TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
RemoteTable<String, String> table = getTable("testPutAll-" + sync + error + hasDelete,
mock(TableReadFunction.class), writeFn, false);
@@ -344,7 +407,7 @@
if (sync) {
table.putAll(entries);
} else {
- table.putAllAsync(entries).get();
+ table.putAllAsync(entries).join();
}
verify(writeFn, times(1)).putAllAsync(argCaptor.capture());
if (hasDelete) {
@@ -361,31 +424,31 @@
}
@Test
- public void testPutAll() throws Exception {
+ public void testPutAll() {
doTestPutAll(true, false, false);
}
@Test
- public void testPutAllHasDelete() throws Exception {
+ public void testPutAllHasDelete() {
doTestPutAll(true, false, true);
}
@Test
- public void testPutAllAsync() throws Exception {
+ public void testPutAllAsync() {
doTestPutAll(false, false, false);
}
@Test
- public void testPutAllAsyncHasDelete() throws Exception {
+ public void testPutAllAsyncHasDelete() {
doTestPutAll(false, false, true);
}
- @Test(expected = ExecutionException.class)
- public void testPutAllAsyncError() throws Exception {
+ @Test(expected = RuntimeException.class)
+ public void testPutAllAsyncError() {
doTestPutAll(false, true, false);
}
- public void doTestDeleteAll(boolean sync, boolean error) throws Exception {
+ public void doTestDeleteAll(boolean sync, boolean error) {
TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
RemoteTable<String, String> table = getTable("testDeleteAll-" + sync + error,
mock(TableReadFunction.class), writeFn, false);
@@ -403,7 +466,7 @@
if (sync) {
table.deleteAll(keys);
} else {
- table.deleteAllAsync(keys).get();
+ table.deleteAllAsync(keys).join();
}
verify(writeFn, times(1)).deleteAllAsync(argCaptor.capture());
Assert.assertEquals(keys, argCaptor.getValue());
@@ -411,20 +474,57 @@
}
@Test
- public void testDeleteAll() throws Exception {
+ public void testDeleteAll() {
doTestDeleteAll(true, false);
}
@Test
- public void testDeleteAllAsync() throws Exception {
+ public void testDeleteAllAsync() {
doTestDeleteAll(false, false);
}
- @Test(expected = ExecutionException.class)
- public void testDeleteAllAsyncError() throws Exception {
+ @Test(expected = RuntimeException.class)
+ public void testDeleteAllAsyncError() {
doTestDeleteAll(false, true);
}
+ public void doTestWrite(boolean sync, boolean error) {
+ TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
+ RemoteTable<String, String> table = getTable("testWrite-" + sync + error,
+ mock(TableReadFunction.class), writeFn, false);
+ CompletableFuture<?> future;
+ if (error) {
+ future = new CompletableFuture();
+ future.completeExceptionally(new RuntimeException("Test exception"));
+ } else {
+ future = CompletableFuture.completedFuture(5);
+ }
+ // Sync is backed by async so needs to mock the async method
+ doReturn(future).when(writeFn).writeAsync(anyInt(), any());
+
+ int writeResult = sync
+ ? table.write(1, 2)
+ : (Integer) table.writeAsync(1, 2).join();
+ verify(writeFn, times(1)).writeAsync(anyInt(), any());
+ Assert.assertEquals(5, writeResult);
+ verify(table.writeRateLimiter, times(1)).throttle(anyInt(), any());
+ }
+
+ @Test
+ public void testWrite() {
+ doTestWrite(true, false);
+ }
+
+ @Test
+ public void testWriteAsync() {
+ doTestWrite(false, false);
+ }
+
+ @Test(expected = RuntimeException.class)
+ public void testWriteAsyncError() {
+ doTestWrite(false, true);
+ }
+
@Test
public void testFlush() {
TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
@@ -434,7 +534,7 @@
}
@Test
- public void testGetWithCallbackExecutor() throws Exception {
+ public void testGetWithCallbackExecutor() {
TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
// Sync is backed by async so needs to mock the async method
doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(anyString());
@@ -448,4 +548,106 @@
Assert.assertNotSame(testThread, Thread.currentThread());
});
}
+
+ @Test
+ public void testGetDelegation() {
+ TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
+ doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any());
+ doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any(), any());
+ Map<String, String> result = new HashMap<>();
+ result.put("foo", "bar");
+ doReturn(CompletableFuture.completedFuture(result)).when(readFn).getAllAsync(any());
+ doReturn(CompletableFuture.completedFuture(result)).when(readFn).getAllAsync(any(), any());
+ doReturn(CompletableFuture.completedFuture(5)).when(readFn).readAsync(anyInt(), any());
+
+ RemoteTable<String, String> table = getTable("testGetDelegation", readFn, null,
+ Executors.newSingleThreadExecutor(), true);
+ verify(readFn, times(1)).init(any(), any());
+
+ // GetAsync
+ verify(readFn, times(0)).getAsync(any());
+ verify(readFn, times(0)).getAsync(any(), any());
+ assertEquals("bar", table.getAsync("foo").join());
+ verify(readFn, times(1)).getAsync(any());
+ verify(readFn, times(0)).getAsync(any(), any());
+ assertEquals("bar", table.getAsync("foo", 1).join());
+ verify(readFn, times(1)).getAsync(any());
+ verify(readFn, times(1)).getAsync(any(), any());
+ // GetAllAsync
+ verify(readFn, times(0)).getAllAsync(any());
+ verify(readFn, times(0)).getAllAsync(any(), any());
+ assertEquals(result, table.getAllAsync(Arrays.asList("foo")).join());
+ verify(readFn, times(1)).getAllAsync(any());
+ verify(readFn, times(0)).getAllAsync(any(), any());
+ assertEquals(result, table.getAllAsync(Arrays.asList("foo"), Arrays.asList(1)).join());
+ verify(readFn, times(1)).getAllAsync(any());
+ verify(readFn, times(1)).getAllAsync(any(), any());
+ // ReadAsync
+ verify(readFn, times(0)).readAsync(anyInt(), any());
+ assertEquals(5, table.readAsync(1, 2).join());
+ verify(readFn, times(1)).readAsync(anyInt(), any());
+
+ table.close();
+ }
+
+ @Test
+ public void testPutAndDeleteDelegation() {
+ TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
+ TableWriteFunction writeFn = mock(TableWriteFunction.class);
+ doReturn(true).when(writeFn).isRetriable(any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(anyCollection());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(anyCollection(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(anyCollection());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(anyCollection(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).writeAsync(anyInt(), any());
+
+ RemoteTable<String, String> table = getTable("testGetDelegation", readFn, writeFn,
+ Executors.newSingleThreadExecutor(), true);
+ verify(readFn, times(1)).init(any(), any());
+
+ // PutAsync
+ verify(writeFn, times(0)).putAsync(any(), any());
+ verify(writeFn, times(0)).putAsync(any(), any(), any());
+ table.putAsync("roo", "bar").join();
+ verify(writeFn, times(1)).putAsync(any(), any());
+ verify(writeFn, times(0)).putAsync(any(), any(), any());
+ table.putAsync("foo", "bar", 3).join();
+ verify(writeFn, times(1)).putAsync(any(), any());
+ verify(writeFn, times(1)).putAsync(any(), any(), any());
+ // PutAllAsync
+ verify(writeFn, times(0)).putAllAsync(anyCollection());
+ verify(writeFn, times(0)).putAllAsync(anyCollection(), any());
+ table.putAllAsync(Arrays.asList(new Entry("foo", "bar"))).join();
+ verify(writeFn, times(1)).putAllAsync(anyCollection());
+ verify(writeFn, times(0)).putAllAsync(anyCollection(), any());
+ table.putAllAsync(Arrays.asList(new Entry("foo", "bar")), 2).join();
+ verify(writeFn, times(1)).putAllAsync(anyCollection());
+ verify(writeFn, times(1)).putAllAsync(anyCollection(), any());
+ // DeleteAsync
+ verify(writeFn, times(0)).deleteAsync(any());
+ verify(writeFn, times(0)).deleteAsync(any(), any());
+ table.deleteAsync("foo").join();
+ verify(writeFn, times(1)).deleteAsync(any());
+ verify(writeFn, times(0)).deleteAsync(any(), any());
+ table.deleteAsync("foo", 2).join();
+ verify(writeFn, times(1)).deleteAsync(any());
+ verify(writeFn, times(1)).deleteAsync(any(), any());
+ // DeleteAllAsync
+ verify(writeFn, times(0)).deleteAllAsync(anyCollection());
+ verify(writeFn, times(0)).deleteAllAsync(anyCollection(), any());
+ table.deleteAllAsync(Arrays.asList("foo")).join();
+ verify(writeFn, times(1)).deleteAllAsync(anyCollection());
+ verify(writeFn, times(0)).deleteAllAsync(anyCollection(), any());
+ table.deleteAllAsync(Arrays.asList("foo"), Arrays.asList(2)).join();
+ verify(writeFn, times(1)).deleteAllAsync(anyCollection());
+ verify(writeFn, times(1)).deleteAllAsync(anyCollection(), any());
+ // WriteAsync
+ verify(writeFn, times(0)).writeAsync(anyInt(), any());
+ table.writeAsync(1, 2).join();
+ verify(writeFn, times(1)).writeAsync(anyInt(), any());
+ }
}
diff --git a/samza-core/src/test/java/org/apache/samza/table/remote/descriptors/TestRemoteTableDescriptor.java b/samza-core/src/test/java/org/apache/samza/table/remote/descriptors/TestRemoteTableDescriptor.java
index 5bc8339..ce89c5a 100644
--- a/samza-core/src/test/java/org/apache/samza/table/remote/descriptors/TestRemoteTableDescriptor.java
+++ b/samza-core/src/test/java/org/apache/samza/table/remote/descriptors/TestRemoteTableDescriptor.java
@@ -111,17 +111,17 @@
@Test
public void testSerializeWithLimiterAndReadCredFn() {
- doTestSerialize(createMockRateLimiter(), (k, v) -> 1, null);
+ doTestSerialize(createMockRateLimiter(), (k, v, args) -> 1, null);
}
@Test
public void testSerializeWithLimiterAndWriteCredFn() {
- doTestSerialize(createMockRateLimiter(), null, (k, v) -> 1);
+ doTestSerialize(createMockRateLimiter(), null, (k, v, args) -> 1);
}
@Test
public void testSerializeWithLimiterAndReadWriteCredFns() {
- doTestSerialize(createMockRateLimiter(), (key, value) -> 1, (key, value) -> 1);
+ doTestSerialize(createMockRateLimiter(), (key, value, args) -> 1, (key, value, args) -> 1);
}
@Test
@@ -263,7 +263,7 @@
int numCalls = 0;
@Override
- public int getCredits(K key, V value) {
+ public int getCredits(K key, V value, Object ... args) {
numCalls++;
return 1;
}
diff --git a/samza-core/src/test/java/org/apache/samza/table/retry/TestAsyncRetriableTable.java b/samza-core/src/test/java/org/apache/samza/table/retry/TestAsyncRetriableTable.java
index 1f4b514..ec4307d 100644
--- a/samza-core/src/test/java/org/apache/samza/table/retry/TestAsyncRetriableTable.java
+++ b/samza-core/src/test/java/org/apache/samza/table/retry/TestAsyncRetriableTable.java
@@ -23,7 +23,6 @@
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
@@ -77,7 +76,51 @@
}
@Test
- public void testGetWithoutRetry() throws Exception {
+ public void testGetDelegation() {
+ TableRetryPolicy policy = new TableRetryPolicy();
+ policy.withFixedBackoff(Duration.ofMillis(100));
+ TableReadFunction readFn = mock(TableReadFunction.class);
+ doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any());
+ doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any(), any());
+ Map<String, String> result = new HashMap<>();
+ result.put("foo", "bar");
+ doReturn(CompletableFuture.completedFuture(result)).when(readFn).getAllAsync(any());
+ doReturn(CompletableFuture.completedFuture(result)).when(readFn).getAllAsync(any(), any());
+ doReturn(CompletableFuture.completedFuture(5)).when(readFn).readAsync(anyInt(), any());
+ AsyncReadWriteTable delegate = new AsyncRemoteTable(readFn, null);
+ AsyncRetriableTable table = new AsyncRetriableTable("t1", delegate, policy, null, schedExec, readFn, null);
+
+ table.init(TestRemoteTable.getMockContext());
+ verify(readFn, times(0)).init(any(), any());
+
+ // GetAsync
+ verify(readFn, times(0)).getAsync(any());
+ verify(readFn, times(0)).getAsync(any(), any());
+ assertEquals("bar", table.getAsync("foo").join());
+ verify(readFn, times(1)).getAsync(any());
+ verify(readFn, times(0)).getAsync(any(), any());
+ assertEquals("bar", table.getAsync("foo", 1).join());
+ verify(readFn, times(1)).getAsync(any());
+ verify(readFn, times(1)).getAsync(any(), any());
+ // GetAllAsync
+ verify(readFn, times(0)).getAllAsync(any());
+ verify(readFn, times(0)).getAllAsync(any(), any());
+ assertEquals(result, table.getAllAsync(Arrays.asList("foo")).join());
+ verify(readFn, times(1)).getAllAsync(any());
+ verify(readFn, times(0)).getAllAsync(any(), any());
+ assertEquals(result, table.getAllAsync(Arrays.asList("foo"), Arrays.asList(1)).join());
+ verify(readFn, times(1)).getAllAsync(any());
+ verify(readFn, times(1)).getAllAsync(any(), any());
+ // ReadAsync
+ verify(readFn, times(0)).readAsync(anyInt(), any());
+ assertEquals(5, table.readAsync(1, 2).join());
+ verify(readFn, times(1)).readAsync(anyInt(), any());
+
+ table.close();
+ }
+
+ @Test
+ public void testGetWithoutRetry() {
TableRetryPolicy policy = new TableRetryPolicy();
policy.withFixedBackoff(Duration.ofMillis(100));
TableReadFunction readFn = mock(TableReadFunction.class);
@@ -91,11 +134,11 @@
int times = 0;
table.init(TestRemoteTable.getMockContext());
- verify(readFn, times(1)).init(any());
- assertEquals("bar", table.getAsync("foo").get());
+ verify(readFn, times(0)).init(any(), any());
+ assertEquals("bar", table.getAsync("foo").join());
verify(readFn, times(1)).getAsync(any());
assertEquals(++times, table.readRetryMetrics.successCount.getCount());
- assertEquals(result, table.getAllAsync(Arrays.asList("foo")).get());
+ assertEquals(result, table.getAllAsync(Arrays.asList("foo")).join());
verify(readFn, times(1)).getAllAsync(any());
assertEquals(++times, table.readRetryMetrics.successCount.getCount());
assertEquals(0, table.readRetryMetrics.retryCount.getCount());
@@ -107,7 +150,7 @@
}
@Test
- public void testGetWithRetryDisabled() throws Exception {
+ public void testGetWithRetryDisabled() {
TableRetryPolicy policy = new TableRetryPolicy();
policy.withFixedBackoff(Duration.ofMillis(10));
policy.withStopAfterDelay(Duration.ofMillis(100));
@@ -121,9 +164,9 @@
table.init(TestRemoteTable.getMockContext());
try {
- table.getAsync("foo").get();
+ table.getAsync("foo").join();
fail();
- } catch (ExecutionException e) {
+ } catch (Throwable t) {
}
verify(readFn, times(1)).getAsync(any());
@@ -134,7 +177,7 @@
}
@Test
- public void testGetAllWithOneRetry() throws Exception {
+ public void testGetAllWithOneRetry() {
TableRetryPolicy policy = new TableRetryPolicy();
policy.withFixedBackoff(Duration.ofMillis(10));
TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
@@ -153,13 +196,13 @@
future.completeExceptionally(new RuntimeException("test exception"));
}
return future;
- }).when(readFn).getAllAsync(any());
+ }).when(readFn).getAllAsync(anyCollection());
AsyncReadWriteTable delegate = new AsyncRemoteTable(readFn, null);
AsyncRetriableTable table = new AsyncRetriableTable("t1", delegate, policy, null, schedExec, readFn, null);
table.init(TestRemoteTable.getMockContext());
- assertEquals(map, table.getAllAsync(Arrays.asList("foo1", "foo2")).get());
+ assertEquals(map, table.getAllAsync(Arrays.asList("foo1", "foo2")).join());
verify(readFn, times(2)).getAllAsync(any());
assertEquals(1, table.readRetryMetrics.retryCount.getCount());
assertEquals(0, table.readRetryMetrics.successCount.getCount());
@@ -168,7 +211,7 @@
}
@Test
- public void testGetWithPermFailureOnTimeout() throws Exception {
+ public void testGetWithPermFailureOnTimeout() {
TableRetryPolicy policy = new TableRetryPolicy();
policy.withFixedBackoff(Duration.ofMillis(5));
policy.withStopAfterDelay(Duration.ofMillis(100));
@@ -182,9 +225,9 @@
table.init(TestRemoteTable.getMockContext());
try {
- table.getAsync("foo").get();
+ table.getAsync("foo").join();
fail();
- } catch (ExecutionException e) {
+ } catch (Throwable t) {
}
verify(readFn, atLeast(3)).getAsync(any());
@@ -195,7 +238,7 @@
}
@Test
- public void testGetWithPermFailureOnMaxCount() throws Exception {
+ public void testGetWithPermFailureOnMaxCount() {
TableRetryPolicy policy = new TableRetryPolicy();
policy.withFixedBackoff(Duration.ofMillis(5));
policy.withStopAfterAttempts(10);
@@ -209,9 +252,9 @@
table.init(TestRemoteTable.getMockContext());
try {
- table.getAsync("foo").get();
+ table.getAsync("foo").join();
fail();
- } catch (ExecutionException e) {
+ } catch (Throwable t) {
}
verify(readFn, atLeast(11)).getAsync(any());
@@ -222,7 +265,68 @@
}
@Test
- public void testPutWithoutRetry() throws Exception {
+ public void testPutAndDeleteDelegation() {
+ TableRetryPolicy policy = new TableRetryPolicy();
+ policy.withFixedBackoff(Duration.ofMillis(100));
+ TableReadFunction readFn = mock(TableReadFunction.class);
+ TableWriteFunction writeFn = mock(TableWriteFunction.class);
+ doReturn(true).when(writeFn).isRetriable(any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any(), any());
+ doReturn(CompletableFuture.completedFuture(null)).when(writeFn).writeAsync(anyInt(), any());
+ AsyncReadWriteTable delegate = new AsyncRemoteTable(readFn, writeFn);
+ AsyncRetriableTable table = new AsyncRetriableTable("t1", delegate, null, policy, schedExec, readFn, writeFn);
+
+ // PutAsync
+ verify(writeFn, times(0)).putAsync(any(), any());
+ verify(writeFn, times(0)).putAsync(any(), any(), any());
+ table.putAsync(1, 2).join();
+ verify(writeFn, times(1)).putAsync(any(), any());
+ verify(writeFn, times(0)).putAsync(any(), any(), any());
+ table.putAsync(1, 2, 3).join();
+ verify(writeFn, times(1)).putAsync(any(), any());
+ verify(writeFn, times(1)).putAsync(any(), any(), any());
+ // PutAllAsync
+ verify(writeFn, times(0)).putAllAsync(anyCollection());
+ verify(writeFn, times(0)).putAllAsync(anyCollection(), any());
+ table.putAllAsync(Arrays.asList(1)).join();
+ verify(writeFn, times(1)).putAllAsync(anyCollection());
+ verify(writeFn, times(0)).putAllAsync(anyCollection(), any());
+ table.putAllAsync(Arrays.asList(1), Arrays.asList(1)).join();
+ verify(writeFn, times(1)).putAllAsync(anyCollection());
+ verify(writeFn, times(1)).putAllAsync(anyCollection(), any());
+ // DeleteAsync
+ verify(writeFn, times(0)).deleteAsync(any());
+ verify(writeFn, times(0)).deleteAsync(any(), any());
+ table.deleteAsync(1).join();
+ verify(writeFn, times(1)).deleteAsync(any());
+ verify(writeFn, times(0)).deleteAsync(any(), any());
+ table.deleteAsync(1, 2).join();
+ verify(writeFn, times(1)).deleteAsync(any());
+ verify(writeFn, times(1)).deleteAsync(any(), any());
+ // DeleteAllAsync
+ verify(writeFn, times(0)).deleteAllAsync(anyCollection());
+ verify(writeFn, times(0)).deleteAllAsync(anyCollection(), any());
+ table.deleteAllAsync(Arrays.asList(1)).join();
+ verify(writeFn, times(1)).deleteAllAsync(anyCollection());
+ verify(writeFn, times(0)).deleteAllAsync(anyCollection(), any());
+ table.deleteAllAsync(Arrays.asList(1), Arrays.asList(2)).join();
+ verify(writeFn, times(1)).deleteAllAsync(anyCollection());
+ verify(writeFn, times(1)).deleteAllAsync(anyCollection(), any());
+ // WriteAsync
+ verify(writeFn, times(0)).writeAsync(anyInt(), any());
+ table.writeAsync(1, 2).join();
+ verify(writeFn, times(1)).writeAsync(anyInt(), any());
+ }
+
+ @Test
+ public void testPutWithoutRetry() {
TableRetryPolicy policy = new TableRetryPolicy();
policy.withFixedBackoff(Duration.ofMillis(100));
TableReadFunction readFn = mock(TableReadFunction.class);
@@ -237,18 +341,18 @@
int times = 0;
table.init(TestRemoteTable.getMockContext());
- verify(readFn, times(1)).init(any());
- verify(writeFn, times(1)).init(any());
- table.putAsync("foo", "bar").get();
+ verify(readFn, times(0)).init(any(), any());
+ verify(writeFn, times(0)).init(any(), any());
+ table.putAsync("foo", "bar").join();
verify(writeFn, times(1)).putAsync(any(), any());
assertEquals(++times, table.writeRetryMetrics.successCount.getCount());
- table.putAllAsync(Arrays.asList(new Entry("1", "2"))).get();
+ table.putAllAsync(Arrays.asList(new Entry("1", "2"))).join();
verify(writeFn, times(1)).putAllAsync(any());
assertEquals(++times, table.writeRetryMetrics.successCount.getCount());
- table.deleteAsync("1").get();
+ table.deleteAsync("1").join();
verify(writeFn, times(1)).deleteAsync(any());
assertEquals(++times, table.writeRetryMetrics.successCount.getCount());
- table.deleteAllAsync(Arrays.asList("1", "2")).get();
+ table.deleteAllAsync(Arrays.asList("1", "2")).join();
verify(writeFn, times(1)).deleteAllAsync(any());
assertEquals(++times, table.writeRetryMetrics.successCount.getCount());
assertEquals(0, table.writeRetryMetrics.retryCount.getCount());
@@ -258,7 +362,7 @@
}
@Test
- public void testPutWithRetryDisabled() throws Exception {
+ public void testPutWithRetryDisabled() {
TableRetryPolicy policy = new TableRetryPolicy();
policy.withFixedBackoff(Duration.ofMillis(10));
policy.withStopAfterDelay(Duration.ofMillis(100));
@@ -273,9 +377,9 @@
table.init(TestRemoteTable.getMockContext());
try {
- table.putAsync("foo", "bar").get();
+ table.putAsync("foo", "bar").join();
fail();
- } catch (ExecutionException e) {
+ } catch (Throwable t) {
}
verify(writeFn, times(1)).putAsync(any(), any());
@@ -286,7 +390,7 @@
}
@Test
- public void testPutAllWithOneRetry() throws Exception {
+ public void testPutAllWithOneRetry() {
TableRetryPolicy policy = new TableRetryPolicy();
policy.withFixedBackoff(Duration.ofMillis(10));
TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
@@ -309,7 +413,7 @@
AsyncRetriableTable table = new AsyncRetriableTable("t1", delegate, null, policy, schedExec, readFn, writeFn);
table.init(TestRemoteTable.getMockContext());
- table.putAllAsync(Arrays.asList(new Entry(1, 2))).get();
+ table.putAllAsync(Arrays.asList(new Entry(1, 2))).join();
verify(writeFn, times(2)).putAllAsync(any());
assertEquals(1, table.writeRetryMetrics.retryCount.getCount());
assertEquals(0, table.writeRetryMetrics.successCount.getCount());
@@ -318,7 +422,7 @@
}
@Test
- public void testPutWithPermFailureOnTimeout() throws Exception {
+ public void testPutWithPermFailureOnTimeout() {
TableRetryPolicy policy = new TableRetryPolicy();
policy.withFixedBackoff(Duration.ofMillis(5));
policy.withStopAfterDelay(Duration.ofMillis(100));
@@ -333,9 +437,9 @@
table.init(TestRemoteTable.getMockContext());
try {
- table.putAsync("foo", "bar").get();
+ table.putAsync("foo", "bar").join();
fail();
- } catch (ExecutionException e) {
+ } catch (Throwable t) {
}
verify(writeFn, atLeast(3)).putAsync(any(), any());
@@ -346,7 +450,7 @@
}
@Test
- public void testPutWithPermFailureOnMaxCount() throws Exception {
+ public void testPutWithPermFailureOnMaxCount() {
TableRetryPolicy policy = new TableRetryPolicy();
policy.withFixedBackoff(Duration.ofMillis(5));
policy.withStopAfterAttempts(10);
@@ -361,9 +465,9 @@
table.init(TestRemoteTable.getMockContext());
try {
- table.putAllAsync(Arrays.asList(new Entry(1, 2))).get();
+ table.putAllAsync(Arrays.asList(new Entry(1, 2))).join();
fail();
- } catch (ExecutionException e) {
+ } catch (Throwable t) {
}
verify(writeFn, atLeast(11)).putAllAsync(any());
diff --git a/samza-core/src/test/java/org/apache/samza/table/retry/TestRetriableTableFunctions.java b/samza-core/src/test/java/org/apache/samza/table/retry/TestRetriableTableFunctions.java
deleted file mode 100644
index 34aac9f..0000000
--- a/samza-core/src/test/java/org/apache/samza/table/retry/TestRetriableTableFunctions.java
+++ /dev/null
@@ -1,312 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.samza.table.retry;
-
-import java.time.Duration;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Executors;
-import java.util.concurrent.ScheduledExecutorService;
-import junit.framework.Assert;
-import org.apache.samza.context.Context;
-import org.apache.samza.storage.kv.Entry;
-import org.apache.samza.table.Table;
-import org.apache.samza.table.remote.TableReadFunction;
-import org.apache.samza.table.remote.TableWriteFunction;
-import org.apache.samza.table.remote.TestRemoteTable;
-import org.apache.samza.table.utils.TableMetricsUtil;
-import org.junit.Test;
-
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyString;
-import static org.mockito.Mockito.atLeast;
-import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-
-
-public class TestRetriableTableFunctions {
-
- private final ScheduledExecutorService schedExec = Executors.newSingleThreadScheduledExecutor();
-
- public TableMetricsUtil getMetricsUtil(String tableId) {
- Table table = mock(Table.class);
- Context context = TestRemoteTable.getMockContext();
- return new TableMetricsUtil(context, table, tableId);
- }
-
- @Test
- public void testFirstTimeSuccessGet() throws Exception {
- String tableId = "testFirstTimeSuccessGet";
- TableRetryPolicy policy = new TableRetryPolicy();
- policy.withFixedBackoff(Duration.ofMillis(100));
- TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
- doReturn(true).when(readFn).isRetriable(any());
- doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(anyString());
- RetriableReadFunction<String, String> retryIO = new RetriableReadFunction<>(policy, readFn, schedExec);
- retryIO.setMetrics(getMetricsUtil(tableId));
- Assert.assertEquals("bar", retryIO.getAsync("foo").get());
- verify(readFn, times(1)).getAsync(anyString());
-
- Assert.assertEquals(0, retryIO.retryMetrics.retryCount.getCount());
- Assert.assertEquals(1, retryIO.retryMetrics.successCount.getCount());
- Assert.assertEquals(0, retryIO.retryMetrics.retryTimer.getSnapshot().getMax());
- }
-
- @Test
- public void testRetryEngagedGet() throws Exception {
- String tableId = "testRetryEngagedGet";
- TableRetryPolicy policy = new TableRetryPolicy();
- policy.withFixedBackoff(Duration.ofMillis(10));
- TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
- doReturn(true).when(readFn).isRetriable(any());
-
- int[] times = {0};
- Map<String, String> map = new HashMap<>();
- map.put("foo1", "bar1");
- map.put("foo2", "bar2");
- doAnswer(invocation -> {
- CompletableFuture<Map<String, String>> future = new CompletableFuture();
- if (times[0] > 0) {
- future.complete(map);
- } else {
- times[0]++;
- future.completeExceptionally(new RuntimeException("test exception"));
- }
- return future;
- }).when(readFn).getAllAsync(any());
-
- RetriableReadFunction<String, String> retryIO = new RetriableReadFunction<>(policy, readFn, schedExec);
- retryIO.setMetrics(getMetricsUtil(tableId));
- Assert.assertEquals(map, retryIO.getAllAsync(Arrays.asList("foo1", "foo2")).get());
- verify(readFn, times(2)).getAllAsync(any());
-
- Assert.assertEquals(1, retryIO.retryMetrics.retryCount.getCount());
- Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
- Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
- }
-
- @Test
- public void testRetryExhaustedTimeGet() throws Exception {
- String tableId = "testRetryExhaustedTime";
- TableRetryPolicy policy = new TableRetryPolicy();
- policy.withFixedBackoff(Duration.ofMillis(5));
- policy.withStopAfterDelay(Duration.ofMillis(100));
- TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
- doReturn(true).when(readFn).isRetriable(any());
-
- CompletableFuture<String> future = new CompletableFuture();
- future.completeExceptionally(new RuntimeException("test exception"));
- doReturn(future).when(readFn).getAsync(anyString());
-
- RetriableReadFunction<String, String> retryIO = new RetriableReadFunction<>(policy, readFn, schedExec);
- retryIO.setMetrics(getMetricsUtil(tableId));
- try {
- retryIO.getAsync("foo").get();
- Assert.fail();
- } catch (ExecutionException e) {
- }
-
- // Conservatively: must be at least 3 attempts with 5ms backoff and 100ms maxDelay
- verify(readFn, atLeast(3)).getAsync(anyString());
- Assert.assertTrue(retryIO.retryMetrics.retryCount.getCount() >= 3);
- Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
- Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
- }
-
- @Test
- public void testRetryExhaustedAttemptsGet() throws Exception {
- String tableId = "testRetryExhaustedAttempts";
- TableRetryPolicy policy = new TableRetryPolicy();
- policy.withFixedBackoff(Duration.ofMillis(5));
- policy.withStopAfterAttempts(10);
- TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
- doReturn(true).when(readFn).isRetriable(any());
-
- CompletableFuture<String> future = new CompletableFuture();
- future.completeExceptionally(new RuntimeException("test exception"));
- doReturn(future).when(readFn).getAllAsync(any());
-
- RetriableReadFunction<String, String> retryIO = new RetriableReadFunction<>(policy, readFn, schedExec);
- retryIO.setMetrics(getMetricsUtil(tableId));
- try {
- retryIO.getAllAsync(Arrays.asList("foo1", "foo2")).get();
- Assert.fail();
- } catch (ExecutionException e) {
- }
-
- // 1 initial try + 10 retries
- verify(readFn, times(11)).getAllAsync(any());
- Assert.assertEquals(10, retryIO.retryMetrics.retryCount.getCount());
- Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
- Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
- }
-
- @Test
- public void testFirstTimeSuccessPut() throws Exception {
- String tableId = "testFirstTimeSuccessPut";
- TableRetryPolicy policy = new TableRetryPolicy();
- policy.withFixedBackoff(Duration.ofMillis(100));
- TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
- doReturn(true).when(writeFn).isRetriable(any());
- doReturn(CompletableFuture.completedFuture("bar")).when(writeFn).putAsync(anyString(), anyString());
- RetriableWriteFunction<String, String> retryIO = new RetriableWriteFunction<>(policy, writeFn, schedExec);
- retryIO.setMetrics(getMetricsUtil(tableId));
- retryIO.putAsync("foo", "bar").get();
- verify(writeFn, times(1)).putAsync(anyString(), anyString());
-
- Assert.assertEquals(0, retryIO.retryMetrics.retryCount.getCount());
- Assert.assertEquals(1, retryIO.retryMetrics.successCount.getCount());
- Assert.assertEquals(0, retryIO.retryMetrics.retryTimer.getSnapshot().getMax());
- }
-
- @Test
- public void testRetryEngagedPut() throws Exception {
- String tableId = "testRetryEngagedPut";
- TableRetryPolicy policy = new TableRetryPolicy();
- policy.withFixedBackoff(Duration.ofMillis(10));
- TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
- doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any());
- doReturn(true).when(writeFn).isRetriable(any());
-
- int[] times = new int[] {0};
- List<Entry<String, String>> records = new ArrayList<>();
- records.add(new Entry<>("foo1", "bar1"));
- records.add(new Entry<>("foo2", "bar2"));
- doAnswer(invocation -> {
- CompletableFuture<Map<String, String>> future = new CompletableFuture();
- if (times[0] > 0) {
- future.complete(null);
- } else {
- times[0]++;
- future.completeExceptionally(new RuntimeException("test exception"));
- }
- return future;
- }).when(writeFn).putAllAsync(any());
-
- RetriableWriteFunction<String, String> retryIO = new RetriableWriteFunction<>(policy, writeFn, schedExec);
- retryIO.setMetrics(getMetricsUtil(tableId));
- retryIO.putAllAsync(records).get();
- verify(writeFn, times(2)).putAllAsync(any());
-
- Assert.assertEquals(1, retryIO.retryMetrics.retryCount.getCount());
- Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
- Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
- }
-
- @Test
- public void testRetryExhaustedTimePut() throws Exception {
- String tableId = "testRetryExhaustedTimePut";
- TableRetryPolicy policy = new TableRetryPolicy();
- policy.withFixedBackoff(Duration.ofMillis(5));
- policy.withStopAfterDelay(Duration.ofMillis(100));
- TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
- doReturn(true).when(writeFn).isRetriable(any());
-
- CompletableFuture<String> future = new CompletableFuture();
- future.completeExceptionally(new RuntimeException("test exception"));
- doReturn(future).when(writeFn).deleteAsync(anyString());
-
- RetriableWriteFunction<String, String> retryIO = new RetriableWriteFunction<>(policy, writeFn, schedExec);
- retryIO.setMetrics(getMetricsUtil(tableId));
- try {
- retryIO.deleteAsync("foo").get();
- Assert.fail();
- } catch (ExecutionException e) {
- }
-
- // Conservatively: must be at least 3 attempts with 5ms backoff and 100ms maxDelay
- verify(writeFn, atLeast(3)).deleteAsync(anyString());
- Assert.assertTrue(retryIO.retryMetrics.retryCount.getCount() >= 3);
- Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
- Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
- }
-
- @Test
- public void testRetryExhaustedAttemptsPut() throws Exception {
- String tableId = "testRetryExhaustedAttemptsPut";
- TableRetryPolicy policy = new TableRetryPolicy();
- policy.withFixedBackoff(Duration.ofMillis(5));
- policy.withStopAfterAttempts(10);
- TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class);
- doReturn(true).when(writeFn).isRetriable(any());
-
- CompletableFuture<String> future = new CompletableFuture();
- future.completeExceptionally(new RuntimeException("test exception"));
- doReturn(future).when(writeFn).deleteAllAsync(any());
-
- RetriableWriteFunction<String, String> retryIO = new RetriableWriteFunction<>(policy, writeFn, schedExec);
- retryIO.setMetrics(getMetricsUtil(tableId));
- try {
- retryIO.deleteAllAsync(Arrays.asList("foo1", "foo2")).get();
- Assert.fail();
- } catch (ExecutionException e) {
- }
-
- // 1 initial try + 10 retries
- verify(writeFn, times(11)).deleteAllAsync(any());
- Assert.assertEquals(10, retryIO.retryMetrics.retryCount.getCount());
- Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
- Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
- }
-
- @Test
- public void testMixedIsRetriablePredicates() throws Exception {
- String tableId = "testMixedIsRetriablePredicates";
- TableRetryPolicy policy = new TableRetryPolicy();
- policy.withFixedBackoff(Duration.ofMillis(100));
-
- // Retry should be attempted based on the custom classification, ie. retry on NPE
- policy.withRetryPredicate(ex -> ex instanceof NullPointerException);
-
- TableReadFunction<String, String> readFn = mock(TableReadFunction.class);
-
- // Table fn classification only retries on IllegalArgumentException
- doAnswer(arg -> arg.getArgumentAt(0, Throwable.class) instanceof IllegalArgumentException).when(readFn).isRetriable(any());
-
- int[] times = new int[1];
- doAnswer(arg -> {
- if (times[0]++ == 0) {
- CompletableFuture<String> future = new CompletableFuture();
- future.completeExceptionally(new NullPointerException("test exception"));
- return future;
- } else {
- return CompletableFuture.completedFuture("bar");
- }
- }).when(readFn).getAsync(any());
-
- RetriableReadFunction<String, String> retryIO = new RetriableReadFunction<>(policy, readFn, schedExec);
- retryIO.setMetrics(getMetricsUtil(tableId));
-
- Assert.assertEquals("bar", retryIO.getAsync("foo").get());
-
- verify(readFn, times(2)).getAsync(anyString());
- Assert.assertEquals(1, retryIO.retryMetrics.retryCount.getCount());
- Assert.assertEquals(0, retryIO.retryMetrics.successCount.getCount());
- Assert.assertTrue(retryIO.retryMetrics.retryTimer.getSnapshot().getMax() > 0);
- }
-}
diff --git a/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala b/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
index 0c009a1..4c53ec2 100644
--- a/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
+++ b/samza-core/src/test/scala/org/apache/samza/checkpoint/TestOffsetManager.scala
@@ -22,12 +22,12 @@
import java.util
import java.util.function.BiConsumer
-import org.apache.samza.{Partition, SamzaException}
import org.apache.samza.config.MapConfig
import org.apache.samza.container.TaskName
import org.apache.samza.startpoint.{StartpointManagerTestUtil, StartpointOldest, StartpointUpcoming}
import org.apache.samza.system.SystemStreamMetadata.{OffsetType, SystemStreamPartitionMetadata}
import org.apache.samza.system._
+import org.apache.samza.{Partition, SamzaException}
import org.junit.Assert._
import org.junit.Test
import org.mockito.Matchers._
@@ -67,13 +67,14 @@
val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
val config = new MapConfig
val checkpointManager = getCheckpointManager(systemStreamPartition, taskName)
- val startpointManager = getStartpointManager()
+ val startpointManagerUtil = getStartpointManagerUtil
val systemAdmins = mock(classOf[SystemAdmins])
when(systemAdmins.getSystemAdmin("test-system")).thenReturn(getSystemAdmin)
- val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, getStartpointManager(), systemAdmins, Map(), new OffsetManagerMetrics)
+ val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins, Map(), new OffsetManagerMetrics)
offsetManager.register(taskName, Set(systemStreamPartition))
- startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
+ startpointManagerUtil.getStartpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
+ assertTrue(startpointManagerUtil.getStartpointManager.readStartpoint(systemStreamPartition, taskName).isPresent)
+ assertTrue(startpointManagerUtil.getStartpointManager.fanOut(asTaskToSSPMap(taskName, systemStreamPartition)).keySet().contains(taskName))
offsetManager.start
assertTrue(checkpointManager.isStarted)
assertEquals(1, checkpointManager.registered.size)
@@ -90,11 +91,19 @@
assertEquals("46", offsetManager.getStartingOffset(taskName, systemStreamPartition).get)
// Should not update null offset
offsetManager.update(taskName, systemStreamPartition, null)
+ assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition))
checkpoint(offsetManager, taskName)
- assertNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should delete after checkpoint commit
+ intercept[IllegalStateException] {
+ // StartpointManager should stop after last fan out is removed
+ startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
+ }
+ startpointManagerUtil.getStartpointManager.start
+ assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint should delete after checkpoint commit
val expectedCheckpoint = new Checkpoint(Map(systemStreamPartition -> "47").asJava)
assertEquals(expectedCheckpoint, checkpointManager.readLastCheckpoint(taskName))
+ startpointManagerUtil.stop
}
+
@Test
def testGetAndSetStartpoint {
val taskName1 = new TaskName("c")
@@ -106,25 +115,23 @@
val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
val config = new MapConfig
val checkpointManager = getCheckpointManager(systemStreamPartition, taskName1)
- val startpointManager = getStartpointManager()
+ val startpointManagerUtil = getStartpointManagerUtil()
val systemAdmins = mock(classOf[SystemAdmins])
when(systemAdmins.getSystemAdmin("test-system")).thenReturn(getSystemAdmin)
- val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, getStartpointManager(), systemAdmins, Map(), new OffsetManagerMetrics)
+ val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins, Map(), new OffsetManagerMetrics)
offsetManager.register(taskName1, Set(systemStreamPartition))
val startpoint1 = new StartpointOldest
- startpointManager.writeStartpoint(systemStreamPartition, taskName1, startpoint1)
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName1))
+ startpointManagerUtil.getStartpointManager.writeStartpoint(systemStreamPartition, taskName1, startpoint1)
+ assertTrue(startpointManagerUtil.getStartpointManager.readStartpoint(systemStreamPartition, taskName1).isPresent)
+ assertTrue(startpointManagerUtil.getStartpointManager.fanOut(asTaskToSSPMap(taskName1, systemStreamPartition)).keySet().contains(taskName1))
offsetManager.start
val startpoint2 = new StartpointUpcoming
offsetManager.setStartpoint(taskName2, systemStreamPartition, startpoint2)
assertEquals(Option(startpoint1), offsetManager.getStartpoint(taskName1, systemStreamPartition))
assertEquals(Option(startpoint2), offsetManager.getStartpoint(taskName2, systemStreamPartition))
-
- assertEquals(startpoint1, startpointManager.readStartpoint(systemStreamPartition, taskName1))
- // Startpoint written to offset manager, but not directly to startpoint manager.
- assertNull(startpointManager.readStartpoint(systemStreamPartition, taskName2))
+ startpointManagerUtil.stop
}
@Test
@@ -137,34 +144,36 @@
val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
val config = new MapConfig
val checkpointManager = getCheckpointManager(systemStreamPartition, taskName)
- val startpointManager = getStartpointManager()
+ val startpointManagerUtil = getStartpointManagerUtil()
val systemAdmins = mock(classOf[SystemAdmins])
when(systemAdmins.getSystemAdmin("test-system")).thenReturn(getSystemAdmin)
- val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, getStartpointManager(), systemAdmins, Map(), new OffsetManagerMetrics)
+ val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins, Map(), new OffsetManagerMetrics)
offsetManager.register(taskName, Set(systemStreamPartition))
// Pre-populate startpoint
- startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
+ startpointManagerUtil.getStartpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
+ assertTrue(startpointManagerUtil.getStartpointManager.fanOut(asTaskToSSPMap(taskName, systemStreamPartition)).keySet().contains(taskName))
offsetManager.start
// Should get offset 45 back from the checkpoint manager, which is last processed, and system admin should return 46 as starting offset.
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
+ assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition))
checkpoint(offsetManager, taskName)
- assertNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should delete after checkpoint commit
+ intercept[IllegalStateException] {
+ // StartpointManager should stop after last fan out is removed
+ startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
+ }
+ startpointManagerUtil.getStartpointManager.start
+ assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint should delete after checkpoint commit
assertEquals("45", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
offsetManager.update(taskName, systemStreamPartition, "46")
offsetManager.update(taskName, systemStreamPartition, "47")
- startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
checkpoint(offsetManager, taskName)
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should only be deleted at first checkpoint
assertEquals("47", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
offsetManager.update(taskName, systemStreamPartition, "48")
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
checkpoint(offsetManager, taskName)
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should only be deleted at first checkpoint
assertEquals("48", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
+ startpointManagerUtil.stop
}
@Test
@@ -207,22 +216,21 @@
val checkpoint = new Checkpoint(Map(systemStreamPartition1 -> "45").asJava)
// Checkpoint manager only has partition 1.
val checkpointManager = getCheckpointManager(systemStreamPartition1, taskName1)
- val startpointManager = getStartpointManager()
+ val startpointManagerUtil = getStartpointManagerUtil()
+
val config = new MapConfig
val systemAdmins = mock(classOf[SystemAdmins])
when(systemAdmins.getSystemAdmin("test-system")).thenReturn(getSystemAdmin)
- val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, getStartpointManager(), systemAdmins)
+ val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins)
// Register both partitions. Partition 2 shouldn't have a checkpoint.
offsetManager.register(taskName1, Set(systemStreamPartition1))
offsetManager.register(taskName2, Set(systemStreamPartition2))
- startpointManager.writeStartpoint(systemStreamPartition1, taskName1, new StartpointOldest)
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition1, taskName1))
offsetManager.start
assertTrue(checkpointManager.isStarted)
assertEquals(2, checkpointManager.registered.size)
assertEquals(checkpoint, checkpointManager.readLastCheckpoint(taskName1))
assertNull(checkpointManager.readLastCheckpoint(taskName2))
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition1, taskName1)) // no checkpoint commit so this should still be there
+ startpointManagerUtil.stop
}
@Test
@@ -314,7 +322,7 @@
val checkpointManager = getCheckpointManager1(systemStreamPartition,
new Checkpoint(Map(systemStreamPartition -> "45", systemStreamPartition2 -> "100").asJava),
taskName)
- val startpointManager = getStartpointManager()
+ val startpointManagerUtil = getStartpointManagerUtil()
val consumer = new SystemConsumerWithCheckpointCallback
val systemAdmins = mock(classOf[SystemAdmins])
when(systemAdmins.getSystemAdmin(systemName)).thenReturn(getSystemAdmin)
@@ -325,16 +333,26 @@
else
Map()
- val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, getStartpointManager(), systemAdmins,
+ val offsetManager = OffsetManager(systemStreamMetadata, config, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins,
checkpointListeners, new OffsetManagerMetrics)
offsetManager.register(taskName, Set(systemStreamPartition, systemStreamPartition2))
- startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
+ startpointManagerUtil.getStartpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
+ assertTrue(startpointManagerUtil.getStartpointManager.readStartpoint(systemStreamPartition, taskName).isPresent)
+ assertTrue(startpointManagerUtil.getStartpointManager.fanOut(asTaskToSSPMap(taskName, systemStreamPartition)).keySet().contains(taskName))
+ assertFalse(startpointManagerUtil.getStartpointManager.readStartpoint(systemStreamPartition, taskName).isPresent)
offsetManager.start
// Should get offset 45 back from the checkpoint manager, which is last processed, and system admin should return 46 as starting offset.
+ assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition))
checkpoint(offsetManager, taskName)
- assertNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint be deleted at first checkpoint
+
+ intercept[IllegalStateException] {
+ // StartpointManager should stop after last fan out is removed
+ startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
+ }
+ startpointManagerUtil.getStartpointManager.start
+ assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint be deleted at first checkpoint
+
assertEquals("45", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
assertEquals("100", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition2).getValue)
assertEquals("45", consumer.recentCheckpoint.get(systemStreamPartition))
@@ -343,10 +361,7 @@
offsetManager.update(taskName, systemStreamPartition, "46")
offsetManager.update(taskName, systemStreamPartition, "47")
- startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
checkpoint(offsetManager, taskName)
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should only be deleted at first checkpoint
assertEquals("47", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
assertEquals("100", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition2).getValue)
assertEquals("47", consumer.recentCheckpoint.get(systemStreamPartition))
@@ -354,15 +369,13 @@
offsetManager.update(taskName, systemStreamPartition, "48")
offsetManager.update(taskName, systemStreamPartition2, "101")
- startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
checkpoint(offsetManager, taskName)
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should only be deleted at first checkpoint
assertEquals("48", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
assertEquals("101", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition2).getValue)
assertEquals("48", consumer.recentCheckpoint.get(systemStreamPartition))
assertNull(consumer.recentCheckpoint.get(systemStreamPartition2))
offsetManager.stop
+ startpointManagerUtil.stop
}
/**
@@ -380,35 +393,38 @@
val testStreamMetadata = new SystemStreamMetadata(systemStream.getStream, Map(partition -> new SystemStreamPartitionMetadata("0", "1", "2")).asJava)
val systemStreamMetadata = Map(systemStream -> testStreamMetadata)
val checkpointManager = getCheckpointManager(systemStreamPartition, taskName)
- val startpointManager = getStartpointManager()
+ val startpointManagerUtil = getStartpointManagerUtil()
val systemAdmins = mock(classOf[SystemAdmins])
when(systemAdmins.getSystemAdmin("test-system")).thenReturn(getSystemAdmin)
- val offsetManager = OffsetManager(systemStreamMetadata, new MapConfig, checkpointManager, getStartpointManager(), systemAdmins, Map(), new OffsetManagerMetrics)
+ val offsetManager = OffsetManager(systemStreamMetadata, new MapConfig, checkpointManager, startpointManagerUtil.getStartpointManager, systemAdmins, Map(), new OffsetManagerMetrics)
offsetManager.register(taskName, Set(systemStreamPartition))
- startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
+ startpointManagerUtil.getStartpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
+ assertTrue(startpointManagerUtil.getStartpointManager.fanOut(asTaskToSSPMap(taskName, systemStreamPartition)).keySet().contains(taskName))
offsetManager.start
// Should get offset 45 back from the checkpoint manager, which is last processed, and system admin should return 46 as starting offset.
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName))
+ assertTrue(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition))
checkpoint(offsetManager, taskName)
- assertNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint be deleted at first checkpoint
+ intercept[IllegalStateException] {
+ // StartpointManager should stop after last fan out is removed
+ startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName)
+ }
+ startpointManagerUtil.getStartpointManager.start
+ assertFalse(startpointManagerUtil.getStartpointManager.getFanOutForTask(taskName).containsKey(systemStreamPartition)) // Startpoint be deleted at first checkpoint
assertEquals("45", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
- startpointManager.writeStartpoint(systemStreamPartition, taskName, new StartpointOldest)
-
offsetManager.update(taskName, systemStreamPartition, "46")
// Get checkpoint snapshot like we do at the beginning of TaskInstance.commit()
val checkpoint46 = offsetManager.buildCheckpoint(taskName)
offsetManager.update(taskName, systemStreamPartition, "47") // Offset updated before checkpoint
offsetManager.writeCheckpoint(taskName, checkpoint46)
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should only be deleted at first checkpoint
assertEquals(Some("47"), offsetManager.getLastProcessedOffset(taskName, systemStreamPartition))
assertEquals("46", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
// Now write the checkpoint for the latest offset
val checkpoint47 = offsetManager.buildCheckpoint(taskName)
offsetManager.writeCheckpoint(taskName, checkpoint47)
- assertNotNull(startpointManager.readStartpoint(systemStreamPartition, taskName)) // Startpoint should only be deleted at first checkpoint
+ startpointManagerUtil.stop
assertEquals(Some("47"), offsetManager.getLastProcessedOffset(taskName, systemStreamPartition))
assertEquals("47", offsetManager.offsetManagerMetrics.checkpointedOffsets.get(systemStreamPartition).getValue)
}
@@ -523,10 +539,9 @@
}
}
- private def getStartpointManager() = {
- val startpointManager = StartpointManagerTestUtil.getStartpointManager()
- startpointManager.start
- startpointManager
+ private def getStartpointManagerUtil() = {
+ val startpointManagerUtil = new StartpointManagerTestUtil
+ startpointManagerUtil
}
private def getSystemAdmin: SystemAdmin = {
@@ -540,4 +555,8 @@
override def offsetComparator(offset1: String, offset2: String) = null
}
}
+
+ private def asTaskToSSPMap(taskName: TaskName, ssps: SystemStreamPartition*) = {
+ Map(taskName -> ssps.toSet.asJava).asJava
+ }
}
diff --git a/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/BaseCouchbaseTableFunction.java b/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/BaseCouchbaseTableFunction.java
index 813fb97..e805975 100644
--- a/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/BaseCouchbaseTableFunction.java
+++ b/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/BaseCouchbaseTableFunction.java
@@ -22,24 +22,26 @@
import com.couchbase.client.java.Bucket;
import com.couchbase.client.java.error.TemporaryFailureException;
import com.couchbase.client.java.error.TemporaryLockFailureException;
+
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
-import java.io.Serializable;
+
import java.time.Duration;
import java.util.List;
+
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.samza.context.Context;
-import org.apache.samza.operators.functions.ClosableFunction;
-import org.apache.samza.operators.functions.InitableFunction;
import org.apache.samza.serializers.Serde;
+import org.apache.samza.table.AsyncReadWriteTable;
+import org.apache.samza.table.remote.BaseTableFunction;
/**
* Base class for {@link CouchbaseTableReadFunction} and {@link CouchbaseTableWriteFunction}
* @param <V> Type of values to read from / write to couchbase
*/
-public abstract class BaseCouchbaseTableFunction<V> implements InitableFunction, ClosableFunction, Serializable {
+public abstract class BaseCouchbaseTableFunction<V> extends BaseTableFunction {
// Clients
private final static CouchbaseBucketRegistry COUCHBASE_BUCKET_REGISTRY = new CouchbaseBucketRegistry();
@@ -75,11 +77,9 @@
environmentConfigs = new CouchbaseEnvironmentConfigs();
}
- /**
- * Helper method to initialize {@link Bucket}.
- */
@Override
- public void init(Context context) {
+ public void init(Context context, AsyncReadWriteTable table) {
+ super.init(context, table);
bucket = COUCHBASE_BUCKET_REGISTRY.getBucket(bucketName, clusterNodes, environmentConfigs);
}
diff --git a/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableReadFunction.java b/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableReadFunction.java
index ae758a5..c05ec55 100644
--- a/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableReadFunction.java
+++ b/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableReadFunction.java
@@ -43,16 +43,23 @@
import com.couchbase.client.java.document.Document;
import com.couchbase.client.java.document.JsonDocument;
import com.couchbase.client.java.document.json.JsonObject;
+
import com.google.common.base.Preconditions;
+
import java.util.NoSuchElementException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
+
import org.apache.commons.lang3.StringUtils;
+
import org.apache.samza.SamzaException;
import org.apache.samza.context.Context;
+import org.apache.samza.table.AsyncReadWriteTable;
import org.apache.samza.table.remote.TableReadFunction;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+
import rx.Single;
import rx.SingleSubscriber;
@@ -66,6 +73,7 @@
*/
public class CouchbaseTableReadFunction<V> extends BaseCouchbaseTableFunction<V>
implements TableReadFunction<String, V> {
+
private static final Logger LOGGER = LoggerFactory.getLogger(CouchbaseTableReadFunction.class);
protected final Class<? extends Document<?>> documentType;
@@ -83,8 +91,8 @@
}
@Override
- public void init(Context context) {
- super.init(context);
+ public void init(Context context, AsyncReadWriteTable table) {
+ super.init(context, table);
LOGGER.info("Read function for bucket {} initialized successfully", bucketName);
}
@@ -123,10 +131,10 @@
return future;
}
- /**
+ /*
* Helper method to read bytes from binaryDocument and release the buffer.
*/
- private void handleGetAsyncBinaryDocument(BinaryDocument binaryDocument, CompletableFuture<V> future, String key) {
+ protected void handleGetAsyncBinaryDocument(BinaryDocument binaryDocument, CompletableFuture<V> future, String key) {
ByteBuf buffer = binaryDocument.content();
try {
byte[] bytes;
diff --git a/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableWriteFunction.java b/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableWriteFunction.java
index cbdcb00..96800ae 100644
--- a/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableWriteFunction.java
+++ b/samza-kv-couchbase/src/main/java/org/apache/samza/table/remote/couchbase/CouchbaseTableWriteFunction.java
@@ -24,15 +24,21 @@
import com.couchbase.client.java.document.Document;
import com.couchbase.client.java.document.JsonDocument;
import com.couchbase.client.java.document.json.JsonObject;
+
import com.google.common.base.Preconditions;
+
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
+
import org.apache.commons.lang3.StringUtils;
import org.apache.samza.SamzaException;
import org.apache.samza.context.Context;
+import org.apache.samza.table.AsyncReadWriteTable;
import org.apache.samza.table.remote.TableWriteFunction;
+
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+
import rx.Single;
import rx.SingleSubscriber;
@@ -45,6 +51,7 @@
*/
public class CouchbaseTableWriteFunction<V> extends BaseCouchbaseTableFunction<V>
implements TableWriteFunction<String, V> {
+
private static final Logger LOGGER = LoggerFactory.getLogger(CouchbaseTableWriteFunction.class);
/**
@@ -59,8 +66,8 @@
}
@Override
- public void init(Context context) {
- super.init(context);
+ public void init(Context context, AsyncReadWriteTable table) {
+ super.init(context, table);
LOGGER.info("Write function for bucket {} initialized successfully", bucketName);
}
@@ -83,10 +90,10 @@
String.format("Failed to delete key %s from bucket %s.", key, bucketName));
}
- /**
+ /*
* Helper method for putAsync and deleteAsync to convert Single to CompletableFuture.
*/
- private CompletableFuture<Void> asyncWriteHelper(Single<? extends Document<?>> single, String errorMessage) {
+ protected CompletableFuture<Void> asyncWriteHelper(Single<? extends Document<?>> single, String errorMessage) {
CompletableFuture<Void> future = new CompletableFuture<>();
single.subscribe(new SingleSubscriber<Document>() {
@Override
diff --git a/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableReadFunction.java b/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableReadFunction.java
index 4532fd7..9e388f1 100644
--- a/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableReadFunction.java
+++ b/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableReadFunction.java
@@ -28,16 +28,23 @@
import com.couchbase.client.java.document.json.JsonObject;
import com.couchbase.client.java.error.TemporaryFailureException;
import com.couchbase.client.java.error.TemporaryLockFailureException;
+
import java.util.List;
import java.util.concurrent.TimeUnit;
+
import org.apache.samza.SamzaException;
+import org.apache.samza.context.Context;
import org.apache.samza.serializers.Serde;
import org.apache.samza.serializers.StringSerde;
+import org.apache.samza.table.AsyncReadWriteTable;
+
import org.junit.Test;
import org.junit.runner.RunWith;
+
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
+
import rx.Observable;
import static org.junit.Assert.*;
@@ -176,7 +183,7 @@
CouchbaseEnvironmentConfigs.class)).toReturn(bucket);
CouchbaseTableReadFunction<V> readFunction =
new CouchbaseTableReadFunction<>(DEFAULT_BUCKET_NAME, valueClass, DEFAULT_CLUSTER_NODE).withSerde(serde);
- readFunction.init(null);
+ readFunction.init(mock(Context.class), mock(AsyncReadWriteTable.class));
return readFunction;
}
}
diff --git a/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableWriteFunction.java b/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableWriteFunction.java
index c63472e..e600de8 100644
--- a/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableWriteFunction.java
+++ b/samza-kv-couchbase/src/test/java/org/apache/samza/table/remote/couchbase/TestCouchbaseTableWriteFunction.java
@@ -26,16 +26,23 @@
import com.couchbase.client.java.document.json.JsonObject;
import com.couchbase.client.java.error.TemporaryFailureException;
import com.couchbase.client.java.error.TemporaryLockFailureException;
+
import java.util.List;
import java.util.concurrent.TimeUnit;
+
import org.apache.samza.SamzaException;
+import org.apache.samza.context.Context;
import org.apache.samza.serializers.Serde;
import org.apache.samza.serializers.StringSerde;
+import org.apache.samza.table.AsyncReadWriteTable;
+
import org.junit.Test;
import org.junit.runner.RunWith;
+
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
+
import rx.Observable;
import static org.junit.Assert.*;
@@ -190,9 +197,9 @@
when(bucket.async()).thenReturn(asyncBucket);
PowerMockito.stub(PowerMockito.method(CouchbaseBucketRegistry.class, "getBucket", String.class, List.class,
CouchbaseEnvironmentConfigs.class)).toReturn(bucket);
- CouchbaseTableWriteFunction<V> readFunction =
+ CouchbaseTableWriteFunction<V> writeFunction =
new CouchbaseTableWriteFunction<>(DEFAULT_BUCKET_NAME, valueClass, DEFAULT_CLUSTER_NODE).withSerde(serde);
- readFunction.init(null);
- return readFunction;
+ writeFunction.init(mock(Context.class), mock(AsyncReadWriteTable.class));
+ return writeFunction;
}
}
diff --git a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalTable.java b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalTable.java
index a269452..2a27d18 100644
--- a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalTable.java
+++ b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalTable.java
@@ -39,7 +39,7 @@
* @param <K> the type of the key in this table
* @param <V> the type of the value in this table
*/
-public class LocalTable<K, V> extends BaseReadWriteTable<K, V> {
+public final class LocalTable<K, V> extends BaseReadWriteTable<K, V> {
protected final KeyValueStore<K, V> kvStore;
@@ -55,7 +55,7 @@
}
@Override
- public V get(K key) {
+ public V get(K key, Object ... args) {
V result = instrument(metrics.numGets, metrics.getNs, () -> kvStore.get(key));
if (result == null) {
incCounter(metrics.numMissedLookups);
@@ -64,10 +64,10 @@
}
@Override
- public CompletableFuture<V> getAsync(K key) {
+ public CompletableFuture<V> getAsync(K key, Object ... args) {
CompletableFuture<V> future = new CompletableFuture();
try {
- future.complete(get(key));
+ future.complete(get(key, args));
} catch (Exception e) {
future.completeExceptionally(e);
}
@@ -75,17 +75,17 @@
}
@Override
- public Map<K, V> getAll(List<K> keys) {
+ public Map<K, V> getAll(List<K> keys, Object ... args) {
Map<K, V> result = instrument(metrics.numGetAlls, metrics.getAllNs, () -> kvStore.getAll(keys));
result.values().stream().filter(Objects::isNull).forEach(v -> incCounter(metrics.numMissedLookups));
return result;
}
@Override
- public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) {
+ public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys, Object ... args) {
CompletableFuture<Map<K, V>> future = new CompletableFuture();
try {
- future.complete(getAll(keys));
+ future.complete(getAll(keys, args));
} catch (Exception e) {
future.completeExceptionally(e);
}
@@ -93,7 +93,7 @@
}
@Override
- public void put(K key, V value) {
+ public void put(K key, V value, Object ... args) {
if (value != null) {
instrument(metrics.numPuts, metrics.putNs, () -> kvStore.put(key, value));
} else {
@@ -102,10 +102,10 @@
}
@Override
- public CompletableFuture<Void> putAsync(K key, V value) {
+ public CompletableFuture<Void> putAsync(K key, V value, Object ... args) {
CompletableFuture<Void> future = new CompletableFuture();
try {
- put(key, value);
+ put(key, value, args);
future.complete(null);
} catch (Exception e) {
future.completeExceptionally(e);
@@ -114,7 +114,7 @@
}
@Override
- public void putAll(List<Entry<K, V>> entries) {
+ public void putAll(List<Entry<K, V>> entries, Object ... args) {
List<Entry<K, V>> toPut = new LinkedList<>();
List<K> toDelete = new LinkedList<>();
entries.forEach(e -> {
@@ -135,10 +135,10 @@
}
@Override
- public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries) {
+ public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries, Object ... args) {
CompletableFuture<Void> future = new CompletableFuture();
try {
- putAll(entries);
+ putAll(entries, args);
future.complete(null);
} catch (Exception e) {
future.completeExceptionally(e);
@@ -147,15 +147,15 @@
}
@Override
- public void delete(K key) {
+ public void delete(K key, Object ... args) {
instrument(metrics.numDeletes, metrics.deleteNs, () -> kvStore.delete(key));
}
@Override
- public CompletableFuture<Void> deleteAsync(K key) {
+ public CompletableFuture<Void> deleteAsync(K key, Object ... args) {
CompletableFuture<Void> future = new CompletableFuture();
try {
- delete(key);
+ delete(key, args);
future.complete(null);
} catch (Exception e) {
future.completeExceptionally(e);
@@ -164,15 +164,15 @@
}
@Override
- public void deleteAll(List<K> keys) {
+ public void deleteAll(List<K> keys, Object ... args) {
instrument(metrics.numDeleteAlls, metrics.deleteAllNs, () -> kvStore.deleteAll(keys));
}
@Override
- public CompletableFuture<Void> deleteAllAsync(List<K> keys) {
+ public CompletableFuture<Void> deleteAllAsync(List<K> keys, Object ... args) {
CompletableFuture<Void> future = new CompletableFuture();
try {
- deleteAll(keys);
+ deleteAll(keys, args);
future.complete(null);
} catch (Exception e) {
future.completeExceptionally(e);
diff --git a/samza-sql/src/test/java/org/apache/samza/sql/util/RemoteStoreIOResolverTestFactory.java b/samza-sql/src/test/java/org/apache/samza/sql/util/RemoteStoreIOResolverTestFactory.java
index 7ba50fd..4a1d299 100644
--- a/samza-sql/src/test/java/org/apache/samza/sql/util/RemoteStoreIOResolverTestFactory.java
+++ b/samza-sql/src/test/java/org/apache/samza/sql/util/RemoteStoreIOResolverTestFactory.java
@@ -33,6 +33,7 @@
import org.apache.samza.sql.interfaces.SqlIOResolver;
import org.apache.samza.sql.interfaces.SqlIOResolverFactory;
import org.apache.samza.storage.kv.descriptors.RocksDbTableDescriptor;
+import org.apache.samza.table.remote.BaseTableFunction;
import org.apache.samza.table.remote.TableReadFunction;
import org.apache.samza.table.remote.TableWriteFunction;
@@ -51,7 +52,8 @@
return new TestRemoteStoreIOResolver(config);
}
- public static class InMemoryWriteFunction implements TableWriteFunction<Object, Object> {
+ public static class InMemoryWriteFunction extends BaseTableFunction
+ implements TableWriteFunction<Object, Object> {
@Override
public CompletableFuture<Void> putAsync(Object key, Object record) {
@@ -71,7 +73,8 @@
}
}
- static class InMemoryReadFunction implements TableReadFunction<Object, Object> {
+ static class InMemoryReadFunction extends BaseTableFunction
+ implements TableReadFunction<Object, Object> {
@Override
public CompletableFuture<Object> getAsync(Object key) {
diff --git a/samza-test/src/test/java/org/apache/samza/test/table/TestCouchbaseRemoteTableEndToEnd.java b/samza-test/src/test/java/org/apache/samza/test/table/TestCouchbaseRemoteTableEndToEnd.java
index b05adcd..46f91d4 100644
--- a/samza-test/src/test/java/org/apache/samza/test/table/TestCouchbaseRemoteTableEndToEnd.java
+++ b/samza-test/src/test/java/org/apache/samza/test/table/TestCouchbaseRemoteTableEndToEnd.java
@@ -28,10 +28,12 @@
import com.couchbase.client.java.env.DefaultCouchbaseEnvironment;
import com.couchbase.mock.BucketConfiguration;
import com.couchbase.mock.CouchbaseMock;
+
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
+
import org.apache.samza.application.StreamApplication;
import org.apache.samza.config.Config;
import org.apache.samza.config.MapConfig;
@@ -44,11 +46,13 @@
import org.apache.samza.system.descriptors.GenericInputDescriptor;
import org.apache.samza.table.Table;
import org.apache.samza.table.descriptors.RemoteTableDescriptor;
+import org.apache.samza.table.remote.BaseTableFunction;
import org.apache.samza.table.remote.TableReadFunction;
import org.apache.samza.table.remote.couchbase.CouchbaseTableReadFunction;
import org.apache.samza.table.remote.couchbase.CouchbaseTableWriteFunction;
import org.apache.samza.test.harness.IntegrationTestHarness;
import org.apache.samza.test.util.Base64Serializer;
+
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
@@ -132,23 +136,24 @@
GenericInputDescriptor<String> inputDescriptor =
inputSystemDescriptor.getInputDescriptor("User", new NoOpSerde<>());
- CouchbaseTableReadFunction<String> readFunction = new CouchbaseTableReadFunction<>(inputBucketName, String.class,
- "couchbase://127.0.0.1").withBootstrapCarrierDirectPort(couchbaseMock.getCarrierPort(inputBucketName))
+ CouchbaseTableReadFunction<String> readFunction = new CouchbaseTableReadFunction<>(inputBucketName,
+ String.class, "couchbase://127.0.0.1")
+ .withBootstrapCarrierDirectPort(couchbaseMock.getCarrierPort(inputBucketName))
.withBootstrapHttpDirectPort(couchbaseMock.getHttpPort())
.withSerde(new StringSerde());
- CouchbaseTableWriteFunction<JsonObject> writeFunction =
- new CouchbaseTableWriteFunction<>(outputBucketName, JsonObject.class,
- "couchbase://127.0.0.1").withBootstrapCarrierDirectPort(couchbaseMock.getCarrierPort(outputBucketName))
- .withBootstrapHttpDirectPort(couchbaseMock.getHttpPort());
+ CouchbaseTableWriteFunction<JsonObject> writeFunction = new CouchbaseTableWriteFunction<>(outputBucketName,
+ JsonObject.class, "couchbase://127.0.0.1")
+ .withBootstrapCarrierDirectPort(couchbaseMock.getCarrierPort(outputBucketName))
+ .withBootstrapHttpDirectPort(couchbaseMock.getHttpPort());
- RemoteTableDescriptor<String, String> inputTableDesc = new RemoteTableDescriptor<>("input-table");
- inputTableDesc.withReadFunction(readFunction).withRateLimiterDisabled();
+ RemoteTableDescriptor inputTableDesc = new RemoteTableDescriptor<String, String>("input-table")
+ .withReadFunction(readFunction)
+ .withRateLimiterDisabled();
Table<KV<String, String>> inputTable = appDesc.getTable(inputTableDesc);
- RemoteTableDescriptor<String, JsonObject> outputTableDesc = new RemoteTableDescriptor<>("output-table");
- outputTableDesc.withReadFunction(new DummyReadFunction<>())
- .withWriteFunction(writeFunction)
+ RemoteTableDescriptor outputTableDesc = new RemoteTableDescriptor<String, JsonObject>("output-table")
+ .withReadFunction(new DummyReadFunction<>()).withWriteFunction(writeFunction)
.withRateLimiterDisabled();
Table<KV<String, JsonObject>> outputTable = appDesc.getTable(outputTableDesc);
@@ -189,7 +194,7 @@
}
}
- static class DummyReadFunction<K, V> implements TableReadFunction<K, V> {
+ static class DummyReadFunction<K, V> extends BaseTableFunction implements TableReadFunction<K, V> {
@Override
public CompletableFuture<V> getAsync(K key) {
return null;
diff --git a/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTableEndToEnd.java b/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTableEndToEnd.java
index 96f99ca..c48f46c 100644
--- a/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTableEndToEnd.java
+++ b/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTableEndToEnd.java
@@ -32,6 +32,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -42,11 +43,14 @@
import org.apache.samza.config.MapConfig;
import org.apache.samza.context.Context;
import org.apache.samza.context.MockContext;
+import org.apache.samza.operators.TableImpl;
+import org.apache.samza.operators.functions.MapFunction;
import org.apache.samza.system.descriptors.GenericInputDescriptor;
import org.apache.samza.metrics.Counter;
import org.apache.samza.metrics.MetricsRegistry;
import org.apache.samza.metrics.Timer;
import org.apache.samza.operators.KV;
+import org.apache.samza.table.ReadWriteTable;
import org.apache.samza.table.descriptors.TableDescriptor;
import org.apache.samza.system.descriptors.DelegatingSystemDescriptor;
import org.apache.samza.runtime.LocalApplicationRunner;
@@ -54,6 +58,7 @@
import org.apache.samza.table.Table;
import org.apache.samza.table.descriptors.CachingTableDescriptor;
import org.apache.samza.table.descriptors.GuavaCacheTableDescriptor;
+import org.apache.samza.table.remote.BaseTableFunction;
import org.apache.samza.table.remote.RemoteTable;
import org.apache.samza.table.descriptors.RemoteTableDescriptor;
import org.apache.samza.table.remote.TableRateLimiter;
@@ -74,20 +79,21 @@
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyString;
-import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.withSettings;
+import static org.mockito.Mockito.*;
public class TestRemoteTableEndToEnd extends IntegrationTestHarness {
+ static Map<String, AtomicInteger> counters = new HashMap<>();
static Map<String, List<EnrichedPageView>> writtenRecords = new HashMap<>();
- static class InMemoryReadFunction implements TableReadFunction<Integer, Profile> {
+ static class InMemoryProfileReadFunction extends BaseTableFunction
+ implements TableReadFunction<Integer, Profile> {
+
private final String serializedProfiles;
private transient Map<Integer, Profile> profileMap;
- private InMemoryReadFunction(String profiles) {
+ private InMemoryProfileReadFunction(String profiles) {
this.serializedProfiles = profiles;
}
@@ -103,20 +109,32 @@
}
@Override
+ public CompletableFuture<Profile> getAsync(Integer key, Object ... args) {
+ Profile profile = profileMap.get(key);
+ boolean append = (boolean) args[0];
+ if (append) {
+ profile = new Profile(profile.memberId, profile.company + "-r");
+ }
+ return CompletableFuture.completedFuture(profile);
+ }
+
+ @Override
public boolean isRetriable(Throwable exception) {
return false;
}
- static InMemoryReadFunction getInMemoryReadFunction(String serializedProfiles) {
- return new InMemoryReadFunction(serializedProfiles);
+ static InMemoryProfileReadFunction getInMemoryReadFunction(String testName, String serializedProfiles) {
+ return new InMemoryProfileReadFunction(serializedProfiles);
}
}
- static class InMemoryWriteFunction implements TableWriteFunction<Integer, EnrichedPageView> {
- private transient List<EnrichedPageView> records;
- private String testName;
+ static class InMemoryEnrichedPageViewWriteFunction extends BaseTableFunction
+ implements TableWriteFunction<Integer, EnrichedPageView> {
- public InMemoryWriteFunction(String testName) {
+ private String testName;
+ private transient List<EnrichedPageView> records;
+
+ public InMemoryEnrichedPageViewWriteFunction(String testName) {
this.testName = testName;
}
@@ -130,6 +148,17 @@
@Override
public CompletableFuture<Void> putAsync(Integer key, EnrichedPageView record) {
+ System.out.println("==> " + testName + " writing " + record.getPageKey());
+ records.add(record);
+ return CompletableFuture.completedFuture(null);
+ }
+
+ @Override
+ public CompletableFuture<Void> putAsync(Integer key, EnrichedPageView record, Object ... args) {
+ boolean append = (boolean) args[0];
+ if (append) {
+ record = new EnrichedPageView(record.pageKey, record.memberId, record.company + "-w");
+ }
records.add(record);
return CompletableFuture.completedFuture(null);
}
@@ -146,7 +175,98 @@
}
}
- private <K, V> Table<KV<K, V>> getCachingTable(TableDescriptor<K, V, ?> actualTableDesc, boolean defaultCache, String id, StreamApplicationDescriptor appDesc) {
+ static class InMemoryCounterReadFunction extends BaseTableFunction
+ implements TableReadFunction {
+
+ private final String testName;
+
+ private InMemoryCounterReadFunction(String testName) {
+ this.testName = testName;
+ }
+
+ @Override
+ public CompletableFuture readAsync(int opId, Object... args) {
+ if (1 == opId) {
+ boolean shouldReturnValue = (boolean) args[0];
+ AtomicInteger counter = counters.get(testName);
+ Integer result = shouldReturnValue ? counter.get() : null;
+ return CompletableFuture.completedFuture(result);
+ } else {
+ throw new SamzaException("Invalid opId: " + opId);
+ }
+ }
+
+ @Override
+ public CompletableFuture getAsync(Object key) {
+ throw new SamzaException("Not supported");
+ }
+
+ @Override
+ public boolean isRetriable(Throwable exception) {
+ return false;
+ }
+ }
+
+ static class InMemoryCounterWriteFunction extends BaseTableFunction
+ implements TableWriteFunction {
+
+ private final String testName;
+
+ private InMemoryCounterWriteFunction(String testName) {
+ this.testName = testName;
+ }
+
+ @Override
+ public CompletableFuture<Void> putAsync(Object key, Object record) {
+ throw new SamzaException("Not supported");
+ }
+
+ @Override
+ public CompletableFuture<Void> deleteAsync(Object key) {
+ throw new SamzaException("Not supported");
+ }
+
+ @Override
+ public CompletableFuture writeAsync(int opId, Object... args) {
+ Integer result;
+ AtomicInteger counter = counters.get(testName);
+ boolean shouldModify = (boolean) args[0];
+ switch (opId) {
+ case 1:
+ result = shouldModify ? counter.incrementAndGet() : counter.get();
+ break;
+ case 2:
+ result = shouldModify ? counter.decrementAndGet() : counter.get();
+ break;
+ default:
+ throw new SamzaException("Invalid opId: " + opId);
+ }
+ return CompletableFuture.completedFuture(result);
+ }
+
+ @Override
+ public boolean isRetriable(Throwable exception) {
+ return false;
+ }
+ }
+
+ static class DummyReadFunction extends BaseTableFunction
+ implements TableReadFunction {
+
+ @Override
+ public CompletableFuture getAsync(Object key) {
+ throw new SamzaException("Not supported");
+ }
+
+ @Override
+ public boolean isRetriable(Throwable exception) {
+ return false;
+ }
+ }
+
+ private <K, V> Table<KV<K, V>> getCachingTable(TableDescriptor<K, V, ?> actualTableDesc, boolean defaultCache,
+ StreamApplicationDescriptor appDesc) {
+ String id = actualTableDesc.getTableId();
CachingTableDescriptor<K, V> cachingDesc;
if (defaultCache) {
cachingDesc = new CachingTableDescriptor<>("caching-table-" + id, actualTableDesc);
@@ -161,75 +281,60 @@
return appDesc.getTable(cachingDesc);
}
- static class MyReadFunction implements TableReadFunction {
- @Override
- public CompletableFuture getAsync(Object key) {
- return null;
- }
-
- @Override
- public boolean isRetriable(Throwable exception) {
- return false;
- }
- }
-
private void doTestStreamTableJoinRemoteTable(boolean withCache, boolean defaultCache, String testName) throws Exception {
- final InMemoryWriteFunction writer = new InMemoryWriteFunction(testName);
writtenRecords.put(testName, new ArrayList<>());
int count = 10;
- PageView[] pageViews = generatePageViews(count);
- String profiles = Base64Serializer.serialize(generateProfiles(count));
+ final PageView[] pageViews = generatePageViews(count);
+ final String profiles = Base64Serializer.serialize(generateProfiles(count));
- int partitionCount = 4;
- Map<String, String> configs = TestLocalTableEndToEnd.getBaseJobConfig(bootstrapUrl(), zkConnect());
-
+ final int partitionCount = 4;
+ final Map<String, String> configs = TestLocalTableEndToEnd.getBaseJobConfig(bootstrapUrl(), zkConnect());
configs.put("streams.PageView.samza.system", "test");
configs.put("streams.PageView.source", Base64Serializer.serialize(pageViews));
configs.put("streams.PageView.partitionCount", String.valueOf(partitionCount));
final RateLimiter readRateLimiter = mock(RateLimiter.class, withSettings().serializable());
- final RateLimiter writeRateLimiter = mock(RateLimiter.class, withSettings().serializable());
- final TableRateLimiter.CreditFunction creditFunction = (k, v)->1;
+ final TableRateLimiter.CreditFunction creditFunction = (k, v, args) -> 1;
final StreamApplication app = appDesc -> {
- RemoteTableDescriptor<Integer, TestTableData.Profile> inputTableDesc = new RemoteTableDescriptor<>("profile-table-1");
- inputTableDesc
- .withReadFunction(InMemoryReadFunction.getInMemoryReadFunction(profiles))
+
+ final RemoteTableDescriptor joinTableDesc =
+ new RemoteTableDescriptor<Integer, TestTableData.Profile>("profile-table-1")
+ .withReadFunction(InMemoryProfileReadFunction.getInMemoryReadFunction(testName, profiles))
.withRateLimiter(readRateLimiter, creditFunction, null);
- // dummy reader
- TableReadFunction readFn = new MyReadFunction();
+ final RemoteTableDescriptor outputTableDesc =
+ new RemoteTableDescriptor<Integer, EnrichedPageView>("enriched-page-view-table-1")
+ .withReadFunction(new DummyReadFunction())
+ .withReadRateLimiterDisabled()
+ .withWriteFunction(new InMemoryEnrichedPageViewWriteFunction(testName))
+ .withWriteRateLimit(1000);
- RemoteTableDescriptor<Integer, EnrichedPageView> outputTableDesc = new RemoteTableDescriptor<>("enriched-page-view-table-1");
- outputTableDesc
- .withReadFunction(readFn)
- .withWriteFunction(writer)
- .withRateLimiter(writeRateLimiter, creditFunction, creditFunction);
-
- Table<KV<Integer, EnrichedPageView>> outputTable = withCache
- ? getCachingTable(outputTableDesc, defaultCache, "output", appDesc)
+ final Table<KV<Integer, EnrichedPageView>> outputTable = withCache
+ ? getCachingTable(outputTableDesc, defaultCache, appDesc)
: appDesc.getTable(outputTableDesc);
- Table<KV<Integer, Profile>> inputTable = withCache
- ? getCachingTable(inputTableDesc, defaultCache, "input", appDesc)
- : appDesc.getTable(inputTableDesc);
+ final Table<KV<Integer, Profile>> joinTable = withCache
+ ? getCachingTable(joinTableDesc, defaultCache, appDesc)
+ : appDesc.getTable(joinTableDesc);
- DelegatingSystemDescriptor ksd = new DelegatingSystemDescriptor("test");
- GenericInputDescriptor<PageView> isd = ksd.getInputDescriptor("PageView", new NoOpSerde<>());
+ final DelegatingSystemDescriptor ksd = new DelegatingSystemDescriptor("test");
+ final GenericInputDescriptor<PageView> isd = ksd.getInputDescriptor("PageView", new NoOpSerde<>());
+
appDesc.getInputStream(isd)
.map(pv -> new KV<>(pv.getMemberId(), pv))
- .join(inputTable, new PageViewToProfileJoinFunction())
+ .join(joinTable, new PageViewToProfileJoinFunction())
.map(m -> new KV(m.getMemberId(), m))
.sendTo(outputTable);
};
- Config config = new MapConfig(configs);
+ final Config config = new MapConfig(configs);
final LocalApplicationRunner runner = new LocalApplicationRunner(app, config);
executeRun(runner, config);
runner.waitForFinish();
- int numExpected = count * partitionCount;
+ final int numExpected = count * partitionCount;
Assert.assertEquals(numExpected, writtenRecords.get(testName).size());
Assert.assertTrue(writtenRecords.get(testName).get(0) instanceof EnrichedPageView);
}
@@ -312,6 +417,162 @@
}
table.flush();
table.close();
- Assert.assertEquals(2, failureCount);
+ }
+
+ private void doTestReadWriteWithArgs(boolean withCache, boolean defaultCache, String testName) throws Exception {
+
+ writtenRecords.put(testName, new ArrayList<>());
+ counters.put(testName, new AtomicInteger());
+
+ final int count = 10;
+ final PageView[] pageViews = generatePageViews(count);
+ final String profiles = Base64Serializer.serialize(generateProfiles(count));
+
+ final int partitionCount = 4;
+ final Map<String, String> configs = TestLocalTableEndToEnd.getBaseJobConfig(bootstrapUrl(), zkConnect());
+ configs.put("streams.PageView.samza.system", "test");
+ configs.put("streams.PageView.source", Base64Serializer.serialize(pageViews));
+ configs.put("streams.PageView.partitionCount", String.valueOf(partitionCount));
+
+ final RateLimiter readRateLimiter = mock(RateLimiter.class, withSettings().serializable());
+ final TableRateLimiter.CreditFunction creditFunction = (k, v, args) -> 1;
+ final StreamApplication app = appDesc -> {
+
+ final RemoteTableDescriptor joinTableDesc =
+ new RemoteTableDescriptor<Integer, TestTableData.Profile>("profile-table-1")
+ .withReadFunction(InMemoryProfileReadFunction.getInMemoryReadFunction(testName, profiles))
+ .withRateLimiter(readRateLimiter, creditFunction, null);
+
+ final RemoteTableDescriptor outputTableDesc =
+ new RemoteTableDescriptor<Integer, EnrichedPageView>("enriched-page-view-table-1")
+ .withReadFunction(new DummyReadFunction())
+ .withReadRateLimiterDisabled()
+ .withWriteFunction(new InMemoryEnrichedPageViewWriteFunction(testName))
+ .withWriteRateLimit(1000);
+
+ final RemoteTableDescriptor counterTableDesc =
+ new RemoteTableDescriptor("counter-table-1")
+ .withReadFunction(new InMemoryCounterReadFunction(testName))
+ .withWriteFunction(new InMemoryCounterWriteFunction(testName))
+ .withRateLimiterDisabled();
+
+ final Table joinTable = withCache
+ ? getCachingTable(joinTableDesc, defaultCache, appDesc)
+ : appDesc.getTable(joinTableDesc);
+
+ final Table outputTable = withCache
+ ? getCachingTable(outputTableDesc, defaultCache, appDesc)
+ : appDesc.getTable(outputTableDesc);
+
+ final Table counterTable = withCache
+ ? getCachingTable(counterTableDesc, defaultCache, appDesc)
+ : appDesc.getTable(counterTableDesc);
+
+ final String joinTableName = ((TableImpl) joinTable).getTableId();
+ final String outputTableName = ((TableImpl) outputTable).getTableId();
+ final String counterTableName = ((TableImpl) counterTable).getTableId();
+
+ final DelegatingSystemDescriptor ksd = new DelegatingSystemDescriptor("test");
+ final GenericInputDescriptor<PageView> isd = ksd.getInputDescriptor("PageView", new NoOpSerde<>());
+
+ appDesc.getInputStream(isd)
+ .map(new JoinMapFunction(joinTableName, outputTableName, counterTableName));
+ };
+
+ final Config config = new MapConfig(configs);
+ final LocalApplicationRunner runner = new LocalApplicationRunner(app, config);
+ executeRun(runner, config);
+ runner.waitForFinish();
+
+ final int numExpected = count * partitionCount;
+ Assert.assertEquals(numExpected, writtenRecords.get(testName).size());
+ Assert.assertTrue(writtenRecords.get(testName).get(0) instanceof EnrichedPageView);
+ writtenRecords.get(testName).forEach(epv -> {
+ Assert.assertTrue(epv.company.endsWith("-r-w"));
+ });
+ Assert.assertEquals(numExpected, counters.get(testName).get());
+ }
+
+ static private class JoinMapFunction implements MapFunction<PageView, EnrichedPageView> {
+
+ private final String joinTableName;
+ private final String outputTableName;
+ private final String counterTableName;
+ private ReadWriteTable<Integer, Profile> inputTable;
+ private ReadWriteTable<Integer, EnrichedPageView> outputTable;
+ private ReadWriteTable counterTable;
+
+ private JoinMapFunction(String joinTableName, String outputTableName, String counterTableName) {
+ this.joinTableName = joinTableName;
+ this.outputTableName = outputTableName;
+ this.counterTableName = counterTableName;
+ }
+
+ @Override
+ public void init(Context context) {
+ inputTable = context.getTaskContext().getTable(joinTableName);
+ outputTable = context.getTaskContext().getTable(outputTableName);
+ counterTable = context.getTaskContext().getTable(counterTableName);
+ }
+
+ @Override
+ public EnrichedPageView apply(PageView pageView) {
+ try {
+ // Counter manipulation
+ badOpId();
+ Assert.assertNull(getCounterValue(false));
+ Integer beforeValue = getCounterValue(true);
+ Assert.assertEquals(beforeValue, incCounterValue(false));
+ Assert.assertEquals(beforeValue, decCounterValue(false));
+ Assert.assertEquals(Integer.valueOf(beforeValue + 1), incCounterValue(true));
+ Assert.assertEquals(beforeValue, decCounterValue(true));
+ Assert.assertEquals(beforeValue, getCounterValue(true));
+ incCounterValue(true);
+
+ // Generate EnrichedPageView
+ Profile profile = inputTable.getAsync(pageView.memberId, true).join();
+ EnrichedPageView epv = new EnrichedPageView(pageView.getPageKey(), profile.memberId, profile.company);
+ outputTable.putAsync(epv.memberId, epv, true).join();
+ return epv;
+ } catch (Exception ex) {
+ throw new SamzaException(ex);
+ }
+ }
+
+ private Integer getCounterValue(boolean shouldReturn) {
+ return (Integer) counterTable.readAsync(1, shouldReturn).join();
+ }
+
+ private Integer incCounterValue(boolean shouldModifdy) {
+ return (Integer) counterTable.writeAsync(1, shouldModifdy).join();
+ }
+
+ private Integer decCounterValue(boolean shouldModifdy) {
+ return (Integer) counterTable.writeAsync(2, shouldModifdy).join();
+ }
+
+ private void badOpId() {
+ try {
+ outputTable.readAsync(0).join();
+ Assert.fail("Shouldn't reach here");
+ } catch (SamzaException ex) {
+ // Expected exception
+ }
+ }
+ }
+
+ @Test
+ public void testReadWriteWithArgs() throws Exception {
+ doTestReadWriteWithArgs(false, false, "testReadWriteWithArgs");
+ }
+
+ @Test
+ public void testReadWriteWithArgsWithCache() throws Exception {
+ doTestReadWriteWithArgs(true, false, "testReadWriteWithArgsWithCache");
+ }
+
+ @Test
+ public void testReadWriteWithArgsWithDefaultCache() throws Exception {
+ doTestReadWriteWithArgs(true, true, "testReadWriteWithArgsWithDefaultCache");
}
}