blob: cf9990af35400c03a196f00d9334de6e0087867a [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import json
from typing import Dict
import pyflink.datastream.data_stream as data_stream
from pyflink.common import typeinfo
from pyflink.common.configuration import Configuration
from pyflink.common.serialization import SimpleStringSchema, DeserializationSchema
from pyflink.common.typeinfo import Types
from pyflink.common.types import Row
from pyflink.common.watermark_strategy import WatermarkStrategy
from pyflink.datastream.connectors.base import DeliveryGuarantee
from pyflink.datastream.connectors.kafka import KafkaSource, KafkaTopicPartition, \
KafkaOffsetsInitializer, KafkaOffsetResetStrategy, KafkaRecordSerializationSchema, KafkaSink, \
FlinkKafkaProducer, FlinkKafkaConsumer
from pyflink.datastream.formats.avro import AvroRowDeserializationSchema, AvroRowSerializationSchema
from pyflink.datastream.formats.csv import CsvRowDeserializationSchema, CsvRowSerializationSchema
from pyflink.datastream.formats.json import JsonRowDeserializationSchema, JsonRowSerializationSchema
from pyflink.java_gateway import get_gateway
from pyflink.testing.test_case_utils import (
PyFlinkStreamingTestCase,
PyFlinkTestCase,
invoke_java_object_method,
to_java_data_structure,
)
from pyflink.util.java_utils import to_jarray, is_instance_of, get_field_value
class KafkaSourceTests(PyFlinkStreamingTestCase):
def test_legacy_kafka_connector(self):
source_topic = 'test_source_topic'
sink_topic = 'test_sink_topic'
props = {'bootstrap.servers': 'localhost:9092', 'group.id': 'test_group'}
type_info = Types.ROW([Types.INT(), Types.STRING()])
# Test for kafka consumer
deserialization_schema = JsonRowDeserializationSchema.builder() \
.type_info(type_info=type_info).build()
flink_kafka_consumer = FlinkKafkaConsumer(source_topic, deserialization_schema, props)
flink_kafka_consumer.set_start_from_earliest()
flink_kafka_consumer.set_commit_offsets_on_checkpoints(True)
j_properties = get_field_value(flink_kafka_consumer.get_java_function(), 'properties')
self.assertEqual('localhost:9092', j_properties.getProperty('bootstrap.servers'))
self.assertEqual('test_group', j_properties.getProperty('group.id'))
self.assertTrue(get_field_value(flink_kafka_consumer.get_java_function(),
'enableCommitOnCheckpoints'))
j_start_up_mode = get_field_value(flink_kafka_consumer.get_java_function(), 'startupMode')
j_deserializer = get_field_value(flink_kafka_consumer.get_java_function(), 'deserializer')
j_deserialize_type_info = invoke_java_object_method(j_deserializer, "getProducedType")
deserialize_type_info = typeinfo._from_java_type(j_deserialize_type_info)
self.assertTrue(deserialize_type_info == type_info)
self.assertTrue(j_start_up_mode.equals(get_gateway().jvm
.org.apache.flink.streaming.connectors
.kafka.config.StartupMode.EARLIEST))
j_topic_desc = get_field_value(flink_kafka_consumer.get_java_function(),
'topicsDescriptor')
j_topics = invoke_java_object_method(j_topic_desc, 'getFixedTopics')
self.assertEqual(['test_source_topic'], list(j_topics))
# Test for kafka producer
serialization_schema = JsonRowSerializationSchema.builder().with_type_info(type_info) \
.build()
flink_kafka_producer = FlinkKafkaProducer(sink_topic, serialization_schema, props)
flink_kafka_producer.set_write_timestamp_to_kafka(False)
j_producer_config = get_field_value(flink_kafka_producer.get_java_function(),
'producerConfig')
self.assertEqual('localhost:9092', j_producer_config.getProperty('bootstrap.servers'))
self.assertEqual('test_group', j_producer_config.getProperty('group.id'))
self.assertFalse(get_field_value(flink_kafka_producer.get_java_function(),
'writeTimestampToKafka'))
def test_compiling(self):
source = KafkaSource.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_topics('test_topic') \
.set_value_only_deserializer(SimpleStringSchema()) \
.build()
ds = self.env.from_source(source=source,
watermark_strategy=WatermarkStrategy.for_monotonous_timestamps(),
source_name='kafka source')
ds.print()
plan = json.loads(self.env.get_execution_plan())
self.assertEqual('Source: kafka source', plan['nodes'][0]['type'])
def test_set_properties(self):
source = KafkaSource.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_group_id('test_group_id') \
.set_client_id_prefix('test_client_id_prefix') \
.set_property('test_property', 'test_value') \
.set_topics('test_topic') \
.set_value_only_deserializer(SimpleStringSchema()) \
.build()
conf = self._get_kafka_source_configuration(source)
self.assertEqual(conf.get_string('bootstrap.servers', ''), 'localhost:9092')
self.assertEqual(conf.get_string('group.id', ''), 'test_group_id')
self.assertEqual(conf.get_string('client.id.prefix', ''), 'test_client_id_prefix')
self.assertEqual(conf.get_string('test_property', ''), 'test_value')
def test_set_topics(self):
source = KafkaSource.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_topics('test_topic1', 'test_topic2') \
.set_value_only_deserializer(SimpleStringSchema()) \
.build()
kafka_subscriber = get_field_value(source.get_java_function(), 'subscriber')
self.assertEqual(
kafka_subscriber.getClass().getCanonicalName(),
'org.apache.flink.connector.kafka.source.enumerator.subscriber.TopicListSubscriber'
)
topics = get_field_value(kafka_subscriber, 'topics')
self.assertTrue(is_instance_of(topics, get_gateway().jvm.java.util.List))
self.assertEqual(topics.size(), 2)
self.assertEqual(topics[0], 'test_topic1')
self.assertEqual(topics[1], 'test_topic2')
def test_set_topic_pattern(self):
source = KafkaSource.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_topic_pattern('test_topic*') \
.set_value_only_deserializer(SimpleStringSchema()) \
.build()
kafka_subscriber = get_field_value(source.get_java_function(), 'subscriber')
self.assertEqual(
kafka_subscriber.getClass().getCanonicalName(),
'org.apache.flink.connector.kafka.source.enumerator.subscriber.TopicPatternSubscriber'
)
topic_pattern = get_field_value(kafka_subscriber, 'topicPattern')
self.assertTrue(is_instance_of(topic_pattern, get_gateway().jvm.java.util.regex.Pattern))
self.assertEqual(topic_pattern.toString(), 'test_topic*')
def test_set_partitions(self):
topic_partition_1 = KafkaTopicPartition('test_topic', 1)
topic_partition_2 = KafkaTopicPartition('test_topic', 2)
source = KafkaSource.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_partitions({topic_partition_1, topic_partition_2}) \
.set_value_only_deserializer(SimpleStringSchema()) \
.build()
kafka_subscriber = get_field_value(source.get_java_function(), 'subscriber')
self.assertEqual(
kafka_subscriber.getClass().getCanonicalName(),
'org.apache.flink.connector.kafka.source.enumerator.subscriber.PartitionSetSubscriber'
)
partitions = get_field_value(kafka_subscriber, 'subscribedPartitions')
self.assertTrue(is_instance_of(partitions, get_gateway().jvm.java.util.Set))
self.assertTrue(topic_partition_1._to_j_topic_partition() in partitions)
self.assertTrue(topic_partition_2._to_j_topic_partition() in partitions)
def test_set_starting_offsets(self):
def _build_source(initializer: KafkaOffsetsInitializer):
return KafkaSource.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_topics('test_topic') \
.set_value_only_deserializer(SimpleStringSchema()) \
.set_group_id('test_group') \
.set_starting_offsets(initializer) \
.build()
self._check_reader_handled_offsets_initializer(
_build_source(KafkaOffsetsInitializer.latest()), -1, KafkaOffsetResetStrategy.LATEST
)
self._check_reader_handled_offsets_initializer(
_build_source(KafkaOffsetsInitializer.earliest()), -2,
KafkaOffsetResetStrategy.EARLIEST
)
self._check_reader_handled_offsets_initializer(
_build_source(KafkaOffsetsInitializer.committed_offsets()), -3,
KafkaOffsetResetStrategy.NONE
)
self._check_reader_handled_offsets_initializer(
_build_source(KafkaOffsetsInitializer.committed_offsets(
KafkaOffsetResetStrategy.LATEST
)), -3, KafkaOffsetResetStrategy.LATEST
)
self._check_timestamp_offsets_initializer(
_build_source(KafkaOffsetsInitializer.timestamp(100)), 100
)
specified_offsets = {
KafkaTopicPartition('test_topic1', 1): 1000,
KafkaTopicPartition('test_topic2', 2): 2000
}
self._check_specified_offsets_initializer(
_build_source(KafkaOffsetsInitializer.offsets(specified_offsets)), specified_offsets,
KafkaOffsetResetStrategy.EARLIEST
)
self._check_specified_offsets_initializer(
_build_source(KafkaOffsetsInitializer.offsets(
specified_offsets,
KafkaOffsetResetStrategy.LATEST
)),
specified_offsets,
KafkaOffsetResetStrategy.LATEST
)
def test_bounded(self):
def _build_source(initializer: KafkaOffsetsInitializer):
return KafkaSource.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_topics('test_topic') \
.set_value_only_deserializer(SimpleStringSchema()) \
.set_group_id('test_group') \
.set_bounded(initializer) \
.build()
def _check_bounded(source: KafkaSource):
self.assertEqual(
get_field_value(source.get_java_function(), 'boundedness').toString(), 'BOUNDED'
)
self._test_set_bounded_or_unbounded(_build_source, _check_bounded)
def test_unbounded(self):
def _build_source(initializer: KafkaOffsetsInitializer):
return KafkaSource.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_topics('test_topic') \
.set_value_only_deserializer(SimpleStringSchema()) \
.set_group_id('test_group') \
.set_unbounded(initializer) \
.build()
def _check_bounded(source: KafkaSource):
self.assertEqual(
get_field_value(source.get_java_function(), 'boundedness').toString(),
'CONTINUOUS_UNBOUNDED'
)
self._test_set_bounded_or_unbounded(_build_source, _check_bounded)
def _test_set_bounded_or_unbounded(self, _build_source, _check_boundedness):
source = _build_source(KafkaOffsetsInitializer.latest())
_check_boundedness(source)
self._check_reader_handled_offsets_initializer(
source, -1, KafkaOffsetResetStrategy.LATEST, False
)
source = _build_source(KafkaOffsetsInitializer.earliest())
_check_boundedness(source)
self._check_reader_handled_offsets_initializer(
source, -2, KafkaOffsetResetStrategy.EARLIEST, False
)
source = _build_source(KafkaOffsetsInitializer.committed_offsets())
_check_boundedness(source)
self._check_reader_handled_offsets_initializer(
source, -3, KafkaOffsetResetStrategy.NONE, False
)
source = _build_source(KafkaOffsetsInitializer.committed_offsets(
KafkaOffsetResetStrategy.LATEST
))
_check_boundedness(source)
self._check_reader_handled_offsets_initializer(
source, -3, KafkaOffsetResetStrategy.LATEST, False
)
source = _build_source(KafkaOffsetsInitializer.timestamp(100))
_check_boundedness(source)
self._check_timestamp_offsets_initializer(source, 100, False)
specified_offsets = {
KafkaTopicPartition('test_topic1', 1): 1000,
KafkaTopicPartition('test_topic2', 2): 2000
}
source = _build_source(KafkaOffsetsInitializer.offsets(specified_offsets))
_check_boundedness(source)
self._check_specified_offsets_initializer(
source, specified_offsets, KafkaOffsetResetStrategy.EARLIEST, False
)
source = _build_source(KafkaOffsetsInitializer.offsets(
specified_offsets,
KafkaOffsetResetStrategy.LATEST)
)
_check_boundedness(source)
self._check_specified_offsets_initializer(
source, specified_offsets, KafkaOffsetResetStrategy.LATEST, False
)
def test_set_value_only_deserializer(self):
def _check(schema: DeserializationSchema, class_name: str):
source = KafkaSource.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_topics('test_topic') \
.set_value_only_deserializer(schema) \
.build()
deserialization_schema_wrapper = get_field_value(source.get_java_function(),
'deserializationSchema')
self.assertEqual(
deserialization_schema_wrapper.getClass().getCanonicalName(),
'org.apache.flink.connector.kafka.source.reader.deserializer'
'.KafkaValueOnlyDeserializationSchemaWrapper'
)
deserialization_schema = get_field_value(deserialization_schema_wrapper,
'deserializationSchema')
self.assertEqual(deserialization_schema.getClass().getCanonicalName(),
class_name)
_check(SimpleStringSchema(), 'org.apache.flink.api.common.serialization.SimpleStringSchema')
_check(
JsonRowDeserializationSchema.builder().type_info(Types.ROW([Types.STRING()])).build(),
'org.apache.flink.formats.json.JsonRowDeserializationSchema'
)
_check(
CsvRowDeserializationSchema.Builder(Types.ROW([Types.STRING()])).build(),
'org.apache.flink.formats.csv.CsvRowDeserializationSchema'
)
avro_schema_string = """
{
"type": "record",
"name": "test_record",
"fields": []
}
"""
_check(
AvroRowDeserializationSchema(avro_schema_string=avro_schema_string),
'org.apache.flink.formats.avro.AvroRowDeserializationSchema'
)
def _check_reader_handled_offsets_initializer(self,
source: KafkaSource,
offset: int,
reset_strategy: KafkaOffsetResetStrategy,
is_start: bool = True):
if is_start:
field_name = 'startingOffsetsInitializer'
else:
field_name = 'stoppingOffsetsInitializer'
offsets_initializer = get_field_value(source.get_java_function(), field_name)
self.assertEqual(
offsets_initializer.getClass().getCanonicalName(),
'org.apache.flink.connector.kafka.source.enumerator.initializer'
'.ReaderHandledOffsetsInitializer'
)
starting_offset = get_field_value(offsets_initializer, 'startingOffset')
self.assertEqual(starting_offset, offset)
offset_reset_strategy = get_field_value(offsets_initializer, 'offsetResetStrategy')
self.assertTrue(
offset_reset_strategy.equals(reset_strategy._to_j_offset_reset_strategy())
)
def _check_timestamp_offsets_initializer(self,
source: KafkaSource,
timestamp: int,
is_start: bool = True):
if is_start:
field_name = 'startingOffsetsInitializer'
else:
field_name = 'stoppingOffsetsInitializer'
offsets_initializer = get_field_value(source.get_java_function(), field_name)
self.assertEqual(
offsets_initializer.getClass().getCanonicalName(),
'org.apache.flink.connector.kafka.source.enumerator.initializer'
'.TimestampOffsetsInitializer'
)
starting_timestamp = get_field_value(offsets_initializer, 'startingTimestamp')
self.assertEqual(starting_timestamp, timestamp)
def _check_specified_offsets_initializer(self,
source: KafkaSource,
offsets: Dict[KafkaTopicPartition, int],
reset_strategy: KafkaOffsetResetStrategy,
is_start: bool = True):
if is_start:
field_name = 'startingOffsetsInitializer'
else:
field_name = 'stoppingOffsetsInitializer'
offsets_initializer = get_field_value(source.get_java_function(), field_name)
self.assertEqual(
offsets_initializer.getClass().getCanonicalName(),
'org.apache.flink.connector.kafka.source.enumerator.initializer'
'.SpecifiedOffsetsInitializer'
)
initial_offsets = get_field_value(offsets_initializer, 'initialOffsets')
self.assertTrue(is_instance_of(initial_offsets, get_gateway().jvm.java.util.Map))
self.assertEqual(initial_offsets.size(), len(offsets))
for j_topic_partition in initial_offsets:
topic_partition = KafkaTopicPartition(j_topic_partition.topic(),
j_topic_partition.partition())
self.assertIsNotNone(offsets.get(topic_partition))
self.assertEqual(initial_offsets[j_topic_partition], offsets[topic_partition])
offset_reset_strategy = get_field_value(offsets_initializer, 'offsetResetStrategy')
self.assertTrue(
offset_reset_strategy.equals(reset_strategy._to_j_offset_reset_strategy())
)
@staticmethod
def _get_kafka_source_configuration(source: KafkaSource):
jvm = get_gateway().jvm
j_source = source.get_java_function()
j_to_configuration = j_source.getClass().getDeclaredMethod(
'getConfiguration', to_jarray(jvm.java.lang.Class, [])
)
j_to_configuration.setAccessible(True)
j_configuration = j_to_configuration.invoke(j_source, to_jarray(jvm.java.lang.Object, []))
return Configuration(j_configuration=j_configuration)
class KafkaSinkTests(PyFlinkStreamingTestCase):
def test_compile(self):
sink = KafkaSink.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_record_serializer(self._build_serialization_schema()) \
.build()
ds = self.env.from_collection([], type_info=Types.STRING())
ds.sink_to(sink)
plan = json.loads(self.env.get_execution_plan())
self.assertEqual(plan['nodes'][1]['type'], 'Sink: Writer')
self.assertEqual(plan['nodes'][2]['type'], 'Sink: Committer')
def test_set_bootstrap_severs(self):
sink = KafkaSink.builder() \
.set_bootstrap_servers('localhost:9092,localhost:9093') \
.set_record_serializer(self._build_serialization_schema()) \
.build()
config = get_field_value(sink.get_java_function(), 'kafkaProducerConfig')
self.assertEqual(config.get('bootstrap.servers'), 'localhost:9092,localhost:9093')
def test_set_delivery_guarantee(self):
sink = KafkaSink.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_record_serializer(self._build_serialization_schema()) \
.build()
guarantee = get_field_value(sink.get_java_function(), 'deliveryGuarantee')
self.assertEqual(guarantee.toString(), 'none')
sink = KafkaSink.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_delivery_guarantee(DeliveryGuarantee.AT_LEAST_ONCE) \
.set_record_serializer(self._build_serialization_schema()) \
.build()
guarantee = get_field_value(sink.get_java_function(), 'deliveryGuarantee')
self.assertEqual(guarantee.toString(), 'at-least-once')
sink = KafkaSink.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_delivery_guarantee(DeliveryGuarantee.EXACTLY_ONCE) \
.set_record_serializer(self._build_serialization_schema()) \
.build()
guarantee = get_field_value(sink.get_java_function(), 'deliveryGuarantee')
self.assertEqual(guarantee.toString(), 'exactly-once')
def test_set_transactional_id_prefix(self):
sink = KafkaSink.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_transactional_id_prefix('test-prefix') \
.set_record_serializer(self._build_serialization_schema()) \
.build()
prefix = get_field_value(sink.get_java_function(), 'transactionalIdPrefix')
self.assertEqual(prefix, 'test-prefix')
def test_set_property(self):
sink = KafkaSink.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_record_serializer(self._build_serialization_schema()) \
.set_property('test-key', 'test-value') \
.build()
config = get_field_value(sink.get_java_function(), 'kafkaProducerConfig')
self.assertEqual(config.get('test-key'), 'test-value')
def test_set_record_serializer(self):
sink = KafkaSink.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_record_serializer(self._build_serialization_schema()) \
.build()
serializer = get_field_value(sink.get_java_function(), 'recordSerializer')
self.assertEqual(serializer.getClass().getCanonicalName(),
'org.apache.flink.connector.kafka.sink.'
'KafkaRecordSerializationSchemaBuilder.'
'KafkaRecordSerializationSchemaWrapper')
topic_selector = get_field_value(serializer, 'topicSelector')
self.assertEqual(topic_selector.apply(None), 'test-topic')
value_serializer = get_field_value(serializer, 'valueSerializationSchema')
self.assertEqual(value_serializer.getClass().getCanonicalName(),
'org.apache.flink.api.common.serialization.SimpleStringSchema')
@staticmethod
def _build_serialization_schema() -> KafkaRecordSerializationSchema:
return KafkaRecordSerializationSchema.builder() \
.set_topic('test-topic') \
.set_value_serialization_schema(SimpleStringSchema()) \
.build()
class KafkaRecordSerializationSchemaTests(PyFlinkTestCase):
def test_set_topic(self):
input_type = Types.ROW([Types.STRING()])
serialization_schema = KafkaRecordSerializationSchema.builder() \
.set_topic('test-topic') \
.set_value_serialization_schema(
JsonRowSerializationSchema.builder().with_type_info(input_type).build()) \
.build()
jvm = get_gateway().jvm
serialization_schema._j_serialization_schema.open(
jvm.org.apache.flink.connector.testutils.formats.DummyInitializationContext(),
jvm.org.apache.flink.connector.kafka.sink.DefaultKafkaSinkContext(
0, 1, jvm.java.util.Properties()))
j_record = serialization_schema._j_serialization_schema.serialize(
to_java_data_structure(Row('test')), None, None
)
self.assertEqual(j_record.topic(), 'test-topic')
self.assertIsNone(j_record.key())
self.assertEqual(j_record.value(), b'{"f0":"test"}')
def test_set_topic_selector(self):
def _select(data):
data = data[0]
if data == 'a':
return 'topic-a'
elif data == 'b':
return 'topic-b'
else:
return 'topic-dead-letter'
def _check_record(data, topic, serialized_data):
input_type = Types.ROW([Types.STRING()])
serialization_schema = KafkaRecordSerializationSchema.builder() \
.set_topic_selector(_select) \
.set_value_serialization_schema(
JsonRowSerializationSchema.builder().with_type_info(input_type).build()) \
.build()
jvm = get_gateway().jvm
serialization_schema._j_serialization_schema.open(
jvm.org.apache.flink.connector.testutils.formats.DummyInitializationContext(),
jvm.org.apache.flink.connector.kafka.sink.DefaultKafkaSinkContext(
0, 1, jvm.java.util.Properties()))
sink = KafkaSink.builder() \
.set_bootstrap_servers('localhost:9092') \
.set_record_serializer(serialization_schema) \
.build()
ds = MockDataStream(Types.ROW([Types.STRING()]))
ds.sink_to(sink)
row = Row(data)
topic_row = ds.feed(row) # type: Row
j_record = serialization_schema._j_serialization_schema.serialize(
to_java_data_structure(topic_row), None, None
)
self.assertEqual(j_record.topic(), topic)
self.assertIsNone(j_record.key())
self.assertEqual(j_record.value(), serialized_data)
_check_record('a', 'topic-a', b'{"f0":"a"}')
_check_record('b', 'topic-b', b'{"f0":"b"}')
_check_record('c', 'topic-dead-letter', b'{"f0":"c"}')
_check_record('d', 'topic-dead-letter', b'{"f0":"d"}')
def test_set_key_serialization_schema(self):
def _check_key_serialization_schema(key_serialization_schema, expected_class):
serialization_schema = KafkaRecordSerializationSchema.builder() \
.set_topic('test-topic') \
.set_key_serialization_schema(key_serialization_schema) \
.set_value_serialization_schema(SimpleStringSchema()) \
.build()
schema_field = get_field_value(serialization_schema._j_serialization_schema,
'keySerializationSchema')
self.assertIsNotNone(schema_field)
self.assertEqual(schema_field.getClass().getCanonicalName(), expected_class)
self._check_serialization_schema_implementations(_check_key_serialization_schema)
def test_set_value_serialization_schema(self):
def _check_value_serialization_schema(value_serialization_schema, expected_class):
serialization_schema = KafkaRecordSerializationSchema.builder() \
.set_topic('test-topic') \
.set_value_serialization_schema(value_serialization_schema) \
.build()
schema_field = get_field_value(serialization_schema._j_serialization_schema,
'valueSerializationSchema')
self.assertIsNotNone(schema_field)
self.assertEqual(schema_field.getClass().getCanonicalName(), expected_class)
self._check_serialization_schema_implementations(_check_value_serialization_schema)
@staticmethod
def _check_serialization_schema_implementations(check_function):
input_type = Types.ROW([Types.STRING()])
check_function(
JsonRowSerializationSchema.builder().with_type_info(input_type).build(),
'org.apache.flink.formats.json.JsonRowSerializationSchema'
)
check_function(
CsvRowSerializationSchema.Builder(input_type).build(),
'org.apache.flink.formats.csv.CsvRowSerializationSchema'
)
avro_schema_string = """
{
"type": "record",
"name": "test_record",
"fields": []
}
"""
check_function(
AvroRowSerializationSchema(avro_schema_string=avro_schema_string),
'org.apache.flink.formats.avro.AvroRowSerializationSchema'
)
check_function(
SimpleStringSchema(),
'org.apache.flink.api.common.serialization.SimpleStringSchema'
)
class MockDataStream(data_stream.DataStream):
def __init__(self, original_type=None):
super().__init__(None)
self._operators = []
self._type = original_type
def feed(self, data):
for op in self._operators:
data = op(data)
return data
def get_type(self):
return self._type
def map(self, f, output_type=None):
self._operators.append(f)
self._type = output_type
def sink_to(self, sink):
ds = self
from pyflink.datastream.connectors.base import SupportsPreprocessing
if isinstance(sink, SupportsPreprocessing) and sink.get_transformer() is not None:
ds = sink.get_transformer().apply(self)
return ds