SAMZA-2751: RunLoop and High-Level API changes for Drain (#1616)
diff --git a/gradle.properties b/gradle.properties
index 1a6ee97..cab76e0 100644
--- a/gradle.properties
+++ b/gradle.properties
@@ -16,7 +16,7 @@
# under the License.
group=org.apache.samza
version=1.7.0-SNAPSHOT
-scalaSuffix=2.11
+scalaSuffix=2.12
# after changing this value, run `$ ./gradlew wrapper` and commit the resulting changed files
gradleVersion=5.2.1
diff --git a/samza-api/src/main/java/org/apache/samza/system/DrainMessage.java b/samza-api/src/main/java/org/apache/samza/system/DrainMessage.java
new file mode 100644
index 0000000..4373d51
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/system/DrainMessage.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.system;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+/**
+ * The DrainMessage is a control message that is sent out to next stage
+ * once the task has consumed to the end of a bounded stream.
+ */
+public class DrainMessage extends ControlMessage {
+ /**
+ * Id used to invalidate DrainMessages between runs. Ties to app.run.id from config.
+ */
+ private final String runId;
+
+ public DrainMessage(String runId) {
+ this(null, runId);
+ }
+
+ public DrainMessage(@JsonProperty("task-name") String taskName, @JsonProperty("run-id") String runId) {
+ super(taskName);
+ this.runId = runId;
+ }
+
+ public String getRunId() {
+ return runId;
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ final int result = prime * super.hashCode() + (this.runId != null ? this.runId.hashCode() : 0);
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+ if (obj == null)
+ return false;
+ if (getClass() != obj.getClass())
+ return false;
+
+ final DrainMessage other = (DrainMessage) obj;
+ if (!super.equals(other)) {
+ return false;
+ }
+ if (!runId.equals(other.runId)) {
+ return false;
+ }
+ return true;
+ }
+}
diff --git a/samza-api/src/main/java/org/apache/samza/system/IncomingMessageEnvelope.java b/samza-api/src/main/java/org/apache/samza/system/IncomingMessageEnvelope.java
index 2112463..1f4a740 100644
--- a/samza-api/src/main/java/org/apache/samza/system/IncomingMessageEnvelope.java
+++ b/samza-api/src/main/java/org/apache/samza/system/IncomingMessageEnvelope.java
@@ -132,7 +132,6 @@
if (envelopeKeyorOffset == null) {
return new SystemStreamPartition(systemStreamPartition, 0);
}
-
// modulo 31 first to best spread out the hashcode and then modulo elasticityFactor for actual keyBucket
// Note: elasticityFactor <= 16 so modulo 31 is safe to do.
int keyBucket = (Math.abs(envelopeKeyorOffset.hashCode()) % 31) % elasticityFactor;
@@ -162,6 +161,10 @@
return END_OF_STREAM_OFFSET.equals(offset);
}
+ public boolean isDrain() {
+ return message != null && DrainMessage.class.isAssignableFrom(message.getClass());
+ }
+
/**
* This method is deprecated in favor of WatermarkManager.buildEndOfStreamEnvelope(SystemStreamPartition ssp).
*
@@ -172,6 +175,10 @@
return new IncomingMessageEnvelope(ssp, END_OF_STREAM_OFFSET, null, new EndOfStreamMessage(null));
}
+ public static IncomingMessageEnvelope buildDrainMessage(SystemStreamPartition ssp, String runId) {
+ return new IncomingMessageEnvelope(ssp, null, null, new DrainMessage(runId));
+ }
+
public static IncomingMessageEnvelope buildWatermarkEnvelope(SystemStreamPartition ssp, long watermark) {
return new IncomingMessageEnvelope(ssp, null, null, new WatermarkMessage(watermark, null));
}
diff --git a/samza-api/src/main/java/org/apache/samza/system/MessageType.java b/samza-api/src/main/java/org/apache/samza/system/MessageType.java
index 7129d00..9f58621 100644
--- a/samza-api/src/main/java/org/apache/samza/system/MessageType.java
+++ b/samza-api/src/main/java/org/apache/samza/system/MessageType.java
@@ -26,7 +26,8 @@
public enum MessageType {
USER_MESSAGE,
WATERMARK,
- END_OF_STREAM;
+ END_OF_STREAM,
+ DRAIN;
/**
* Returns the {@link MessageType} of a particular intermediate stream message.
@@ -38,6 +39,8 @@
return WATERMARK;
} else if (message instanceof EndOfStreamMessage) {
return END_OF_STREAM;
+ } else if (message instanceof DrainMessage) {
+ return DRAIN;
} else {
return USER_MESSAGE;
}
diff --git a/samza-api/src/main/java/org/apache/samza/task/DrainListenerTask.java b/samza-api/src/main/java/org/apache/samza/task/DrainListenerTask.java
new file mode 100644
index 0000000..c5ca7d5
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/task/DrainListenerTask.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.task;
+
+/**
+ * The DrainListenerTask augments {@link StreamTask} allowing the method implementor to specify code to be
+ * executed when the 'drain' is reached for a task.
+ */
+public interface DrainListenerTask {
+ /**
+ * Guaranteed to be invoked when all SSPs processed by this task have drained.
+ *
+ * @param collector Contains the means of sending message envelopes to an output stream.*
+ * @param coordinator Manages execution of tasks.
+ *
+ * @throws Exception Any exception types encountered during the execution of the processing task.
+ */
+ void onDrain(MessageCollector collector, TaskCoordinator coordinator) throws Exception;
+}
diff --git a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
index 145f4cd..d36ac66 100644
--- a/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/JobConfig.java
@@ -166,9 +166,12 @@
// Enable DrainMonitor in Samza Containers
// Default is false for now. Will be turned on after testing
- public static final String DRAIN_MONITOR_ENABLED = "samza.drain-monitor.enabled";
+ public static final String DRAIN_MONITOR_ENABLED = "job.drain-monitor.enabled";
public static final boolean DRAIN_MONITOR_ENABLED_DEFAULT = false;
+ public static final String DRAIN_MONITOR_POLL_INTERVAL_MILLIS = "job.drain-monitor.poll.interval.ms";
+ public static final long DRAIN_MONITOR_POLL_INTERVAL_MILLIS_DEFAULT = 60_000;
+
// Enable ClusterBasedJobCoordinator aka ApplicationMaster High Availability (AM-HA).
// High availability allows new AM to establish connection with already running containers
public static final String YARN_AM_HIGH_AVAILABILITY_ENABLED = "yarn.am.high-availability.enabled";
@@ -479,6 +482,10 @@
return getBoolean(DRAIN_MONITOR_ENABLED, DRAIN_MONITOR_ENABLED_DEFAULT);
}
+ public long getDrainMonitorPollIntervalMillis() {
+ return getLong(DRAIN_MONITOR_POLL_INTERVAL_MILLIS, DRAIN_MONITOR_POLL_INTERVAL_MILLIS_DEFAULT);
+ }
+
public long getContainerHeartbeatRetryCount() {
return getLong(YARN_CONTAINER_HEARTBEAT_RETRY_COUNT, YARN_CONTAINER_HEARTBEAT_RETRY_COUNT_DEFAULT);
}
diff --git a/samza-core/src/main/java/org/apache/samza/config/StreamConfig.java b/samza-core/src/main/java/org/apache/samza/config/StreamConfig.java
index 9f774b8..9fc74a7 100644
--- a/samza-core/src/main/java/org/apache/samza/config/StreamConfig.java
+++ b/samza-core/src/main/java/org/apache/samza/config/StreamConfig.java
@@ -167,7 +167,7 @@
* @param systemStream system stream to map to stream id
* @return stream id corresponding to the system stream
*/
- private String systemStreamToStreamId(SystemStream systemStream) {
+ public String systemStreamToStreamId(SystemStream systemStream) {
List<String> streamIds = getStreamIdsForSystem(systemStream.getSystem()).stream()
.filter(streamId -> systemStream.getStream().equals(getPhysicalName(streamId))).collect(Collectors.toList());
if (streamIds.size() > 1) {
diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoop.java b/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
index b7a80a3..5bbcbae 100644
--- a/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoop.java
@@ -24,6 +24,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -35,6 +36,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.samza.SamzaException;
+import org.apache.samza.system.DrainMessage;
import org.apache.samza.system.IncomingMessageEnvelope;
import org.apache.samza.system.MessageType;
import org.apache.samza.system.SystemConsumers;
@@ -70,7 +72,6 @@
private final List<AsyncTaskWorker> taskWorkers;
private final SystemConsumers consumerMultiplexer;
private final Map<SystemStreamPartition, List<AsyncTaskWorker>> sspToTaskWorkerMapping;
-
private final ExecutorService threadPool;
private final CoordinatorRequests coordinatorRequests;
private final Object latch;
@@ -86,9 +87,11 @@
private volatile boolean shutdownNow = false;
private volatile Throwable throwable = null;
private final HighResolutionClock clock;
- private final boolean isAsyncCommitEnabled;
+ private boolean isAsyncCommitEnabled;
private volatile boolean runLoopResumedSinceLastChecked;
private final int elasticityFactor;
+ private final String runId;
+ private boolean isDraining = false;
public RunLoop(Map<TaskName, RunLoopTask> runLoopTasks,
ExecutorService threadPool,
@@ -103,7 +106,7 @@
HighResolutionClock clock,
boolean isAsyncCommitEnabled) {
this(runLoopTasks, threadPool, consumerMultiplexer, maxConcurrency, windowMs, commitMs, callbackTimeoutMs,
- maxThrottlingDelayMs, maxIdleMs, containerMetrics, clock, isAsyncCommitEnabled, 1);
+ maxThrottlingDelayMs, maxIdleMs, containerMetrics, clock, isAsyncCommitEnabled, 1, null);
}
public RunLoop(Map<TaskName, RunLoopTask> runLoopTasks,
@@ -118,7 +121,8 @@
SamzaContainerMetrics containerMetrics,
HighResolutionClock clock,
boolean isAsyncCommitEnabled,
- int elasticityFactor) {
+ int elasticityFactor,
+ String runId) {
this.threadPool = threadPool;
this.consumerMultiplexer = consumerMultiplexer;
@@ -134,18 +138,30 @@
this.latch = new Object();
this.workerTimer = Executors.newSingleThreadScheduledExecutor();
this.clock = clock;
- Map<TaskName, AsyncTaskWorker> workers = new HashMap<>();
+ // assign runId before creating workers. As the inner AsyncTaskWorker class is not static, it relies on
+ // the outer class fields to be init first
+ this.runId = runId;
+ Map<TaskName, AsyncTaskWorker> workers = new HashMap<>();
for (RunLoopTask task : runLoopTasks.values()) {
workers.put(task.taskName(), new AsyncTaskWorker(task));
}
// Partions and tasks assigned to the container will not change during the run loop life time
this.sspToTaskWorkerMapping = Collections.unmodifiableMap(getSspToAsyncTaskWorkerMap(runLoopTasks, workers));
+
this.taskWorkers = Collections.unmodifiableList(new ArrayList<>(workers.values()));
this.isAsyncCommitEnabled = isAsyncCommitEnabled;
this.elasticityFactor = elasticityFactor;
}
/**
+ * Sets the RunLoop to drain mode.
+ * */
+ private void drain() {
+ isDraining = true;
+ isAsyncCommitEnabled = false;
+ }
+
+ /**
* Returns mapping of the SystemStreamPartition to the AsyncTaskWorkers to efficiently route the envelopes
*/
private static Map<SystemStreamPartition, List<AsyncTaskWorker>> getSspToAsyncTaskWorkerMap(
@@ -297,7 +313,8 @@
* when elasticity is enabled,
* sspToTaskWorkerMapping has workers for a SSP which has keyBucket
* hence need to use envelop.getSSP(elasticityFactor)
- * Additionally, when envelope is EnofStream or Watermark, it needs to be sent to all works for the ssp irrespective of keyBucket
+ * Additionally, when envelope is EndOfStream or Watermark or Drain, it needs to be sent to all works for the ssp
+ * irrespective of keyBucket
* @param envelope
* @return list of workers for the envelope
*/
@@ -309,9 +326,14 @@
final SystemStreamPartition sspOfEnvelope = envelope.getSystemStreamPartition(elasticityFactor);
List<AsyncTaskWorker> listOfWorkersForEnvelope = null;
- // if envelope is end of stream or watermark, it needs to be routed to all tasks consuming the ssp irresp of keybucket
+ // if envelope is end of stream or watermark or drain, it needs to be routed to all tasks consuming the ssp irresp
+ // of keybucket
MessageType messageType = MessageType.of(envelope.getMessage());
- if (envelope.isEndOfStream() || MessageType.END_OF_STREAM == messageType || MessageType.WATERMARK == messageType) {
+ if (envelope.isEndOfStream()
+ || envelope.isDrain()
+ || messageType == MessageType.END_OF_STREAM
+ || messageType == MessageType.DRAIN
+ || messageType == MessageType.WATERMARK) {
//sspToTaskWorkerMapping has ssps with keybucket so extract and check only system, stream and partition and ignore the keybucket
listOfWorkersForEnvelope = sspToTaskWorkerMapping.entrySet()
@@ -391,6 +413,20 @@
}
/**
+ * Resume the runloop thread. This method is triggered after a task has completed drain.
+ */
+ private void resumeAfterDrain() {
+ log.trace("Resume loop thread");
+ if (coordinatorRequests.shouldShutdownNow()) {
+ shutdownNow = true;
+ }
+ synchronized (latch) {
+ latch.notifyAll();
+ runLoopResumedSinceLastChecked = true;
+ }
+ }
+
+ /**
* Set the throwable and abort run loop. The throwable will be thrown from the run loop thread
* @param t throwable
*/
@@ -426,6 +462,7 @@
COMMIT,
PROCESS,
END_OF_STREAM,
+ DRAIN,
SCHEDULER,
NO_OP
}
@@ -443,7 +480,8 @@
this.task = task;
this.callbackManager = new TaskCallbackManager(this, callbackTimer, callbackTimeoutMs, maxConcurrency, clock);
Set<SystemStreamPartition> sspSet = getWorkingSSPSet(task);
- this.state = new AsyncTaskState(task.taskName(), task.metrics(), sspSet, !task.intermediateStreams().isEmpty());
+ this.state = new AsyncTaskState(task.taskName(), task.metrics(), sspSet, !task.intermediateStreams().isEmpty(),
+ runId);
}
private void init() {
@@ -514,18 +552,46 @@
case END_OF_STREAM:
endOfStream();
break;
+ case DRAIN:
+ drain();
+ break;
default:
//no op
break;
}
}
+ /**
+ * Called when a task has drained i.e all SSPs for the task have received a drain message.
+ * */
+ private void drain() {
+ state.complete = true;
+ state.startDrain();
+ try {
+ ReadableCoordinator coordinator = new ReadableCoordinator(task.taskName());
+
+ task.drain(coordinator);
+
+ // issue a shutdown request for the task
+ coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+ coordinatorRequests.update(coordinator);
+
+ // issue a commit explicitly before we shutdown the task
+ // Adding commit to coordinator will not work as the state is marked complete and NO_OP will always be the
+ // next operation for this task
+ commit();
+ } finally {
+ resumeAfterDrain();
+ }
+ }
+
private void endOfStream() {
state.complete = true;
try {
ReadableCoordinator coordinator = new ReadableCoordinator(task.taskName());
task.endOfStream(coordinator);
+
// issue a request for shutdown of the task
coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
coordinatorRequests.update(coordinator);
@@ -538,7 +604,6 @@
} finally {
resume();
}
-
}
/**
@@ -638,9 +703,12 @@
}
};
- if (threadPool != null) {
+ if (threadPool != null && !isDraining) {
log.trace("Task {} commits on the thread pool", task.taskName());
threadPool.submit(commitWorker);
+ } else if (isDraining) {
+ log.trace("Task {} commits on the run loop thread as task is draining", task.taskName());
+ commitWorker.run();
} else {
log.trace("Task {} commits on the run loop thread", task.taskName());
commitWorker.run();
@@ -759,24 +827,31 @@
private volatile boolean needScheduler = false;
private volatile boolean complete = false;
private volatile boolean endOfStream = false;
+ private volatile boolean shouldDrain = false;
private volatile boolean windowInFlight = false;
private volatile boolean commitInFlight = false;
private volatile boolean schedulerInFlight = false;
private final AtomicInteger messagesInFlight = new AtomicInteger(0);
- private final ArrayDeque<PendingEnvelope> pendingEnvelopeQueue;
+ private final ArrayDeque<PendingEnvelope> pendingEnvelopeQueue;
//Set of SSPs that we are currently processing for this task instance
private final Set<SystemStreamPartition> processingSspSet;
+ //Set of SSPs that we are currently processing for this task instance
+ private final Set<SystemStreamPartition> processingSspSetToDrain;
private final TaskName taskName;
private final TaskInstanceMetrics taskMetrics;
private final boolean hasIntermediateStreams;
+ private final String runId;
- AsyncTaskState(TaskName taskName, TaskInstanceMetrics taskMetrics, Set<SystemStreamPartition> sspSet, boolean hasIntermediateStreams) {
+ AsyncTaskState(TaskName taskName, TaskInstanceMetrics taskMetrics, Set<SystemStreamPartition> sspSet,
+ boolean hasIntermediateStreams, String runId) {
this.taskName = taskName;
this.taskMetrics = taskMetrics;
this.pendingEnvelopeQueue = new ArrayDeque<>();
this.processingSspSet = sspSet;
+ this.processingSspSetToDrain = new HashSet<>(sspSet);
this.hasIntermediateStreams = hasIntermediateStreams;
+ this.runId = runId;
}
private boolean checkEndOfStream() {
@@ -796,6 +871,7 @@
&& sspInSet.getPartition().equals(sspOfEnvelope.getPartition()))
.findFirst();
ssp.ifPresent(processingSspSet::remove);
+ ssp.ifPresent(processingSspSetToDrain::remove);
}
if (!hasIntermediateStreams) {
pendingEnvelopeQueue.remove();
@@ -805,6 +881,57 @@
return processingSspSet.isEmpty();
}
+ private boolean shouldDrain() {
+ if (endOfStream) {
+ return false;
+ }
+
+ if (!pendingEnvelopeQueue.isEmpty()) {
+ PendingEnvelope pendingEnvelope = pendingEnvelopeQueue.peek();
+ IncomingMessageEnvelope envelope = pendingEnvelope.envelope;
+
+ if (envelope.isDrain()) {
+ final DrainMessage message = (DrainMessage) envelope.getMessage();
+ if (!message.getRunId().equals(runId)) {
+ // Removing the drain message from the pending queue as it doesn't match with the current runId
+ // Removing it will ensure that it is not picked up by process()
+ pendingEnvelopeQueue.remove();
+ } else {
+ // set the RunLoop to drain mode
+ if (!isDraining) {
+ drain();
+ }
+
+ if (elasticityFactor <= 1) {
+ SystemStreamPartition ssp = envelope.getSystemStreamPartition();
+ processingSspSetToDrain.remove(ssp);
+ } else {
+ // SystemConsumers will write only one envelope (enclosing DrainMessage) per SSP in its buffer.
+ // This envelope doesn't have keybucket info it's SSP. With elasticity, the same SSP can be processed by
+ // multiple tasks. Therefore, if envelope contains drain message, the ssp of envelope should be removed
+ // from task's processing set irrespective of keyBucket.
+ SystemStreamPartition sspOfEnvelope = envelope.getSystemStreamPartition();
+ Optional<SystemStreamPartition> ssp = processingSspSetToDrain.stream()
+ .filter(sspInSet -> sspInSet.getSystemStream().equals(sspOfEnvelope.getSystemStream())
+ && sspInSet.getPartition().equals(sspOfEnvelope.getPartition()))
+ .findFirst();
+ ssp.ifPresent(processingSspSetToDrain::remove);
+ }
+
+ if (!hasIntermediateStreams) {
+ // Don't remove from the pending queue as we want the DAG to pick up Drain message and propagate it to
+ // intermediate streams
+ pendingEnvelopeQueue.remove();
+ }
+ }
+ }
+ return processingSspSetToDrain.isEmpty();
+ }
+ // if no messages are in the queue, the task has probably already drained or there are no messages from
+ // the chooser
+ return false;
+ }
+
/**
* Returns whether the task is ready to do process/window/commit.
*
@@ -813,6 +940,11 @@
if (checkEndOfStream()) {
endOfStream = true;
}
+
+ if (shouldDrain()) {
+ shouldDrain = true;
+ }
+
if (coordinatorRequests.commitRequests().remove(taskName)) {
needCommit = true;
}
@@ -826,9 +958,9 @@
*/
if (needCommit) {
return (messagesInFlight.get() == 0 || isAsyncCommitEnabled) && !opInFlight;
- } else if (needWindow || needScheduler || endOfStream) {
+ } else if (needWindow || needScheduler || endOfStream || shouldDrain) {
/*
- * A task is ready for window, scheduler or end-of-stream operation.
+ * A task is ready for window, scheduler, drain or end-of-stream operation.
*/
return messagesInFlight.get() == 0 && !opInFlight;
} else {
@@ -847,14 +979,24 @@
*/
private WorkerOp nextOp() {
- if (complete) return WorkerOp.NO_OP;
+ if (complete) {
+ return WorkerOp.NO_OP;
+ }
if (isReady()) {
- if (needCommit) return WorkerOp.COMMIT;
- else if (needWindow) return WorkerOp.WINDOW;
- else if (needScheduler) return WorkerOp.SCHEDULER;
- else if (endOfStream && pendingEnvelopeQueue.isEmpty()) return WorkerOp.END_OF_STREAM;
- else if (!pendingEnvelopeQueue.isEmpty()) return WorkerOp.PROCESS;
+ if (needCommit) {
+ return WorkerOp.COMMIT;
+ } else if (needWindow) {
+ return WorkerOp.WINDOW;
+ } else if (needScheduler) {
+ return WorkerOp.SCHEDULER;
+ } else if (endOfStream && pendingEnvelopeQueue.isEmpty()) {
+ return WorkerOp.END_OF_STREAM;
+ } else if (shouldDrain && pendingEnvelopeQueue.isEmpty()) {
+ return WorkerOp.DRAIN;
+ } else if (!pendingEnvelopeQueue.isEmpty()) {
+ return WorkerOp.PROCESS;
+ }
}
return WorkerOp.NO_OP;
}
@@ -876,6 +1018,10 @@
windowInFlight = true;
}
+ private void startDrain() {
+ shouldDrain = false;
+ }
+
private void startCommit() {
needCommit = false;
commitInFlight = true;
@@ -918,6 +1064,7 @@
int queueSize = pendingEnvelopeQueue.size();
taskMetrics.pendingMessages().set(queueSize);
log.trace("Insert envelope to task {} queue.", taskName);
+ log.trace("Insert envelope to task {} queue.", taskName);
log.debug("Task {} pending envelope count is {} after insertion.", taskName, queueSize);
}
diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
index 9d20950..cb021d0 100644
--- a/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoopFactory.java
@@ -41,7 +41,8 @@
SamzaContainerMetrics containerMetrics,
TaskConfig taskConfig,
HighResolutionClock clock,
- int elasticityFactor) {
+ int elasticityFactor,
+ String runId) {
long taskWindowMs = taskConfig.getWindowMs();
@@ -65,6 +66,8 @@
log.info("Got elasticity factor: {}.", elasticityFactor);
+ log.info("Got current run Id: {}.", runId);
+
log.info("Run loop in asynchronous mode.");
return new RunLoop(
@@ -80,6 +83,7 @@
containerMetrics,
clock,
isAsyncCommitEnabled,
- elasticityFactor);
+ elasticityFactor,
+ runId);
}
}
diff --git a/samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java b/samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
index 551da88..af63494 100644
--- a/samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
+++ b/samza-core/src/main/java/org/apache/samza/container/RunLoopTask.java
@@ -99,6 +99,15 @@
void endOfStream(ReadableCoordinator coordinator);
/**
+ * Called when all {@link SystemStreamPartition} processed by a task have drained. This is called only
+ * once per task. {@link RunLoop} will issue a shutdown request to the coordinator immediately following the
+ * invocation of this method.
+ *
+ * @param coordinator manages execution of tasks.
+ */
+ void drain(ReadableCoordinator coordinator);
+
+ /**
* Indicates whether {@link #window} should be invoked on this task. If true, {@link RunLoop}
* will schedule window to execute periodically according to its windowMs.
*
diff --git a/samza-core/src/main/java/org/apache/samza/drain/DrainMonitor.java b/samza-core/src/main/java/org/apache/samza/drain/DrainMonitor.java
index 6b5c98e..b46fc3f 100644
--- a/samza-core/src/main/java/org/apache/samza/drain/DrainMonitor.java
+++ b/samza-core/src/main/java/org/apache/samza/drain/DrainMonitor.java
@@ -33,6 +33,7 @@
import org.apache.samza.SamzaException;
import org.apache.samza.config.ApplicationConfig;
import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
import org.apache.samza.coordinator.metadatastore.NamespaceAwareCoordinatorStreamStore;
import org.apache.samza.metadatastore.MetadataStore;
import org.slf4j.Logger;
@@ -65,7 +66,6 @@
STOPPED
}
- private static final int POLLING_INTERVAL_MILLIS = 60_000;
private static final int INITIAL_POLL_DELAY_MILLIS = 0;
private final ScheduledExecutorService schedulerService =
@@ -75,7 +75,6 @@
.setDaemon(true)
.build());
private final String appRunId;
- private final long pollingIntervalMillis;
private final NamespaceAwareCoordinatorStreamStore drainMetadataStore;
// Used to guard write access to state.
private final Object lock = new Object();
@@ -83,20 +82,23 @@
@GuardedBy("lock")
private State state = State.INIT;
private DrainCallback callback;
+ private long pollingIntervalMillis;
public DrainMonitor(MetadataStore metadataStore, Config config) {
- this(metadataStore, config, POLLING_INTERVAL_MILLIS);
- }
-
- public DrainMonitor(MetadataStore metadataStore, Config config, long pollingIntervalMillis) {
Preconditions.checkNotNull(metadataStore, "MetadataStore parameter cannot be null.");
Preconditions.checkNotNull(config, "Config parameter cannot be null.");
- Preconditions.checkArgument(pollingIntervalMillis > 0,
- String.format("Polling interval specified is %d ms. It should be greater than 0.", pollingIntervalMillis));
this.drainMetadataStore =
new NamespaceAwareCoordinatorStreamStore(metadataStore, DrainUtils.DRAIN_METADATA_STORE_NAMESPACE);
ApplicationConfig applicationConfig = new ApplicationConfig(config);
this.appRunId = applicationConfig.getRunId();
+ JobConfig jobConfig = new JobConfig(config);
+ this.pollingIntervalMillis = jobConfig.getDrainMonitorPollIntervalMillis();
+ }
+
+ public DrainMonitor(MetadataStore metadataStore, Config config, long pollingIntervalMillis) {
+ this(metadataStore, config);
+ Preconditions.checkArgument(pollingIntervalMillis > 0,
+ String.format("Polling interval specified is %d ms. It should be greater than 0.", pollingIntervalMillis));
this.pollingIntervalMillis = pollingIntervalMillis;
}
@@ -207,12 +209,12 @@
* One time check check to see if there are any DrainNotification messages available in the
* metadata store for the current deployment.
* */
- static boolean shouldDrain(NamespaceAwareCoordinatorStreamStore drainMetadataStore, String deploymentId) {
+ static boolean shouldDrain(NamespaceAwareCoordinatorStreamStore drainMetadataStore, String runId) {
final Optional<List<DrainNotification>> drainNotifications = readDrainNotificationMessages(drainMetadataStore);
if (drainNotifications.isPresent()) {
final ImmutableList<DrainNotification> filteredDrainNotifications = drainNotifications.get()
.stream()
- .filter(notification -> deploymentId.equals(notification.getDeploymentId()))
+ .filter(notification -> runId.equals(notification.getRunId()))
.collect(ImmutableList.toImmutableList());
return !filteredDrainNotifications.isEmpty();
}
diff --git a/samza-core/src/main/java/org/apache/samza/drain/DrainNotification.java b/samza-core/src/main/java/org/apache/samza/drain/DrainNotification.java
index a16595e..dd97cd7 100644
--- a/samza-core/src/main/java/org/apache/samza/drain/DrainNotification.java
+++ b/samza-core/src/main/java/org/apache/samza/drain/DrainNotification.java
@@ -33,26 +33,26 @@
/**
* Unique identifier for a deployment so drain notifications messages can be invalidated across a job restarts.
*/
- private final String deploymentId;
+ private final String runId;
- public DrainNotification(UUID uuid, String deploymentId) {
+ public DrainNotification(UUID uuid, String runId) {
this.uuid = uuid;
- this.deploymentId = deploymentId;
+ this.runId = runId;
}
public UUID getUuid() {
return this.uuid;
}
- public String getDeploymentId() {
- return deploymentId;
+ public String getRunId() {
+ return runId;
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("DrainMessage{");
sb.append(" UUID: ").append(uuid);
- sb.append(", deploymentId: '").append(deploymentId).append('\'');
+ sb.append(", runId: '").append(runId).append('\'');
sb.append('}');
return sb.toString();
}
@@ -67,11 +67,11 @@
}
DrainNotification that = (DrainNotification) o;
return Objects.equal(uuid, that.uuid)
- && Objects.equal(deploymentId, that.deploymentId);
+ && Objects.equal(runId, that.runId);
}
@Override
public int hashCode() {
- return Objects.hashCode(uuid, deploymentId);
+ return Objects.hashCode(uuid, runId);
}
}
diff --git a/samza-core/src/main/java/org/apache/samza/drain/DrainNotificationObjectMapper.java b/samza-core/src/main/java/org/apache/samza/drain/DrainNotificationObjectMapper.java
index 48e6db1..72cdfbb 100644
--- a/samza-core/src/main/java/org/apache/samza/drain/DrainNotificationObjectMapper.java
+++ b/samza-core/src/main/java/org/apache/samza/drain/DrainNotificationObjectMapper.java
@@ -67,7 +67,7 @@
throws IOException {
Map<String, Object> drainMessageMap = new HashMap<>();
drainMessageMap.put("uuid", value.getUuid().toString());
- drainMessageMap.put("deploymentId", value.getDeploymentId());
+ drainMessageMap.put("runId", value.getRunId());
jsonGenerator.writeObject(drainMessageMap);
}
}
@@ -79,8 +79,8 @@
ObjectCodec oc = jsonParser.getCodec();
JsonNode node = oc.readTree(jsonParser);
UUID uuid = UUID.fromString(node.get("uuid").textValue());
- String deploymentId = node.get("deploymentId").textValue();
- return new DrainNotification(uuid, deploymentId);
+ String runId = node.get("runId").textValue();
+ return new DrainNotification(uuid, runId);
}
}
}
diff --git a/samza-core/src/main/java/org/apache/samza/drain/DrainUtils.java b/samza-core/src/main/java/org/apache/samza/drain/DrainUtils.java
index c100a47..f712381 100644
--- a/samza-core/src/main/java/org/apache/samza/drain/DrainUtils.java
+++ b/samza-core/src/main/java/org/apache/samza/drain/DrainUtils.java
@@ -47,23 +47,23 @@
* Writes a {@link DrainNotification} to the underlying metastore. This method should be used by external controllers
* to issue a DrainNotification to the JobCoordinator and Samza Containers.
* @param metadataStore Metadata store to write drain notification to.
- * @param deploymentId deploymentId for the DrainNotification
+ * @param runId runId for the DrainNotification
*
* @return generated uuid for the DrainNotification
*/
- public static UUID writeDrainNotification(MetadataStore metadataStore, String deploymentId) {
+ public static UUID writeDrainNotification(MetadataStore metadataStore, String runId) {
Preconditions.checkArgument(metadataStore != null, "MetadataStore cannot be null.");
- Preconditions.checkArgument(!Strings.isNullOrEmpty(deploymentId), "deploymentId should be non-null.");
- LOG.info("Attempting to write DrainNotification to metadata-store for the deployment ID {}", deploymentId);
+ Preconditions.checkArgument(!Strings.isNullOrEmpty(runId), "runId should be non-null.");
+ LOG.info("Attempting to write DrainNotification to metadata-store for the deployment ID {}", runId);
final NamespaceAwareCoordinatorStreamStore drainMetadataStore =
new NamespaceAwareCoordinatorStreamStore(metadataStore, DRAIN_METADATA_STORE_NAMESPACE);
final ObjectMapper objectMapper = DrainNotificationObjectMapper.getObjectMapper();
final UUID uuid = UUID.randomUUID();
- final DrainNotification message = new DrainNotification(uuid, deploymentId);
+ final DrainNotification message = new DrainNotification(uuid, runId);
try {
drainMetadataStore.put(message.getUuid().toString(), objectMapper.writeValueAsBytes(message));
drainMetadataStore.flush();
- LOG.info("DrainNotification with id {} written to metadata-store for the deployment ID {}", uuid, deploymentId);
+ LOG.info("DrainNotification with id {} written to metadata-store for the deployment ID {}", uuid, runId);
} catch (Exception ex) {
throw new SamzaException(
String.format("DrainNotification might have been not written to metastore %s", message), ex);
@@ -73,23 +73,23 @@
/**
* Cleans up DrainNotifications for the current deployment from the underlying metadata store.
- * The current deploymentId is extracted from the config.
+ * The current runId is extracted from the config.
*
* @param metadataStore underlying metadata store
- * @param config Config for the job. Used to extract the deploymentId of the job.
+ * @param config Config for the job. Used to extract the runId of the job.
* */
public static void cleanup(MetadataStore metadataStore, Config config) {
Preconditions.checkArgument(metadataStore != null, "MetadataStore cannot be null.");
Preconditions.checkNotNull(config, "Config parameter cannot be null.");
final ApplicationConfig applicationConfig = new ApplicationConfig(config);
- final String deploymentId = applicationConfig.getRunId();
+ final String runId = applicationConfig.getRunId();
final ObjectMapper objectMapper = DrainNotificationObjectMapper.getObjectMapper();
final NamespaceAwareCoordinatorStreamStore drainMetadataStore =
new NamespaceAwareCoordinatorStreamStore(metadataStore, DRAIN_METADATA_STORE_NAMESPACE);
- if (DrainMonitor.shouldDrain(drainMetadataStore, deploymentId)) {
- LOG.info("Attempting to clean up DrainNotifications from the metadata-store for the current deployment {}", deploymentId);
+ if (DrainMonitor.shouldDrain(drainMetadataStore, runId)) {
+ LOG.info("Attempting to clean up DrainNotifications from the metadata-store for the current deployment {}", runId);
drainMetadataStore.all()
.values()
.stream()
@@ -101,19 +101,19 @@
throw new SamzaException(e);
}
})
- .filter(notification -> deploymentId.equals(notification.getDeploymentId()))
+ .filter(notification -> runId.equals(notification.getRunId()))
.forEach(notification -> drainMetadataStore.delete(notification.getUuid().toString()));
drainMetadataStore.flush();
- LOG.info("Successfully cleaned up DrainNotifications from the metadata-store for the current deployment {}", deploymentId);
+ LOG.info("Successfully cleaned up DrainNotifications from the metadata-store for the current deployment {}", runId);
} else {
LOG.info("No DrainNotification found in the metadata-store for the current deployment {}. No need to cleanup.",
- deploymentId);
+ runId);
}
}
/**
- * Cleans up all DrainNotifications irrespective of the deploymentId.
+ * Cleans up all DrainNotifications irrespective of the runId.
* */
public static void cleanupAll(MetadataStore metadataStore) {
Preconditions.checkArgument(metadataStore != null, "MetadataStore cannot be null.");
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java
index 4f93f2c..7dd5d19 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/BroadcastOperatorImpl.java
@@ -21,10 +21,12 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
+import org.apache.samza.config.ApplicationConfig;
import org.apache.samza.context.Context;
import org.apache.samza.operators.spec.BroadcastOperatorSpec;
import org.apache.samza.operators.spec.OperatorSpec;
import org.apache.samza.system.ControlMessage;
+import org.apache.samza.system.DrainMessage;
import org.apache.samza.system.EndOfStreamMessage;
import org.apache.samza.system.OutgoingMessageEnvelope;
import org.apache.samza.system.SystemStream;
@@ -40,11 +42,13 @@
private final BroadcastOperatorSpec<M> broadcastOpSpec;
private final SystemStream systemStream;
private final String taskName;
+ private final String runId;
BroadcastOperatorImpl(BroadcastOperatorSpec<M> broadcastOpSpec, SystemStream systemStream, Context context) {
this.broadcastOpSpec = broadcastOpSpec;
this.systemStream = systemStream;
this.taskName = context.getTaskContext().getTaskModel().getTaskName().getTaskName();
+ this.runId = new ApplicationConfig(context.getJobContext().getConfig()).getRunId();
}
@Override
@@ -74,6 +78,12 @@
}
@Override
+ protected Collection<Void> handleDrain(MessageCollector collector, TaskCoordinator coordinator) {
+ sendControlMessage(new DrainMessage(taskName, runId), collector);
+ return Collections.emptyList();
+ }
+
+ @Override
protected Collection<Void> handleWatermark(long watermark, MessageCollector collector, TaskCoordinator coordinator) {
sendControlMessage(new WatermarkMessage(watermark, taskName), collector);
return Collections.emptyList();
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/DrainStates.java b/samza-core/src/main/java/org/apache/samza/operators/impl/DrainStates.java
new file mode 100644
index 0000000..30a341f
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/DrainStates.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.operators.impl;
+
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.samza.system.DrainMessage;
+import org.apache.samza.system.SystemStream;
+import org.apache.samza.system.SystemStreamPartition;
+
+
+/**
+ * This class tracks the drain state for streams in a task. Internally it keeps track of Drain messages received
+ * from upstream tasks for each system stream partition (ssp). If messages have been received from all tasks,
+ * it will mark the ssp as drained. For a stream to be drained, all its partitions assigned to
+ * the task need to be drained.
+ *
+ * This class is thread-safe.
+ */
+public class DrainStates {
+ private static final class DrainState {
+ // set of upstream tasks
+ private final Set<String> tasks = new HashSet<>();
+ private final int expectedTotal;
+ private volatile boolean drained = false;
+
+ DrainState(int expectedTotal) {
+ this.expectedTotal = expectedTotal;
+ }
+
+ synchronized void update(String taskName) {
+ if (taskName != null) {
+ // aggregate the eos messages
+ tasks.add(taskName);
+ drained = tasks.size() == expectedTotal;
+ } else {
+ // eos is coming from either source or aggregator task
+ drained = true;
+ }
+ }
+
+ boolean isDrained() {
+ return drained;
+ }
+
+ @Override
+ public String toString() {
+ return "DrainState: [Tasks : "
+ + tasks
+ + ", isDrained : "
+ + drained
+ + "]";
+ }
+ }
+
+ private final Map<SystemStreamPartition, DrainState> drainStates;
+
+ /**
+ * Constructing the drain states for a task.
+ * @param ssps all the ssps assigned to this task
+ * @param producerTaskCounts mapping from a stream to the number of upstream tasks that produce to it
+ */
+ DrainStates(Set<SystemStreamPartition> ssps, Map<SystemStream, Integer> producerTaskCounts) {
+ this.drainStates = ssps.stream()
+ .collect(Collectors.toMap(
+ ssp -> ssp,
+ ssp -> new DrainState(producerTaskCounts.getOrDefault(ssp.getSystemStream(), 0))));
+ }
+
+ /**
+ * Update the state upon receiving a drain message.
+ * @param eos message of {@link DrainMessage}
+ * @param ssp system stream partition
+ */
+ void update(DrainMessage eos, SystemStreamPartition ssp) {
+ DrainState state = drainStates.get(ssp);
+ state.update(eos.getTaskName());
+ }
+
+ /**
+ * Checks if the system-stream is drained.
+ * */
+ boolean isDrained(SystemStream systemStream) {
+ return drainStates.entrySet().stream()
+ .filter(entry -> entry.getKey().getSystemStream().equals(systemStream))
+ .allMatch(entry -> entry.getValue().isDrained());
+ }
+
+ /**
+ * Checks if all streams (input SSPs) for the task has drained.
+ * */
+ boolean areAllStreamsDrained() {
+ return drainStates.values().stream().allMatch(DrainState::isDrained);
+ }
+
+ @Override
+ public String toString() {
+ return drainStates.toString();
+ }
+}
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
index c404475..5723d91 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java
@@ -1,4 +1,4 @@
-/*
+ /*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
@@ -39,6 +39,7 @@
import org.apache.samza.operators.functions.WatermarkFunction;
import org.apache.samza.operators.spec.OperatorSpec;
import org.apache.samza.scheduler.CallbackScheduler;
+import org.apache.samza.system.DrainMessage;
import org.apache.samza.system.EndOfStreamMessage;
import org.apache.samza.system.SystemStream;
import org.apache.samza.system.SystemStreamPartition;
@@ -86,6 +87,8 @@
private TaskModel taskModel;
// end-of-stream states
private EndOfStreamStates eosStates;
+ // drain states
+ private DrainStates drainStates;
// watermark states
private WatermarkStates watermarkStates;
private CallbackScheduler callbackScheduler;
@@ -125,6 +128,7 @@
this.taskName = taskContext.getTaskModel().getTaskName();
this.eosStates = (EndOfStreamStates) internalTaskContext.fetchObject(EndOfStreamStates.class.getName());
this.watermarkStates = (WatermarkStates) internalTaskContext.fetchObject(WatermarkStates.class.getName());
+ this.drainStates = (DrainStates) internalTaskContext.fetchObject(DrainStates.class.getName());
this.controlMessageSender = new ControlMessageSender(internalTaskContext.getStreamMetadataCache());
this.taskModel = taskContext.getTaskModel();
this.callbackScheduler = taskContext.getCallbackScheduler();
@@ -362,6 +366,88 @@
}
/**
+ * This method is implemented when all input stream to this operation have encountered drain-and-stop control message.
+ * Inherited class should handle drain-and-stop by overriding this function.
+ * By default noop implementation is for in-memory operator to handle the drain-and-stop. Output operator need to
+ * override this to actually propagate drain-and-stop over the wire.
+ * @param collector message collector
+ * @param coordinator task coordinator
+ * @return results to be emitted when this operator reaches drain-and-stop
+ */
+ protected Collection<RM> handleDrain(MessageCollector collector, TaskCoordinator coordinator) {
+ return Collections.emptyList();
+ }
+
+ /**
+ * Aggregate {@link DrainMessage} from each ssp of the stream.
+ * Invoke {@link #onDrainOfStream(MessageCollector, TaskCoordinator)} if the stream reaches the end.
+ * @param drainMessage {@link DrainMessage} object
+ * @param ssp system stream partition
+ * @param collector message collector
+ * @param coordinator task coordinator
+ */
+ public final CompletionStage<Void> aggregateDrainMessages(DrainMessage drainMessage, SystemStreamPartition ssp,
+ MessageCollector collector, TaskCoordinator coordinator) {
+ LOG.info("Received drain message from task {} in {}", drainMessage.getTaskName(), ssp);
+ drainStates.update(drainMessage, ssp);
+
+ SystemStream stream = ssp.getSystemStream();
+ CompletionStage<Void> drainFuture = CompletableFuture.completedFuture(null);
+
+ if (drainStates.isDrained(stream)) {
+ LOG.info("Input {} reaches the end for task {}", stream.toString(), taskName.getTaskName());
+ if (drainMessage.getTaskName() != null && shouldTaskBroadcastToOtherPartitions(ssp)) {
+ // This is the aggregation task, which already received all the eos messages from upstream
+ // broadcast the end-of-stream to all the peer partitions
+ // additionally if elasiticty is enabled
+ // then only one of the elastic tasks of the ssp will broadcast
+ controlMessageSender.broadcastToOtherPartitions(new DrainMessage(drainMessage.getRunId()), ssp, collector);
+ }
+
+ drainFuture = onDrainOfStream(collector, coordinator)
+ .thenAccept(result -> {
+ if (drainStates.areAllStreamsDrained()) {
+ // all input streams have been drained, shut down the task
+ LOG.info("All input streams have been drained for task {}", taskName.getTaskName());
+ coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
+ }
+ });
+ }
+
+ return drainFuture;
+ }
+
+
+ /**
+ * Invoke {@link #handleDrain(MessageCollector, TaskCoordinator)} if all the input streams to the current operator
+ * have encountered drain message.
+ * Propagate the drain to downstream operators.
+ * @param collector message collector
+ * @param coordinator task coordinator
+ */
+ private CompletionStage<Void> onDrainOfStream(MessageCollector collector, TaskCoordinator coordinator) {
+ CompletionStage<Void> drainFuture = CompletableFuture.completedFuture(null);
+
+ if (inputStreams.stream().allMatch(input -> drainStates.isDrained(input))) {
+ Collection<RM> results = handleDrain(collector, coordinator);
+
+ CompletionStage<Void> resultFuture = CompletableFuture.allOf(
+ results.stream()
+ .flatMap(r -> this.registeredOperators.stream()
+ .map(op -> op.onMessageAsync(r, collector, coordinator)))
+ .toArray(CompletableFuture[]::new));
+
+ // propagate DrainMessage to downstream operators
+ drainFuture = resultFuture.thenCompose(x ->
+ CompletableFuture.allOf(this.registeredOperators.stream()
+ .map(op -> op.onDrainOfStream(collector, coordinator))
+ .toArray(CompletableFuture[]::new)));
+ }
+
+ return drainFuture;
+ }
+
+ /**
* Aggregate the {@link WatermarkMessage} from each ssp into a watermark. Then call onWatermark() if
* a new watermark exits.
* @param watermarkMessage a {@link WatermarkMessage} object
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java
index 1aa67d8..c62b0b2 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImplGraph.java
@@ -115,11 +115,16 @@
// set states for end-of-stream; don't include side inputs (see SAMZA-2303)
internalTaskContext.registerObject(EndOfStreamStates.class.getName(),
new EndOfStreamStates(internalTaskContext.getSspsExcludingSideInputs(), producerTaskCounts));
+
// set states for watermark; don't include side inputs (see SAMZA-2303)
internalTaskContext.registerObject(WatermarkStates.class.getName(),
new WatermarkStates(internalTaskContext.getSspsExcludingSideInputs(), producerTaskCounts,
context.getContainerContext().getContainerMetricsRegistry()));
+ // set states for drain; don't include side inputs (see SAMZA-2303)
+ internalTaskContext.registerObject(DrainStates.class.getName(),
+ new DrainStates(internalTaskContext.getSspsExcludingSideInputs(), producerTaskCounts));
+
specGraph.getInputOperators().forEach((streamId, inputOpSpec) -> {
SystemStream systemStream = streamConfig.streamIdToSystemStream(streamId);
InputOperatorImpl inputOperatorImpl =
diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java
index 47ad4f6..643abdf 100644
--- a/samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java
+++ b/samza-core/src/main/java/org/apache/samza/operators/impl/PartitionByOperatorImpl.java
@@ -20,12 +20,14 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
+import org.apache.samza.config.ApplicationConfig;
import org.apache.samza.context.Context;
import org.apache.samza.context.InternalTaskContext;
import org.apache.samza.operators.functions.MapFunction;
import org.apache.samza.operators.spec.OperatorSpec;
import org.apache.samza.operators.spec.PartitionByOperatorSpec;
import org.apache.samza.system.ControlMessage;
+import org.apache.samza.system.DrainMessage;
import org.apache.samza.system.EndOfStreamMessage;
import org.apache.samza.system.OutgoingMessageEnvelope;
import org.apache.samza.system.StreamMetadataCache;
@@ -48,6 +50,7 @@
private final MapFunction<? super M, ? extends K> keyFunction;
private final MapFunction<? super M, ? extends V> valueFunction;
private final String taskName;
+ private final String runId;
private final ControlMessageSender controlMessageSender;
PartitionByOperatorImpl(PartitionByOperatorSpec<M, K, V> partitionByOpSpec,
@@ -57,6 +60,7 @@
this.keyFunction = partitionByOpSpec.getKeyFunction();
this.valueFunction = partitionByOpSpec.getValueFunction();
this.taskName = internalTaskContext.getContext().getTaskContext().getTaskModel().getTaskName().getTaskName();
+ this.runId = new ApplicationConfig(internalTaskContext.getContext().getJobContext().getConfig()).getRunId();
StreamMetadataCache streamMetadataCache = internalTaskContext.getStreamMetadataCache();
this.controlMessageSender = new ControlMessageSender(streamMetadataCache);
}
@@ -95,6 +99,12 @@
}
@Override
+ protected Collection<Void> handleDrain(MessageCollector collector, TaskCoordinator coordinator) {
+ sendControlMessage(new DrainMessage(taskName, runId), collector);
+ return Collections.emptyList();
+ }
+
+ @Override
protected Collection<Void> handleWatermark(long watermark, MessageCollector collector, TaskCoordinator coordinator) {
sendControlMessage(new WatermarkMessage(watermark, taskName), collector);
return Collections.emptyList();
diff --git a/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java b/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
index 654d9a0..cda04377 100644
--- a/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
+++ b/samza-core/src/main/java/org/apache/samza/processor/StreamProcessor.java
@@ -48,6 +48,7 @@
import org.apache.samza.coordinator.JobCoordinatorFactory;
import org.apache.samza.coordinator.JobCoordinatorListener;
import org.apache.samza.diagnostics.DiagnosticsManager;
+import org.apache.samza.drain.DrainMonitor;
import org.apache.samza.job.model.JobModel;
import org.apache.samza.metadatastore.MetadataStore;
import org.apache.samza.metrics.MetricsRegistry;
@@ -402,12 +403,18 @@
*/
MetricsRegistryMap metricsRegistryMap = new MetricsRegistryMap();
+ DrainMonitor drainMonitor = null;
+ JobConfig jobConfig = new JobConfig(config);
+ if (metadataStore != null && jobConfig.getDrainMonitorEnabled()) {
+ drainMonitor = new DrainMonitor(metadataStore, config);
+ }
+
return SamzaContainer.apply(processorId, jobModel, ScalaJavaUtil.toScalaMap(this.customMetricsReporter),
metricsRegistryMap, this.taskFactory, JobContextImpl.fromConfigWithDefaults(this.config, jobModel),
Option.apply(this.applicationDefinedContainerContextFactoryOptional.orElse(null)),
Option.apply(this.applicationDefinedTaskContextFactoryOptional.orElse(null)),
Option.apply(this.externalContextOptional.orElse(null)), null, startpointManager,
- Option.apply(diagnosticsManager.orElse(null)), null);
+ Option.apply(diagnosticsManager.orElse(null)), drainMonitor);
}
private static JobCoordinator createJobCoordinator(Config config, String processorId, MetricsRegistry metricsRegistry, MetadataStore metadataStore) {
diff --git a/samza-core/src/main/java/org/apache/samza/runtime/ContainerLaunchUtil.java b/samza-core/src/main/java/org/apache/samza/runtime/ContainerLaunchUtil.java
index f499eb3..cd153b5 100644
--- a/samza-core/src/main/java/org/apache/samza/runtime/ContainerLaunchUtil.java
+++ b/samza-core/src/main/java/org/apache/samza/runtime/ContainerLaunchUtil.java
@@ -143,7 +143,8 @@
MetricsRegistryMap metricsRegistryMap = new MetricsRegistryMap();
DrainMonitor drainMonitor = null;
- if (new JobConfig(config).getDrainMonitorEnabled()) {
+ JobConfig jobConfig = new JobConfig(config);
+ if (jobConfig.getDrainMonitorEnabled()) {
drainMonitor = new DrainMonitor(coordinatorStreamStore, config);
}
diff --git a/samza-core/src/main/java/org/apache/samza/serializers/IntermediateMessageSerde.java b/samza-core/src/main/java/org/apache/samza/serializers/IntermediateMessageSerde.java
index 83a0a35..a7afecd 100644
--- a/samza-core/src/main/java/org/apache/samza/serializers/IntermediateMessageSerde.java
+++ b/samza-core/src/main/java/org/apache/samza/serializers/IntermediateMessageSerde.java
@@ -22,6 +22,7 @@
import java.util.Arrays;
import org.apache.samza.SamzaException;
+import org.apache.samza.system.DrainMessage;
import org.apache.samza.system.EndOfStreamMessage;
import org.apache.samza.system.MessageType;
import org.apache.samza.system.WatermarkMessage;
@@ -56,11 +57,13 @@
private final Serde userMessageSerde;
private final Serde<WatermarkMessage> watermarkSerde;
private final Serde<EndOfStreamMessage> eosSerde;
+ private final Serde<DrainMessage> drainMessageSerde;
public IntermediateMessageSerde(Serde userMessageSerde) {
this.userMessageSerde = userMessageSerde;
this.watermarkSerde = new JsonSerdeV2<>(WatermarkMessage.class);
this.eosSerde = new JsonSerdeV2<>(EndOfStreamMessage.class);
+ this.drainMessageSerde = new JsonSerdeV2<>(DrainMessage.class);
}
@Override
@@ -93,6 +96,9 @@
case END_OF_STREAM:
object = eosSerde.fromBytes(data);
break;
+ case DRAIN:
+ object = drainMessageSerde.fromBytes(data);
+ break;
default:
throw new UnsupportedOperationException(String.format("Message type %s is not supported", type.name()));
}
@@ -118,6 +124,9 @@
case END_OF_STREAM:
data = eosSerde.toBytes((EndOfStreamMessage) object);
break;
+ case DRAIN:
+ data = drainMessageSerde.toBytes((DrainMessage) object);
+ break;
default:
throw new SamzaException("Unknown message type: " + type.name());
}
diff --git a/samza-core/src/main/java/org/apache/samza/storage/SideInputTask.java b/samza-core/src/main/java/org/apache/samza/storage/SideInputTask.java
index 6274c15..981e847 100644
--- a/samza-core/src/main/java/org/apache/samza/storage/SideInputTask.java
+++ b/samza-core/src/main/java/org/apache/samza/storage/SideInputTask.java
@@ -97,6 +97,11 @@
}
@Override
+ public void drain(ReadableCoordinator coordinator) {
+ LOG.info("Task {} has drained", this.taskName);
+ }
+
+ @Override
public boolean isWindowableTask() {
return false;
}
diff --git a/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java b/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java
index 0079fab..0db960a 100644
--- a/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java
+++ b/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java
@@ -31,6 +31,7 @@
import org.apache.samza.system.EndOfStreamMessage;
import org.apache.samza.system.IncomingMessageEnvelope;
import org.apache.samza.system.MessageType;
+import org.apache.samza.system.DrainMessage;
import org.apache.samza.system.SystemStream;
import org.apache.samza.system.WatermarkMessage;
import org.apache.samza.util.Clock;
@@ -129,6 +130,12 @@
inputOpImpl.aggregateEndOfStream(eosMessage, ime.getSystemStreamPartition(), collector, coordinator);
break;
+ case DRAIN:
+ DrainMessage drainMessage = (DrainMessage) ime.getMessage();
+ processFuture =
+ inputOpImpl.aggregateDrainMessages(drainMessage, ime.getSystemStreamPartition(), collector, coordinator);
+ break;
+
case WATERMARK:
WatermarkMessage watermarkMessage = (WatermarkMessage) ime.getMessage();
processFuture = inputOpImpl.aggregateWatermark(watermarkMessage, ime.getSystemStreamPartition(), collector,
diff --git a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
index c157b87..f0aa20b 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/SamzaContainer.scala
@@ -27,7 +27,6 @@
import java.util.concurrent._
import java.util.function.Consumer
import java.util.{Base64, Optional}
-import com.google.common.annotations.VisibleForTesting
import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.samza.SamzaException
import org.apache.samza.checkpoint.{CheckpointListener, OffsetManager, OffsetManagerMetrics}
@@ -426,6 +425,8 @@
val pollIntervalMs = taskConfig.getPollIntervalMs
+ val appConfig = new ApplicationConfig(config)
+
val consumerMultiplexer = new SystemConsumers(
chooser = chooser,
consumers = consumers,
@@ -435,7 +436,8 @@
dropDeserializationError = dropDeserializationError,
pollIntervalMs = pollIntervalMs,
clock = () => clock.nanoTime(),
- elasticityFactor = jobConfig.getElasticityFactor)
+ elasticityFactor = jobConfig.getElasticityFactor,
+ runId = appConfig.getRunId)
val producerMultiplexer = new SystemProducers(
producers = producers,
@@ -622,7 +624,7 @@
val maxThrottlingDelayMs = config.getLong("container.disk.quota.delay.max.ms", TimeUnit.SECONDS.toMillis(1))
- val runLoop = RunLoopFactory.createRunLoop(
+ val runLoop: Runnable = RunLoopFactory.createRunLoop(
taskInstances,
consumerMultiplexer,
taskThreadPool,
@@ -630,7 +632,8 @@
samzaContainerMetrics,
taskConfig,
clock,
- jobConfig.getElasticityFactor)
+ jobConfig.getElasticityFactor,
+ appConfig.getRunId)
val containerMemoryMb : Int = new ClusterManagerConfig(config).getContainerMemoryMb
@@ -747,6 +750,7 @@
def getStatus(): SamzaContainerStatus = status
def drain() {
+ // set the SystemConsumers multiplexer and RunLoop in drain mode
consumerMultiplexer.drain
}
diff --git a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
index d75d911..606f32a 100644
--- a/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
+++ b/samza-core/src/main/scala/org/apache/samza/container/TaskInstance.scala
@@ -72,6 +72,7 @@
val taskName: TaskName = taskModel.getTaskName
val isInitableTask = task.isInstanceOf[InitableTask]
+ val isDrainTask = task.isInstanceOf[DrainListenerTask]
val isEndOfStreamListenerTask = task.isInstanceOf[EndOfStreamListenerTask]
val isClosableTask = task.isInstanceOf[ClosableTask]
@@ -111,6 +112,12 @@
val streamConfig: StreamConfig = new StreamConfig(config)
override val intermediateStreams: java.util.Set[String] = JavaConverters.setAsJavaSetConverter(streamConfig.getStreamIds.filter(streamConfig.getIsIntermediateStream)).asJava
+ val intermediateSSPs: Set[SystemStreamPartition] = systemStreamPartitions.filter(ssp => {
+ val systemStream = ssp.getSystemStream
+ val streamId = streamConfig.systemStreamToStreamId(systemStream)
+ intermediateStreams.contains(streamId)
+ }).toSet
+
val streamsToDeleteCommittedMessages: Set[String] = streamConfig.getStreamIds.filter(streamConfig.getDeleteCommittedMessages).map(streamConfig.getPhysicalName).toSet
val checkpointWriteVersions = new TaskConfig(config).getCheckpointWriteVersions
@@ -189,20 +196,27 @@
}
/**
- * Computes the starting offset for the partitions assigned to the task and registers them with the underlying {@see SystemConsumers}.
- *
- * Starting offset for a partition of the task is computed in the following manner:
- *
- * 1. If a startpoint exists for a task, system stream partition and it resolves to a offset, then the resolved offset is used as the starting offset.
- * 2. Else, the checkpointed offset for the system stream partition is used as the starting offset.
+ * This method registers the following with the underlying {@see SystemConsumers}:
+ * a) starting offsets for all SSPs assigned to the task
+ * b) intermediate SSPs assigned to the task
+ *
+ * Starting offset for a partition of the task is computed in the following manner:
+ *
+ * 1. If a startpoint exists for a task, system stream partition and it resolves to a offset, then the resolved offset is used as the starting offset.
+ * 2. Else, the checkpointed offset for the system stream partition is used as the starting offset.
*/
def registerConsumers() {
debug("Registering consumers for taskName: %s" format taskName)
+
systemStreamPartitions.foreach(systemStreamPartition => {
val startingOffset: String = getStartingOffset(systemStreamPartition)
consumerMultiplexer.register(systemStreamPartition, startingOffset)
metrics.addOffsetGauge(systemStreamPartition, () => offsetManager.getLastProcessedOffset(taskName, systemStreamPartition).orNull)
})
+
+ intermediateSSPs.foreach(ssp => {
+ consumerMultiplexer.registerIntermediateSSP(ssp)
+ })
}
def process(envelope: IncomingMessageEnvelope, coordinator: ReadableCoordinator,
@@ -241,6 +255,16 @@
}
}
+ def drain(coordinator: ReadableCoordinator): Unit = {
+ task match {
+ case _: DrainListenerTask =>
+ exceptionHandler.maybeHandle {
+ task.asInstanceOf[DrainListenerTask].onDrain(collector, coordinator)
+ }
+ case _ =>
+ }
+ }
+
def window(coordinator: ReadableCoordinator) {
if (isWindowableTask) {
trace("Windowing for taskName: %s" format taskName)
@@ -613,14 +637,19 @@
// if elasticityFactor > 1, find the SSP with keyBucket
var incomingMessageSsp = envelope.getSystemStreamPartition(elasticityFactor)
- // if envelope is end of stream or watermark, it needs to be routed to all tasks consuming the ssp irresp of keyBucket
+ // if envelope is end of stream or watermark or drain,
+ // it needs to be routed to all tasks consuming the ssp irresp of keyBucket
val messageType = MessageType.of(envelope.getMessage)
- if (envelope.isEndOfStream() || MessageType.END_OF_STREAM == messageType || MessageType.WATERMARK == messageType) {
+ if (envelope.isEndOfStream()
+ || envelope.isDrain()
+ || messageType == MessageType.END_OF_STREAM
+ || messageType == MessageType.WATERMARK) {
+
incomingMessageSsp = systemStreamPartitions
.filter(ssp => ssp.getSystemStream.equals(incomingMessageSsp.getSystemStream)
&& ssp.getPartition.equals(incomingMessageSsp.getPartition))
.toIterator.next()
- debug("for watermark or end of stream envelope, found incoming ssp as {}" format incomingMessageSsp)
+ debug("for watermark or end-of-stream or drain envelope, found incoming ssp as {}".format(incomingMessageSsp))
}
incomingMessageSsp
}
diff --git a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
index f1c476a..ac1a558 100644
--- a/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
+++ b/samza-core/src/main/scala/org/apache/samza/job/local/ThreadJobFactory.scala
@@ -89,7 +89,7 @@
}
var drainMonitor: DrainMonitor = null
if (jobConfig.getDrainMonitorEnabled()) {
- drainMonitor = new DrainMonitor(coordinatorStreamStore, config)
+ drainMonitor = new DrainMonitor(coordinatorStreamStore, config, jobConfig.getDrainMonitorPollIntervalMillis)
}
val containerId = "0"
diff --git a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
index c2e02a4..40b2999 100644
--- a/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
+++ b/samza-core/src/main/scala/org/apache/samza/storage/ContainerStorageManager.java
@@ -45,6 +45,7 @@
import org.apache.samza.checkpoint.Checkpoint;
import org.apache.samza.checkpoint.CheckpointManager;
import org.apache.samza.checkpoint.CheckpointV2;
+import org.apache.samza.config.ApplicationConfig;
import org.apache.samza.config.Config;
import org.apache.samza.config.JobConfig;
import org.apache.samza.config.StorageConfig;
@@ -120,6 +121,8 @@
private static final int RESTORE_THREAD_POOL_SHUTDOWN_TIMEOUT_SECONDS = 60;
+ private static final int DEFAULT_SIDE_INPUT_ELASTICITY_FACTOR = 1;
+
/** Maps containing relevant per-task objects */
private final Map<TaskName, TaskInstanceMetrics> taskInstanceMetrics;
private final Map<TaskName, TaskInstanceCollector> taskInstanceCollectors;
@@ -297,11 +300,13 @@
MessageChooser chooser = DefaultChooser.apply(inputStreamMetadata, new RoundRobinChooserFactory(), config,
sideInputSystemConsumersMetrics.registry(), systemAdmins);
+ ApplicationConfig applicationConfig = new ApplicationConfig(config);
+
sideInputSystemConsumers =
new SystemConsumers(chooser, ScalaJavaUtil.toScalaMap(sideInputConsumers), systemAdmins, serdeManager,
sideInputSystemConsumersMetrics, SystemConsumers.DEFAULT_NO_NEW_MESSAGES_TIMEOUT(), SystemConsumers.DEFAULT_DROP_SERIALIZATION_ERROR(),
TaskConfig.DEFAULT_POLL_INTERVAL_MS, ScalaJavaUtil.toScalaFunction(() -> System.nanoTime()),
- JobConfig.DEFAULT_JOB_ELASTICITY_FACTOR);
+ JobConfig.DEFAULT_JOB_ELASTICITY_FACTOR, applicationConfig.getRunId());
}
}
@@ -922,6 +927,8 @@
new SamzaContainerMetrics(SIDEINPUTS_METRICS_PREFIX + this.samzaContainerMetrics.source(),
this.samzaContainerMetrics.registry(), SIDEINPUTS_METRICS_PREFIX);
+ ApplicationConfig applicationConfig = new ApplicationConfig(config);
+
this.sideInputRunLoop = new RunLoop(sideInputTasks,
null, // all operations are executed in the main runloop thread
this.sideInputSystemConsumers,
@@ -934,7 +941,9 @@
taskConfig.getMaxIdleMs(),
sideInputContainerMetrics,
System::nanoTime,
- false); // commit must be synchronous to ensure integrity of state flush
+ false,
+ DEFAULT_SIDE_INPUT_ELASTICITY_FACTOR,
+ applicationConfig.getRunId()); // commit must be synchronous to ensure integrity of state flush
try {
sideInputsExecutor.submit(() -> {
diff --git a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
index bcecebb..d1b639c 100644
--- a/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
+++ b/samza-core/src/main/scala/org/apache/samza/system/SystemConsumers.scala
@@ -28,9 +28,8 @@
import java.util.HashSet
import java.util.Queue
import java.util.Set
-import java.util.function.Predicate
+import java.util.function.{Consumer}
import java.util.stream.Collectors
-
import scala.collection.JavaConverters._
import org.apache.samza.serializers.SerdeManager
import org.apache.samza.util.{Logging, TimerUtil}
@@ -38,6 +37,7 @@
import org.apache.samza.SamzaException
import org.apache.samza.config.TaskConfig
+
object SystemConsumers {
val DEFAULT_NO_NEW_MESSAGES_TIMEOUT = 10
val DEFAULT_DROP_SERIALIZATION_ERROR = false
@@ -116,7 +116,12 @@
*/
val clock: () => Long = () => System.nanoTime(),
- val elasticityFactor: Int = 1) extends Logging with TimerUtil {
+ val elasticityFactor: Int = 1,
+
+ /**
+ * Identifier of the current deployment.
+ * */
+ val runId: String = null) extends Logging with TimerUtil {
/**
* Mapping from the {@see SystemStreamPartition} to the registered offsets.
@@ -130,6 +135,10 @@
*/
private val sspKeyBucketsRegistered = new HashSet[SystemStreamPartition] ()
+ private val intermediateSSPs = new HashSet[SystemStreamPartition]()
+
+ private val intermediateSystems = new HashSet[String]()
+
/**
* A buffer of incoming messages grouped by SystemStreamPartition. These
* messages are handed out to the MessageChooser as it needs them.
@@ -155,10 +164,8 @@
*/
private var started = false
- /**
- * Denotes if the SystemConsumers is in drain mode.
- * */
- private var draining = false
+ @volatile
+ private var isDraining = false
/**
* Default timeout to noNewMessagesTimeout. Every time SystemConsumers
@@ -193,8 +200,6 @@
// but the actual systemConsumer which consumes from the input does not know about KeyBucket.
// hence, use an SSP without KeyBucket
consumer.register(removeKeyBucket(systemStreamPartition), offset)
- chooser.register(removeKeyBucket(systemStreamPartition), offset)
- debug("consumer.register and chooser.register for ssp: %s with offset %s" format (systemStreamPartition, offset))
}
debug("Starting consumers.")
@@ -214,15 +219,12 @@
chooser.start
+
started = true
refresh
}
- def drain: Unit = {
- draining = true
- }
-
def stop {
if (started) {
debug("Stopping consumers.")
@@ -237,6 +239,14 @@
}
}
+ def drain(): Unit = {
+ if (!isDraining) {
+ isDraining = true;
+ info("SystemConsumers is set to drain mode.")
+ consumers.values.foreach(_.stop)
+ writeDrainControlMessageToSspQueue()
+ }
+ }
def register(ssp: SystemStreamPartition, offset: String) {
// If elasticity is enabled then the RunLoop gives SSP with keybucket
@@ -255,6 +265,8 @@
metrics.registerSystemStreamPartition(systemStreamPartition)
unprocessedMessagesBySSP.put(systemStreamPartition, new ArrayDeque[IncomingMessageEnvelope]())
+ chooser.register(systemStreamPartition, offset)
+
try {
val consumer = consumers(systemStreamPartition.getSystem)
val existingOffset = sspToRegisteredOffsets.get(systemStreamPartition)
@@ -268,12 +280,17 @@
}
}
+ def registerIntermediateSSP(ssp: SystemStreamPartition): Unit = {
+ debug("Registering intermediate stream: %s" format ssp)
+ intermediateSSPs.add(ssp)
+ intermediateSystems.add(ssp.getSystem)
+ }
def isEndOfStream(systemStreamPartition: SystemStreamPartition) = {
endOfStreamSSPs.contains(removeKeyBucket(systemStreamPartition))
}
- def choose (updateChooser: Boolean = true): IncomingMessageEnvelope = {
+ def choose(updateChooser: Boolean = true): IncomingMessageEnvelope = {
val envelopeFromChooser = chooser.choose
updateTimer(metrics.deserializationNs) {
@@ -398,16 +415,54 @@
}
private def refresh {
- if (draining) {
- trace("Skipping refresh of chooser as the multiplexer is in drain mode.")
- return
- }
- trace("Refreshing chooser with new messages.")
-
// Update last poll time so we don't poll too frequently.
lastPollNs = clock()
- // Poll every system for new messages.
- consumers.keys.map(poll(_))
+
+ if (isDraining) {
+ trace("Refreshing chooser with new messages from intermediate systems.")
+
+ // scala 2.11 doesn't allow using syntactical sugar: intermediateSystems.foreach(poll(_)) over java collections
+ intermediateSystems.forEach(new Consumer[String] {
+ override def accept(system: String): Unit = poll(system)
+ })
+ } else {
+ trace("Refreshing chooser with new messages.")
+ consumers.keys.foreach(poll(_))
+ }
+ }
+
+ private def writeDrainControlMessageToSspQueue() {
+ val sspsToDrain = new HashSet(sspKeyBucketsRegistered)
+
+ // only write Drain ControlMessages to source SSPs
+ // sspsToDrain = allSSPs - intermediateSSPs - eosSSPs
+ sspsToDrain.removeAll(intermediateSSPs)
+ sspsToDrain.removeAll(endOfStreamSSPs)
+
+ sspsToDrain.forEach(new Consumer[SystemStreamPartition] {
+ override def accept(ssp: SystemStreamPartition): Unit = {
+ val envelopes: Queue[IncomingMessageEnvelope] =
+ if (unprocessedMessagesBySSP.containsKey(ssp)) {
+ unprocessedMessagesBySSP.get(ssp)
+ } else {
+ new util.ArrayDeque[IncomingMessageEnvelope]()
+ }
+
+ // Add watermark ControlMessage only if there are intermediate SSPs as low-level API task doesn't process
+ // WatermarkMessages
+ if (!intermediateSSPs.isEmpty) {
+ envelopes.add(IncomingMessageEnvelope.buildWatermarkEnvelope(ssp, Long.MaxValue))
+ totalUnprocessedMessages += 1
+ }
+ // Add Drain ControlMessage
+ envelopes.add(IncomingMessageEnvelope.buildDrainMessage(ssp, runId))
+ totalUnprocessedMessages += 1
+ unprocessedMessagesBySSP.put(ssp, envelopes)
+
+ // update the chooser with the messages
+ tryUpdate(ssp)
+ }
+ })
}
/**
diff --git a/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java b/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
index 7bd5f9a..e584974 100644
--- a/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
+++ b/samza-core/src/test/java/org/apache/samza/container/TestRunLoop.java
@@ -20,8 +20,9 @@
package org.apache.samza.container;
import com.google.common.collect.ImmutableMap;
-import java.util.Collections;
+import java.util.Arrays;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
@@ -49,6 +50,7 @@
public class TestRunLoop {
// Immutable objects shared by all test methods.
+ private final String runId = "foo";
private final ExecutorService executor = null;
private final SamzaContainerMetrics containerMetrics = new SamzaContainerMetrics("container", new MetricsRegistryMap(), "");
private final long windowMs = -1;
@@ -60,13 +62,23 @@
private final Partition p1 = new Partition(1);
private final TaskName taskName0 = new TaskName(p0.toString());
private final TaskName taskName1 = new TaskName(p1.toString());
- private final SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0);
- private final SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p1);
- private final IncomingMessageEnvelope envelope00 = new IncomingMessageEnvelope(ssp0, "0", "key0", "value0");
- private final IncomingMessageEnvelope envelope11 = new IncomingMessageEnvelope(ssp1, "1", "key1", "value1");
- private final IncomingMessageEnvelope envelope01 = new IncomingMessageEnvelope(ssp0, "1", "key0", "value0");
- private final IncomingMessageEnvelope ssp0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp0);
- private final IncomingMessageEnvelope ssp1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(ssp1);
+ private final SystemStreamPartition sspA0 = new SystemStreamPartition("testSystem", "testStreamA", p0);
+ private final SystemStreamPartition sspA1 = new SystemStreamPartition("testSystem", "testStreamA", p1);
+ private final SystemStreamPartition sspB0 = new SystemStreamPartition("testSystem", "testStreamB", p0);
+ private final SystemStreamPartition sspB1 = new SystemStreamPartition("testSystem", "testStreamB", p1);
+ private final IncomingMessageEnvelope envelopeA00 = new IncomingMessageEnvelope(sspA0, "0", "key0", "value0");
+ private final IncomingMessageEnvelope envelopeA11 = new IncomingMessageEnvelope(sspA1, "1", "key1", "value1");
+ private final IncomingMessageEnvelope envelopeA01 = new IncomingMessageEnvelope(sspA0, "1", "key0", "value0");
+ private final IncomingMessageEnvelope envelopeB00 = new IncomingMessageEnvelope(sspB0, "0", "key0", "value0");
+ private final IncomingMessageEnvelope envelopeB11 = new IncomingMessageEnvelope(sspB1, "1", "key1", "value1");
+ private final IncomingMessageEnvelope sspA0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(sspA0);
+ private final IncomingMessageEnvelope sspA1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(sspA1);
+ private final IncomingMessageEnvelope sspB0EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(sspB0);
+ private final IncomingMessageEnvelope sspB1EndOfStream = IncomingMessageEnvelope.buildEndOfStreamEnvelope(sspB1);
+ private final IncomingMessageEnvelope sspA0Drain = IncomingMessageEnvelope.buildDrainMessage(sspA0, runId);
+ private final IncomingMessageEnvelope sspA1Drain = IncomingMessageEnvelope.buildDrainMessage(sspA1, runId);
+ private final IncomingMessageEnvelope sspB0Drain = IncomingMessageEnvelope.buildDrainMessage(sspB0, runId);
+ private final IncomingMessageEnvelope sspB1Drain = IncomingMessageEnvelope.buildDrainMessage(sspB1, runId);
@Rule
public Timeout maxTestDurationInSeconds = Timeout.seconds(120);
@@ -75,8 +87,8 @@
public void testProcessMultipleTasks() {
SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
- RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
- RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
+ RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1);
Map<TaskName, RunLoopTask> tasks = new HashMap<>();
tasks.put(taskName0, task0);
@@ -84,12 +96,13 @@
int maxMessagesInFlight = 1;
RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
- callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
- when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope11).thenReturn(ssp0EndOfStream).thenReturn(ssp1EndOfStream).thenReturn(null);
+ callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false, 1, "foo");
+ when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(sspA0EndOfStream).thenReturn(
+ sspA1EndOfStream).thenReturn(null);
runLoop.run();
- verify(task0).process(eq(envelope00), any(), any());
- verify(task1).process(eq(envelope11), any(), any());
+ verify(task0).process(eq(envelopeA00), any(), any());
+ verify(task1).process(eq(envelopeA11), any(), any());
assertEquals(4L, containerMetrics.envelopes().getCount());
}
@@ -97,9 +110,9 @@
@Test
public void testProcessInOrder() {
SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
- when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
+ when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(sspA0EndOfStream).thenReturn(null);
- RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
int maxMessagesInFlight = 1;
@@ -108,8 +121,8 @@
runLoop.run();
InOrder inOrder = inOrder(task0);
- inOrder.verify(task0).process(eq(envelope00), any(), any());
- inOrder.verify(task0).process(eq(envelope01), any(), any());
+ inOrder.verify(task0).process(eq(envelopeA00), any(), any());
+ inOrder.verify(task0).process(eq(envelopeA01), any(), any());
}
@Test
@@ -119,7 +132,7 @@
SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
OffsetManager offsetManager = mock(OffsetManager.class);
- RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
when(task0.offsetManager()).thenReturn(offsetManager);
CountDownLatch firstMessageBarrier = new CountDownLatch(1);
doAnswer(invocation -> {
@@ -134,7 +147,7 @@
return null;
});
return null;
- }).when(task0).process(eq(envelope00), any(), any());
+ }).when(task0).process(eq(envelopeA00), any(), any());
doAnswer(invocation -> {
assertEquals(1, task0.metrics().messagesInFlight().getValue());
@@ -145,21 +158,21 @@
callback.complete();
firstMessageBarrier.countDown();
return null;
- }).when(task0).process(eq(envelope01), any(), any());
+ }).when(task0).process(eq(envelopeA01), any(), any());
Map<TaskName, RunLoopTask> tasks = new HashMap<>();
tasks.put(taskName0, task0);
RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
- when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(null);
+ when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(null);
runLoop.run();
InOrder inOrder = inOrder(task0);
- inOrder.verify(task0).process(eq(envelope00), any(), any());
- inOrder.verify(task0).process(eq(envelope01), any(), any());
+ inOrder.verify(task0).process(eq(envelopeA00), any(), any());
+ inOrder.verify(task0).process(eq(envelopeA01), any(), any());
- verify(offsetManager).update(eq(taskName0), eq(ssp0), eq(envelope00.getOffset()));
+ verify(offsetManager).update(eq(taskName0), eq(sspA0), eq(envelopeA00.getOffset()));
assertEquals(2L, containerMetrics.processes().getCount());
}
@@ -168,9 +181,9 @@
public void testProcessElasticityEnabled() {
TaskName taskName0 = new TaskName(p0.toString() + " 0");
- SystemStreamPartition ssp = new SystemStreamPartition("testSystem", "testStream", p0);
- SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0, 0);
- SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p0, 1);
+ SystemStreamPartition ssp = new SystemStreamPartition("testSystem", "testStreamA", p0);
+ SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStreamA", p0, 0);
+ SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStreamA", p0, 1);
// create two IME such that one of their ssp keybucket maps to ssp0 and the other one maps to ssp1
// task in the runloop should process only the first ime (aka the one whose ssp keybucket is ssp0)
@@ -197,12 +210,12 @@
}).when(task0).process(eq(envelope01), any(), any());
SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
- when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream).thenReturn(null);
+ when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(sspA0EndOfStream).thenReturn(null);
Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
int maxMessagesInFlight = 1;
RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
- callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2);
+ callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2, null);
runLoop.run();
verify(task0).process(eq(envelope00), any(), any());
@@ -216,9 +229,9 @@
TaskName taskName0 = new TaskName(p0.toString() + " 0");
TaskName taskName1 = new TaskName(p0.toString() + " 1");
- SystemStreamPartition ssp = new SystemStreamPartition("testSystem", "testStream", p0);
- SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStream", p0, 0);
- SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStream", p0, 1);
+ SystemStreamPartition ssp = new SystemStreamPartition("testSystem", "testStreamA", p0);
+ SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStreamA", p0, 0);
+ SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStreamA", p0, 1);
// create EOS IME such that its ssp keybucket maps to ssp0 and not to ssp1
// task in the runloop should give this ime to both it tasks
@@ -237,7 +250,7 @@
Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0, taskName1, task1);
int maxMessagesInFlight = 1;
RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
- callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2);
+ callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2, null);
runLoop.run();
verify(task0).endOfStream(any());
@@ -245,12 +258,104 @@
}
@Test
+ public void testDrainWithElasticityEnabled() {
+ TaskName taskName0 = new TaskName(p0.toString() + " 0");
+ TaskName taskName1 = new TaskName(p0.toString() + " 1");
+ SystemStreamPartition ssp = new SystemStreamPartition("testSystem", "testStreamA", p0);
+ SystemStreamPartition ssp0 = new SystemStreamPartition("testSystem", "testStreamA", p0, 0);
+ SystemStreamPartition ssp1 = new SystemStreamPartition("testSystem", "testStreamA", p0, 1);
+
+ // create EOS IME such that its ssp keybucket maps to ssp0 and not to ssp1
+ // task in the runloop should give this ime to both it tasks
+ IncomingMessageEnvelope envelopeDrain = spy(IncomingMessageEnvelope.buildDrainMessage(ssp, runId));
+ when(envelopeDrain.getSystemStreamPartition(2)).thenReturn(ssp0);
+
+ // two task in the run loop that processes ssp0 -> 0th keybucket of ssp and ssp1 -> 1st keybucket of ssp
+ // Drain ime should be given to both the tasks irrespective of the keybucket
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+ RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+
+ SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
+ when(consumerMultiplexer.choose(false)).thenReturn(envelopeDrain).thenReturn(null);
+
+ Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0, taskName1, task1);
+ int maxMessagesInFlight = 1;
+ RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
+ callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 2, runId);
+ runLoop.run();
+
+ verify(task0).drain(any());
+ verify(task1).drain(any());
+ }
+
+
+ @Test
+ public void testDrainForTasksWithSingleSSP() {
+ TaskName taskName0 = new TaskName(p0.toString() + " 0");
+ TaskName taskName1 = new TaskName(p1.toString() + " 1");
+
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
+ RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1);
+
+ SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
+ // insert all envelopes followed by drain messages
+ when(consumerMultiplexer.choose(false))
+ .thenReturn(envelopeA00).thenReturn(envelopeA11)
+ .thenReturn(sspA0Drain).thenReturn(sspA1Drain);
+
+ Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0, taskName1, task1);
+ int maxMessagesInFlight = 1;
+ RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
+ callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 1, runId);
+ runLoop.run();
+
+ // check if process was called once for each task
+ verify(task0, times(1)).process(any(), any(), any());
+ verify(task1, times(1)).process(any(), any(), any());
+ // check if drain was called once for each task followed by commit
+ verify(task0, times(1)).drain(any());
+ verify(task1, times(1)).drain(any());
+ verify(task0, times(1)).commit();
+ verify(task1, times(1)).commit();
+ }
+
+ @Test
+ public void testDrainForTasksWithMultipleSSP() {
+ TaskName taskName0 = new TaskName(p0.toString() + " 0");
+ TaskName taskName1 = new TaskName(p1.toString() + " 1");
+
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0, sspB0);
+ RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1, sspB1);
+
+ SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
+ // insert all envelopes followed by drain messages
+ when(consumerMultiplexer.choose(false))
+ .thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(envelopeB00).thenReturn(envelopeB11)
+ .thenReturn(sspA0Drain).thenReturn(sspA1Drain).thenReturn(sspB0Drain).thenReturn(sspB1Drain);
+
+ Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0, taskName1, task1);
+ int maxMessagesInFlight = 1;
+ RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
+ callbackTimeoutMs, maxThrottlingDelayMs, 0, containerMetrics, () -> 0L, false, 1, runId);
+ runLoop.run();
+
+ // check if process was called twice for each task
+ verify(task0, times(2)).process(any(), any(), any());
+ verify(task1, times(2)).process(any(), any(), any());
+ // check if drain was called once for each task followed by commit
+ verify(task0, times(1)).drain(any());
+ verify(task1, times(1)).drain(any());
+ verify(task0, times(1)).commit();
+ verify(task1, times(1)).commit();
+ }
+
+ @Test
public void testWindow() {
SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
int maxMessagesInFlight = 1;
long windowMs = 1;
- RunLoopTask task = getMockRunLoopTask(taskName0, ssp0);
+ RunLoopTask task = getMockRunLoopTask(taskName0, sspA0);
when(task.isWindowableTask()).thenReturn(true);
final AtomicInteger windowCount = new AtomicInteger(0);
@@ -277,7 +382,7 @@
public void testCommitSingleTask() {
SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
- RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
doAnswer(invocation -> {
ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
@@ -288,9 +393,9 @@
callback.complete();
return null;
- }).when(task0).process(eq(envelope00), any(), any());
+ }).when(task0).process(eq(envelopeA00), any(), any());
- RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+ RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1);
Map<TaskName, RunLoopTask> tasks = new HashMap<>();
tasks.put(this.taskName0, task0);
@@ -300,7 +405,7 @@
RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
//have a null message in between to make sure task0 finishes processing and invoke the commit
- when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope11).thenReturn(null);
+ when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(null);
runLoop.run();
@@ -315,7 +420,7 @@
public void testCommitAllTasks() {
SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
- RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
doAnswer(invocation -> {
ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
@@ -326,9 +431,9 @@
callback.complete();
return null;
- }).when(task0).process(eq(envelope00), any(), any());
+ }).when(task0).process(eq(envelopeA00), any(), any());
- RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+ RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1);
Map<TaskName, RunLoopTask> tasks = new HashMap<>();
tasks.put(this.taskName0, task0);
@@ -338,7 +443,7 @@
RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
//have a null message in between to make sure task0 finishes processing and invoke the commit
- when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope11).thenReturn(null);
+ when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(null);
runLoop.run();
@@ -354,7 +459,7 @@
SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
int maxMessagesInFlight = 1;
- RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
doAnswer(invocation -> {
ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
@@ -363,9 +468,9 @@
coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
callback.complete();
return null;
- }).when(task0).process(eq(envelope00), any(), any());
+ }).when(task0).process(eq(envelopeA00), any(), any());
- RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+ RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1);
doAnswer(invocation -> {
ReadableCoordinator coordinator = invocation.getArgumentAt(1, ReadableCoordinator.class);
TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
@@ -374,7 +479,7 @@
coordinator.shutdown(TaskCoordinator.RequestScope.CURRENT_TASK);
callback.complete();
return null;
- }).when(task1).process(eq(envelope11), any(), any());
+ }).when(task1).process(eq(envelopeA11), any(), any());
Map<TaskName, RunLoopTask> tasks = new HashMap<>();
tasks.put(taskName0, task0);
@@ -383,7 +488,7 @@
RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
// consensus is reached after envelope1 is processed.
- when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope11).thenReturn(null);
+ when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA11).thenReturn(null);
runLoop.run();
verify(task0).process(any(), any(), any());
@@ -397,8 +502,8 @@
public void testEndOfStreamWithMultipleTasks() {
SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
- RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
- RunLoopTask task1 = getMockRunLoopTask(taskName1, ssp1);
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0, sspB0);
+ RunLoopTask task1 = getMockRunLoopTask(taskName1, sspA1, sspB1);
Map<TaskName, RunLoopTask> tasks = new HashMap<>();
@@ -409,21 +514,27 @@
RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
when(consumerMultiplexer.choose(false))
- .thenReturn(envelope00)
- .thenReturn(envelope11)
- .thenReturn(ssp0EndOfStream)
- .thenReturn(ssp1EndOfStream)
+ .thenReturn(envelopeA00)
+ .thenReturn(envelopeA11)
+ .thenReturn(envelopeB00)
+ .thenReturn(envelopeB11)
+ .thenReturn(sspA0EndOfStream)
+ .thenReturn(sspB0EndOfStream)
+ .thenReturn(sspB1EndOfStream)
+ .thenReturn(sspA1EndOfStream)
.thenReturn(null);
runLoop.run();
- verify(task0).process(eq(envelope00), any(), any());
+ verify(task0).process(eq(envelopeA00), any(), any());
+ verify(task0).process(eq(envelopeB00), any(), any());
verify(task0).endOfStream(any());
- verify(task1).process(eq(envelope11), any(), any());
+ verify(task1).process(eq(envelopeA11), any(), any());
+ verify(task1).process(eq(envelopeB11), any(), any());
verify(task1).endOfStream(any());
- assertEquals(4L, containerMetrics.envelopes().getCount());
+ assertEquals(8L, containerMetrics.envelopes().getCount());
}
@Test
@@ -433,7 +544,7 @@
SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
OffsetManager offsetManager = mock(OffsetManager.class);
- RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
when(task0.offsetManager()).thenReturn(offsetManager);
CountDownLatch firstMessageBarrier = new CountDownLatch(2);
doAnswer(invocation -> {
@@ -445,7 +556,7 @@
return null;
});
return null;
- }).when(task0).process(eq(envelope00), any(), any());
+ }).when(task0).process(eq(envelopeA00), any(), any());
doAnswer(invocation -> {
assertEquals(1, task0.metrics().messagesInFlight().getValue());
@@ -455,7 +566,7 @@
callback.complete();
firstMessageBarrier.countDown();
return null;
- }).when(task0).process(eq(envelope01), any(), any());
+ }).when(task0).process(eq(envelopeA01), any(), any());
doAnswer(invocation -> {
assertEquals(0, task0.metrics().messagesInFlight().getValue());
@@ -469,7 +580,7 @@
RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
- when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(ssp0EndOfStream)
+ when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(sspA0EndOfStream)
.thenAnswer(invocation -> {
// this ensures that the end of stream message has passed through run loop BEFORE the last remaining in flight message completes
firstMessageBarrier.countDown();
@@ -482,10 +593,67 @@
}
@Test
+ public void testDrainWaitsForInFlightMessages() {
+ int maxMessagesInFlight = 2;
+ ExecutorService taskExecutor = Executors.newFixedThreadPool(1);
+ SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
+ OffsetManager offsetManager = mock(OffsetManager.class);
+
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
+ when(task0.offsetManager()).thenReturn(offsetManager);
+ CountDownLatch firstMessageBarrier = new CountDownLatch(2);
+ doAnswer(invocation -> {
+ TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+ TaskCallback callback = callbackFactory.createCallback();
+ taskExecutor.submit(() -> {
+ firstMessageBarrier.await();
+ callback.complete();
+ return null;
+ });
+ return null;
+ }).when(task0).process(eq(envelopeA00), any(), any());
+
+ doAnswer(invocation -> {
+ assertEquals(1, task0.metrics().messagesInFlight().getValue());
+
+ TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
+ TaskCallback callback = callbackFactory.createCallback();
+ callback.complete();
+ firstMessageBarrier.countDown();
+ return null;
+ }).when(task0).process(eq(envelopeA01), any(), any());
+
+ doAnswer(invocation -> {
+ assertEquals(0, task0.metrics().messagesInFlight().getValue());
+ assertEquals(2, task0.metrics().asyncCallbackCompleted().getCount());
+
+ return null;
+ }).when(task0).drain(any());
+
+ Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+ tasks.put(taskName0, task0);
+
+ RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
+ callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false,
+ 1, runId);
+ when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(sspA0Drain)
+ .thenAnswer(invocation -> {
+ // this ensures that the drain message has passed through run loop BEFORE the flight message
+ // completes
+ firstMessageBarrier.countDown();
+ return null;
+ });
+
+ runLoop.run();
+
+ verify(task0).drain(any());
+ }
+
+ @Test
public void testEndOfStreamCommitBehavior() {
SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
- RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
doAnswer(invocation -> {
ReadableCoordinator coordinator = invocation.getArgumentAt(0, ReadableCoordinator.class);
@@ -500,7 +668,7 @@
int maxMessagesInFlight = 1;
RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
- when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(ssp0EndOfStream).thenReturn(null);
+ when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(sspA0EndOfStream).thenReturn(null);
runLoop.run();
@@ -511,13 +679,36 @@
}
@Test
+ public void testDrainCommitBehavior() {
+ SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
+
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
+ Map<TaskName, RunLoopTask> tasks = new HashMap<>();
+
+ tasks.put(taskName0, task0);
+
+ int maxMessagesInFlight = 1;
+ RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
+ callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false,
+ 1, runId);
+ when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(sspA0Drain).thenReturn(null);
+
+ runLoop.run();
+
+ InOrder inOrder = inOrder(task0);
+
+ inOrder.verify(task0).drain(any());
+ inOrder.verify(task0).commit();
+ }
+
+ @Test
public void testCommitWithMessageInFlightWhenAsyncCommitIsEnabled() {
int maxMessagesInFlight = 2;
ExecutorService taskExecutor = Executors.newFixedThreadPool(2);
SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
OffsetManager offsetManager = mock(OffsetManager.class);
- RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
when(task0.offsetManager()).thenReturn(offsetManager);
CountDownLatch firstMessageBarrier = new CountDownLatch(1);
doAnswer(invocation -> {
@@ -532,7 +723,7 @@
return null;
});
return null;
- }).when(task0).process(eq(envelope00), any(), any());
+ }).when(task0).process(eq(envelopeA00), any(), any());
CountDownLatch secondMessageBarrier = new CountDownLatch(1);
doAnswer(invocation -> {
@@ -550,7 +741,7 @@
return null;
});
return null;
- }).when(task0).process(eq(envelope01), any(), any());
+ }).when(task0).process(eq(envelopeA01), any(), any());
doAnswer(invocation -> {
assertEquals(1, task0.metrics().asyncCallbackCompleted().getCount());
@@ -565,12 +756,12 @@
RunLoop runLoop = new RunLoop(tasks, executor, consumerMultiplexer, maxMessagesInFlight, windowMs, commitMs,
callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, true);
- when(consumerMultiplexer.choose(false)).thenReturn(envelope00).thenReturn(envelope01).thenReturn(null);
+ when(consumerMultiplexer.choose(false)).thenReturn(envelopeA00).thenReturn(envelopeA01).thenReturn(null);
runLoop.run();
InOrder inOrder = inOrder(task0);
- inOrder.verify(task0).process(eq(envelope00), any(), any());
- inOrder.verify(task0).process(eq(envelope01), any(), any());
+ inOrder.verify(task0).process(eq(envelopeA00), any(), any());
+ inOrder.verify(task0).process(eq(envelopeA01), any(), any());
inOrder.verify(task0).commit();
}
@@ -578,12 +769,12 @@
public void testExceptionIsPropagated() {
SystemConsumers consumerMultiplexer = mock(SystemConsumers.class);
- RunLoopTask task0 = getMockRunLoopTask(taskName0, ssp0);
+ RunLoopTask task0 = getMockRunLoopTask(taskName0, sspA0);
doAnswer(invocation -> {
TaskCallbackFactory callbackFactory = invocation.getArgumentAt(2, TaskCallbackFactory.class);
callbackFactory.createCallback().failure(new Exception("Intentional failure"));
return null;
- }).when(task0).process(eq(envelope00), any(), any());
+ }).when(task0).process(eq(envelopeA00), any(), any());
Map<TaskName, RunLoopTask> tasks = ImmutableMap.of(taskName0, task0);
@@ -592,16 +783,16 @@
callbackTimeoutMs, maxThrottlingDelayMs, maxIdleMs, containerMetrics, () -> 0L, false);
when(consumerMultiplexer.choose(false))
- .thenReturn(envelope00)
- .thenReturn(ssp0EndOfStream)
+ .thenReturn(envelopeA00)
+ .thenReturn(sspA0EndOfStream)
.thenReturn(null);
runLoop.run();
}
- private RunLoopTask getMockRunLoopTask(TaskName taskName, SystemStreamPartition ssp0) {
+ private RunLoopTask getMockRunLoopTask(TaskName taskName, SystemStreamPartition ... ssps) {
RunLoopTask task0 = mock(RunLoopTask.class);
- when(task0.systemStreamPartitions()).thenReturn(Collections.singleton(ssp0));
+ when(task0.systemStreamPartitions()).thenReturn(new HashSet<>(Arrays.asList(ssps)));
when(task0.metrics()).thenReturn(new TaskInstanceMetrics("test", new MetricsRegistryMap(), ""));
when(task0.taskName()).thenReturn(taskName);
return task0;
diff --git a/samza-core/src/test/java/org/apache/samza/drain/DrainMonitorTests.java b/samza-core/src/test/java/org/apache/samza/drain/DrainMonitorTests.java
index e7666aa..80061b3 100644
--- a/samza-core/src/test/java/org/apache/samza/drain/DrainMonitorTests.java
+++ b/samza-core/src/test/java/org/apache/samza/drain/DrainMonitorTests.java
@@ -41,13 +41,13 @@
* Tests for {@link DrainMonitor}
* */
public class DrainMonitorTests {
- private static final String TEST_DEPLOYMENT_ID = "foo";
+ private static final String TEST_RUN_ID = "foo";
private static final Config
CONFIG = new MapConfig(ImmutableMap.of(
"job.name", "test-job",
"job.coordinator.system", "test-kafka",
- ApplicationConfig.APP_RUN_ID, TEST_DEPLOYMENT_ID));
+ ApplicationConfig.APP_RUN_ID, TEST_RUN_ID));
private CoordinatorStreamStore coordinatorStreamStore;
@@ -114,7 +114,7 @@
final AtomicInteger numCallbacks = new AtomicInteger(0);
final CountDownLatch latch = new CountDownLatch(1);
// write drain before monitor start
- DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_DEPLOYMENT_ID);
+ DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_RUN_ID);
DrainMonitor drainMonitor = new DrainMonitor(coordinatorStreamStore, CONFIG);
drainMonitor.registerDrainCallback(() -> {
numCallbacks.incrementAndGet();
@@ -141,7 +141,7 @@
latch.countDown();
});
drainMonitor.start();
- DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_DEPLOYMENT_ID);
+ DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_RUN_ID);
if (!latch.await(2, TimeUnit.SECONDS)) {
Assert.fail("Timed out waiting for drain callback to complete");
}
@@ -150,8 +150,8 @@
}
@Test
- public void testCallbackNotCalledDueToMismatchedDeploymentId() throws InterruptedException {
- // The test fails due to timeout as the published DrainNotification's deploymentId doesn't match deploymentId
+ public void testCallbackNotCalledDueToMismatchedRunId() throws InterruptedException {
+ // The test fails due to timeout as the published DrainNotification's runId doesn't match runId
// in the config
exceptionRule.expect(AssertionError.class);
exceptionRule.expectMessage("Timed out waiting for drain callback to complete.");
@@ -166,8 +166,8 @@
});
drainMonitor.start();
- final String mismatchedDeploymentId = "bar";
- DrainUtils.writeDrainNotification(coordinatorStreamStore, mismatchedDeploymentId);
+ final String mismatchedRunId = "bar";
+ DrainUtils.writeDrainNotification(coordinatorStreamStore, mismatchedRunId);
if (!latch.await(2, TimeUnit.SECONDS)) {
Assert.fail("Timed out waiting for drain callback to complete.");
}
@@ -184,16 +184,16 @@
@Test
public void testShouldDrain() {
- DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_DEPLOYMENT_ID);
+ DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_RUN_ID);
NamespaceAwareCoordinatorStreamStore drainStore =
new NamespaceAwareCoordinatorStreamStore(coordinatorStreamStore, DrainUtils.DRAIN_METADATA_STORE_NAMESPACE);
- Assert.assertTrue(DrainMonitor.shouldDrain(drainStore, TEST_DEPLOYMENT_ID));
+ Assert.assertTrue(DrainMonitor.shouldDrain(drainStore, TEST_RUN_ID));
// Cleanup old drain message
DrainUtils.cleanup(coordinatorStreamStore, CONFIG);
- final String mismatchedDeploymentId = "bar";
- DrainUtils.writeDrainNotification(coordinatorStreamStore, mismatchedDeploymentId);
- Assert.assertFalse(DrainMonitor.shouldDrain(drainStore, TEST_DEPLOYMENT_ID));
+ final String mismatchedRunId = "bar";
+ DrainUtils.writeDrainNotification(coordinatorStreamStore, mismatchedRunId);
+ Assert.assertFalse(DrainMonitor.shouldDrain(drainStore, TEST_RUN_ID));
}
}
diff --git a/samza-core/src/test/java/org/apache/samza/drain/DrainUtilsTests.java b/samza-core/src/test/java/org/apache/samza/drain/DrainUtilsTests.java
index 1726535..43b91ad 100644
--- a/samza-core/src/test/java/org/apache/samza/drain/DrainUtilsTests.java
+++ b/samza-core/src/test/java/org/apache/samza/drain/DrainUtilsTests.java
@@ -45,11 +45,11 @@
* Tests for {@link DrainUtils}
* */
public class DrainUtilsTests {
- private static final String TEST_DEPLOYMENT_ID = "foo";
+ private static final String TEST_RUN_ID = "foo";
private static final Config CONFIG = new MapConfig(ImmutableMap.of(
"job.name", "test-job",
"job.coordinator.system", "test-kafka",
- ApplicationConfig.APP_RUN_ID, TEST_DEPLOYMENT_ID));
+ ApplicationConfig.APP_RUN_ID, TEST_RUN_ID));
private CoordinatorStreamStore coordinatorStreamStore;
@@ -68,17 +68,17 @@
@Test
public void testWrites() {
- String deploymentId1 = "foo1";
- String deploymentId2 = "foo2";
- String deploymentId3 = "foo3";
+ String runId1 = "foo1";
+ String runId2 = "foo2";
+ String runId3 = "foo3";
- UUID uuid1 = DrainUtils.writeDrainNotification(coordinatorStreamStore, deploymentId1);
- UUID uuid2 = DrainUtils.writeDrainNotification(coordinatorStreamStore, deploymentId2);
- UUID uuid3 = DrainUtils.writeDrainNotification(coordinatorStreamStore, deploymentId3);
+ UUID uuid1 = DrainUtils.writeDrainNotification(coordinatorStreamStore, runId1);
+ UUID uuid2 = DrainUtils.writeDrainNotification(coordinatorStreamStore, runId2);
+ UUID uuid3 = DrainUtils.writeDrainNotification(coordinatorStreamStore, runId3);
- DrainNotification expectedDrainNotification1 = new DrainNotification(uuid1, deploymentId1);
- DrainNotification expectedDrainNotification2 = new DrainNotification(uuid2, deploymentId2);
- DrainNotification expectedDrainNotification3 = new DrainNotification(uuid3, deploymentId3);
+ DrainNotification expectedDrainNotification1 = new DrainNotification(uuid1, runId1);
+ DrainNotification expectedDrainNotification2 = new DrainNotification(uuid2, runId2);
+ DrainNotification expectedDrainNotification3 = new DrainNotification(uuid3, runId3);
Set<DrainNotification> expectedDrainNotifications = new HashSet<>(Arrays.asList(expectedDrainNotification1,
expectedDrainNotification2, expectedDrainNotification3));
@@ -90,23 +90,23 @@
@Test
public void testCleanup() {
- DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_DEPLOYMENT_ID);
+ DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_RUN_ID);
DrainUtils.cleanup(coordinatorStreamStore, CONFIG);
final Optional<List<DrainNotification>> drainNotifications1 = readDrainNotificationMessages(coordinatorStreamStore);
Assert.assertFalse(drainNotifications1.isPresent());
- final String deploymentId = "bar";
- DrainUtils.writeDrainNotification(coordinatorStreamStore, deploymentId);
+ final String runId = "bar";
+ DrainUtils.writeDrainNotification(coordinatorStreamStore, runId);
DrainUtils.cleanup(coordinatorStreamStore, CONFIG);
final Optional<List<DrainNotification>> drainNotifications2 = readDrainNotificationMessages(coordinatorStreamStore);
Assert.assertTrue(drainNotifications2.isPresent());
- Assert.assertEquals(deploymentId, drainNotifications2.get().get(0).getDeploymentId());
+ Assert.assertEquals(runId, drainNotifications2.get().get(0).getRunId());
}
@Test
public void testCleanupAll() {
- DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_DEPLOYMENT_ID);
- DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_DEPLOYMENT_ID);
+ DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_RUN_ID);
+ DrainUtils.writeDrainNotification(coordinatorStreamStore, TEST_RUN_ID);
DrainUtils.writeDrainNotification(coordinatorStreamStore, "bar");
DrainUtils.cleanupAll(coordinatorStreamStore);
final Optional<List<DrainNotification>> drainNotifications = readDrainNotificationMessages(coordinatorStreamStore);
diff --git a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
index d9bb916..6b2154e 100644
--- a/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
+++ b/samza-core/src/test/scala/org/apache/samza/container/TestSamzaContainer.scala
@@ -57,7 +57,7 @@
@Mock
private var taskInstance: TaskInstance = null
@Mock
- private var runLoop: Runnable = null
+ private var runLoop: RunLoop = null
@Mock
private var systemAdmins: SystemAdmins = null
@Mock
diff --git a/samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java b/samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java
index 943a8ca..60323d6 100644
--- a/samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java
+++ b/samza-test/src/main/java/org/apache/samza/test/framework/TestRunner.java
@@ -20,6 +20,7 @@
package org.apache.samza.test.framework;
import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
import java.io.File;
import java.time.Duration;
import java.util.HashMap;
@@ -28,6 +29,9 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.commons.lang3.RandomStringUtils;
import org.apache.samza.SamzaException;
@@ -93,10 +97,12 @@
private static final Logger LOG = LoggerFactory.getLogger(TestRunner.class);
private static final String JOB_DEFAULT_SYSTEM = "default-samza-system";
private static final String APP_NAME = "samza-test";
+ private static final long DEFAULT_DELAY_BETWEEN_MESSAGES = 0L;
private Map<String, String> configs;
private SamzaApplication app;
private ExternalContext externalContext;
+ private InMemoryMetadataStoreFactory inMemoryMetadataStoreFactory;
/*
* inMemoryScope is a unique global key per TestRunner, this key when configured with {@link InMemorySystemDescriptor}
* provides an isolated state to run with in memory system
@@ -201,6 +207,13 @@
}
/**
+ * Get configs.
+ * */
+ public Config getConfig() {
+ return new MapConfig(configs);
+ }
+
+ /**
* Passes the user provided external context to {@link LocalApplicationRunner}
*
* @param externalContext external context provided by user
@@ -213,6 +226,18 @@
}
/**
+ * Set an InMemory MetadataStoreFactory to be used by {@link LocalApplicationRunner}
+ *
+ * @param inMemoryMetadataStoreFactory InMemoryMetadataStoreFactory
+ * @return this {@link TestRunner}
+ */
+ public TestRunner setInMemoryMetadataFactory(InMemoryMetadataStoreFactory inMemoryMetadataStoreFactory) {
+ Preconditions.checkNotNull(inMemoryMetadataStoreFactory);
+ this.inMemoryMetadataStoreFactory = inMemoryMetadataStoreFactory;
+ return this;
+ }
+
+ /**
* Adds the provided input stream with mock data to the test application.
*
* @param descriptor describes the stream that is supposed to be input to Samza application
@@ -224,10 +249,29 @@
*/
public <StreamMessageType> TestRunner addInputStream(InMemoryInputDescriptor descriptor,
List<StreamMessageType> messages) {
+ return addInputStream(descriptor, messages, DEFAULT_DELAY_BETWEEN_MESSAGES);
+ }
+
+ /**
+ * Adds the provided input stream with mock data to the test application.The mock messages will be added one at a time
+ * with a delay between messages instead of adding all at once.
+ * Default configs and user added configs have
+ * a higher precedence over system and stream descriptor generated configs.
+ * @param descriptor describes the stream that is supposed to be input to Samza application
+ * @param messages map whose key is partitionId and value is messages in the partition. These message should always
+ * be deserialized
+ * @param delayBetweenMessagesInMillis delay between messages
+ * @param <StreamMessageType> message with null key or a KV {@link org.apache.samza.operators.KV}.
+ * A key of which represents key of {@link org.apache.samza.system.IncomingMessageEnvelope} or
+ * {@link org.apache.samza.system.OutgoingMessageEnvelope} and value is message
+ * @return this {@link TestRunner}
+ */
+ public <StreamMessageType> TestRunner addInputStream(InMemoryInputDescriptor descriptor,
+ List<StreamMessageType> messages, long delayBetweenMessagesInMillis) {
Preconditions.checkNotNull(descriptor, messages);
- Map<Integer, Iterable<StreamMessageType>> partitionData = new HashMap<Integer, Iterable<StreamMessageType>>();
+ Map<Integer, Iterable<StreamMessageType>> partitionData = new HashMap<>();
partitionData.put(0, messages);
- initializeInMemoryInputStream(descriptor, partitionData);
+ initializeInMemoryInputStream(descriptor, partitionData, delayBetweenMessagesInMillis);
return this;
}
@@ -244,10 +288,28 @@
*/
public <StreamMessageType> TestRunner addInputStream(InMemoryInputDescriptor descriptor,
Map<Integer, ? extends Iterable<StreamMessageType>> messages) {
+ return addInputStream(descriptor, messages, DEFAULT_DELAY_BETWEEN_MESSAGES);
+ }
+
+ /**
+ * Adds the provided input stream with mock data to the test application. The mock messages will be added one at a time
+ * with a delay between messages instead of adding all at once. TestRunner will cycle through all partitions once to add
+ * a message each per partition. This will be repeated periodically till all messages are exhausted.
+ * Default configs and user added configs have a higher precedence over system and stream descriptor generated configs.
+ * @param descriptor describes the stream that is supposed to be input to Samza application
+ * @param messages map whose key is partitionId and value is messages in the partition. These message should always
+ * be deserialized
+ * @param delayBetweenMessagesInMillis delay bewtween messages.
+ * @param <StreamMessageType> message with null key or a KV {@link org.apache.samza.operators.KV}.
+ * A key of which represents key of {@link org.apache.samza.system.IncomingMessageEnvelope} or
+ * {@link org.apache.samza.system.OutgoingMessageEnvelope} and value is message
+ * @return this {@link TestRunner}
+ */
+ public <StreamMessageType> TestRunner addInputStream(InMemoryInputDescriptor descriptor,
+ Map<Integer, ? extends Iterable<StreamMessageType>> messages, long delayBetweenMessagesInMillis) {
Preconditions.checkNotNull(descriptor, messages);
- Map<Integer, Iterable<StreamMessageType>> partitionData = new HashMap<Integer, Iterable<StreamMessageType>>();
- partitionData.putAll(messages);
- initializeInMemoryInputStream(descriptor, partitionData);
+ Map<Integer, Iterable<StreamMessageType>> partitionData = new HashMap<>(messages);
+ initializeInMemoryInputStream(descriptor, partitionData, delayBetweenMessagesInMillis);
return this;
}
@@ -291,7 +353,10 @@
// Cleaning store directories to ensure current run does not pick up state from previous run
deleteStoreDirectories();
Config config = new MapConfig(JobPlanner.generateSingleJobConfig(configs));
- final LocalApplicationRunner runner = new LocalApplicationRunner(app, config, new InMemoryMetadataStoreFactory());
+ InMemoryMetadataStoreFactory metadataStoreFactory = inMemoryMetadataStoreFactory != null
+ ? inMemoryMetadataStoreFactory
+ : new InMemoryMetadataStoreFactory();
+ final LocalApplicationRunner runner = new LocalApplicationRunner(app, config, metadataStoreFactory);
runner.run(externalContext);
if (!runner.waitForFinish(timeout)) {
throw new SamzaException("Timed out waiting for application to finish");
@@ -378,7 +443,7 @@
* @param descriptor describes a stream to initialize with the in memory system
*/
private <StreamMessageType> void initializeInMemoryInputStream(InMemoryInputDescriptor<?> descriptor,
- Map<Integer, Iterable<StreamMessageType>> partitionData) {
+ Map<Integer, Iterable<StreamMessageType>> partitionData, long delayInMillis) {
String systemName = descriptor.getSystemName();
String streamName = (String) descriptor.getPhysicalName().orElse(descriptor.getStreamId());
if (this.app instanceof LegacyTaskApplication) {
@@ -402,19 +467,64 @@
factory.getAdmin(systemName, config).createStream(spec);
InMemorySystemProducer producer = (InMemorySystemProducer) factory.getProducer(systemName, config, null);
SystemStream sysStream = new SystemStream(systemName, streamName);
- partitionData.forEach((partitionId, partition) -> {
- partition.forEach(e -> {
- Object key = e instanceof KV ? ((KV) e).getKey() : null;
- Object value = e instanceof KV ? ((KV) e).getValue() : e;
- if (value instanceof IncomingMessageEnvelope) {
- producer.send((IncomingMessageEnvelope) value);
- } else {
- producer.send(systemName, new OutgoingMessageEnvelope(sysStream, Integer.valueOf(partitionId), key, value));
- }
+ if (delayInMillis > 0) {
+ delayedInitialization(partitionData, producer, sysStream, delayInMillis);
+ } else {
+ partitionData.forEach((partitionId, partition) -> {
+ partition.forEach(e -> {
+ Object key = e instanceof KV ? ((KV) e).getKey() : null;
+ Object value = e instanceof KV ? ((KV) e).getValue() : e;
+ if (value instanceof IncomingMessageEnvelope) {
+ producer.send((IncomingMessageEnvelope) value);
+ } else {
+ producer.send(systemName, new OutgoingMessageEnvelope(sysStream, Integer.valueOf(partitionId), key, value));
+ }
+ });
+ producer.send(systemName, new OutgoingMessageEnvelope(sysStream, Integer.valueOf(partitionId), null,
+ new EndOfStreamMessage(null)));
});
- producer.send(systemName, new OutgoingMessageEnvelope(sysStream, Integer.valueOf(partitionId), null,
- new EndOfStreamMessage(null)));
- });
+ }
+ }
+
+ private <StreamMessageType> void delayedInitialization(Map<Integer, Iterable<StreamMessageType>> partitionData,
+ InMemorySystemProducer producer, SystemStream systemStream, long delayInMillis) {
+ final Set<Integer> endOfStreamPartitions = new HashSet<>();
+ final int numPartitions = partitionData.size();
+ Map<Integer, LinkedList<StreamMessageType>> messageQueuesByPartition = partitionData.entrySet()
+ .stream()
+ .collect(Collectors.toMap(
+ Map.Entry::getKey,
+ entry -> Lists.newLinkedList(entry.getValue())));
+
+ final ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
+ executor.scheduleAtFixedRate(new Runnable() {
+ @Override
+ public void run() {
+ if (endOfStreamPartitions.size() == numPartitions) {
+ // shutdown scheduled executor as EOS is reached for all partitions
+ executor.shutdownNow();
+ } else {
+ messageQueuesByPartition.forEach((partitionId, messages) -> {
+ if (messages.isEmpty() && !endOfStreamPartitions.contains(partitionId)) {
+ // if end of input is reached, send EOS for the partition and add the partition to the EOS partition set
+ // to ensure we don't send EOS again for the partition
+ producer.send(systemStream.getSystem(),
+ new OutgoingMessageEnvelope(systemStream, Integer.valueOf(partitionId), null, new EndOfStreamMessage(null)));
+ endOfStreamPartitions.add(partitionId);
+ } else if (!messages.isEmpty()) {
+ final StreamMessageType e = messageQueuesByPartition.get(partitionId).removeFirst();
+ final Object key = e instanceof KV ? ((KV) e).getKey() : null;
+ final Object value = e instanceof KV ? ((KV) e).getValue() : e;
+ if (value instanceof IncomingMessageEnvelope) {
+ producer.send((IncomingMessageEnvelope) value);
+ } else {
+ producer.send(systemStream.getSystem(), new OutgoingMessageEnvelope(systemStream, Integer.valueOf(partitionId), key, value));
+ }
+ }
+ });
+ }
+ }
+ }, 0L, delayInMillis, TimeUnit.MILLISECONDS);
}
private void deleteStoreDirectories() {
diff --git a/samza-test/src/test/java/org/apache/samza/test/controlmessages/TestData.java b/samza-test/src/test/java/org/apache/samza/test/TestData.java
similarity index 97%
rename from samza-test/src/test/java/org/apache/samza/test/controlmessages/TestData.java
rename to samza-test/src/test/java/org/apache/samza/test/TestData.java
index 52d8380..5b2f137 100644
--- a/samza-test/src/test/java/org/apache/samza/test/controlmessages/TestData.java
+++ b/samza-test/src/test/java/org/apache/samza/test/TestData.java
@@ -17,7 +17,7 @@
* under the License.
*/
-package org.apache.samza.test.controlmessages;
+package org.apache.samza.test;
import java.io.Serializable;
import org.apache.samza.SamzaException;
diff --git a/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java
index 849f64c..bbe03ae 100644
--- a/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java
+++ b/samza-test/src/test/java/org/apache/samza/test/controlmessages/WatermarkIntegrationTest.java
@@ -67,8 +67,8 @@
import org.apache.samza.system.SystemStreamPartition;
import org.apache.samza.task.StreamOperatorTask;
import org.apache.samza.task.TestStreamOperatorTask;
-import org.apache.samza.test.controlmessages.TestData.PageView;
-import org.apache.samza.test.controlmessages.TestData.PageViewJsonSerdeFactory;
+import org.apache.samza.test.TestData.PageView;
+import org.apache.samza.test.TestData.PageViewJsonSerdeFactory;
import org.apache.samza.test.harness.IntegrationTestHarness;
import org.apache.samza.test.util.SimpleSystemAdmin;
import org.apache.samza.test.util.TestStreamConsumer;
diff --git a/samza-test/src/test/java/org/apache/samza/test/drain/DrainHighLevelApiIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/drain/DrainHighLevelApiIntegrationTest.java
new file mode 100644
index 0000000..c1a6f91
--- /dev/null
+++ b/samza-test/src/test/java/org/apache/samza/test/drain/DrainHighLevelApiIntegrationTest.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.drain;
+
+import com.google.common.collect.ImmutableMap;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import org.apache.samza.application.StreamApplication;
+import org.apache.samza.application.descriptors.StreamApplicationDescriptor;
+import org.apache.samza.config.ApplicationConfig;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.drain.DrainUtils;
+import org.apache.samza.metadatastore.InMemoryMetadataStoreFactory;
+import org.apache.samza.metadatastore.MetadataStore;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.operators.KV;
+import org.apache.samza.serializers.IntegerSerde;
+import org.apache.samza.serializers.KVSerde;
+import org.apache.samza.serializers.NoOpSerde;
+import org.apache.samza.system.descriptors.DelegatingSystemDescriptor;
+import org.apache.samza.system.descriptors.GenericInputDescriptor;
+import org.apache.samza.test.framework.TestRunner;
+import org.apache.samza.test.framework.system.descriptors.InMemoryInputDescriptor;
+import org.apache.samza.test.framework.system.descriptors.InMemorySystemDescriptor;
+import org.apache.samza.test.table.TestTableData;
+import org.apache.samza.test.table.TestTableData.PageView;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+
+/**
+ * End to end integration test to check drain functionality with samza high-level API.
+ */
+public class DrainHighLevelApiIntegrationTest {
+ private static final List<PageView> RECEIVED = new ArrayList<>();
+ private static final String SYSTEM_NAME = "test";
+ private static final String STREAM_ID = "PageView";
+
+ private static class PageViewEventCountHighLevelApplication implements StreamApplication {
+ @Override
+ public void describe(StreamApplicationDescriptor appDescriptor) {
+ DelegatingSystemDescriptor sd = new DelegatingSystemDescriptor(SYSTEM_NAME);
+ GenericInputDescriptor<KV<String, PageView>> isd =
+ sd.getInputDescriptor(STREAM_ID, KVSerde.of(new NoOpSerde<>(), new NoOpSerde<>()));
+ appDescriptor.getInputStream(isd)
+ .map(KV::getValue)
+ .partitionBy(PageView::getMemberId, pv -> pv,
+ KVSerde.of(new IntegerSerde(), new TestTableData.PageViewJsonSerde()), "p1")
+ .sink((m, collector, coordinator) -> {
+ RECEIVED.add(m.getValue());
+ });
+ }
+ }
+
+ // The test can be occasionally flaky, so we set Ignore annotation
+ // Remove ignore annotation and run the test as follows:
+ // ./gradlew :samza-test:test --tests org.apache.samza.test.drain.DrainHighLevelApiIntegrationTest -PscalaSuffix=2.12
+ @Ignore
+ @Test
+ public void testPipeline() {
+ String runId = "DrainTestId";
+ int numPageViews = 40;
+
+ InMemorySystemDescriptor isd = new InMemorySystemDescriptor(SYSTEM_NAME);
+ InMemoryInputDescriptor<TestTableData.PageView> inputDescriptor = isd.getInputDescriptor(STREAM_ID, new NoOpSerde<>());
+ InMemoryMetadataStoreFactory metadataStoreFactory = new InMemoryMetadataStoreFactory();
+
+ Map<String, String> customConfig = ImmutableMap.of(
+ ApplicationConfig.APP_RUN_ID, runId,
+ JobConfig.DRAIN_MONITOR_POLL_INTERVAL_MILLIS, "100",
+ JobConfig.DRAIN_MONITOR_ENABLED, "true");
+
+ // Create a TestRunner
+ // Set a InMemoryMetadataFactory. We will use this factory in the test to create a metadata store and
+ // write drain message to it
+ // Mock data comprises of 40 messages across 4 partitions. TestRunner adds a 1 second delay between messages
+ // per partition when writing messages to the InputStream
+ TestRunner testRunner = TestRunner.of(new PageViewEventCountHighLevelApplication())
+ .setInMemoryMetadataFactory(metadataStoreFactory)
+ .addConfig(customConfig)
+ .addInputStream(inputDescriptor, TestTableData.generatePartitionedPageViews(numPageViews, 4), 1000L);
+
+ Config configFromRunner = testRunner.getConfig();
+ MetadataStore metadataStore = metadataStoreFactory.getMetadataStore("NoOp", configFromRunner, new MetricsRegistryMap());
+
+ // write drain message after a delay
+ ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor();
+ executorService.schedule(new Callable<String>() {
+ @Override
+ public String call() throws Exception {
+ UUID uuid = DrainUtils.writeDrainNotification(metadataStore, runId);
+ return uuid.toString();
+ }
+ }, 2000L, TimeUnit.MILLISECONDS);
+
+ testRunner.run(Duration.ofSeconds(20));
+
+ assertTrue(RECEIVED.size() < numPageViews && RECEIVED.size() > 0);
+ }
+}
diff --git a/samza-test/src/test/java/org/apache/samza/test/drain/DrainLowLevelApiIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/drain/DrainLowLevelApiIntegrationTest.java
new file mode 100644
index 0000000..7579c6e
--- /dev/null
+++ b/samza-test/src/test/java/org/apache/samza/test/drain/DrainLowLevelApiIntegrationTest.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.drain;
+
+import com.google.common.collect.ImmutableMap;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+import org.apache.samza.application.TaskApplication;
+import org.apache.samza.application.descriptors.TaskApplicationDescriptor;
+import org.apache.samza.config.ApplicationConfig;
+import org.apache.samza.config.Config;
+import org.apache.samza.config.JobConfig;
+import org.apache.samza.context.Context;
+import org.apache.samza.drain.DrainUtils;
+import org.apache.samza.metadatastore.InMemoryMetadataStoreFactory;
+import org.apache.samza.metadatastore.MetadataStore;
+import org.apache.samza.metrics.MetricsRegistryMap;
+import org.apache.samza.serializers.NoOpSerde;
+import org.apache.samza.system.IncomingMessageEnvelope;
+import org.apache.samza.system.descriptors.DelegatingSystemDescriptor;
+import org.apache.samza.system.descriptors.GenericInputDescriptor;
+import org.apache.samza.task.DrainListenerTask;
+import org.apache.samza.task.EndOfStreamListenerTask;
+import org.apache.samza.task.InitableTask;
+import org.apache.samza.task.MessageCollector;
+import org.apache.samza.task.StreamTask;
+import org.apache.samza.task.StreamTaskFactory;
+import org.apache.samza.task.TaskCoordinator;
+import org.apache.samza.test.framework.TestRunner;
+import org.apache.samza.test.framework.system.descriptors.InMemoryInputDescriptor;
+import org.apache.samza.test.framework.system.descriptors.InMemorySystemDescriptor;
+import org.apache.samza.test.table.TestTableData;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+
+/**
+ * End to end integration test to check drain functionality with samza low-level API.
+ * */
+public class DrainLowLevelApiIntegrationTest {
+ private static final List<TestTableData.PageView> RECEIVED = new ArrayList<>();
+
+ private static Integer drainCounter = 0;
+ private static Integer eosCounter = 0;
+
+ private static final String SYSTEM_NAME = "test";
+ private static final String STREAM_ID = "PageView";
+
+ private static class PageViewEventCountLowLevelApplication implements TaskApplication {
+ @Override
+ public void describe(TaskApplicationDescriptor appDescriptor) {
+ DelegatingSystemDescriptor ksd = new DelegatingSystemDescriptor(SYSTEM_NAME);
+ GenericInputDescriptor<TestTableData.PageView> pageViewIsd = ksd.getInputDescriptor(STREAM_ID, new NoOpSerde<>());
+ appDescriptor
+ .withInputStream(pageViewIsd)
+ .withTaskFactory((StreamTaskFactory) PageViewEventCountStreamTask::new);
+ }
+ }
+
+ private static class PageViewEventCountStreamTask implements StreamTask, InitableTask, DrainListenerTask, EndOfStreamListenerTask {
+ public PageViewEventCountStreamTask() {
+ }
+
+ @Override
+ public void init(Context context) {
+ }
+
+ @Override
+ public void process(IncomingMessageEnvelope message, MessageCollector collector, TaskCoordinator coordinator) {
+ TestTableData.PageView pv = (TestTableData.PageView) message.getMessage();
+ RECEIVED.add(pv);
+ }
+
+ @Override
+ public void onDrain(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+ drainCounter++;
+ }
+
+ @Override
+ public void onEndOfStream(MessageCollector collector, TaskCoordinator coordinator) throws Exception {
+ eosCounter++;
+ }
+ }
+
+ // The test can be occasionally flaky, so we set Ignore annotation
+ // Remove ignore annotation and run the test as follows:
+ // ./gradlew :samza-test:test --tests org.apache.samza.test.drain.DrainHighLevelApiIntegrationTest -PscalaSuffix=2.12
+ @Ignore
+ @Test
+ public void testPipeline() {
+ int numPageViews = 40;
+ InMemorySystemDescriptor isd = new InMemorySystemDescriptor("test");
+ InMemoryInputDescriptor<TestTableData.PageView> inputDescriptor = isd.getInputDescriptor("PageView", new NoOpSerde<>());
+ InMemoryMetadataStoreFactory metadataStoreFactory = new InMemoryMetadataStoreFactory();
+
+ String runId = "DrainTestId";
+ Map<String, String> customConfig = ImmutableMap.of(
+ ApplicationConfig.APP_RUN_ID, runId,
+ JobConfig.DRAIN_MONITOR_POLL_INTERVAL_MILLIS, "100",
+ JobConfig.DRAIN_MONITOR_ENABLED, "true");
+
+ // Create a TestRunner
+ // Set a InMemoryMetadataFactory. We will use this factory in the test to create a metadata store and
+ // write drain message to it
+ // Mock data comprises of 40 messages across 4 partitions. TestRunner adds a 1 second delay between messages
+ // per partition when writing messages to the InputStream
+ TestRunner testRunner = TestRunner.of(new PageViewEventCountLowLevelApplication())
+ .setInMemoryMetadataFactory(metadataStoreFactory)
+ .addConfig(customConfig)
+ .addInputStream(inputDescriptor, TestTableData.generatePartitionedPageViews(numPageViews, 4), 1000L);
+
+ Config configFromRunner = testRunner.getConfig();
+ MetadataStore
+ metadataStore = metadataStoreFactory.getMetadataStore("NoOp", configFromRunner, new MetricsRegistryMap());
+
+ // write drain message after a delay
+ ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor();
+ executorService.schedule(new Callable<String>() {
+ @Override
+ public String call() throws Exception {
+ UUID uuid = DrainUtils.writeDrainNotification(metadataStore, runId);
+ return uuid.toString();
+ }
+ }, 2000L, TimeUnit.MILLISECONDS);
+
+ testRunner.run(Duration.ofSeconds(25));
+
+ assertTrue(RECEIVED.size() < numPageViews && RECEIVED.size() > 0);
+ }
+}
diff --git a/samza-test/src/test/java/org/apache/samza/test/framework/StreamApplicationIntegrationTest.java b/samza-test/src/test/java/org/apache/samza/test/framework/StreamApplicationIntegrationTest.java
index 6afc77c..2a6e70c 100644
--- a/samza-test/src/test/java/org/apache/samza/test/framework/StreamApplicationIntegrationTest.java
+++ b/samza-test/src/test/java/org/apache/samza/test/framework/StreamApplicationIntegrationTest.java
@@ -42,7 +42,7 @@
import org.apache.samza.system.kafka.descriptors.KafkaOutputDescriptor;
import org.apache.samza.system.kafka.descriptors.KafkaSystemDescriptor;
import org.apache.samza.table.Table;
-import org.apache.samza.test.controlmessages.TestData;
+import org.apache.samza.test.TestData;
import org.apache.samza.test.framework.system.descriptors.InMemoryInputDescriptor;
import org.apache.samza.test.framework.system.descriptors.InMemoryOutputDescriptor;
import org.apache.samza.test.framework.system.descriptors.InMemorySystemDescriptor;
@@ -53,7 +53,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import static org.apache.samza.test.controlmessages.TestData.PageView;
+import static org.apache.samza.test.TestData.PageView;
public class StreamApplicationIntegrationTest {
private static final Logger LOG = LoggerFactory.getLogger(StreamApplicationIntegrationTest.class);