blob: bd38ab82c8459803e54aabaf541c6531a28e0b7e [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Runs integration tests in the tests directory."""
import contextlib
import copy
import glob
import itertools
import logging
import os
import unittest
import uuid
import mock
import yaml
import apache_beam as beam
from apache_beam.io import filesystems
from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper
from apache_beam.io.gcp.internal.clients import bigquery
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.utils import python_callable
from apache_beam.yaml import yaml_provider
from apache_beam.yaml import yaml_transform
@contextlib.contextmanager
def gcs_temp_dir(bucket):
gcs_tempdir = bucket + '/yaml-' + str(uuid.uuid4())
yield gcs_tempdir
filesystems.FileSystems.delete([gcs_tempdir])
@contextlib.contextmanager
def temp_bigquery_table(project, prefix='yaml_bq_it_'):
bigquery_client = BigQueryWrapper()
dataset_id = '%s_%s' % (prefix, uuid.uuid4().hex)
bigquery_client.get_or_create_dataset(project, dataset_id)
logging.info("Created dataset %s in project %s", dataset_id, project)
yield f'{project}:{dataset_id}.tmp_table'
request = bigquery.BigqueryDatasetsDeleteRequest(
projectId=project, datasetId=dataset_id, deleteContents=True)
logging.info("Deleting dataset %s in project %s", dataset_id, project)
bigquery_client.client.datasets.Delete(request)
def replace_recursive(spec, vars):
if isinstance(spec, dict):
return {
key: replace_recursive(value, vars)
for (key, value) in spec.items()
}
elif isinstance(spec, list):
return [replace_recursive(value, vars) for value in spec]
elif isinstance(spec, str) and '{' in spec and '{\n' not in spec:
try:
return spec.format(**vars)
except Exception as exn:
raise ValueError(f"Error evaluating {spec}: {exn}") from exn
else:
return spec
def transform_types(spec):
if spec.get('type', None) in (None, 'composite', 'chain'):
if 'source' in spec:
yield from transform_types(spec['source'])
for t in spec.get('transforms', []):
yield from transform_types(t)
if 'sink' in spec:
yield from transform_types(spec['sink'])
else:
yield spec['type']
def provider_sets(spec, require_available=False):
"""For transforms that are vended by multiple providers, yields all possible
combinations of providers to use.
"""
try:
for p in spec['pipelines']:
_ = yaml_transform.preprocess(copy.deepcopy(p['pipeline']))
except Exception as exn:
print(exn)
all_transform_types = []
else:
all_transform_types = set.union(
*(
set(
transform_types(
yaml_transform.preprocess(copy.deepcopy(p['pipeline']))))
for p in spec['pipelines']))
def filter_to_available(t, providers):
if require_available:
for p in providers:
if not p.available():
raise ValueError("Provider {p} required for {t} is not available.")
return providers
else:
return [p for p in providers if p.available()]
standard_providers = yaml_provider.standard_providers()
multiple_providers = {
t: filter_to_available(t, standard_providers[t])
for t in all_transform_types
if len(filter_to_available(t, standard_providers[t])) > 1
}
if not multiple_providers:
yield 'only', standard_providers
else:
names, provider_lists = zip(*sorted(multiple_providers.items()))
for ix, c in enumerate(itertools.product(*provider_lists)):
yield (
'_'.join(
n + '_' + type(p.underlying_provider()).__name__
for (n, p) in zip(names, c)) + f'_{ix}',
dict(standard_providers, **{n: [p]
for (n, p) in zip(names, c)}))
def create_test_methods(spec):
for suffix, providers in provider_sets(spec):
def test(self, providers=providers): # default arg to capture loop value
vars = {}
with contextlib.ExitStack() as stack:
stack.enter_context(
mock.patch(
'apache_beam.yaml.yaml_provider.standard_providers',
lambda: providers))
for fixture in spec.get('fixtures', []):
vars[fixture['name']] = stack.enter_context(
python_callable.PythonCallableWithSource.
load_from_fully_qualified_name(fixture['type'])(
**yaml_transform.SafeLineLoader.strip_metadata(
fixture.get('config', {}))))
for pipeline_spec in spec['pipelines']:
with beam.Pipeline(options=PipelineOptions(
pickle_library='cloudpickle',
**yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get(
'options', {})))) as p:
yaml_transform.expand_pipeline(
p, replace_recursive(pipeline_spec, vars))
yield f'test_{suffix}', test
def parse_test_files(filepattern):
for path in glob.glob(filepattern):
with open(path) as fin:
suite_name = os.path.splitext(os.path.basename(path))[0].title().replace(
'-', '') + 'Test'
print(path, suite_name)
methods = dict(
create_test_methods(
yaml.load(fin, Loader=yaml_transform.SafeLineLoader)))
globals()[suite_name] = type(suite_name, (unittest.TestCase, ), methods)
logging.getLogger().setLevel(logging.INFO)
parse_test_files(os.path.join(os.path.dirname(__file__), 'tests', '*.yaml'))
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()