[BEAM-8575] Test a customized window fn work as expected
diff --git a/sdks/python/apache_beam/transforms/window_test.py b/sdks/python/apache_beam/transforms/window_test.py
index dda651d..bee33e0 100644
--- a/sdks/python/apache_beam/transforms/window_test.py
+++ b/sdks/python/apache_beam/transforms/window_test.py
@@ -22,7 +22,10 @@
import unittest
from builtins import range
+from nose.plugins.attrib import attr
+
import apache_beam as beam
+from apache_beam.coders import coders
from apache_beam.runners import pipeline_context
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
@@ -39,6 +42,7 @@
from apache_beam.transforms.core import Windowing
from apache_beam.transforms.trigger import AccumulationMode
from apache_beam.transforms.trigger import AfterCount
+from apache_beam.transforms.window import NonMergingWindowFn
from apache_beam.transforms.window import FixedWindows
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import GlobalWindows
@@ -65,6 +69,23 @@
reify_windows = core.ParDo(ReifyWindowsFn())
+class TestCustomWindows(NonMergingWindowFn):
+ """A custom non merging window fn which assigns elements into interval windows
+ [0, 3), [3, 5) and [5, element timestamp) based on the element timestamps.
+ """
+
+ def assign(self, context):
+ timestamp = context.timestamp
+ if timestamp < 3:
+ return [IntervalWindow(0, 3)]
+ elif timestamp < 5:
+ return [IntervalWindow(3, 5)]
+ else:
+ return [IntervalWindow(5, timestamp)]
+
+ def get_window_coder(self):
+ return coders.IntervalWindowCoder()
+
class WindowTest(unittest.TestCase):
@@ -281,6 +302,19 @@
assert_that(mean_per_window, equal_to([(0, 2.0), (1, 7.0)]),
label='assert:mean')
+ @attr('ValidatesRunner')
+ def test_custom_windows(self):
+ with TestPipeline() as p:
+ pcoll = self.timestamped_key_values(p, 'key', 0, 1, 2, 3, 4, 5, 6)
+ result = (pcoll
+ | 'custom window' >> WindowInto(TestCustomWindows())
+ | GroupByKey()
+ | 'sort values' >> MapTuple(lambda k, vs: (k, sorted(vs))))
+ assert_that(result, equal_to([('key', [0, 1, 2]),
+ ('key', [3, 4]),
+ ('key', [5]),
+ ('key', [6])]))
+
class RunnerApiTest(unittest.TestCase):