blob: c05039cb703ed39feb9bda2f605e170dab070606 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Runs the examples from the README.md file."""
import argparse
import logging
import os
import random
import re
import sys
import tempfile
import unittest
import mock
import yaml
from yaml.loader import SafeLoader
import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.typehints import trivial_inference
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_testing
from apache_beam.yaml import yaml_transform
from apache_beam.yaml import yaml_utils
class FakeSql(beam.PTransform):
def __init__(self, query):
self.query = query
def default_label(self):
return 'Sql'
def expand(self, inputs):
if isinstance(inputs, beam.PCollection):
inputs = {'PCOLLECTION': inputs}
# This only handles the most basic of queries, trying to infer the output
# schema...
m = re.match('select (.*?) from', self.query, flags=re.IGNORECASE)
if not m:
raise ValueError(self.query)
def guess_name_and_type(expr):
expr = expr.strip().replace('`', '')
if expr.endswith('*'):
return 'unknown', str
parts = expr.split()
if len(parts) >= 2 and parts[-2].lower() == 'as':
name = parts[-1]
elif re.match(r'[\w.]+', parts[0]):
name = parts[0].split('.')[-1]
else:
name = f'expr{hash(expr)}'
if '(' in expr:
expr = expr.lower()
if expr.startswith('count'):
typ = int
elif expr.startswith('avg'):
typ = float
else:
typ = str
elif '+' in expr:
typ = float
else:
part = parts[0]
if '.' in part:
table, field = part.split('.')
typ = inputs[table].element_type.get_type_for(field)
else:
typ = next(iter(inputs.values())).element_type.get_type_for(name)
# Handle optionals more gracefully.
if (str(typ).startswith('typing.Union[') or
str(typ).startswith('typing.Optional[')):
if len(typ.__args__) == 2 and type(None) in typ.__args__:
typ, = [t for t in typ.__args__ if t is not type(None)]
return name, typ
if m.group(1) == '*':
return next(iter(inputs.values())) | beam.Filter(lambda _: True)
else:
output_schema = [
guess_name_and_type(expr) for expr in m.group(1).split(',')
]
output_element = beam.Row(**{name: typ() for name, typ in output_schema})
return next(iter(inputs.values())) | beam.Map(
lambda _: output_element).with_output_types(
trivial_inference.instance_to_type(output_element))
class FakeReadFromPubSub(beam.PTransform):
def __init__(self, topic, format, schema):
pass
def expand(self, p):
data = p | beam.Create([beam.Row(col1='a', col2=1, col3=0.5)])
result = data | beam.Map(
lambda row: beam.transforms.window.TimestampedValue(row, 0))
# TODO(robertwb): Allow this to be inferred.
result.element_type = data.element_type
return result
class FakeWriteToPubSub(beam.PTransform):
def __init__(self, topic, format):
pass
def expand(self, pcoll):
return pcoll
class FakeAggregation(beam.PTransform):
def __init__(self, **unused_kwargs):
pass
def expand(self, pcoll):
return pcoll | beam.GroupBy(lambda _: 'key').aggregate_field(
lambda _: 1, sum, 'count')
class _Fakes:
fn = str
class SomeTransform(beam.PTransform):
def __init__(*args, **kwargs):
pass
def expand(self, pcoll):
return pcoll
RENDER_DIR = None
TEST_TRANSFORMS = {
'Sql': FakeSql,
'ReadFromPubSub': FakeReadFromPubSub,
'WriteToPubSub': FakeWriteToPubSub,
'SomeGroupingTransform': FakeAggregation,
'SomeTransform': _Fakes.SomeTransform,
'AnotherTransform': _Fakes.SomeTransform,
}
class TestProvider(yaml_provider.InlineProvider):
def _affinity(self, other):
# Always try to choose this one.
return float('inf')
class TestEnvironment:
def __enter__(self):
self.tempdir = tempfile.TemporaryDirectory()
return self
def input_file(self, name, content):
path = os.path.join(self.tempdir.name, name)
with open(path, 'w') as fout:
fout.write(content)
return path
def input_csv(self):
return self.input_file('input.csv', 'col1,col2,col3\na,1,2.5\n')
def input_tsv(self):
return self.input_file('input.tsv', 'col1\tcol2\tcol3\nabc\t1\t2.5\n')
def input_json(self):
return self.input_file(
'input.json', '{"col1": "abc", "col2": 1, "col3": 2.5"}\n')
def output_file(self):
return os.path.join(
self.tempdir.name, str(random.randint(0, 1000)) + '.out')
def udf_file(self, name):
if name == 'my_mapping':
lines = '\n'.join(['def my_mapping(row):', '\treturn "good"'])
else:
lines = '\n'.join(['def my_filter(row):', '\treturn True'])
return self.input_file('udf.py', lines)
def __exit__(self, *args):
self.tempdir.cleanup()
def replace_recursive(spec, transform_type, arg_name, arg_value):
if isinstance(spec, dict):
spec = {
key: replace_recursive(value, transform_type, arg_name, arg_value)
for (key, value) in spec.items()
}
if spec.get('type', None) == transform_type:
spec['config'][arg_name] = arg_value
return spec
elif isinstance(spec, list):
return [
replace_recursive(value, transform_type, arg_name, arg_value)
for value in spec
]
else:
return spec
def create_test_method(test_type, test_name, test_yaml):
test_yaml = test_yaml.replace(
'apache_beam.pkg.module.', 'apache_beam.yaml.readme_test._Fakes.')
test_yaml = test_yaml.replace(
'pkg.module.', 'apache_beam.yaml.readme_test._Fakes.')
def test(self):
with TestEnvironment() as env:
nonlocal test_yaml
test_yaml = test_yaml.replace('/path/to/*.tsv', env.input_tsv())
if 'MapToFields' in test_yaml or 'Filter' in test_yaml:
if 'my_mapping' in test_yaml:
test_yaml = test_yaml.replace(
'/path/to/some/udf.py', env.udf_file('my_mapping'))
elif 'my_filter' in test_yaml:
test_yaml = test_yaml.replace(
'/path/to/some/udf.py', env.udf_file('my_filter'))
spec = yaml.load(test_yaml, Loader=SafeLoader)
if test_type == 'PARSE':
return
if 'ReadFromCsv' in test_yaml:
spec = replace_recursive(spec, 'ReadFromCsv', 'path', env.input_csv())
if 'ReadFromText' in test_yaml:
spec = replace_recursive(spec, 'ReadFromText', 'path', env.input_csv())
if 'ReadFromJson' in test_yaml:
spec = replace_recursive(spec, 'ReadFromJson', 'path', env.input_json())
for write in ['WriteToText', 'WriteToCsv', 'WriteToJson']:
if write in test_yaml:
spec = replace_recursive(spec, write, 'path', env.output_file())
modified_yaml = yaml.dump(spec)
options = {'pickle_library': 'cloudpickle'}
if RENDER_DIR is not None:
options['runner'] = 'apache_beam.runners.render.RenderRunner'
options['render_output'] = [
os.path.join(RENDER_DIR, test_name + '.png')
]
options['render_leaf_composite_nodes'] = ['.*']
test_provider = TestProvider(TEST_TRANSFORMS)
with mock.patch(
'apache_beam.yaml.yaml_provider.SqlBackedProvider.sql_provider',
lambda self: test_provider):
# TODO(polber) - remove once there is support for ExternalTransforms
# in precommits
with mock.patch(
'apache_beam.yaml.yaml_provider.ExternalProvider.create_transform',
lambda *args, **kwargs: _Fakes.SomeTransform(*args, **kwargs)):
# Uses the FnApiRunner to ensure errors are mocked/passed through
# correctly
p = beam.Pipeline('FnApiRunner', options=PipelineOptions(**options))
yaml_transform.expand_pipeline(
p, modified_yaml, yaml_provider.merge_providers([test_provider]))
if test_type == 'BUILD':
return
p.run().wait_until_finish()
return test
def parse_test_methods(markdown_lines):
# pylint: disable=too-many-nested-blocks
def extract_inputs(input_spec):
if not input_spec:
return set()
elif isinstance(input_spec, str):
return set([input_spec.split('.')[0]])
elif isinstance(input_spec, list):
return set.union(*[extract_inputs(v) for v in input_spec])
elif isinstance(input_spec, dict):
return set.union(*[extract_inputs(v) for v in input_spec.values()])
else:
raise ValueError("Misformed inputs: " + input_spec)
def extract_name(input_spec):
return input_spec.get('name', input_spec.get('type'))
code_lines = None
last_pipeline = None
for ix, line in enumerate(markdown_lines):
line = line.rstrip()
if line == '```':
if code_lines is None:
code_lines = []
test_type = 'RUN'
test_name = f'test_line_{ix + 2}'
else:
if code_lines:
if code_lines[0].startswith('- type:'):
specs = yaml.load('\n'.join(code_lines), Loader=SafeLoader)
if 'dependencies:' in specs:
test_type = 'PARSE'
is_chain = not any('input' in spec for spec in specs)
if is_chain:
undefined_inputs = set(['input'])
else:
undefined_inputs = set.union(
*[extract_inputs(spec.get('input')) for spec in specs]) - set(
extract_name(spec) for spec in specs)
# Treat this as a fragment of a larger pipeline.
# pylint: disable=not-an-iterable
code_lines = [
'pipeline:',
' type: chain' if is_chain else '',
' transforms:',
] + [
' - {type: ReadFromCsv, name: "%s", config: {path: x}}' %
undefined_input for undefined_input in undefined_inputs
] + [' ' + line for line in code_lines]
if code_lines[0] == 'pipeline:':
yaml_pipeline = '\n'.join(code_lines)
last_pipeline = yaml_pipeline
if 'providers:' in yaml_pipeline or 'tests:' in yaml_pipeline:
test_type = 'PARSE'
yield test_name, create_test_method(
test_type,
test_name,
yaml_pipeline)
if 'tests:' in code_lines:
test_spec = '\n'.join(code_lines)
if code_lines[0] == 'pipeline:':
yaml_pipeline = '\n'.join(code_lines)
else:
yaml_pipeline = last_pipeline
for sub_ix, test_spec in enumerate(yaml.load(
'\n'.join(code_lines),
Loader=yaml_utils.SafeLineLoader)['tests']):
suffix = test_spec.get('name', str(sub_ix))
yield (
test_name + '_' + suffix,
# The yp=... ts=... is to capture the looped closure values.
lambda _, yp=yaml_pipeline, ts=test_spec: yaml_testing.
run_test(yp, ts))
code_lines = None
elif code_lines is not None:
code_lines.append(line)
def createTestSuite(name, path):
with open(path) as readme:
return type(name, (unittest.TestCase, ), dict(parse_test_methods(readme)))
# These are copied from $ROOT/website/www/site/content/en/documentation/sdks
# at build time.
YAML_DOCS_DIR = os.path.join(os.path.join(os.path.dirname(__file__), 'docs'))
ReadMeTest = createTestSuite(
'ReadMeTest', os.path.join(YAML_DOCS_DIR, 'yaml.md'))
ErrorHandlingTest = createTestSuite(
'ErrorHandlingTest', os.path.join(YAML_DOCS_DIR, 'yaml-errors.md'))
MappingTest = createTestSuite(
'MappingTest', os.path.join(YAML_DOCS_DIR, 'yaml-udf.md'))
CombineTest = createTestSuite(
'CombineTest', os.path.join(YAML_DOCS_DIR, 'yaml-combine.md'))
InlinePythonTest = createTestSuite(
'InlinePythonTest', os.path.join(YAML_DOCS_DIR, 'yaml-inline-python.md'))
JoinTest = createTestSuite(
'JoinTest', os.path.join(YAML_DOCS_DIR, 'yaml-join.md'))
TestingTest = createTestSuite(
'TestingTest', os.path.join(YAML_DOCS_DIR, 'yaml-testing.md'))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--render_dir', default=None)
known_args, unknown_args = parser.parse_known_args(sys.argv)
if known_args.render_dir:
RENDER_DIR = known_args.render_dir
logging.getLogger().setLevel(logging.INFO)
unittest.main(argv=unknown_args)