| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.samza.operators.impl; |
| |
| import com.google.common.collect.HashMultimap; |
| import com.google.common.collect.Multimap; |
| import org.apache.samza.Partition; |
| import org.apache.samza.application.descriptors.StreamApplicationDescriptorImpl; |
| import org.apache.samza.config.Config; |
| import org.apache.samza.config.JobConfig; |
| import org.apache.samza.config.MapConfig; |
| import org.apache.samza.config.StreamConfig; |
| import org.apache.samza.container.TaskName; |
| import org.apache.samza.context.Context; |
| import org.apache.samza.context.MockContext; |
| import org.apache.samza.context.TaskContextImpl; |
| import org.apache.samza.system.descriptors.GenericInputDescriptor; |
| import org.apache.samza.system.descriptors.GenericOutputDescriptor; |
| import org.apache.samza.system.descriptors.GenericSystemDescriptor; |
| import org.apache.samza.job.model.ContainerModel; |
| import org.apache.samza.job.model.JobModel; |
| import org.apache.samza.job.model.TaskModel; |
| import org.apache.samza.metrics.MetricsRegistryMap; |
| import org.apache.samza.operators.KV; |
| import org.apache.samza.operators.MessageStream; |
| import org.apache.samza.operators.OutputStream; |
| import org.apache.samza.operators.functions.ClosableFunction; |
| import org.apache.samza.operators.functions.FilterFunction; |
| import org.apache.samza.operators.functions.InitableFunction; |
| import org.apache.samza.operators.functions.JoinFunction; |
| import org.apache.samza.operators.functions.MapFunction; |
| import org.apache.samza.operators.spec.OperatorSpec.OpCode; |
| import org.apache.samza.serializers.IntegerSerde; |
| import org.apache.samza.serializers.KVSerde; |
| import org.apache.samza.serializers.Serde; |
| import org.apache.samza.serializers.StringSerde; |
| import org.apache.samza.storage.kv.KeyValueStore; |
| import org.apache.samza.system.IncomingMessageEnvelope; |
| import org.apache.samza.system.SystemStream; |
| import org.apache.samza.system.SystemStreamPartition; |
| import org.apache.samza.task.MessageCollector; |
| import org.apache.samza.task.TaskCoordinator; |
| import org.apache.samza.testUtils.StreamTestUtils; |
| import org.apache.samza.util.Clock; |
| import org.apache.samza.util.SystemClock; |
| import org.apache.samza.util.TimestampedValue; |
| import org.junit.After; |
| import org.junit.Before; |
| import org.junit.Test; |
| |
| import java.io.Serializable; |
| import java.time.Duration; |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.Collections; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Set; |
| import java.util.function.BiFunction; |
| import java.util.function.Function; |
| |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertNotSame; |
| import static org.junit.Assert.assertTrue; |
| import static org.mockito.Matchers.eq; |
| import static org.mockito.Mockito.mock; |
| import static org.mockito.Mockito.when; |
| |
| public class TestOperatorImplGraph { |
| private Context context; |
| |
| @Before |
| public void setup() { |
| this.context = new MockContext(); |
| // individual tests can override this config if necessary |
| when(this.context.getJobContext().getConfig()).thenReturn(mock(Config.class)); |
| TaskModel taskModel = mock(TaskModel.class); |
| when(taskModel.getTaskName()).thenReturn(new TaskName("task 0")); |
| when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel); |
| when(this.context.getTaskContext().getTaskMetricsRegistry()).thenReturn(new MetricsRegistryMap()); |
| when(this.context.getContainerContext().getContainerMetricsRegistry()).thenReturn(new MetricsRegistryMap()); |
| } |
| |
| @After |
| public void tearDown() { |
| BaseTestFunction.reset(); |
| } |
| |
| @Test |
| public void testEmptyChain() { |
| StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { }, mock(Config.class)); |
| OperatorImplGraph opGraph = new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), context, mock(Clock.class)); |
| assertEquals(0, opGraph.getAllInputOperators().size()); |
| } |
| |
| @Test |
| public void testLinearChain() { |
| String inputStreamId = "input"; |
| String inputSystem = "input-system"; |
| String inputPhysicalName = "input-stream"; |
| String outputStreamId = "output"; |
| String outputSystem = "output-system"; |
| String outputPhysicalName = "output-stream"; |
| String intermediateSystem = "intermediate-system"; |
| |
| HashMap<String, String> configs = new HashMap<>(); |
| configs.put(JobConfig.JOB_NAME, "jobName"); |
| configs.put(JobConfig.JOB_ID, "jobId"); |
| configs.put(JobConfig.JOB_DEFAULT_SYSTEM, intermediateSystem); |
| StreamTestUtils.addStreamConfigs(configs, inputStreamId, inputSystem, inputPhysicalName); |
| StreamTestUtils.addStreamConfigs(configs, outputStreamId, outputSystem, outputPhysicalName); |
| Config config = new MapConfig(configs); |
| when(this.context.getJobContext().getConfig()).thenReturn(config); |
| |
| StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { |
| GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass"); |
| GenericInputDescriptor inputDescriptor = sd.getInputDescriptor(inputStreamId, mock(Serde.class)); |
| GenericOutputDescriptor outputDescriptor = sd.getOutputDescriptor(outputStreamId, mock(Serde.class)); |
| MessageStream<Object> inputStream = appDesc.getInputStream(inputDescriptor); |
| OutputStream<Object> outputStream = appDesc.getOutputStream(outputDescriptor); |
| |
| inputStream |
| .filter(mock(FilterFunction.class)) |
| .map(mock(MapFunction.class)) |
| .sendTo(outputStream); |
| }, config); |
| |
| OperatorImplGraph opImplGraph = |
| new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class)); |
| |
| InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName)); |
| assertEquals(1, inputOpImpl.registeredOperators.size()); |
| |
| OperatorImpl filterOpImpl = (FlatmapOperatorImpl) inputOpImpl.registeredOperators.iterator().next(); |
| assertEquals(1, filterOpImpl.registeredOperators.size()); |
| assertEquals(OpCode.FILTER, filterOpImpl.getOperatorSpec().getOpCode()); |
| |
| OperatorImpl mapOpImpl = (FlatmapOperatorImpl) filterOpImpl.registeredOperators.iterator().next(); |
| assertEquals(1, mapOpImpl.registeredOperators.size()); |
| assertEquals(OpCode.MAP, mapOpImpl.getOperatorSpec().getOpCode()); |
| |
| OperatorImpl sendToOpImpl = (OutputOperatorImpl) mapOpImpl.registeredOperators.iterator().next(); |
| assertEquals(0, sendToOpImpl.registeredOperators.size()); |
| assertEquals(OpCode.SEND_TO, sendToOpImpl.getOperatorSpec().getOpCode()); |
| } |
| |
| @Test |
| public void testPartitionByChain() { |
| String inputStreamId = "input"; |
| String inputSystem = "input-system"; |
| String inputPhysicalName = "input-stream"; |
| String outputStreamId = "output"; |
| String outputSystem = "output-system"; |
| String outputPhysicalName = "output-stream"; |
| String intermediateStreamId = "jobName-jobId-partition_by-p1"; |
| String intermediateSystem = "intermediate-system"; |
| |
| HashMap<String, String> configs = new HashMap<>(); |
| configs.put(JobConfig.JOB_NAME, "jobName"); |
| configs.put(JobConfig.JOB_ID, "jobId"); |
| configs.put(JobConfig.JOB_DEFAULT_SYSTEM, intermediateSystem); |
| StreamTestUtils.addStreamConfigs(configs, inputStreamId, inputSystem, inputPhysicalName); |
| StreamTestUtils.addStreamConfigs(configs, outputStreamId, outputSystem, outputPhysicalName); |
| Config config = new MapConfig(configs); |
| when(this.context.getJobContext().getConfig()).thenReturn(config); |
| |
| StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { |
| GenericSystemDescriptor isd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass"); |
| GenericSystemDescriptor osd = new GenericSystemDescriptor(outputSystem, "mockFactoryClass"); |
| GenericInputDescriptor inputDescriptor = isd.getInputDescriptor(inputStreamId, mock(Serde.class)); |
| GenericOutputDescriptor outputDescriptor = osd.getOutputDescriptor(outputStreamId, |
| KVSerde.of(mock(IntegerSerde.class), mock(StringSerde.class))); |
| MessageStream<Object> inputStream = appDesc.getInputStream(inputDescriptor); |
| OutputStream<KV<Integer, String>> outputStream = appDesc.getOutputStream(outputDescriptor); |
| |
| inputStream |
| .partitionBy(Object::hashCode, Object::toString, |
| KVSerde.of(mock(IntegerSerde.class), mock(StringSerde.class)), "p1") |
| .sendTo(outputStream); |
| }, config); |
| |
| JobModel jobModel = mock(JobModel.class); |
| ContainerModel containerModel = mock(ContainerModel.class); |
| TaskModel taskModel = mock(TaskModel.class); |
| when(jobModel.getContainers()).thenReturn(Collections.singletonMap("0", containerModel)); |
| when(containerModel.getTasks()).thenReturn(Collections.singletonMap(new TaskName("task 0"), taskModel)); |
| when(taskModel.getSystemStreamPartitions()).thenReturn(Collections.emptySet()); |
| when(((TaskContextImpl) this.context.getTaskContext()).getJobModel()).thenReturn(jobModel); |
| OperatorImplGraph opImplGraph = |
| new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class)); |
| |
| InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName)); |
| assertEquals(1, inputOpImpl.registeredOperators.size()); |
| |
| OperatorImpl partitionByOpImpl = (PartitionByOperatorImpl) inputOpImpl.registeredOperators.iterator().next(); |
| assertEquals(0, partitionByOpImpl.registeredOperators.size()); // is terminal but paired with an input operator |
| assertEquals(OpCode.PARTITION_BY, partitionByOpImpl.getOperatorSpec().getOpCode()); |
| |
| InputOperatorImpl repartitionedInputOpImpl = |
| opImplGraph.getInputOperator(new SystemStream(intermediateSystem, intermediateStreamId)); |
| assertEquals(1, repartitionedInputOpImpl.registeredOperators.size()); |
| |
| OperatorImpl sendToOpImpl = (OutputOperatorImpl) repartitionedInputOpImpl.registeredOperators.iterator().next(); |
| assertEquals(0, sendToOpImpl.registeredOperators.size()); |
| assertEquals(OpCode.SEND_TO, sendToOpImpl.getOperatorSpec().getOpCode()); |
| } |
| |
| @Test |
| public void testBroadcastChain() { |
| String inputStreamId = "input"; |
| String inputSystem = "input-system"; |
| String inputPhysicalName = "input-stream"; |
| HashMap<String, String> configMap = new HashMap<>(); |
| configMap.put(JobConfig.JOB_NAME, "test-job"); |
| configMap.put(JobConfig.JOB_ID, "1"); |
| StreamTestUtils.addStreamConfigs(configMap, inputStreamId, inputSystem, inputPhysicalName); |
| Config config = new MapConfig(configMap); |
| when(this.context.getJobContext().getConfig()).thenReturn(config); |
| StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { |
| GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass"); |
| GenericInputDescriptor inputDescriptor = sd.getInputDescriptor(inputStreamId, mock(Serde.class)); |
| MessageStream<Object> inputStream = appDesc.getInputStream(inputDescriptor); |
| inputStream.filter(mock(FilterFunction.class)); |
| inputStream.map(mock(MapFunction.class)); |
| }, config); |
| |
| OperatorImplGraph opImplGraph = |
| new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class)); |
| |
| InputOperatorImpl inputOpImpl = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName)); |
| assertEquals(2, inputOpImpl.registeredOperators.size()); |
| assertTrue(inputOpImpl.registeredOperators.stream() |
| .anyMatch(opImpl -> ((OperatorImpl) opImpl).getOperatorSpec().getOpCode() == OpCode.FILTER)); |
| assertTrue(inputOpImpl.registeredOperators.stream() |
| .anyMatch(opImpl -> ((OperatorImpl) opImpl).getOperatorSpec().getOpCode() == OpCode.MAP)); |
| } |
| |
| @Test |
| public void testMergeChain() { |
| String inputStreamId = "input"; |
| String inputSystem = "input-system"; |
| StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { |
| GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass"); |
| GenericInputDescriptor inputDescriptor = sd.getInputDescriptor(inputStreamId, mock(Serde.class)); |
| MessageStream<Object> inputStream = appDesc.getInputStream(inputDescriptor); |
| MessageStream<Object> stream1 = inputStream.filter(mock(FilterFunction.class)); |
| MessageStream<Object> stream2 = inputStream.map(mock(MapFunction.class)); |
| stream1.merge(Collections.singleton(stream2)) |
| .map(new TestMapFunction<Object, Object>("test-map-1", (Function & Serializable) m -> m)); |
| }, getConfig()); |
| |
| TaskName mockTaskName = mock(TaskName.class); |
| TaskModel taskModel = mock(TaskModel.class); |
| when(taskModel.getTaskName()).thenReturn(mockTaskName); |
| when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel); |
| |
| OperatorImplGraph opImplGraph = |
| new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class)); |
| |
| Set<OperatorImpl> opSet = opImplGraph.getAllInputOperators().stream().collect(HashSet::new, |
| (s, op) -> addOperatorRecursively(s, op), HashSet::addAll); |
| Object[] mergeOps = opSet.stream().filter(op -> op.getOperatorSpec().getOpCode() == OpCode.MERGE).toArray(); |
| assertEquals(1, mergeOps.length); |
| assertEquals(1, ((OperatorImpl) mergeOps[0]).registeredOperators.size()); |
| OperatorImpl mapOp = (OperatorImpl) ((OperatorImpl) mergeOps[0]).registeredOperators.iterator().next(); |
| assertEquals(mapOp.getOperatorSpec().getOpCode(), OpCode.MAP); |
| |
| // verify that the DAG after merge is only traversed & initialized once |
| assertEquals(TestMapFunction.getInstanceByTaskName(mockTaskName, "test-map-1").numInitCalled, 1); |
| } |
| |
| @Test |
| public void testJoinChain() { |
| String inputStreamId1 = "input1"; |
| String inputStreamId2 = "input2"; |
| String inputSystem = "input-system"; |
| String inputPhysicalName1 = "input-stream1"; |
| String inputPhysicalName2 = "input-stream2"; |
| HashMap<String, String> configs = new HashMap<>(); |
| configs.put(JobConfig.JOB_NAME, "jobName"); |
| configs.put(JobConfig.JOB_ID, "jobId"); |
| StreamTestUtils.addStreamConfigs(configs, inputStreamId1, inputSystem, inputPhysicalName1); |
| StreamTestUtils.addStreamConfigs(configs, inputStreamId2, inputSystem, inputPhysicalName2); |
| Config config = new MapConfig(configs); |
| when(this.context.getJobContext().getConfig()).thenReturn(config); |
| |
| Integer joinKey = new Integer(1); |
| Function<Object, Integer> keyFn = (Function & Serializable) m -> joinKey; |
| JoinFunction testJoinFunction = new TestJoinFunction("jobName-jobId-join-j1", |
| (BiFunction & Serializable) (m1, m2) -> KV.of(m1, m2), keyFn, keyFn); |
| |
| StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { |
| GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass"); |
| GenericInputDescriptor inputDescriptor1 = sd.getInputDescriptor(inputStreamId1, mock(Serde.class)); |
| GenericInputDescriptor inputDescriptor2 = sd.getInputDescriptor(inputStreamId2, mock(Serde.class)); |
| MessageStream<Object> inputStream1 = appDesc.getInputStream(inputDescriptor1); |
| MessageStream<Object> inputStream2 = appDesc.getInputStream(inputDescriptor2); |
| |
| inputStream1.join(inputStream2, testJoinFunction, |
| mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j1"); |
| }, config); |
| |
| TaskName mockTaskName = mock(TaskName.class); |
| TaskModel taskModel = mock(TaskModel.class); |
| when(taskModel.getTaskName()).thenReturn(mockTaskName); |
| when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel); |
| |
| KeyValueStore mockLeftStore = mock(KeyValueStore.class); |
| when(this.context.getTaskContext().getStore(eq("jobName-jobId-join-j1-L"))).thenReturn(mockLeftStore); |
| KeyValueStore mockRightStore = mock(KeyValueStore.class); |
| when(this.context.getTaskContext().getStore(eq("jobName-jobId-join-j1-R"))).thenReturn(mockRightStore); |
| OperatorImplGraph opImplGraph = |
| new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, mock(Clock.class)); |
| |
| // verify that join function is initialized once. |
| assertEquals(TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1").numInitCalled, 1); |
| |
| InputOperatorImpl inputOpImpl1 = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName1)); |
| InputOperatorImpl inputOpImpl2 = opImplGraph.getInputOperator(new SystemStream(inputSystem, inputPhysicalName2)); |
| PartialJoinOperatorImpl leftPartialJoinOpImpl = |
| (PartialJoinOperatorImpl) inputOpImpl1.registeredOperators.iterator().next(); |
| PartialJoinOperatorImpl rightPartialJoinOpImpl = |
| (PartialJoinOperatorImpl) inputOpImpl2.registeredOperators.iterator().next(); |
| |
| assertEquals(leftPartialJoinOpImpl.getOperatorSpec(), rightPartialJoinOpImpl.getOperatorSpec()); |
| assertNotSame(leftPartialJoinOpImpl, rightPartialJoinOpImpl); |
| |
| // verify that left partial join operator calls getFirstKey |
| Object mockLeftMessage = mock(Object.class); |
| long currentTimeMillis = System.currentTimeMillis(); |
| when(mockLeftStore.get(eq(joinKey))).thenReturn(new TimestampedValue<>(mockLeftMessage, currentTimeMillis)); |
| IncomingMessageEnvelope leftMessage = new IncomingMessageEnvelope(mock(SystemStreamPartition.class), "", "", mockLeftMessage); |
| inputOpImpl1.onMessage(leftMessage, mock(MessageCollector.class), mock(TaskCoordinator.class)); |
| |
| // verify that right partial join operator calls getSecondKey |
| Object mockRightMessage = mock(Object.class); |
| when(mockRightStore.get(eq(joinKey))).thenReturn(new TimestampedValue<>(mockRightMessage, currentTimeMillis)); |
| IncomingMessageEnvelope rightMessage = new IncomingMessageEnvelope(mock(SystemStreamPartition.class), "", "", mockRightMessage); |
| inputOpImpl2.onMessage(rightMessage, mock(MessageCollector.class), mock(TaskCoordinator.class)); |
| |
| |
| // verify that the join function apply is called with the correct messages on match |
| assertEquals(((TestJoinFunction) TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1")).joinResults.size(), 1); |
| KV joinResult = (KV) ((TestJoinFunction) TestJoinFunction.getInstanceByTaskName(mockTaskName, "jobName-jobId-join-j1")).joinResults.iterator().next(); |
| assertEquals(joinResult.getKey(), mockLeftMessage); |
| assertEquals(joinResult.getValue(), mockRightMessage); |
| } |
| |
| @Test |
| public void testOperatorGraphInitAndClose() { |
| String inputStreamId1 = "input1"; |
| String inputStreamId2 = "input2"; |
| String inputSystem = "input-system"; |
| |
| TaskName mockTaskName = mock(TaskName.class); |
| TaskModel taskModel = mock(TaskModel.class); |
| when(taskModel.getTaskName()).thenReturn(mockTaskName); |
| when(this.context.getTaskContext().getTaskModel()).thenReturn(taskModel); |
| |
| StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { |
| GenericSystemDescriptor sd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass"); |
| GenericInputDescriptor inputDescriptor1 = sd.getInputDescriptor(inputStreamId1, mock(Serde.class)); |
| GenericInputDescriptor inputDescriptor2 = sd.getInputDescriptor(inputStreamId2, mock(Serde.class)); |
| MessageStream<Object> inputStream1 = appDesc.getInputStream(inputDescriptor1); |
| MessageStream<Object> inputStream2 = appDesc.getInputStream(inputDescriptor2); |
| |
| Function mapFn = (Function & Serializable) m -> m; |
| inputStream1.map(new TestMapFunction<Object, Object>("1", mapFn)) |
| .map(new TestMapFunction<Object, Object>("2", mapFn)); |
| |
| inputStream2.map(new TestMapFunction<Object, Object>("3", mapFn)) |
| .map(new TestMapFunction<Object, Object>("4", mapFn)); |
| }, getConfig()); |
| |
| OperatorImplGraph opImplGraph = new OperatorImplGraph(graphSpec.getOperatorSpecGraph(), this.context, SystemClock.instance()); |
| |
| List<String> initializedOperators = BaseTestFunction.getInitListByTaskName(mockTaskName); |
| |
| // Assert that initialization occurs in topological order. |
| assertEquals(initializedOperators.get(0), "1"); |
| assertEquals(initializedOperators.get(1), "2"); |
| assertEquals(initializedOperators.get(2), "3"); |
| assertEquals(initializedOperators.get(3), "4"); |
| |
| // Assert that finalization occurs in reverse topological order. |
| opImplGraph.close(); |
| List<String> closedOperators = BaseTestFunction.getCloseListByTaskName(mockTaskName); |
| assertEquals(closedOperators.get(0), "4"); |
| assertEquals(closedOperators.get(1), "3"); |
| assertEquals(closedOperators.get(2), "2"); |
| assertEquals(closedOperators.get(3), "1"); |
| } |
| |
| @Test |
| public void testGetStreamToConsumerTasks() { |
| String system = "test-system"; |
| String streamId0 = "test-stream-0"; |
| String streamId1 = "test-stream-1"; |
| |
| HashMap<String, String> configs = new HashMap<>(); |
| configs.put(JobConfig.JOB_NAME, "test-app"); |
| configs.put(JobConfig.JOB_DEFAULT_SYSTEM, "test-system"); |
| StreamTestUtils.addStreamConfigs(configs, streamId0, system, streamId0); |
| StreamTestUtils.addStreamConfigs(configs, streamId1, system, streamId1); |
| Config config = new MapConfig(configs); |
| when(this.context.getJobContext().getConfig()).thenReturn(config); |
| |
| SystemStreamPartition ssp0 = new SystemStreamPartition(system, streamId0, new Partition(0)); |
| SystemStreamPartition ssp1 = new SystemStreamPartition(system, streamId0, new Partition(1)); |
| SystemStreamPartition ssp2 = new SystemStreamPartition(system, streamId1, new Partition(0)); |
| |
| TaskName task0 = new TaskName("Task 0"); |
| TaskName task1 = new TaskName("Task 1"); |
| Set<SystemStreamPartition> ssps = new HashSet<>(); |
| ssps.add(ssp0); |
| ssps.add(ssp2); |
| TaskModel tm0 = new TaskModel(task0, ssps, new Partition(0)); |
| ContainerModel cm0 = new ContainerModel("c0", Collections.singletonMap(task0, tm0)); |
| TaskModel tm1 = new TaskModel(task1, Collections.singleton(ssp1), new Partition(1)); |
| ContainerModel cm1 = new ContainerModel("c1", Collections.singletonMap(task1, tm1)); |
| |
| Map<String, ContainerModel> cms = new HashMap<>(); |
| cms.put(cm0.getId(), cm0); |
| cms.put(cm1.getId(), cm1); |
| |
| JobModel jobModel = new JobModel(config, cms); |
| Multimap<SystemStream, String> streamToTasks = OperatorImplGraph.getStreamToConsumerTasks(jobModel); |
| assertEquals(streamToTasks.get(ssp0.getSystemStream()).size(), 2); |
| assertEquals(streamToTasks.get(ssp2.getSystemStream()).size(), 1); |
| } |
| |
| @Test |
| public void testGetOutputToInputStreams() { |
| String inputStreamId1 = "input1"; |
| String inputStreamId2 = "input2"; |
| String inputStreamId3 = "input3"; |
| String inputSystem = "input-system"; |
| |
| String outputStreamId1 = "output1"; |
| String outputStreamId2 = "output2"; |
| String outputSystem = "output-system"; |
| |
| String intStreamId1 = "test-app-1-partition_by-p1"; |
| String intStreamId2 = "test-app-1-partition_by-p2"; |
| String intSystem = "test-system"; |
| |
| HashMap<String, String> configs = new HashMap<>(); |
| configs.put(JobConfig.JOB_NAME, "test-app"); |
| configs.put(JobConfig.JOB_DEFAULT_SYSTEM, intSystem); |
| StreamTestUtils.addStreamConfigs(configs, inputStreamId1, inputSystem, inputStreamId1); |
| StreamTestUtils.addStreamConfigs(configs, inputStreamId2, inputSystem, inputStreamId2); |
| StreamTestUtils.addStreamConfigs(configs, inputStreamId3, inputSystem, inputStreamId3); |
| StreamTestUtils.addStreamConfigs(configs, outputStreamId1, outputSystem, outputStreamId1); |
| StreamTestUtils.addStreamConfigs(configs, outputStreamId2, outputSystem, outputStreamId2); |
| Config config = new MapConfig(configs); |
| when(this.context.getJobContext().getConfig()).thenReturn(config); |
| |
| StreamApplicationDescriptorImpl graphSpec = new StreamApplicationDescriptorImpl(appDesc -> { |
| GenericSystemDescriptor isd = new GenericSystemDescriptor(inputSystem, "mockFactoryClass"); |
| GenericInputDescriptor inputDescriptor1 = isd.getInputDescriptor(inputStreamId1, mock(Serde.class)); |
| GenericInputDescriptor inputDescriptor2 = isd.getInputDescriptor(inputStreamId2, mock(Serde.class)); |
| GenericInputDescriptor inputDescriptor3 = isd.getInputDescriptor(inputStreamId3, mock(Serde.class)); |
| GenericSystemDescriptor osd = new GenericSystemDescriptor(outputSystem, "mockFactoryClass"); |
| GenericOutputDescriptor outputDescriptor1 = osd.getOutputDescriptor(outputStreamId1, mock(Serde.class)); |
| GenericOutputDescriptor outputDescriptor2 = osd.getOutputDescriptor(outputStreamId2, mock(Serde.class)); |
| MessageStream messageStream1 = appDesc.getInputStream(inputDescriptor1).map(m -> m); |
| MessageStream messageStream2 = appDesc.getInputStream(inputDescriptor2).filter(m -> true); |
| MessageStream messageStream3 = |
| appDesc.getInputStream(inputDescriptor3) |
| .filter(m -> true) |
| .partitionBy(m -> "m", m -> m, mock(KVSerde.class), "p1") |
| .map(m -> m); |
| OutputStream<Object> outputStream1 = appDesc.getOutputStream(outputDescriptor1); |
| OutputStream<Object> outputStream2 = appDesc.getOutputStream(outputDescriptor2); |
| |
| messageStream1 |
| .join(messageStream2, mock(JoinFunction.class), |
| mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(2), "j1") |
| .partitionBy(m -> "m", m -> m, mock(KVSerde.class), "p2") |
| .sendTo(outputStream1); |
| messageStream3 |
| .join(messageStream2, mock(JoinFunction.class), |
| mock(Serde.class), mock(Serde.class), mock(Serde.class), Duration.ofHours(1), "j2") |
| .sendTo(outputStream2); |
| }, config); |
| |
| Multimap<SystemStream, SystemStream> outputToInput = |
| OperatorImplGraph.getIntermediateToInputStreamsMap(graphSpec.getOperatorSpecGraph(), new StreamConfig(config)); |
| Collection<SystemStream> inputs = outputToInput.get(new SystemStream(intSystem, intStreamId2)); |
| assertEquals(inputs.size(), 2); |
| assertTrue(inputs.contains(new SystemStream(inputSystem, inputStreamId1))); |
| assertTrue(inputs.contains(new SystemStream(inputSystem, inputStreamId2))); |
| |
| inputs = outputToInput.get(new SystemStream(intSystem, intStreamId1)); |
| assertEquals(inputs.size(), 1); |
| assertEquals(inputs.iterator().next(), new SystemStream(inputSystem, inputStreamId3)); |
| } |
| |
| @Test |
| public void testGetProducerTaskCountForIntermediateStreams() { |
| String inputStreamId1 = "input1"; |
| String inputStreamId2 = "input2"; |
| String inputStreamId3 = "input3"; |
| String inputSystem1 = "system1"; |
| String inputSystem2 = "system2"; |
| |
| SystemStream input1 = new SystemStream("system1", "intput1"); |
| SystemStream input2 = new SystemStream("system2", "intput2"); |
| SystemStream input3 = new SystemStream("system2", "intput3"); |
| |
| SystemStream int1 = new SystemStream("system1", "int1"); |
| SystemStream int2 = new SystemStream("system1", "int2"); |
| |
| |
| /** |
| * the task assignment looks like the following: |
| * |
| * input1 -----> task0, task1 -----> int1 |
| * ^ |
| * input2 ------> task1, task2--------| |
| * v |
| * input3 ------> task1 -----------> int2 |
| * |
| */ |
| String task0 = "Task 0"; |
| String task1 = "Task 1"; |
| String task2 = "Task 2"; |
| |
| Multimap<SystemStream, String> streamToConsumerTasks = HashMultimap.create(); |
| streamToConsumerTasks.put(input1, task0); |
| streamToConsumerTasks.put(input1, task1); |
| streamToConsumerTasks.put(input2, task1); |
| streamToConsumerTasks.put(input2, task2); |
| streamToConsumerTasks.put(input3, task1); |
| streamToConsumerTasks.put(int1, task0); |
| streamToConsumerTasks.put(int1, task1); |
| streamToConsumerTasks.put(int2, task0); |
| |
| Multimap<SystemStream, SystemStream> intermediateToInputStreams = HashMultimap.create(); |
| intermediateToInputStreams.put(int1, input1); |
| intermediateToInputStreams.put(int1, input2); |
| |
| intermediateToInputStreams.put(int2, input2); |
| intermediateToInputStreams.put(int2, input3); |
| |
| Map<SystemStream, Integer> counts = OperatorImplGraph.getProducerTaskCountForIntermediateStreams( |
| streamToConsumerTasks, intermediateToInputStreams); |
| assertTrue(counts.get(int1) == 3); |
| assertTrue(counts.get(int2) == 2); |
| } |
| |
| private void addOperatorRecursively(HashSet<OperatorImpl> s, OperatorImpl op) { |
| List<OperatorImpl> operators = new ArrayList<>(); |
| operators.add(op); |
| while (!operators.isEmpty()) { |
| OperatorImpl opImpl = operators.remove(0); |
| s.add(opImpl); |
| if (!opImpl.registeredOperators.isEmpty()) { |
| operators.addAll(opImpl.registeredOperators); |
| } |
| } |
| } |
| |
| private Config getConfig() { |
| HashMap<String, String> configMap = new HashMap<>(); |
| configMap.put(JobConfig.JOB_NAME, "test-job"); |
| configMap.put(JobConfig.JOB_ID, "1"); |
| return new MapConfig(configMap); |
| } |
| |
| private static class TestMapFunction<M, OM> extends BaseTestFunction implements MapFunction<M, OM> { |
| final Function<M, OM> mapFn; |
| |
| public TestMapFunction(String opId, Function<M, OM> mapFn) { |
| super(opId); |
| this.mapFn = mapFn; |
| } |
| |
| @Override |
| public OM apply(M message) { |
| return this.mapFn.apply(message); |
| } |
| } |
| |
| private static class TestJoinFunction<K, M, JM, RM> extends BaseTestFunction implements JoinFunction<K, M, JM, RM> { |
| final BiFunction<M, JM, RM> joiner; |
| final Function<M, K> firstKeyFn; |
| final Function<JM, K> secondKeyFn; |
| final Collection<RM> joinResults = new HashSet<>(); |
| |
| public TestJoinFunction(String opId, BiFunction<M, JM, RM> joiner, Function<M, K> firstKeyFn, Function<JM, K> secondKeyFn) { |
| super(opId); |
| this.joiner = joiner; |
| this.firstKeyFn = firstKeyFn; |
| this.secondKeyFn = secondKeyFn; |
| } |
| |
| @Override |
| public RM apply(M message, JM otherMessage) { |
| RM result = this.joiner.apply(message, otherMessage); |
| this.joinResults.add(result); |
| return result; |
| } |
| |
| @Override |
| public K getFirstKey(M message) { |
| return this.firstKeyFn.apply(message); |
| } |
| |
| @Override |
| public K getSecondKey(JM message) { |
| return this.secondKeyFn.apply(message); |
| } |
| } |
| |
| private static abstract class BaseTestFunction implements InitableFunction, ClosableFunction, Serializable { |
| static Map<TaskName, Map<String, BaseTestFunction>> perTaskFunctionMap = new HashMap<>(); |
| static Map<TaskName, List<String>> perTaskInitList = new HashMap<>(); |
| static Map<TaskName, List<String>> perTaskCloseList = new HashMap<>(); |
| int numInitCalled = 0; |
| int numCloseCalled = 0; |
| TaskName taskName = null; |
| final String opId; |
| |
| public BaseTestFunction(String opId) { |
| this.opId = opId; |
| } |
| |
| static public void reset() { |
| perTaskFunctionMap.clear(); |
| perTaskCloseList.clear(); |
| perTaskInitList.clear(); |
| } |
| |
| static public BaseTestFunction getInstanceByTaskName(TaskName taskName, String opId) { |
| return perTaskFunctionMap.get(taskName).get(opId); |
| } |
| |
| static public List<String> getInitListByTaskName(TaskName taskName) { |
| return perTaskInitList.get(taskName); |
| } |
| |
| static public List<String> getCloseListByTaskName(TaskName taskName) { |
| return perTaskCloseList.get(taskName); |
| } |
| |
| @Override |
| public void close() { |
| if (this.taskName == null) { |
| throw new IllegalStateException("Close called before init"); |
| } |
| if (perTaskFunctionMap.get(this.taskName) == null || !perTaskFunctionMap.get(this.taskName).containsKey(opId)) { |
| throw new IllegalStateException("Close called before init"); |
| } |
| |
| if (perTaskCloseList.get(this.taskName) == null) { |
| perTaskCloseList.put(taskName, new ArrayList<>(Collections.singletonList(opId))); |
| } else { |
| perTaskCloseList.get(taskName).add(opId); |
| } |
| |
| this.numCloseCalled++; |
| } |
| |
| @Override |
| public void init(Context context) { |
| TaskName taskName = context.getTaskContext().getTaskModel().getTaskName(); |
| if (perTaskFunctionMap.get(taskName) == null) { |
| perTaskFunctionMap.put(taskName, new HashMap<>(Collections.singletonMap(opId, BaseTestFunction.this))); |
| } else { |
| if (perTaskFunctionMap.get(taskName).containsKey(opId)) { |
| throw new IllegalStateException(String.format("Multiple init called for op %s in the same task instance %s", opId, this.taskName.getTaskName())); |
| } |
| perTaskFunctionMap.get(taskName).put(opId, this); |
| } |
| if (perTaskInitList.get(taskName) == null) { |
| perTaskInitList.put(taskName, new ArrayList<>(Collections.singletonList(opId))); |
| } else { |
| perTaskInitList.get(taskName).add(opId); |
| } |
| this.taskName = taskName; |
| this.numInitCalled++; |
| } |
| } |
| } |