| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.kafka.streams; |
| |
| import org.apache.kafka.clients.consumer.ConsumerRecord; |
| import org.apache.kafka.clients.producer.ProducerRecord; |
| import org.apache.kafka.common.header.Header; |
| import org.apache.kafka.common.header.Headers; |
| import org.apache.kafka.common.header.internals.RecordHeader; |
| import org.apache.kafka.common.header.internals.RecordHeaders; |
| import org.apache.kafka.common.serialization.ByteArraySerializer; |
| import org.apache.kafka.common.serialization.LongDeserializer; |
| import org.apache.kafka.common.serialization.LongSerializer; |
| import org.apache.kafka.common.serialization.Serdes; |
| import org.apache.kafka.common.serialization.Serializer; |
| import org.apache.kafka.common.serialization.StringDeserializer; |
| import org.apache.kafka.common.serialization.StringSerializer; |
| import org.apache.kafka.common.utils.Bytes; |
| import org.apache.kafka.common.utils.SystemTime; |
| import org.apache.kafka.streams.errors.TopologyException; |
| import org.apache.kafka.streams.kstream.Consumed; |
| import org.apache.kafka.streams.kstream.Materialized; |
| import org.apache.kafka.streams.processor.Processor; |
| import org.apache.kafka.streams.processor.ProcessorContext; |
| import org.apache.kafka.streams.processor.ProcessorSupplier; |
| import org.apache.kafka.streams.processor.PunctuationType; |
| import org.apache.kafka.streams.processor.Punctuator; |
| import org.apache.kafka.streams.processor.StateStore; |
| import org.apache.kafka.streams.processor.TaskId; |
| import org.apache.kafka.streams.state.KeyValueBytesStoreSupplier; |
| import org.apache.kafka.streams.state.KeyValueIterator; |
| import org.apache.kafka.streams.state.KeyValueStore; |
| import org.apache.kafka.streams.state.Stores; |
| import org.apache.kafka.streams.state.internals.KeyValueStoreBuilder; |
| import org.apache.kafka.streams.test.ConsumerRecordFactory; |
| import org.apache.kafka.streams.test.OutputVerifier; |
| import org.apache.kafka.test.TestUtils; |
| import org.junit.After; |
| import org.junit.Assert; |
| import org.junit.Test; |
| import org.junit.runner.RunWith; |
| import org.junit.runners.Parameterized; |
| |
| import java.io.File; |
| import java.time.Duration; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Collection; |
| import java.util.Collections; |
| import java.util.HashSet; |
| import java.util.LinkedList; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Objects; |
| import java.util.Properties; |
| import java.util.Set; |
| import java.util.regex.Pattern; |
| |
| import static org.apache.kafka.common.utils.Utils.mkEntry; |
| import static org.apache.kafka.common.utils.Utils.mkMap; |
| import static org.apache.kafka.common.utils.Utils.mkProperties; |
| import static org.hamcrest.CoreMatchers.equalTo; |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertFalse; |
| import static org.junit.Assert.assertNotNull; |
| import static org.junit.Assert.assertThat; |
| import static org.junit.Assert.assertTrue; |
| import static org.junit.Assert.fail; |
| |
| @RunWith(value = Parameterized.class) |
| public class TopologyTestDriverTest { |
| private final static String SOURCE_TOPIC_1 = "source-topic-1"; |
| private final static String SOURCE_TOPIC_2 = "source-topic-2"; |
| private final static String SINK_TOPIC_1 = "sink-topic-1"; |
| private final static String SINK_TOPIC_2 = "sink-topic-2"; |
| |
| private final ConsumerRecordFactory<byte[], byte[]> consumerRecordFactory = new ConsumerRecordFactory<>( |
| new ByteArraySerializer(), |
| new ByteArraySerializer()); |
| |
| private final Headers headers = new RecordHeaders(new Header[]{new RecordHeader("key", "value".getBytes())}); |
| |
| private final byte[] key1 = new byte[0]; |
| private final byte[] value1 = new byte[0]; |
| private final long timestamp1 = 42L; |
| private final ConsumerRecord<byte[], byte[]> consumerRecord1 = consumerRecordFactory.create(SOURCE_TOPIC_1, key1, value1, headers, timestamp1); |
| |
| private final byte[] key2 = new byte[0]; |
| private final byte[] value2 = new byte[0]; |
| private final long timestamp2 = 43L; |
| private final ConsumerRecord<byte[], byte[]> consumerRecord2 = consumerRecordFactory.create(SOURCE_TOPIC_2, key2, value2, timestamp2); |
| |
| private TopologyTestDriver testDriver; |
| private final Properties config = mkProperties(mkMap( |
| mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "test-TopologyTestDriver"), |
| mkEntry(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"), |
| mkEntry(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getAbsolutePath()) |
| )); |
| private KeyValueStore<String, Long> store; |
| |
| private final StringDeserializer stringDeserializer = new StringDeserializer(); |
| private final LongDeserializer longDeserializer = new LongDeserializer(); |
| private final ConsumerRecordFactory<String, Long> recordFactory = new ConsumerRecordFactory<>( |
| new StringSerializer(), |
| new LongSerializer()); |
| |
| @Parameterized.Parameters(name = "Eos enabled = {0}") |
| public static Collection<Object[]> data() { |
| final List<Object[]> values = new ArrayList<>(); |
| for (final boolean eosEnabled : Arrays.asList(true, false)) { |
| values.add(new Object[] {eosEnabled}); |
| } |
| return values; |
| } |
| |
| public TopologyTestDriverTest(final boolean eosEnabled) { |
| if (eosEnabled) { |
| config.put(StreamsConfig.PROCESSING_GUARANTEE_CONFIG, StreamsConfig.EXACTLY_ONCE); |
| } |
| } |
| |
| private final static class Record { |
| private final Object key; |
| private final Object value; |
| private final long timestamp; |
| private final long offset; |
| private final String topic; |
| private final Headers headers; |
| |
| Record(final ConsumerRecord consumerRecord, |
| final long newOffset) { |
| key = consumerRecord.key(); |
| value = consumerRecord.value(); |
| timestamp = consumerRecord.timestamp(); |
| offset = newOffset; |
| topic = consumerRecord.topic(); |
| headers = consumerRecord.headers(); |
| } |
| |
| Record(final Object key, |
| final Object value, |
| final Headers headers, |
| final long timestamp, |
| final long offset, |
| final String topic) { |
| this.key = key; |
| this.value = value; |
| this.headers = headers; |
| this.timestamp = timestamp; |
| this.offset = offset; |
| this.topic = topic; |
| } |
| |
| @Override |
| public String toString() { |
| return "key: " + key + ", value: " + value + ", timestamp: " + timestamp + ", offset: " + offset + ", topic: " + topic; |
| } |
| |
| @Override |
| public boolean equals(final Object o) { |
| if (this == o) { |
| return true; |
| } |
| if (o == null || getClass() != o.getClass()) { |
| return false; |
| } |
| final Record record = (Record) o; |
| return timestamp == record.timestamp && |
| offset == record.offset && |
| Objects.equals(key, record.key) && |
| Objects.equals(value, record.value) && |
| Objects.equals(topic, record.topic) && |
| Objects.equals(headers, record.headers); |
| } |
| |
| @Override |
| public int hashCode() { |
| return Objects.hash(key, value, headers, timestamp, offset, topic); |
| } |
| } |
| |
| private final static class Punctuation { |
| private final long intervalMs; |
| private final PunctuationType punctuationType; |
| private final Punctuator callback; |
| |
| Punctuation(final long intervalMs, |
| final PunctuationType punctuationType, |
| final Punctuator callback) { |
| this.intervalMs = intervalMs; |
| this.punctuationType = punctuationType; |
| this.callback = callback; |
| } |
| } |
| |
| private final class MockPunctuator implements Punctuator { |
| private final List<Long> punctuatedAt = new LinkedList<>(); |
| |
| @Override |
| public void punctuate(final long timestamp) { |
| punctuatedAt.add(timestamp); |
| } |
| } |
| |
| private final class MockProcessor implements Processor { |
| private final Collection<Punctuation> punctuations; |
| private ProcessorContext context; |
| |
| private boolean initialized = false; |
| private boolean closed = false; |
| private final List<Record> processedRecords = new ArrayList<>(); |
| |
| MockProcessor(final Collection<Punctuation> punctuations) { |
| this.punctuations = punctuations; |
| } |
| |
| @Override |
| public void init(final ProcessorContext context) { |
| initialized = true; |
| this.context = context; |
| for (final Punctuation punctuation : punctuations) { |
| this.context.schedule(Duration.ofMillis(punctuation.intervalMs), punctuation.punctuationType, punctuation.callback); |
| } |
| } |
| |
| @Override |
| public void process(final Object key, final Object value) { |
| processedRecords.add(new Record(key, value, context.headers(), context.timestamp(), context.offset(), context.topic())); |
| context.forward(key, value); |
| } |
| |
| @Override |
| public void close() { |
| closed = true; |
| } |
| } |
| |
| private final List<MockProcessor> mockProcessors = new ArrayList<>(); |
| |
| private final class MockProcessorSupplier implements ProcessorSupplier { |
| private final Collection<Punctuation> punctuations; |
| |
| private MockProcessorSupplier() { |
| this(Collections.emptySet()); |
| } |
| |
| private MockProcessorSupplier(final Collection<Punctuation> punctuations) { |
| this.punctuations = punctuations; |
| } |
| |
| @Override |
| public Processor get() { |
| final MockProcessor mockProcessor = new MockProcessor(punctuations); |
| mockProcessors.add(mockProcessor); |
| return mockProcessor; |
| } |
| } |
| |
| @After |
| public void tearDown() { |
| if (testDriver != null) { |
| testDriver.close(); |
| } |
| } |
| |
| private Topology setupSourceSinkTopology() { |
| final Topology topology = new Topology(); |
| |
| final String sourceName = "source"; |
| |
| topology.addSource(sourceName, SOURCE_TOPIC_1); |
| topology.addSink("sink", SINK_TOPIC_1, sourceName); |
| |
| return topology; |
| } |
| |
| private Topology setupTopologyWithTwoSubtopologies() { |
| final Topology topology = new Topology(); |
| |
| final String sourceName1 = "source-1"; |
| final String sourceName2 = "source-2"; |
| |
| topology.addSource(sourceName1, SOURCE_TOPIC_1); |
| topology.addSink("sink-1", SINK_TOPIC_1, sourceName1); |
| topology.addSource(sourceName2, SINK_TOPIC_1); |
| topology.addSink("sink-2", SINK_TOPIC_2, sourceName2); |
| |
| return topology; |
| } |
| |
| |
| private Topology setupSingleProcessorTopology() { |
| return setupSingleProcessorTopology(-1, null, null); |
| } |
| |
| private Topology setupSingleProcessorTopology(final long punctuationIntervalMs, |
| final PunctuationType punctuationType, |
| final Punctuator callback) { |
| final Collection<Punctuation> punctuations; |
| if (punctuationIntervalMs > 0 && punctuationType != null && callback != null) { |
| punctuations = Collections.singleton(new Punctuation(punctuationIntervalMs, punctuationType, callback)); |
| } else { |
| punctuations = Collections.emptySet(); |
| } |
| |
| final Topology topology = new Topology(); |
| |
| final String sourceName = "source"; |
| |
| topology.addSource(sourceName, SOURCE_TOPIC_1); |
| topology.addProcessor("processor", new MockProcessorSupplier(punctuations), sourceName); |
| |
| return topology; |
| } |
| |
| private Topology setupMultipleSourceTopology(final String... sourceTopicNames) { |
| final Topology topology = new Topology(); |
| |
| final String[] processorNames = new String[sourceTopicNames.length]; |
| int i = 0; |
| for (final String sourceTopicName : sourceTopicNames) { |
| final String sourceName = sourceTopicName + "-source"; |
| final String processorName = sourceTopicName + "-processor"; |
| topology.addSource(sourceName, sourceTopicName); |
| processorNames[i++] = processorName; |
| topology.addProcessor(processorName, new MockProcessorSupplier(), sourceName); |
| } |
| topology.addSink("sink-topic", SINK_TOPIC_1, processorNames); |
| |
| return topology; |
| } |
| |
| private Topology setupGlobalStoreTopology(final String... sourceTopicNames) { |
| if (sourceTopicNames.length == 0) { |
| throw new IllegalArgumentException("sourceTopicNames cannot be empty"); |
| } |
| final Topology topology = new Topology(); |
| |
| for (final String sourceTopicName : sourceTopicNames) { |
| topology.addGlobalStore( |
| Stores.<Bytes, byte[]>keyValueStoreBuilder(Stores.inMemoryKeyValueStore(sourceTopicName + "-globalStore"), null, null).withLoggingDisabled(), |
| sourceTopicName, |
| null, |
| null, |
| sourceTopicName, |
| sourceTopicName + "-processor", |
| new MockProcessorSupplier() |
| ); |
| } |
| |
| return topology; |
| } |
| |
| @Test |
| public void shouldInitProcessor() { |
| testDriver = new TopologyTestDriver(setupSingleProcessorTopology(), config); |
| assertTrue(mockProcessors.get(0).initialized); |
| } |
| |
| @Test |
| public void shouldCloseProcessor() { |
| testDriver = new TopologyTestDriver(setupSingleProcessorTopology(), config); |
| |
| testDriver.close(); |
| assertTrue(mockProcessors.get(0).closed); |
| // As testDriver is already closed, bypassing @After tearDown testDriver.close(). |
| testDriver = null; |
| } |
| |
| @Test |
| public void shouldThrowForUnknownTopic() { |
| final String unknownTopic = "unknownTopic"; |
| final ConsumerRecordFactory<byte[], byte[]> consumerRecordFactory = new ConsumerRecordFactory<>( |
| "unknownTopic", |
| new ByteArraySerializer(), |
| new ByteArraySerializer()); |
| |
| testDriver = new TopologyTestDriver(new Topology(), config); |
| try { |
| testDriver.pipeInput(consumerRecordFactory.create((byte[]) null)); |
| fail("Should have throw IllegalArgumentException"); |
| } catch (final IllegalArgumentException exception) { |
| assertEquals("Unknown topic: " + unknownTopic, exception.getMessage()); |
| } |
| } |
| |
| @Test |
| public void shouldProcessRecordForTopic() { |
| testDriver = new TopologyTestDriver(setupSourceSinkTopology(), config); |
| |
| testDriver.pipeInput(consumerRecord1); |
| final ProducerRecord outputRecord = testDriver.readOutput(SINK_TOPIC_1); |
| |
| assertEquals(key1, outputRecord.key()); |
| assertEquals(value1, outputRecord.value()); |
| assertEquals(SINK_TOPIC_1, outputRecord.topic()); |
| } |
| |
| @Test |
| public void shouldSetRecordMetadata() { |
| testDriver = new TopologyTestDriver(setupSingleProcessorTopology(), config); |
| |
| testDriver.pipeInput(consumerRecord1); |
| |
| final List<Record> processedRecords = mockProcessors.get(0).processedRecords; |
| assertEquals(1, processedRecords.size()); |
| |
| final Record record = processedRecords.get(0); |
| final Record expectedResult = new Record(consumerRecord1, 0L); |
| |
| assertThat(record, equalTo(expectedResult)); |
| } |
| |
| @Test |
| public void shouldSendRecordViaCorrectSourceTopic() { |
| testDriver = new TopologyTestDriver(setupMultipleSourceTopology(SOURCE_TOPIC_1, SOURCE_TOPIC_2), config); |
| |
| final List<Record> processedRecords1 = mockProcessors.get(0).processedRecords; |
| final List<Record> processedRecords2 = mockProcessors.get(1).processedRecords; |
| |
| testDriver.pipeInput(consumerRecord1); |
| |
| assertEquals(1, processedRecords1.size()); |
| assertEquals(0, processedRecords2.size()); |
| |
| Record record = processedRecords1.get(0); |
| Record expectedResult = new Record(consumerRecord1, 0L); |
| assertThat(record, equalTo(expectedResult)); |
| |
| testDriver.pipeInput(consumerRecord2); |
| |
| assertEquals(1, processedRecords1.size()); |
| assertEquals(1, processedRecords2.size()); |
| |
| record = processedRecords2.get(0); |
| expectedResult = new Record(consumerRecord2, 0L); |
| assertThat(record, equalTo(expectedResult)); |
| } |
| |
| @Test |
| public void shouldUseSourceSpecificDeserializers() { |
| final Topology topology = new Topology(); |
| |
| final String sourceName1 = "source-1"; |
| final String sourceName2 = "source-2"; |
| final String processor = "processor"; |
| |
| topology.addSource(sourceName1, Serdes.Long().deserializer(), Serdes.String().deserializer(), SOURCE_TOPIC_1); |
| topology.addSource(sourceName2, Serdes.Integer().deserializer(), Serdes.Double().deserializer(), SOURCE_TOPIC_2); |
| topology.addProcessor(processor, new MockProcessorSupplier(), sourceName1, sourceName2); |
| topology.addSink( |
| "sink", |
| SINK_TOPIC_1, |
| new Serializer<Object>() { |
| @Override |
| public byte[] serialize(final String topic, final Object data) { |
| if (data instanceof Long) { |
| return Serdes.Long().serializer().serialize(topic, (Long) data); |
| } |
| return Serdes.Integer().serializer().serialize(topic, (Integer) data); |
| } |
| @Override |
| public void close() {} |
| @Override |
| public void configure(final Map configs, final boolean isKey) {} |
| }, |
| new Serializer<Object>() { |
| @Override |
| public byte[] serialize(final String topic, final Object data) { |
| if (data instanceof String) { |
| return Serdes.String().serializer().serialize(topic, (String) data); |
| } |
| return Serdes.Double().serializer().serialize(topic, (Double) data); |
| } |
| @Override |
| public void close() {} |
| @Override |
| public void configure(final Map configs, final boolean isKey) {} |
| }, |
| processor); |
| |
| testDriver = new TopologyTestDriver(topology, config); |
| |
| final ConsumerRecordFactory<Long, String> source1Factory = new ConsumerRecordFactory<>( |
| SOURCE_TOPIC_1, |
| Serdes.Long().serializer(), |
| Serdes.String().serializer()); |
| final ConsumerRecordFactory<Integer, Double> source2Factory = new ConsumerRecordFactory<>( |
| SOURCE_TOPIC_2, |
| Serdes.Integer().serializer(), |
| Serdes.Double().serializer()); |
| |
| final Long source1Key = 42L; |
| final String source1Value = "anyString"; |
| final Integer source2Key = 73; |
| final Double source2Value = 3.14; |
| |
| final ConsumerRecord<byte[], byte[]> consumerRecord1 = source1Factory.create(source1Key, source1Value); |
| final ConsumerRecord<byte[], byte[]> consumerRecord2 = source2Factory.create(source2Key, source2Value); |
| |
| testDriver.pipeInput(consumerRecord1); |
| OutputVerifier.compareKeyValue( |
| testDriver.readOutput(SINK_TOPIC_1, Serdes.Long().deserializer(), Serdes.String().deserializer()), |
| source1Key, |
| source1Value); |
| |
| testDriver.pipeInput(consumerRecord2); |
| OutputVerifier.compareKeyValue( |
| testDriver.readOutput(SINK_TOPIC_1, Serdes.Integer().deserializer(), Serdes.Double().deserializer()), |
| source2Key, |
| source2Value); |
| } |
| |
| @Test |
| public void shouldUseSinkSpecificSerializers() { |
| final Topology topology = new Topology(); |
| |
| final String sourceName1 = "source-1"; |
| final String sourceName2 = "source-2"; |
| |
| topology.addSource(sourceName1, Serdes.Long().deserializer(), Serdes.String().deserializer(), SOURCE_TOPIC_1); |
| topology.addSource(sourceName2, Serdes.Integer().deserializer(), Serdes.Double().deserializer(), SOURCE_TOPIC_2); |
| topology.addSink("sink-1", SINK_TOPIC_1, Serdes.Long().serializer(), Serdes.String().serializer(), sourceName1); |
| topology.addSink("sink-2", SINK_TOPIC_2, Serdes.Integer().serializer(), Serdes.Double().serializer(), sourceName2); |
| |
| testDriver = new TopologyTestDriver(topology, config); |
| |
| final ConsumerRecordFactory<Long, String> source1Factory = new ConsumerRecordFactory<>( |
| SOURCE_TOPIC_1, |
| Serdes.Long().serializer(), |
| Serdes.String().serializer()); |
| final ConsumerRecordFactory<Integer, Double> source2Factory = new ConsumerRecordFactory<>( |
| SOURCE_TOPIC_2, |
| Serdes.Integer().serializer(), |
| Serdes.Double().serializer()); |
| |
| final Long source1Key = 42L; |
| final String source1Value = "anyString"; |
| final Integer source2Key = 73; |
| final Double source2Value = 3.14; |
| |
| final ConsumerRecord<byte[], byte[]> consumerRecord1 = source1Factory.create(source1Key, source1Value); |
| final ConsumerRecord<byte[], byte[]> consumerRecord2 = source2Factory.create(source2Key, source2Value); |
| |
| testDriver.pipeInput(consumerRecord1); |
| OutputVerifier.compareKeyValue( |
| testDriver.readOutput(SINK_TOPIC_1, Serdes.Long().deserializer(), Serdes.String().deserializer()), |
| source1Key, |
| source1Value); |
| |
| testDriver.pipeInput(consumerRecord2); |
| OutputVerifier.compareKeyValue( |
| testDriver.readOutput(SINK_TOPIC_2, Serdes.Integer().deserializer(), Serdes.Double().deserializer()), |
| source2Key, |
| source2Value); |
| } |
| |
| @Test |
| public void shouldProcessConsumerRecordList() { |
| testDriver = new TopologyTestDriver(setupMultipleSourceTopology(SOURCE_TOPIC_1, SOURCE_TOPIC_2), config); |
| |
| final List<Record> processedRecords1 = mockProcessors.get(0).processedRecords; |
| final List<Record> processedRecords2 = mockProcessors.get(1).processedRecords; |
| |
| final List<ConsumerRecord<byte[], byte[]>> testRecords = new ArrayList<>(2); |
| testRecords.add(consumerRecord1); |
| testRecords.add(consumerRecord2); |
| |
| testDriver.pipeInput(testRecords); |
| |
| assertEquals(1, processedRecords1.size()); |
| assertEquals(1, processedRecords2.size()); |
| |
| Record record = processedRecords1.get(0); |
| Record expectedResult = new Record(consumerRecord1, 0L); |
| assertThat(record, equalTo(expectedResult)); |
| |
| record = processedRecords2.get(0); |
| expectedResult = new Record(consumerRecord2, 0L); |
| assertThat(record, equalTo(expectedResult)); |
| } |
| |
| @Test |
| public void shouldForwardRecordsFromSubtopologyToSubtopology() { |
| testDriver = new TopologyTestDriver(setupTopologyWithTwoSubtopologies(), config); |
| |
| testDriver.pipeInput(consumerRecord1); |
| |
| ProducerRecord outputRecord = testDriver.readOutput(SINK_TOPIC_1); |
| assertEquals(key1, outputRecord.key()); |
| assertEquals(value1, outputRecord.value()); |
| assertEquals(SINK_TOPIC_1, outputRecord.topic()); |
| |
| outputRecord = testDriver.readOutput(SINK_TOPIC_2); |
| assertEquals(key1, outputRecord.key()); |
| assertEquals(value1, outputRecord.value()); |
| assertEquals(SINK_TOPIC_2, outputRecord.topic()); |
| } |
| |
| @Test |
| public void shouldPopulateGlobalStore() { |
| testDriver = new TopologyTestDriver(setupGlobalStoreTopology(SOURCE_TOPIC_1), config); |
| |
| final KeyValueStore<byte[], byte[]> globalStore = testDriver.getKeyValueStore(SOURCE_TOPIC_1 + "-globalStore"); |
| Assert.assertNotNull(globalStore); |
| Assert.assertNotNull(testDriver.getAllStateStores().get(SOURCE_TOPIC_1 + "-globalStore")); |
| |
| testDriver.pipeInput(consumerRecord1); |
| |
| final List<Record> processedRecords = mockProcessors.get(0).processedRecords; |
| assertEquals(1, processedRecords.size()); |
| |
| final Record record = processedRecords.get(0); |
| final Record expectedResult = new Record(consumerRecord1, 0L); |
| assertThat(record, equalTo(expectedResult)); |
| } |
| |
| @Test |
| public void shouldPunctuateOnStreamsTime() { |
| final MockPunctuator mockPunctuator = new MockPunctuator(); |
| testDriver = new TopologyTestDriver( |
| setupSingleProcessorTopology(10L, PunctuationType.STREAM_TIME, mockPunctuator), |
| config); |
| |
| final List<Long> expectedPunctuations = new LinkedList<>(); |
| |
| expectedPunctuations.add(42L); |
| testDriver.pipeInput(consumerRecordFactory.create(SOURCE_TOPIC_1, key1, value1, 42L)); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| testDriver.pipeInput(consumerRecordFactory.create(SOURCE_TOPIC_1, key1, value1, 42L)); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| expectedPunctuations.add(51L); |
| testDriver.pipeInput(consumerRecordFactory.create(SOURCE_TOPIC_1, key1, value1, 51L)); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| testDriver.pipeInput(consumerRecordFactory.create(SOURCE_TOPIC_1, key1, value1, 52L)); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| expectedPunctuations.add(61L); |
| testDriver.pipeInput(consumerRecordFactory.create(SOURCE_TOPIC_1, key1, value1, 61L)); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| testDriver.pipeInput(consumerRecordFactory.create(SOURCE_TOPIC_1, key1, value1, 65L)); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| expectedPunctuations.add(71L); |
| testDriver.pipeInput(consumerRecordFactory.create(SOURCE_TOPIC_1, key1, value1, 71L)); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| testDriver.pipeInput(consumerRecordFactory.create(SOURCE_TOPIC_1, key1, value1, 72L)); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| expectedPunctuations.add(95L); |
| testDriver.pipeInput(consumerRecordFactory.create(SOURCE_TOPIC_1, key1, value1, 95L)); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| expectedPunctuations.add(101L); |
| testDriver.pipeInput(consumerRecordFactory.create(SOURCE_TOPIC_1, key1, value1, 101L)); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| testDriver.pipeInput(consumerRecordFactory.create(SOURCE_TOPIC_1, key1, value1, 102L)); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| } |
| |
| @Test |
| public void shouldPunctuateOnWallClockTime() { |
| final MockPunctuator mockPunctuator = new MockPunctuator(); |
| testDriver = new TopologyTestDriver( |
| setupSingleProcessorTopology(10L, PunctuationType.WALL_CLOCK_TIME, mockPunctuator), |
| config, |
| 0); |
| |
| final List<Long> expectedPunctuations = new LinkedList<>(); |
| |
| testDriver.advanceWallClockTime(5L); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| expectedPunctuations.add(14L); |
| testDriver.advanceWallClockTime(9L); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| testDriver.advanceWallClockTime(1L); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| expectedPunctuations.add(35L); |
| testDriver.advanceWallClockTime(20L); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| |
| expectedPunctuations.add(40L); |
| testDriver.advanceWallClockTime(5L); |
| assertThat(mockPunctuator.punctuatedAt, equalTo(expectedPunctuations)); |
| } |
| |
| @Test |
| public void shouldReturnAllStores() { |
| final Topology topology = setupSourceSinkTopology(); |
| topology.addProcessor("processor", () -> null, "source"); |
| topology.addStateStore( |
| new KeyValueStoreBuilder<>( |
| Stores.inMemoryKeyValueStore("store"), |
| Serdes.ByteArray(), |
| Serdes.ByteArray(), |
| new SystemTime()), |
| "processor"); |
| topology.addGlobalStore( |
| new KeyValueStoreBuilder<>( |
| Stores.inMemoryKeyValueStore("globalStore"), |
| Serdes.ByteArray(), |
| Serdes.ByteArray(), |
| new SystemTime()).withLoggingDisabled(), |
| "sourceProcessorName", |
| Serdes.ByteArray().deserializer(), |
| Serdes.ByteArray().deserializer(), |
| "globalTopicName", |
| "globalProcessorName", |
| () -> null); |
| |
| testDriver = new TopologyTestDriver(topology, config); |
| |
| final Set<String> expectedStoreNames = new HashSet<>(); |
| expectedStoreNames.add("store"); |
| expectedStoreNames.add("globalStore"); |
| final Map<String, StateStore> allStores = testDriver.getAllStateStores(); |
| assertThat(allStores.keySet(), equalTo(expectedStoreNames)); |
| for (final StateStore store : allStores.values()) { |
| assertNotNull(store); |
| } |
| } |
| |
| @Test |
| public void shouldReturnAllStoresNames() { |
| final Topology topology = setupSourceSinkTopology(); |
| topology.addStateStore( |
| new KeyValueStoreBuilder<>( |
| Stores.inMemoryKeyValueStore("store"), |
| Serdes.ByteArray(), |
| Serdes.ByteArray(), |
| new SystemTime())); |
| topology.addGlobalStore( |
| new KeyValueStoreBuilder<>( |
| Stores.inMemoryKeyValueStore("globalStore"), |
| Serdes.ByteArray(), |
| Serdes.ByteArray(), |
| new SystemTime()).withLoggingDisabled(), |
| "sourceProcessorName", |
| Serdes.ByteArray().deserializer(), |
| Serdes.ByteArray().deserializer(), |
| "globalTopicName", |
| "globalProcessorName", |
| () -> null); |
| |
| testDriver = new TopologyTestDriver(topology, config); |
| |
| final Set<String> expectedStoreNames = new HashSet<>(); |
| expectedStoreNames.add("store"); |
| expectedStoreNames.add("globalStore"); |
| assertThat(testDriver.getAllStateStores().keySet(), equalTo(expectedStoreNames)); |
| } |
| |
| private void setup() { |
| setup(Stores.inMemoryKeyValueStore("aggStore")); |
| } |
| |
| private void setup(final KeyValueBytesStoreSupplier storeSupplier) { |
| final Topology topology = new Topology(); |
| topology.addSource("sourceProcessor", "input-topic"); |
| topology.addProcessor("aggregator", new CustomMaxAggregatorSupplier(), "sourceProcessor"); |
| topology.addStateStore(Stores.keyValueStoreBuilder( |
| storeSupplier, |
| Serdes.String(), |
| Serdes.Long()), |
| "aggregator"); |
| topology.addSink("sinkProcessor", "result-topic", "aggregator"); |
| |
| config.setProperty(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); |
| config.setProperty(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Long().getClass().getName()); |
| testDriver = new TopologyTestDriver(topology, config); |
| |
| store = testDriver.getKeyValueStore("aggStore"); |
| store.put("a", 21L); |
| } |
| |
| @Test |
| public void shouldFlushStoreForFirstInput() { |
| setup(); |
| testDriver.pipeInput(recordFactory.create("input-topic", "a", 1L, 9999L)); |
| OutputVerifier.compareKeyValue(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer), "a", 21L); |
| Assert.assertNull(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer)); |
| } |
| |
| @Test |
| public void shouldNotUpdateStoreForSmallerValue() { |
| setup(); |
| testDriver.pipeInput(recordFactory.create("input-topic", "a", 1L, 9999L)); |
| Assert.assertThat(store.get("a"), equalTo(21L)); |
| OutputVerifier.compareKeyValue(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer), "a", 21L); |
| Assert.assertNull(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer)); |
| } |
| |
| @Test |
| public void shouldNotUpdateStoreForLargerValue() { |
| setup(); |
| testDriver.pipeInput(recordFactory.create("input-topic", "a", 42L, 9999L)); |
| Assert.assertThat(store.get("a"), equalTo(42L)); |
| OutputVerifier.compareKeyValue(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer), "a", 42L); |
| Assert.assertNull(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer)); |
| } |
| |
| @Test |
| public void shouldUpdateStoreForNewKey() { |
| setup(); |
| testDriver.pipeInput(recordFactory.create("input-topic", "b", 21L, 9999L)); |
| Assert.assertThat(store.get("b"), equalTo(21L)); |
| OutputVerifier.compareKeyValue(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer), "a", 21L); |
| OutputVerifier.compareKeyValue(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer), "b", 21L); |
| Assert.assertNull(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer)); |
| } |
| |
| @Test |
| public void shouldPunctuateIfEvenTimeAdvances() { |
| setup(); |
| testDriver.pipeInput(recordFactory.create("input-topic", "a", 1L, 9999L)); |
| OutputVerifier.compareKeyValue(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer), "a", 21L); |
| |
| testDriver.pipeInput(recordFactory.create("input-topic", "a", 1L, 9999L)); |
| Assert.assertNull(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer)); |
| |
| testDriver.pipeInput(recordFactory.create("input-topic", "a", 1L, 10000L)); |
| OutputVerifier.compareKeyValue(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer), "a", 21L); |
| Assert.assertNull(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer)); |
| } |
| |
| @Test |
| public void shouldPunctuateIfWallClockTimeAdvances() { |
| setup(); |
| testDriver.advanceWallClockTime(60000); |
| OutputVerifier.compareKeyValue(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer), "a", 21L); |
| Assert.assertNull(testDriver.readOutput("result-topic", stringDeserializer, longDeserializer)); |
| } |
| |
| private class CustomMaxAggregatorSupplier implements ProcessorSupplier<String, Long> { |
| @Override |
| public Processor<String, Long> get() { |
| return new CustomMaxAggregator(); |
| } |
| } |
| |
| private class CustomMaxAggregator implements Processor<String, Long> { |
| ProcessorContext context; |
| private KeyValueStore<String, Long> store; |
| |
| @SuppressWarnings("unchecked") |
| @Override |
| public void init(final ProcessorContext context) { |
| this.context = context; |
| context.schedule(Duration.ofMinutes(1), PunctuationType.WALL_CLOCK_TIME, timestamp -> flushStore()); |
| context.schedule(Duration.ofSeconds(10), PunctuationType.STREAM_TIME, timestamp -> flushStore()); |
| store = (KeyValueStore<String, Long>) context.getStateStore("aggStore"); |
| } |
| |
| @Override |
| public void process(final String key, final Long value) { |
| final Long oldValue = store.get(key); |
| if (oldValue == null || value > oldValue) { |
| store.put(key, value); |
| } |
| } |
| |
| private void flushStore() { |
| try (final KeyValueIterator<String, Long> it = store.all()) { |
| while (it.hasNext()) { |
| final KeyValue<String, Long> next = it.next(); |
| context.forward(next.key, next.value); |
| } |
| } |
| } |
| |
| @Override |
| public void close() {} |
| } |
| |
| @Test |
| public void shouldAllowPrePopulatingStatesStoresWithCachingEnabled() { |
| final Topology topology = new Topology(); |
| topology.addSource("sourceProcessor", "input-topic"); |
| topology.addProcessor("aggregator", new CustomMaxAggregatorSupplier(), "sourceProcessor"); |
| topology.addStateStore(Stores.keyValueStoreBuilder( |
| Stores.inMemoryKeyValueStore("aggStore"), |
| Serdes.String(), |
| Serdes.Long()).withCachingEnabled(), // intentionally turn on caching to achieve better test coverage |
| "aggregator"); |
| |
| testDriver = new TopologyTestDriver(topology, config); |
| |
| store = testDriver.getKeyValueStore("aggStore"); |
| store.put("a", 21L); |
| } |
| |
| @Test |
| public void shouldCleanUpPersistentStateStoresOnClose() { |
| final Topology topology = new Topology(); |
| topology.addSource("sourceProcessor", "input-topic"); |
| topology.addProcessor( |
| "storeProcessor", |
| new ProcessorSupplier() { |
| @Override |
| public Processor get() { |
| return new Processor<String, Long>() { |
| private KeyValueStore<String, Long> store; |
| |
| @SuppressWarnings("unchecked") |
| @Override |
| public void init(final ProcessorContext context) { |
| this.store = (KeyValueStore<String, Long>) context.getStateStore("storeProcessorStore"); |
| } |
| |
| @Override |
| public void process(final String key, final Long value) { |
| store.put(key, value); |
| } |
| |
| @Override |
| public void close() {} |
| }; |
| } |
| }, |
| "sourceProcessor" |
| ); |
| topology.addStateStore(Stores.keyValueStoreBuilder(Stores.persistentKeyValueStore("storeProcessorStore"), Serdes.String(), Serdes.Long()), "storeProcessor"); |
| |
| final Properties config = new Properties(); |
| config.put(StreamsConfig.APPLICATION_ID_CONFIG, "test-TopologyTestDriver-cleanup"); |
| config.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, "dummy:1234"); |
| config.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getAbsolutePath()); |
| config.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.String().getClass().getName()); |
| config.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.Long().getClass().getName()); |
| |
| try (final TopologyTestDriver testDriver = new TopologyTestDriver(topology, config)) { |
| Assert.assertNull(testDriver.getKeyValueStore("storeProcessorStore").get("a")); |
| testDriver.pipeInput(recordFactory.create("input-topic", "a", 1L)); |
| Assert.assertEquals(1L, testDriver.getKeyValueStore("storeProcessorStore").get("a")); |
| } |
| |
| |
| try (final TopologyTestDriver testDriver = new TopologyTestDriver(topology, config)) { |
| Assert.assertNull( |
| "Closing the prior test driver should have cleaned up this store and value.", |
| testDriver.getKeyValueStore("storeProcessorStore").get("a") |
| ); |
| } |
| |
| } |
| |
| @Test |
| public void shouldFeedStoreFromGlobalKTable() { |
| final StreamsBuilder builder = new StreamsBuilder(); |
| builder.globalTable("topic", |
| Consumed.with(Serdes.String(), Serdes.String()), |
| Materialized.as("globalStore")); |
| try (final TopologyTestDriver testDriver = new TopologyTestDriver(builder.build(), config)) { |
| final KeyValueStore<String, String> globalStore = testDriver.getKeyValueStore("globalStore"); |
| Assert.assertNotNull(globalStore); |
| Assert.assertNotNull(testDriver.getAllStateStores().get("globalStore")); |
| final ConsumerRecordFactory<String, String> recordFactory = new ConsumerRecordFactory<>(new StringSerializer(), new StringSerializer()); |
| testDriver.pipeInput(recordFactory.create("topic", "k1", "value1")); |
| // we expect to have both in the global store, the one from pipeInput and the one from the producer |
| Assert.assertEquals("value1", globalStore.get("k1")); |
| } |
| } |
| |
| private Topology setupMultipleSourcesPatternTopology(final Pattern... sourceTopicPatternNames) { |
| final Topology topology = new Topology(); |
| |
| final String[] processorNames = new String[sourceTopicPatternNames.length]; |
| int i = 0; |
| for (final Pattern sourceTopicPatternName : sourceTopicPatternNames) { |
| final String sourceName = sourceTopicPatternName + "-source"; |
| final String processorName = sourceTopicPatternName + "-processor"; |
| topology.addSource(sourceName, sourceTopicPatternName); |
| processorNames[i++] = processorName; |
| topology.addProcessor(processorName, new MockProcessorSupplier(), sourceName); |
| } |
| topology.addSink("sink-topic", SINK_TOPIC_1, processorNames); |
| return topology; |
| } |
| |
| @Test |
| public void shouldProcessFromSourcesThatMatchMultiplePattern() { |
| |
| final Pattern pattern2Source1 = Pattern.compile("source-topic-\\d"); |
| final Pattern pattern2Source2 = Pattern.compile("source-topic-[A-Z]"); |
| final String consumerTopic2 = "source-topic-Z"; |
| |
| final ConsumerRecord<byte[], byte[]> consumerRecord2 = consumerRecordFactory.create(consumerTopic2, key2, value2, timestamp2); |
| |
| testDriver = new TopologyTestDriver(setupMultipleSourcesPatternTopology(pattern2Source1, pattern2Source2), config); |
| |
| final List<Record> processedRecords1 = mockProcessors.get(0).processedRecords; |
| final List<Record> processedRecords2 = mockProcessors.get(1).processedRecords; |
| |
| testDriver.pipeInput(consumerRecord1); |
| |
| assertEquals(1, processedRecords1.size()); |
| assertEquals(0, processedRecords2.size()); |
| |
| final Record record1 = processedRecords1.get(0); |
| final Record expectedResult1 = new Record(consumerRecord1, 0L); |
| assertThat(record1, equalTo(expectedResult1)); |
| |
| testDriver.pipeInput(consumerRecord2); |
| |
| assertEquals(1, processedRecords1.size()); |
| assertEquals(1, processedRecords2.size()); |
| |
| final Record record2 = processedRecords2.get(0); |
| final Record expectedResult2 = new Record(consumerRecord2, 0L); |
| assertThat(record2, equalTo(expectedResult2)); |
| } |
| |
| @Test |
| public void shouldProcessFromSourceThatMatchPattern() { |
| final String sourceName = "source"; |
| final Pattern pattern2Source1 = Pattern.compile("source-topic-\\d"); |
| |
| final Topology topology = new Topology(); |
| |
| topology.addSource(sourceName, pattern2Source1); |
| topology.addSink("sink", SINK_TOPIC_1, sourceName); |
| |
| testDriver = new TopologyTestDriver(topology, config); |
| testDriver.pipeInput(consumerRecord1); |
| |
| final ProducerRecord outputRecord = testDriver.readOutput(SINK_TOPIC_1); |
| assertEquals(key1, outputRecord.key()); |
| assertEquals(value1, outputRecord.value()); |
| assertEquals(SINK_TOPIC_1, outputRecord.topic()); |
| } |
| |
| @Test |
| public void shouldThrowPatternNotValidForTopicNameException() { |
| final String sourceName = "source"; |
| final String pattern2Source1 = "source-topic-\\d"; |
| |
| final Topology topology = new Topology(); |
| |
| topology.addSource(sourceName, pattern2Source1); |
| topology.addSink("sink", SINK_TOPIC_1, sourceName); |
| |
| testDriver = new TopologyTestDriver(topology, config); |
| try { |
| testDriver.pipeInput(consumerRecord1); |
| } catch (final TopologyException exception) { |
| final String str = |
| String.format( |
| "Invalid topology: Topology add source of type String for topic: %s cannot contain regex pattern for " + |
| "input record topic: %s and hence cannot process the message.", |
| pattern2Source1, |
| SOURCE_TOPIC_1); |
| assertEquals(str, exception.getMessage()); |
| } |
| } |
| |
| @Test |
| public void shouldNotCreateStateDirectoryForStatelessTopology() { |
| setup(); |
| final String stateDir = config.getProperty(StreamsConfig.STATE_DIR_CONFIG); |
| final File appDir = new File(stateDir, config.getProperty(StreamsConfig.APPLICATION_ID_CONFIG)); |
| assertFalse(appDir.exists()); |
| } |
| |
| @Test |
| public void shouldCreateStateDirectoryForStatefulTopology() { |
| setup(Stores.persistentKeyValueStore("aggStore")); |
| final String stateDir = config.getProperty(StreamsConfig.STATE_DIR_CONFIG); |
| final File appDir = new File(stateDir, config.getProperty(StreamsConfig.APPLICATION_ID_CONFIG)); |
| |
| assertTrue(appDir.exists()); |
| assertTrue(appDir.isDirectory()); |
| |
| final TaskId taskId = new TaskId(0, 0); |
| assertTrue(new File(appDir, taskId.toString()).exists()); |
| } |
| } |