| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """Runs the examples from the README.md file.""" |
| |
| import argparse |
| import logging |
| import os |
| import random |
| import re |
| import sys |
| import tempfile |
| import unittest |
| |
| import mock |
| import yaml |
| from yaml.loader import SafeLoader |
| |
| import apache_beam as beam |
| from apache_beam.options.pipeline_options import PipelineOptions |
| from apache_beam.typehints import trivial_inference |
| from apache_beam.yaml import yaml_provider |
| from apache_beam.yaml import yaml_testing |
| from apache_beam.yaml import yaml_transform |
| from apache_beam.yaml import yaml_utils |
| |
| |
| class FakeSql(beam.PTransform): |
| def __init__(self, query): |
| self.query = query |
| |
| def default_label(self): |
| return 'Sql' |
| |
| def expand(self, inputs): |
| if isinstance(inputs, beam.PCollection): |
| inputs = {'PCOLLECTION': inputs} |
| # This only handles the most basic of queries, trying to infer the output |
| # schema... |
| m = re.match('select (.*?) from', self.query, flags=re.IGNORECASE) |
| if not m: |
| raise ValueError(self.query) |
| |
| def guess_name_and_type(expr): |
| expr = expr.strip().replace('`', '') |
| if expr.endswith('*'): |
| return 'unknown', str |
| parts = expr.split() |
| if len(parts) >= 2 and parts[-2].lower() == 'as': |
| name = parts[-1] |
| elif re.match(r'[\w.]+', parts[0]): |
| name = parts[0].split('.')[-1] |
| else: |
| name = f'expr{hash(expr)}' |
| if '(' in expr: |
| expr = expr.lower() |
| if expr.startswith('count'): |
| typ = int |
| elif expr.startswith('avg'): |
| typ = float |
| else: |
| typ = str |
| elif '+' in expr: |
| typ = float |
| else: |
| part = parts[0] |
| if '.' in part: |
| table, field = part.split('.') |
| typ = inputs[table].element_type.get_type_for(field) |
| else: |
| typ = next(iter(inputs.values())).element_type.get_type_for(name) |
| # Handle optionals more gracefully. |
| if (str(typ).startswith('typing.Union[') or |
| str(typ).startswith('typing.Optional[')): |
| if len(typ.__args__) == 2 and type(None) in typ.__args__: |
| typ, = [t for t in typ.__args__ if t is not type(None)] |
| return name, typ |
| |
| if m.group(1) == '*': |
| return next(iter(inputs.values())) | beam.Filter(lambda _: True) |
| else: |
| output_schema = [ |
| guess_name_and_type(expr) for expr in m.group(1).split(',') |
| ] |
| output_element = beam.Row(**{name: typ() for name, typ in output_schema}) |
| return next(iter(inputs.values())) | beam.Map( |
| lambda _: output_element).with_output_types( |
| trivial_inference.instance_to_type(output_element)) |
| |
| |
| class FakeReadFromPubSub(beam.PTransform): |
| def __init__(self, topic, format, schema): |
| pass |
| |
| def expand(self, p): |
| data = p | beam.Create([beam.Row(col1='a', col2=1, col3=0.5)]) |
| result = data | beam.Map( |
| lambda row: beam.transforms.window.TimestampedValue(row, 0)) |
| # TODO(robertwb): Allow this to be inferred. |
| result.element_type = data.element_type |
| return result |
| |
| |
| class FakeWriteToPubSub(beam.PTransform): |
| def __init__(self, topic, format): |
| pass |
| |
| def expand(self, pcoll): |
| return pcoll |
| |
| |
| class FakeAggregation(beam.PTransform): |
| def __init__(self, **unused_kwargs): |
| pass |
| |
| def expand(self, pcoll): |
| return pcoll | beam.GroupBy(lambda _: 'key').aggregate_field( |
| lambda _: 1, sum, 'count') |
| |
| |
| class _Fakes: |
| fn = str |
| |
| class SomeTransform(beam.PTransform): |
| def __init__(*args, **kwargs): |
| pass |
| |
| def expand(self, pcoll): |
| return pcoll |
| |
| |
| RENDER_DIR = None |
| TEST_TRANSFORMS = { |
| 'Sql': FakeSql, |
| 'ReadFromPubSub': FakeReadFromPubSub, |
| 'WriteToPubSub': FakeWriteToPubSub, |
| 'SomeGroupingTransform': FakeAggregation, |
| 'SomeTransform': _Fakes.SomeTransform, |
| 'AnotherTransform': _Fakes.SomeTransform, |
| } |
| |
| |
| class TestProvider(yaml_provider.InlineProvider): |
| def _affinity(self, other): |
| # Always try to choose this one. |
| return float('inf') |
| |
| |
| class TestEnvironment: |
| def __enter__(self): |
| self.tempdir = tempfile.TemporaryDirectory() |
| return self |
| |
| def input_file(self, name, content): |
| path = os.path.join(self.tempdir.name, name) |
| with open(path, 'w') as fout: |
| fout.write(content) |
| return path |
| |
| def input_csv(self): |
| return self.input_file('input.csv', 'col1,col2,col3\na,1,2.5\n') |
| |
| def input_tsv(self): |
| return self.input_file('input.tsv', 'col1\tcol2\tcol3\nabc\t1\t2.5\n') |
| |
| def input_json(self): |
| return self.input_file( |
| 'input.json', '{"col1": "abc", "col2": 1, "col3": 2.5"}\n') |
| |
| def output_file(self): |
| return os.path.join( |
| self.tempdir.name, str(random.randint(0, 1000)) + '.out') |
| |
| def udf_file(self, name): |
| if name == 'my_mapping': |
| lines = '\n'.join(['def my_mapping(row):', '\treturn "good"']) |
| else: |
| lines = '\n'.join(['def my_filter(row):', '\treturn True']) |
| return self.input_file('udf.py', lines) |
| |
| def __exit__(self, *args): |
| self.tempdir.cleanup() |
| |
| |
| def replace_recursive(spec, transform_type, arg_name, arg_value): |
| if isinstance(spec, dict): |
| spec = { |
| key: replace_recursive(value, transform_type, arg_name, arg_value) |
| for (key, value) in spec.items() |
| } |
| if spec.get('type', None) == transform_type: |
| spec['config'][arg_name] = arg_value |
| return spec |
| elif isinstance(spec, list): |
| return [ |
| replace_recursive(value, transform_type, arg_name, arg_value) |
| for value in spec |
| ] |
| else: |
| return spec |
| |
| |
| def create_test_method(test_type, test_name, test_yaml): |
| test_yaml = test_yaml.replace( |
| 'apache_beam.pkg.module.', 'apache_beam.yaml.readme_test._Fakes.') |
| test_yaml = test_yaml.replace( |
| 'pkg.module.', 'apache_beam.yaml.readme_test._Fakes.') |
| |
| def test(self): |
| with TestEnvironment() as env: |
| nonlocal test_yaml |
| test_yaml = test_yaml.replace('/path/to/*.tsv', env.input_tsv()) |
| if 'MapToFields' in test_yaml or 'Filter' in test_yaml: |
| if 'my_mapping' in test_yaml: |
| test_yaml = test_yaml.replace( |
| '/path/to/some/udf.py', env.udf_file('my_mapping')) |
| elif 'my_filter' in test_yaml: |
| test_yaml = test_yaml.replace( |
| '/path/to/some/udf.py', env.udf_file('my_filter')) |
| spec = yaml.load(test_yaml, Loader=SafeLoader) |
| if test_type == 'PARSE': |
| return |
| if 'ReadFromCsv' in test_yaml: |
| spec = replace_recursive(spec, 'ReadFromCsv', 'path', env.input_csv()) |
| if 'ReadFromText' in test_yaml: |
| spec = replace_recursive(spec, 'ReadFromText', 'path', env.input_csv()) |
| if 'ReadFromJson' in test_yaml: |
| spec = replace_recursive(spec, 'ReadFromJson', 'path', env.input_json()) |
| for write in ['WriteToText', 'WriteToCsv', 'WriteToJson']: |
| if write in test_yaml: |
| spec = replace_recursive(spec, write, 'path', env.output_file()) |
| modified_yaml = yaml.dump(spec) |
| options = {'pickle_library': 'cloudpickle'} |
| if RENDER_DIR is not None: |
| options['runner'] = 'apache_beam.runners.render.RenderRunner' |
| options['render_output'] = [ |
| os.path.join(RENDER_DIR, test_name + '.png') |
| ] |
| options['render_leaf_composite_nodes'] = ['.*'] |
| test_provider = TestProvider(TEST_TRANSFORMS) |
| with mock.patch( |
| 'apache_beam.yaml.yaml_provider.SqlBackedProvider.sql_provider', |
| lambda self: test_provider): |
| # TODO(polber) - remove once there is support for ExternalTransforms |
| # in precommits |
| with mock.patch( |
| 'apache_beam.yaml.yaml_provider.ExternalProvider.create_transform', |
| lambda *args, **kwargs: _Fakes.SomeTransform(*args, **kwargs)): |
| # Uses the FnApiRunner to ensure errors are mocked/passed through |
| # correctly |
| p = beam.Pipeline('FnApiRunner', options=PipelineOptions(**options)) |
| yaml_transform.expand_pipeline( |
| p, modified_yaml, yaml_provider.merge_providers([test_provider])) |
| if test_type == 'BUILD': |
| return |
| p.run().wait_until_finish() |
| |
| return test |
| |
| |
| def parse_test_methods(markdown_lines): |
| # pylint: disable=too-many-nested-blocks |
| |
| def extract_inputs(input_spec): |
| if not input_spec: |
| return set() |
| elif isinstance(input_spec, str): |
| return set([input_spec.split('.')[0]]) |
| elif isinstance(input_spec, list): |
| return set.union(*[extract_inputs(v) for v in input_spec]) |
| elif isinstance(input_spec, dict): |
| return set.union(*[extract_inputs(v) for v in input_spec.values()]) |
| else: |
| raise ValueError("Misformed inputs: " + input_spec) |
| |
| def extract_name(input_spec): |
| return input_spec.get('name', input_spec.get('type')) |
| |
| code_lines = None |
| last_pipeline = None |
| for ix, line in enumerate(markdown_lines): |
| line = line.rstrip() |
| if line == '```': |
| if code_lines is None: |
| code_lines = [] |
| test_type = 'RUN' |
| test_name = f'test_line_{ix + 2}' |
| else: |
| if code_lines: |
| if code_lines[0].startswith('- type:'): |
| specs = yaml.load('\n'.join(code_lines), Loader=SafeLoader) |
| if 'dependencies:' in specs: |
| test_type = 'PARSE' |
| is_chain = not any('input' in spec for spec in specs) |
| if is_chain: |
| undefined_inputs = set(['input']) |
| else: |
| undefined_inputs = set.union( |
| *[extract_inputs(spec.get('input')) for spec in specs]) - set( |
| extract_name(spec) for spec in specs) |
| # Treat this as a fragment of a larger pipeline. |
| # pylint: disable=not-an-iterable |
| code_lines = [ |
| 'pipeline:', |
| ' type: chain' if is_chain else '', |
| ' transforms:', |
| ] + [ |
| ' - {type: ReadFromCsv, name: "%s", config: {path: x}}' % |
| undefined_input for undefined_input in undefined_inputs |
| ] + [' ' + line for line in code_lines] |
| if code_lines[0] == 'pipeline:': |
| yaml_pipeline = '\n'.join(code_lines) |
| last_pipeline = yaml_pipeline |
| if 'providers:' in yaml_pipeline or 'tests:' in yaml_pipeline: |
| test_type = 'PARSE' |
| yield test_name, create_test_method( |
| test_type, |
| test_name, |
| yaml_pipeline) |
| if 'tests:' in code_lines: |
| test_spec = '\n'.join(code_lines) |
| if code_lines[0] == 'pipeline:': |
| yaml_pipeline = '\n'.join(code_lines) |
| else: |
| yaml_pipeline = last_pipeline |
| for sub_ix, test_spec in enumerate(yaml.load( |
| '\n'.join(code_lines), |
| Loader=yaml_utils.SafeLineLoader)['tests']): |
| suffix = test_spec.get('name', str(sub_ix)) |
| yield ( |
| test_name + '_' + suffix, |
| # The yp=... ts=... is to capture the looped closure values. |
| lambda _, yp=yaml_pipeline, ts=test_spec: yaml_testing. |
| run_test(yp, ts)) |
| |
| code_lines = None |
| elif code_lines is not None: |
| code_lines.append(line) |
| |
| |
| def createTestSuite(name, path): |
| with open(path) as readme: |
| return type(name, (unittest.TestCase, ), dict(parse_test_methods(readme))) |
| |
| |
| # These are copied from $ROOT/website/www/site/content/en/documentation/sdks |
| # at build time. |
| YAML_DOCS_DIR = os.path.join(os.path.join(os.path.dirname(__file__), 'docs')) |
| |
| ReadMeTest = createTestSuite( |
| 'ReadMeTest', os.path.join(YAML_DOCS_DIR, 'yaml.md')) |
| |
| ErrorHandlingTest = createTestSuite( |
| 'ErrorHandlingTest', os.path.join(YAML_DOCS_DIR, 'yaml-errors.md')) |
| |
| MappingTest = createTestSuite( |
| 'MappingTest', os.path.join(YAML_DOCS_DIR, 'yaml-udf.md')) |
| |
| CombineTest = createTestSuite( |
| 'CombineTest', os.path.join(YAML_DOCS_DIR, 'yaml-combine.md')) |
| |
| InlinePythonTest = createTestSuite( |
| 'InlinePythonTest', os.path.join(YAML_DOCS_DIR, 'yaml-inline-python.md')) |
| |
| JoinTest = createTestSuite( |
| 'JoinTest', os.path.join(YAML_DOCS_DIR, 'yaml-join.md')) |
| |
| TestingTest = createTestSuite( |
| 'TestingTest', os.path.join(YAML_DOCS_DIR, 'yaml-testing.md')) |
| |
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--render_dir', default=None) |
| known_args, unknown_args = parser.parse_known_args(sys.argv) |
| if known_args.render_dir: |
| RENDER_DIR = known_args.render_dir |
| logging.getLogger().setLevel(logging.INFO) |
| unittest.main(argv=unknown_args) |