blob: ea4f496c2d04c5486b10adc6cbb55c50a3b4e852 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Generates Python wrappers for external transforms (specifically,
SchemaTransforms)
"""
import argparse
import datetime
import logging
import os
import shutil
import subprocess
import typing
from typing import Any
from typing import Dict
from typing import List
from typing import Union
import yaml
from gen_protos import LICENSE_HEADER
from gen_protos import PROJECT_ROOT
from gen_protos import PYTHON_SDK_ROOT
SUPPORTED_SDK_DESTINATIONS = ['python']
PYTHON_SUFFIX = "_et.py"
PY_WRAPPER_OUTPUT_DIR = os.path.join(
PYTHON_SDK_ROOT, 'apache_beam', 'transforms', 'xlang')
def generate_transforms_config(input_services, output_file):
"""
Generates a YAML file containing a list of transform configurations.
Takes an input YAML file containing a list of expansion service gradle
targets. Each service must provide a `destinations` field that specifies the
default package (relative path) that generated wrappers should be imported
to. A default destination package is specified for each SDK, like so::
- gradle_target: 'sdks:java:io:expansion-service:shadowJar'
destinations:
python: 'apache_beam/io'
We use :class:`ExternalTransformProvider` to discover external
transforms. Then, we extract the necessary details of each transform and
compile them into a new YAML file, which is later used to generate wrappers.
Importing generated transforms to an existing package
-----------------------------------------------------
When running the script on the config above, a new module will be created at
`apache_beam/transforms/xlang/io.py`. This contains all
generated wrappers that are set to destination 'apache_beam/io'. Finally,
to make these available to the `apache_beam.io` package (or any package
really), just add the following line to the package's `__init__.py` file::
from apache_beam.transforms.xlang.io import *
Modifying a transform's name and destination
--------------------------------------------
Each service may also specify modifications for particular transform.
Currently, one can modify the generated wrapper's **name** and
**destination** package:
- By default, the transform's identifier is used to generate the wrapper
class name. This can be overriden by manually providing a name.
- By default, generated wrappers are made available to the package provided
by their respective expansion service. This can be overridden by
providing a relative path to a different package.
See the following example for what such modifications can look like::
- gradle_target: 'sdks:java:io:expansion-service:shadowJar'
destinations:
python: 'apache_beam/io'
transforms:
'beam:schematransform:org.apache.beam:my_transform:v1':
name: 'MyCustomTransformName'
destinations:
python: 'apache_beam/io/gcp'
For the above example, we would take the transform with identifier
`beam:schematransform:org.apache.beam:my_transform:v1` and by default infer
a wrapper class name of `MyTransform` then write it to the module
`apache_beam/transforms/xlang/io.py`. With these modifications
however, we instead use the provided name `MyCustomTransformName` and write
it to `apache_beam/transforms/xlang/io_gcp.py`.
Similar to above, this can be made available by importing it in the
`__init__.py` file like so::
from apache_beam.transforms.xlang.io_gcp import *
Skipping transforms
-------------------
To skip a particular transform, simply list its identifier in the
`skip_transforms` field, like so::
- gradle_target: 'sdks:java:io:expansion-service:shadowJar'
destinations:
python: 'apache_beam/io'
skip_transforms:
- 'beam:schematransform:org.apache.beam:some_transform:v1'
"""
from apache_beam.transforms.external import BeamJarExpansionService
from apache_beam.transforms.external_transform_provider import ExternalTransform
from apache_beam.transforms.external_transform_provider import ExternalTransformProvider
transform_list: List[Dict[str, Any]] = []
with open(input_services) as f:
services = yaml.safe_load(f)
for service in services:
target = service['gradle_target']
if "destinations" not in service:
raise ValueError(
f"Expansion service with target '{target}' does not "
"specify any default destinations.")
service_destinations: Dict[str, str] = service['destinations']
for sdk, dest in service_destinations.items():
validate_sdks_destinations(sdk, dest, target)
transforms_to_skip = service.get('skip_transforms', [])
# use dynamic provider to discover and populate wrapper details
provider = ExternalTransformProvider(BeamJarExpansionService(target))
discovered: Dict[str, ExternalTransform] = provider.get_all()
for identifier, wrapper in discovered.items():
if identifier in transforms_to_skip:
continue
transform_destinations = service_destinations.copy()
# apply any modifications
modified_transform = {}
if 'transforms' in service and identifier in service['transforms']:
modified_transform = service['transforms'][identifier]
for sdk, dest in modified_transform.get('destinations', {}).items():
validate_sdks_destinations(sdk, dest, target, identifier)
transform_destinations[sdk] = dest # override the destination
name = modified_transform.get('name', wrapper.__name__)
fields = {}
for param in wrapper.configuration_schema.values():
(tp, nullable) = pretty_type(param.type)
field_info = {
'type': str(tp),
'description': param.description,
'nullable': nullable
}
fields[param.original_name] = field_info
transform = {
'identifier': identifier,
'name': name,
'destinations': transform_destinations,
'default_service': target,
'fields': fields,
'description': wrapper.description
}
transform_list.append(transform)
with open(output_file, 'w') as f:
f.write(LICENSE_HEADER.lstrip())
f.write(
"# NOTE: This file is autogenerated and should "
"not be edited by hand.\n")
f.write(
"# Configs are generated based on the expansion service\n"
f"# configuration in {input_services.replace(PROJECT_ROOT, '')}.\n")
f.write("# Refer to gen_xlang_wrappers.py for more info.\n")
dt = datetime.datetime.now().date()
f.write(f"#\n# Last updated on: {dt}\n\n")
yaml.dump(transform_list, f)
logging.info("Successfully wrote transform configs to file: %s", output_file)
def validate_sdks_destinations(sdk, dest, service, identifier=None):
if identifier:
message = f"Identifier '{identifier}'"
else:
message = f"Service '{service}'"
if sdk not in SUPPORTED_SDK_DESTINATIONS:
raise ValueError(
message + " specifies a destination for an invalid SDK:"
f" '{sdk}'. The supported SDKs are {SUPPORTED_SDK_DESTINATIONS}")
if not os.path.isdir(os.path.join(PYTHON_SDK_ROOT, *dest.split('/'))):
raise ValueError(
message + f" specifies an invalid destination '{dest}'."
" Please make sure the destination is an existing directory.")
def pretty_type(tp):
"""
Takes a type and returns a tuple containing a pretty string representing it
and a bool signifying if it is nullable or not.
For optional types, the contained type is unwrapped and returned. This does
not recurse however, so inner Optional types are not affected.
E.g. the input typing.Optional[typing.Dict[int, typing.Optional[str]]] will
return (Dict[int, Union[str, NoneType]], True)
"""
nullable = False
if (typing.get_origin(tp) is Union and type(None) in typing.get_args(tp)):
nullable = True
# only unwrap if it's a single nullable type. if the type is truly a union
# of multiple types, leave it alone.
args = typing.get_args(tp)
if len(args) == 2:
tp = list(filter(lambda t: not isinstance(t, type(None)), args))[0]
# TODO(ahmedabu98): Make this more generic to support other remote SDKs
# Potentially use Runner API types
if tp.__module__ == 'builtins':
tp = tp.__name__
elif tp.__module__ == 'typing':
tp = str(tp).replace("typing.", "")
elif tp.__module__ == 'numpy':
tp = "%s.%s" % (tp.__module__, tp.__name__)
return (tp, nullable)
def get_wrappers_from_transform_configs(config_file) -> Dict[str, List[str]]:
"""
Generates code for external transform wrapper classes (subclasses of
:class:`ExternalTransform`).
Takes a YAML file containing a list of SchemaTransform configurations. For
each configuration, the code for a wrapper class is generated, along with any
documentation that may be included.
Each configuration must include a destination file that the generated class
will be written to.
Returns the generated classes, grouped by destination.
"""
from jinja2 import Environment
from jinja2 import FileSystemLoader
env = Environment(loader=FileSystemLoader(PYTHON_SDK_ROOT))
python_wrapper_template = env.get_template("python_xlang_wrapper.template")
# maintain a list of wrappers to write in each file. if modified destinations
# are used, we may end up with multiple wrappers in one file.
destinations: Dict[str, List[str]] = {}
with open(config_file) as f:
transforms = yaml.safe_load(f)
for config in transforms:
default_service = config['default_service']
description = config['description']
destination = config['destinations']['python']
name = config['name']
fields = config['fields']
identifier = config['identifier']
parameters = []
for param, info in fields.items():
param_details = {
"name": param,
"type": info['type'],
"description": info['description'],
}
if info['nullable']:
param_details["default"] = None
parameters.append(param_details)
# Python syntax requires function definitions to have
# non-default parameters first
parameters = sorted(parameters, key=lambda p: 'default' in p)
default_service = f"BeamJarExpansionService(\"{default_service}\")"
python_wrapper_class = python_wrapper_template.render(
class_name=name,
identifier=identifier,
parameters=parameters,
description=description,
default_expansion_service=default_service)
if destination not in destinations:
destinations[destination] = []
destinations[destination].append(python_wrapper_class)
return destinations
def write_wrappers_to_destinations(
grouped_wrappers: Dict[str, List[str]],
output_dir=PY_WRAPPER_OUTPUT_DIR,
format_code=True):
"""
Takes a dictionary of generated wrapper code, grouped by destination.
For each destination, create a new file containing the respective wrapper
classes. Each file includes the Apache License header and relevant imports.
Note: the Jinja template should already follow linting and formatting rules.
"""
written_files = []
for dest, wrappers in grouped_wrappers.items():
module_name = dest.replace('apache_beam/', '').replace('/', '_')
module_path = os.path.join(output_dir, module_name) + ".py"
with open(module_path, "w") as file:
file.write(LICENSE_HEADER.lstrip())
file.write(
"\n# NOTE: This file contains autogenerated external transform(s)\n"
"# and should not be edited by hand.\n"
"# Refer to gen_xlang_wrappers.py for more info.\n\n")
file.write(
"\"\"\""
"Cross-language transforms in this module can be imported from the\n"
f":py:mod:`{dest.replace('/', '.')}` package."
"\"\"\"\n\n")
file.write(
"# pylint:disable=line-too-long\n\n"
"from apache_beam.transforms.external import "
"BeamJarExpansionService\n"
"from apache_beam.transforms.external_transform_provider "
"import ExternalTransform\n")
for wrapper in wrappers:
file.write(wrapper + "\n")
written_files.append(module_path)
logging.info("Created external transform wrapper modules: %s", written_files)
if format_code:
formatting_cmd = ['yapf', '--in-place', *written_files]
subprocess.run(formatting_cmd, capture_output=True, check=True)
def delete_generated_files(root_dir):
"""Scans for and deletes generated wrapper files."""
logging.info("Deleting external transform wrappers from dir %s", root_dir)
deleted_files = os.listdir(root_dir)
for file in deleted_files:
if file == '__init__.py':
deleted_files.remove(file)
continue
path = os.path.join(root_dir, file)
if os.path.isfile(path) or os.path.islink(path):
os.unlink(os.path.join(root_dir, file))
else:
shutil.rmtree(path)
logging.info("Successfully deleted files: %s", deleted_files)
def run_script(
cleanup,
generate_config_only,
input_expansion_services,
transforms_config_source):
# Cleanup first if requested. This is needed to remove outdated wrappers.
if cleanup:
delete_generated_files(PY_WRAPPER_OUTPUT_DIR)
# This step requires the expansion service.
# Only generate a transforms config file if none are provided
if not transforms_config_source:
output_transforms_config = os.path.join(
PROJECT_ROOT, 'sdks', 'standard_external_transforms.yaml')
generate_transforms_config(
input_services=input_expansion_services,
output_file=output_transforms_config)
transforms_config_source = output_transforms_config
else:
if not os.path.exists(transforms_config_source):
raise RuntimeError(
"Could not find the provided transforms config "
f"source: {transforms_config_source}")
if generate_config_only:
return
wrappers_grouped_by_destination = get_wrappers_from_transform_configs(
transforms_config_source)
write_wrappers_to_destinations(wrappers_grouped_by_destination)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--cleanup',
dest='cleanup',
action='store_true',
help="Whether to cleanup existing generated wrappers first.")
parser.add_argument(
'--generate-config-only',
dest='generate_config_only',
action='store_true',
help="If set, will generate the transform config only without generating"
"any wrappers.")
parser.add_argument(
'--input-expansion-services',
dest='input_expansion_services',
default=os.path.join(
PROJECT_ROOT, 'sdks', 'standard_expansion_services.yaml'),
help=(
"Absolute path to the input YAML file that contains "
"expansion service configs. Ignored if a transforms config"
"source is provided."))
parser.add_argument(
'--transforms-config-source',
dest='transforms_config_source',
help=(
"Absolute path to a source transforms config YAML file to "
"generate wrapper modules from. If not provided, one will be "
"created by this script."))
args = parser.parse_args()
run_script(
args.cleanup,
args.generate_config_only,
args.input_expansion_services,
args.transforms_config_source)