| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| from __future__ import absolute_import |
| from __future__ import division |
| |
| import logging |
| import string |
| import sys |
| import unittest |
| from collections import Counter |
| |
| from nose.plugins.attrib import attr |
| |
| from apache_beam import Create |
| from apache_beam import DoFn |
| from apache_beam import FlatMap |
| from apache_beam import Flatten |
| from apache_beam import Map |
| from apache_beam import ParDo |
| from apache_beam import Reshuffle |
| from apache_beam.io.filesystems import FileSystems |
| from apache_beam.io.parquetio import ReadAllFromParquet |
| from apache_beam.io.parquetio import WriteToParquet |
| from apache_beam.testing.test_pipeline import TestPipeline |
| from apache_beam.testing.util import BeamAssertException |
| from apache_beam.transforms import CombineGlobally |
| from apache_beam.transforms.combiners import Count |
| |
| try: |
| import pyarrow as pa |
| except ImportError: |
| pa = None |
| |
| |
| @unittest.skipIf(pa is None, "PyArrow is not installed.") |
| class TestParquetIT(unittest.TestCase): |
| |
| @classmethod |
| def setUpClass(cls): |
| # Method has been renamed in Python 3 |
| if sys.version_info[0] < 3: |
| cls.assertCountEqual = cls.assertItemsEqual |
| |
| def setUp(self): |
| pass |
| |
| def tearDown(self): |
| pass |
| |
| @attr('IT') |
| def test_parquetio_it(self): |
| file_prefix = "parquet_it_test" |
| init_size = 10 |
| data_size = 20000 |
| p = TestPipeline(is_integration_test=True) |
| pcol = self._generate_data( |
| p, file_prefix, init_size, data_size) |
| self._verify_data(pcol, init_size, data_size) |
| result = p.run() |
| result.wait_until_finish() |
| |
| @staticmethod |
| def _sum_verifier(init_size, data_size, x): |
| expected = sum(range(data_size)) * init_size |
| if x != expected: |
| raise BeamAssertException( |
| "incorrect sum: expected(%d) actual(%d)" % (expected, x) |
| ) |
| return [] |
| |
| @staticmethod |
| def _count_verifier(init_size, data_size, x): |
| name, count = x[0].decode('utf-8'), x[1] |
| counter = Counter( |
| [string.ascii_uppercase[x%26] for x in range(0, data_size*4, 4)] |
| ) |
| expected_count = counter[name[0]] * init_size |
| if count != expected_count: |
| raise BeamAssertException( |
| "incorrect count(%s): expected(%d) actual(%d)" % ( |
| name, expected_count, count |
| ) |
| ) |
| return [] |
| |
| def _verify_data(self, pcol, init_size, data_size): |
| read = pcol | 'read' >> ReadAllFromParquet() |
| v1 = (read |
| | 'get_number' >> Map(lambda x: x['number']) |
| | 'sum_globally' >> CombineGlobally(sum) |
| | 'validate_number' >> FlatMap( |
| lambda x: TestParquetIT._sum_verifier(init_size, data_size, x) |
| ) |
| ) |
| v2 = (read |
| | 'make_pair' >> Map(lambda x: (x['name'], x['number'])) |
| | 'count_per_key' >> Count.PerKey() |
| | 'validate_name' >> FlatMap( |
| lambda x: TestParquetIT._count_verifier(init_size, data_size, x) |
| ) |
| ) |
| _ = ((v1, v2, pcol) |
| | 'flatten' >> Flatten() |
| | 'reshuffle' >> Reshuffle() |
| | 'cleanup' >> Map(lambda x: FileSystems.delete([x])) |
| ) |
| |
| def _generate_data( |
| self, p, output_prefix, init_size, data_size): |
| init_data = [x for x in range(init_size)] |
| |
| lines = (p |
| | 'create' >> Create(init_data) |
| | 'produce' >> ParDo(ProducerFn(data_size)) |
| ) |
| |
| schema = pa.schema([ |
| ('name', pa.binary()), |
| ('number', pa.int64()) |
| ]) |
| |
| files = lines | 'write' >> WriteToParquet( |
| output_prefix, |
| schema, |
| codec='snappy', |
| file_name_suffix='.parquet' |
| ) |
| |
| return files |
| |
| |
| class ProducerFn(DoFn): |
| def __init__(self, number): |
| super(ProducerFn, self).__init__() |
| self._number = number |
| self._string_index = 0 |
| self._number_index = 0 |
| |
| def process(self, element): |
| self._string_index = 0 |
| self._number_index = 0 |
| for _ in range(self._number): |
| yield {'name': self.get_string(4), 'number': self.get_int()} |
| |
| def get_string(self, length): |
| s = [] |
| for _ in range(length): |
| s.append(string.ascii_uppercase[self._string_index]) |
| self._string_index = (self._string_index + 1) % 26 |
| return ''.join(s) |
| |
| def get_int(self): |
| i = self._number_index |
| self._number_index = self._number_index + 1 |
| return i |
| |
| |
| if __name__ == '__main__': |
| logging.getLogger().setLevel(logging.INFO) |
| unittest.main() |