| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # cython: language_level = 3 |
| |
| from __future__ import absolute_import |
| |
| import collections |
| import enum |
| |
| import six |
| |
| from cython.operator cimport dereference as deref |
| |
| from pyarrow.compat import frombytes, tobytes |
| from pyarrow.lib cimport * |
| from pyarrow.lib import as_buffer |
| from pyarrow.includes.libarrow_flight cimport * |
| from pyarrow.ipc import _ReadPandasOption |
| import pyarrow.lib as lib |
| |
| |
| cdef CFlightCallOptions DEFAULT_CALL_OPTIONS |
| |
| |
| cdef class FlightCallOptions: |
| """RPC-layer options for a Flight call.""" |
| |
| cdef: |
| CFlightCallOptions options |
| |
| def __init__(self, timeout=None): |
| """Create call options. |
| |
| Parameters |
| ---------- |
| timeout : float or None |
| A timeout for the call, in seconds. None means that the |
| timeout defaults to an implementation-specific value. |
| |
| """ |
| if timeout is not None: |
| self.options.timeout = CTimeoutDuration(timeout) |
| |
| @staticmethod |
| cdef CFlightCallOptions* unwrap(obj): |
| if not obj: |
| return &DEFAULT_CALL_OPTIONS |
| elif isinstance(obj, FlightCallOptions): |
| return &((<FlightCallOptions> obj).options) |
| raise TypeError("Expected a FlightCallOptions object, not " |
| "'{}'".format(type(obj))) |
| |
| |
| _CertKeyPair = collections.namedtuple('_CertKeyPair', ['cert', 'key']) |
| |
| |
| class CertKeyPair(_CertKeyPair): |
| """A TLS certificate and key for use in Flight.""" |
| |
| |
| cdef class Action: |
| """An action executable on a Flight service.""" |
| cdef: |
| CAction action |
| |
| def __init__(self, action_type, buf): |
| """Create an action from a type and a buffer. |
| |
| Parameters |
| ---------- |
| action_type : bytes or str |
| buf : Buffer or bytes-like object |
| """ |
| self.action.type = tobytes(action_type) |
| self.action.body = pyarrow_unwrap_buffer(as_buffer(buf)) |
| |
| @property |
| def type(self): |
| """The action type.""" |
| return frombytes(self.action.type) |
| |
| @property |
| def body(self): |
| """The action body (arguments for the action).""" |
| return pyarrow_wrap_buffer(self.action.body) |
| |
| @staticmethod |
| cdef CAction unwrap(action) except *: |
| if not isinstance(action, Action): |
| raise TypeError("Must provide Action, not '{}'".format( |
| type(action))) |
| return (<Action> action).action |
| |
| |
| _ActionType = collections.namedtuple('_ActionType', ['type', 'description']) |
| |
| |
| class ActionType(_ActionType): |
| """A type of action that is executable on a Flight service.""" |
| |
| def make_action(self, buf): |
| """Create an Action with this type. |
| |
| Parameters |
| ---------- |
| buf : obj |
| An Arrow buffer or Python bytes or bytes-like object. |
| """ |
| return Action(self.type, buf) |
| |
| |
| cdef class Result: |
| """A result from executing an Action.""" |
| cdef: |
| unique_ptr[CFlightResult] result |
| |
| def __init__(self, buf): |
| """Create a new result. |
| |
| Parameters |
| ---------- |
| buf : Buffer or bytes-like object |
| """ |
| self.result.reset(new CFlightResult()) |
| self.result.get().body = pyarrow_unwrap_buffer(as_buffer(buf)) |
| |
| @property |
| def body(self): |
| """Get the Buffer containing the result.""" |
| return pyarrow_wrap_buffer(self.result.get().body) |
| |
| |
| class DescriptorType(enum.Enum): |
| """ |
| The type of a FlightDescriptor. |
| |
| Attributes |
| ---------- |
| |
| UNKNOWN |
| An unknown descriptor type. |
| |
| PATH |
| A Flight stream represented by a path. |
| |
| CMD |
| A Flight stream represented by an application-defined command. |
| |
| """ |
| |
| UNKNOWN = 0 |
| PATH = 1 |
| CMD = 2 |
| |
| |
| cdef class FlightDescriptor: |
| """A description of a data stream available from a Flight service.""" |
| cdef: |
| CFlightDescriptor descriptor |
| |
| def __init__(self): |
| raise TypeError("Do not call {}'s constructor directly, use " |
| "`pyarrow.flight.FlightDescriptor.for_{path,command}` " |
| "function instead." |
| .format(self.__class__.__name__)) |
| |
| @staticmethod |
| def for_path(*path): |
| """Create a FlightDescriptor for a resource path.""" |
| cdef FlightDescriptor result = \ |
| FlightDescriptor.__new__(FlightDescriptor) |
| result.descriptor.type = CDescriptorTypePath |
| result.descriptor.path = [tobytes(p) for p in path] |
| return result |
| |
| @staticmethod |
| def for_command(command): |
| """Create a FlightDescriptor for an opaque command.""" |
| cdef FlightDescriptor result = \ |
| FlightDescriptor.__new__(FlightDescriptor) |
| result.descriptor.type = CDescriptorTypeCmd |
| result.descriptor.cmd = tobytes(command) |
| return result |
| |
| @property |
| def descriptor_type(self): |
| """Get the type of this descriptor.""" |
| if self.descriptor.type == CDescriptorTypeUnknown: |
| return DescriptorType.UNKNOWN |
| elif self.descriptor.type == CDescriptorTypePath: |
| return DescriptorType.PATH |
| elif self.descriptor.type == CDescriptorTypeCmd: |
| return DescriptorType.CMD |
| raise RuntimeError("Invalid descriptor type!") |
| |
| @property |
| def command(self): |
| """Get the command for this descriptor.""" |
| if self.descriptor_type != DescriptorType.CMD: |
| return None |
| return self.descriptor.cmd |
| |
| @property |
| def path(self): |
| """Get the path for this descriptor.""" |
| if self.descriptor_type != DescriptorType.PATH: |
| return None |
| return self.descriptor.path |
| |
| def __repr__(self): |
| return "<FlightDescriptor type: {!r}>".format(self.descriptor_type) |
| |
| @staticmethod |
| cdef CFlightDescriptor unwrap(descriptor) except *: |
| if not isinstance(descriptor, FlightDescriptor): |
| raise TypeError("Must provide a FlightDescriptor, not '{}'".format( |
| type(descriptor))) |
| return (<FlightDescriptor> descriptor).descriptor |
| |
| |
| cdef class Ticket: |
| """A ticket for requesting a Flight stream.""" |
| |
| cdef: |
| CTicket ticket |
| |
| def __init__(self, ticket): |
| self.ticket.ticket = tobytes(ticket) |
| |
| @property |
| def ticket(self): |
| return self.ticket.ticket |
| |
| def __repr__(self): |
| return '<Ticket {}>'.format(self.ticket.ticket) |
| |
| |
| cdef class Location: |
| """The location of a Flight service.""" |
| cdef: |
| CLocation location |
| |
| def __init__(self, uri): |
| check_status(CLocation.Parse(tobytes(uri), &self.location)) |
| |
| def __repr__(self): |
| return '<Location {}>'.format(self.location.ToString()) |
| |
| @property |
| def uri(self): |
| return self.location.ToString() |
| |
| def equals(self, Location other): |
| return self == other |
| |
| def __eq__(self, other): |
| if not isinstance(other, Location): |
| return NotImplemented |
| return self.location.Equals((<Location> other).location) |
| |
| @staticmethod |
| def for_grpc_tcp(host, port): |
| """Create a Location for a TCP-based gRPC service.""" |
| cdef: |
| c_string c_host = tobytes(host) |
| int c_port = port |
| Location result = Location.__new__(Location) |
| check_status(CLocation.ForGrpcTcp(c_host, c_port, &result.location)) |
| return result |
| |
| @staticmethod |
| def for_grpc_tls(host, port): |
| """Create a Location for a TLS-based gRPC service.""" |
| cdef: |
| c_string c_host = tobytes(host) |
| int c_port = port |
| Location result = Location.__new__(Location) |
| check_status(CLocation.ForGrpcTls(c_host, c_port, &result.location)) |
| return result |
| |
| @staticmethod |
| def for_grpc_unix(path): |
| """Create a Location for a domain socket-based gRPC service.""" |
| cdef: |
| c_string c_path = tobytes(path) |
| Location result = Location.__new__(Location) |
| check_status(CLocation.ForGrpcUnix(c_path, &result.location)) |
| return result |
| |
| @staticmethod |
| cdef Location wrap(CLocation location): |
| cdef Location result = Location.__new__(Location) |
| result.location = location |
| return result |
| |
| @staticmethod |
| cdef CLocation unwrap(object location) except *: |
| cdef CLocation c_location |
| if isinstance(location, six.text_type): |
| check_status(CLocation.Parse(tobytes(location), &c_location)) |
| return c_location |
| elif not isinstance(location, Location): |
| raise TypeError("Must provide a Location, not '{}'".format( |
| type(location))) |
| return (<Location> location).location |
| |
| |
| cdef class FlightEndpoint: |
| """A Flight stream, along with the ticket and locations to access it.""" |
| cdef: |
| CFlightEndpoint endpoint |
| |
| def __init__(self, ticket, locations): |
| """Create a FlightEndpoint from a ticket and list of locations. |
| |
| Parameters |
| ---------- |
| ticket : Ticket or bytes |
| the ticket needed to access this flight |
| locations : list of string URIs |
| locations where this flight is available |
| |
| Raises |
| ------ |
| ArrowException |
| If one of the location URIs is not a valid URI. |
| """ |
| cdef: |
| CLocation c_location |
| |
| if isinstance(ticket, Ticket): |
| self.endpoint.ticket.ticket = tobytes(ticket.ticket) |
| else: |
| self.endpoint.ticket.ticket = tobytes(ticket) |
| |
| for location in locations: |
| if isinstance(location, Location): |
| c_location = (<Location> location).location |
| else: |
| c_location = CLocation() |
| check_status(CLocation.Parse(tobytes(location), &c_location)) |
| self.endpoint.locations.push_back(c_location) |
| |
| @property |
| def ticket(self): |
| """Get the ticket in this endpoint.""" |
| return Ticket(self.endpoint.ticket.ticket) |
| |
| @property |
| def locations(self): |
| return [Location.wrap(location) |
| for location in self.endpoint.locations] |
| |
| |
| cdef class FlightInfo: |
| """A description of a Flight stream.""" |
| cdef: |
| unique_ptr[CFlightInfo] info |
| |
| def __init__(self, Schema schema, FlightDescriptor descriptor, endpoints, |
| total_records, total_bytes): |
| """Create a FlightInfo object from a schema, descriptor, and endpoints. |
| |
| Parameters |
| ---------- |
| schema : Schema |
| the schema of the data in this flight. |
| descriptor : FlightDescriptor |
| the descriptor for this flight. |
| endpoints : list of FlightEndpoint |
| a list of endpoints where this flight is available. |
| total_records : int |
| the total records in this flight, or -1 if unknown |
| total_bytes : int |
| the total bytes in this flight, or -1 if unknown |
| """ |
| cdef: |
| shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) |
| vector[CFlightEndpoint] c_endpoints |
| |
| for endpoint in endpoints: |
| if isinstance(endpoint, FlightEndpoint): |
| c_endpoints.push_back((<FlightEndpoint> endpoint).endpoint) |
| else: |
| raise TypeError('Endpoint {} is not instance of' |
| ' FlightEndpoint'.format(endpoint)) |
| |
| check_status(CreateFlightInfo(c_schema, |
| descriptor.descriptor, |
| c_endpoints, |
| total_records, |
| total_bytes, &self.info)) |
| |
| @property |
| def total_records(self): |
| """The total record count of this flight, or -1 if unknown.""" |
| return self.info.get().total_records() |
| |
| @property |
| def total_bytes(self): |
| """The size in bytes of the data in this flight, or -1 if unknown.""" |
| return self.info.get().total_bytes() |
| |
| @property |
| def schema(self): |
| """The schema of the data in this flight.""" |
| cdef: |
| shared_ptr[CSchema] schema |
| CDictionaryMemo dummy_memo |
| |
| check_status(self.info.get().GetSchema(&dummy_memo, &schema)) |
| return pyarrow_wrap_schema(schema) |
| |
| @property |
| def descriptor(self): |
| """The descriptor of the data in this flight.""" |
| cdef FlightDescriptor result = \ |
| FlightDescriptor.__new__(FlightDescriptor) |
| result.descriptor = self.info.get().descriptor() |
| return result |
| |
| @property |
| def endpoints(self): |
| """The endpoints where this flight is available.""" |
| # TODO: get Cython to iterate over reference directly |
| cdef: |
| vector[CFlightEndpoint] endpoints = self.info.get().endpoints() |
| FlightEndpoint py_endpoint |
| |
| result = [] |
| for endpoint in endpoints: |
| py_endpoint = FlightEndpoint.__new__(FlightEndpoint) |
| py_endpoint.endpoint = endpoint |
| result.append(py_endpoint) |
| return result |
| |
| |
| cdef class FlightStreamChunk: |
| """A RecordBatch with application metadata on the side.""" |
| cdef: |
| CFlightStreamChunk chunk |
| |
| @property |
| def data(self): |
| if self.chunk.data == NULL: |
| return None |
| return pyarrow_wrap_batch(self.chunk.data) |
| |
| @property |
| def app_metadata(self): |
| if self.chunk.app_metadata == NULL: |
| return None |
| return pyarrow_wrap_buffer(self.chunk.app_metadata) |
| |
| def __iter__(self): |
| return iter((self.data, self.app_metadata)) |
| |
| |
| cdef class _MetadataRecordBatchReader: |
| """A reader for Flight streams.""" |
| |
| # Needs to be separate class so the "real" class can subclass the |
| # pure-Python mixin class |
| |
| cdef dict __dict__ |
| cdef shared_ptr[CMetadataRecordBatchReader] reader |
| |
| cdef readonly: |
| Schema schema |
| |
| |
| cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader, |
| _ReadPandasOption): |
| """A reader for Flight streams.""" |
| |
| def __iter__(self): |
| while True: |
| yield self.read_chunk() |
| |
| def read_all(self): |
| """Read the entire contents of the stream as a Table.""" |
| cdef: |
| shared_ptr[CTable] c_table |
| with nogil: |
| check_status(self.reader.get().ReadAll(&c_table)) |
| return pyarrow_wrap_table(c_table) |
| |
| def read_chunk(self): |
| """Read the next RecordBatch along with any metadata. |
| |
| Returns |
| ------- |
| data : RecordBatch |
| The next RecordBatch in the stream. |
| app_metadata : Buffer or None |
| Application-specific metadata for the batch as defined by |
| Flight. |
| |
| Raises |
| ------ |
| StopIteration |
| when the stream is finished |
| """ |
| cdef: |
| FlightStreamChunk chunk = FlightStreamChunk() |
| |
| with nogil: |
| check_status(self.reader.get().Next(&chunk.chunk)) |
| |
| if chunk.chunk.data == NULL: |
| raise StopIteration |
| |
| return chunk |
| |
| |
| cdef class FlightStreamReader(MetadataRecordBatchReader): |
| """A reader that can also be canceled.""" |
| |
| def cancel(self): |
| """Cancel the read operation.""" |
| with nogil: |
| (<CFlightStreamReader*> self.reader.get()).Cancel() |
| |
| |
| cdef class FlightStreamWriter(_CRecordBatchWriter): |
| """A RecordBatchWriter that also allows writing application metadata.""" |
| |
| def write_with_metadata(self, RecordBatch batch, buf): |
| """Write a RecordBatch along with Flight metadata. |
| |
| Parameters |
| ---------- |
| batch : RecordBatch |
| The next RecordBatch in the stream. |
| buf : Buffer |
| Application-specific metadata for the batch as defined by |
| Flight. |
| """ |
| cdef shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(as_buffer(buf)) |
| with nogil: |
| check_status( |
| (<CFlightStreamWriter*> self.writer.get()) |
| .WriteWithMetadata(deref(batch.batch), |
| c_buf, |
| 1)) |
| |
| |
| cdef class FlightMetadataReader: |
| """A reader for Flight metadata messages sent during a DoPut.""" |
| |
| cdef: |
| unique_ptr[CFlightMetadataReader] reader |
| |
| def read(self): |
| """Read the next metadata message.""" |
| cdef shared_ptr[CBuffer] buf |
| with nogil: |
| check_status(self.reader.get().ReadMetadata(&buf)) |
| if buf == NULL: |
| return None |
| return pyarrow_wrap_buffer(buf) |
| |
| |
| cdef class FlightMetadataWriter: |
| """A sender for Flight metadata messages during a DoPut.""" |
| |
| cdef: |
| unique_ptr[CFlightMetadataWriter] writer |
| |
| def write(self, message): |
| """Write the next metadata message. |
| |
| Parameters |
| ---------- |
| message : Buffer |
| """ |
| cdef shared_ptr[CBuffer] buf = \ |
| pyarrow_unwrap_buffer(as_buffer(message)) |
| with nogil: |
| check_status(self.writer.get().WriteMetadata(deref(buf))) |
| |
| |
| cdef class FlightClient: |
| """A client to a Flight service.""" |
| cdef: |
| unique_ptr[CFlightClient] client |
| |
| def __init__(self): |
| raise TypeError("Do not call {}'s constructor directly, use " |
| "`pyarrow.flight.FlightClient.connect` instead." |
| .format(self.__class__.__name__)) |
| |
| @staticmethod |
| def connect(location, tls_root_certs=None, override_hostname=None): |
| """ |
| Connect to a Flight service on the given host and port. |
| |
| Parameters |
| ---------- |
| location : Location |
| location to connect to |
| |
| tls_root_certs : bytes or None |
| PEM-encoded |
| unsafe_override_hostname : str or None |
| Override the hostname checked by TLS. Insecure, use with caution. |
| """ |
| cdef: |
| FlightClient result = FlightClient.__new__(FlightClient) |
| int c_port = 0 |
| CLocation c_location = Location.unwrap(location) |
| CFlightClientOptions c_options |
| |
| if tls_root_certs: |
| c_options.tls_root_certs = tobytes(tls_root_certs) |
| if override_hostname: |
| c_options.override_hostname = tobytes(override_hostname) |
| |
| with nogil: |
| check_status(CFlightClient.Connect(c_location, c_options, |
| &result.client)) |
| |
| return result |
| |
| def authenticate(self, auth_handler, options: FlightCallOptions = None): |
| """Authenticate to the server. |
| |
| Parameters |
| ---------- |
| auth_handler : ClientAuthHandler |
| The authentication mechanism to use. |
| options : FlightCallOptions |
| Options for this call. |
| """ |
| cdef: |
| unique_ptr[CClientAuthHandler] handler |
| CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) |
| |
| if not isinstance(auth_handler, ClientAuthHandler): |
| raise TypeError( |
| "FlightClient.authenticate takes a ClientAuthHandler, " |
| "not '{}'".format(type(auth_handler))) |
| handler.reset((<ClientAuthHandler> auth_handler).to_handler()) |
| with nogil: |
| check_status(self.client.get().Authenticate(deref(c_options), |
| move(handler))) |
| |
| def list_actions(self, options: FlightCallOptions = None): |
| """List the actions available on a service.""" |
| cdef: |
| vector[CActionType] results |
| CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) |
| |
| with nogil: |
| check_status( |
| self.client.get().ListActions(deref(c_options), &results)) |
| |
| result = [] |
| for action_type in results: |
| py_action = ActionType(frombytes(action_type.type), |
| frombytes(action_type.description)) |
| result.append(py_action) |
| |
| return result |
| |
| def do_action(self, action: Action, options: FlightCallOptions = None): |
| """Execute an action on a service.""" |
| cdef: |
| unique_ptr[CResultStream] results |
| Result result |
| CAction c_action = Action.unwrap(action) |
| CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) |
| with nogil: |
| check_status( |
| self.client.get().DoAction(deref(c_options), c_action, |
| &results)) |
| |
| while True: |
| result = Result.__new__(Result) |
| with nogil: |
| check_status(results.get().Next(&result.result)) |
| if result.result == NULL: |
| break |
| yield result |
| |
| def list_flights(self, options: FlightCallOptions = None): |
| """List the flights available on a service.""" |
| cdef: |
| unique_ptr[CFlightListing] listing |
| FlightInfo result |
| CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) |
| CCriteria c_criteria |
| |
| with nogil: |
| check_status(self.client.get().ListFlights(deref(c_options), |
| c_criteria, &listing)) |
| |
| while True: |
| result = FlightInfo.__new__(FlightInfo) |
| with nogil: |
| check_status(listing.get().Next(&result.info)) |
| if result.info == NULL: |
| break |
| yield result |
| |
| def get_flight_info(self, descriptor: FlightDescriptor, |
| options: FlightCallOptions = None): |
| """Request information about an available flight.""" |
| cdef: |
| FlightInfo result = FlightInfo.__new__(FlightInfo) |
| CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) |
| CFlightDescriptor c_descriptor = \ |
| FlightDescriptor.unwrap(descriptor) |
| |
| with nogil: |
| check_status(self.client.get().GetFlightInfo( |
| deref(c_options), c_descriptor, &result.info)) |
| |
| return result |
| |
| def do_get(self, ticket: Ticket, options: FlightCallOptions = None): |
| """Request the data for a flight. |
| |
| Returns |
| ------- |
| reader : FlightStreamReader |
| """ |
| cdef: |
| unique_ptr[CFlightStreamReader] reader |
| CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) |
| |
| with nogil: |
| check_status( |
| self.client.get().DoGet( |
| deref(c_options), ticket.ticket, &reader)) |
| result = FlightStreamReader() |
| result.reader.reset(reader.release()) |
| result.schema = pyarrow_wrap_schema(result.reader.get().schema()) |
| return result |
| |
| def do_put(self, descriptor: FlightDescriptor, schema: Schema, |
| options: FlightCallOptions = None): |
| """Upload data to a flight. |
| |
| Returns |
| ------- |
| writer : FlightStreamWriter |
| reader : FlightMetadataReader |
| """ |
| cdef: |
| shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema) |
| unique_ptr[CFlightStreamWriter] writer |
| unique_ptr[CFlightMetadataReader] metadata_reader |
| CFlightCallOptions* c_options = FlightCallOptions.unwrap(options) |
| CFlightDescriptor c_descriptor = \ |
| FlightDescriptor.unwrap(descriptor) |
| FlightMetadataReader reader = FlightMetadataReader() |
| |
| with nogil: |
| check_status(self.client.get().DoPut( |
| deref(c_options), |
| c_descriptor, |
| c_schema, |
| &writer, |
| &reader.reader)) |
| result = FlightStreamWriter() |
| result.writer.reset(writer.release()) |
| return result, reader |
| |
| |
| cdef class FlightDataStream: |
| """Abstract base class for Flight data streams.""" |
| |
| cdef CFlightDataStream* to_stream(self) except *: |
| """Create the C++ data stream for the backing Python object. |
| |
| We don't expose the C++ object to Python, so we can manage its |
| lifetime from the Cython/C++ side. |
| """ |
| raise NotImplementedError |
| |
| |
| cdef class RecordBatchStream(FlightDataStream): |
| """A Flight data stream backed by RecordBatches.""" |
| cdef: |
| object data_source |
| |
| def __init__(self, data_source): |
| """Create a RecordBatchStream from a data source. |
| |
| Parameters |
| ---------- |
| data_source : RecordBatchReader or Table |
| """ |
| if (not isinstance(data_source, _CRecordBatchReader) and |
| not isinstance(data_source, lib.Table)): |
| raise TypeError("Expected RecordBatchReader or Table, " |
| "but got: {}".format(type(data_source))) |
| self.data_source = data_source |
| |
| cdef CFlightDataStream* to_stream(self) except *: |
| cdef: |
| shared_ptr[CRecordBatchReader] reader |
| if isinstance(self.data_source, _CRecordBatchReader): |
| reader = (<_CRecordBatchReader> self.data_source).reader |
| elif isinstance(self.data_source, lib.Table): |
| table = (<Table> self.data_source).table |
| reader.reset(new TableBatchReader(deref(table))) |
| else: |
| raise RuntimeError("Can't construct RecordBatchStream " |
| "from type {}".format(type(self.data_source))) |
| return new CRecordBatchStream(reader) |
| |
| |
| cdef class GeneratorStream(FlightDataStream): |
| """A Flight data stream backed by a Python generator.""" |
| cdef: |
| shared_ptr[CSchema] schema |
| object generator |
| # A substream currently being consumed by the client, if |
| # present. Produced by the generator. |
| unique_ptr[CFlightDataStream] current_stream |
| |
| def __init__(self, schema, generator): |
| """Create a GeneratorStream from a Python generator. |
| |
| Parameters |
| ---------- |
| schema : Schema |
| The schema for the data to be returned. |
| |
| generator : iterator or iterable |
| The generator should yield other FlightDataStream objects, |
| Tables, RecordBatches, or RecordBatchReaders. |
| """ |
| self.schema = pyarrow_unwrap_schema(schema) |
| self.generator = iter(generator) |
| |
| cdef CFlightDataStream* to_stream(self) except *: |
| cdef: |
| function[cb_data_stream_next] callback = &_data_stream_next |
| return new CPyGeneratorFlightDataStream(self, self.schema, callback) |
| |
| |
| cdef class ServerCallContext: |
| """Per-call state/context.""" |
| cdef: |
| const CServerCallContext* context |
| |
| def peer_identity(self): |
| """Get the identity of the authenticated peer. |
| |
| May be the empty string. |
| """ |
| return tobytes(self.context.peer_identity()) |
| |
| @staticmethod |
| cdef ServerCallContext wrap(const CServerCallContext& context): |
| cdef ServerCallContext result = \ |
| ServerCallContext.__new__(ServerCallContext) |
| result.context = &context |
| return result |
| |
| |
| cdef class ServerAuthReader: |
| """A reader for messages from the client during an auth handshake.""" |
| cdef: |
| CServerAuthReader* reader |
| |
| def read(self): |
| cdef c_string token |
| if not self.reader: |
| raise ValueError("Cannot use ServerAuthReader outside " |
| "ServerAuthHandler.authenticate") |
| with nogil: |
| check_status(self.reader.Read(&token)) |
| return token |
| |
| cdef void poison(self): |
| """Prevent further usage of this object. |
| |
| This object is constructed by taking a pointer to a reference, |
| so we want to make sure Python users do not access this after |
| the reference goes away. |
| """ |
| self.reader = NULL |
| |
| @staticmethod |
| cdef ServerAuthReader wrap(CServerAuthReader* reader): |
| cdef ServerAuthReader result = \ |
| ServerAuthReader.__new__(ServerAuthReader) |
| result.reader = reader |
| return result |
| |
| |
| cdef class ServerAuthSender: |
| """A writer for messages to the client during an auth handshake.""" |
| cdef: |
| CServerAuthSender* sender |
| |
| def write(self, message): |
| cdef c_string c_message = tobytes(message) |
| if not self.sender: |
| raise ValueError("Cannot use ServerAuthSender outside " |
| "ServerAuthHandler.authenticate") |
| with nogil: |
| check_status(self.sender.Write(c_message)) |
| |
| cdef void poison(self): |
| """Prevent further usage of this object. |
| |
| This object is constructed by taking a pointer to a reference, |
| so we want to make sure Python users do not access this after |
| the reference goes away. |
| """ |
| self.sender = NULL |
| |
| @staticmethod |
| cdef ServerAuthSender wrap(CServerAuthSender* sender): |
| cdef ServerAuthSender result = \ |
| ServerAuthSender.__new__(ServerAuthSender) |
| result.sender = sender |
| return result |
| |
| |
| cdef class ClientAuthReader: |
| """A reader for messages from the server during an auth handshake.""" |
| cdef: |
| CClientAuthReader* reader |
| |
| def read(self): |
| cdef c_string token |
| if not self.reader: |
| raise ValueError("Cannot use ClientAuthReader outside " |
| "ClientAuthHandler.authenticate") |
| with nogil: |
| check_status(self.reader.Read(&token)) |
| return token |
| |
| cdef void poison(self): |
| """Prevent further usage of this object. |
| |
| This object is constructed by taking a pointer to a reference, |
| so we want to make sure Python users do not access this after |
| the reference goes away. |
| """ |
| self.reader = NULL |
| |
| @staticmethod |
| cdef ClientAuthReader wrap(CClientAuthReader* reader): |
| cdef ClientAuthReader result = \ |
| ClientAuthReader.__new__(ClientAuthReader) |
| result.reader = reader |
| return result |
| |
| |
| cdef class ClientAuthSender: |
| """A writer for messages to the server during an auth handshake.""" |
| cdef: |
| CClientAuthSender* sender |
| |
| def write(self, message): |
| cdef c_string c_message = tobytes(message) |
| if not self.sender: |
| raise ValueError("Cannot use ClientAuthSender outside " |
| "ClientAuthHandler.authenticate") |
| with nogil: |
| check_status(self.sender.Write(c_message)) |
| |
| cdef void poison(self): |
| """Prevent further usage of this object. |
| |
| This object is constructed by taking a pointer to a reference, |
| so we want to make sure Python users do not access this after |
| the reference goes away. |
| """ |
| self.sender = NULL |
| |
| @staticmethod |
| cdef ClientAuthSender wrap(CClientAuthSender* sender): |
| cdef ClientAuthSender result = \ |
| ClientAuthSender.__new__(ClientAuthSender) |
| result.sender = sender |
| return result |
| |
| |
| cdef void _data_stream_next(void* self, CFlightPayload* payload) except *: |
| """Callback for implementing FlightDataStream in Python.""" |
| cdef: |
| unique_ptr[CFlightDataStream] data_stream |
| |
| py_stream = <object> self |
| if not isinstance(py_stream, GeneratorStream): |
| raise RuntimeError("self object in callback is not GeneratorStream") |
| stream = <GeneratorStream> py_stream |
| |
| if stream.current_stream != nullptr: |
| check_status(stream.current_stream.get().Next(payload)) |
| # If the stream ended, see if there's another stream from the |
| # generator |
| if payload.ipc_message.metadata != nullptr: |
| return |
| stream.current_stream.reset(nullptr) |
| |
| try: |
| result = next(stream.generator) |
| except StopIteration: |
| payload.ipc_message.metadata.reset(<CBuffer*> nullptr) |
| return |
| |
| if isinstance(result, (list, tuple)): |
| result, metadata = result |
| else: |
| result, metadata = result, None |
| |
| if isinstance(result, (Table, _CRecordBatchReader)): |
| if metadata: |
| raise ValueError("Can only return metadata alongside a " |
| "RecordBatch.") |
| result = RecordBatchStream(result) |
| |
| stream_schema = pyarrow_wrap_schema(stream.schema) |
| if isinstance(result, FlightDataStream): |
| if metadata: |
| raise ValueError("Can only return metadata alongside a " |
| "RecordBatch.") |
| data_stream = unique_ptr[CFlightDataStream]( |
| (<FlightDataStream> result).to_stream()) |
| substream_schema = pyarrow_wrap_schema(data_stream.get().schema()) |
| if substream_schema != stream_schema: |
| raise ValueError("Got a FlightDataStream whose schema does not " |
| "match the declared schema of this " |
| "GeneratorStream. " |
| "Got: {}\nExpected: {}".format(substream_schema, |
| stream_schema)) |
| stream.current_stream.reset( |
| new CPyFlightDataStream(result, move(data_stream))) |
| _data_stream_next(self, payload) |
| elif isinstance(result, RecordBatch): |
| batch = <RecordBatch> result |
| if batch.schema != stream_schema: |
| raise ValueError("Got a RecordBatch whose schema does not " |
| "match the declared schema of this " |
| "GeneratorStream. " |
| "Got: {}\nExpected: {}".format(batch.schema, |
| stream_schema)) |
| check_status(_GetRecordBatchPayload( |
| deref(batch.batch), |
| c_default_memory_pool(), |
| &payload.ipc_message)) |
| if metadata: |
| payload.app_metadata = pyarrow_unwrap_buffer(as_buffer(metadata)) |
| else: |
| raise TypeError("GeneratorStream must be initialized with " |
| "an iterator of FlightDataStream, Table, " |
| "RecordBatch, or RecordBatchStreamReader objects, " |
| "not {}.".format(type(result))) |
| |
| |
| cdef void _list_flights(void* self, const CServerCallContext& context, |
| const CCriteria* c_criteria, |
| unique_ptr[CFlightListing]* listing) except *: |
| """Callback for implementing ListFlights in Python.""" |
| cdef: |
| vector[CFlightInfo] flights |
| result = (<object> self).list_flights(ServerCallContext.wrap(context), |
| c_criteria.expression) |
| for info in result: |
| if not isinstance(info, FlightInfo): |
| raise TypeError("FlightServerBase.list_flights must return " |
| "FlightInfo instances, but got {}".format( |
| type(info))) |
| flights.push_back(deref((<FlightInfo> info).info.get())) |
| listing.reset(new CSimpleFlightListing(flights)) |
| |
| |
| cdef void _get_flight_info(void* self, const CServerCallContext& context, |
| CFlightDescriptor c_descriptor, |
| unique_ptr[CFlightInfo]* info) except *: |
| """Callback for implementing Flight servers in Python.""" |
| cdef: |
| FlightDescriptor py_descriptor = \ |
| FlightDescriptor.__new__(FlightDescriptor) |
| py_descriptor.descriptor = c_descriptor |
| result = (<object> self).get_flight_info(ServerCallContext.wrap(context), |
| py_descriptor) |
| if not isinstance(result, FlightInfo): |
| raise TypeError("FlightServerBase.get_flight_info must return " |
| "a FlightInfo instance, but got {}".format( |
| type(result))) |
| info.reset(new CFlightInfo(deref((<FlightInfo> result).info.get()))) |
| |
| |
| cdef void _do_put(void* self, const CServerCallContext& context, |
| unique_ptr[CFlightMessageReader] reader, |
| unique_ptr[CFlightMetadataWriter] writer) except *: |
| """Callback for implementing Flight servers in Python.""" |
| cdef: |
| MetadataRecordBatchReader py_reader = MetadataRecordBatchReader() |
| FlightMetadataWriter py_writer = FlightMetadataWriter() |
| FlightDescriptor descriptor = \ |
| FlightDescriptor.__new__(FlightDescriptor) |
| |
| descriptor.descriptor = reader.get().descriptor() |
| py_reader.reader.reset(reader.release()) |
| py_reader.schema = pyarrow_wrap_schema( |
| py_reader.reader.get().schema()) |
| py_writer.writer.reset(writer.release()) |
| (<object> self).do_put(ServerCallContext.wrap(context), descriptor, |
| py_reader, py_writer) |
| |
| |
| cdef void _do_get(void* self, const CServerCallContext& context, |
| CTicket ticket, |
| unique_ptr[CFlightDataStream]* stream) except *: |
| """Callback for implementing Flight servers in Python.""" |
| cdef: |
| unique_ptr[CFlightDataStream] data_stream |
| |
| py_ticket = Ticket(ticket.ticket) |
| result = (<object> self).do_get(ServerCallContext.wrap(context), |
| py_ticket) |
| if not isinstance(result, FlightDataStream): |
| raise TypeError("FlightServerBase.do_get must return " |
| "a FlightDataStream") |
| data_stream = unique_ptr[CFlightDataStream]( |
| (<FlightDataStream> result).to_stream()) |
| stream[0] = unique_ptr[CFlightDataStream]( |
| new CPyFlightDataStream(result, move(data_stream))) |
| |
| |
| cdef void _do_action_result_next(void* self, |
| unique_ptr[CFlightResult]* result) except *: |
| """Callback for implementing Flight servers in Python.""" |
| cdef: |
| CFlightResult* c_result |
| |
| try: |
| action_result = next(<object> self) |
| if not isinstance(action_result, Result): |
| raise TypeError("Result of FlightServerBase.do_action must " |
| "return an iterator of Result objects") |
| c_result = (<Result> action_result).result.get() |
| result.reset(new CFlightResult(deref(c_result))) |
| except StopIteration: |
| result.reset(nullptr) |
| |
| |
| cdef void _do_action(void* self, const CServerCallContext& context, |
| const CAction& action, |
| unique_ptr[CResultStream]* result) except *: |
| """Callback for implementing Flight servers in Python.""" |
| cdef: |
| function[cb_result_next] ptr = &_do_action_result_next |
| py_action = Action(action.type, pyarrow_wrap_buffer(action.body)) |
| responses = (<object> self).do_action(ServerCallContext.wrap(context), |
| py_action) |
| result.reset(new CPyFlightResultStream(responses, ptr)) |
| |
| |
| cdef void _list_actions(void* self, const CServerCallContext& context, |
| vector[CActionType]* actions) except *: |
| """Callback for implementing Flight servers in Python.""" |
| cdef: |
| CActionType action_type |
| # Method should return a list of ActionTypes or similar tuple |
| result = (<object> self).list_actions(ServerCallContext.wrap(context)) |
| for action in result: |
| action_type.type = tobytes(action[0]) |
| action_type.description = tobytes(action[1]) |
| actions.push_back(action_type) |
| |
| |
| cdef void _server_authenticate(void* self, CServerAuthSender* outgoing, |
| CServerAuthReader* incoming) except *: |
| """Callback for implementing authentication in Python.""" |
| sender = ServerAuthSender.wrap(outgoing) |
| reader = ServerAuthReader.wrap(incoming) |
| try: |
| (<object> self).authenticate(sender, reader) |
| finally: |
| sender.poison() |
| reader.poison() |
| |
| |
| cdef void _is_valid(void* self, const c_string& token, |
| c_string* peer_identity) except *: |
| """Callback for implementing authentication in Python.""" |
| cdef c_string c_result |
| c_result = tobytes((<object> self).is_valid(token)) |
| peer_identity[0] = c_result |
| |
| |
| cdef void _client_authenticate(void* self, CClientAuthSender* outgoing, |
| CClientAuthReader* incoming) except *: |
| """Callback for implementing authentication in Python.""" |
| sender = ClientAuthSender.wrap(outgoing) |
| reader = ClientAuthReader.wrap(incoming) |
| try: |
| (<object> self).authenticate(sender, reader) |
| finally: |
| sender.poison() |
| reader.poison() |
| |
| |
| cdef void _get_token(void* self, c_string* token) except *: |
| """Callback for implementing authentication in Python.""" |
| cdef c_string c_result |
| c_result = tobytes((<object> self).get_token()) |
| token[0] = c_result |
| |
| |
| cdef class ServerAuthHandler: |
| """Authentication middleware for a server. |
| |
| To implement an authentication mechanism, subclass this class and |
| override its methods. |
| |
| """ |
| |
| def authenticate(self, outgoing, incoming): |
| """Conduct the handshake with the client. |
| |
| May raise an error if the client cannot authenticate. |
| |
| Parameters |
| ---------- |
| outgoing : ServerAuthSender |
| A channel to send messages to the client. |
| incoming : ServerAuthReader |
| A channel to read messages from the client. |
| """ |
| raise NotImplementedError |
| |
| def is_valid(self, token): |
| """Validate a client token, returning their identity. |
| |
| May return an empty string (if the auth mechanism does not |
| name the peer) or raise an exception (if the token is |
| invalid). |
| |
| Parameters |
| ---------- |
| token : bytes |
| The authentication token from the client. |
| |
| """ |
| raise NotImplementedError |
| |
| cdef PyServerAuthHandler* to_handler(self): |
| cdef PyServerAuthHandlerVtable vtable |
| vtable.authenticate = _server_authenticate |
| vtable.is_valid = _is_valid |
| return new PyServerAuthHandler(self, vtable) |
| |
| |
| cdef class ClientAuthHandler: |
| """Authentication plugin for a client.""" |
| |
| def authenticate(self, outgoing, incoming): |
| """Conduct the handshake with the server. |
| |
| Parameters |
| ---------- |
| outgoing : ClientAuthSender |
| A channel to send messages to the server. |
| incoming : ClientAuthReader |
| A channel to read messages from the server. |
| """ |
| raise NotImplementedError |
| |
| def get_token(self): |
| """Get the auth token for a call.""" |
| raise NotImplementedError |
| |
| cdef PyClientAuthHandler* to_handler(self): |
| cdef PyClientAuthHandlerVtable vtable |
| vtable.authenticate = _client_authenticate |
| vtable.get_token = _get_token |
| return new PyClientAuthHandler(self, vtable) |
| |
| |
| cdef class FlightServerBase: |
| """A Flight service definition. |
| |
| Override methods to define your Flight service. |
| |
| """ |
| |
| cdef: |
| unique_ptr[PyFlightServer] server |
| |
| def run(self, location, auth_handler=None, tls_certificates=None): |
| """Start this server. |
| |
| Parameters |
| ---------- |
| location : Location |
| auth_handler : ServerAuthHandler |
| An authentication mechanism to use. May be None. |
| tls_certificates : list |
| A list of (certificate, key) pairs. |
| """ |
| cdef: |
| PyFlightServerVtable vtable = PyFlightServerVtable() |
| PyFlightServer* c_server |
| unique_ptr[CFlightServerOptions] c_options |
| CCertKeyPair c_cert |
| |
| c_options.reset(new CFlightServerOptions(Location.unwrap(location))) |
| |
| if auth_handler: |
| if not isinstance(auth_handler, ServerAuthHandler): |
| raise TypeError("auth_handler must be a ServerAuthHandler, " |
| "not a '{}'".format(type(auth_handler))) |
| c_options.get().auth_handler.reset( |
| (<ServerAuthHandler> auth_handler).to_handler()) |
| |
| if tls_certificates: |
| for cert, key in tls_certificates: |
| c_cert.pem_cert = tobytes(cert) |
| c_cert.pem_key = tobytes(key) |
| c_options.get().tls_certificates.push_back(c_cert) |
| |
| vtable.list_flights = &_list_flights |
| vtable.get_flight_info = &_get_flight_info |
| vtable.do_put = &_do_put |
| vtable.do_get = &_do_get |
| vtable.list_actions = &_list_actions |
| vtable.do_action = &_do_action |
| |
| c_server = new PyFlightServer(self, vtable) |
| self.server.reset(c_server) |
| with nogil: |
| check_status(c_server.Init(deref(c_options))) |
| check_status(c_server.ServeWithSignals()) |
| |
| def list_flights(self, context, criteria): |
| raise NotImplementedError |
| |
| def get_flight_info(self, context, descriptor): |
| raise NotImplementedError |
| |
| def do_put(self, context, descriptor, reader, |
| writer: FlightMetadataWriter): |
| raise NotImplementedError |
| |
| def do_get(self, context, ticket): |
| raise NotImplementedError |
| |
| def list_actions(self, context): |
| raise NotImplementedError |
| |
| def do_action(self, context, action): |
| raise NotImplementedError |
| |
| def shutdown(self): |
| """Shut down the server, blocking until current requests finish. |
| |
| Do not call this directly from the implementation of a Flight |
| method, as then the server will block forever waiting for that |
| request to finish. Instead, call this method from a background |
| thread. |
| """ |
| # Must not hold the GIL: shutdown waits for pending RPCs to |
| # complete. Holding the GIL means Python-implemented Flight |
| # methods will never get to run, so this will hang |
| # indefinitely. |
| with nogil: |
| if self.server.get() != NULL: |
| self.server.get().Shutdown() |