| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import ast |
| import base64 |
| import os |
| import struct |
| import tempfile |
| import threading |
| import time |
| import traceback |
| |
| import numpy as np |
| import pytest |
| import pyarrow as pa |
| |
| from pyarrow.lib import tobytes |
| from pyarrow.util import pathlib, find_free_port |
| |
| try: |
| from pyarrow import flight |
| from pyarrow.flight import ( |
| FlightClient, FlightServerBase, |
| ServerAuthHandler, ClientAuthHandler, |
| ServerMiddleware, ServerMiddlewareFactory, |
| ClientMiddleware, ClientMiddlewareFactory, |
| ) |
| except ImportError: |
| flight = None |
| FlightClient, FlightServerBase = object, object |
| ServerAuthHandler, ClientAuthHandler = object, object |
| ServerMiddleware, ServerMiddlewareFactory = object, object |
| ClientMiddleware, ClientMiddlewareFactory = object, object |
| |
| # Marks all of the tests in this module |
| # Ignore these with pytest ... -m 'not flight' |
| pytestmark = pytest.mark.flight |
| |
| |
| def test_import(): |
| # So we see the ImportError somewhere |
| import pyarrow.flight # noqa |
| |
| |
| def resource_root(): |
| """Get the path to the test resources directory.""" |
| if not os.environ.get("ARROW_TEST_DATA"): |
| raise RuntimeError("Test resources not found; set " |
| "ARROW_TEST_DATA to <repo root>/testing/data") |
| return pathlib.Path(os.environ["ARROW_TEST_DATA"]) / "flight" |
| |
| |
| def read_flight_resource(path): |
| """Get the contents of a test resource file.""" |
| root = resource_root() |
| if not root: |
| return None |
| try: |
| with (root / path).open("rb") as f: |
| return f.read() |
| except FileNotFoundError: |
| raise RuntimeError( |
| "Test resource {} not found; did you initialize the " |
| "test resource submodule?\n{}".format(root / path, |
| traceback.format_exc())) |
| |
| |
| def example_tls_certs(): |
| """Get the paths to test TLS certificates.""" |
| return { |
| "root_cert": read_flight_resource("root-ca.pem"), |
| "certificates": [ |
| flight.CertKeyPair( |
| cert=read_flight_resource("cert0.pem"), |
| key=read_flight_resource("cert0.key"), |
| ), |
| flight.CertKeyPair( |
| cert=read_flight_resource("cert1.pem"), |
| key=read_flight_resource("cert1.key"), |
| ), |
| ] |
| } |
| |
| |
| def simple_ints_table(): |
| data = [ |
| pa.array([-10, -5, 0, 5, 10]) |
| ] |
| return pa.Table.from_arrays(data, names=['some_ints']) |
| |
| |
| def simple_dicts_table(): |
| dict_values = pa.array(["foo", "baz", "quux"], type=pa.utf8()) |
| data = [ |
| pa.chunked_array([ |
| pa.DictionaryArray.from_arrays([1, 0, None], dict_values), |
| pa.DictionaryArray.from_arrays([2, 1], dict_values) |
| ]) |
| ] |
| return pa.Table.from_arrays(data, names=['some_dicts']) |
| |
| |
| class ConstantFlightServer(FlightServerBase): |
| """A Flight server that always returns the same data. |
| |
| See ARROW-4796: this server implementation will segfault if Flight |
| does not properly hold a reference to the Table object. |
| """ |
| |
| CRITERIA = b"the expected criteria" |
| |
| def __init__(self, location=None, options=None, **kwargs): |
| super().__init__(location, **kwargs) |
| # Ticket -> Table |
| self.table_factories = { |
| b'ints': simple_ints_table, |
| b'dicts': simple_dicts_table, |
| } |
| self.options = options |
| |
| def list_flights(self, context, criteria): |
| if criteria == self.CRITERIA: |
| yield flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path('/foo'), |
| [], |
| -1, -1 |
| ) |
| |
| def do_get(self, context, ticket): |
| # Return a fresh table, so that Flight is the only one keeping a |
| # reference. |
| table = self.table_factories[ticket.ticket]() |
| return flight.RecordBatchStream(table, options=self.options) |
| |
| |
| class MetadataFlightServer(FlightServerBase): |
| """A Flight server that numbers incoming/outgoing data.""" |
| |
| def __init__(self, options=None, **kwargs): |
| super().__init__(**kwargs) |
| self.options = options |
| |
| def do_get(self, context, ticket): |
| data = [ |
| pa.array([-10, -5, 0, 5, 10]) |
| ] |
| table = pa.Table.from_arrays(data, names=['a']) |
| return flight.GeneratorStream( |
| table.schema, |
| self.number_batches(table), |
| options=self.options) |
| |
| def do_put(self, context, descriptor, reader, writer): |
| counter = 0 |
| expected_data = [-10, -5, 0, 5, 10] |
| while True: |
| try: |
| batch, buf = reader.read_chunk() |
| assert batch.equals(pa.RecordBatch.from_arrays( |
| [pa.array([expected_data[counter]])], |
| ['a'] |
| )) |
| assert buf is not None |
| client_counter, = struct.unpack('<i', buf.to_pybytes()) |
| assert counter == client_counter |
| writer.write(struct.pack('<i', counter)) |
| counter += 1 |
| except StopIteration: |
| return |
| |
| @staticmethod |
| def number_batches(table): |
| for idx, batch in enumerate(table.to_batches()): |
| buf = struct.pack('<i', idx) |
| yield batch, buf |
| |
| |
| class EchoFlightServer(FlightServerBase): |
| """A Flight server that returns the last data uploaded.""" |
| |
| def __init__(self, location=None, expected_schema=None, **kwargs): |
| super().__init__(location, **kwargs) |
| self.last_message = None |
| self.expected_schema = expected_schema |
| |
| def do_get(self, context, ticket): |
| return flight.RecordBatchStream(self.last_message) |
| |
| def do_put(self, context, descriptor, reader, writer): |
| if self.expected_schema: |
| assert self.expected_schema == reader.schema |
| self.last_message = reader.read_all() |
| |
| |
| class EchoStreamFlightServer(EchoFlightServer): |
| """An echo server that streams individual record batches.""" |
| |
| def do_get(self, context, ticket): |
| return flight.GeneratorStream( |
| self.last_message.schema, |
| self.last_message.to_batches(max_chunksize=1024)) |
| |
| def list_actions(self, context): |
| return [] |
| |
| def do_action(self, context, action): |
| if action.type == "who-am-i": |
| return [context.peer_identity(), context.peer().encode("utf-8")] |
| raise NotImplementedError |
| |
| |
| class GetInfoFlightServer(FlightServerBase): |
| """A Flight server that tests GetFlightInfo.""" |
| |
| def get_flight_info(self, context, descriptor): |
| return flight.FlightInfo( |
| pa.schema([('a', pa.int32())]), |
| descriptor, |
| [ |
| flight.FlightEndpoint(b'', ['grpc://test']), |
| flight.FlightEndpoint( |
| b'', |
| [flight.Location.for_grpc_tcp('localhost', 5005)], |
| ), |
| ], |
| -1, |
| -1, |
| ) |
| |
| def get_schema(self, context, descriptor): |
| info = self.get_flight_info(context, descriptor) |
| return flight.SchemaResult(info.schema) |
| |
| |
| class ListActionsFlightServer(FlightServerBase): |
| """A Flight server that tests ListActions.""" |
| |
| @classmethod |
| def expected_actions(cls): |
| return [ |
| ("action-1", "description"), |
| ("action-2", ""), |
| flight.ActionType("action-3", "more detail"), |
| ] |
| |
| def list_actions(self, context): |
| yield from self.expected_actions() |
| |
| |
| class ListActionsErrorFlightServer(FlightServerBase): |
| """A Flight server that tests ListActions.""" |
| |
| def list_actions(self, context): |
| yield ("action-1", "") |
| yield "foo" |
| |
| |
| class CheckTicketFlightServer(FlightServerBase): |
| """A Flight server that compares the given ticket to an expected value.""" |
| |
| def __init__(self, expected_ticket, location=None, **kwargs): |
| super().__init__(location, **kwargs) |
| self.expected_ticket = expected_ticket |
| |
| def do_get(self, context, ticket): |
| assert self.expected_ticket == ticket.ticket |
| data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] |
| table = pa.Table.from_arrays(data1, names=['a']) |
| return flight.RecordBatchStream(table) |
| |
| def do_put(self, context, descriptor, reader): |
| self.last_message = reader.read_all() |
| |
| |
| class InvalidStreamFlightServer(FlightServerBase): |
| """A Flight server that tries to return messages with differing schemas.""" |
| |
| schema = pa.schema([('a', pa.int32())]) |
| |
| def do_get(self, context, ticket): |
| data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] |
| data2 = [pa.array([-10.0, -5.0, 0.0, 5.0, 10.0], type=pa.float64())] |
| assert data1.type != data2.type |
| table1 = pa.Table.from_arrays(data1, names=['a']) |
| table2 = pa.Table.from_arrays(data2, names=['a']) |
| assert table1.schema == self.schema |
| |
| return flight.GeneratorStream(self.schema, [table1, table2]) |
| |
| |
| class SlowFlightServer(FlightServerBase): |
| """A Flight server that delays its responses to test timeouts.""" |
| |
| def do_get(self, context, ticket): |
| return flight.GeneratorStream(pa.schema([('a', pa.int32())]), |
| self.slow_stream()) |
| |
| def do_action(self, context, action): |
| time.sleep(0.5) |
| return [] |
| |
| @staticmethod |
| def slow_stream(): |
| data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] |
| yield pa.Table.from_arrays(data1, names=['a']) |
| # The second message should never get sent; the client should |
| # cancel before we send this |
| time.sleep(10) |
| yield pa.Table.from_arrays(data1, names=['a']) |
| |
| |
| class ErrorFlightServer(FlightServerBase): |
| """A Flight server that uses all the Flight-specific errors.""" |
| |
| def do_action(self, context, action): |
| if action.type == "internal": |
| raise flight.FlightInternalError("foo") |
| elif action.type == "timedout": |
| raise flight.FlightTimedOutError("foo") |
| elif action.type == "cancel": |
| raise flight.FlightCancelledError("foo") |
| elif action.type == "unauthenticated": |
| raise flight.FlightUnauthenticatedError("foo") |
| elif action.type == "unauthorized": |
| raise flight.FlightUnauthorizedError("foo") |
| elif action.type == "protobuf": |
| err_msg = b'this is an error message' |
| raise flight.FlightUnauthorizedError("foo", err_msg) |
| raise NotImplementedError |
| |
| def list_flights(self, context, criteria): |
| yield flight.FlightInfo( |
| pa.schema([]), |
| flight.FlightDescriptor.for_path('/foo'), |
| [], |
| -1, -1 |
| ) |
| raise flight.FlightInternalError("foo") |
| |
| |
| class ExchangeFlightServer(FlightServerBase): |
| """A server for testing DoExchange.""" |
| |
| def __init__(self, options=None, **kwargs): |
| super().__init__(**kwargs) |
| self.options = options |
| |
| def do_exchange(self, context, descriptor, reader, writer): |
| if descriptor.descriptor_type != flight.DescriptorType.CMD: |
| raise pa.ArrowInvalid("Must provide a command descriptor") |
| elif descriptor.command == b"echo": |
| return self.exchange_echo(context, reader, writer) |
| elif descriptor.command == b"get": |
| return self.exchange_do_get(context, reader, writer) |
| elif descriptor.command == b"put": |
| return self.exchange_do_put(context, reader, writer) |
| elif descriptor.command == b"transform": |
| return self.exchange_transform(context, reader, writer) |
| else: |
| raise pa.ArrowInvalid( |
| "Unknown command: {}".format(descriptor.command)) |
| |
| def exchange_do_get(self, context, reader, writer): |
| """Emulate DoGet with DoExchange.""" |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024)) |
| ], names=["a"]) |
| writer.begin(data.schema) |
| writer.write_table(data) |
| |
| def exchange_do_put(self, context, reader, writer): |
| """Emulate DoPut with DoExchange.""" |
| num_batches = 0 |
| for chunk in reader: |
| if not chunk.data: |
| raise pa.ArrowInvalid("All chunks must have data.") |
| num_batches += 1 |
| writer.write_metadata(str(num_batches).encode("utf-8")) |
| |
| def exchange_echo(self, context, reader, writer): |
| """Run a simple echo server.""" |
| started = False |
| for chunk in reader: |
| if not started and chunk.data: |
| writer.begin(chunk.data.schema, options=self.options) |
| started = True |
| if chunk.app_metadata and chunk.data: |
| writer.write_with_metadata(chunk.data, chunk.app_metadata) |
| elif chunk.app_metadata: |
| writer.write_metadata(chunk.app_metadata) |
| elif chunk.data: |
| writer.write_batch(chunk.data) |
| else: |
| assert False, "Should not happen" |
| |
| def exchange_transform(self, context, reader, writer): |
| """Sum rows in an uploaded table.""" |
| for field in reader.schema: |
| if not pa.types.is_integer(field.type): |
| raise pa.ArrowInvalid("Invalid field: " + repr(field)) |
| table = reader.read_all() |
| sums = [0] * table.num_rows |
| for column in table: |
| for row, value in enumerate(column): |
| sums[row] += value.as_py() |
| result = pa.Table.from_arrays([pa.array(sums)], names=["sum"]) |
| writer.begin(result.schema) |
| writer.write_table(result) |
| |
| |
| class HttpBasicServerAuthHandler(ServerAuthHandler): |
| """An example implementation of HTTP basic authentication.""" |
| |
| def __init__(self, creds): |
| super().__init__() |
| self.creds = creds |
| |
| def authenticate(self, outgoing, incoming): |
| buf = incoming.read() |
| auth = flight.BasicAuth.deserialize(buf) |
| if auth.username not in self.creds: |
| raise flight.FlightUnauthenticatedError("unknown user") |
| if self.creds[auth.username] != auth.password: |
| raise flight.FlightUnauthenticatedError("wrong password") |
| outgoing.write(tobytes(auth.username)) |
| |
| def is_valid(self, token): |
| if not token: |
| raise flight.FlightUnauthenticatedError("token not provided") |
| if token not in self.creds: |
| raise flight.FlightUnauthenticatedError("unknown user") |
| return token |
| |
| |
| class HttpBasicClientAuthHandler(ClientAuthHandler): |
| """An example implementation of HTTP basic authentication.""" |
| |
| def __init__(self, username, password): |
| super().__init__() |
| self.basic_auth = flight.BasicAuth(username, password) |
| self.token = None |
| |
| def authenticate(self, outgoing, incoming): |
| auth = self.basic_auth.serialize() |
| outgoing.write(auth) |
| self.token = incoming.read() |
| |
| def get_token(self): |
| return self.token |
| |
| |
| class TokenServerAuthHandler(ServerAuthHandler): |
| """An example implementation of authentication via handshake.""" |
| |
| def __init__(self, creds): |
| super().__init__() |
| self.creds = creds |
| |
| def authenticate(self, outgoing, incoming): |
| username = incoming.read() |
| password = incoming.read() |
| if username in self.creds and self.creds[username] == password: |
| outgoing.write(base64.b64encode(b'secret:' + username)) |
| else: |
| raise flight.FlightUnauthenticatedError( |
| "invalid username/password") |
| |
| def is_valid(self, token): |
| token = base64.b64decode(token) |
| if not token.startswith(b'secret:'): |
| raise flight.FlightUnauthenticatedError("invalid token") |
| return token[7:] |
| |
| |
| class TokenClientAuthHandler(ClientAuthHandler): |
| """An example implementation of authentication via handshake.""" |
| |
| def __init__(self, username, password): |
| super().__init__() |
| self.username = username |
| self.password = password |
| self.token = b'' |
| |
| def authenticate(self, outgoing, incoming): |
| outgoing.write(self.username) |
| outgoing.write(self.password) |
| self.token = incoming.read() |
| |
| def get_token(self): |
| return self.token |
| |
| |
| class NoopAuthHandler(ServerAuthHandler): |
| """A no-op auth handler.""" |
| |
| def authenticate(self, outgoing, incoming): |
| """Do nothing.""" |
| |
| def is_valid(self, token): |
| """ |
| Returning an empty string. |
| Returning None causes Type error. |
| """ |
| return "" |
| |
| |
| def case_insensitive_header_lookup(headers, lookup_key): |
| """Lookup the value of given key in the given headers. |
| The key lookup is case insensitive. |
| """ |
| for key in headers: |
| if key.lower() == lookup_key.lower(): |
| return headers.get(key) |
| |
| |
| class ClientHeaderAuthMiddlewareFactory(ClientMiddlewareFactory): |
| """ClientMiddlewareFactory that creates ClientAuthHeaderMiddleware.""" |
| |
| def __init__(self): |
| self.call_credential = [] |
| |
| def start_call(self, info): |
| return ClientHeaderAuthMiddleware(self) |
| |
| def set_call_credential(self, call_credential): |
| self.call_credential = call_credential |
| |
| |
| class ClientHeaderAuthMiddleware(ClientMiddleware): |
| """ |
| ClientMiddleware that extracts the authorization header |
| from the server. |
| |
| This is an example of a ClientMiddleware that can extract |
| the bearer token authorization header from a HTTP header |
| authentication enabled server. |
| |
| Parameters |
| ---------- |
| factory : ClientHeaderAuthMiddlewareFactory |
| This factory is used to set call credentials if an |
| authorization header is found in the headers from the server. |
| """ |
| |
| def __init__(self, factory): |
| self.factory = factory |
| |
| def received_headers(self, headers): |
| auth_header = case_insensitive_header_lookup(headers, 'Authorization') |
| self.factory.set_call_credential([ |
| b'authorization', |
| auth_header[0].encode("utf-8")]) |
| |
| |
| class HeaderAuthServerMiddlewareFactory(ServerMiddlewareFactory): |
| """Validates incoming username and password.""" |
| |
| def start_call(self, info, headers): |
| auth_header = case_insensitive_header_lookup( |
| headers, |
| 'Authorization' |
| ) |
| values = auth_header[0].split(' ') |
| token = '' |
| error_message = 'Invalid credentials' |
| |
| if values[0] == 'Basic': |
| decoded = base64.b64decode(values[1]) |
| pair = decoded.decode("utf-8").split(':') |
| if not (pair[0] == 'test' and pair[1] == 'password'): |
| raise flight.FlightUnauthenticatedError(error_message) |
| token = 'token1234' |
| elif values[0] == 'Bearer': |
| token = values[1] |
| if not token == 'token1234': |
| raise flight.FlightUnauthenticatedError(error_message) |
| else: |
| raise flight.FlightUnauthenticatedError(error_message) |
| |
| return HeaderAuthServerMiddleware(token) |
| |
| |
| class HeaderAuthServerMiddleware(ServerMiddleware): |
| """A ServerMiddleware that transports incoming username and passowrd.""" |
| |
| def __init__(self, token): |
| self.token = token |
| |
| def sending_headers(self): |
| return {'authorization': 'Bearer ' + self.token} |
| |
| |
| class HeaderAuthFlightServer(FlightServerBase): |
| """A Flight server that tests with basic token authentication. """ |
| |
| def do_action(self, context, action): |
| middleware = context.get_middleware("auth") |
| if middleware: |
| auth_header = case_insensitive_header_lookup( |
| middleware.sending_headers(), 'Authorization') |
| values = auth_header.split(' ') |
| return [values[1].encode("utf-8")] |
| raise flight.FlightUnauthenticatedError( |
| 'No token auth middleware found.') |
| |
| |
| class ArbitraryHeadersServerMiddlewareFactory(ServerMiddlewareFactory): |
| """A ServerMiddlewareFactory that transports arbitrary headers.""" |
| |
| def start_call(self, info, headers): |
| return ArbitraryHeadersServerMiddleware(headers) |
| |
| |
| class ArbitraryHeadersServerMiddleware(ServerMiddleware): |
| """A ServerMiddleware that transports arbitrary headers.""" |
| |
| def __init__(self, incoming): |
| self.incoming = incoming |
| |
| def sending_headers(self): |
| return self.incoming |
| |
| |
| class ArbitraryHeadersFlightServer(FlightServerBase): |
| """A Flight server that tests multiple arbitrary headers.""" |
| |
| def do_action(self, context, action): |
| middleware = context.get_middleware("arbitrary-headers") |
| if middleware: |
| headers = middleware.sending_headers() |
| header_1 = case_insensitive_header_lookup( |
| headers, |
| 'test-header-1' |
| ) |
| header_2 = case_insensitive_header_lookup( |
| headers, |
| 'test-header-2' |
| ) |
| value1 = header_1[0].encode("utf-8") |
| value2 = header_2[0].encode("utf-8") |
| return [value1, value2] |
| raise flight.FlightServerError("No headers middleware found") |
| |
| |
| class HeaderServerMiddleware(ServerMiddleware): |
| """Expose a per-call value to the RPC method body.""" |
| |
| def __init__(self, special_value): |
| self.special_value = special_value |
| |
| |
| class HeaderServerMiddlewareFactory(ServerMiddlewareFactory): |
| """Expose a per-call hard-coded value to the RPC method body.""" |
| |
| def start_call(self, info, headers): |
| return HeaderServerMiddleware("right value") |
| |
| |
| class HeaderFlightServer(FlightServerBase): |
| """Echo back the per-call hard-coded value.""" |
| |
| def do_action(self, context, action): |
| middleware = context.get_middleware("test") |
| if middleware: |
| return [middleware.special_value.encode()] |
| return [b""] |
| |
| |
| class MultiHeaderFlightServer(FlightServerBase): |
| """Test sending/receiving multiple (binary-valued) headers.""" |
| |
| def do_action(self, context, action): |
| middleware = context.get_middleware("test") |
| headers = repr(middleware.client_headers).encode("utf-8") |
| return [headers] |
| |
| |
| class SelectiveAuthServerMiddlewareFactory(ServerMiddlewareFactory): |
| """Deny access to certain methods based on a header.""" |
| |
| def start_call(self, info, headers): |
| if info.method == flight.FlightMethod.LIST_ACTIONS: |
| # No auth needed |
| return |
| |
| token = headers.get("x-auth-token") |
| if not token: |
| raise flight.FlightUnauthenticatedError("No token") |
| |
| token = token[0] |
| if token != "password": |
| raise flight.FlightUnauthenticatedError("Invalid token") |
| |
| return HeaderServerMiddleware(token) |
| |
| |
| class SelectiveAuthClientMiddlewareFactory(ClientMiddlewareFactory): |
| def start_call(self, info): |
| return SelectiveAuthClientMiddleware() |
| |
| |
| class SelectiveAuthClientMiddleware(ClientMiddleware): |
| def sending_headers(self): |
| return { |
| "x-auth-token": "password", |
| } |
| |
| |
| class RecordingServerMiddlewareFactory(ServerMiddlewareFactory): |
| """Record what methods were called.""" |
| |
| def __init__(self): |
| super().__init__() |
| self.methods = [] |
| |
| def start_call(self, info, headers): |
| self.methods.append(info.method) |
| return None |
| |
| |
| class RecordingClientMiddlewareFactory(ClientMiddlewareFactory): |
| """Record what methods were called.""" |
| |
| def __init__(self): |
| super().__init__() |
| self.methods = [] |
| |
| def start_call(self, info): |
| self.methods.append(info.method) |
| return None |
| |
| |
| class MultiHeaderClientMiddlewareFactory(ClientMiddlewareFactory): |
| """Test sending/receiving multiple (binary-valued) headers.""" |
| |
| def __init__(self): |
| # Read in test_middleware_multi_header below. |
| # The middleware instance will update this value. |
| self.last_headers = {} |
| |
| def start_call(self, info): |
| return MultiHeaderClientMiddleware(self) |
| |
| |
| class MultiHeaderClientMiddleware(ClientMiddleware): |
| """Test sending/receiving multiple (binary-valued) headers.""" |
| |
| EXPECTED = { |
| "x-text": ["foo", "bar"], |
| "x-binary-bin": [b"\x00", b"\x01"], |
| } |
| |
| def __init__(self, factory): |
| self.factory = factory |
| |
| def sending_headers(self): |
| return self.EXPECTED |
| |
| def received_headers(self, headers): |
| # Let the test code know what the last set of headers we |
| # received were. |
| self.factory.last_headers = headers |
| |
| |
| class MultiHeaderServerMiddlewareFactory(ServerMiddlewareFactory): |
| """Test sending/receiving multiple (binary-valued) headers.""" |
| |
| def start_call(self, info, headers): |
| return MultiHeaderServerMiddleware(headers) |
| |
| |
| class MultiHeaderServerMiddleware(ServerMiddleware): |
| """Test sending/receiving multiple (binary-valued) headers.""" |
| |
| def __init__(self, client_headers): |
| self.client_headers = client_headers |
| |
| def sending_headers(self): |
| return MultiHeaderClientMiddleware.EXPECTED |
| |
| |
| def test_flight_server_location_argument(): |
| locations = [ |
| None, |
| 'grpc://localhost:0', |
| ('localhost', find_free_port()), |
| ] |
| for location in locations: |
| with FlightServerBase(location) as server: |
| assert isinstance(server, FlightServerBase) |
| |
| |
| def test_server_exit_reraises_exception(): |
| with pytest.raises(ValueError): |
| with FlightServerBase(): |
| raise ValueError() |
| |
| |
| @pytest.mark.slow |
| def test_client_wait_for_available(): |
| location = ('localhost', find_free_port()) |
| server = None |
| |
| def serve(): |
| global server |
| time.sleep(0.5) |
| server = FlightServerBase(location) |
| server.serve() |
| |
| client = FlightClient(location) |
| thread = threading.Thread(target=serve, daemon=True) |
| thread.start() |
| |
| started = time.time() |
| client.wait_for_available(timeout=5) |
| elapsed = time.time() - started |
| assert elapsed >= 0.5 |
| |
| |
| def test_flight_list_flights(): |
| """Try a simple list_flights call.""" |
| with ConstantFlightServer() as server: |
| client = flight.connect(('localhost', server.port)) |
| assert list(client.list_flights()) == [] |
| flights = client.list_flights(ConstantFlightServer.CRITERIA) |
| assert len(list(flights)) == 1 |
| |
| |
| def test_flight_do_get_ints(): |
| """Try a simple do_get call.""" |
| table = simple_ints_table() |
| |
| with ConstantFlightServer() as server: |
| client = flight.connect(('localhost', server.port)) |
| data = client.do_get(flight.Ticket(b'ints')).read_all() |
| assert data.equals(table) |
| |
| options = pa.ipc.IpcWriteOptions( |
| metadata_version=pa.ipc.MetadataVersion.V4) |
| with ConstantFlightServer(options=options) as server: |
| client = flight.connect(('localhost', server.port)) |
| data = client.do_get(flight.Ticket(b'ints')).read_all() |
| assert data.equals(table) |
| |
| # Also test via RecordBatchReader interface |
| data = client.do_get(flight.Ticket(b'ints')).to_reader().read_all() |
| assert data.equals(table) |
| |
| with pytest.raises(flight.FlightServerError, |
| match="expected IpcWriteOptions, got <class 'int'>"): |
| with ConstantFlightServer(options=42) as server: |
| client = flight.connect(('localhost', server.port)) |
| data = client.do_get(flight.Ticket(b'ints')).read_all() |
| |
| |
| @pytest.mark.pandas |
| def test_do_get_ints_pandas(): |
| """Try a simple do_get call.""" |
| table = simple_ints_table() |
| |
| with ConstantFlightServer() as server: |
| client = flight.connect(('localhost', server.port)) |
| data = client.do_get(flight.Ticket(b'ints')).read_pandas() |
| assert list(data['some_ints']) == table.column(0).to_pylist() |
| |
| |
| def test_flight_do_get_dicts(): |
| table = simple_dicts_table() |
| |
| with ConstantFlightServer() as server: |
| client = flight.connect(('localhost', server.port)) |
| data = client.do_get(flight.Ticket(b'dicts')).read_all() |
| assert data.equals(table) |
| |
| |
| def test_flight_do_get_ticket(): |
| """Make sure Tickets get passed to the server.""" |
| data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] |
| table = pa.Table.from_arrays(data1, names=['a']) |
| with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server: |
| client = flight.connect(('localhost', server.port)) |
| data = client.do_get(flight.Ticket(b'the-ticket')).read_all() |
| assert data.equals(table) |
| |
| |
| def test_flight_get_info(): |
| """Make sure FlightEndpoint accepts string and object URIs.""" |
| with GetInfoFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| info = client.get_flight_info(flight.FlightDescriptor.for_command(b'')) |
| assert info.total_records == -1 |
| assert info.total_bytes == -1 |
| assert info.schema == pa.schema([('a', pa.int32())]) |
| assert len(info.endpoints) == 2 |
| assert len(info.endpoints[0].locations) == 1 |
| assert info.endpoints[0].locations[0] == flight.Location('grpc://test') |
| assert info.endpoints[1].locations[0] == \ |
| flight.Location.for_grpc_tcp('localhost', 5005) |
| |
| |
| def test_flight_get_schema(): |
| """Make sure GetSchema returns correct schema.""" |
| with GetInfoFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| info = client.get_schema(flight.FlightDescriptor.for_command(b'')) |
| assert info.schema == pa.schema([('a', pa.int32())]) |
| |
| |
| def test_list_actions(): |
| """Make sure the return type of ListActions is validated.""" |
| # ARROW-6392 |
| with ListActionsErrorFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| with pytest.raises( |
| flight.FlightServerError, |
| match=("Results of list_actions must be " |
| "ActionType or tuple") |
| ): |
| list(client.list_actions()) |
| |
| with ListActionsFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| assert list(client.list_actions()) == \ |
| ListActionsFlightServer.expected_actions() |
| |
| |
| class ConvenienceServer(FlightServerBase): |
| """ |
| Server for testing various implementation conveniences (auto-boxing, etc.) |
| """ |
| |
| @property |
| def simple_action_results(self): |
| return [b'foo', b'bar', b'baz'] |
| |
| def do_action(self, context, action): |
| if action.type == 'simple-action': |
| return self.simple_action_results |
| elif action.type == 'echo': |
| return [action.body] |
| elif action.type == 'bad-action': |
| return ['foo'] |
| elif action.type == 'arrow-exception': |
| raise pa.ArrowMemoryError() |
| |
| |
| def test_do_action_result_convenience(): |
| with ConvenienceServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| |
| # do_action as action type without body |
| results = [x.body for x in client.do_action('simple-action')] |
| assert results == server.simple_action_results |
| |
| # do_action with tuple of type and body |
| body = b'the-body' |
| results = [x.body for x in client.do_action(('echo', body))] |
| assert results == [body] |
| |
| |
| def test_nicer_server_exceptions(): |
| with ConvenienceServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| with pytest.raises(flight.FlightServerError, |
| match="a bytes-like object is required"): |
| list(client.do_action('bad-action')) |
| # While Flight/C++ sends across the original status code, it |
| # doesn't get mapped to the equivalent code here, since we |
| # want to be able to distinguish between client- and server- |
| # side errors. |
| with pytest.raises(flight.FlightServerError, |
| match="ArrowMemoryError"): |
| list(client.do_action('arrow-exception')) |
| |
| |
| def test_get_port(): |
| """Make sure port() works.""" |
| server = GetInfoFlightServer("grpc://localhost:0") |
| try: |
| assert server.port > 0 |
| finally: |
| server.shutdown() |
| |
| |
| @pytest.mark.skipif(os.name == 'nt', |
| reason="Unix sockets can't be tested on Windows") |
| def test_flight_domain_socket(): |
| """Try a simple do_get call over a Unix domain socket.""" |
| with tempfile.NamedTemporaryFile() as sock: |
| sock.close() |
| location = flight.Location.for_grpc_unix(sock.name) |
| with ConstantFlightServer(location=location): |
| client = FlightClient(location) |
| |
| reader = client.do_get(flight.Ticket(b'ints')) |
| table = simple_ints_table() |
| assert reader.schema.equals(table.schema) |
| data = reader.read_all() |
| assert data.equals(table) |
| |
| reader = client.do_get(flight.Ticket(b'dicts')) |
| table = simple_dicts_table() |
| assert reader.schema.equals(table.schema) |
| data = reader.read_all() |
| assert data.equals(table) |
| |
| |
| @pytest.mark.slow |
| def test_flight_large_message(): |
| """Try sending/receiving a large message via Flight. |
| |
| See ARROW-4421: by default, gRPC won't allow us to send messages > |
| 4MiB in size. |
| """ |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024 * 1024)) |
| ], names=['a']) |
| |
| with EchoFlightServer(expected_schema=data.schema) as server: |
| client = FlightClient(('localhost', server.port)) |
| writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), |
| data.schema) |
| # Write a single giant chunk |
| writer.write_table(data, 10 * 1024 * 1024) |
| writer.close() |
| result = client.do_get(flight.Ticket(b'')).read_all() |
| assert result.equals(data) |
| |
| |
| def test_flight_generator_stream(): |
| """Try downloading a flight of RecordBatches in a GeneratorStream.""" |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024)) |
| ], names=['a']) |
| |
| with EchoStreamFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), |
| data.schema) |
| writer.write_table(data) |
| writer.close() |
| result = client.do_get(flight.Ticket(b'')).read_all() |
| assert result.equals(data) |
| |
| |
| def test_flight_invalid_generator_stream(): |
| """Try streaming data with mismatched schemas.""" |
| with InvalidStreamFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| with pytest.raises(pa.ArrowException): |
| client.do_get(flight.Ticket(b'')).read_all() |
| |
| |
| def test_timeout_fires(): |
| """Make sure timeouts fire on slow requests.""" |
| # Do this in a separate thread so that if it fails, we don't hang |
| # the entire test process |
| with SlowFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| action = flight.Action("", b"") |
| options = flight.FlightCallOptions(timeout=0.2) |
| # gRPC error messages change based on version, so don't look |
| # for a particular error |
| with pytest.raises(flight.FlightTimedOutError): |
| list(client.do_action(action, options=options)) |
| |
| |
| def test_timeout_passes(): |
| """Make sure timeouts do not fire on fast requests.""" |
| with ConstantFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| options = flight.FlightCallOptions(timeout=5.0) |
| client.do_get(flight.Ticket(b'ints'), options=options).read_all() |
| |
| |
| basic_auth_handler = HttpBasicServerAuthHandler(creds={ |
| b"test": b"p4ssw0rd", |
| }) |
| |
| token_auth_handler = TokenServerAuthHandler(creds={ |
| b"test": b"p4ssw0rd", |
| }) |
| |
| |
| @pytest.mark.slow |
| def test_http_basic_unauth(): |
| """Test that auth fails when not authenticated.""" |
| with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: |
| client = FlightClient(('localhost', server.port)) |
| action = flight.Action("who-am-i", b"") |
| with pytest.raises(flight.FlightUnauthenticatedError, |
| match=".*unauthenticated.*"): |
| list(client.do_action(action)) |
| |
| |
| @pytest.mark.skipif(os.name == 'nt', |
| reason="ARROW-10013: gRPC on Windows corrupts peer()") |
| def test_http_basic_auth(): |
| """Test a Python implementation of HTTP basic authentication.""" |
| with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: |
| client = FlightClient(('localhost', server.port)) |
| action = flight.Action("who-am-i", b"") |
| client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd')) |
| results = client.do_action(action) |
| identity = next(results) |
| assert identity.body.to_pybytes() == b'test' |
| peer_address = next(results) |
| assert peer_address.body.to_pybytes() != b'' |
| |
| |
| def test_http_basic_auth_invalid_password(): |
| """Test that auth fails with the wrong password.""" |
| with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: |
| client = FlightClient(('localhost', server.port)) |
| action = flight.Action("who-am-i", b"") |
| with pytest.raises(flight.FlightUnauthenticatedError, |
| match=".*wrong password.*"): |
| client.authenticate(HttpBasicClientAuthHandler('test', 'wrong')) |
| next(client.do_action(action)) |
| |
| |
| def test_token_auth(): |
| """Test an auth mechanism that uses a handshake.""" |
| with EchoStreamFlightServer(auth_handler=token_auth_handler) as server: |
| client = FlightClient(('localhost', server.port)) |
| action = flight.Action("who-am-i", b"") |
| client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd')) |
| identity = next(client.do_action(action)) |
| assert identity.body.to_pybytes() == b'test' |
| |
| |
| def test_token_auth_invalid(): |
| """Test an auth mechanism that uses a handshake.""" |
| with EchoStreamFlightServer(auth_handler=token_auth_handler) as server: |
| client = FlightClient(('localhost', server.port)) |
| with pytest.raises(flight.FlightUnauthenticatedError): |
| client.authenticate(TokenClientAuthHandler('test', 'wrong')) |
| |
| |
| header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory() |
| no_op_auth_handler = NoopAuthHandler() |
| |
| |
| def test_authenticate_basic_token(): |
| """Test authenticate_basic_token with bearer token and auth headers.""" |
| with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ |
| "auth": HeaderAuthServerMiddlewareFactory() |
| }) as server: |
| client = FlightClient(('localhost', server.port)) |
| token_pair = client.authenticate_basic_token(b'test', b'password') |
| assert token_pair[0] == b'authorization' |
| assert token_pair[1] == b'Bearer token1234' |
| |
| |
| def test_authenticate_basic_token_invalid_password(): |
| """Test authenticate_basic_token with an invalid password.""" |
| with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ |
| "auth": HeaderAuthServerMiddlewareFactory() |
| }) as server: |
| client = FlightClient(('localhost', server.port)) |
| with pytest.raises(flight.FlightUnauthenticatedError): |
| client.authenticate_basic_token(b'test', b'badpassword') |
| |
| |
| def test_authenticate_basic_token_and_action(): |
| """Test authenticate_basic_token and doAction after authentication.""" |
| with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ |
| "auth": HeaderAuthServerMiddlewareFactory() |
| }) as server: |
| client = FlightClient(('localhost', server.port)) |
| token_pair = client.authenticate_basic_token(b'test', b'password') |
| assert token_pair[0] == b'authorization' |
| assert token_pair[1] == b'Bearer token1234' |
| options = flight.FlightCallOptions(headers=[token_pair]) |
| result = list(client.do_action( |
| action=flight.Action('test-action', b''), options=options)) |
| assert result[0].body.to_pybytes() == b'token1234' |
| |
| |
| def test_authenticate_basic_token_with_client_middleware(): |
| """Test authenticate_basic_token with client middleware |
| to intercept authorization header returned by the |
| HTTP header auth enabled server. |
| """ |
| with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ |
| "auth": HeaderAuthServerMiddlewareFactory() |
| }) as server: |
| client_auth_middleware = ClientHeaderAuthMiddlewareFactory() |
| client = FlightClient( |
| ('localhost', server.port), |
| middleware=[client_auth_middleware] |
| ) |
| encoded_credentials = base64.b64encode(b'test:password') |
| options = flight.FlightCallOptions(headers=[ |
| (b'authorization', b'Basic ' + encoded_credentials) |
| ]) |
| result = list(client.do_action( |
| action=flight.Action('test-action', b''), options=options)) |
| assert result[0].body.to_pybytes() == b'token1234' |
| assert client_auth_middleware.call_credential[0] == b'authorization' |
| assert client_auth_middleware.call_credential[1] == \ |
| b'Bearer ' + b'token1234' |
| result2 = list(client.do_action( |
| action=flight.Action('test-action', b''), options=options)) |
| assert result2[0].body.to_pybytes() == b'token1234' |
| assert client_auth_middleware.call_credential[0] == b'authorization' |
| assert client_auth_middleware.call_credential[1] == \ |
| b'Bearer ' + b'token1234' |
| |
| |
| def test_arbitrary_headers_in_flight_call_options(): |
| """Test passing multiple arbitrary headers to the middleware.""" |
| with ArbitraryHeadersFlightServer( |
| auth_handler=no_op_auth_handler, |
| middleware={ |
| "auth": HeaderAuthServerMiddlewareFactory(), |
| "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory() |
| }) as server: |
| client = FlightClient(('localhost', server.port)) |
| token_pair = client.authenticate_basic_token(b'test', b'password') |
| assert token_pair[0] == b'authorization' |
| assert token_pair[1] == b'Bearer token1234' |
| options = flight.FlightCallOptions(headers=[ |
| token_pair, |
| (b'test-header-1', b'value1'), |
| (b'test-header-2', b'value2') |
| ]) |
| result = list(client.do_action(flight.Action( |
| "test-action", b""), options=options)) |
| assert result[0].body.to_pybytes() == b'value1' |
| assert result[1].body.to_pybytes() == b'value2' |
| |
| |
| def test_location_invalid(): |
| """Test constructing invalid URIs.""" |
| with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): |
| flight.connect("%") |
| |
| with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): |
| ConstantFlightServer("%") |
| |
| |
| def test_location_unknown_scheme(): |
| """Test creating locations for unknown schemes.""" |
| assert flight.Location("s3://foo").uri == b"s3://foo" |
| assert flight.Location("https://example.com/bar.parquet").uri == \ |
| b"https://example.com/bar.parquet" |
| |
| |
| @pytest.mark.slow |
| @pytest.mark.requires_testing_data |
| def test_tls_fails(): |
| """Make sure clients cannot connect when cert verification fails.""" |
| certs = example_tls_certs() |
| |
| with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: |
| # Ensure client doesn't connect when certificate verification |
| # fails (this is a slow test since gRPC does retry a few times) |
| client = FlightClient("grpc+tls://localhost:" + str(s.port)) |
| |
| # gRPC error messages change based on version, so don't look |
| # for a particular error |
| with pytest.raises(flight.FlightUnavailableError): |
| client.do_get(flight.Ticket(b'ints')).read_all() |
| |
| |
| @pytest.mark.requires_testing_data |
| def test_tls_do_get(): |
| """Try a simple do_get call over TLS.""" |
| table = simple_ints_table() |
| certs = example_tls_certs() |
| |
| with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: |
| client = FlightClient(('localhost', s.port), |
| tls_root_certs=certs["root_cert"]) |
| data = client.do_get(flight.Ticket(b'ints')).read_all() |
| assert data.equals(table) |
| |
| |
| @pytest.mark.requires_testing_data |
| def test_tls_disable_server_verification(): |
| """Try a simple do_get call over TLS with server verification disabled.""" |
| table = simple_ints_table() |
| certs = example_tls_certs() |
| |
| with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: |
| try: |
| client = FlightClient(('localhost', s.port), |
| disable_server_verification=True) |
| except NotImplementedError: |
| pytest.skip('disable_server_verification feature is not available') |
| data = client.do_get(flight.Ticket(b'ints')).read_all() |
| assert data.equals(table) |
| |
| |
| @pytest.mark.requires_testing_data |
| def test_tls_override_hostname(): |
| """Check that incorrectly overriding the hostname fails.""" |
| certs = example_tls_certs() |
| |
| with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: |
| client = flight.connect(('localhost', s.port), |
| tls_root_certs=certs["root_cert"], |
| override_hostname="fakehostname") |
| with pytest.raises(flight.FlightUnavailableError): |
| client.do_get(flight.Ticket(b'ints')) |
| |
| |
| def test_flight_do_get_metadata(): |
| """Try a simple do_get call with metadata.""" |
| data = [ |
| pa.array([-10, -5, 0, 5, 10]) |
| ] |
| table = pa.Table.from_arrays(data, names=['a']) |
| |
| batches = [] |
| with MetadataFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| reader = client.do_get(flight.Ticket(b'')) |
| idx = 0 |
| while True: |
| try: |
| batch, metadata = reader.read_chunk() |
| batches.append(batch) |
| server_idx, = struct.unpack('<i', metadata.to_pybytes()) |
| assert idx == server_idx |
| idx += 1 |
| except StopIteration: |
| break |
| data = pa.Table.from_batches(batches) |
| assert data.equals(table) |
| |
| |
| def test_flight_do_get_metadata_v4(): |
| """Try a simple do_get call with V4 metadata version.""" |
| table = pa.Table.from_arrays( |
| [pa.array([-10, -5, 0, 5, 10])], names=['a']) |
| options = pa.ipc.IpcWriteOptions( |
| metadata_version=pa.ipc.MetadataVersion.V4) |
| with MetadataFlightServer(options=options) as server: |
| client = FlightClient(('localhost', server.port)) |
| reader = client.do_get(flight.Ticket(b'')) |
| data = reader.read_all() |
| assert data.equals(table) |
| |
| |
| def test_flight_do_put_metadata(): |
| """Try a simple do_put call with metadata.""" |
| data = [ |
| pa.array([-10, -5, 0, 5, 10]) |
| ] |
| table = pa.Table.from_arrays(data, names=['a']) |
| |
| with MetadataFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| writer, metadata_reader = client.do_put( |
| flight.FlightDescriptor.for_path(''), |
| table.schema) |
| with writer: |
| for idx, batch in enumerate(table.to_batches(max_chunksize=1)): |
| metadata = struct.pack('<i', idx) |
| writer.write_with_metadata(batch, metadata) |
| buf = metadata_reader.read() |
| assert buf is not None |
| server_idx, = struct.unpack('<i', buf.to_pybytes()) |
| assert idx == server_idx |
| |
| |
| def test_flight_do_put_limit(): |
| """Try a simple do_put call with a size limit.""" |
| large_batch = pa.RecordBatch.from_arrays([ |
| pa.array(np.ones(768, dtype=np.int64())), |
| ], names=['a']) |
| |
| with EchoFlightServer() as server: |
| client = FlightClient(('localhost', server.port), |
| write_size_limit_bytes=4096) |
| writer, metadata_reader = client.do_put( |
| flight.FlightDescriptor.for_path(''), |
| large_batch.schema) |
| with writer: |
| with pytest.raises(flight.FlightWriteSizeExceededError, |
| match="exceeded soft limit") as excinfo: |
| writer.write_batch(large_batch) |
| assert excinfo.value.limit == 4096 |
| smaller_batches = [ |
| large_batch.slice(0, 384), |
| large_batch.slice(384), |
| ] |
| for batch in smaller_batches: |
| writer.write_batch(batch) |
| expected = pa.Table.from_batches([large_batch]) |
| actual = client.do_get(flight.Ticket(b'')).read_all() |
| assert expected == actual |
| |
| |
| @pytest.mark.slow |
| def test_cancel_do_get(): |
| """Test canceling a DoGet operation on the client side.""" |
| with ConstantFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| reader = client.do_get(flight.Ticket(b'ints')) |
| reader.cancel() |
| with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"): |
| reader.read_chunk() |
| |
| |
| @pytest.mark.slow |
| def test_cancel_do_get_threaded(): |
| """Test canceling a DoGet operation from another thread.""" |
| with SlowFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| reader = client.do_get(flight.Ticket(b'ints')) |
| |
| read_first_message = threading.Event() |
| stream_canceled = threading.Event() |
| result_lock = threading.Lock() |
| raised_proper_exception = threading.Event() |
| |
| def block_read(): |
| reader.read_chunk() |
| read_first_message.set() |
| stream_canceled.wait(timeout=5) |
| try: |
| reader.read_chunk() |
| except flight.FlightCancelledError: |
| with result_lock: |
| raised_proper_exception.set() |
| |
| thread = threading.Thread(target=block_read, daemon=True) |
| thread.start() |
| read_first_message.wait(timeout=5) |
| reader.cancel() |
| stream_canceled.set() |
| thread.join(timeout=1) |
| |
| with result_lock: |
| assert raised_proper_exception.is_set() |
| |
| |
| def test_roundtrip_types(): |
| """Make sure serializable types round-trip.""" |
| ticket = flight.Ticket("foo") |
| assert ticket == flight.Ticket.deserialize(ticket.serialize()) |
| |
| desc = flight.FlightDescriptor.for_command("test") |
| assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) |
| |
| desc = flight.FlightDescriptor.for_path("a", "b", "test.arrow") |
| assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) |
| |
| info = flight.FlightInfo( |
| pa.schema([('a', pa.int32())]), |
| desc, |
| [ |
| flight.FlightEndpoint(b'', ['grpc://test']), |
| flight.FlightEndpoint( |
| b'', |
| [flight.Location.for_grpc_tcp('localhost', 5005)], |
| ), |
| ], |
| -1, |
| -1, |
| ) |
| info2 = flight.FlightInfo.deserialize(info.serialize()) |
| assert info.schema == info2.schema |
| assert info.descriptor == info2.descriptor |
| assert info.total_bytes == info2.total_bytes |
| assert info.total_records == info2.total_records |
| assert info.endpoints == info2.endpoints |
| |
| |
| def test_roundtrip_errors(): |
| """Ensure that Flight errors propagate from server to client.""" |
| with ErrorFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| with pytest.raises(flight.FlightInternalError, match=".*foo.*"): |
| list(client.do_action(flight.Action("internal", b""))) |
| with pytest.raises(flight.FlightTimedOutError, match=".*foo.*"): |
| list(client.do_action(flight.Action("timedout", b""))) |
| with pytest.raises(flight.FlightCancelledError, match=".*foo.*"): |
| list(client.do_action(flight.Action("cancel", b""))) |
| with pytest.raises(flight.FlightUnauthenticatedError, match=".*foo.*"): |
| list(client.do_action(flight.Action("unauthenticated", b""))) |
| with pytest.raises(flight.FlightUnauthorizedError, match=".*foo.*"): |
| list(client.do_action(flight.Action("unauthorized", b""))) |
| with pytest.raises(flight.FlightInternalError, match=".*foo.*"): |
| list(client.list_flights()) |
| |
| |
| def test_do_put_independent_read_write(): |
| """Ensure that separate threads can read/write on a DoPut.""" |
| # ARROW-6063: previously this would cause gRPC to abort when the |
| # writer was closed (due to simultaneous reads), or would hang |
| # forever. |
| data = [ |
| pa.array([-10, -5, 0, 5, 10]) |
| ] |
| table = pa.Table.from_arrays(data, names=['a']) |
| |
| with MetadataFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| writer, metadata_reader = client.do_put( |
| flight.FlightDescriptor.for_path(''), |
| table.schema) |
| |
| count = [0] |
| |
| def _reader_thread(): |
| while metadata_reader.read() is not None: |
| count[0] += 1 |
| |
| thread = threading.Thread(target=_reader_thread) |
| thread.start() |
| |
| batches = table.to_batches(max_chunksize=1) |
| with writer: |
| for idx, batch in enumerate(batches): |
| metadata = struct.pack('<i', idx) |
| writer.write_with_metadata(batch, metadata) |
| # Causes the server to stop writing and end the call |
| writer.done_writing() |
| # Thus reader thread will break out of loop |
| thread.join() |
| # writer.close() won't segfault since reader thread has |
| # stopped |
| assert count[0] == len(batches) |
| |
| |
| def test_server_middleware_same_thread(): |
| """Ensure that server middleware run on the same thread as the RPC.""" |
| with HeaderFlightServer(middleware={ |
| "test": HeaderServerMiddlewareFactory(), |
| }) as server: |
| client = FlightClient(('localhost', server.port)) |
| results = list(client.do_action(flight.Action(b"test", b""))) |
| assert len(results) == 1 |
| value = results[0].body.to_pybytes() |
| assert b"right value" == value |
| |
| |
| def test_middleware_reject(): |
| """Test rejecting an RPC with server middleware.""" |
| with HeaderFlightServer(middleware={ |
| "test": SelectiveAuthServerMiddlewareFactory(), |
| }) as server: |
| client = FlightClient(('localhost', server.port)) |
| # The middleware allows this through without auth. |
| with pytest.raises(pa.ArrowNotImplementedError): |
| list(client.list_actions()) |
| |
| # But not anything else. |
| with pytest.raises(flight.FlightUnauthenticatedError): |
| list(client.do_action(flight.Action(b"", b""))) |
| |
| client = FlightClient( |
| ('localhost', server.port), |
| middleware=[SelectiveAuthClientMiddlewareFactory()] |
| ) |
| response = next(client.do_action(flight.Action(b"", b""))) |
| assert b"password" == response.body.to_pybytes() |
| |
| |
| def test_middleware_mapping(): |
| """Test that middleware records methods correctly.""" |
| server_middleware = RecordingServerMiddlewareFactory() |
| client_middleware = RecordingClientMiddlewareFactory() |
| with FlightServerBase(middleware={"test": server_middleware}) as server: |
| client = FlightClient( |
| ('localhost', server.port), |
| middleware=[client_middleware] |
| ) |
| |
| descriptor = flight.FlightDescriptor.for_command(b"") |
| with pytest.raises(NotImplementedError): |
| list(client.list_flights()) |
| with pytest.raises(NotImplementedError): |
| client.get_flight_info(descriptor) |
| with pytest.raises(NotImplementedError): |
| client.get_schema(descriptor) |
| with pytest.raises(NotImplementedError): |
| client.do_get(flight.Ticket(b"")) |
| with pytest.raises(NotImplementedError): |
| writer, _ = client.do_put(descriptor, pa.schema([])) |
| writer.close() |
| with pytest.raises(NotImplementedError): |
| list(client.do_action(flight.Action(b"", b""))) |
| with pytest.raises(NotImplementedError): |
| list(client.list_actions()) |
| with pytest.raises(NotImplementedError): |
| writer, _ = client.do_exchange(descriptor) |
| writer.close() |
| |
| expected = [ |
| flight.FlightMethod.LIST_FLIGHTS, |
| flight.FlightMethod.GET_FLIGHT_INFO, |
| flight.FlightMethod.GET_SCHEMA, |
| flight.FlightMethod.DO_GET, |
| flight.FlightMethod.DO_PUT, |
| flight.FlightMethod.DO_ACTION, |
| flight.FlightMethod.LIST_ACTIONS, |
| flight.FlightMethod.DO_EXCHANGE, |
| ] |
| assert server_middleware.methods == expected |
| assert client_middleware.methods == expected |
| |
| |
| def test_extra_info(): |
| with ErrorFlightServer() as server: |
| client = FlightClient(('localhost', server.port)) |
| try: |
| list(client.do_action(flight.Action("protobuf", b""))) |
| assert False |
| except flight.FlightUnauthorizedError as e: |
| assert e.extra_info is not None |
| ei = e.extra_info |
| assert ei == b'this is an error message' |
| |
| |
| @pytest.mark.requires_testing_data |
| def test_mtls(): |
| """Test mutual TLS (mTLS) with gRPC.""" |
| certs = example_tls_certs() |
| table = simple_ints_table() |
| |
| with ConstantFlightServer( |
| tls_certificates=[certs["certificates"][0]], |
| verify_client=True, |
| root_certificates=certs["root_cert"]) as s: |
| client = FlightClient( |
| ('localhost', s.port), |
| tls_root_certs=certs["root_cert"], |
| cert_chain=certs["certificates"][0].cert, |
| private_key=certs["certificates"][0].key) |
| data = client.do_get(flight.Ticket(b'ints')).read_all() |
| assert data.equals(table) |
| |
| |
| def test_doexchange_get(): |
| """Emulate DoGet with DoExchange.""" |
| expected = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024)) |
| ], names=["a"]) |
| |
| with ExchangeFlightServer() as server: |
| client = FlightClient(("localhost", server.port)) |
| descriptor = flight.FlightDescriptor.for_command(b"get") |
| writer, reader = client.do_exchange(descriptor) |
| with writer: |
| table = reader.read_all() |
| assert expected == table |
| |
| |
| def test_doexchange_put(): |
| """Emulate DoPut with DoExchange.""" |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024)) |
| ], names=["a"]) |
| batches = data.to_batches(max_chunksize=512) |
| |
| with ExchangeFlightServer() as server: |
| client = FlightClient(("localhost", server.port)) |
| descriptor = flight.FlightDescriptor.for_command(b"put") |
| writer, reader = client.do_exchange(descriptor) |
| with writer: |
| writer.begin(data.schema) |
| for batch in batches: |
| writer.write_batch(batch) |
| writer.done_writing() |
| chunk = reader.read_chunk() |
| assert chunk.data is None |
| expected_buf = str(len(batches)).encode("utf-8") |
| assert chunk.app_metadata == expected_buf |
| |
| |
| def test_doexchange_echo(): |
| """Try a DoExchange echo server.""" |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024)) |
| ], names=["a"]) |
| batches = data.to_batches(max_chunksize=512) |
| |
| with ExchangeFlightServer() as server: |
| client = FlightClient(("localhost", server.port)) |
| descriptor = flight.FlightDescriptor.for_command(b"echo") |
| writer, reader = client.do_exchange(descriptor) |
| with writer: |
| # Read/write metadata before starting data. |
| for i in range(10): |
| buf = str(i).encode("utf-8") |
| writer.write_metadata(buf) |
| chunk = reader.read_chunk() |
| assert chunk.data is None |
| assert chunk.app_metadata == buf |
| |
| # Now write data without metadata. |
| writer.begin(data.schema) |
| for batch in batches: |
| writer.write_batch(batch) |
| assert reader.schema == data.schema |
| chunk = reader.read_chunk() |
| assert chunk.data == batch |
| assert chunk.app_metadata is None |
| |
| # And write data with metadata. |
| for i, batch in enumerate(batches): |
| buf = str(i).encode("utf-8") |
| writer.write_with_metadata(batch, buf) |
| chunk = reader.read_chunk() |
| assert chunk.data == batch |
| assert chunk.app_metadata == buf |
| |
| |
| def test_doexchange_echo_v4(): |
| """Try a DoExchange echo server using the V4 metadata version.""" |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 10 * 1024)) |
| ], names=["a"]) |
| batches = data.to_batches(max_chunksize=512) |
| |
| options = pa.ipc.IpcWriteOptions( |
| metadata_version=pa.ipc.MetadataVersion.V4) |
| with ExchangeFlightServer(options=options) as server: |
| client = FlightClient(("localhost", server.port)) |
| descriptor = flight.FlightDescriptor.for_command(b"echo") |
| writer, reader = client.do_exchange(descriptor) |
| with writer: |
| # Now write data without metadata. |
| writer.begin(data.schema, options=options) |
| for batch in batches: |
| writer.write_batch(batch) |
| assert reader.schema == data.schema |
| chunk = reader.read_chunk() |
| assert chunk.data == batch |
| assert chunk.app_metadata is None |
| |
| |
| def test_doexchange_transform(): |
| """Transform a table with a service.""" |
| data = pa.Table.from_arrays([ |
| pa.array(range(0, 1024)), |
| pa.array(range(1, 1025)), |
| pa.array(range(2, 1026)), |
| ], names=["a", "b", "c"]) |
| expected = pa.Table.from_arrays([ |
| pa.array(range(3, 1024 * 3 + 3, 3)), |
| ], names=["sum"]) |
| |
| with ExchangeFlightServer() as server: |
| client = FlightClient(("localhost", server.port)) |
| descriptor = flight.FlightDescriptor.for_command(b"transform") |
| writer, reader = client.do_exchange(descriptor) |
| with writer: |
| writer.begin(data.schema) |
| writer.write_table(data) |
| writer.done_writing() |
| table = reader.read_all() |
| assert expected == table |
| |
| |
| def test_middleware_multi_header(): |
| """Test sending/receiving multiple (binary-valued) headers.""" |
| with MultiHeaderFlightServer(middleware={ |
| "test": MultiHeaderServerMiddlewareFactory(), |
| }) as server: |
| headers = MultiHeaderClientMiddlewareFactory() |
| client = FlightClient(('localhost', server.port), middleware=[headers]) |
| response = next(client.do_action(flight.Action(b"", b""))) |
| # The server echoes the headers it got back to us. |
| raw_headers = response.body.to_pybytes().decode("utf-8") |
| client_headers = ast.literal_eval(raw_headers) |
| # Don't directly compare; gRPC may add headers like User-Agent. |
| for header, values in MultiHeaderClientMiddleware.EXPECTED.items(): |
| assert client_headers.get(header) == values |
| assert headers.last_headers.get(header) == values |
| |
| |
| @pytest.mark.requires_testing_data |
| def test_generic_options(): |
| """Test setting generic client options.""" |
| certs = example_tls_certs() |
| |
| with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: |
| # Try setting a string argument that will make requests fail |
| options = [("grpc.ssl_target_name_override", "fakehostname")] |
| client = flight.connect(('localhost', s.port), |
| tls_root_certs=certs["root_cert"], |
| generic_options=options) |
| with pytest.raises(flight.FlightUnavailableError): |
| client.do_get(flight.Ticket(b'ints')) |
| # Try setting an int argument that will make requests fail |
| options = [("grpc.max_receive_message_length", 32)] |
| client = flight.connect(('localhost', s.port), |
| tls_root_certs=certs["root_cert"], |
| generic_options=options) |
| with pytest.raises(pa.ArrowInvalid): |
| client.do_get(flight.Ticket(b'ints')) |