blob: 4cab1859f699c598c1a528840164cd309b2fc5e7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.samza.container
import java.util.Collections
import com.google.common.collect.ImmutableSet
import org.apache.samza.{Partition, SamzaException}
import org.apache.samza.checkpoint.{Checkpoint, CheckpointedChangelogOffset, OffsetManager}
import org.apache.samza.config.MapConfig
import org.apache.samza.context.{TaskContext => _, _}
import org.apache.samza.job.model.TaskModel
import org.apache.samza.metrics.Counter
import org.apache.samza.storage.NonTransactionalStateTaskStorageManager
import org.apache.samza.system.{IncomingMessageEnvelope, StreamMetadataCache, SystemAdmin, SystemConsumers, SystemStream, SystemStreamMetadata, _}
import org.apache.samza.table.TableManager
import org.apache.samza.task._
import org.junit.Assert._
import org.junit.{Before, Test}
import org.mockito.Matchers._
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.mockito.{ArgumentCaptor, Matchers, Mock, MockitoAnnotations}
import org.scalatest.junit.AssertionsForJUnit
import org.scalatest.mockito.MockitoSugar
import scala.collection.JavaConverters._
class TestTaskInstance extends AssertionsForJUnit with MockitoSugar {
private val SYSTEM_NAME = "test-system"
private val TASK_NAME = new TaskName("taskName")
private val SYSTEM_STREAM_PARTITION =
new SystemStreamPartition(new SystemStream(SYSTEM_NAME, "test-stream"), new Partition(0))
private val SYSTEM_STREAM_PARTITIONS = ImmutableSet.of(SYSTEM_STREAM_PARTITION)
@Mock
private var task: AllTask = null
@Mock
private var taskModel: TaskModel = null
@Mock
private var metrics: TaskInstanceMetrics = null
@Mock
private var systemAdmins: SystemAdmins = null
@Mock
private var systemAdmin: SystemAdmin = null
@Mock
private var consumerMultiplexer: SystemConsumers = null
@Mock
private var collector: TaskInstanceCollector = null
@Mock
private var offsetManager: OffsetManager = null
@Mock
private var taskStorageManager: NonTransactionalStateTaskStorageManager = null
@Mock
private var taskTableManager: TableManager = null
// not a mock; using MockTaskInstanceExceptionHandler
private var taskInstanceExceptionHandler: MockTaskInstanceExceptionHandler = null
@Mock
private var jobContext: JobContext = null
@Mock
private var containerContext: ContainerContext = null
@Mock
private var applicationContainerContext: ApplicationContainerContext = null
@Mock
private var applicationTaskContextFactory: ApplicationTaskContextFactory[ApplicationTaskContext] = null
@Mock
private var applicationTaskContext: ApplicationTaskContext = null
@Mock
private var externalContext: ExternalContext = null
private var taskInstance: TaskInstance = null
@Before
def setup(): Unit = {
MockitoAnnotations.initMocks(this)
// not using Mockito mock since Mockito doesn't work well with the call-by-name argument in maybeHandle
this.taskInstanceExceptionHandler = new MockTaskInstanceExceptionHandler
when(this.taskModel.getTaskName).thenReturn(TASK_NAME)
when(this.applicationTaskContextFactory.create(Matchers.eq(this.externalContext), Matchers.eq(this.jobContext),
Matchers.eq(this.containerContext), any(), Matchers.eq(this.applicationContainerContext)))
.thenReturn(this.applicationTaskContext)
when(this.systemAdmins.getSystemAdmin(SYSTEM_NAME)).thenReturn(this.systemAdmin)
when(this.jobContext.getConfig).thenReturn(new MapConfig(Collections.singletonMap("task.commit.ms", "-1")))
setupTaskInstance(Some(this.applicationTaskContextFactory))
}
@Test
def testProcess() {
val processesCounter = mock[Counter]
when(this.metrics.processes).thenReturn(processesCounter)
val messagesActuallyProcessedCounter = mock[Counter]
when(this.metrics.messagesActuallyProcessed).thenReturn(messagesActuallyProcessedCounter)
when(this.offsetManager.getStartingOffset(TASK_NAME, SYSTEM_STREAM_PARTITION)).thenReturn(Some("0"))
val envelope = new IncomingMessageEnvelope(SYSTEM_STREAM_PARTITION, "0", null, null)
val coordinator = mock[ReadableCoordinator]
val callbackFactory = mock[TaskCallbackFactory]
val callback = mock[TaskCallback]
when(callbackFactory.createCallback()).thenReturn(callback)
this.taskInstance.process(envelope, coordinator, callbackFactory)
assertEquals(1, this.taskInstanceExceptionHandler.numTimesCalled)
verify(this.task).processAsync(envelope, this.collector, coordinator, callback)
verify(processesCounter).inc()
verify(messagesActuallyProcessedCounter).inc()
}
@Test
def testWindow() {
val windowsCounter = mock[Counter]
when(this.metrics.windows).thenReturn(windowsCounter)
val coordinator = mock[ReadableCoordinator]
this.taskInstance.window(coordinator)
assertEquals(1, this.taskInstanceExceptionHandler.numTimesCalled)
verify(this.task).window(this.collector, coordinator)
verify(windowsCounter).inc()
}
@Test
def testInitTask(): Unit = {
this.taskInstance.initTask
val contextCaptor = ArgumentCaptor.forClass(classOf[Context])
verify(this.task).init(contextCaptor.capture())
val actualContext = contextCaptor.getValue
assertEquals(this.jobContext, actualContext.getJobContext)
assertEquals(this.containerContext, actualContext.getContainerContext)
assertEquals(this.taskModel, actualContext.getTaskContext.getTaskModel)
assertEquals(this.applicationContainerContext, actualContext.getApplicationContainerContext)
assertEquals(this.applicationTaskContext, actualContext.getApplicationTaskContext)
assertEquals(this.externalContext, actualContext.getExternalContext)
verify(this.applicationTaskContext).start()
}
@Test
def testShutdownTask(): Unit = {
this.taskInstance.shutdownTask
verify(this.applicationTaskContext).stop()
verify(this.task).close()
}
/**
* Tests that the init() method of task can override the existing offset assignment.
* This helps verify wiring for the task context (i.e. offset manager).
*/
@Test
def testManualOffsetReset() {
when(this.task.init(any())).thenAnswer(new Answer[Void] {
override def answer(invocation: InvocationOnMock): Void = {
val context = invocation.getArgumentAt(0, classOf[Context])
context.getTaskContext.setStartingOffset(SYSTEM_STREAM_PARTITION, "10")
null
}
})
taskInstance.initTask
verify(this.offsetManager).setStartingOffset(TASK_NAME, SYSTEM_STREAM_PARTITION, "10")
verifyNoMoreInteractions(this.offsetManager)
}
@Test
def testIgnoreMessagesOlderThanStartingOffsets() {
val processesCounter = mock[Counter]
when(this.metrics.processes).thenReturn(processesCounter)
val messagesActuallyProcessedCounter = mock[Counter]
when(this.metrics.messagesActuallyProcessed).thenReturn(messagesActuallyProcessedCounter)
when(this.offsetManager.getStartingOffset(TASK_NAME, SYSTEM_STREAM_PARTITION)).thenReturn(Some("5"))
when(this.systemAdmin.offsetComparator(any(), any())).thenAnswer(new Answer[Integer] {
override def answer(invocation: InvocationOnMock): Integer = {
val offset1 = invocation.getArgumentAt(0, classOf[String])
val offset2 = invocation.getArgumentAt(1, classOf[String])
offset1.toLong.compareTo(offset2.toLong)
}
})
val oldEnvelope = new IncomingMessageEnvelope(SYSTEM_STREAM_PARTITION, "0", null, null)
val newEnvelope0 = new IncomingMessageEnvelope(SYSTEM_STREAM_PARTITION, "5", null, null)
val newEnvelope1 = new IncomingMessageEnvelope(SYSTEM_STREAM_PARTITION, "7", null, null)
val mockCoordinator = mock[ReadableCoordinator]
val mockCallback = mock[TaskCallback]
val mockCallbackFactory = mock[TaskCallbackFactory]
when(mockCallbackFactory.createCallback()).thenReturn(mockCallback)
this.taskInstance.process(oldEnvelope, mockCoordinator, mockCallbackFactory)
this.taskInstance.process(newEnvelope0, mockCoordinator, mockCallbackFactory)
this.taskInstance.process(newEnvelope1, mockCoordinator, mockCallbackFactory)
verify(this.task).processAsync(Matchers.eq(newEnvelope0), Matchers.eq(this.collector), Matchers.eq(mockCoordinator), Matchers.eq(mockCallback))
verify(this.task).processAsync(Matchers.eq(newEnvelope1), Matchers.eq(this.collector), Matchers.eq(mockCoordinator), Matchers.eq(mockCallback))
verify(this.task, never()).processAsync(Matchers.eq(oldEnvelope), any(), any(), any())
verify(processesCounter, times(3)).inc()
verify(messagesActuallyProcessedCounter, times(2)).inc()
}
@Test
def testCommitOrder() {
val commitsCounter = mock[Counter]
when(this.metrics.commits).thenReturn(commitsCounter)
val inputOffsets = new Checkpoint(Map(SYSTEM_STREAM_PARTITION -> "4").asJava)
val changelogSSP = new SystemStreamPartition(new SystemStream(SYSTEM_NAME, "test-changelog-stream"), new Partition(0))
val changelogOffsets = Map(changelogSSP -> Some("5"))
when(this.offsetManager.buildCheckpoint(TASK_NAME)).thenReturn(inputOffsets)
when(this.taskStorageManager.flush()).thenReturn(changelogOffsets)
doNothing().when(this.taskStorageManager).checkpoint(any(), any[Map[SystemStreamPartition, Option[String]]])
taskInstance.commit
val mockOrder = inOrder(this.offsetManager, this.collector, this.taskTableManager, this.taskStorageManager)
// We must first get a snapshot of the input offsets so it doesn't change while we flush. SAMZA-1384
mockOrder.verify(this.offsetManager).buildCheckpoint(TASK_NAME)
// Producers must be flushed next and ideally the output would be flushed before the changelog
// s.t. the changelog and checkpoints (state and inputs) are captured last
mockOrder.verify(this.collector).flush
// Tables should be flushed next
mockOrder.verify(this.taskTableManager).flush()
// Local state should be flushed next next
mockOrder.verify(this.taskStorageManager).flush()
// Stores checkpoints should be created next with the newest changelog offsets
mockOrder.verify(this.taskStorageManager).checkpoint(any(), Matchers.eq(changelogOffsets))
// Input checkpoint should be written with the snapshot captured at the beginning of commit and the
// newest changelog offset captured during storage manager flush
val captor = ArgumentCaptor.forClass(classOf[Checkpoint])
mockOrder.verify(offsetManager).writeCheckpoint(any(), captor.capture)
val cp = captor.getValue
assertEquals("4", cp.getOffsets.get(SYSTEM_STREAM_PARTITION))
assertEquals("5", CheckpointedChangelogOffset.fromString(cp.getOffsets.get(changelogSSP)).getOffset)
// Old checkpointed stores should be cleared
mockOrder.verify(this.taskStorageManager).removeOldCheckpoints(any())
verify(commitsCounter).inc()
}
@Test
def testEmptyChangelogSSPOffsetInCommit() { // e.g. if changelog topic is empty
val commitsCounter = mock[Counter]
when(this.metrics.commits).thenReturn(commitsCounter)
val inputOffsets = new Checkpoint(Map(SYSTEM_STREAM_PARTITION -> "4").asJava)
val changelogSSP = new SystemStreamPartition(new SystemStream(SYSTEM_NAME, "test-changelog-stream"), new Partition(0))
val changelogOffsets = Map(changelogSSP -> None)
when(this.offsetManager.buildCheckpoint(TASK_NAME)).thenReturn(inputOffsets)
when(this.taskStorageManager.flush()).thenReturn(changelogOffsets)
taskInstance.commit
val captor = ArgumentCaptor.forClass(classOf[Checkpoint])
verify(offsetManager).writeCheckpoint(any(), captor.capture)
val cp = captor.getValue
assertEquals("4", cp.getOffsets.get(SYSTEM_STREAM_PARTITION))
val message = cp.getOffsets.get(changelogSSP)
val checkpointedOffset = CheckpointedChangelogOffset.fromString(message)
assertNull(checkpointedOffset.getOffset)
assertNotNull(checkpointedOffset.getCheckpointId)
verify(commitsCounter).inc()
}
@Test
def testEmptyChangelogOffsetsInCommit() { // e.g. if stores have no changelogs
val commitsCounter = mock[Counter]
when(this.metrics.commits).thenReturn(commitsCounter)
val inputOffsets = new Checkpoint(Map(SYSTEM_STREAM_PARTITION -> "4").asJava)
val changelogOffsets = Map[SystemStreamPartition, Option[String]]()
when(this.offsetManager.buildCheckpoint(TASK_NAME)).thenReturn(inputOffsets)
when(this.taskStorageManager.flush()).thenReturn(changelogOffsets)
taskInstance.commit
val captor = ArgumentCaptor.forClass(classOf[Checkpoint])
verify(offsetManager).writeCheckpoint(any(), captor.capture)
val cp = captor.getValue
assertEquals("4", cp.getOffsets.get(SYSTEM_STREAM_PARTITION))
assertEquals(1, cp.getOffsets.size())
verify(commitsCounter).inc()
}
@Test
def testCommitFailsIfErrorGettingChangelogOffset() { // required for transactional state
val commitsCounter = mock[Counter]
when(this.metrics.commits).thenReturn(commitsCounter)
val inputOffsets = new Checkpoint(Map(SYSTEM_STREAM_PARTITION -> "4").asJava)
when(this.offsetManager.buildCheckpoint(TASK_NAME)).thenReturn(inputOffsets)
when(this.taskStorageManager.flush()).thenThrow(new SamzaException("Error getting changelog offsets"))
try {
taskInstance.commit
} catch {
case e: SamzaException =>
// exception is expected, container should fail if could not get changelog offsets.
return
}
fail("Should have failed commit if error getting newest changelog offests")
}
@Test
def testCommitFailsIfErrorCreatingStoreCheckpoints() { // required for transactional state
val commitsCounter = mock[Counter]
when(this.metrics.commits).thenReturn(commitsCounter)
val inputOffsets = new Checkpoint(Map(SYSTEM_STREAM_PARTITION -> "4").asJava)
when(this.offsetManager.buildCheckpoint(TASK_NAME)).thenReturn(inputOffsets)
when(this.taskStorageManager.flush()).thenReturn(Map[SystemStreamPartition, Option[String]]())
when(this.taskStorageManager.checkpoint(any(), any())).thenThrow(new SamzaException("Error creating store checkpoint"))
try {
taskInstance.commit
} catch {
case e: SamzaException =>
// exception is expected, container should fail if could not get changelog offsets.
return
}
fail("Should have failed commit if error getting newest changelog offests")
}
@Test
def testCommitContinuesIfErrorClearingOldCheckpoints() { // required for transactional state
val commitsCounter = mock[Counter]
when(this.metrics.commits).thenReturn(commitsCounter)
val inputOffsets = new Checkpoint(Map(SYSTEM_STREAM_PARTITION -> "4").asJava)
when(this.offsetManager.buildCheckpoint(TASK_NAME)).thenReturn(inputOffsets)
when(this.taskStorageManager.flush()).thenReturn(Map[SystemStreamPartition, Option[String]]())
doNothing().when(this.taskStorageManager).checkpoint(any(), any())
when(this.taskStorageManager.removeOldCheckpoints(any()))
.thenThrow(new SamzaException("Error clearing old checkpoints"))
try {
taskInstance.commit
} catch {
case e: SamzaException =>
// exception is expected, container should fail if could not get changelog offsets.
fail("Exception from removeOldCheckpoints should have been caught")
}
}
/**
* Given that no application task context factory is provided, then no lifecycle calls should be made.
*/
@Test
def testNoApplicationTaskContextFactoryProvided() {
setupTaskInstance(None)
this.taskInstance.initTask
this.taskInstance.shutdownTask
verifyZeroInteractions(this.applicationTaskContext)
}
@Test(expected = classOf[SystemProducerException])
def testProducerExceptionsIsPropagated() {
when(this.metrics.commits).thenReturn(mock[Counter])
when(this.collector.flush).thenThrow(new SystemProducerException("systemProducerException"))
try {
taskInstance.commit // Should not swallow the SystemProducerException
} finally {
verify(offsetManager, never()).writeCheckpoint(any(), any())
}
}
@Test
def testInitCaughtUpMapping() {
val offsetManagerMock = mock[OffsetManager]
when(offsetManagerMock.getStartingOffset(anyObject(), anyObject())).thenReturn(Option("42"))
val cacheMock = mock[StreamMetadataCache]
val systemStreamMetadata = mock[SystemStreamMetadata]
when(cacheMock.getSystemStreamMetadata(anyObject(), anyBoolean()))
.thenReturn(systemStreamMetadata)
val sspMetadata = mock[SystemStreamMetadata.SystemStreamPartitionMetadata]
when(sspMetadata.getUpcomingOffset).thenReturn("42")
when(systemStreamMetadata.getSystemStreamPartitionMetadata)
.thenReturn(Collections.singletonMap(new Partition(0), sspMetadata))
val ssp = new SystemStreamPartition("test-system", "test-stream", new Partition(0))
val inputStreamMetadata = collection.Map(ssp.getSystemStream -> systemStreamMetadata)
val taskInstance = new TaskInstance(this.task,
this.taskModel,
this.metrics,
this.systemAdmins,
this.consumerMultiplexer,
this.collector,
offsetManager = offsetManagerMock,
storageManager = this.taskStorageManager,
tableManager = this.taskTableManager,
systemStreamPartitions = ImmutableSet.of(ssp),
exceptionHandler = this.taskInstanceExceptionHandler,
streamMetadataCache = cacheMock,
inputStreamMetadata = Map.empty ++ inputStreamMetadata,
jobContext = this.jobContext,
containerContext = this.containerContext,
applicationContainerContextOption = Some(this.applicationContainerContext),
applicationTaskContextFactoryOption = Some(this.applicationTaskContextFactory),
externalContextOption = Some(this.externalContext))
taskInstance.initCaughtUpMapping()
assertTrue(taskInstance.ssp2CaughtupMapping(ssp))
}
private def setupTaskInstance(
applicationTaskContextFactory: Option[ApplicationTaskContextFactory[ApplicationTaskContext]]): Unit = {
this.taskInstance = new TaskInstance(this.task,
this.taskModel,
this.metrics,
this.systemAdmins,
this.consumerMultiplexer,
this.collector,
offsetManager = this.offsetManager,
storageManager = this.taskStorageManager,
tableManager = this.taskTableManager,
systemStreamPartitions = SYSTEM_STREAM_PARTITIONS,
exceptionHandler = this.taskInstanceExceptionHandler,
jobContext = this.jobContext,
containerContext = this.containerContext,
applicationContainerContextOption = Some(this.applicationContainerContext),
applicationTaskContextFactoryOption = applicationTaskContextFactory,
externalContextOption = Some(this.externalContext))
}
/**
* Task type which has all task traits, which can be mocked.
*/
trait AllTask extends AsyncStreamTask with InitableTask with ClosableTask with WindowableTask {}
/**
* Mock version of [TaskInstanceExceptionHandler] which just does a passthrough execution and keeps track of the
* number of times it is called. This is used to verify that the handler does get used to wrap the actual processing.
*/
class MockTaskInstanceExceptionHandler extends TaskInstanceExceptionHandler {
var numTimesCalled = 0
override def maybeHandle(tryCodeBlock: => Unit): Unit = {
numTimesCalled += 1
tryCodeBlock
}
}
}