blob: 6fd216fadb5372c5a16c2c187cd96dba5df1a514 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Transform Beam PTransforms into Dask Bag operations.
A minimum set of operation substitutions, to adap Beam's PTransform model
to Dask Bag functions.
"""
import abc
import dataclasses
import logging
import math
import typing as t
from dataclasses import field
import dask.bag as db
import apache_beam
from apache_beam import DoFn
from apache_beam import TaggedOutput
from apache_beam.pipeline import AppliedPTransform
from apache_beam.runners.common import DoFnContext
from apache_beam.runners.common import DoFnInvoker
from apache_beam.runners.common import DoFnSignature
from apache_beam.runners.common import Receiver
from apache_beam.runners.common import _OutputHandler
from apache_beam.runners.dask.overrides import _Create
from apache_beam.runners.dask.overrides import _Flatten
from apache_beam.runners.dask.overrides import _GroupByKeyOnly
from apache_beam.transforms.sideinputs import SideInputMap
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import TimestampedValue
from apache_beam.transforms.window import WindowFn
from apache_beam.utils.windowed_value import WindowedValue
# Inputs to DaskOps.
OpInput = t.Union[db.Bag, t.Sequence[db.Bag], None]
OpSide = t.Optional[t.Sequence[SideInputMap]]
# Value types for PCollections (possibly Windowed Values).
PCollVal = t.Union[WindowedValue, t.Any]
_LOGGER = logging.getLogger(__name__)
def get_windowed_value(item: t.Any, window_fn: WindowFn) -> WindowedValue:
"""Wraps a value (item) inside a Window."""
if isinstance(item, TaggedOutput):
item = item.value
if isinstance(item, WindowedValue):
windowed_value = item
elif isinstance(item, TimestampedValue):
assign_context = WindowFn.AssignContext(item.timestamp, item.value)
windowed_value = WindowedValue(
item.value, item.timestamp, tuple(window_fn.assign(assign_context)))
else:
windowed_value = WindowedValue(item, 0, (GlobalWindow(), ))
return windowed_value
def defenestrate(x):
"""Extracts the underlying item from a Window."""
if isinstance(x, WindowedValue):
return x.value
return x
@dataclasses.dataclass
class DaskBagWindowedIterator:
"""Iterator for `apache_beam.transforms.sideinputs.SideInputMap`"""
bag: db.Bag
window_fn: WindowFn
def __iter__(self):
# FIXME(cisaacstern): list() is likely inefficient, since it presumably
# materializes the full result before iterating over it. doing this for
# now as a proof-of-concept. can we can generate results incrementally?
for result in list(self.bag):
yield get_windowed_value(result, self.window_fn)
@dataclasses.dataclass
class TaggingReceiver(Receiver):
"""A Receiver that handles tagged `WindowValue`s."""
tag: str
values: t.List[PCollVal]
def receive(self, windowed_value: WindowedValue):
if self.tag:
output = TaggedOutput(self.tag, windowed_value)
else:
output = windowed_value
self.values.append(output)
@dataclasses.dataclass
class OneReceiver(dict):
"""A Receiver that tags value via dictionary lookup key."""
values: t.List[PCollVal] = field(default_factory=list)
def __missing__(self, key):
if key not in self:
self[key] = TaggingReceiver(key, self.values)
return self[key]
@dataclasses.dataclass
class DaskBagOp(abc.ABC):
"""Abstract Base Class for all Dask-supported Operations.
All DaskBagOps must support an `apply()` operation, which invokes the dask
bag upon the previous op's input.
Attributes
applied: The underlying `AppliedPTransform` which holds the code for the
target operation.
bag_kwargs: (optional) Keyword arguments applied to input bags, usually
from the pipeline's `DaskOptions`.
"""
applied: AppliedPTransform
bag_kwargs: t.Dict = dataclasses.field(default_factory=dict)
@property
def transform(self):
return self.applied.transform
@abc.abstractmethod
def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag:
pass
class NoOp(DaskBagOp):
"""An identity on a dask bag: returns the input as-is."""
def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag:
return input_bag
class Create(DaskBagOp):
"""The beginning of a Beam pipeline; the input must be `None`."""
def apply(self, input_bag: OpInput, side_inputs: OpSide = None) -> db.Bag:
assert input_bag is None, 'Create expects no input!'
original_transform = t.cast(_Create, self.transform)
items = original_transform.values
npartitions = self.bag_kwargs.get('npartitions')
partition_size = self.bag_kwargs.get('partition_size')
if npartitions and partition_size:
raise ValueError(
f'Please specify either `dask_npartitions` or '
f'`dask_parition_size` but not both: '
f'{npartitions=}, {partition_size=}.')
if not npartitions and not partition_size:
# partition_size is inversely related to `npartitions`.
# Ideal "chunk sizes" in dask are around 10-100 MBs.
# Let's hope ~128 items per partition is around this
# memory overhead.
default_size = 128
partition_size = max(default_size, math.ceil(math.sqrt(len(items)) / 10))
if partition_size == default_size:
_LOGGER.warning(
'The new default partition size is %d, it used to be 1 '
'in previous DaskRunner versions.' % default_size)
return db.from_sequence(
items, npartitions=npartitions, partition_size=partition_size)
def apply_dofn_to_bundle(
items, do_fn_invoker_args, do_fn_invoker_kwargs, tagged_receivers):
"""Invokes a DoFn within a bundle, implemented as a Dask partition."""
do_fn_invoker = DoFnInvoker.create_invoker(
*do_fn_invoker_args, **do_fn_invoker_kwargs)
do_fn_invoker.invoke_setup()
do_fn_invoker.invoke_start_bundle()
for it in items:
do_fn_invoker.invoke_process(it)
results = [v.value for v in tagged_receivers.values]
do_fn_invoker.invoke_finish_bundle()
do_fn_invoker.invoke_teardown()
return results
class ParDo(DaskBagOp):
"""Apply a pure function in an embarrassingly-parallel way.
This consumes a sequence of items and returns a sequence of items.
"""
def apply(self, input_bag: db.Bag, side_inputs: OpSide = None) -> db.Bag:
transform = t.cast(apache_beam.ParDo, self.transform)
args, kwargs = transform.raw_side_inputs
args = list(args)
main_input = next(iter(self.applied.main_inputs.values()))
window_fn = main_input.windowing.windowfn if hasattr(
main_input, "windowing") else None
tagged_receivers = OneReceiver()
do_fn_invoker_args = [
DoFnSignature(transform.fn),
_OutputHandler(
window_fn=window_fn,
main_receivers=tagged_receivers[None],
tagged_receivers=tagged_receivers,
per_element_output_counter=None,
output_batch_converter=None,
process_yields_batches=False,
process_batch_yields_elements=False),
]
do_fn_invoker_kwargs = dict(
context=DoFnContext(transform.label, state=None),
side_inputs=side_inputs,
input_args=args,
input_kwargs=kwargs,
user_state_context=None,
bundle_finalizer_param=DoFn.BundleFinalizerParam(),
)
return input_bag.map(get_windowed_value, window_fn).map_partitions(
apply_dofn_to_bundle,
do_fn_invoker_args,
do_fn_invoker_kwargs,
tagged_receivers,
)
class GroupByKey(DaskBagOp):
"""Group a PCollection into a mapping of keys to elements."""
def apply(self, input_bag: db.Bag, side_inputs: OpSide = None) -> db.Bag:
def key(item):
return item[0]
def value(item):
k, v = item
return k, [defenestrate(elm[1]) for elm in v]
return input_bag.groupby(key).map(value)
class Flatten(DaskBagOp):
"""Produces a flattened bag from a collection of bags."""
def apply(
self, input_bag: t.List[db.Bag], side_inputs: OpSide = None) -> db.Bag:
assert isinstance(input_bag, list), 'Must take a sequence of bags!'
return db.concat(input_bag)
TRANSLATIONS = {
_Create: Create,
apache_beam.ParDo: ParDo,
_GroupByKeyOnly: GroupByKey,
_Flatten: Flatten,
}