blob: 78d1db4cc7c2049a05e3b24cfd0309c772e4be6a [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import glob
import json
import logging
import os
import re
import shutil
import tempfile
import unittest
from datetime import datetime
from tempfile import TemporaryDirectory
import hamcrest as hc
import pandas
import pytest
import pytz
from parameterized import param
from parameterized import parameterized
import apache_beam as beam
from apache_beam import Create
from apache_beam import Map
from apache_beam.io import filebasedsource
from apache_beam.io import source_test_utils
from apache_beam.io.iobase import RangeTracker
from apache_beam.io.parquetio import ReadAllFromParquet
from apache_beam.io.parquetio import ReadAllFromParquetBatched
from apache_beam.io.parquetio import ReadFromParquet
from apache_beam.io.parquetio import ReadFromParquetBatched
from apache_beam.io.parquetio import WriteToParquet
from apache_beam.io.parquetio import WriteToParquetBatched
from apache_beam.io.parquetio import _create_parquet_sink
from apache_beam.io.parquetio import _create_parquet_source
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.test_stream import TestStream
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms.display import DisplayData
from apache_beam.transforms.display_test import DisplayDataItemMatcher
from apache_beam.transforms.util import LogElements
try:
import pyarrow as pa
import pyarrow.parquet as pq
ARROW_MAJOR_VERSION, _, _ = map(int, pa.__version__.split('.'))
except ImportError:
pa = None
pq = None
ARROW_MAJOR_VERSION = 0
@unittest.skipIf(pa is None, "PyArrow is not installed.")
@pytest.mark.uses_pyarrow
class TestParquet(unittest.TestCase):
def setUp(self):
# Reducing the size of thread pools. Without this test execution may fail in
# environments with limited amount of resources.
filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2
self.temp_dir = tempfile.mkdtemp()
self.RECORDS = [{
'name': 'Thomas', 'favorite_number': 1, 'favorite_color': 'blue'
},
{
'name': 'Henry',
'favorite_number': 3,
'favorite_color': 'green'
},
{
'name': 'Toby',
'favorite_number': 7,
'favorite_color': 'brown'
},
{
'name': 'Gordon',
'favorite_number': 4,
'favorite_color': 'blue'
},
{
'name': 'Emily',
'favorite_number': -1,
'favorite_color': 'Red'
},
{
'name': 'Percy',
'favorite_number': 6,
'favorite_color': 'Green'
},
{
'name': 'Peter',
'favorite_number': 3,
'favorite_color': None
}]
self.SCHEMA = pa.schema([('name', pa.string(), False),
('favorite_number', pa.int64(), False),
('favorite_color', pa.string())])
self.SCHEMA96 = pa.schema([('name', pa.string(), False),
('favorite_number', pa.timestamp('ns'), False),
('favorite_color', pa.string())])
self.RECORDS_NESTED = [{
'items': [
{
'name': 'Thomas',
'favorite_number': 1,
'favorite_color': 'blue'
},
{
'name': 'Henry',
'favorite_number': 3,
'favorite_color': 'green'
},
]
},
{
'items': [
{
'name': 'Toby',
'favorite_number': 7,
'favorite_color': 'brown'
},
]
}]
self.SCHEMA_NESTED = pa.schema([(
'items',
pa.list_(
pa.struct([('name', pa.string(), False),
('favorite_number', pa.int64(), False),
('favorite_color', pa.string())])))])
def tearDown(self):
shutil.rmtree(self.temp_dir)
def _record_to_columns(self, records, schema):
col_list = []
for n in schema.names:
column = []
for r in records:
column.append(r[n])
col_list.append(column)
return col_list
def _records_as_arrow(self, schema=None, count=None):
if schema is None:
schema = self.SCHEMA
if count is None:
count = len(self.RECORDS)
len_records = len(self.RECORDS)
data = []
for i in range(count):
data.append(self.RECORDS[i % len_records])
col_data = self._record_to_columns(data, schema)
col_array = [pa.array(c, schema.types[cn]) for cn, c in enumerate(col_data)]
return pa.Table.from_arrays(col_array, schema=schema)
def _write_data(
self,
directory=None,
schema=None,
prefix=tempfile.template,
row_group_size=1000,
codec='none',
count=None):
if directory is None:
directory = self.temp_dir
with tempfile.NamedTemporaryFile(delete=False, dir=directory,
prefix=prefix) as f:
table = self._records_as_arrow(schema, count)
pq.write_table(
table,
f,
row_group_size=row_group_size,
compression=codec,
use_deprecated_int96_timestamps=True)
return f.name
def _write_pattern(self, num_files, with_filename=False):
assert num_files > 0
temp_dir = tempfile.mkdtemp(dir=self.temp_dir)
file_list = []
for _ in range(num_files):
file_list.append(self._write_data(directory=temp_dir, prefix='mytemp'))
if with_filename:
return (temp_dir + os.path.sep + 'mytemp*', file_list)
return temp_dir + os.path.sep + 'mytemp*'
def _run_parquet_test(
self,
pattern,
columns,
desired_bundle_size,
perform_splitting,
expected_result):
source = _create_parquet_source(pattern, columns=columns)
if perform_splitting:
assert desired_bundle_size
sources_info = [
(split.source, split.start_position, split.stop_position)
for split in source.split(desired_bundle_size=desired_bundle_size)
]
if len(sources_info) < 2:
raise ValueError(
'Test is trivial. Please adjust it so that at least '
'two splits get generated')
source_test_utils.assert_sources_equal_reference_source(
(source, None, None), sources_info)
else:
read_records = source_test_utils.read_from_source(source, None, None)
self.assertCountEqual(expected_result, read_records)
def test_read_without_splitting(self):
file_name = self._write_data()
expected_result = [self._records_as_arrow()]
self._run_parquet_test(file_name, None, None, False, expected_result)
def test_read_with_splitting(self):
file_name = self._write_data()
expected_result = [self._records_as_arrow()]
self._run_parquet_test(file_name, None, 100, True, expected_result)
def test_source_display_data(self):
file_name = 'some_parquet_source'
source = \
_create_parquet_source(
file_name,
validate=False
)
dd = DisplayData.create_from(source)
expected_items = [
DisplayDataItemMatcher('compression', 'auto'),
DisplayDataItemMatcher('file_pattern', file_name)
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_read_display_data(self):
file_name = 'some_parquet_source'
read = \
ReadFromParquet(
file_name,
validate=False)
read_batched = \
ReadFromParquetBatched(
file_name,
validate=False)
expected_items = [
DisplayDataItemMatcher('compression', 'auto'),
DisplayDataItemMatcher('file_pattern', file_name)
]
hc.assert_that(
DisplayData.create_from(read).items,
hc.contains_inanyorder(*expected_items))
hc.assert_that(
DisplayData.create_from(read_batched).items,
hc.contains_inanyorder(*expected_items))
def test_sink_display_data(self):
file_name = 'some_parquet_sink'
sink = _create_parquet_sink(
file_name,
self.SCHEMA,
'none',
False,
False,
'.end',
0,
None,
'application/x-parquet')
dd = DisplayData.create_from(sink)
expected_items = [
DisplayDataItemMatcher('schema', str(self.SCHEMA)),
DisplayDataItemMatcher(
'file_pattern',
'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d.end'),
DisplayDataItemMatcher('codec', 'none'),
DisplayDataItemMatcher('compression', 'uncompressed')
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_write_display_data(self):
file_name = 'some_parquet_sink'
write = WriteToParquet(file_name, self.SCHEMA)
dd = DisplayData.create_from(write)
expected_items = [
DisplayDataItemMatcher('codec', 'none'),
DisplayDataItemMatcher('schema', str(self.SCHEMA)),
DisplayDataItemMatcher('row_group_buffer_size', str(64 * 1024 * 1024)),
DisplayDataItemMatcher(
'file_pattern',
'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d'),
DisplayDataItemMatcher('compression', 'uncompressed')
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
def test_write_batched_display_data(self):
file_name = 'some_parquet_sink'
write = WriteToParquetBatched(file_name, self.SCHEMA)
dd = DisplayData.create_from(write)
expected_items = [
DisplayDataItemMatcher('codec', 'none'),
DisplayDataItemMatcher('schema', str(self.SCHEMA)),
DisplayDataItemMatcher(
'file_pattern',
'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d'),
DisplayDataItemMatcher('compression', 'uncompressed')
]
hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
@unittest.skipIf(
ARROW_MAJOR_VERSION >= 13,
'pyarrow 13.x and above does not throw ArrowInvalid error')
def test_sink_transform_int96(self):
with self.assertRaisesRegex(Exception, 'would lose data'):
# Should throw an error "ArrowInvalid: Casting from timestamp[ns] to
# timestamp[us] would lose data"
dst = tempfile.NamedTemporaryFile()
path = dst.name
with TestPipeline() as p:
_ = p \
| Create(self.RECORDS) \
| WriteToParquet(
path, self.SCHEMA96, num_shards=1, shard_name_template='')
def test_sink_transform(self):
with TemporaryDirectory() as tmp_dirname:
path = os.path.join(tmp_dirname + "tmp_filename")
with TestPipeline() as p:
_ = p \
| Create(self.RECORDS) \
| WriteToParquet(
path, self.SCHEMA, num_shards=1, shard_name_template='')
with TestPipeline() as p:
# json used for stable sortability
readback = \
p \
| ReadFromParquet(path) \
| Map(json.dumps)
assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))
def test_sink_transform_batched(self):
with TemporaryDirectory() as tmp_dirname:
path = os.path.join(tmp_dirname + "tmp_filename")
with TestPipeline() as p:
_ = p \
| Create([self._records_as_arrow()]) \
| WriteToParquetBatched(
path, self.SCHEMA, num_shards=1, shard_name_template='')
with TestPipeline() as p:
# json used for stable sortability
readback = \
p \
| ReadFromParquet(path) \
| Map(json.dumps)
assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))
def test_sink_transform_compliant_nested_type(self):
if ARROW_MAJOR_VERSION < 4:
return unittest.skip(
'Writing with compliant nested type is only '
'supported in pyarrow 4.x and above')
with TemporaryDirectory() as tmp_dirname:
path = os.path.join(tmp_dirname + 'tmp_filename')
with TestPipeline() as p:
_ = p \
| Create(self.RECORDS_NESTED) \
| WriteToParquet(
path, self.SCHEMA_NESTED, num_shards=1,
shard_name_template='', use_compliant_nested_type=True)
with TestPipeline() as p:
# json used for stable sortability
readback = \
p \
| ReadFromParquet(path) \
| Map(json.dumps)
assert_that(
readback, equal_to([json.dumps(r) for r in self.RECORDS_NESTED]))
def test_schema_read_write(self):
with TemporaryDirectory() as tmp_dirname:
path = os.path.join(tmp_dirname, 'tmp_filename')
rows = [beam.Row(a=1, b='x'), beam.Row(a=2, b='y')]
stable_repr = lambda row: json.dumps(row._asdict())
with TestPipeline() as p:
_ = p | Create(rows) | WriteToParquet(path)
with TestPipeline() as p:
readback = (
p
| ReadFromParquet(path + '*', as_rows=True)
| Map(stable_repr))
assert_that(readback, equal_to([stable_repr(r) for r in rows]))
def test_write_with_nullable_fields_missing_data(self):
"""Test WriteToParquet with nullable fields where some fields are missing.
This test addresses the bug reported in:
https://github.com/apache/beam/issues/35791
where WriteToParquet fails with a KeyError if any nullable
field is missing in the data.
"""
# Define PyArrow schema with all fields nullable
schema = pa.schema([
pa.field("id", pa.int64(), nullable=True),
pa.field("name", pa.string(), nullable=True),
pa.field("age", pa.int64(), nullable=True),
pa.field("email", pa.string(), nullable=True),
])
# Sample data with missing nullable fields
data = [
{
'id': 1, 'name': 'Alice', 'age': 30
}, # missing 'email'
{
'id': 2, 'name': 'Bob', 'age': 25, 'email': 'bob@example.com'
}, # all fields present
{
'id': 3, 'name': 'Charlie', 'age': None, 'email': None
}, # explicit None values
{
'id': 4, 'name': 'David'
}, # missing 'age' and 'email'
]
with TemporaryDirectory() as tmp_dirname:
path = os.path.join(tmp_dirname, 'nullable_test')
# Write data with missing nullable fields - this should not raise KeyError
with TestPipeline() as p:
_ = (
p
| Create(data)
| WriteToParquet(
path, schema, num_shards=1, shard_name_template=''))
# Read back and verify the data
with TestPipeline() as p:
readback = (
p
| ReadFromParquet(path + '*')
| Map(json.dumps, sort_keys=True))
# Expected data should have None for missing nullable fields
expected_data = [
{
'id': 1, 'name': 'Alice', 'age': 30, 'email': None
},
{
'id': 2, 'name': 'Bob', 'age': 25, 'email': 'bob@example.com'
},
{
'id': 3, 'name': 'Charlie', 'age': None, 'email': None
},
{
'id': 4, 'name': 'David', 'age': None, 'email': None
},
]
assert_that(
readback,
equal_to([json.dumps(r, sort_keys=True) for r in expected_data]))
def test_batched_read(self):
with TemporaryDirectory() as tmp_dirname:
path = os.path.join(tmp_dirname + "tmp_filename")
with TestPipeline() as p:
_ = p \
| Create(self.RECORDS, reshuffle=False) \
| WriteToParquet(
path, self.SCHEMA, num_shards=1, shard_name_template='')
with TestPipeline() as p:
# json used for stable sortability
readback = \
p \
| ReadFromParquetBatched(path)
assert_that(readback, equal_to([self._records_as_arrow()]))
@parameterized.expand([
param(compression_type='snappy'),
param(compression_type='gzip'),
param(compression_type='brotli'),
param(compression_type='lz4'),
param(compression_type='zstd')
])
def test_sink_transform_compressed(self, compression_type):
if compression_type == 'lz4' and ARROW_MAJOR_VERSION == 1:
return unittest.skip(
"Writing with LZ4 compression is not supported in "
"pyarrow 1.x")
with TemporaryDirectory() as tmp_dirname:
path = os.path.join(tmp_dirname + "tmp_filename")
with TestPipeline() as p:
_ = p \
| Create(self.RECORDS) \
| WriteToParquet(
path, self.SCHEMA, codec=compression_type,
num_shards=1, shard_name_template='')
with TestPipeline() as p:
# json used for stable sortability
readback = \
p \
| ReadFromParquet(path + '*') \
| Map(json.dumps)
assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))
def test_read_reentrant(self):
file_name = self._write_data(count=6, row_group_size=3)
source = _create_parquet_source(file_name)
source_test_utils.assert_reentrant_reads_succeed((source, None, None))
def test_read_without_splitting_multiple_row_group(self):
file_name = self._write_data(count=12000, row_group_size=1000)
# We expect 12000 elements, split into batches of 1000 elements. Create
# a list of pa.Table instances to model this expecation
expected_result = [
pa.Table.from_batches([batch]) for batch in self._records_as_arrow(
count=12000).to_batches(max_chunksize=1000)
]
self._run_parquet_test(file_name, None, None, False, expected_result)
def test_read_with_splitting_multiple_row_group(self):
file_name = self._write_data(count=12000, row_group_size=1000)
# We expect 12000 elements, split into batches of 1000 elements. Create
# a list of pa.Table instances to model this expecation
expected_result = [
pa.Table.from_batches([batch]) for batch in self._records_as_arrow(
count=12000).to_batches(max_chunksize=1000)
]
self._run_parquet_test(file_name, None, 10000, True, expected_result)
def test_dynamic_work_rebalancing(self):
# This test depends on count being sufficiently large + the ratio of
# count to row_group_size also being sufficiently large (but the required
# ratio to pass varies for values of row_group_size and, somehow, the
# version of pyarrow being tested against.)
file_name = self._write_data(count=280, row_group_size=20)
source = _create_parquet_source(file_name)
splits = [split for split in source.split(desired_bundle_size=float('inf'))]
assert len(splits) == 1
source_test_utils.assert_split_at_fraction_exhaustive(
splits[0].source, splits[0].start_position, splits[0].stop_position)
def test_min_bundle_size(self):
file_name = self._write_data(count=120, row_group_size=20)
source = _create_parquet_source(
file_name, min_bundle_size=100 * 1024 * 1024)
splits = [split for split in source.split(desired_bundle_size=1)]
self.assertEqual(len(splits), 1)
source = _create_parquet_source(file_name, min_bundle_size=0)
splits = [split for split in source.split(desired_bundle_size=1)]
self.assertNotEqual(len(splits), 1)
def _convert_to_timestamped_record(self, record):
timestamped_record = record.copy()
timestamped_record['favorite_number'] =\
pandas.Timestamp(timestamped_record['favorite_number'])
return timestamped_record
def test_int96_type_conversion(self):
file_name = self._write_data(
count=120, row_group_size=20, schema=self.SCHEMA96)
orig = self._records_as_arrow(count=120, schema=self.SCHEMA96)
expected_result = [
pa.Table.from_batches([batch], schema=self.SCHEMA96)
for batch in orig.to_batches(max_chunksize=20)
]
self._run_parquet_test(file_name, None, None, False, expected_result)
def test_split_points(self):
file_name = self._write_data(count=12000, row_group_size=3000)
source = _create_parquet_source(file_name)
splits = [split for split in source.split(desired_bundle_size=float('inf'))]
assert len(splits) == 1
range_tracker = splits[0].source.get_range_tracker(
splits[0].start_position, splits[0].stop_position)
split_points_report = []
for _ in splits[0].source.read(range_tracker):
split_points_report.append(range_tracker.split_points())
# There are a total of four row groups. Each row group has 3000 records.
# When reading records of the first group, range_tracker.split_points()
# should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
self.assertEqual(
split_points_report,
[
(0, RangeTracker.SPLIT_POINTS_UNKNOWN),
(1, RangeTracker.SPLIT_POINTS_UNKNOWN),
(2, RangeTracker.SPLIT_POINTS_UNKNOWN),
(3, 1),
])
def test_selective_columns(self):
file_name = self._write_data()
orig = self._records_as_arrow()
name_column = self.SCHEMA.field('name')
expected_result = [
pa.Table.from_arrays(
[orig.column('name')],
schema=pa.schema([('name', name_column.type, name_column.nullable)
]))
]
self._run_parquet_test(file_name, ['name'], None, False, expected_result)
def test_sink_transform_multiple_row_group(self):
with TemporaryDirectory() as tmp_dirname:
path = os.path.join(tmp_dirname + "tmp_filename")
# Pin to FnApiRunner since test assumes fixed bundle size
with TestPipeline('FnApiRunner') as p:
# writing 623200 bytes of data
_ = p \
| Create(self.RECORDS * 4000) \
| WriteToParquet(
path, self.SCHEMA, num_shards=1, codec='none',
shard_name_template='', row_group_buffer_size=250000)
self.assertEqual(pq.read_metadata(path).num_row_groups, 3)
def test_read_all_from_parquet_single_file(self):
path = self._write_data()
with TestPipeline() as p:
assert_that(
p \
| Create([path]) \
| ReadAllFromParquet(),
equal_to(self.RECORDS))
with TestPipeline() as p:
assert_that(
p \
| Create([path]) \
| ReadAllFromParquetBatched(),
equal_to([self._records_as_arrow()]))
def test_read_all_from_parquet_many_single_files(self):
path1 = self._write_data()
path2 = self._write_data()
path3 = self._write_data()
with TestPipeline() as p:
assert_that(
p \
| Create([path1, path2, path3]) \
| ReadAllFromParquet(),
equal_to(self.RECORDS * 3))
with TestPipeline() as p:
assert_that(
p \
| Create([path1, path2, path3]) \
| ReadAllFromParquetBatched(),
equal_to([self._records_as_arrow()] * 3))
def test_read_all_from_parquet_file_pattern(self):
file_pattern = self._write_pattern(5)
with TestPipeline() as p:
assert_that(
p \
| Create([file_pattern]) \
| ReadAllFromParquet(),
equal_to(self.RECORDS * 5))
with TestPipeline() as p:
assert_that(
p \
| Create([file_pattern]) \
| ReadAllFromParquetBatched(),
equal_to([self._records_as_arrow()] * 5))
def test_read_all_from_parquet_many_file_patterns(self):
file_pattern1 = self._write_pattern(5)
file_pattern2 = self._write_pattern(2)
file_pattern3 = self._write_pattern(3)
with TestPipeline() as p:
assert_that(
p \
| Create([file_pattern1, file_pattern2, file_pattern3]) \
| ReadAllFromParquet(),
equal_to(self.RECORDS * 10))
with TestPipeline() as p:
assert_that(
p \
| Create([file_pattern1, file_pattern2, file_pattern3]) \
| ReadAllFromParquetBatched(),
equal_to([self._records_as_arrow()] * 10))
def test_read_all_from_parquet_with_filename(self):
file_pattern, file_paths = self._write_pattern(3, with_filename=True)
result = [(path, record) for path in file_paths for record in self.RECORDS]
with TestPipeline() as p:
assert_that(
p \
| Create([file_pattern]) \
| ReadAllFromParquet(with_filename=True),
equal_to(result))
class GenerateEvent(beam.PTransform):
@staticmethod
def sample_data():
return GenerateEvent()
def expand(self, input):
elemlist = [{'age': 10}, {'age': 20}, {'age': 30}]
elem = elemlist
return (
input
| TestStream().add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 1, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 2, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 3, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 4, 0,
tzinfo=pytz.UTC).timestamp()).
advance_watermark_to(
datetime(2021, 3, 1, 0, 0, 5, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 5, 0,
tzinfo=pytz.UTC).timestamp()).
add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 6,
0, tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 7, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 8, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 9, 0,
tzinfo=pytz.UTC).timestamp()).
advance_watermark_to(
datetime(2021, 3, 1, 0, 0, 10, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 10, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 11, 0,
tzinfo=pytz.UTC).timestamp()).
add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 12, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 13, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 14, 0,
tzinfo=pytz.UTC).timestamp()).
advance_watermark_to(
datetime(2021, 3, 1, 0, 0, 15, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 15, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 16, 0,
tzinfo=pytz.UTC).timestamp()).
add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 17, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 18, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 19, 0,
tzinfo=pytz.UTC).timestamp()).
advance_watermark_to(
datetime(2021, 3, 1, 0, 0, 20, 0,
tzinfo=pytz.UTC).timestamp()).add_elements(
elements=elem,
event_timestamp=datetime(
2021, 3, 1, 0, 0, 20, 0,
tzinfo=pytz.UTC).timestamp()).advance_watermark_to(
datetime(
2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC).
timestamp()).advance_watermark_to_infinity())
class WriteStreamingTest(unittest.TestCase):
def setUp(self):
super().setUp()
self.tempdir = tempfile.mkdtemp()
def tearDown(self):
if os.path.exists(self.tempdir):
shutil.rmtree(self.tempdir)
def test_write_streaming_2_shards_default_shard_name_template(
self, num_shards=2):
with TestPipeline() as p:
output = (p | GenerateEvent.sample_data())
#ParquetIO
pyschema = pa.schema([('age', pa.int64())])
output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet(
file_path_prefix=self.tempdir + "/ouput_WriteToParquet",
file_name_suffix=".parquet",
num_shards=num_shards,
triggering_frequency=60,
schema=pyschema)
_ = output2 | 'LogElements after WriteToParquet' >> LogElements(
prefix='after WriteToParquet ', with_window=True, level=logging.INFO)
# Regex to match the expected windowed file pattern
# Example:
# ouput_WriteToParquet-[1614556800.0, 1614556805.0)-00000-of-00002.parquet
# It captures: window_interval, shard_num, total_shards
pattern_string = (
r'.*-\[(?P<window_start>[\d\.]+), '
r'(?P<window_end>[\d\.]+|Infinity)\)-'
r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.parquet$')
pattern = re.compile(pattern_string)
file_names = []
for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'):
match = pattern.match(file_name)
self.assertIsNotNone(
match, f"File name {file_name} did not match expected pattern.")
if match:
file_names.append(file_name)
print("Found files matching expected pattern:", file_names)
self.assertEqual(
len(file_names),
num_shards,
"expected %d files, but got: %d" % (num_shards, len(file_names)))
def test_write_streaming_2_shards_custom_shard_name_template(
self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN'):
with TestPipeline() as p:
output = (p | GenerateEvent.sample_data())
#ParquetIO
pyschema = pa.schema([('age', pa.int64())])
output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet(
file_path_prefix=self.tempdir + "/ouput_WriteToParquet",
file_name_suffix=".parquet",
shard_name_template=shard_name_template,
num_shards=num_shards,
triggering_frequency=60,
schema=pyschema)
_ = output2 | 'LogElements after WriteToParquet' >> LogElements(
prefix='after WriteToParquet ', with_window=True, level=logging.INFO)
# Regex to match the expected windowed file pattern
# Example:
# ouput_WriteToParquet-[2021-03-01T00-00-00, 2021-03-01T00-01-00)-
# 00000-of-00002.parquet
# It captures: window_interval, shard_num, total_shards
pattern_string = (
r'.*-\[(?P<window_start>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), '
r'(?P<window_end>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-'
r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.parquet$')
pattern = re.compile(pattern_string)
file_names = []
for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'):
match = pattern.match(file_name)
self.assertIsNotNone(
match, f"File name {file_name} did not match expected pattern.")
if match:
file_names.append(file_name)
print("Found files matching expected pattern:", file_names)
self.assertEqual(
len(file_names),
num_shards,
"expected %d files, but got: %d" % (num_shards, len(file_names)))
def test_write_streaming_2_shards_custom_shard_name_template_5s_window(
self,
num_shards=2,
shard_name_template='-V-SSSSS-of-NNNNN',
triggering_frequency=5):
with TestPipeline() as p:
output = (p | GenerateEvent.sample_data())
#ParquetIO
pyschema = pa.schema([('age', pa.int64())])
output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet(
file_path_prefix=self.tempdir + "/ouput_WriteToParquet",
file_name_suffix=".parquet",
shard_name_template=shard_name_template,
num_shards=num_shards,
triggering_frequency=triggering_frequency,
schema=pyschema)
_ = output2 | 'LogElements after WriteToParquet' >> LogElements(
prefix='after WriteToParquet ', with_window=True, level=logging.INFO)
# Regex to match the expected windowed file pattern
# Example:
# ouput_WriteToParquet-[2021-03-01T00-00-00, 2021-03-01T00-01-00)-
# 00000-of-00002.parquet
# It captures: window_interval, shard_num, total_shards
pattern_string = (
r'.*-\[(?P<window_start>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), '
r'(?P<window_end>\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-'
r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.parquet$')
pattern = re.compile(pattern_string)
file_names = []
for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'):
match = pattern.match(file_name)
self.assertIsNotNone(
match, f"File name {file_name} did not match expected pattern.")
if match:
file_names.append(file_name)
print("Found files matching expected pattern:", file_names)
# for 5s window size, the input should be processed by 5 windows with
# 2 shards per window
self.assertEqual(
len(file_names),
10,
"expected %d files, but got: %d" % (num_shards, len(file_names)))
def test_write_streaming_undef_shards_default_shard_name_template_windowed_pcoll( # pylint: disable=line-too-long
self):
with TestPipeline() as p:
output = (
p | GenerateEvent.sample_data()
| 'User windowing' >> beam.transforms.core.WindowInto(
beam.transforms.window.FixedWindows(10),
trigger=beam.transforms.trigger.AfterWatermark(),
accumulation_mode=beam.transforms.trigger.AccumulationMode.
DISCARDING,
allowed_lateness=beam.utils.timestamp.Duration(seconds=0)))
#ParquetIO
pyschema = pa.schema([('age', pa.int64())])
output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet(
file_path_prefix=self.tempdir + "/ouput_WriteToParquet",
file_name_suffix=".parquet",
num_shards=0,
schema=pyschema)
_ = output2 | 'LogElements after WriteToParquet' >> LogElements(
prefix='after WriteToParquet ', with_window=True, level=logging.INFO)
# Regex to match the expected windowed file pattern
# Example:
# ouput_WriteToParquet-[1614556800.0, 1614556805.0)-00000-of-00002.parquet
# It captures: window_interval, shard_num, total_shards
pattern_string = (
r'.*-\[(?P<window_start>[\d\.]+), '
r'(?P<window_end>[\d\.]+|Infinity)\)-'
r'(?P<shard_num>\d{5})-of-(?P<total_shards>\d{5})\.parquet$')
pattern = re.compile(pattern_string)
file_names = []
for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'):
match = pattern.match(file_name)
self.assertIsNotNone(
match, f"File name {file_name} did not match expected pattern.")
if match:
file_names.append(file_name)
print("Found files matching expected pattern:", file_names)
self.assertGreaterEqual(
len(file_names),
1 * 3, #25s of data covered by 3 10s windows
"expected %d files, but got: %d" % (1 * 3, len(file_names)))
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()