blob: a0a65419974e611e9e551edefcd2b15398b392e9 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests corresponding to Dataflow's iobase module."""
from __future__ import absolute_import
import unittest
from apache_beam import Create
from apache_beam import error
from apache_beam import pvalue
from apache_beam.runners.dataflow.native_io.iobase import ConcatPosition
from apache_beam.runners.dataflow.native_io.iobase import DynamicSplitRequest
from apache_beam.runners.dataflow.native_io.iobase import DynamicSplitResultWithPosition
from apache_beam.runners.dataflow.native_io.iobase import NativeSink
from apache_beam.runners.dataflow.native_io.iobase import NativeSinkWriter
from apache_beam.runners.dataflow.native_io.iobase import NativeSource
from apache_beam.runners.dataflow.native_io.iobase import ReaderPosition
from apache_beam.runners.dataflow.native_io.iobase import ReaderProgress
from apache_beam.runners.dataflow.native_io.iobase import _dict_printable_fields
from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite
from apache_beam.testing.test_pipeline import TestPipeline
class TestHelperFunctions(unittest.TestCase):
def test_dict_printable_fields(self):
dict_object = {
'key_alpha': '1',
'key_beta': None,
'key_charlie': [],
'key_delta': 2.0,
'key_echo': 'skip_me',
'key_fox': 0
}
skip_fields = [
'key_echo',
]
self.assertEqual(
sorted(_dict_printable_fields(dict_object, skip_fields)),
[
"key_alpha='1'",
'key_delta=2.0',
'key_fox=0'
]
)
class TestNativeSource(unittest.TestCase):
def test_reader_method(self):
native_source = NativeSource()
self.assertRaises(NotImplementedError, native_source.reader)
def test_repr_method(self):
class FakeSource(NativeSource):
"""A fake source modeled after BigQuerySource, which inherits from
NativeSource."""
def __init__(self, table=None, dataset=None, project=None, query=None,
validate=False, coder=None, use_std_sql=False,
flatten_results=True):
self.validate = validate
fake_source = FakeSource()
self.assertEqual(fake_source.__repr__(), '<FakeSource validate=False>')
class TestReaderProgress(unittest.TestCase):
def test_out_of_bounds_percent_complete(self):
with self.assertRaises(ValueError):
ReaderProgress(percent_complete=-0.1)
with self.assertRaises(ValueError):
ReaderProgress(percent_complete=1.1)
def test_position_property(self):
reader_progress = ReaderProgress(position=ReaderPosition())
self.assertEqual(type(reader_progress.position), ReaderPosition)
def test_percent_complete_property(self):
reader_progress = ReaderProgress(percent_complete=0.5)
self.assertEqual(reader_progress.percent_complete, 0.5)
class TestReaderPosition(unittest.TestCase):
def test_invalid_concat_position_type(self):
with self.assertRaises(AssertionError):
ReaderPosition(concat_position=1)
def test_valid_concat_position_type(self):
ReaderPosition(concat_position=ConcatPosition(None, None))
class TestConcatPosition(unittest.TestCase):
def test_invalid_position_type(self):
with self.assertRaises(AssertionError):
ConcatPosition(None, position=1)
def test_valid_position_type(self):
ConcatPosition(None, position=ReaderPosition())
class TestDynamicSplitRequest(unittest.TestCase):
def test_invalid_progress_type(self):
with self.assertRaises(AssertionError):
DynamicSplitRequest(progress=1)
def test_valid_progress_type(self):
DynamicSplitRequest(progress=ReaderProgress())
class TestDynamicSplitResultWithPosition(unittest.TestCase):
def test_invalid_stop_position_type(self):
with self.assertRaises(AssertionError):
DynamicSplitResultWithPosition(stop_position=1)
def test_valid_stop_position_type(self):
DynamicSplitResultWithPosition(stop_position=ReaderPosition())
class TestNativeSink(unittest.TestCase):
def test_writer_method(self):
native_sink = NativeSink()
self.assertRaises(NotImplementedError, native_sink.writer)
def test_repr_method(self):
class FakeSink(NativeSink):
"""A fake sink modeled after BigQuerySink, which inherits from
NativeSink."""
def __init__(self, validate=False, dataset=None, project=None,
schema=None, create_disposition='create',
write_disposition=None, coder=None):
self.validate = validate
fake_sink = FakeSink()
self.assertEqual(fake_sink.__repr__(), "<FakeSink ['validate=False']>")
def test_on_direct_runner(self):
class FakeSink(NativeSink):
"""A fake sink outputing a number of elements."""
def __init__(self):
self.written_values = []
self.writer_instance = FakeSinkWriter(self.written_values)
def writer(self):
return self.writer_instance
class FakeSinkWriter(NativeSinkWriter):
"""A fake sink writer for testing."""
def __init__(self, written_values):
self.written_values = written_values
def __enter__(self):
return self
def __exit__(self, *unused_args):
pass
def Write(self, value):
self.written_values.append(value)
p = TestPipeline()
sink = FakeSink()
p | Create(['a', 'b', 'c']) | _NativeWrite(sink) # pylint: disable=expression-not-assigned
p.run()
self.assertEqual(['a', 'b', 'c'], sorted(sink.written_values))
class Test_NativeWrite(unittest.TestCase):
def setUp(self):
self.native_sink = NativeSink()
self.native_write = _NativeWrite(self.native_sink)
def test_expand_method_pcollection_errors(self):
with self.assertRaises(error.TransformError):
self.native_write.expand(None)
with self.assertRaises(error.TransformError):
pcoll = pvalue.PCollection(pipeline=None)
self.native_write.expand(pcoll)
if __name__ == '__main__':
unittest.main()