blob: 4ca60e6176bd995153d71b8138593e891f6e53b8 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Runs the examples from the README.md file."""
import argparse
import logging
import os
import random
import re
import sys
import tempfile
import unittest
import mock
import yaml
from yaml.loader import SafeLoader
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.typehints import trivial_inference
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_transform
class FakeSql(beam.PTransform):
def __init__(self, query):
self.query = query
def default_label(self):
return 'Sql'
def expand(self, inputs):
if isinstance(inputs, beam.PCollection):
inputs = {'PCOLLECTION': inputs}
# This only handles the most basic of queries, trying to infer the output
# schema...
m = re.match('select (.*?) from', self.query, flags=re.IGNORECASE)
if not m:
raise ValueError(self.query)
def guess_name_and_type(expr):
expr = expr.strip()
parts = expr.split()
if len(parts) >= 2 and parts[-2].lower() == 'as':
name = parts[-1]
elif re.match(r'[\w.]+', parts[0]):
name = parts[0].split('.')[-1]
else:
name = f'expr{hash(expr)}'
if '(' in expr:
expr = expr.lower()
if expr.startswith('count'):
typ = int
elif expr.startswith('avg'):
typ = float
else:
typ = str
elif '+' in expr:
typ = float
else:
part = parts[0]
if '.' in part:
table, field = part.split('.')
typ = inputs[table].element_type.get_type_for(field)
else:
typ = next(iter(inputs.values())).element_type.get_type_for(name)
# Handle optionals more gracefully.
if (str(typ).startswith('typing.Union[') or
str(typ).startswith('typing.Optional[')):
if len(typ.__args__) == 2 and type(None) in typ.__args__:
typ, = [t for t in typ.__args__ if t is not type(None)]
return name, typ
if m.group(1) == '*':
return inputs['PCOLLECTION'] | beam.Filter(lambda _: True)
else:
output_schema = [
guess_name_and_type(expr) for expr in m.group(1).split(',')
]
output_element = beam.Row(**{name: typ() for name, typ in output_schema})
return next(iter(inputs.values())) | beam.Map(
lambda _: output_element).with_output_types(
trivial_inference.instance_to_type(output_element))
class FakeReadFromPubSub(beam.PTransform):
def __init__(self, topic, format, schema):
pass
def expand(self, p):
data = p | beam.Create([beam.Row(col1='a', col2=1, col3=0.5)])
result = data | beam.Map(
lambda row: beam.transforms.window.TimestampedValue(row, 0))
# TODO(robertwb): Allow this to be inferred.
result.element_type = data.element_type
return result
class FakeWriteToPubSub(beam.PTransform):
def __init__(self, topic, format):
pass
def expand(self, pcoll):
return pcoll
class FakeAggregation(beam.PTransform):
def __init__(self, **unused_kwargs):
pass
def expand(self, pcoll):
return pcoll | beam.GroupBy(lambda _: 'key').aggregate_field(
lambda _: 1, sum, 'count')
RENDER_DIR = None
TEST_TRANSFORMS = {
'Sql': FakeSql,
'ReadFromPubSub': FakeReadFromPubSub,
'WriteToPubSub': FakeWriteToPubSub,
'SomeGroupingTransform': FakeAggregation,
}
class TestProvider(yaml_provider.InlineProvider):
def _affinity(self, other):
# Always try to choose this one.
return float('inf')
class TestEnvironment:
def __enter__(self):
self.tempdir = tempfile.TemporaryDirectory()
return self
def input_file(self, name, content):
path = os.path.join(self.tempdir.name, name)
with open(path, 'w') as fout:
fout.write(content)
return path
def input_csv(self):
return self.input_file('input.csv', 'col1,col2,col3\nabc,1,2.5\n')
def input_tsv(self):
return self.input_file('input.tsv', 'col1\tcol2\tcol3\nabc\t1\t2.5\n')
def input_json(self):
return self.input_file(
'input.json', '{"col1": "abc", "col2": 1, "col3": 2.5"}\n')
def output_file(self):
return os.path.join(
self.tempdir.name, str(random.randint(0, 1000)) + '.out')
def __exit__(self, *args):
self.tempdir.cleanup()
def replace_recursive(spec, transform_type, arg_name, arg_value):
if isinstance(spec, dict):
spec = {
key: replace_recursive(value, transform_type, arg_name, arg_value)
for (key, value) in spec.items()
}
if spec.get('type', None) == transform_type:
spec['config'][arg_name] = arg_value
return spec
elif isinstance(spec, list):
return [
replace_recursive(value, transform_type, arg_name, arg_value)
for value in spec
]
else:
return spec
def create_test_method(test_type, test_name, test_yaml):
test_yaml = test_yaml.replace(
'apache_beam.pkg.module.', 'apache_beam.yaml.readme_test._Fakes.')
test_yaml = test_yaml.replace(
'pkg.module.', 'apache_beam.yaml.readme_test._Fakes.')
def test(self):
with TestEnvironment() as env:
nonlocal test_yaml
test_yaml = test_yaml.replace('/path/to/*.tsv', env.input_tsv())
spec = yaml.load(test_yaml, Loader=SafeLoader)
if test_type == 'PARSE':
return
if 'ReadFromCsv' in test_yaml:
spec = replace_recursive(spec, 'ReadFromCsv', 'path', env.input_csv())
if 'ReadFromText' in test_yaml:
spec = replace_recursive(spec, 'ReadFromText', 'path', env.input_csv())
if 'ReadFromJson' in test_yaml:
spec = replace_recursive(spec, 'ReadFromJson', 'path', env.input_json())
for write in ['WriteToText', 'WriteToCsv', 'WriteToJson']:
if write in test_yaml:
spec = replace_recursive(spec, write, 'path', env.output_file())
modified_yaml = yaml.dump(spec)
options = {
'pickle_library': 'cloudpickle',
'yaml_experimental_features': ['Combine']
}
if RENDER_DIR is not None:
options['runner'] = 'apache_beam.runners.render.RenderRunner'
options['render_output'] = [
os.path.join(RENDER_DIR, test_name + '.png')
]
options['render_leaf_composite_nodes'] = ['.*']
test_provider = TestProvider(TEST_TRANSFORMS)
with mock.patch(
'apache_beam.yaml.yaml_provider.SqlBackedProvider.sql_provider',
lambda self: test_provider):
p = beam.Pipeline(options=PipelineOptions(**options))
yaml_transform.expand_pipeline(
p, modified_yaml, yaml_provider.merge_providers([test_provider]))
if test_type == 'BUILD':
return
p.run().wait_until_finish()
return test
def parse_test_methods(markdown_lines):
# pylint: disable=too-many-nested-blocks
code_lines = None
for ix, line in enumerate(markdown_lines):
line = line.rstrip()
if line == '```':
if code_lines is None:
code_lines = []
test_type = 'RUN'
test_name = f'test_line_{ix + 2}'
else:
if code_lines:
if code_lines[0].startswith('- type:'):
# Treat this as a fragment of a larger pipeline.
# pylint: disable=not-an-iterable
code_lines = [
'pipeline:',
' type: chain',
' transforms:',
' - type: ReadFromCsv',
' config:',
' path: whatever',
] + [' ' + line for line in code_lines]
if code_lines[0] == 'pipeline:':
yaml_pipeline = '\n'.join(code_lines)
if 'providers:' in yaml_pipeline:
test_type = 'PARSE'
yield test_name, create_test_method(
test_type,
test_name,
yaml_pipeline)
code_lines = None
elif code_lines is not None:
code_lines.append(line)
def createTestSuite(name, path):
with open(path) as readme:
return type(name, (unittest.TestCase, ), dict(parse_test_methods(readme)))
class _Fakes:
fn = str
class SomeTransform(beam.PTransform):
def __init__(*args, **kwargs):
pass
def expand(self, pcoll):
return pcoll
# These are copied from $ROOT/website/www/site/content/en/documentation/sdks
# at build time.
YAML_DOCS_DIR = os.path.join(os.path.join(os.path.dirname(__file__), 'docs'))
ReadMeTest = createTestSuite(
'ReadMeTest', os.path.join(YAML_DOCS_DIR, 'yaml.md'))
ErrorHandlingTest = createTestSuite(
'ErrorHandlingTest', os.path.join(YAML_DOCS_DIR, 'yaml-errors.md'))
MappingTest = createTestSuite(
'MappingTest', os.path.join(YAML_DOCS_DIR, 'yaml-udf.md'))
CombineTest = createTestSuite(
'CombineTest', os.path.join(YAML_DOCS_DIR, 'yaml-combine.md'))
InlinePythonTest = createTestSuite(
'InlinePythonTest', os.path.join(YAML_DOCS_DIR, 'yaml-inline-python.md'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--render_dir', default=None)
known_args, unknown_args = parser.parse_known_args(sys.argv)
if known_args.render_dir:
RENDER_DIR = known_args.render_dir
logging.getLogger().setLevel(logging.INFO)
unittest.main(argv=unknown_args)