| # licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| from collections import namedtuple |
| from concurrent.futures import ThreadPoolExecutor |
| from functools import partial |
| import glob |
| import gzip |
| import itertools |
| import os |
| import sys |
| import tempfile |
| import traceback |
| |
| from .scenario import Scenario |
| from .tester_cpp import CPPTester |
| from .tester_go import GoTester |
| from .tester_rust import RustTester |
| from .tester_java import JavaTester |
| from .tester_js import JSTester |
| from .util import (ARROW_ROOT_DEFAULT, guid, SKIP_ARROW, SKIP_FLIGHT, |
| printer) |
| from . import datagen |
| |
| |
| Failure = namedtuple('Failure', |
| ('test_case', 'producer', 'consumer', 'exc_info')) |
| |
| log = printer.print |
| |
| |
| class Outcome: |
| def __init__(self): |
| self.failure = None |
| self.skipped = False |
| |
| |
| class IntegrationRunner(object): |
| |
| def __init__(self, json_files, flight_scenarios, testers, tempdir=None, |
| debug=False, stop_on_error=True, gold_dirs=None, |
| serial=False, match=None, **unused_kwargs): |
| self.json_files = json_files |
| self.flight_scenarios = flight_scenarios |
| self.testers = testers |
| self.temp_dir = tempdir or tempfile.mkdtemp() |
| self.debug = debug |
| self.stop_on_error = stop_on_error |
| self.serial = serial |
| self.gold_dirs = gold_dirs |
| self.failures = [] |
| self.match = match |
| |
| if self.match is not None: |
| print("-- Only running tests with {} in their name" |
| .format(self.match)) |
| self.json_files = [json_file for json_file in self.json_files |
| if self.match in json_file.name] |
| |
| def run(self): |
| """ |
| Run Arrow IPC integration tests for the matrix of enabled |
| implementations. |
| """ |
| for producer, consumer in itertools.product( |
| filter(lambda t: t.PRODUCER, self.testers), |
| filter(lambda t: t.CONSUMER, self.testers)): |
| self._compare_implementations( |
| producer, consumer, self._produce_consume, |
| self.json_files) |
| if self.gold_dirs: |
| for gold_dir, consumer in itertools.product( |
| self.gold_dirs, |
| filter(lambda t: t.CONSUMER, self.testers)): |
| log('\n\n\n\n') |
| log('******************************************************') |
| log('Tests against golden files in {}'.format(gold_dir)) |
| log('******************************************************') |
| |
| def run_gold(producer, consumer, outcome, test_case): |
| self._run_gold(gold_dir, producer, consumer, outcome, |
| test_case) |
| self._compare_implementations( |
| consumer, consumer, run_gold, |
| self._gold_tests(gold_dir)) |
| |
| def run_flight(self): |
| """ |
| Run Arrow Flight integration tests for the matrix of enabled |
| implementations. |
| """ |
| servers = filter(lambda t: t.FLIGHT_SERVER, self.testers) |
| clients = filter(lambda t: (t.FLIGHT_CLIENT and t.CONSUMER), |
| self.testers) |
| for server, client in itertools.product(servers, clients): |
| self._compare_flight_implementations(server, client) |
| |
| def _gold_tests(self, gold_dir): |
| prefix = os.path.basename(os.path.normpath(gold_dir)) |
| SUFFIX = ".json.gz" |
| golds = [jf for jf in os.listdir(gold_dir) if jf.endswith(SUFFIX)] |
| for json_path in golds: |
| name = json_path[json_path.index('_')+1: -len(SUFFIX)] |
| base_name = prefix + "_" + name + ".gold.json" |
| out_path = os.path.join(self.temp_dir, base_name) |
| with gzip.open(os.path.join(gold_dir, json_path)) as i: |
| with open(out_path, "wb") as out: |
| out.write(i.read()) |
| |
| try: |
| skip = next(f for f in self.json_files |
| if f.name == name).skip |
| except StopIteration: |
| skip = set() |
| if name == 'union' and prefix == '0.17.1': |
| skip.add("Java") |
| if prefix == '1.0.0-bigendian' or prefix == '1.0.0-littleendian': |
| skip.add("Go") |
| skip.add("Java") |
| skip.add("JS") |
| skip.add("Rust") |
| if prefix == '2.0.0-compression': |
| skip.add("JS") |
| skip.add("Rust") |
| |
| # See https://github.com/apache/arrow/pull/9822 for how to |
| # disable specific compression type tests. |
| |
| if prefix == '4.0.0-shareddict': |
| skip.add("Go") |
| |
| yield datagen.File(name, None, None, skip=skip, path=out_path) |
| |
| def _run_test_cases(self, producer, consumer, case_runner, |
| test_cases): |
| def case_wrapper(test_case): |
| with printer.cork(): |
| return case_runner(test_case) |
| |
| if self.failures and self.stop_on_error: |
| return |
| |
| if self.serial: |
| for outcome in map(case_wrapper, test_cases): |
| if outcome.failure is not None: |
| self.failures.append(outcome.failure) |
| if self.stop_on_error: |
| break |
| |
| else: |
| with ThreadPoolExecutor() as executor: |
| for outcome in executor.map(case_wrapper, test_cases): |
| if outcome.failure is not None: |
| self.failures.append(outcome.failure) |
| if self.stop_on_error: |
| break |
| |
| def _compare_implementations( |
| self, producer, consumer, run_binaries, test_cases): |
| """ |
| Compare Arrow IPC for two implementations (one producer, one consumer). |
| """ |
| log('##########################################################') |
| log('IPC: {0} producing, {1} consuming' |
| .format(producer.name, consumer.name)) |
| log('##########################################################') |
| |
| case_runner = partial(self._run_ipc_test_case, |
| producer, consumer, run_binaries) |
| self._run_test_cases(producer, consumer, case_runner, test_cases) |
| |
| def _run_ipc_test_case(self, producer, consumer, run_binaries, test_case): |
| """ |
| Run one IPC test case. |
| """ |
| outcome = Outcome() |
| |
| json_path = test_case.path |
| log('==========================================================') |
| log('Testing file {0}'.format(json_path)) |
| log('==========================================================') |
| |
| if producer.name in test_case.skip: |
| log('-- Skipping test because producer {0} does ' |
| 'not support'.format(producer.name)) |
| outcome.skipped = True |
| |
| elif consumer.name in test_case.skip: |
| log('-- Skipping test because consumer {0} does ' |
| 'not support'.format(consumer.name)) |
| outcome.skipped = True |
| |
| elif SKIP_ARROW in test_case.skip: |
| log('-- Skipping test') |
| outcome.skipped = True |
| |
| else: |
| try: |
| run_binaries(producer, consumer, outcome, test_case) |
| except Exception: |
| traceback.print_exc(file=printer.stdout) |
| outcome.failure = Failure(test_case, producer, consumer, |
| sys.exc_info()) |
| |
| return outcome |
| |
| def _produce_consume(self, producer, consumer, outcome, test_case): |
| # Make the random access file |
| json_path = test_case.path |
| file_id = guid()[:8] |
| name = os.path.splitext(os.path.basename(json_path))[0] |
| |
| producer_file_path = os.path.join(self.temp_dir, file_id + '_' + |
| name + '.json_as_file') |
| producer_stream_path = os.path.join(self.temp_dir, file_id + '_' + |
| name + '.producer_file_as_stream') |
| consumer_file_path = os.path.join(self.temp_dir, file_id + '_' + |
| name + '.consumer_stream_as_file') |
| |
| log('-- Creating binary inputs') |
| producer.json_to_file(json_path, producer_file_path) |
| |
| # Validate the file |
| log('-- Validating file') |
| consumer.validate(json_path, producer_file_path) |
| |
| log('-- Validating stream') |
| producer.file_to_stream(producer_file_path, producer_stream_path) |
| consumer.stream_to_file(producer_stream_path, consumer_file_path) |
| consumer.validate(json_path, consumer_file_path) |
| |
| def _run_gold(self, gold_dir, producer, consumer, outcome, test_case): |
| json_path = test_case.path |
| |
| # Validate the file |
| log('-- Validating file') |
| producer_file_path = os.path.join( |
| gold_dir, "generated_" + test_case.name + ".arrow_file") |
| consumer.validate(json_path, producer_file_path) |
| |
| log('-- Validating stream') |
| consumer_stream_path = os.path.join( |
| gold_dir, "generated_" + test_case.name + ".stream") |
| file_id = guid()[:8] |
| name = os.path.splitext(os.path.basename(json_path))[0] |
| |
| consumer_file_path = os.path.join(self.temp_dir, file_id + '_' + |
| name + '.consumer_stream_as_file') |
| |
| consumer.stream_to_file(consumer_stream_path, consumer_file_path) |
| consumer.validate(json_path, consumer_file_path) |
| |
| def _compare_flight_implementations(self, producer, consumer): |
| log('##########################################################') |
| log('Flight: {0} serving, {1} requesting' |
| .format(producer.name, consumer.name)) |
| log('##########################################################') |
| |
| case_runner = partial(self._run_flight_test_case, producer, consumer) |
| self._run_test_cases(producer, consumer, case_runner, |
| self.json_files + self.flight_scenarios) |
| |
| def _run_flight_test_case(self, producer, consumer, test_case): |
| """ |
| Run one Flight test case. |
| """ |
| outcome = Outcome() |
| |
| log('=' * 58) |
| log('Testing file {0}'.format(test_case.name)) |
| log('=' * 58) |
| |
| if producer.name in test_case.skip: |
| log('-- Skipping test because producer {0} does ' |
| 'not support'.format(producer.name)) |
| outcome.skipped = True |
| |
| elif consumer.name in test_case.skip: |
| log('-- Skipping test because consumer {0} does ' |
| 'not support'.format(consumer.name)) |
| outcome.skipped = True |
| |
| elif SKIP_FLIGHT in test_case.skip: |
| log('-- Skipping test') |
| outcome.skipped = True |
| |
| else: |
| try: |
| if isinstance(test_case, Scenario): |
| server = producer.flight_server(test_case.name) |
| client_args = {'scenario_name': test_case.name} |
| else: |
| server = producer.flight_server() |
| client_args = {'json_path': test_case.path} |
| |
| with server as port: |
| # Have the client upload the file, then download and |
| # compare |
| consumer.flight_request(port, **client_args) |
| except Exception: |
| traceback.print_exc(file=printer.stdout) |
| outcome.failure = Failure(test_case, producer, consumer, |
| sys.exc_info()) |
| |
| return outcome |
| |
| |
| def get_static_json_files(): |
| glob_pattern = os.path.join(ARROW_ROOT_DEFAULT, |
| 'integration', 'data', '*.json') |
| return [ |
| datagen.File(name=os.path.basename(p), path=p, skip=set(), |
| schema=None, batches=None) |
| for p in glob.glob(glob_pattern) |
| ] |
| |
| |
| def run_all_tests(with_cpp=True, with_java=True, with_js=True, |
| with_go=True, with_rust=False, run_flight=False, |
| tempdir=None, **kwargs): |
| tempdir = tempdir or tempfile.mkdtemp(prefix='arrow-integration-') |
| |
| testers = [] |
| |
| if with_cpp: |
| testers.append(CPPTester(**kwargs)) |
| |
| if with_java: |
| testers.append(JavaTester(**kwargs)) |
| |
| if with_js: |
| testers.append(JSTester(**kwargs)) |
| |
| if with_go: |
| testers.append(GoTester(**kwargs)) |
| |
| if with_rust: |
| testers.append(RustTester(**kwargs)) |
| |
| static_json_files = get_static_json_files() |
| generated_json_files = datagen.get_generated_json_files(tempdir=tempdir) |
| json_files = static_json_files + generated_json_files |
| |
| # Additional integration test cases for Arrow Flight. |
| flight_scenarios = [ |
| Scenario( |
| "auth:basic_proto", |
| description="Authenticate using the BasicAuth protobuf."), |
| Scenario( |
| "middleware", |
| description="Ensure headers are propagated via middleware.", |
| skip={"Rust"} # TODO(ARROW-10961): tonic upgrade needed |
| ), |
| ] |
| |
| runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs) |
| runner.run() |
| if run_flight: |
| runner.run_flight() |
| |
| fail_count = 0 |
| if runner.failures: |
| log("################# FAILURES #################") |
| for test_case, producer, consumer, exc_info in runner.failures: |
| fail_count += 1 |
| log("FAILED TEST:", end=" ") |
| log(test_case.name, producer.name, "producing, ", |
| consumer.name, "consuming") |
| if exc_info: |
| traceback.print_exception(*exc_info) |
| log() |
| |
| log(fail_count, "failures") |
| if fail_count > 0: |
| sys.exit(1) |
| |
| |
| def write_js_test_json(directory): |
| datagen.generate_map_case().write( |
| os.path.join(directory, 'map.json') |
| ) |
| datagen.generate_nested_case().write( |
| os.path.join(directory, 'nested.json') |
| ) |
| datagen.generate_decimal_case().write( |
| os.path.join(directory, 'decimal.json') |
| ) |
| datagen.generate_datetime_case().write( |
| os.path.join(directory, 'datetime.json') |
| ) |
| datagen.generate_dictionary_case().write( |
| os.path.join(directory, 'dictionary.json') |
| ) |
| datagen.generate_dictionary_unsigned_case().write( |
| os.path.join(directory, 'dictionary_unsigned.json') |
| ) |
| datagen.generate_primitive_case([]).write( |
| os.path.join(directory, 'primitive_no_batches.json') |
| ) |
| datagen.generate_primitive_case([7, 10]).write( |
| os.path.join(directory, 'primitive.json') |
| ) |
| datagen.generate_primitive_case([0, 0, 0]).write( |
| os.path.join(directory, 'primitive-empty.json') |
| ) |