blob: b06e7268fec483bb844ab4cc5e5bd38949f4de96 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import logging
import string
import unittest
import uuid
from collections import Counter
from datetime import datetime
import pytest
import pytz
import apache_beam as beam
from apache_beam import Create
from apache_beam import DoFn
from apache_beam import FlatMap
from apache_beam import Flatten
from apache_beam import Map
from apache_beam import ParDo
from apache_beam import Reshuffle
from apache_beam.io.filesystems import FileSystems
from apache_beam.io.parquetio import ReadAllFromParquet
from apache_beam.io.parquetio import WriteToParquet
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import BeamAssertException
from apache_beam.transforms import CombineGlobally
from apache_beam.transforms.combiners import Count
from apache_beam.transforms.periodicsequence import PeriodicImpulse
try:
import pyarrow as pa
except ImportError:
pa = None
@unittest.skipIf(pa is None, "PyArrow is not installed.")
class TestParquetIT(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
@pytest.mark.it_postcommit
def test_parquetio_it(self):
file_prefix = "parquet_it_test"
init_size = 10
data_size = 20000
with TestPipeline(is_integration_test=True) as p:
pcol = self._generate_data(p, file_prefix, init_size, data_size)
self._verify_data(pcol, init_size, data_size)
@staticmethod
def _sum_verifier(init_size, data_size, x):
expected = sum(range(data_size)) * init_size
if x != expected:
raise BeamAssertException(
"incorrect sum: expected(%d) actual(%d)" % (expected, x))
return []
@staticmethod
def _count_verifier(init_size, data_size, x):
name, count = x[0].decode('utf-8'), x[1]
counter = Counter(
[string.ascii_uppercase[x % 26] for x in range(0, data_size * 4, 4)])
expected_count = counter[name[0]] * init_size
if count != expected_count:
raise BeamAssertException(
"incorrect count(%s): expected(%d) actual(%d)" %
(name, expected_count, count))
return []
def _verify_data(self, pcol, init_size, data_size):
read = pcol | 'read' >> ReadAllFromParquet()
v1 = (
read
| 'get_number' >> Map(lambda x: x['number'])
| 'sum_globally' >> CombineGlobally(sum)
| 'validate_number' >>
FlatMap(lambda x: TestParquetIT._sum_verifier(init_size, data_size, x)))
v2 = (
read
| 'make_pair' >> Map(lambda x: (x['name'], x['number']))
| 'count_per_key' >> Count.PerKey()
| 'validate_name' >> FlatMap(
lambda x: TestParquetIT._count_verifier(init_size, data_size, x)))
_ = ((v1, v2, pcol)
| 'flatten' >> Flatten()
| 'reshuffle' >> Reshuffle()
| 'cleanup' >> Map(lambda x: FileSystems.delete([x])))
def _generate_data(self, p, output_prefix, init_size, data_size):
init_data = [x for x in range(init_size)]
lines = (
p
| 'create' >> Create(init_data)
| 'produce' >> ParDo(ProducerFn(data_size)))
schema = pa.schema([('name', pa.binary()), ('number', pa.int64())])
files = lines | 'write' >> WriteToParquet(
output_prefix, schema, codec='snappy', file_name_suffix='.parquet')
return files
class ProducerFn(DoFn):
def __init__(self, number):
super().__init__()
self._number = number
self._string_index = 0
self._number_index = 0
def process(self, element):
self._string_index = 0
self._number_index = 0
for _ in range(self._number):
yield {'name': self.get_string(4), 'number': self.get_int()}
def get_string(self, length):
s = []
for _ in range(length):
s.append(string.ascii_uppercase[self._string_index])
self._string_index = (self._string_index + 1) % 26
return ''.join(s)
def get_int(self):
i = self._number_index
self._number_index = self._number_index + 1
return i
@unittest.skipIf(pa is None, "PyArrow is not installed.")
class WriteStreamingIT(unittest.TestCase):
def setUp(self):
self.test_pipeline = TestPipeline(is_integration_test=True)
self.runner_name = type(self.test_pipeline.runner).__name__
super().setUp()
def test_write_streaming_2_shards_default_shard_name_template(
self, num_shards=2):
args = self.test_pipeline.get_full_options_as_args(streaming=True)
unique_id = str(uuid.uuid4())
output_file = f'gs://apache-beam-testing-integration-testing/iobase/test-{unique_id}' # pylint: disable=line-too-long
p = beam.Pipeline(argv=args)
pyschema = pa.schema([('age', pa.int64())])
_ = (
p
| "generate impulse" >> PeriodicImpulse(
start_timestamp=datetime(2021, 3, 1, 0, 0, 1, 0,
tzinfo=pytz.UTC).timestamp(),
stop_timestamp=datetime(2021, 3, 1, 0, 0, 20, 0,
tzinfo=pytz.UTC).timestamp(),
fire_interval=1)
| "generate data" >> beam.Map(lambda t: {'age': t * 10})
| 'WriteToParquet' >> beam.io.WriteToParquet(
file_path_prefix=output_file,
file_name_suffix=".parquet",
num_shards=num_shards,
triggering_frequency=60,
schema=pyschema))
result = p.run()
result.wait_until_finish(duration=600 * 1000)
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()