[YAML] Allow explicitly including external provider lists. (#31604)
This can be more useful than jinja's {% include %} as it can refer
to urls, does not have to assume perfect indentation of the included
file, and avoids applying templitization to the included file.diff --git a/sdks/python/apache_beam/yaml/pipeline.schema.yaml b/sdks/python/apache_beam/yaml/pipeline.schema.yaml
index 40f576c..f68a730 100644
--- a/sdks/python/apache_beam/yaml/pipeline.schema.yaml
+++ b/sdks/python/apache_beam/yaml/pipeline.schema.yaml
@@ -154,6 +154,27 @@
- transforms
- config
+ providerInclude:
+ # TODO(robertwb): Consider enumerating the provider types along with
+ # the arguments they accept/expect (possibly in a separate schema file).
+ type: object
+ properties:
+ include: { type: string }
+ __line__: {}
+ __uuid__: {}
+ additionalProperties: false
+ required:
+ - include
+
+ providerOrProviderInclude:
+ if:
+ properties:
+ include {}
+ then:
+ $ref: '#/$defs/providerInclude'
+ else:
+ $ref: '#/$defs/provider'
+
type: object
properties:
pipeline:
@@ -185,7 +206,7 @@
providers:
type: array
items:
- $ref: '#/$defs/provider'
+ $ref: '#/$defs/providerOrProviderInclude'
options:
type: object
required:
diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py b/sdks/python/apache_beam/yaml/yaml_provider.py
index 794cad0..46ec070 100755
--- a/sdks/python/apache_beam/yaml/yaml_provider.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -34,6 +34,7 @@
from typing import Callable
from typing import Dict
from typing import Iterable
+from typing import Iterator
from typing import Mapping
from typing import Optional
@@ -45,6 +46,7 @@
import apache_beam.dataframe.io
import apache_beam.io
import apache_beam.transforms.util
+from apache_beam.io.filesystems import FileSystems
from apache_beam.portability.api import schema_pb2
from apache_beam.runners import pipeline_context
from apache_beam.testing.util import assert_that
@@ -222,7 +224,10 @@
config['version'] = beam_version
if type in cls._provider_types:
try:
- return cls._provider_types[type](urns, **config)
+ result = cls._provider_types[type](urns, **config)
+ if not hasattr(result, 'to_json'):
+ result.to_json = lambda: spec
+ return result
except Exception as exn:
raise ValueError(
f'Unable to instantiate provider of type {type} '
@@ -1153,18 +1158,44 @@
self._underlying_provider.cache_artifacts()
-def parse_providers(provider_specs):
- providers = collections.defaultdict(list)
+def flatten_included_provider_specs(
+ provider_specs: Iterable[Mapping]) -> Iterator[Mapping]:
+ from apache_beam.yaml.yaml_transform import SafeLineLoader
for provider_spec in provider_specs:
- provider = ExternalProvider.provider_from_spec(provider_spec)
- for transform_type in provider.provided_transforms():
- providers[transform_type].append(provider)
- # TODO: Do this better.
- provider.to_json = lambda result=provider_spec: result
- return providers
+ if 'include' in provider_spec:
+ if len(SafeLineLoader.strip_metadata(provider_spec)) != 1:
+ raise ValueError(
+ f"When using include, it must be the only parameter: "
+ f"{provider_spec} "
+ f"at line {{SafeLineLoader.get_line(provider_spec)}}")
+ include_uri = provider_spec['include']
+ try:
+ with urllib.request.urlopen(include_uri) as response:
+ content = response.read()
+ except (ValueError, urllib.error.URLError) as exn:
+ if 'unknown url type' in str(exn):
+ with FileSystems.open(include_uri) as fin:
+ content = fin.read()
+ else:
+ raise
+ included_providers = yaml.load(content, Loader=SafeLineLoader)
+ if not isinstance(included_providers, list):
+ raise ValueError(
+ f"Included file {include_uri} must be a list of Providers "
+ f"at line {{SafeLineLoader.get_line(provider_spec)}}")
+ yield from flatten_included_provider_specs(included_providers)
+ else:
+ yield provider_spec
-def merge_providers(*provider_sets):
+def parse_providers(provider_specs: Iterable[Mapping]) -> Iterable[Provider]:
+ return [
+ ExternalProvider.provider_from_spec(provider_spec)
+ for provider_spec in flatten_included_provider_specs(provider_specs)
+ ]
+
+
+def merge_providers(*provider_sets) -> Mapping[str, Iterable[Provider]]:
result = collections.defaultdict(list)
for provider_set in provider_sets:
if isinstance(provider_set, Provider):
diff --git a/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py b/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py
index ec71422..5a30c7d 100644
--- a/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py
@@ -16,9 +16,15 @@
#
import logging
+import os
+import tempfile
import unittest
+import yaml
+
+from apache_beam.yaml import yaml_provider
from apache_beam.yaml.yaml_provider import YamlProviders
+from apache_beam.yaml.yaml_transform import SafeLineLoader
class WindowIntoTest(unittest.TestCase):
@@ -63,6 +69,80 @@
self.parse_duration('s', 'size')
+class ProviderParsingTest(unittest.TestCase):
+
+ INLINE_PROVIDER = {'type': 'TEST', 'name': 'INLINED'}
+ INCLUDED_PROVIDER = {'type': 'TEST', 'name': 'INCLUDED'}
+ EXTRA_PROVIDER = {'type': 'TEST', 'name': 'EXTRA'}
+
+ @classmethod
+ def setUpClass(cls):
+ cls.tempdir = tempfile.TemporaryDirectory()
+ cls.to_include = os.path.join(cls.tempdir.name, 'providers.yaml')
+ with open(cls.to_include, 'w') as fout:
+ yaml.dump([cls.INCLUDED_PROVIDER], fout)
+ cls.to_include_nested = os.path.join(
+ cls.tempdir.name, 'nested_providers.yaml')
+ with open(cls.to_include_nested, 'w') as fout:
+ yaml.dump([{'include': cls.to_include}, cls.EXTRA_PROVIDER], fout)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.tempdir.cleanup()
+
+ def test_include_file(self):
+ flattened = [
+ SafeLineLoader.strip_metadata(spec)
+ for spec in yaml_provider.flatten_included_provider_specs([
+ self.INLINE_PROVIDER,
+ {
+ 'include': self.to_include
+ },
+ ])
+ ]
+
+ self.assertEqual([
+ self.INLINE_PROVIDER,
+ self.INCLUDED_PROVIDER,
+ ],
+ flattened)
+
+ def test_include_url(self):
+ flattened = [
+ SafeLineLoader.strip_metadata(spec)
+ for spec in yaml_provider.flatten_included_provider_specs([
+ self.INLINE_PROVIDER,
+ {
+ 'include': 'file:///' + self.to_include
+ },
+ ])
+ ]
+
+ self.assertEqual([
+ self.INLINE_PROVIDER,
+ self.INCLUDED_PROVIDER,
+ ],
+ flattened)
+
+ def test_nested_include(self):
+ flattened = [
+ SafeLineLoader.strip_metadata(spec)
+ for spec in yaml_provider.flatten_included_provider_specs([
+ self.INLINE_PROVIDER,
+ {
+ 'include': self.to_include_nested
+ },
+ ])
+ ]
+
+ self.assertEqual([
+ self.INLINE_PROVIDER,
+ self.INCLUDED_PROVIDER,
+ self.EXTRA_PROVIDER,
+ ],
+ flattened)
+
+
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py b/sdks/python/apache_beam/yaml/yaml_transform.py
index 3fb5bb7..ea51a30 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform.py
@@ -1044,10 +1044,6 @@
# Calling expand directly to avoid outer layer of nesting.
return YamlTransform(
pipeline_as_composite(pipeline_spec['pipeline']),
- {
- **yaml_provider.parse_providers(pipeline_spec.get('providers', [])),
- **{
- key: yaml_provider.as_provider_list(key, value)
- for (key, value) in (providers or {}).items()
- }
- }).expand(beam.pvalue.PBegin(pipeline))
+ yaml_provider.merge_providers(
+ pipeline_spec.get('providers', []), providers or
+ {})).expand(beam.pvalue.PBegin(pipeline))