| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.samza.test.integration |
| |
| import java.util.Properties |
| import java.util.concurrent.CountDownLatch |
| import java.util.concurrent.TimeUnit |
| |
| import kafka.admin.AdminUtils |
| import kafka.common.ErrorMapping |
| import kafka.consumer.Consumer |
| import kafka.consumer.ConsumerConfig |
| import kafka.message.MessageAndMetadata |
| import kafka.producer.KeyedMessage |
| import kafka.producer.Producer |
| import kafka.producer.ProducerConfig |
| import kafka.server.KafkaConfig |
| import kafka.server.KafkaServer |
| import kafka.utils.TestUtils |
| import kafka.utils.TestZKUtils |
| import kafka.utils.Utils |
| import kafka.utils.ZKStringSerializer |
| import kafka.zk.EmbeddedZookeeper |
| |
| import org.I0Itec.zkclient.ZkClient |
| import org.apache.samza.Partition |
| import org.apache.samza.checkpoint.Checkpoint |
| import org.apache.samza.config.Config |
| import org.apache.samza.job.local.ThreadJobFactory |
| import org.apache.samza.config.MapConfig |
| import org.apache.samza.container.TaskName |
| import org.apache.samza.job.ApplicationStatus |
| import org.apache.samza.job.StreamJob |
| import org.apache.samza.storage.kv.KeyValueStore |
| import org.apache.samza.system.kafka.TopicMetadataCache |
| import org.apache.samza.system.{SystemStreamPartition, IncomingMessageEnvelope} |
| import org.apache.samza.task.InitableTask |
| import org.apache.samza.task.MessageCollector |
| import org.apache.samza.task.StreamTask |
| import org.apache.samza.task.TaskContext |
| import org.apache.samza.task.TaskCoordinator |
| import org.apache.samza.task.TaskCoordinator.RequestScope |
| import org.apache.samza.util.ClientUtilTopicMetadataStore |
| import org.apache.samza.util.TopicMetadataStore |
| import org.junit.Assert._ |
| import org.junit.{BeforeClass, AfterClass, Test} |
| |
| import scala.collection.JavaConversions._ |
| import scala.collection.mutable.ArrayBuffer |
| import scala.collection.mutable.HashMap |
| import scala.collection.mutable.SynchronizedMap |
| |
| object TestStatefulTask { |
| val INPUT_TOPIC = "input" |
| val STATE_TOPIC = "mystore" |
| val TOTAL_TASK_NAMES = 1 |
| val REPLICATION_FACTOR = 3 |
| |
| val zkConnect: String = TestZKUtils.zookeeperConnect |
| var zkClient: ZkClient = null |
| val zkConnectionTimeout = 6000 |
| val zkSessionTimeout = 6000 |
| |
| val brokerId1 = 0 |
| val brokerId2 = 1 |
| val brokerId3 = 2 |
| val ports = TestUtils.choosePorts(3) |
| val (port1, port2, port3) = (ports(0), ports(1), ports(2)) |
| |
| val props1 = TestUtils.createBrokerConfig(brokerId1, port1) |
| val props2 = TestUtils.createBrokerConfig(brokerId2, port2) |
| val props3 = TestUtils.createBrokerConfig(brokerId3, port3) |
| |
| val config = new java.util.Properties() |
| val brokers = "localhost:%d,localhost:%d,localhost:%d" format (port1, port2, port3) |
| config.put("metadata.broker.list", brokers) |
| config.put("producer.type", "sync") |
| config.put("request.required.acks", "-1") |
| config.put("serializer.class", "kafka.serializer.StringEncoder"); |
| val producerConfig = new ProducerConfig(config) |
| var producer: Producer[String, String] = null |
| val cp1 = new Checkpoint(Map(new SystemStreamPartition("kafka", "topic", new Partition(0)) -> "123")) |
| val cp2 = new Checkpoint(Map(new SystemStreamPartition("kafka", "topic", new Partition(0)) -> "12345")) |
| var zookeeper: EmbeddedZookeeper = null |
| var server1: KafkaServer = null |
| var server2: KafkaServer = null |
| var server3: KafkaServer = null |
| var metadataStore: TopicMetadataStore = null |
| |
| @BeforeClass |
| def beforeSetupServers { |
| zookeeper = new EmbeddedZookeeper(zkConnect) |
| server1 = TestUtils.createServer(new KafkaConfig(props1)) |
| server2 = TestUtils.createServer(new KafkaConfig(props2)) |
| server3 = TestUtils.createServer(new KafkaConfig(props3)) |
| zkClient = new ZkClient(zkConnect + "/", 6000, 6000, ZKStringSerializer) |
| producer = new Producer(producerConfig) |
| metadataStore = new ClientUtilTopicMetadataStore(brokers, "some-job-name") |
| |
| createTopics |
| validateTopics |
| } |
| |
| def createTopics { |
| AdminUtils.createTopic( |
| zkClient, |
| INPUT_TOPIC, |
| TOTAL_TASK_NAMES, |
| REPLICATION_FACTOR) |
| |
| AdminUtils.createTopic( |
| zkClient, |
| STATE_TOPIC, |
| TOTAL_TASK_NAMES, |
| REPLICATION_FACTOR) |
| } |
| |
| def validateTopics { |
| val topics = Set(STATE_TOPIC, INPUT_TOPIC) |
| var done = false |
| var retries = 0 |
| |
| while (!done && retries < 100) { |
| try { |
| val topicMetadataMap = TopicMetadataCache.getTopicMetadata(topics, "kafka", metadataStore.getTopicInfo) |
| |
| topics.foreach(topic => { |
| val topicMetadata = topicMetadataMap(topic) |
| val errorCode = topicMetadata.errorCode |
| |
| ErrorMapping.maybeThrowException(errorCode) |
| }) |
| |
| done = true |
| } catch { |
| case e: Exception => |
| System.err.println("Got exception while validating test topics. Waiting and retrying.", e) |
| retries += 1 |
| Thread.sleep(500) |
| } |
| } |
| |
| if (retries >= 100) { |
| fail("Unable to successfully create topics. Tried to validate %s times." format retries) |
| } |
| } |
| |
| @AfterClass |
| def afterCleanLogDirs { |
| server1.shutdown |
| server1.awaitShutdown() |
| server2.shutdown |
| server2.awaitShutdown() |
| server3.shutdown |
| server3.awaitShutdown() |
| Utils.rm(server1.config.logDirs) |
| Utils.rm(server2.config.logDirs) |
| Utils.rm(server3.config.logDirs) |
| zkClient.close |
| zookeeper.shutdown |
| } |
| } |
| |
| /** |
| * Test that does the following: |
| * |
| * 1. Starts ZK, and 3 kafka brokers. |
| * 2. Create two topics: input and mystore. |
| * 3. Validate that the topics were created successfully and have leaders. |
| * 4. Start a single partition of TestTask using ThreadJobFactory. |
| * 5. Send four messages to input (1,2,3,2), which contain one dupe (2). |
| * 6. Validate that all messages were received by TestTask. |
| * 7. Validate that TestTask called store.put() for all four messages, and that the messages ended up in the mystore topic. |
| * 8. Kill the job. |
| * 9. Start the job again. |
| * 10. Validate that the job restored all messages (1,2,3) to the store. |
| * 11. Send three more messages to input (4,5,5), and validate that TestTask receives them. |
| * 12. Kill the job again. |
| */ |
| class TestStatefulTask { |
| import TestStatefulTask._ |
| val jobFactory = new ThreadJobFactory |
| |
| val jobConfig = Map( |
| "job.factory.class" -> jobFactory.getClass.getCanonicalName, |
| "job.name" -> "hello-stateful-world", |
| "task.class" -> "org.apache.samza.test.integration.TestTask", |
| "task.inputs" -> "kafka.input", |
| "serializers.registry.string.class" -> "org.apache.samza.serializers.StringSerdeFactory", |
| "stores.mystore.factory" -> "org.apache.samza.storage.kv.KeyValueStorageEngineFactory", |
| "stores.mystore.key.serde" -> "string", |
| "stores.mystore.msg.serde" -> "string", |
| "stores.mystore.changelog" -> "kafka.mystore", |
| |
| "systems.kafka.samza.factory" -> "org.apache.samza.system.kafka.KafkaSystemFactory", |
| // Always start consuming at offset 0. This avoids a race condition between |
| // the producer and the consumer in this test (SAMZA-166, SAMZA-224). |
| "systems.kafka.samza.offset.default" -> "oldest", // applies to a nonempty topic |
| "systems.kafka.consumer.auto.offset.reset" -> "smallest", // applies to an empty topic |
| "systems.kafka.samza.msg.serde" -> "string", |
| "systems.kafka.consumer.zookeeper.connect" -> zkConnect, |
| "systems.kafka.producer.metadata.broker.list" -> ("localhost:%s" format port1), |
| // Since using state, need a checkpoint manager |
| "task.checkpoint.factory" -> "org.apache.samza.checkpoint.kafka.KafkaCheckpointManagerFactory", |
| "task.checkpoint.system" -> "kafka", |
| "task.checkpoint.replication.factor" -> "1", |
| // However, don't have the inputs use the checkpoint manager |
| // since the second part of the test expects to replay the input streams. |
| "systems.kafka.streams.input.samza.reset.offset" -> "true") |
| |
| @Test |
| def testShouldStartAndRestore { |
| // Have to do this in one test to guarantee ordering. |
| testShouldStartTaskForFirstTime |
| testShouldRestoreStore |
| } |
| |
| def testShouldStartTaskForFirstTime { |
| val (job, task) = startJob |
| |
| // Validate that restored is empty. |
| assertEquals(0, task.initFinished.getCount) |
| assertEquals(0, task.restored.size) |
| assertEquals(0, task.received.size) |
| |
| // Send some messages to input stream. |
| send(task, "1") |
| send(task, "2") |
| send(task, "3") |
| send(task, "2") |
| send(task, "99") |
| send(task, "-99") |
| |
| // Validate that messages appear in store stream. |
| val messages = readAll(STATE_TOPIC, 5, "testShouldStartTaskForFirstTime") |
| |
| assertEquals(6, messages.length) |
| assertEquals("1", messages(0)) |
| assertEquals("2", messages(1)) |
| assertEquals("3", messages(2)) |
| assertEquals("2", messages(3)) |
| assertEquals("99", messages(4)) |
| assertNull(messages(5)) |
| |
| stopJob(job) |
| } |
| |
| def testShouldRestoreStore { |
| val (job, task) = startJob |
| |
| // Validate that restored has expected data. |
| assertEquals(3, task.restored.size) |
| assertTrue(task.restored.contains("1")) |
| assertTrue(task.restored.contains("2")) |
| assertTrue(task.restored.contains("3")) |
| |
| var count = 0 |
| |
| // We should get the original four messages in the stream (1,2,3,2). |
| // Note that this will trigger four new outgoing messages to the STATE_TOPIC. |
| while (task.received.size < 4 && count < 100) { |
| Thread.sleep(600) |
| count += 1 |
| } |
| |
| assertTrue("Timed out waiting to received messages. Received thus far: " + task.received.size, count < 100) |
| |
| // Reset the count down latch after the 4 messages come in. |
| task.awaitMessage |
| |
| // Send some messages to input stream. |
| send(task, "4") |
| send(task, "5") |
| send(task, "5") |
| |
| // Validate that messages appear in store stream. |
| val messages = readAll(STATE_TOPIC, 14, "testShouldRestoreStore") |
| |
| assertEquals(15, messages.length) |
| // From initial start. |
| assertEquals("1", messages(0)) |
| assertEquals("2", messages(1)) |
| assertEquals("3", messages(2)) |
| assertEquals("2", messages(3)) |
| assertEquals("99", messages(4)) |
| assertNull(messages(5)) |
| // From second startup. |
| assertEquals("1", messages(6)) |
| assertEquals("2", messages(7)) |
| assertEquals("3", messages(8)) |
| assertEquals("2", messages(9)) |
| assertEquals("99", messages(10)) |
| assertNull(messages(11)) |
| // From sending in this method. |
| assertEquals("4", messages(12)) |
| assertEquals("5", messages(13)) |
| assertEquals("5", messages(14)) |
| |
| stopJob(job) |
| } |
| |
| /** |
| * Start a job for TestJob, and do some basic sanity checks around startup |
| * time, number of partitions, etc. |
| */ |
| def startJob = { |
| val job = jobFactory.getJob(new MapConfig(jobConfig)) |
| |
| // Start task. |
| job.submit |
| assertEquals(ApplicationStatus.Running, job.waitForStatus(ApplicationStatus.Running, 60000)) |
| TestTask.awaitTaskRegistered |
| val tasks = TestTask.tasks |
| |
| assertEquals("Should only have a single partition in this task", 1, tasks.size) |
| |
| val task = tasks.values.toList.head |
| |
| task.initFinished.await(60, TimeUnit.SECONDS) |
| assertEquals(0, task.initFinished.getCount) |
| |
| (job, task) |
| } |
| |
| /** |
| * Kill a job, and wait for an unsuccessful finish (since this throws an |
| * interrupt, which is forwarded on to ThreadJob, and marked as a failure). |
| */ |
| def stopJob(job: StreamJob) { |
| // Shutdown task. |
| job.kill |
| assertEquals(ApplicationStatus.UnsuccessfulFinish, job.waitForFinish(60000)) |
| } |
| |
| /** |
| * Send a message to the input topic, and validate that it gets to the test task. |
| */ |
| def send(task: TestTask, msg: String) { |
| producer.send(new KeyedMessage(INPUT_TOPIC, msg)) |
| task.awaitMessage |
| assertEquals(msg, task.received.last) |
| } |
| |
| /** |
| * Read all messages from a topic starting from last saved offset for group. |
| * To read all from offset 0, specify a unique, new group string. |
| */ |
| def readAll(topic: String, maxOffsetInclusive: Int, group: String): List[String] = { |
| val props = new Properties |
| |
| props.put("zookeeper.connect", zkConnect) |
| props.put("group.id", group) |
| props.put("auto.offset.reset", "smallest") |
| |
| val consumerConfig = new ConsumerConfig(props) |
| val consumerConnector = Consumer.create(consumerConfig) |
| var stream = consumerConnector.createMessageStreams(Map(topic -> 1)).get(topic).get.get(0).iterator |
| var message: MessageAndMetadata[Array[Byte], Array[Byte]] = null |
| var messages = ArrayBuffer[String]() |
| |
| while (message == null || message.offset < maxOffsetInclusive) { |
| message = stream.next |
| if (message.message == null) { |
| messages += null |
| } else { |
| messages += new String(message.message, "UTF-8") |
| } |
| System.err.println("TestStatefulTask.readAll(): offset=%s, message=%s" format (message.offset, messages.last)) |
| } |
| |
| consumerConnector.shutdown |
| |
| messages.toList |
| } |
| } |
| |
| object TestTask { |
| val tasks = new HashMap[TaskName, TestTask] with SynchronizedMap[TaskName, TestTask] |
| @volatile var allTasksRegistered = new CountDownLatch(TestStatefulTask.TOTAL_TASK_NAMES) |
| |
| /** |
| * Static method that tasks can use to register themselves with. Useful so |
| * we don't have to sneak into the ThreadJob/SamzaContainer to get our test |
| * tasks. |
| */ |
| def register(taskName: TaskName, task: TestTask) { |
| tasks += taskName -> task |
| allTasksRegistered.countDown |
| } |
| |
| def awaitTaskRegistered { |
| allTasksRegistered.await(60, TimeUnit.SECONDS) |
| assertEquals(0, allTasksRegistered.getCount) |
| assertEquals(TestStatefulTask.TOTAL_TASK_NAMES, tasks.size) |
| // Reset the registered latch, so we can use it again every time we start a new job. |
| TestTask.allTasksRegistered = new CountDownLatch(TestStatefulTask.TOTAL_TASK_NAMES) |
| } |
| } |
| |
| class TestTask extends StreamTask with InitableTask { |
| var store: KeyValueStore[String, String] = null |
| var restored = Set[String]() |
| var received = ArrayBuffer[String]() |
| val initFinished = new CountDownLatch(1) |
| var gotMessage = new CountDownLatch(1) |
| |
| def init(config: Config, context: TaskContext) { |
| TestTask.register(context.getTaskName, this) |
| store = context |
| .getStore(TestStatefulTask.STATE_TOPIC) |
| .asInstanceOf[KeyValueStore[String, String]] |
| val iter = store.all |
| restored ++= iter |
| .map(_.getValue) |
| .toSet |
| System.err.println("TestTask.init(): %s" format restored) |
| iter.close |
| initFinished.countDown() |
| } |
| |
| def process(envelope: IncomingMessageEnvelope, collector: MessageCollector, coordinator: TaskCoordinator) { |
| val msg = envelope.getMessage.asInstanceOf[String] |
| |
| System.err.println("TestTask.process(): %s" format msg) |
| |
| received += msg |
| |
| // A negative string means delete |
| if (msg.startsWith("-")) { |
| store.delete(msg.substring(1)) |
| } else { |
| store.put(msg, msg) |
| } |
| |
| coordinator.commit(RequestScope.ALL_TASKS_IN_CONTAINER) |
| |
| // Notify sender that we got a message. |
| gotMessage.countDown |
| } |
| |
| def awaitMessage { |
| assertTrue("Timed out of waiting for message rather than received one.", gotMessage.await(60, TimeUnit.SECONDS)) |
| assertEquals(0, gotMessage.getCount) |
| gotMessage = new CountDownLatch(1) |
| } |
| } |