| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| import dataclasses |
| import typing as t |
| |
| import apache_beam as beam |
| from apache_beam import typehints |
| from apache_beam.io.iobase import SourceBase |
| from apache_beam.pipeline import AppliedPTransform |
| from apache_beam.pipeline import PTransformOverride |
| from apache_beam.runners.direct.direct_runner import _GroupAlsoByWindowDoFn |
| from apache_beam.transforms import ptransform |
| from apache_beam.transforms.window import GlobalWindows |
| |
| K = t.TypeVar("K") |
| V = t.TypeVar("V") |
| |
| |
| @dataclasses.dataclass |
| class _Create(beam.PTransform): |
| values: t.Tuple[t.Any] |
| |
| def expand(self, input_or_inputs): |
| return beam.pvalue.PCollection.from_(input_or_inputs) |
| |
| def get_windowing(self, inputs: t.Any) -> beam.Windowing: |
| return beam.Windowing(GlobalWindows()) |
| |
| |
| @typehints.with_input_types(K) |
| @typehints.with_output_types(K) |
| class _Reshuffle(beam.PTransform): |
| def expand(self, input_or_inputs): |
| return beam.pvalue.PCollection.from_(input_or_inputs) |
| |
| |
| @dataclasses.dataclass |
| class _Read(beam.PTransform): |
| source: SourceBase |
| |
| def expand(self, input_or_inputs): |
| return beam.pvalue.PCollection.from_(input_or_inputs) |
| |
| |
| @typehints.with_input_types(t.Tuple[K, V]) |
| @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) |
| class _GroupByKeyOnly(beam.PTransform): |
| def expand(self, input_or_inputs): |
| return beam.pvalue.PCollection.from_(input_or_inputs) |
| |
| def infer_output_type(self, input_type): |
| |
| key_type, value_type = typehints.trivial_inference.key_value_types( |
| input_type |
| ) |
| return typehints.KV[key_type, typehints.Iterable[value_type]] |
| |
| |
| @typehints.with_input_types(t.Tuple[K, t.Iterable[V]]) |
| @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) |
| class _GroupAlsoByWindow(beam.ParDo): |
| """Not used yet...""" |
| def __init__(self, windowing): |
| super().__init__(_GroupAlsoByWindowDoFn(windowing)) |
| self.windowing = windowing |
| |
| def expand(self, input_or_inputs): |
| return beam.pvalue.PCollection.from_(input_or_inputs) |
| |
| |
| @typehints.with_input_types(t.Tuple[K, V]) |
| @typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) |
| class _GroupByKey(beam.PTransform): |
| def expand(self, input_or_inputs): |
| return input_or_inputs | "GroupByKey" >> _GroupByKeyOnly() |
| |
| |
| class _Flatten(beam.PTransform): |
| def expand(self, input_or_inputs): |
| is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs) |
| return beam.pvalue.PCollection(self.pipeline, is_bounded=is_bounded) |
| |
| |
| def dask_overrides() -> t.List[PTransformOverride]: |
| class CreateOverride(PTransformOverride): |
| def matches(self, applied_ptransform: AppliedPTransform) -> bool: |
| return applied_ptransform.transform.__class__ == beam.Create |
| |
| def get_replacement_transform_for_applied_ptransform( |
| self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: |
| return _Create(t.cast(beam.Create, applied_ptransform.transform).values) |
| |
| class ReshuffleOverride(PTransformOverride): |
| def matches(self, applied_ptransform: AppliedPTransform) -> bool: |
| return applied_ptransform.transform.__class__ == beam.Reshuffle |
| |
| def get_replacement_transform_for_applied_ptransform( |
| self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: |
| return _Reshuffle() |
| |
| class ReadOverride(PTransformOverride): |
| def matches(self, applied_ptransform: AppliedPTransform) -> bool: |
| return applied_ptransform.transform.__class__ == beam.io.Read |
| |
| def get_replacement_transform_for_applied_ptransform( |
| self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: |
| return _Read(t.cast(beam.io.Read, applied_ptransform.transform).source) |
| |
| class GroupByKeyOverride(PTransformOverride): |
| def matches(self, applied_ptransform: AppliedPTransform) -> bool: |
| return applied_ptransform.transform.__class__ == beam.GroupByKey |
| |
| def get_replacement_transform_for_applied_ptransform( |
| self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: |
| return _GroupByKey() |
| |
| class FlattenOverride(PTransformOverride): |
| def matches(self, applied_ptransform: AppliedPTransform) -> bool: |
| return applied_ptransform.transform.__class__ == beam.Flatten |
| |
| def get_replacement_transform_for_applied_ptransform( |
| self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: |
| return _Flatten() |
| |
| return [ |
| CreateOverride(), |
| ReshuffleOverride(), |
| ReadOverride(), |
| GroupByKeyOverride(), |
| FlattenOverride(), |
| ] |