| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import ast |
| import base64 |
| import itertools |
| import os |
| import pathlib |
| import signal |
| import struct |
| import tempfile |
| import threading |
| import time |
| import traceback |
| import json |
| from datetime import datetime |
| |
| try: |
| import numpy as np |
| except ImportError: |
| np = None |
| import pytest |
| import pyarrow as pa |
| |
| from pyarrow.lib import IpcReadOptions, ReadStats, tobytes |
| from pyarrow.util import find_free_port |
| from pyarrow.tests import util |
| |
| try: |
| from pyarrow import flight |
| from pyarrow.flight import ( |
| FlightClient, FlightServerBase, |
| ServerAuthHandler, ClientAuthHandler, |
| ServerMiddleware, ServerMiddlewareFactory, |
| ClientMiddleware, ClientMiddlewareFactory, |
| FlightCallOptions, |
| ) |
| except ImportError: |
| flight = None |
| FlightClient, FlightServerBase = object, object |
| ServerAuthHandler, ClientAuthHandler = object, object |
| ServerMiddleware, ServerMiddlewareFactory = object, object |
| ClientMiddleware, ClientMiddlewareFactory = object, object |
| FlightCallOptions = object |
| |
| # Marks all of the tests in this module |
| # Ignore these with pytest ... -m 'not flight' |
| pytestmark = pytest.mark.flight |
| |
| |
| def test_import(): |
| # So we see the ImportError somewhere |
| import pyarrow.flight # noqa |
| |
| |
| def resource_root(): |
| """Get the path to the test resources directory.""" |
| if not os.environ.get("ARROW_TEST_DATA"): |
| raise RuntimeError("Test resources not found; set " |
| "ARROW_TEST_DATA to <repo root>/testing/data") |
| return pathlib.Path(os.environ["ARROW_TEST_DATA"]) / "flight" |
| |
| |
| def read_flight_resource(path): |
| """Get the contents of a test resource file.""" |
| root = resource_root() |
| if not root: |
| return None |
| try: |
| with (root / path).open("rb") as f: |
| return f.read() |
| except FileNotFoundError: |
| raise RuntimeError( |
| f"Test resource {root / path} not found; did you initialize the " |
| f"test resource submodule?\n{traceback.format_exc()}" |
| ) |
| |
| |
| def example_tls_certs(): |
| """Get the paths to test TLS certificates.""" |
| return { |
| "root_cert": read_flight_resource("root-ca.pem"), |
| "certificates": [ |
| flight.CertKeyPair( |
| cert=read_flight_resource("cert0.pem"), |
| key=read_flight_resource("cert0.key"), |
| ), |
| flight.CertKeyPair( |
| cert=read_flight_resource("cert1.pem"), |
| key=read_flight_resource("cert1.key"), |
| ), |
| ] |
| } |
| |
| |
| def simple_ints_table(): |
| data = [ |
| pa.array([-10, -5, 0, 5, 10]) |
| ] |
| return pa.Table.from_arrays(data, names=['some_ints']) |
| |
| |
| def simple_dicts_table(): |
| dict_values = pa.array(["foo", "baz", "quux"], type=pa.utf8()) |
| new_dict_values = pa.array(["foo", "baz", "quux", "new"], type=pa.utf8()) |
| data = [ |
| pa.chunked_array([ |
| pa.DictionaryArray.from_arrays([1, 0, None], dict_values), |
| pa.DictionaryArray.from_arrays([2, 1], dict_values), |
| pa.DictionaryArray.from_arrays([0, 3], new_dict_values) |
| ]) |
| ] |
| return pa.Table.from_arrays(data, names=['some_dicts']) |
| |
| |
| def multiple_column_table(): |
| return pa.Table.from_arrays([pa.array(['foo', 'bar', 'baz', 'qux']), |
| pa.array([1, 2, 3, 4])], |
| names=['a', 'b']) |
| |
| |
| class ConstantFlightServer(FlightServerBase): |
| """A Flight server that always returns the same data. |
| |
| See ARROW-4796: this server implementation will segfault if Flight |
| does not properly hold a reference to the Table object. |
| """ |
| |
| CRITERIA = b"the expected criteria" |
| |
| def __init__(self, location=None, options=None, **kwargs): |
| super().__init__(location, **kwargs) |
| # Ticket -> Table |
| self.table_factories = { |
| b'ints': simple_ints_table, |
| b'dicts': simple_dicts_table, |
| b'multi': multiple_column_table, |
| } |
| self.options = options |
| |
| def list_flights(self, context, criteria): |
| if criteria == self.CRITERIA: |
| yield flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path('/foo'), |
| [] |
| ) |
| |
| def do_get(self, context, ticket): |
| # Return a fresh table, so that Flight is the only one keeping a |
| # reference. |
| table = self.table_factories[ticket.ticket]() |
| return flight.RecordBatchStream(table, options=self.options) |
| |
| |
| class MetadataFlightServer(FlightServerBase): |
| """A Flight server that numbers incoming/outgoing data.""" |
| |
| def __init__(self, options=None, **kwargs): |
| super().__init__(**kwargs) |
| self.options = options |
| |
| def do_get(self, context, ticket): |
| data = [ |
| pa.array([-10, -5, 0, 5, 10]) |
| ] |
| table = pa.Table.from_arrays(data, names=['a']) |
| return flight.GeneratorStream( |
| table.schema, |
| self.number_batches(table), |
| options=self.options) |
| |
| def do_put(self, context, descriptor, reader, writer): |
| counter = 0 |
| expected_data = [-10, -5, 0, 5, 10] |
| assert reader.stats.num_messages == 1 |
| for batch, buf in reader: |
| assert batch.equals(pa.RecordBatch.from_arrays( |
| [pa.array([expected_data[counter]])], |
| ['a'] |
| )) |
| assert buf is not None |
| client_counter, = struct.unpack('<i', buf.to_pybytes()) |
| assert counter == client_counter |
| writer.write(struct.pack('<i', counter)) |
| counter += 1 |
| assert reader.stats.num_messages == 6 |
| assert reader.stats.num_record_batches == 5 |
| |
| @staticmethod |
| def number_batches(table): |
| for idx, batch in enumerate(table.to_batches()): |
| buf = struct.pack('<i', idx) |
| yield batch, buf |
| |
| |
| class EchoFlightServer(FlightServerBase): |
| """A Flight server that returns the last data uploaded.""" |
| |
| def __init__(self, location=None, expected_schema=None, **kwargs): |
| super().__init__(location, **kwargs) |
| self.last_message = None |
| self.expected_schema = expected_schema |
| |
| def do_get(self, context, ticket): |
| return flight.RecordBatchStream(self.last_message) |
| |
| def do_put(self, context, descriptor, reader, writer): |
| if self.expected_schema: |
| assert self.expected_schema == reader.schema |
| self.last_message = reader.read_all() |
| |
| def do_exchange(self, context, descriptor, reader, writer): |
| for chunk in reader: |
| pass |
| |
| |
| class EchoStreamFlightServer(EchoFlightServer): |
| """An echo server that streams individual record batches.""" |
| |
| def do_get(self, context, ticket): |
| return flight.GeneratorStream( |
| self.last_message.schema, |
| self.last_message.to_batches(max_chunksize=1024)) |
| |
| def list_actions(self, context): |
| return [] |
| |
| def do_action(self, context, action): |
| if action.type == "who-am-i": |
| return [context.peer_identity(), context.peer().encode("utf-8")] |
| raise NotImplementedError |
| |
| |
| class GetInfoFlightServer(FlightServerBase): |
| """A Flight server that tests GetFlightInfo.""" |
| |
| def get_flight_info(self, context, descriptor): |
| return flight.FlightInfo( |
| pa.schema([('a', pa.int32())]), |
| descriptor, |
| [ |
| flight.FlightEndpoint(b'', ['grpc://test']), |
| flight.FlightEndpoint( |
| b'', |
| [flight.Location.for_grpc_tcp('localhost', 5005)], |
| pa.scalar("2023-04-05T12:34:56.789012345").cast(pa.timestamp("ns")), |
| "endpoint app metadata" |
| ), |
| ], |
| 1, |
| 42, |
| True, |
| "info app metadata" |
| ) |
| |
| def get_schema(self, context, descriptor): |
| info = self.get_flight_info(context, descriptor) |
| return flight.SchemaResult(info.schema) |
| |
| |
| class ListActionsFlightServer(FlightServerBase): |
| """A Flight server that tests ListActions.""" |
| |
| @classmethod |
| def expected_actions(cls): |
| return [ |
| ("action-1", "description"), |
| ("action-2", ""), |
| flight.ActionType("action-3", "more detail"), |
| ] |
| |
| def list_actions(self, context): |
| yield from self.expected_actions() |
| |
| |
| class ListActionsErrorFlightServer(FlightServerBase): |
| """A Flight server that tests ListActions.""" |
| |
| def list_actions(self, context): |
| yield ("action-1", "") |
| yield "foo" |
| |
| |
| class CheckTicketFlightServer(FlightServerBase): |
| """A Flight server that compares the given ticket to an expected value.""" |
| |
| def __init__(self, expected_ticket, location=None, **kwargs): |
| super().__init__(location, **kwargs) |
| self.expected_ticket = expected_ticket |
| |
| def do_get(self, context, ticket): |
| assert self.expected_ticket == ticket.ticket |
| data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] |
| table = pa.Table.from_arrays(data1, names=['a']) |
| return flight.RecordBatchStream(table) |
| |
| def do_put(self, context, descriptor, reader): |
| self.last_message = reader.read_all() |
| |
| |
| class InvalidStreamFlightServer(FlightServerBase): |
| """A Flight server that tries to return messages with differing schemas.""" |
| |
| schema = pa.schema([('a', pa.int32())]) |
| |
| def do_get(self, context, ticket): |
| data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] |
| data2 = [pa.array([-10.0, -5.0, 0.0, 5.0, 10.0], type=pa.float64())] |
| assert data1.type != data2.type |
| table1 = pa.Table.from_arrays(data1, names=['a']) |
| table2 = pa.Table.from_arrays(data2, names=['a']) |
| assert table1.schema == self.schema |
| |
| return flight.GeneratorStream(self.schema, [table1, table2]) |
| |
| |
| class NeverSendsDataFlightServer(FlightServerBase): |
| """A Flight server that never actually yields data.""" |
| |
| schema = pa.schema([('a', pa.int32())]) |
| |
| def do_get(self, context, ticket): |
| if ticket.ticket == b'yield_data': |
| # Check that the server handler will ignore empty tables |
| # up to a certain extent |
| data = [ |
| self.schema.empty_table(), |
| self.schema.empty_table(), |
| pa.RecordBatch.from_arrays([range(5)], schema=self.schema), |
| ] |
| return flight.GeneratorStream(self.schema, data) |
| return flight.GeneratorStream( |
| self.schema, itertools.repeat(self.schema.empty_table())) |
| |
| |
| class SlowFlightServer(FlightServerBase): |
| """A Flight server that delays its responses to test timeouts.""" |
| |
| def do_get(self, context, ticket): |
| return flight.GeneratorStream(pa.schema([('a', pa.int32())]), |
| self.slow_stream()) |
| |
| def do_action(self, context, action): |
| time.sleep(0.5) |
| return [] |
| |
| @staticmethod |
| def slow_stream(): |
| data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] |
| yield pa.Table.from_arrays(data1, names=['a']) |
| # The second message should never get sent; the client should |
| # cancel before we send this |
| time.sleep(10) |
| yield pa.Table.from_arrays(data1, names=['a']) |
| |
| |
| class ErrorFlightServer(FlightServerBase): |
| """A Flight server that uses all the Flight-specific errors.""" |
| |
| @staticmethod |
| def error_cases(): |
| return { |
| "internal": flight.FlightInternalError, |
| "timedout": flight.FlightTimedOutError, |
| "cancel": flight.FlightCancelledError, |
| "unauthenticated": flight.FlightUnauthenticatedError, |
| "unauthorized": flight.FlightUnauthorizedError, |
| "notimplemented": NotImplementedError, |
| "invalid": pa.ArrowInvalid, |
| "key": KeyError, |
| } |
| |
| def do_action(self, context, action): |
| error_cases = ErrorFlightServer.error_cases() |
| if action.type in error_cases: |
| raise error_cases[action.type]("foo") |
| elif action.type == "protobuf": |
| err_msg = b'this is an error message' |
| raise flight.FlightUnauthorizedError("foo", err_msg) |
| raise NotImplementedError |
| |
| def list_flights(self, context, criteria): |
| yield flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path('/foo'), |
| [] |
| ) |
| raise flight.FlightInternalError("foo") |
| |
| def do_put(self, context, descriptor, reader, writer): |
| if descriptor.command == b"internal": |
| raise flight.FlightInternalError("foo") |
| elif descriptor.command == b"timedout": |
| raise flight.FlightTimedOutError("foo") |
| elif descriptor.command == b"cancel": |
| raise flight.FlightCancelledError("foo") |
| elif descriptor.command == b"unauthenticated": |
| raise flight.FlightUnauthenticatedError("foo") |
| elif descriptor.command == b"unauthorized": |
| raise flight.FlightUnauthorizedError("foo") |
| elif descriptor.command == b"protobuf": |
| err_msg = b'this is an error message' |
| raise flight.FlightUnauthorizedError("foo", err_msg) |
| |
| |
| class ExchangeFlightServer(FlightServerBase): |
| """A server for testing DoExchange.""" |
| |
| def __init__(self, options=None, **kwargs): |
| super().__init__(**kwargs) |
| self.options = options |
| |
| def do_exchange(self, context, descriptor, reader, writer): |
| assert reader.stats.num_messages == 0 |
| if descriptor.descriptor_type != flight.DescriptorType.CMD: |
| raise pa.ArrowInvalid("Must provide a command descriptor") |
| elif descriptor.command == b"echo": |
| return self.exchange_echo(context, reader, writer) |
| elif descriptor.command == b"get": |
| return self.exchange_do_get(context, reader, writer) |
| elif descriptor.command == b"put": |
| return self.exchange_do_put(context, reader, writer) |
| elif descriptor.command == b"transform": |
| return self.exchange_transform(context, reader, writer) |
| else: |
| raise pa.ArrowInvalid( |
| f"Unknown command: {descriptor.command}") |
| |
| def exchange_do_get(self, context, reader, writer): |
| """Emulate DoGet with DoExchange.""" |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024)) |
| ], names=["a"]) |
| writer.begin(data.schema) |
| writer.write_table(data) |
| |
| def exchange_do_put(self, context, reader, writer): |
| """Emulate DoPut with DoExchange.""" |
| num_batches = 0 |
| for chunk in reader: |
| if not chunk.data: |
| raise pa.ArrowInvalid("All chunks must have data.") |
| assert reader.stats.num_messages != 0 |
| num_batches += 1 |
| assert reader.stats.num_record_batches == num_batches |
| writer.write_metadata(str(num_batches).encode("utf-8")) |
| |
| def exchange_echo(self, context, reader, writer): |
| """Run a simple echo server.""" |
| assert reader.stats.num_messages == 0 |
| started = False |
| for chunk in reader: |
| if not started and chunk.data: |
| writer.begin(chunk.data.schema, options=self.options) |
| started = True |
| if chunk.app_metadata and chunk.data: |
| writer.write_with_metadata(chunk.data, chunk.app_metadata) |
| elif chunk.app_metadata: |
| writer.write_metadata(chunk.app_metadata) |
| elif chunk.data: |
| assert reader.stats.num_messages != 0 |
| writer.write_batch(chunk.data) |
| else: |
| assert False, "Should not happen" |
| |
| def exchange_transform(self, context, reader, writer): |
| """Sum rows in an uploaded table.""" |
| assert reader.stats.num_messages == 0 |
| for field in reader.schema: |
| if not pa.types.is_integer(field.type): |
| raise pa.ArrowInvalid("Invalid field: " + repr(field)) |
| table = reader.read_all() |
| assert reader.stats.num_messages != 0 |
| sums = [0] * table.num_rows |
| for column in table: |
| for row, value in enumerate(column): |
| sums[row] += value.as_py() |
| result = pa.Table.from_arrays([pa.array(sums)], names=["sum"]) |
| writer.begin(result.schema) |
| writer.write_table(result) |
| |
| |
| class HttpBasicServerAuthHandler(ServerAuthHandler): |
| """An example implementation of HTTP basic authentication.""" |
| |
| def __init__(self, creds): |
| super().__init__() |
| self.creds = creds |
| |
| def authenticate(self, outgoing, incoming): |
| buf = incoming.read() |
| auth = flight.BasicAuth.deserialize(buf) |
| if auth.username not in self.creds: |
| raise flight.FlightUnauthenticatedError("unknown user") |
| if self.creds[auth.username] != auth.password: |
| raise flight.FlightUnauthenticatedError("wrong password") |
| outgoing.write(tobytes(auth.username)) |
| |
| def is_valid(self, token): |
| if not token: |
| raise flight.FlightUnauthenticatedError("token not provided") |
| if token not in self.creds: |
| raise flight.FlightUnauthenticatedError("unknown user") |
| return token |
| |
| |
| class HttpBasicClientAuthHandler(ClientAuthHandler): |
| """An example implementation of HTTP basic authentication.""" |
| |
| def __init__(self, username, password): |
| super().__init__() |
| self.basic_auth = flight.BasicAuth(username, password) |
| self.token = None |
| |
| def authenticate(self, outgoing, incoming): |
| auth = self.basic_auth.serialize() |
| outgoing.write(auth) |
| self.token = incoming.read() |
| |
| def get_token(self): |
| return self.token |
| |
| |
| class TokenServerAuthHandler(ServerAuthHandler): |
| """An example implementation of authentication via handshake.""" |
| |
| def __init__(self, creds): |
| super().__init__() |
| self.creds = creds |
| |
| def authenticate(self, outgoing, incoming): |
| username = incoming.read() |
| password = incoming.read() |
| if username in self.creds and self.creds[username] == password: |
| outgoing.write(base64.b64encode(b'secret:' + username)) |
| else: |
| raise flight.FlightUnauthenticatedError( |
| "invalid username/password") |
| |
| def is_valid(self, token): |
| token = base64.b64decode(token) |
| if not token.startswith(b'secret:'): |
| raise flight.FlightUnauthenticatedError("invalid token") |
| return token[7:] |
| |
| |
| class TokenClientAuthHandler(ClientAuthHandler): |
| """An example implementation of authentication via handshake.""" |
| |
| def __init__(self, username, password): |
| super().__init__() |
| self.username = username |
| self.password = password |
| self.token = b'' |
| |
| def authenticate(self, outgoing, incoming): |
| outgoing.write(self.username) |
| outgoing.write(self.password) |
| self.token = incoming.read() |
| |
| def get_token(self): |
| return self.token |
| |
| |
| class NoopAuthHandler(ServerAuthHandler): |
| """A no-op auth handler.""" |
| |
| def authenticate(self, outgoing, incoming): |
| """Do nothing.""" |
| |
| def is_valid(self, token): |
| """ |
| Returning an empty string. |
| Returning None causes Type error. |
| """ |
| return "" |
| |
| |
| def case_insensitive_header_lookup(headers, lookup_key): |
| """Lookup the value of given key in the given headers. |
| The key lookup is case-insensitive. |
| """ |
| for key in headers: |
| if key.lower() == lookup_key.lower(): |
| return headers.get(key) |
| |
| |
| class ClientHeaderAuthMiddlewareFactory(ClientMiddlewareFactory): |
| """ClientMiddlewareFactory that creates ClientAuthHeaderMiddleware.""" |
| |
| def __init__(self): |
| self.call_credential = [] |
| |
| def start_call(self, info): |
| return ClientHeaderAuthMiddleware(self) |
| |
| def set_call_credential(self, call_credential): |
| self.call_credential = call_credential |
| |
| |
| class ClientHeaderAuthMiddleware(ClientMiddleware): |
| """ |
| ClientMiddleware that extracts the authorization header |
| from the server. |
| |
| This is an example of a ClientMiddleware that can extract |
| the bearer token authorization header from a HTTP header |
| authentication enabled server. |
| |
| Parameters |
| ---------- |
| factory : ClientHeaderAuthMiddlewareFactory |
| This factory is used to set call credentials if an |
| authorization header is found in the headers from the server. |
| """ |
| |
| def __init__(self, factory): |
| self.factory = factory |
| |
| def received_headers(self, headers): |
| auth_header = case_insensitive_header_lookup(headers, 'Authorization') |
| if auth_header: |
| self.factory.set_call_credential([ |
| b'authorization', |
| auth_header[0].encode("utf-8")]) |
| |
| |
| class HeaderAuthServerMiddlewareFactory(ServerMiddlewareFactory): |
| """Validates incoming username and password.""" |
| |
| def start_call(self, info, headers): |
| auth_header = case_insensitive_header_lookup( |
| headers, |
| 'Authorization' |
| ) |
| values = auth_header[0].split(' ') |
| token = '' |
| error_message = 'Invalid credentials' |
| |
| if values[0] == 'Basic': |
| decoded = base64.b64decode(values[1]) |
| pair = decoded.decode("utf-8").split(':') |
| if not (pair[0] == 'test' and pair[1] == 'password'): |
| raise flight.FlightUnauthenticatedError(error_message) |
| token = 'token1234' |
| elif values[0] == 'Bearer': |
| token = values[1] |
| if not token == 'token1234': |
| raise flight.FlightUnauthenticatedError(error_message) |
| else: |
| raise flight.FlightUnauthenticatedError(error_message) |
| |
| return HeaderAuthServerMiddleware(token) |
| |
| |
| class HeaderAuthServerMiddleware(ServerMiddleware): |
| """A ServerMiddleware that transports incoming username and password.""" |
| |
| def __init__(self, token): |
| self.token = token |
| |
| def sending_headers(self): |
| return {'authorization': 'Bearer ' + self.token} |
| |
| |
| class HeaderAuthFlightServer(FlightServerBase): |
| """A Flight server that tests with basic token authentication. """ |
| |
| def do_action(self, context, action): |
| middleware = context.get_middleware("auth") |
| if middleware: |
| auth_header = case_insensitive_header_lookup( |
| middleware.sending_headers(), 'Authorization') |
| values = auth_header.split(' ') |
| return [values[1].encode("utf-8")] |
| raise flight.FlightUnauthenticatedError( |
| 'No token auth middleware found.') |
| |
| |
| class ArbitraryHeadersServerMiddlewareFactory(ServerMiddlewareFactory): |
| """A ServerMiddlewareFactory that transports arbitrary headers.""" |
| |
| def start_call(self, info, headers): |
| return ArbitraryHeadersServerMiddleware(headers) |
| |
| |
| class ArbitraryHeadersServerMiddleware(ServerMiddleware): |
| """A ServerMiddleware that transports arbitrary headers.""" |
| |
| def __init__(self, incoming): |
| self.incoming = incoming |
| |
| def sending_headers(self): |
| return self.incoming |
| |
| |
| class ArbitraryHeadersFlightServer(FlightServerBase): |
| """A Flight server that tests multiple arbitrary headers.""" |
| |
| def do_action(self, context, action): |
| middleware = context.get_middleware("arbitrary-headers") |
| if middleware: |
| headers = middleware.sending_headers() |
| header_1 = case_insensitive_header_lookup( |
| headers, |
| 'test-header-1' |
| ) |
| header_2 = case_insensitive_header_lookup( |
| headers, |
| 'test-header-2' |
| ) |
| value1 = header_1[0].encode("utf-8") |
| value2 = header_2[0].encode("utf-8") |
| return [value1, value2] |
| raise flight.FlightServerError("No headers middleware found") |
| |
| |
| class HeaderServerMiddleware(ServerMiddleware): |
| """Expose a per-call value to the RPC method body.""" |
| |
| def __init__(self, special_value): |
| self.special_value = special_value |
| |
| |
| class HeaderServerMiddlewareFactory(ServerMiddlewareFactory): |
| """Expose a per-call hard-coded value to the RPC method body.""" |
| |
| def start_call(self, info, headers): |
| return HeaderServerMiddleware("right value") |
| |
| |
| class HeaderFlightServer(FlightServerBase): |
| """Echo back the per-call hard-coded value.""" |
| |
| def do_action(self, context, action): |
| middleware = context.get_middleware("test") |
| if middleware: |
| return [middleware.special_value.encode()] |
| return [b""] |
| |
| |
| class MultiHeaderFlightServer(FlightServerBase): |
| """Test sending/receiving multiple (binary-valued) headers.""" |
| |
| def do_action(self, context, action): |
| middleware = context.get_middleware("test") |
| headers = repr(middleware.client_headers).encode("utf-8") |
| return [headers] |
| |
| |
| class SelectiveAuthServerMiddlewareFactory(ServerMiddlewareFactory): |
| """Deny access to certain methods based on a header.""" |
| |
| def start_call(self, info, headers): |
| if info.method == flight.FlightMethod.LIST_ACTIONS: |
| # No auth needed |
| return |
| |
| token = headers.get("x-auth-token") |
| if not token: |
| raise flight.FlightUnauthenticatedError("No token") |
| |
| token = token[0] |
| if token != "password": |
| raise flight.FlightUnauthenticatedError("Invalid token") |
| |
| return HeaderServerMiddleware(token) |
| |
| |
| class SelectiveAuthClientMiddlewareFactory(ClientMiddlewareFactory): |
| def start_call(self, info): |
| return SelectiveAuthClientMiddleware() |
| |
| |
| class SelectiveAuthClientMiddleware(ClientMiddleware): |
| def sending_headers(self): |
| return { |
| "x-auth-token": "password", |
| } |
| |
| |
| class RecordingServerMiddlewareFactory(ServerMiddlewareFactory): |
| """Record what methods were called.""" |
| |
| def __init__(self): |
| super().__init__() |
| self.methods = [] |
| |
| def start_call(self, info, headers): |
| self.methods.append(info.method) |
| return None |
| |
| |
| class RecordingClientMiddlewareFactory(ClientMiddlewareFactory): |
| """Record what methods were called.""" |
| |
| def __init__(self): |
| super().__init__() |
| self.methods = [] |
| |
| def start_call(self, info): |
| self.methods.append(info.method) |
| return None |
| |
| |
| class MultiHeaderClientMiddlewareFactory(ClientMiddlewareFactory): |
| """Test sending/receiving multiple (binary-valued) headers.""" |
| |
| def __init__(self): |
| # Read in test_middleware_multi_header below. |
| # The middleware instance will update this value. |
| self.last_headers = {} |
| |
| def start_call(self, info): |
| return MultiHeaderClientMiddleware(self) |
| |
| |
| class MultiHeaderClientMiddleware(ClientMiddleware): |
| """Test sending/receiving multiple (binary-valued) headers.""" |
| |
| EXPECTED = { |
| "x-text": ["foo", "bar"], |
| "x-binary-bin": [b"\x00", b"\x01"], |
| # ARROW-16606: ensure mixed-case headers are accepted |
| "x-MIXED-case": ["baz"], |
| b"x-other-MIXED-case": ["baz"], |
| } |
| |
| def __init__(self, factory): |
| self.factory = factory |
| |
| def sending_headers(self): |
| return self.EXPECTED |
| |
| def received_headers(self, headers): |
| # Let the test code know what the last set of headers we |
| # received were. |
| self.factory.last_headers.update(headers) |
| |
| |
| class MultiHeaderServerMiddlewareFactory(ServerMiddlewareFactory): |
| """Test sending/receiving multiple (binary-valued) headers.""" |
| |
| def start_call(self, info, headers): |
| return MultiHeaderServerMiddleware(headers) |
| |
| |
| class MultiHeaderServerMiddleware(ServerMiddleware): |
| """Test sending/receiving multiple (binary-valued) headers.""" |
| |
| def __init__(self, client_headers): |
| self.client_headers = client_headers |
| |
| def sending_headers(self): |
| return MultiHeaderClientMiddleware.EXPECTED |
| |
| |
| class LargeMetadataFlightServer(FlightServerBase): |
| """Regression test for ARROW-13253.""" |
| |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self._metadata = b' ' * (2 ** 31 + 1) |
| |
| def do_get(self, context, ticket): |
| schema = pa.schema([('a', pa.int64())]) |
| return flight.GeneratorStream(schema, [ |
| (pa.record_batch([[1]], schema=schema), self._metadata), |
| ]) |
| |
| def do_exchange(self, context, descriptor, reader, writer): |
| writer.write_metadata(self._metadata) |
| |
| |
| def test_repr(): |
| action_repr = "<pyarrow.flight.Action type='foo' body=(0 bytes)>" |
| action_type_repr = "ActionType(type='foo', description='bar')" |
| basic_auth_repr = "<pyarrow.flight.BasicAuth username=b'user' password=(redacted)>" |
| descriptor_repr = "<pyarrow.flight.FlightDescriptor cmd=b'foo'>" |
| endpoint_repr = ("<pyarrow.flight.FlightEndpoint " |
| "ticket=<pyarrow.flight.Ticket ticket=b'foo'> " |
| "locations=[] " |
| "expiration_time=2023-04-05 12:34:56+00:00 " |
| "app_metadata=b'endpoint app metadata'>") |
| info_repr = ( |
| "<pyarrow.flight.FlightInfo " |
| "schema= " |
| "descriptor=<pyarrow.flight.FlightDescriptor path=[]> " |
| "endpoints=[] " |
| "total_records=1 " |
| "total_bytes=42 " |
| "ordered=True " |
| "app_metadata=b'test app metadata'>") |
| location_repr = "<pyarrow.flight.Location b'grpc+tcp://localhost:1234'>" |
| result_repr = "<pyarrow.flight.Result body=(3 bytes)>" |
| schema_result_repr = "<pyarrow.flight.SchemaResult schema=()>" |
| ticket_repr = "<pyarrow.flight.Ticket ticket=b'foo'>" |
| |
| assert repr(flight.Action("foo", b"")) == action_repr |
| assert repr(flight.ActionType("foo", "bar")) == action_type_repr |
| assert repr(flight.BasicAuth("user", "pass")) == basic_auth_repr |
| assert repr(flight.FlightDescriptor.for_command("foo")) == descriptor_repr |
| endpoint = flight.FlightEndpoint( |
| b"foo", [], pa.scalar("2023-04-05T12:34:56").cast(pa.timestamp("s")), |
| b"endpoint app metadata" |
| ) |
| assert repr(endpoint) == endpoint_repr |
| info = flight.FlightInfo( |
| pa.schema([]), flight.FlightDescriptor.for_path(), [], |
| 1, 42, True, b"test app metadata" |
| ) |
| assert repr(info) == info_repr |
| assert repr(flight.Location("grpc+tcp://localhost:1234")) == location_repr |
| assert repr(flight.Result(b"foo")) == result_repr |
| assert repr(flight.SchemaResult(pa.schema([]))) == schema_result_repr |
| assert repr(flight.SchemaResult(pa.schema([("int", "int64")]))) == \ |
| "<pyarrow.flight.SchemaResult schema=(int: int64)>" |
| assert repr(flight.Ticket(b"foo")) == ticket_repr |
| assert info.schema == pa.schema([]) |
| |
| info = flight.FlightInfo( |
| None, flight.FlightDescriptor.for_path(), [], |
| 1, 42, True, b"test app metadata" |
| ) |
| info_repr = ( |
| "<pyarrow.flight.FlightInfo " |
| "schema=None " |
| "descriptor=<pyarrow.flight.FlightDescriptor path=[]> " |
| "endpoints=[] " |
| "total_records=1 " |
| "total_bytes=42 " |
| "ordered=True " |
| "app_metadata=b'test app metadata'>") |
| assert repr(info) == info_repr |
| assert info.schema is None |
| |
| with pytest.raises(TypeError): |
| flight.Action("foo", None) |
| |
| with pytest.raises(TypeError): |
| flight.FlightEndpoint(object(), []) |
| with pytest.raises(TypeError): |
| flight.FlightEndpoint("foo", ["grpc://test", b"grpc://test", object()]) |
| with pytest.raises(TypeError): |
| flight.FlightEndpoint("foo", [], expiration_time="2023-04-05T01:02:03") |
| with pytest.raises(TypeError): |
| flight.FlightEndpoint("foo", [], expiration_time=datetime(2023, 4, 5, 1, 2, 3)) |
| with pytest.raises(TypeError): |
| flight.FlightEndpoint("foo", [], app_metadata=object()) |
| |
| |
| def test_eq(): |
| items = [ |
| lambda: (flight.Action("foo", b""), flight.Action("bar", b"")), |
| lambda: (flight.Action("foo", b""), flight.Action("foo", b"bar")), |
| lambda: (flight.ActionType("foo", "bar"), |
| flight.ActionType("foo", "baz")), |
| lambda: (flight.BasicAuth("user", "pass"), |
| flight.BasicAuth("user2", "pass")), |
| lambda: (flight.BasicAuth("user", "pass"), |
| flight.BasicAuth("user", "pass2")), |
| lambda: (flight.FlightDescriptor.for_command("foo"), |
| flight.FlightDescriptor.for_path("foo")), |
| lambda: (flight.FlightEndpoint(b"foo", []), |
| flight.FlightEndpoint(b"bar", [])), |
| lambda: ( |
| flight.FlightEndpoint( |
| b"foo", [flight.Location("grpc+tcp://localhost:1234")]), |
| flight.FlightEndpoint( |
| b"foo", [flight.Location("grpc+tls://localhost:1234")]) |
| ), |
| lambda: ( |
| flight.FlightEndpoint( |
| b"foo", [], pa.scalar("2023-04-05T12:34:56").cast(pa.timestamp("s"))), |
| flight.FlightEndpoint( |
| b"foo", [], |
| pa.scalar("2023-04-05T12:34:56.789").cast(pa.timestamp("ms")))), |
| lambda: (flight.FlightEndpoint(b"foo", [], app_metadata=b''), |
| flight.FlightEndpoint(b"foo", [], app_metadata=b'meta')), |
| lambda: ( |
| flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), []), |
| flight.FlightInfo( |
| pa.schema([("ints", pa.int64())]), |
| flight.FlightDescriptor.for_path(), [])), |
| lambda: ( |
| flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), []), |
| flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_command(b"foo"), [])), |
| lambda: ( |
| flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), |
| [flight.FlightEndpoint(b"foo", [])]), |
| flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), |
| [flight.FlightEndpoint(b"bar", [])])), |
| lambda: ( |
| flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), [], total_records=-1), |
| flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), [], total_records=1)), |
| lambda: ( |
| flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), [], total_bytes=-1), |
| flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), [], total_bytes=42)), |
| lambda: ( |
| flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), [], ordered=False), |
| flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), [], ordered=True)), |
| lambda: ( |
| flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), [], app_metadata=b""), |
| flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), [], app_metadata=b"meta")), |
| lambda: (flight.Location("grpc+tcp://localhost:1234"), |
| flight.Location("grpc+tls://localhost:1234")), |
| lambda: (flight.Result(b"foo"), flight.Result(b"bar")), |
| lambda: (flight.SchemaResult(pa.schema([])), |
| flight.SchemaResult(pa.schema([("ints", pa.int64())]))), |
| lambda: (flight.Ticket(b""), flight.Ticket(b"foo")), |
| ] |
| |
| for gen in items: |
| lhs1, rhs1 = gen() |
| lhs2, rhs2 = gen() |
| assert lhs1 == lhs1 |
| assert lhs1 == lhs2 |
| assert lhs2 == lhs1 |
| assert rhs1 == rhs1 |
| assert rhs1 == rhs2 |
| assert rhs2 == rhs1 |
| assert lhs1 != rhs1 |
| |
| |
| def test_flight_info_defaults(): |
| fi1 = flight.FlightInfo(pa.schema([]), flight.FlightDescriptor.for_path(), []) |
| fi2 = flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), [], total_records=-1, total_bytes=-1) |
| fi3 = flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path(), [], total_records=None, total_bytes=None) |
| |
| assert fi1.total_records == -1 |
| assert fi2.total_records == -1 |
| assert fi3.total_records == -1 |
| |
| assert fi1.total_bytes == -1 |
| assert fi2.total_bytes == -1 |
| assert fi3.total_bytes == -1 |
| |
| |
| def test_flight_server_location_argument(): |
| locations = [ |
| None, |
| 'grpc://localhost:0', |
| ('localhost', find_free_port()), |
| ] |
| for location in locations: |
| with FlightServerBase(location) as server: |
| assert isinstance(server, FlightServerBase) |
| |
| |
| def test_server_exit_reraises_exception(): |
| with pytest.raises(ValueError): |
| with FlightServerBase(): |
| raise ValueError() |
| |
| |
| @pytest.mark.threading |
| @pytest.mark.slow |
| def test_client_wait_for_available(): |
| location = ('localhost', find_free_port()) |
| server = None |
| |
| def serve(): |
| global server |
| time.sleep(0.5) |
| server = FlightServerBase(location) |
| server.serve() |
| |
| with FlightClient(location) as client: |
| thread = threading.Thread(target=serve, daemon=True) |
| thread.start() |
| |
| started = time.time() |
| client.wait_for_available(timeout=5) |
| elapsed = time.time() - started |
| assert elapsed >= 0.5 |
| |
| |
| def test_flight_list_flights(): |
| """Try a simple list_flights call.""" |
| with ConstantFlightServer() as server, \ |
| flight.connect(('localhost', server.port)) as client: |
| assert list(client.list_flights()) == [] |
| flights = client.list_flights(ConstantFlightServer.CRITERIA) |
| assert len(list(flights)) == 1 |
| |
| |
| def test_flight_client_close(): |
| with ConstantFlightServer() as server, \ |
| flight.connect(('localhost', server.port)) as client: |
| assert list(client.list_flights()) == [] |
| client.close() |
| client.close() # Idempotent |
| with pytest.raises(pa.ArrowInvalid): |
| list(client.list_flights()) |
| |
| |
| def test_flight_do_get_ints(): |
| """Try a simple do_get call.""" |
| table = simple_ints_table() |
| |
| with ConstantFlightServer() as server, \ |
| flight.connect(('localhost', server.port)) as client: |
| data = client.do_get(flight.Ticket(b'ints')).read_all() |
| assert data.equals(table) |
| |
| options = pa.ipc.IpcWriteOptions( |
| metadata_version=pa.ipc.MetadataVersion.V4) |
| with ConstantFlightServer(options=options) as server, \ |
| flight.connect(('localhost', server.port)) as client: |
| data = client.do_get(flight.Ticket(b'ints')).read_all() |
| assert data.equals(table) |
| |
| # Also test via RecordBatchReader interface |
| data = client.do_get(flight.Ticket(b'ints')).to_reader().read_all() |
| assert data.equals(table) |
| |
| with pytest.raises(flight.FlightServerError, |
| match="expected IpcWriteOptions, got <class 'int'>"): |
| with ConstantFlightServer(options=42) as server, \ |
| flight.connect(('localhost', server.port)) as client: |
| data = client.do_get(flight.Ticket(b'ints')).read_all() |
| |
| |
| @pytest.mark.pandas |
| def test_do_get_ints_pandas(): |
| """Try a simple do_get call.""" |
| table = simple_ints_table() |
| |
| with ConstantFlightServer() as server, \ |
| flight.connect(('localhost', server.port)) as client: |
| data = client.do_get(flight.Ticket(b'ints')).read_pandas() |
| assert list(data['some_ints']) == table.column(0).to_pylist() |
| |
| |
| def test_flight_do_get_dicts(): |
| table = simple_dicts_table() |
| |
| with ConstantFlightServer() as server, \ |
| flight.connect(('localhost', server.port)) as client: |
| reader = client.do_get(flight.Ticket(b'dicts')) |
| assert reader.stats.num_messages == 1 |
| data = reader.read_all() |
| assert data.equals(table) |
| assert reader.stats == ReadStats( |
| num_messages=6, |
| num_record_batches=3, |
| num_dictionary_batches=2, |
| num_dictionary_deltas=0, |
| num_replaced_dictionaries=1 |
| ) |
| |
| |
| def test_flight_do_get_ticket(): |
| """Make sure Tickets get passed to the server.""" |
| data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] |
| table = pa.Table.from_arrays(data1, names=['a']) |
| with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server, \ |
| flight.connect(('localhost', server.port)) as client: |
| data = client.do_get(flight.Ticket(b'the-ticket')).read_all() |
| assert data.equals(table) |
| |
| |
| def test_flight_get_info(): |
| """Make sure FlightEndpoint accepts string and object URIs.""" |
| with GetInfoFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| info = client.get_flight_info(flight.FlightDescriptor.for_command(b'')) |
| assert info.total_records == 1 |
| assert info.total_bytes == 42 |
| assert info.ordered |
| assert info.app_metadata == b"info app metadata" |
| assert info.schema == pa.schema([('a', pa.int32())]) |
| assert len(info.endpoints) == 2 |
| assert len(info.endpoints[0].locations) == 1 |
| assert info.endpoints[0].expiration_time is None |
| assert info.endpoints[0].app_metadata == b"" |
| assert info.endpoints[0].locations[0] == flight.Location('grpc://test') |
| assert info.endpoints[1].expiration_time == \ |
| pa.scalar("2023-04-05T12:34:56.789012345+00:00") \ |
| .cast(pa.timestamp("ns", "UTC")) |
| assert info.endpoints[1].app_metadata == b"endpoint app metadata" |
| assert info.endpoints[1].locations[0] == \ |
| flight.Location.for_grpc_tcp('localhost', 5005) |
| |
| |
| def test_flight_get_schema(): |
| """Make sure GetSchema returns correct schema.""" |
| with GetInfoFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| info = client.get_schema(flight.FlightDescriptor.for_command(b'')) |
| assert info.schema == pa.schema([('a', pa.int32())]) |
| |
| |
| def test_list_actions(): |
| """Make sure the return type of ListActions is validated.""" |
| # ARROW-6392 |
| with ListActionsErrorFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| with pytest.raises( |
| flight.FlightServerError, |
| match=("Results of list_actions must be " |
| "ActionType or tuple") |
| ): |
| list(client.list_actions()) |
| |
| with ListActionsFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| assert list(client.list_actions()) == \ |
| ListActionsFlightServer.expected_actions() |
| |
| |
| class ConvenienceServer(FlightServerBase): |
| """ |
| Server for testing various implementation conveniences (auto-boxing, etc.) |
| """ |
| |
| @property |
| def simple_action_results(self): |
| return [b'foo', b'bar', b'baz'] |
| |
| def do_action(self, context, action): |
| if action.type == 'simple-action': |
| return self.simple_action_results |
| elif action.type == 'echo': |
| return [action.body] |
| elif action.type == 'bad-action': |
| return ['foo'] |
| elif action.type == 'arrow-exception': |
| raise pa.ArrowMemoryError() |
| elif action.type == 'forever': |
| def gen(): |
| while not context.is_cancelled(): |
| yield b'foo' |
| return gen() |
| |
| |
| def test_do_action_result_convenience(): |
| with ConvenienceServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| |
| # do_action as action type without body |
| results = [x.body for x in client.do_action('simple-action')] |
| assert results == server.simple_action_results |
| |
| # do_action with tuple of type and body |
| body = b'the-body' |
| results = [x.body for x in client.do_action(('echo', body))] |
| assert results == [body] |
| |
| |
| def test_nicer_server_exceptions(): |
| with ConvenienceServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| with pytest.raises(flight.FlightServerError, |
| match="a bytes-like object is required"): |
| list(client.do_action('bad-action')) |
| # While Flight/C++ sends across the original status code, it |
| # doesn't get mapped to the equivalent code here, since we |
| # want to be able to distinguish between client- and server- |
| # side errors. |
| with pytest.raises(flight.FlightServerError, |
| match="ArrowMemoryError"): |
| list(client.do_action('arrow-exception')) |
| |
| |
| def test_get_port(): |
| """Make sure port() works.""" |
| server = GetInfoFlightServer("grpc://localhost:0") |
| try: |
| assert server.port > 0 |
| finally: |
| server.shutdown() |
| |
| |
| @pytest.mark.skipif(os.name == 'nt', |
| reason="Unix sockets can't be tested on Windows") |
| def test_flight_domain_socket(): |
| """Try a simple do_get call over a Unix domain socket.""" |
| with tempfile.NamedTemporaryFile() as sock: |
| sock.close() |
| location = flight.Location.for_grpc_unix(sock.name) |
| with ConstantFlightServer(location=location), \ |
| FlightClient(location) as client: |
| |
| reader = client.do_get(flight.Ticket(b'ints')) |
| table = simple_ints_table() |
| assert reader.schema.equals(table.schema) |
| data = reader.read_all() |
| assert data.equals(table) |
| |
| reader = client.do_get(flight.Ticket(b'dicts')) |
| table = simple_dicts_table() |
| assert reader.schema.equals(table.schema) |
| data = reader.read_all() |
| assert data.equals(table) |
| |
| |
| @pytest.mark.slow |
| def test_flight_large_message(): |
| """Try sending/receiving a large message via Flight. |
| |
| See ARROW-4421: by default, gRPC won't allow us to send messages > |
| 4MiB in size. |
| """ |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024 * 1024)) |
| ], names=['a']) |
| |
| with EchoFlightServer(expected_schema=data.schema) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), |
| data.schema) |
| # Write a single giant chunk |
| writer.write_table(data, 10 * 1024 * 1024) |
| writer.close() |
| result = client.do_get(flight.Ticket(b'')).read_all() |
| assert result.equals(data) |
| |
| |
| def test_flight_generator_stream(): |
| """Try downloading a flight of RecordBatches in a GeneratorStream.""" |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024)) |
| ], names=['a']) |
| |
| with EchoStreamFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), |
| data.schema) |
| writer.write_table(data) |
| writer.close() |
| result = client.do_get(flight.Ticket(b'')).read_all() |
| assert result.equals(data) |
| |
| |
| def test_flight_invalid_generator_stream(): |
| """Try streaming data with mismatched schemas.""" |
| with InvalidStreamFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| with pytest.raises(pa.ArrowException): |
| client.do_get(flight.Ticket(b'')).read_all() |
| |
| |
| def test_timeout_fires(): |
| """Make sure timeouts fire on slow requests.""" |
| # Do this in a separate thread so that if it fails, we don't hang |
| # the entire test process |
| with SlowFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| action = flight.Action("", b"") |
| options = flight.FlightCallOptions(timeout=0.2) |
| # gRPC error messages change based on version, so don't look |
| # for a particular error |
| with pytest.raises(flight.FlightTimedOutError): |
| list(client.do_action(action, options=options)) |
| |
| |
| def test_timeout_passes(): |
| """Make sure timeouts do not fire on fast requests.""" |
| with ConstantFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| options = flight.FlightCallOptions(timeout=5.0) |
| client.do_get(flight.Ticket(b'ints'), options=options).read_all() |
| |
| |
| def test_read_options(): |
| """Make sure ReadOptions can be used.""" |
| expected = pa.Table.from_arrays([pa.array([1, 2, 3, 4])], names=["b"]) |
| with ConstantFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| options = flight.FlightCallOptions( |
| read_options=IpcReadOptions(included_fields=[1])) |
| response1 = client.do_get(flight.Ticket( |
| b'multi'), options=options).read_all() |
| response2 = client.do_get(flight.Ticket(b'multi')).read_all() |
| |
| assert response2.num_columns == 2 |
| assert response1.num_columns == 1 |
| assert response1 == expected |
| assert response2 == multiple_column_table() |
| |
| |
| basic_auth_handler = HttpBasicServerAuthHandler(creds={ |
| b"test": b"p4ssw0rd", |
| }) |
| |
| token_auth_handler = TokenServerAuthHandler(creds={ |
| b"test": b"p4ssw0rd", |
| }) |
| |
| |
| @pytest.mark.slow |
| def test_http_basic_unauth(): |
| """Test that auth fails when not authenticated.""" |
| with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| action = flight.Action("who-am-i", b"") |
| with pytest.raises(flight.FlightUnauthenticatedError, |
| match=".*unauthenticated.*"): |
| list(client.do_action(action)) |
| |
| |
| @pytest.mark.skipif(os.name == 'nt', |
| reason="ARROW-10013: gRPC on Windows corrupts peer()") |
| def test_http_basic_auth(): |
| """Test a Python implementation of HTTP basic authentication.""" |
| with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| action = flight.Action("who-am-i", b"") |
| client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd')) |
| results = client.do_action(action) |
| identity = next(results) |
| assert identity.body.to_pybytes() == b'test' |
| peer_address = next(results) |
| assert peer_address.body.to_pybytes() != b'' |
| |
| |
| def test_http_basic_auth_invalid_password(): |
| """Test that auth fails with the wrong password.""" |
| with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| action = flight.Action("who-am-i", b"") |
| with pytest.raises(flight.FlightUnauthenticatedError, |
| match=".*wrong password.*"): |
| client.authenticate(HttpBasicClientAuthHandler('test', 'wrong')) |
| next(client.do_action(action)) |
| |
| |
| def test_token_auth(): |
| """Test an auth mechanism that uses a handshake.""" |
| with EchoStreamFlightServer(auth_handler=token_auth_handler) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| action = flight.Action("who-am-i", b"") |
| client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd')) |
| identity = next(client.do_action(action)) |
| assert identity.body.to_pybytes() == b'test' |
| |
| |
| def test_token_auth_invalid(): |
| """Test an auth mechanism that uses a handshake.""" |
| with EchoStreamFlightServer(auth_handler=token_auth_handler) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| with pytest.raises(flight.FlightUnauthenticatedError): |
| client.authenticate(TokenClientAuthHandler('test', 'wrong')) |
| |
| |
| header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory() |
| no_op_auth_handler = NoopAuthHandler() |
| |
| |
| def test_authenticate_basic_token(): |
| """Test authenticate_basic_token with bearer token and auth headers.""" |
| with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ |
| "auth": HeaderAuthServerMiddlewareFactory() |
| }) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| token_pair = client.authenticate_basic_token(b'test', b'password') |
| assert token_pair[0] == b'authorization' |
| assert token_pair[1] == b'Bearer token1234' |
| |
| |
| def test_authenticate_basic_token_invalid_password(): |
| """Test authenticate_basic_token with an invalid password.""" |
| with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ |
| "auth": HeaderAuthServerMiddlewareFactory() |
| }) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| with pytest.raises(flight.FlightUnauthenticatedError): |
| client.authenticate_basic_token(b'test', b'badpassword') |
| |
| |
| def test_authenticate_basic_token_and_action(): |
| """Test authenticate_basic_token and doAction after authentication.""" |
| with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ |
| "auth": HeaderAuthServerMiddlewareFactory() |
| }) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| token_pair = client.authenticate_basic_token(b'test', b'password') |
| assert token_pair[0] == b'authorization' |
| assert token_pair[1] == b'Bearer token1234' |
| options = flight.FlightCallOptions(headers=[token_pair]) |
| result = list(client.do_action( |
| action=flight.Action('test-action', b''), options=options)) |
| assert result[0].body.to_pybytes() == b'token1234' |
| |
| |
| def test_authenticate_basic_token_with_client_middleware(): |
| """Test authenticate_basic_token with client middleware |
| to intercept authorization header returned by the |
| HTTP header auth enabled server. |
| """ |
| with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ |
| "auth": HeaderAuthServerMiddlewareFactory() |
| }) as server: |
| client_auth_middleware = ClientHeaderAuthMiddlewareFactory() |
| client = FlightClient( |
| ('localhost', server.port), |
| middleware=[client_auth_middleware] |
| ) |
| encoded_credentials = base64.b64encode(b'test:password') |
| options = flight.FlightCallOptions(headers=[ |
| (b'authorization', b'Basic ' + encoded_credentials) |
| ]) |
| result = list(client.do_action( |
| action=flight.Action('test-action', b''), options=options)) |
| assert result[0].body.to_pybytes() == b'token1234' |
| assert client_auth_middleware.call_credential[0] == b'authorization' |
| assert client_auth_middleware.call_credential[1] == \ |
| b'Bearer ' + b'token1234' |
| result2 = list(client.do_action( |
| action=flight.Action('test-action', b''), options=options)) |
| assert result2[0].body.to_pybytes() == b'token1234' |
| assert client_auth_middleware.call_credential[0] == b'authorization' |
| assert client_auth_middleware.call_credential[1] == \ |
| b'Bearer ' + b'token1234' |
| client.close() |
| |
| |
| def test_arbitrary_headers_in_flight_call_options(): |
| """Test passing multiple arbitrary headers to the middleware.""" |
| with ArbitraryHeadersFlightServer( |
| auth_handler=no_op_auth_handler, |
| middleware={ |
| "auth": HeaderAuthServerMiddlewareFactory(), |
| "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory() |
| }) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| token_pair = client.authenticate_basic_token(b'test', b'password') |
| assert token_pair[0] == b'authorization' |
| assert token_pair[1] == b'Bearer token1234' |
| options = flight.FlightCallOptions(headers=[ |
| token_pair, |
| (b'test-header-1', b'value1'), |
| (b'test-header-2', b'value2') |
| ]) |
| result = list(client.do_action(flight.Action( |
| "test-action", b""), options=options)) |
| assert result[0].body.to_pybytes() == b'value1' |
| assert result[1].body.to_pybytes() == b'value2' |
| |
| |
| def test_location_invalid(): |
| """Test constructing invalid URIs.""" |
| with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): |
| flight.connect("%") |
| |
| with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): |
| ConstantFlightServer("%") |
| |
| |
| def test_location_unknown_scheme(): |
| """Test creating locations for unknown schemes.""" |
| assert flight.Location("s3://foo").uri == b"s3://foo" |
| assert flight.Location("https://example.com/bar.parquet").uri == \ |
| b"https://example.com/bar.parquet" |
| |
| |
| @pytest.mark.slow |
| @pytest.mark.requires_testing_data |
| def test_tls_fails(): |
| """Make sure clients cannot connect when cert verification fails.""" |
| certs = example_tls_certs() |
| |
| # Ensure client doesn't connect when certificate verification |
| # fails (this is a slow test since gRPC does retry a few times) |
| with ConstantFlightServer(tls_certificates=certs["certificates"]) as s, \ |
| FlightClient("grpc+tls://localhost:" + str(s.port)) as client: |
| # gRPC error messages change based on version, so don't look |
| # for a particular error |
| with pytest.raises(flight.FlightUnavailableError): |
| client.do_get(flight.Ticket(b'ints')).read_all() |
| |
| |
| @pytest.mark.requires_testing_data |
| def test_tls_do_get(): |
| """Try a simple do_get call over TLS.""" |
| table = simple_ints_table() |
| certs = example_tls_certs() |
| |
| with ConstantFlightServer(tls_certificates=certs["certificates"]) as s, \ |
| FlightClient(('localhost', s.port), |
| tls_root_certs=certs["root_cert"]) as client: |
| data = client.do_get(flight.Ticket(b'ints')).read_all() |
| assert data.equals(table) |
| |
| |
| @pytest.mark.requires_testing_data |
| def test_tls_disable_server_verification(): |
| """Try a simple do_get call over TLS with server verification disabled.""" |
| table = simple_ints_table() |
| certs = example_tls_certs() |
| |
| with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: |
| try: |
| client = FlightClient(('localhost', s.port), |
| disable_server_verification=True) |
| except NotImplementedError: |
| pytest.skip('disable_server_verification feature is not available') |
| data = client.do_get(flight.Ticket(b'ints')).read_all() |
| assert data.equals(table) |
| client.close() |
| |
| |
| @pytest.mark.requires_testing_data |
| def test_tls_override_hostname(): |
| """Check that incorrectly overriding the hostname fails.""" |
| certs = example_tls_certs() |
| |
| with ConstantFlightServer(tls_certificates=certs["certificates"]) as s, \ |
| flight.connect(('localhost', s.port), |
| tls_root_certs=certs["root_cert"], |
| override_hostname="fakehostname") as client: |
| with pytest.raises(flight.FlightUnavailableError): |
| client.do_get(flight.Ticket(b'ints')) |
| |
| |
| def test_flight_do_get_metadata(): |
| """Try a simple do_get call with metadata.""" |
| data = [ |
| pa.array([-10, -5, 0, 5, 10]) |
| ] |
| table = pa.Table.from_arrays(data, names=['a']) |
| |
| batches = [] |
| with MetadataFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| reader = client.do_get(flight.Ticket(b'')) |
| idx = 0 |
| for batch, metadata in reader: |
| batches.append(batch) |
| server_idx, = struct.unpack('<i', metadata.to_pybytes()) |
| assert idx == server_idx |
| idx += 1 |
| data = pa.Table.from_batches(batches) |
| assert data.equals(table) |
| |
| |
| def test_flight_metadata_record_batch_reader_iterator(): |
| """Verify the iterator interface works as expected.""" |
| batches1 = [] |
| batches2 = [] |
| |
| with MetadataFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| reader = client.do_get(flight.Ticket(b'')) |
| idx = 0 |
| while True: |
| try: |
| batch, metadata = reader.read_chunk() |
| batches1.append(batch) |
| server_idx, = struct.unpack('<i', metadata.to_pybytes()) |
| assert idx == server_idx |
| idx += 1 |
| except StopIteration: |
| break |
| |
| with MetadataFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| reader = client.do_get(flight.Ticket(b'')) |
| idx = 0 |
| for batch, metadata in reader: |
| batches2.append(batch) |
| server_idx, = struct.unpack('<i', metadata.to_pybytes()) |
| assert idx == server_idx |
| idx += 1 |
| |
| assert batches1 == batches2 |
| |
| |
| def test_flight_do_get_metadata_v4(): |
| """Try a simple do_get call with V4 metadata version.""" |
| table = pa.Table.from_arrays( |
| [pa.array([-10, -5, 0, 5, 10])], names=['a']) |
| options = pa.ipc.IpcWriteOptions( |
| metadata_version=pa.ipc.MetadataVersion.V4) |
| with MetadataFlightServer(options=options) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| reader = client.do_get(flight.Ticket(b'')) |
| data = reader.read_all() |
| assert data.equals(table) |
| |
| |
| def test_flight_do_put_metadata(): |
| """Try a simple do_put call with metadata.""" |
| data = [ |
| pa.array([-10, -5, 0, 5, 10]) |
| ] |
| table = pa.Table.from_arrays(data, names=['a']) |
| |
| with MetadataFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| writer, metadata_reader = client.do_put( |
| flight.FlightDescriptor.for_path(''), |
| table.schema) |
| with writer: |
| for idx, batch in enumerate(table.to_batches(max_chunksize=1)): |
| metadata = struct.pack('<i', idx) |
| writer.write_with_metadata(batch, metadata) |
| buf = metadata_reader.read() |
| assert buf is not None |
| server_idx, = struct.unpack('<i', buf.to_pybytes()) |
| assert idx == server_idx |
| |
| |
| @pytest.mark.numpy |
| def test_flight_do_put_limit(): |
| """Try a simple do_put call with a size limit.""" |
| large_batch = pa.RecordBatch.from_arrays([ |
| pa.array(np.ones(768, dtype=np.int64())), |
| ], names=['a']) |
| |
| with EchoFlightServer() as server, \ |
| FlightClient(('localhost', server.port), |
| write_size_limit_bytes=4096) as client: |
| writer, metadata_reader = client.do_put( |
| flight.FlightDescriptor.for_path(''), |
| large_batch.schema) |
| with writer: |
| with pytest.raises(flight.FlightWriteSizeExceededError, |
| match="exceeded soft limit") as excinfo: |
| writer.write_batch(large_batch) |
| assert excinfo.value.limit == 4096 |
| smaller_batches = [ |
| large_batch.slice(0, 384), |
| large_batch.slice(384), |
| ] |
| for batch in smaller_batches: |
| writer.write_batch(batch) |
| expected = pa.Table.from_batches([large_batch]) |
| actual = client.do_get(flight.Ticket(b'')).read_all() |
| assert expected == actual |
| |
| |
| @pytest.mark.slow |
| def test_cancel_do_get(): |
| """Test canceling a DoGet operation on the client side.""" |
| with ConstantFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| reader = client.do_get(flight.Ticket(b'ints')) |
| reader.cancel() |
| with pytest.raises(flight.FlightCancelledError, |
| match="(?i).*cancel.*"): |
| reader.read_chunk() |
| |
| |
| @pytest.mark.threading |
| @pytest.mark.slow |
| def test_cancel_do_get_threaded(): |
| """Test canceling a DoGet operation from another thread.""" |
| with SlowFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| reader = client.do_get(flight.Ticket(b'ints')) |
| |
| read_first_message = threading.Event() |
| stream_canceled = threading.Event() |
| result_lock = threading.Lock() |
| raised_proper_exception = threading.Event() |
| |
| def block_read(): |
| reader.read_chunk() |
| read_first_message.set() |
| stream_canceled.wait(timeout=5) |
| try: |
| reader.read_chunk() |
| except flight.FlightCancelledError: |
| with result_lock: |
| raised_proper_exception.set() |
| |
| thread = threading.Thread(target=block_read, daemon=True) |
| thread.start() |
| read_first_message.wait(timeout=5) |
| reader.cancel() |
| stream_canceled.set() |
| thread.join(timeout=1) |
| |
| with result_lock: |
| assert raised_proper_exception.is_set() |
| |
| |
| def test_streaming_do_action(): |
| with ConvenienceServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| results = client.do_action(flight.Action('forever', b'')) |
| assert next(results).body == b'foo' |
| # Implicit cancel when destructed |
| del results |
| |
| |
| def test_roundtrip_types(): |
| """Make sure serializable types round-trip.""" |
| action = flight.Action("action1", b"action1-body") |
| assert action == flight.Action.deserialize(action.serialize()) |
| |
| ticket = flight.Ticket("foo") |
| assert ticket == flight.Ticket.deserialize(ticket.serialize()) |
| |
| result = flight.Result(b"result1") |
| assert result == flight.Result.deserialize(result.serialize()) |
| |
| basic_auth = flight.BasicAuth("username1", "password1") |
| assert basic_auth == flight.BasicAuth.deserialize(basic_auth.serialize()) |
| |
| schema_result = flight.SchemaResult(pa.schema([('a', pa.int32())])) |
| assert schema_result == flight.SchemaResult.deserialize( |
| schema_result.serialize()) |
| |
| desc = flight.FlightDescriptor.for_command("test") |
| assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) |
| |
| desc = flight.FlightDescriptor.for_path("a", "b", "test.arrow") |
| assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) |
| |
| info = flight.FlightInfo( |
| pa.schema([('a', pa.int32())]), |
| desc, |
| [ |
| flight.FlightEndpoint(b'', ['grpc://test']), |
| flight.FlightEndpoint( |
| b'', |
| [flight.Location.for_grpc_tcp('localhost', 5005)], |
| pa.scalar("2023-04-05T12:34:56.789012345").cast(pa.timestamp("ns")), |
| b'endpoint app metadata' |
| ), |
| ], |
| 1, |
| 42, |
| True, |
| b'test app metadata' |
| ) |
| info2 = flight.FlightInfo.deserialize(info.serialize()) |
| assert info.schema == info2.schema |
| assert info.descriptor == info2.descriptor |
| assert info.total_bytes == info2.total_bytes |
| assert info.total_records == info2.total_records |
| assert info.ordered == info2.ordered |
| assert info.app_metadata == info2.app_metadata |
| assert info.endpoints == info2.endpoints |
| |
| endpoint = flight.FlightEndpoint( |
| ticket, |
| ['grpc://test', flight.Location.for_grpc_tcp('localhost', 5005)], |
| pa.scalar("2023-04-05T12:34:56").cast(pa.timestamp("s")), |
| b'endpoint app metadata' |
| ) |
| assert endpoint == flight.FlightEndpoint.deserialize(endpoint.serialize()) |
| |
| |
| def test_roundtrip_errors(): |
| """Ensure that Flight errors propagate from server to client.""" |
| with ErrorFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| |
| for arg, exc_type in ErrorFlightServer.error_cases().items(): |
| with pytest.raises(exc_type, match=".*foo.*"): |
| list(client.do_action(flight.Action(arg, b""))) |
| with pytest.raises(flight.FlightInternalError, match=".*foo.*"): |
| list(client.list_flights()) |
| |
| data = [pa.array([-10, -5, 0, 5, 10])] |
| table = pa.Table.from_arrays(data, names=['a']) |
| |
| exceptions = { |
| 'internal': flight.FlightInternalError, |
| 'timedout': flight.FlightTimedOutError, |
| 'cancel': flight.FlightCancelledError, |
| 'unauthenticated': flight.FlightUnauthenticatedError, |
| 'unauthorized': flight.FlightUnauthorizedError, |
| } |
| |
| for command, exception in exceptions.items(): |
| |
| with pytest.raises(exception, match=".*foo.*"): |
| writer, reader = client.do_put( |
| flight.FlightDescriptor.for_command(command), |
| table.schema) |
| writer.write_table(table) |
| writer.close() |
| |
| with pytest.raises(exception, match=".*foo.*"): |
| writer, reader = client.do_put( |
| flight.FlightDescriptor.for_command(command), |
| table.schema) |
| writer.close() |
| |
| |
| def test_do_put_independent_read_write(): |
| """Ensure that separate threads can read/write on a DoPut.""" |
| # ARROW-6063: previously this would cause gRPC to abort when the |
| # writer was closed (due to simultaneous reads), or would hang |
| # forever. |
| data = [ |
| pa.array([-10, -5, 0, 5, 10]) |
| ] |
| table = pa.Table.from_arrays(data, names=['a']) |
| |
| with MetadataFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| writer, metadata_reader = client.do_put( |
| flight.FlightDescriptor.for_path(''), |
| table.schema) |
| |
| count = [0] |
| |
| def _reader_thread(): |
| while metadata_reader.read() is not None: |
| count[0] += 1 |
| |
| thread = threading.Thread(target=_reader_thread) |
| thread.start() |
| |
| batches = table.to_batches(max_chunksize=1) |
| with writer: |
| for idx, batch in enumerate(batches): |
| metadata = struct.pack('<i', idx) |
| writer.write_with_metadata(batch, metadata) |
| # Causes the server to stop writing and end the call |
| writer.done_writing() |
| # Thus reader thread will break out of loop |
| thread.join() |
| # writer.close() won't segfault since reader thread has |
| # stopped |
| assert count[0] == len(batches) |
| |
| |
| def test_server_middleware_same_thread(): |
| """Ensure that server middleware run on the same thread as the RPC.""" |
| with HeaderFlightServer(middleware={ |
| "test": HeaderServerMiddlewareFactory(), |
| }) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| results = list(client.do_action(flight.Action(b"test", b""))) |
| assert len(results) == 1 |
| value = results[0].body.to_pybytes() |
| assert b"right value" == value |
| |
| |
| def test_middleware_reject(): |
| """Test rejecting an RPC with server middleware.""" |
| with HeaderFlightServer(middleware={ |
| "test": SelectiveAuthServerMiddlewareFactory(), |
| }) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| # The middleware allows this through without auth. |
| with pytest.raises(pa.ArrowNotImplementedError): |
| list(client.list_actions()) |
| |
| # But not anything else. |
| with pytest.raises(flight.FlightUnauthenticatedError): |
| list(client.do_action(flight.Action(b"", b""))) |
| |
| client = FlightClient( |
| ('localhost', server.port), |
| middleware=[SelectiveAuthClientMiddlewareFactory()] |
| ) |
| response = next(client.do_action(flight.Action(b"", b""))) |
| assert b"password" == response.body.to_pybytes() |
| |
| |
| def test_middleware_mapping(): |
| """Test that middleware records methods correctly.""" |
| server_middleware = RecordingServerMiddlewareFactory() |
| client_middleware = RecordingClientMiddlewareFactory() |
| with FlightServerBase(middleware={"test": server_middleware}) as server, \ |
| FlightClient( |
| ('localhost', server.port), |
| middleware=[client_middleware] |
| ) as client: |
| |
| descriptor = flight.FlightDescriptor.for_command(b"") |
| with pytest.raises(NotImplementedError): |
| list(client.list_flights()) |
| with pytest.raises(NotImplementedError): |
| client.get_flight_info(descriptor) |
| with pytest.raises(NotImplementedError): |
| client.get_schema(descriptor) |
| with pytest.raises(NotImplementedError): |
| client.do_get(flight.Ticket(b"")) |
| with pytest.raises(NotImplementedError): |
| writer, _ = client.do_put(descriptor, pa.schema([])) |
| writer.close() |
| with pytest.raises(NotImplementedError): |
| list(client.do_action(flight.Action(b"", b""))) |
| with pytest.raises(NotImplementedError): |
| list(client.list_actions()) |
| with pytest.raises(NotImplementedError): |
| writer, _ = client.do_exchange(descriptor) |
| writer.close() |
| |
| expected = [ |
| flight.FlightMethod.LIST_FLIGHTS, |
| flight.FlightMethod.GET_FLIGHT_INFO, |
| flight.FlightMethod.GET_SCHEMA, |
| flight.FlightMethod.DO_GET, |
| flight.FlightMethod.DO_PUT, |
| flight.FlightMethod.DO_ACTION, |
| flight.FlightMethod.LIST_ACTIONS, |
| flight.FlightMethod.DO_EXCHANGE, |
| ] |
| assert server_middleware.methods == expected |
| assert client_middleware.methods == expected |
| |
| |
| def test_extra_info(): |
| with ErrorFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| try: |
| list(client.do_action(flight.Action("protobuf", b""))) |
| assert False |
| except flight.FlightUnauthorizedError as e: |
| assert e.extra_info is not None |
| ei = e.extra_info |
| assert ei == b'this is an error message' |
| |
| |
| @pytest.mark.requires_testing_data |
| def test_mtls(): |
| """Test mutual TLS (mTLS) with gRPC.""" |
| certs = example_tls_certs() |
| table = simple_ints_table() |
| |
| with ConstantFlightServer( |
| tls_certificates=[certs["certificates"][0]], |
| verify_client=True, |
| root_certificates=certs["root_cert"]) as s, \ |
| FlightClient( |
| ('localhost', s.port), |
| tls_root_certs=certs["root_cert"], |
| cert_chain=certs["certificates"][0].cert, |
| private_key=certs["certificates"][0].key) as client: |
| data = client.do_get(flight.Ticket(b'ints')).read_all() |
| assert data.equals(table) |
| |
| |
| def test_doexchange_get(): |
| """Emulate DoGet with DoExchange.""" |
| expected = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024)) |
| ], names=["a"]) |
| |
| with ExchangeFlightServer() as server, \ |
| FlightClient(("localhost", server.port)) as client: |
| descriptor = flight.FlightDescriptor.for_command(b"get") |
| writer, reader = client.do_exchange(descriptor) |
| with writer: |
| table = reader.read_all() |
| assert expected == table |
| |
| |
| def test_doexchange_put(): |
| """Emulate DoPut with DoExchange.""" |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024)) |
| ], names=["a"]) |
| batches = data.to_batches(max_chunksize=512) |
| |
| with ExchangeFlightServer() as server, \ |
| FlightClient(("localhost", server.port)) as client: |
| descriptor = flight.FlightDescriptor.for_command(b"put") |
| writer, reader = client.do_exchange(descriptor) |
| with writer: |
| writer.begin(data.schema) |
| for batch in batches: |
| writer.write_batch(batch) |
| writer.done_writing() |
| chunk = reader.read_chunk() |
| assert chunk.data is None |
| expected_buf = str(len(batches)).encode("utf-8") |
| assert chunk.app_metadata == expected_buf |
| # Metadata only message is not counted as an ipc data message |
| assert reader.stats.num_messages == 0 |
| |
| |
| def test_doexchange_echo(): |
| """Try a DoExchange echo server.""" |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024)) |
| ], names=["a"]) |
| batches = data.to_batches(max_chunksize=512) |
| |
| with ExchangeFlightServer() as server, \ |
| FlightClient(("localhost", server.port)) as client: |
| descriptor = flight.FlightDescriptor.for_command(b"echo") |
| writer, reader = client.do_exchange(descriptor) |
| with writer: |
| # Read/write metadata before starting data. |
| for i in range(10): |
| buf = str(i).encode("utf-8") |
| writer.write_metadata(buf) |
| chunk = reader.read_chunk() |
| assert chunk.data is None |
| assert chunk.app_metadata == buf |
| |
| # Now write data without metadata. |
| writer.begin(data.schema) |
| num_batches = 0 |
| for batch in batches: |
| writer.write_batch(batch) |
| assert reader.schema == data.schema |
| chunk = reader.read_chunk() |
| assert chunk.data == batch |
| assert chunk.app_metadata is None |
| num_batches += 1 |
| assert reader.stats.num_record_batches == num_batches |
| |
| # And write data with metadata. |
| for i, batch in enumerate(batches): |
| buf = str(i).encode("utf-8") |
| writer.write_with_metadata(batch, buf) |
| chunk = reader.read_chunk() |
| assert chunk.data == batch |
| assert chunk.app_metadata == buf |
| num_batches += 1 |
| assert reader.stats.num_record_batches == num_batches |
| |
| |
| def test_doexchange_echo_v4(): |
| """Try a DoExchange echo server using the V4 metadata version.""" |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024)) |
| ], names=["a"]) |
| batches = data.to_batches(max_chunksize=512) |
| |
| options = pa.ipc.IpcWriteOptions( |
| metadata_version=pa.ipc.MetadataVersion.V4) |
| with ExchangeFlightServer(options=options) as server, \ |
| FlightClient(("localhost", server.port)) as client: |
| descriptor = flight.FlightDescriptor.for_command(b"echo") |
| writer, reader = client.do_exchange(descriptor) |
| with writer: |
| # Now write data without metadata. |
| writer.begin(data.schema, options=options) |
| for batch in batches: |
| writer.write_batch(batch) |
| assert reader.schema == data.schema |
| chunk = reader.read_chunk() |
| assert chunk.data == batch |
| assert chunk.app_metadata is None |
| |
| |
| def test_doexchange_transform(): |
| """Transform a table with a service.""" |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 1024)), |
| pa.array(range(1, 1025)), |
| pa.array(range(2, 1026)), |
| ], names=["a", "b", "c"]) |
| expected = pa.Table.from_arrays([ |
| pa.array(range(3, 1024 * 3 + 3, 3)), |
| ], names=["sum"]) |
| |
| with ExchangeFlightServer() as server, \ |
| FlightClient(("localhost", server.port)) as client: |
| descriptor = flight.FlightDescriptor.for_command(b"transform") |
| writer, reader = client.do_exchange(descriptor) |
| with writer: |
| writer.begin(data.schema) |
| writer.write_table(data) |
| writer.done_writing() |
| table = reader.read_all() |
| assert expected == table |
| |
| |
| def test_middleware_multi_header(): |
| """Test sending/receiving multiple (binary-valued) headers.""" |
| with MultiHeaderFlightServer(middleware={ |
| "test": MultiHeaderServerMiddlewareFactory(), |
| }) as server: |
| headers = MultiHeaderClientMiddlewareFactory() |
| with FlightClient( |
| ('localhost', server.port), |
| middleware=[headers]) as client: |
| response = next(client.do_action(flight.Action(b"", b""))) |
| # The server echoes the headers it got back to us. |
| raw_headers = response.body.to_pybytes().decode("utf-8") |
| client_headers = ast.literal_eval(raw_headers) |
| # Don't directly compare; gRPC may add headers like User-Agent. |
| for header, values in MultiHeaderClientMiddleware.EXPECTED.items(): |
| header = header.lower() |
| if isinstance(header, bytes): |
| header = header.decode("ascii") |
| assert client_headers.get(header) == values |
| assert headers.last_headers.get(header) == values |
| |
| |
| @pytest.mark.requires_testing_data |
| def test_generic_options(): |
| """Test setting generic client options.""" |
| certs = example_tls_certs() |
| |
| with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: |
| # Try setting a string argument that will make requests fail |
| options = [("grpc.ssl_target_name_override", "fakehostname")] |
| client = flight.connect(('localhost', s.port), |
| tls_root_certs=certs["root_cert"], |
| generic_options=options) |
| with pytest.raises(flight.FlightUnavailableError): |
| client.do_get(flight.Ticket(b'ints')) |
| client.close() |
| # Try setting an int argument that will make requests fail |
| options = [("grpc.max_receive_message_length", 32)] |
| client = flight.connect(('localhost', s.port), |
| tls_root_certs=certs["root_cert"], |
| generic_options=options) |
| with pytest.raises((pa.ArrowInvalid, flight.FlightCancelledError)): |
| client.do_get(flight.Ticket(b'ints')) |
| client.close() |
| |
| |
| class CancelFlightServer(FlightServerBase): |
| """A server for testing StopToken.""" |
| |
| def do_get(self, context, ticket): |
| schema = pa.schema([]) |
| rb = pa.RecordBatch.from_arrays([], schema=schema) |
| return flight.GeneratorStream(schema, itertools.repeat(rb)) |
| |
| def do_exchange(self, context, descriptor, reader, writer): |
| schema = pa.schema([]) |
| rb = pa.RecordBatch.from_arrays([], schema=schema) |
| writer.begin(schema) |
| while not context.is_cancelled(): |
| writer.write_batch(rb) |
| time.sleep(0.5) |
| |
| |
| @pytest.mark.threading |
| def test_interrupt(): |
| if threading.current_thread().ident != threading.main_thread().ident: |
| pytest.skip("test only works from main Python thread") |
| |
| def signal_from_thread(): |
| time.sleep(0.5) |
| signal.raise_signal(signal.SIGINT) |
| |
| exc_types = (KeyboardInterrupt, pa.ArrowCancelled) |
| |
| def test(read_all): |
| try: |
| try: |
| t = threading.Thread(target=signal_from_thread) |
| with pytest.raises(exc_types) as exc_info: |
| t.start() |
| read_all() |
| finally: |
| t.join() |
| except KeyboardInterrupt: |
| # In case KeyboardInterrupt didn't interrupt read_all |
| # above, at least prevent it from stopping the test suite |
| pytest.fail("KeyboardInterrupt didn't interrupt Flight read_all") |
| # __context__ is sometimes None |
| e = exc_info.value |
| assert isinstance(e, (pa.ArrowCancelled, KeyboardInterrupt)) or \ |
| isinstance(e.__context__, (pa.ArrowCancelled, KeyboardInterrupt)) |
| |
| with CancelFlightServer() as server, \ |
| FlightClient(("localhost", server.port)) as client: |
| |
| reader = client.do_get(flight.Ticket(b"")) |
| test(reader.read_all) |
| |
| descriptor = flight.FlightDescriptor.for_command(b"echo") |
| writer, reader = client.do_exchange(descriptor) |
| test(reader.read_all) |
| try: |
| writer.close() |
| except (KeyboardInterrupt, flight.FlightCancelledError): |
| # Silence the Cancelled/Interrupt exception |
| pass |
| |
| |
| def test_never_sends_data(): |
| # Regression test for ARROW-12779 |
| match = "application server implementation error" |
| with NeverSendsDataFlightServer() as server, \ |
| flight.connect(('localhost', server.port)) as client: |
| with pytest.raises(flight.FlightServerError, match=match): |
| client.do_get(flight.Ticket(b'')).read_all() |
| |
| # Check that the server handler will ignore empty tables |
| # up to a certain extent |
| table = client.do_get(flight.Ticket(b'yield_data')).read_all() |
| assert table.num_rows == 5 |
| |
| |
| @pytest.mark.large_memory |
| @pytest.mark.slow |
| def test_large_descriptor(): |
| # Regression test for ARROW-13253. Placed here with appropriate marks |
| # since some CI pipelines can't run the C++ equivalent |
| large_descriptor = flight.FlightDescriptor.for_command( |
| b' ' * (2 ** 31 + 1)) |
| with FlightServerBase() as server, \ |
| flight.connect(('localhost', server.port)) as client: |
| with pytest.raises(OSError, |
| match="Failed to serialize Flight descriptor"): |
| writer, _ = client.do_put(large_descriptor, pa.schema([])) |
| writer.close() |
| with pytest.raises(pa.ArrowException, |
| match="Failed to serialize Flight descriptor"): |
| client.do_exchange(large_descriptor) |
| |
| |
| @pytest.mark.large_memory |
| @pytest.mark.slow |
| def test_large_metadata_client(): |
| # Regression test for ARROW-13253 |
| descriptor = flight.FlightDescriptor.for_command(b'') |
| metadata = b' ' * (2 ** 31 + 1) |
| with EchoFlightServer() as server, \ |
| flight.connect(('localhost', server.port)) as client: |
| with pytest.raises(pa.ArrowCapacityError, |
| match="app_metadata size overflow"): |
| writer, _ = client.do_put(descriptor, pa.schema([])) |
| with writer: |
| writer.write_metadata(metadata) |
| writer.close() |
| with pytest.raises(pa.ArrowCapacityError, |
| match="app_metadata size overflow"): |
| writer, reader = client.do_exchange(descriptor) |
| with writer: |
| writer.write_metadata(metadata) |
| |
| del metadata |
| with LargeMetadataFlightServer() as server, \ |
| flight.connect(('localhost', server.port)) as client: |
| with pytest.raises(flight.FlightServerError, |
| match="app_metadata size overflow"): |
| reader = client.do_get(flight.Ticket(b'')) |
| reader.read_all() |
| with pytest.raises(pa.ArrowException, |
| match="app_metadata size overflow"): |
| writer, reader = client.do_exchange(descriptor) |
| with writer: |
| reader.read_all() |
| |
| |
| class ActionNoneFlightServer(EchoFlightServer): |
| """A server that implements a side effect to a non iterable action.""" |
| VALUES = [] |
| |
| def do_action(self, context, action): |
| if action.type == "get_value": |
| return [json.dumps(self.VALUES).encode('utf-8')] |
| elif action.type == "append": |
| self.VALUES.append(True) |
| return None |
| raise NotImplementedError |
| |
| |
| def test_none_action_side_effect(): |
| """Ensure that actions are executed even when we don't consume iterator. |
| |
| See https://issues.apache.org/jira/browse/ARROW-14255 |
| """ |
| |
| with ActionNoneFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| client.do_action(flight.Action("append", b"")) |
| r = client.do_action(flight.Action("get_value", b"")) |
| assert json.loads(next(r).body.to_pybytes()) == [True] |
| |
| |
| @pytest.mark.slow # Takes a while for gRPC to "realize" writes fail |
| def test_write_error_propagation(): |
| """ |
| Ensure that exceptions during writing preserve error context. |
| |
| See https://issues.apache.org/jira/browse/ARROW-16592. |
| """ |
| expected_message = "foo" |
| expected_info = b"bar" |
| exc = flight.FlightCancelledError( |
| expected_message, extra_info=expected_info) |
| descriptor = flight.FlightDescriptor.for_command(b"") |
| schema = pa.schema([("int64", pa.int64())]) |
| |
| class FailServer(flight.FlightServerBase): |
| def do_put(self, context, descriptor, reader, writer): |
| raise exc |
| |
| def do_exchange(self, context, descriptor, reader, writer): |
| raise exc |
| |
| with FailServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| # DoPut |
| writer, reader = client.do_put(descriptor, schema) |
| |
| # Set a concurrent reader - ensure this doesn't block the |
| # writer side from calling Close() |
| def _reader(): |
| try: |
| while True: |
| reader.read() |
| except flight.FlightError: |
| return |
| |
| thread = threading.Thread(target=_reader, daemon=True) |
| thread.start() |
| |
| with pytest.raises(flight.FlightCancelledError) as exc_info: |
| while True: |
| writer.write_batch(pa.record_batch([[1]], schema=schema)) |
| assert exc_info.value.extra_info == expected_info |
| |
| with pytest.raises(flight.FlightCancelledError) as exc_info: |
| writer.close() |
| assert exc_info.value.extra_info == expected_info |
| thread.join() |
| |
| # DoExchange |
| writer, reader = client.do_exchange(descriptor) |
| |
| def _reader(): |
| try: |
| while True: |
| reader.read_chunk() |
| except flight.FlightError: |
| return |
| |
| thread = threading.Thread(target=_reader, daemon=True) |
| thread.start() |
| with pytest.raises(flight.FlightCancelledError) as exc_info: |
| while True: |
| writer.write_metadata(b" ") |
| assert exc_info.value.extra_info == expected_info |
| |
| with pytest.raises(flight.FlightCancelledError) as exc_info: |
| writer.close() |
| assert exc_info.value.extra_info == expected_info |
| thread.join() |
| |
| |
| def test_interpreter_shutdown(): |
| """ |
| Ensure that the gRPC server is stopped at interpreter shutdown. |
| |
| See https://issues.apache.org/jira/browse/ARROW-16597. |
| """ |
| util.invoke_script("arrow_16597.py") |
| |
| |
| class TracingFlightServer(FlightServerBase): |
| """A server that echoes back trace context values.""" |
| |
| def do_action(self, context, action): |
| trace_context = context.get_middleware("tracing").trace_context |
| # Don't turn this method into a generator since then |
| # trace_context will be evaluated after we've exited the scope |
| # of the OTel span (and so the value we want won't be present) |
| return ((f"{key}: {value}").encode("utf-8") |
| for (key, value) in trace_context.items()) |
| |
| |
| def test_tracing(): |
| with TracingFlightServer(middleware={ |
| "tracing": flight.TracingServerMiddlewareFactory(), |
| }) as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| # We can't tell if Arrow was built with OpenTelemetry support, |
| # so we can't count on any particular values being there; we |
| # can only ensure things don't blow up either way. |
| options = flight.FlightCallOptions(headers=[ |
| # Pretend we have an OTel implementation |
| (b"traceparent", b"00-000ff00f00f0ff000f0f00ff0f00fff0-" |
| b"000f0000f0f00000-00"), |
| (b"tracestate", b""), |
| ]) |
| for value in client.do_action((b"", b""), options=options): |
| pass |
| |
| |
| def test_do_put_does_not_crash_when_schema_is_none(): |
| client = FlightClient('grpc+tls://localhost:9643', |
| disable_server_verification=True) |
| msg = ("Argument 'schema' has incorrect type " |
| r"\(expected pyarrow.lib.Schema, got NoneType\)") |
| with pytest.raises(TypeError, match=msg): |
| client.do_put(flight.FlightDescriptor.for_command('foo'), |
| schema=None) |
| |
| |
| def test_headers_trailers(): |
| """Ensure that server-sent headers/trailers make it through.""" |
| |
| class HeadersTrailersFlightServer(FlightServerBase): |
| def get_flight_info(self, context, descriptor): |
| context.add_header("x-header", "header-value") |
| context.add_header("x-header-bin", "header\x01value") |
| context.add_trailer("x-trailer", "trailer-value") |
| context.add_trailer("x-trailer-bin", "trailer\x01value") |
| return flight.FlightInfo( |
| pa.schema([]), |
| descriptor, |
| [] |
| ) |
| |
| class HeadersTrailersMiddlewareFactory(ClientMiddlewareFactory): |
| def __init__(self): |
| self.headers = [] |
| |
| def start_call(self, info): |
| return HeadersTrailersMiddleware(self) |
| |
| class HeadersTrailersMiddleware(ClientMiddleware): |
| def __init__(self, factory): |
| self.factory = factory |
| |
| def received_headers(self, headers): |
| for key, values in headers.items(): |
| for value in values: |
| self.factory.headers.append((key, value)) |
| |
| factory = HeadersTrailersMiddlewareFactory() |
| with HeadersTrailersFlightServer() as server, \ |
| FlightClient(("localhost", server.port), middleware=[factory]) as client: |
| client.get_flight_info(flight.FlightDescriptor.for_path("")) |
| assert ("x-header", "header-value") in factory.headers |
| assert ("x-header-bin", b"header\x01value") in factory.headers |
| assert ("x-trailer", "trailer-value") in factory.headers |
| assert ("x-trailer-bin", b"trailer\x01value") in factory.headers |
| |
| |
| def test_flight_dictionary_deltas_do_exchange(): |
| expected_stats = { |
| 'dict_deltas': ReadStats( |
| num_messages=6, |
| num_record_batches=3, |
| num_dictionary_batches=2, |
| num_dictionary_deltas=1, |
| num_replaced_dictionaries=0 |
| ), |
| 'dict_replacement': ReadStats( |
| num_messages=6, |
| num_record_batches=3, |
| num_dictionary_batches=2, |
| num_dictionary_deltas=0, |
| num_replaced_dictionaries=1 |
| ) |
| } |
| |
| class DeltaFlightServer(ConstantFlightServer): |
| def do_exchange(self, context, descriptor, reader, writer): |
| expected_table = simple_dicts_table() |
| received_table = reader.read_all() |
| assert received_table.equals(expected_table) |
| assert reader.stats == expected_stats[descriptor.command.decode()] |
| if descriptor.command == b'dict_deltas': |
| options = pa.ipc.IpcWriteOptions(emit_dictionary_deltas=True) |
| writer.begin(expected_table.schema, options=options) |
| writer.write_table(expected_table) |
| if descriptor.command == b'dict_replacement': |
| writer.begin(expected_table.schema) |
| writer.write_table(expected_table) |
| |
| with DeltaFlightServer() as server, \ |
| FlightClient(('localhost', server.port)) as client: |
| expected_table = simple_dicts_table() |
| for command in ["dict_deltas", "dict_replacement"]: |
| descriptor = flight.FlightDescriptor.for_command(command) |
| writer, reader = client.do_exchange( |
| descriptor, |
| options=flight.FlightCallOptions( |
| write_options=pa.ipc.IpcWriteOptions( |
| emit_dictionary_deltas=True) |
| ) |
| ) |
| # Send client table with dictionary updates |
| with writer: |
| writer.begin(expected_table.schema, options=pa.ipc.IpcWriteOptions( |
| emit_dictionary_deltas=(command == "dict_deltas"))) |
| writer.write_table(expected_table) |
| writer.done_writing() |
| received_table = reader.read_all() |
| |
| assert received_table.equals(expected_table) |
| assert reader.stats == expected_stats[command] |
| |
| |
| @pytest.fixture |
| def call_options_args(request): |
| if request.param == "default": |
| return { |
| "timeout": 3, |
| "headers": None, |
| "write_options": None, |
| "read_options": None, |
| } |
| elif request.param == "all": |
| return { |
| "timeout": 7, |
| "headers": [(b"abc", b"def")], |
| "write_options": pa.ipc.IpcWriteOptions(compression="zstd"), |
| "read_options": pa.ipc.IpcReadOptions( |
| use_threads=False, |
| ensure_alignment=pa.ipc.Alignment.DataTypeSpecific, |
| ), |
| } |
| else: |
| return {} |
| |
| |
| @pytest.mark.parametrize( |
| "call_options_args", ["default", "all"], indirect=True) |
| def test_call_options_repr(call_options_args): |
| # https://github.com/apache/arrow/issues/47358 |
| call_options = FlightCallOptions(**call_options_args) |
| repr = call_options.__repr__() |
| |
| for arg, val in call_options_args.items(): |
| if val is None: |
| assert arg in repr |
| continue |
| |
| assert f"{arg}={val}" in repr |