blob: e5d80df938051904baf914dd934acd966e960a66 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# cython: language_level = 3
# cython: embedsignature = True
import collections
import contextlib
import enum
import re
import socket
import time
import threading
import warnings
from cython.operator cimport dereference as deref
from cython.operator cimport postincrement
from libcpp cimport bool as c_bool
from pyarrow.lib cimport *
from pyarrow.lib import ArrowException, ArrowInvalid
from pyarrow.lib import as_buffer, frombytes, tobytes
from pyarrow.includes.libarrow_flight cimport *
from pyarrow.ipc import _get_legacy_format_default, _ReadPandasMixin
import pyarrow.lib as lib
cdef CFlightCallOptions DEFAULT_CALL_OPTIONS
cdef int check_flight_status(const CStatus& status) nogil except -1:
cdef shared_ptr[FlightStatusDetail] detail
if status.ok():
return 0
detail = FlightStatusDetail.UnwrapStatus(status)
if detail:
with gil:
message = frombytes(status.message(), safe=True)
detail_msg = detail.get().extra_info()
if detail.get().code() == CFlightStatusInternal:
raise FlightInternalError(message, detail_msg)
elif detail.get().code() == CFlightStatusFailed:
message = _munge_grpc_python_error(message)
raise FlightServerError(message, detail_msg)
elif detail.get().code() == CFlightStatusTimedOut:
raise FlightTimedOutError(message, detail_msg)
elif detail.get().code() == CFlightStatusCancelled:
raise FlightCancelledError(message, detail_msg)
elif detail.get().code() == CFlightStatusUnauthenticated:
raise FlightUnauthenticatedError(message, detail_msg)
elif detail.get().code() == CFlightStatusUnauthorized:
raise FlightUnauthorizedError(message, detail_msg)
elif detail.get().code() == CFlightStatusUnavailable:
raise FlightUnavailableError(message, detail_msg)
size_detail = FlightWriteSizeStatusDetail.UnwrapStatus(status)
if size_detail:
with gil:
message = frombytes(status.message(), safe=True)
raise FlightWriteSizeExceededError(
message,
size_detail.get().limit(), size_detail.get().actual())
return check_status(status)
_FLIGHT_SERVER_ERROR_REGEX = re.compile(
r'Flight RPC failed with message: (.*). Detail: '
r'Python exception: (.*)',
re.DOTALL
)
def _munge_grpc_python_error(message):
m = _FLIGHT_SERVER_ERROR_REGEX.match(message)
if m:
return ('Flight RPC failed with Python exception \"{}: {}\"'
.format(m.group(2), m.group(1)))
else:
return message
cdef IpcWriteOptions _get_options(options):
return <IpcWriteOptions> _get_legacy_format_default(
use_legacy_format=None, options=options)
cdef class FlightCallOptions(_Weakrefable):
"""RPC-layer options for a Flight call."""
cdef:
CFlightCallOptions options
def __init__(self, timeout=None, write_options=None, headers=None):
"""Create call options.
Parameters
----------
timeout : float, None
A timeout for the call, in seconds. None means that the
timeout defaults to an implementation-specific value.
write_options : pyarrow.ipc.IpcWriteOptions, optional
IPC write options. The default options can be controlled
by environment variables (see pyarrow.ipc).
headers : List[Tuple[str, str]], optional
A list of arbitrary headers as key, value tuples
"""
cdef IpcWriteOptions c_write_options
if timeout is not None:
self.options.timeout = CTimeoutDuration(timeout)
if write_options is not None:
c_write_options = _get_options(write_options)
self.options.write_options = c_write_options.c_options
if headers is not None:
self.options.headers = headers
@staticmethod
cdef CFlightCallOptions* unwrap(obj):
if not obj:
return &DEFAULT_CALL_OPTIONS
elif isinstance(obj, FlightCallOptions):
return &((<FlightCallOptions> obj).options)
raise TypeError("Expected a FlightCallOptions object, not "
"'{}'".format(type(obj)))
_CertKeyPair = collections.namedtuple('_CertKeyPair', ['cert', 'key'])
class CertKeyPair(_CertKeyPair):
"""A TLS certificate and key for use in Flight."""
cdef class FlightError(Exception):
cdef dict __dict__
def __init__(self, message='', extra_info=b''):
super().__init__(message)
self.extra_info = tobytes(extra_info)
cdef CStatus to_status(self):
message = tobytes("Flight error: {}".format(str(self)))
return CStatus_UnknownError(message)
cdef class FlightInternalError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(CFlightStatusInternal,
tobytes(str(self)), self.extra_info)
cdef class FlightTimedOutError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(CFlightStatusTimedOut,
tobytes(str(self)), self.extra_info)
cdef class FlightCancelledError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(CFlightStatusCancelled, tobytes(str(self)),
self.extra_info)
cdef class FlightServerError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(CFlightStatusFailed, tobytes(str(self)),
self.extra_info)
cdef class FlightUnauthenticatedError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(
CFlightStatusUnauthenticated, tobytes(str(self)), self.extra_info)
cdef class FlightUnauthorizedError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(CFlightStatusUnauthorized, tobytes(str(self)),
self.extra_info)
cdef class FlightUnavailableError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(CFlightStatusUnavailable, tobytes(str(self)),
self.extra_info)
class FlightWriteSizeExceededError(ArrowInvalid):
"""A write operation exceeded the client-configured limit."""
def __init__(self, message, limit, actual):
super().__init__(message)
self.limit = limit
self.actual = actual
cdef class Action(_Weakrefable):
"""An action executable on a Flight service."""
cdef:
CAction action
def __init__(self, action_type, buf):
"""Create an action from a type and a buffer.
Parameters
----------
action_type : bytes or str
buf : Buffer or bytes-like object
"""
self.action.type = tobytes(action_type)
self.action.body = pyarrow_unwrap_buffer(as_buffer(buf))
@property
def type(self):
"""The action type."""
return frombytes(self.action.type)
@property
def body(self):
"""The action body (arguments for the action)."""
return pyarrow_wrap_buffer(self.action.body)
@staticmethod
cdef CAction unwrap(action) except *:
if not isinstance(action, Action):
raise TypeError("Must provide Action, not '{}'".format(
type(action)))
return (<Action> action).action
_ActionType = collections.namedtuple('_ActionType', ['type', 'description'])
class ActionType(_ActionType):
"""A type of action that is executable on a Flight service."""
def make_action(self, buf):
"""Create an Action with this type.
Parameters
----------
buf : obj
An Arrow buffer or Python bytes or bytes-like object.
"""
return Action(self.type, buf)
cdef class Result(_Weakrefable):
"""A result from executing an Action."""
cdef:
unique_ptr[CFlightResult] result
def __init__(self, buf):
"""Create a new result.
Parameters
----------
buf : Buffer or bytes-like object
"""
self.result.reset(new CFlightResult())
self.result.get().body = pyarrow_unwrap_buffer(as_buffer(buf))
@property
def body(self):
"""Get the Buffer containing the result."""
return pyarrow_wrap_buffer(self.result.get().body)
cdef class BasicAuth(_Weakrefable):
"""A container for basic auth."""
cdef:
unique_ptr[CBasicAuth] basic_auth
def __init__(self, username=None, password=None):
"""Create a new basic auth object.
Parameters
----------
username : string
password : string
"""
self.basic_auth.reset(new CBasicAuth())
if username:
self.basic_auth.get().username = tobytes(username)
if password:
self.basic_auth.get().password = tobytes(password)
@property
def username(self):
"""Get the username."""
return self.basic_auth.get().username
@property
def password(self):
"""Get the password."""
return self.basic_auth.get().password
@staticmethod
def deserialize(string):
auth = BasicAuth()
check_flight_status(DeserializeBasicAuth(string, &auth.basic_auth))
return auth
def serialize(self):
cdef:
c_string auth
check_flight_status(SerializeBasicAuth(deref(self.basic_auth), &auth))
return frombytes(auth)
class DescriptorType(enum.Enum):
"""
The type of a FlightDescriptor.
Attributes
----------
UNKNOWN
An unknown descriptor type.
PATH
A Flight stream represented by a path.
CMD
A Flight stream represented by an application-defined command.
"""
UNKNOWN = 0
PATH = 1
CMD = 2
class FlightMethod(enum.Enum):
"""The implemented methods in Flight."""
INVALID = 0
HANDSHAKE = 1
LIST_FLIGHTS = 2
GET_FLIGHT_INFO = 3
GET_SCHEMA = 4
DO_GET = 5
DO_PUT = 6
DO_ACTION = 7
LIST_ACTIONS = 8
DO_EXCHANGE = 9
cdef wrap_flight_method(CFlightMethod method):
if method == CFlightMethodHandshake:
return FlightMethod.HANDSHAKE
elif method == CFlightMethodListFlights:
return FlightMethod.LIST_FLIGHTS
elif method == CFlightMethodGetFlightInfo:
return FlightMethod.GET_FLIGHT_INFO
elif method == CFlightMethodGetSchema:
return FlightMethod.GET_SCHEMA
elif method == CFlightMethodDoGet:
return FlightMethod.DO_GET
elif method == CFlightMethodDoPut:
return FlightMethod.DO_PUT
elif method == CFlightMethodDoAction:
return FlightMethod.DO_ACTION
elif method == CFlightMethodListActions:
return FlightMethod.LIST_ACTIONS
elif method == CFlightMethodDoExchange:
return FlightMethod.DO_EXCHANGE
return FlightMethod.INVALID
cdef class FlightDescriptor(_Weakrefable):
"""A description of a data stream available from a Flight service."""
cdef:
CFlightDescriptor descriptor
def __init__(self):
raise TypeError("Do not call {}'s constructor directly, use "
"`pyarrow.flight.FlightDescriptor.for_{path,command}` "
"function instead."
.format(self.__class__.__name__))
@staticmethod
def for_path(*path):
"""Create a FlightDescriptor for a resource path."""
cdef FlightDescriptor result = \
FlightDescriptor.__new__(FlightDescriptor)
result.descriptor.type = CDescriptorTypePath
result.descriptor.path = [tobytes(p) for p in path]
return result
@staticmethod
def for_command(command):
"""Create a FlightDescriptor for an opaque command."""
cdef FlightDescriptor result = \
FlightDescriptor.__new__(FlightDescriptor)
result.descriptor.type = CDescriptorTypeCmd
result.descriptor.cmd = tobytes(command)
return result
@property
def descriptor_type(self):
"""Get the type of this descriptor."""
if self.descriptor.type == CDescriptorTypeUnknown:
return DescriptorType.UNKNOWN
elif self.descriptor.type == CDescriptorTypePath:
return DescriptorType.PATH
elif self.descriptor.type == CDescriptorTypeCmd:
return DescriptorType.CMD
raise RuntimeError("Invalid descriptor type!")
@property
def command(self):
"""Get the command for this descriptor."""
if self.descriptor_type != DescriptorType.CMD:
return None
return self.descriptor.cmd
@property
def path(self):
"""Get the path for this descriptor."""
if self.descriptor_type != DescriptorType.PATH:
return None
return self.descriptor.path
def __repr__(self):
if self.descriptor_type == DescriptorType.PATH:
return "<FlightDescriptor path: {!r}>".format(self.path)
elif self.descriptor_type == DescriptorType.CMD:
return "<FlightDescriptor command: {!r}>".format(self.command)
else:
return "<FlightDescriptor type: {!r}>".format(self.descriptor_type)
@staticmethod
cdef CFlightDescriptor unwrap(descriptor) except *:
if not isinstance(descriptor, FlightDescriptor):
raise TypeError("Must provide a FlightDescriptor, not '{}'".format(
type(descriptor)))
return (<FlightDescriptor> descriptor).descriptor
def serialize(self):
"""Get the wire-format representation of this type.
Useful when interoperating with non-Flight systems (e.g. REST
services) that may want to return Flight types.
"""
cdef c_string out
check_flight_status(self.descriptor.SerializeToString(&out))
return out
@classmethod
def deserialize(cls, serialized):
"""Parse the wire-format representation of this type.
Useful when interoperating with non-Flight systems (e.g. REST
services) that may want to return Flight types.
"""
cdef FlightDescriptor descriptor = \
FlightDescriptor.__new__(FlightDescriptor)
check_flight_status(CFlightDescriptor.Deserialize(
tobytes(serialized), &descriptor.descriptor))
return descriptor
def __eq__(self, FlightDescriptor other):
return self.descriptor == other.descriptor
cdef class Ticket(_Weakrefable):
"""A ticket for requesting a Flight stream."""
cdef:
CTicket ticket
def __init__(self, ticket):
self.ticket.ticket = tobytes(ticket)
@property
def ticket(self):
return self.ticket.ticket
def serialize(self):
"""Get the wire-format representation of this type.
Useful when interoperating with non-Flight systems (e.g. REST
services) that may want to return Flight types.
"""
cdef c_string out
check_flight_status(self.ticket.SerializeToString(&out))
return out
@classmethod
def deserialize(cls, serialized):
"""Parse the wire-format representation of this type.
Useful when interoperating with non-Flight systems (e.g. REST
services) that may want to return Flight types.
"""
cdef:
CTicket c_ticket
Ticket ticket
check_flight_status(
CTicket.Deserialize(tobytes(serialized), &c_ticket))
ticket = Ticket.__new__(Ticket)
ticket.ticket = c_ticket
return ticket
def __eq__(self, Ticket other):
return self.ticket == other.ticket
def __repr__(self):
return '<Ticket {}>'.format(self.ticket.ticket)
cdef class Location(_Weakrefable):
"""The location of a Flight service."""
cdef:
CLocation location
def __init__(self, uri):
check_flight_status(CLocation.Parse(tobytes(uri), &self.location))
def __repr__(self):
return '<Location {}>'.format(self.location.ToString())
@property
def uri(self):
return self.location.ToString()
def equals(self, Location other):
return self == other
def __eq__(self, other):
if not isinstance(other, Location):
return NotImplemented
return self.location.Equals((<Location> other).location)
@staticmethod
def for_grpc_tcp(host, port):
"""Create a Location for a TCP-based gRPC service."""
cdef:
c_string c_host = tobytes(host)
int c_port = port
Location result = Location.__new__(Location)
check_flight_status(
CLocation.ForGrpcTcp(c_host, c_port, &result.location))
return result
@staticmethod
def for_grpc_tls(host, port):
"""Create a Location for a TLS-based gRPC service."""
cdef:
c_string c_host = tobytes(host)
int c_port = port
Location result = Location.__new__(Location)
check_flight_status(
CLocation.ForGrpcTls(c_host, c_port, &result.location))
return result
@staticmethod
def for_grpc_unix(path):
"""Create a Location for a domain socket-based gRPC service."""
cdef:
c_string c_path = tobytes(path)
Location result = Location.__new__(Location)
check_flight_status(CLocation.ForGrpcUnix(c_path, &result.location))
return result
@staticmethod
cdef Location wrap(CLocation location):
cdef Location result = Location.__new__(Location)
result.location = location
return result
@staticmethod
cdef CLocation unwrap(object location) except *:
cdef CLocation c_location
if isinstance(location, str):
check_flight_status(
CLocation.Parse(tobytes(location), &c_location))
return c_location
elif not isinstance(location, Location):
raise TypeError("Must provide a Location, not '{}'".format(
type(location)))
return (<Location> location).location
cdef class FlightEndpoint(_Weakrefable):
"""A Flight stream, along with the ticket and locations to access it."""
cdef:
CFlightEndpoint endpoint
def __init__(self, ticket, locations):
"""Create a FlightEndpoint from a ticket and list of locations.
Parameters
----------
ticket : Ticket or bytes
the ticket needed to access this flight
locations : list of string URIs
locations where this flight is available
Raises
------
ArrowException
If one of the location URIs is not a valid URI.
"""
cdef:
CLocation c_location
if isinstance(ticket, Ticket):
self.endpoint.ticket.ticket = tobytes(ticket.ticket)
else:
self.endpoint.ticket.ticket = tobytes(ticket)
for location in locations:
if isinstance(location, Location):
c_location = (<Location> location).location
else:
c_location = CLocation()
check_flight_status(
CLocation.Parse(tobytes(location), &c_location))
self.endpoint.locations.push_back(c_location)
@property
def ticket(self):
"""Get the ticket in this endpoint."""
return Ticket(self.endpoint.ticket.ticket)
@property
def locations(self):
return [Location.wrap(location)
for location in self.endpoint.locations]
def __repr__(self):
return "<FlightEndpoint ticket: {!r} locations: {!r}>".format(
self.ticket, self.locations)
def __eq__(self, FlightEndpoint other):
return self.endpoint == other.endpoint
cdef class SchemaResult(_Weakrefable):
"""A result from a getschema request. Holding a schema"""
cdef:
unique_ptr[CSchemaResult] result
def __init__(self, Schema schema):
"""Create a SchemaResult from a schema.
Parameters
----------
schema: Schema
the schema of the data in this flight.
"""
cdef:
shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema)
check_flight_status(CreateSchemaResult(c_schema, &self.result))
@property
def schema(self):
"""The schema of the data in this flight."""
cdef:
shared_ptr[CSchema] schema
CDictionaryMemo dummy_memo
check_flight_status(self.result.get().GetSchema(&dummy_memo, &schema))
return pyarrow_wrap_schema(schema)
cdef class FlightInfo(_Weakrefable):
"""A description of a Flight stream."""
cdef:
unique_ptr[CFlightInfo] info
def __init__(self, Schema schema, FlightDescriptor descriptor, endpoints,
total_records, total_bytes):
"""Create a FlightInfo object from a schema, descriptor, and endpoints.
Parameters
----------
schema : Schema
the schema of the data in this flight.
descriptor : FlightDescriptor
the descriptor for this flight.
endpoints : list of FlightEndpoint
a list of endpoints where this flight is available.
total_records : int
the total records in this flight, or -1 if unknown
total_bytes : int
the total bytes in this flight, or -1 if unknown
"""
cdef:
shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema)
vector[CFlightEndpoint] c_endpoints
for endpoint in endpoints:
if isinstance(endpoint, FlightEndpoint):
c_endpoints.push_back((<FlightEndpoint> endpoint).endpoint)
else:
raise TypeError('Endpoint {} is not instance of'
' FlightEndpoint'.format(endpoint))
check_flight_status(CreateFlightInfo(c_schema,
descriptor.descriptor,
c_endpoints,
total_records,
total_bytes, &self.info))
@property
def total_records(self):
"""The total record count of this flight, or -1 if unknown."""
return self.info.get().total_records()
@property
def total_bytes(self):
"""The size in bytes of the data in this flight, or -1 if unknown."""
return self.info.get().total_bytes()
@property
def schema(self):
"""The schema of the data in this flight."""
cdef:
shared_ptr[CSchema] schema
CDictionaryMemo dummy_memo
check_flight_status(self.info.get().GetSchema(&dummy_memo, &schema))
return pyarrow_wrap_schema(schema)
@property
def descriptor(self):
"""The descriptor of the data in this flight."""
cdef FlightDescriptor result = \
FlightDescriptor.__new__(FlightDescriptor)
result.descriptor = self.info.get().descriptor()
return result
@property
def endpoints(self):
"""The endpoints where this flight is available."""
# TODO: get Cython to iterate over reference directly
cdef:
vector[CFlightEndpoint] endpoints = self.info.get().endpoints()
FlightEndpoint py_endpoint
result = []
for endpoint in endpoints:
py_endpoint = FlightEndpoint.__new__(FlightEndpoint)
py_endpoint.endpoint = endpoint
result.append(py_endpoint)
return result
def serialize(self):
"""Get the wire-format representation of this type.
Useful when interoperating with non-Flight systems (e.g. REST
services) that may want to return Flight types.
"""
cdef c_string out
check_flight_status(self.info.get().SerializeToString(&out))
return out
@classmethod
def deserialize(cls, serialized):
"""Parse the wire-format representation of this type.
Useful when interoperating with non-Flight systems (e.g. REST
services) that may want to return Flight types.
"""
cdef FlightInfo info = FlightInfo.__new__(FlightInfo)
check_flight_status(CFlightInfo.Deserialize(
tobytes(serialized), &info.info))
return info
cdef class FlightStreamChunk(_Weakrefable):
"""A RecordBatch with application metadata on the side."""
cdef:
CFlightStreamChunk chunk
@property
def data(self):
if self.chunk.data == NULL:
return None
return pyarrow_wrap_batch(self.chunk.data)
@property
def app_metadata(self):
if self.chunk.app_metadata == NULL:
return None
return pyarrow_wrap_buffer(self.chunk.app_metadata)
def __iter__(self):
return iter((self.data, self.app_metadata))
def __repr__(self):
return "<FlightStreamChunk with data: {} with metadata: {}>".format(
self.chunk.data != NULL, self.chunk.app_metadata != NULL)
cdef class _MetadataRecordBatchReader(_Weakrefable, _ReadPandasMixin):
"""A reader for Flight streams."""
# Needs to be separate class so the "real" class can subclass the
# pure-Python mixin class
cdef dict __dict__
cdef shared_ptr[CMetadataRecordBatchReader] reader
def __iter__(self):
while True:
yield self.read_chunk()
@property
def schema(self):
"""Get the schema for this reader."""
cdef shared_ptr[CSchema] c_schema
with nogil:
c_schema = GetResultValue(self.reader.get().GetSchema())
return pyarrow_wrap_schema(c_schema)
def read_all(self):
"""Read the entire contents of the stream as a Table."""
cdef:
shared_ptr[CTable] c_table
with nogil:
check_flight_status(self.reader.get().ReadAll(&c_table))
return pyarrow_wrap_table(c_table)
def read_chunk(self):
"""Read the next RecordBatch along with any metadata.
Returns
-------
data : RecordBatch
The next RecordBatch in the stream.
app_metadata : Buffer or None
Application-specific metadata for the batch as defined by
Flight.
Raises
------
StopIteration
when the stream is finished
"""
cdef:
FlightStreamChunk chunk = FlightStreamChunk()
with nogil:
check_flight_status(self.reader.get().Next(&chunk.chunk))
if chunk.chunk.data == NULL and chunk.chunk.app_metadata == NULL:
raise StopIteration
return chunk
def to_reader(self):
"""Convert this reader into a regular RecordBatchReader.
This may fail if the schema cannot be read from the remote end.
"""
cdef RecordBatchReader reader
reader = RecordBatchReader.__new__(RecordBatchReader)
reader.reader = GetResultValue(MakeRecordBatchReader(self.reader))
return reader
cdef class MetadataRecordBatchReader(_MetadataRecordBatchReader):
"""The virtual base class for readers for Flight streams."""
cdef class FlightStreamReader(MetadataRecordBatchReader):
"""A reader that can also be canceled."""
def cancel(self):
"""Cancel the read operation."""
with nogil:
(<CFlightStreamReader*> self.reader.get()).Cancel()
cdef class MetadataRecordBatchWriter(_CRecordBatchWriter):
"""A RecordBatchWriter that also allows writing application metadata.
This class is a context manager; on exit, close() will be called.
"""
cdef CMetadataRecordBatchWriter* _writer(self) nogil:
return <CMetadataRecordBatchWriter*> self.writer.get()
def begin(self, schema: Schema, options=None):
"""Prepare to write data to this stream with the given schema."""
cdef:
shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema)
CIpcWriteOptions c_options = _get_options(options).c_options
with nogil:
check_flight_status(self._writer().Begin(c_schema, c_options))
def write_metadata(self, buf):
"""Write Flight metadata by itself."""
cdef shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(as_buffer(buf))
with nogil:
check_flight_status(
self._writer().WriteMetadata(c_buf))
def write_batch(self, RecordBatch batch):
"""
Write RecordBatch to stream.
Parameters
----------
batch : RecordBatch
"""
# Override superclass method to use check_flight_status so we
# can generate FlightWriteSizeExceededError. We don't do this
# for write_table as callers who intend to handle the error
# and retry with a smaller batch should be working with
# individual batches to have control.
with nogil:
check_flight_status(
self.writer.get().WriteRecordBatch(deref(batch.batch)))
def write_with_metadata(self, RecordBatch batch, buf):
"""Write a RecordBatch along with Flight metadata.
Parameters
----------
batch : RecordBatch
The next RecordBatch in the stream.
buf : Buffer
Application-specific metadata for the batch as defined by
Flight.
"""
cdef shared_ptr[CBuffer] c_buf = pyarrow_unwrap_buffer(as_buffer(buf))
with nogil:
check_flight_status(
self._writer().WriteWithMetadata(deref(batch.batch), c_buf))
cdef class FlightStreamWriter(MetadataRecordBatchWriter):
"""A writer that also allows closing the write side of a stream."""
def done_writing(self):
"""Indicate that the client is done writing, but not done reading."""
with nogil:
check_flight_status(
(<CFlightStreamWriter*> self.writer.get()).DoneWriting())
cdef class FlightMetadataReader(_Weakrefable):
"""A reader for Flight metadata messages sent during a DoPut."""
cdef:
unique_ptr[CFlightMetadataReader] reader
def read(self):
"""Read the next metadata message."""
cdef shared_ptr[CBuffer] buf
with nogil:
check_flight_status(self.reader.get().ReadMetadata(&buf))
if buf == NULL:
return None
return pyarrow_wrap_buffer(buf)
cdef class FlightMetadataWriter(_Weakrefable):
"""A sender for Flight metadata messages during a DoPut."""
cdef:
unique_ptr[CFlightMetadataWriter] writer
def write(self, message):
"""Write the next metadata message.
Parameters
----------
message : Buffer
"""
cdef shared_ptr[CBuffer] buf = \
pyarrow_unwrap_buffer(as_buffer(message))
with nogil:
check_flight_status(self.writer.get().WriteMetadata(deref(buf)))
cdef class FlightClient(_Weakrefable):
"""A client to a Flight service.
Connect to a Flight service on the given host and port.
Parameters
----------
location : str, tuple or Location
Location to connect to. Either a gRPC URI like `grpc://localhost:port`,
a tuple of (host, port) pair, or a Location instance.
tls_root_certs : bytes or None
PEM-encoded
cert_chain: bytes or None
Client certificate if using mutual TLS
private_key: bytes or None
Client private key for cert_chain is using mutual TLS
override_hostname : str or None
Override the hostname checked by TLS. Insecure, use with caution.
middleware : list optional, default None
A list of ClientMiddlewareFactory instances.
write_size_limit_bytes : int optional, default None
A soft limit on the size of a data payload sent to the
server. Enabled if positive. If enabled, writing a record
batch that (when serialized) exceeds this limit will raise an
exception; the client can retry the write with a smaller
batch.
disable_server_verification : boolean optional, default False
A flag that indicates that, if the client is connecting
with TLS, that it skips server verification. If this is
enabled, all other TLS settings are overridden.
generic_options : list optional, default None
A list of generic (string, int or string) option tuples passed
to the underlying transport. Effect is implementation
dependent.
"""
cdef:
unique_ptr[CFlightClient] client
def __init__(self, location, *, tls_root_certs=None, cert_chain=None,
private_key=None, override_hostname=None, middleware=None,
write_size_limit_bytes=None,
disable_server_verification=None, generic_options=None):
if isinstance(location, (bytes, str)):
location = Location(location)
elif isinstance(location, tuple):
host, port = location
if tls_root_certs or disable_server_verification is not None:
location = Location.for_grpc_tls(host, port)
else:
location = Location.for_grpc_tcp(host, port)
elif not isinstance(location, Location):
raise TypeError('`location` argument must be a string, tuple or a '
'Location instance')
self.init(location, tls_root_certs, cert_chain, private_key,
override_hostname, middleware, write_size_limit_bytes,
disable_server_verification, generic_options)
cdef init(self, Location location, tls_root_certs, cert_chain,
private_key, override_hostname, middleware,
write_size_limit_bytes, disable_server_verification,
generic_options):
cdef:
int c_port = 0
CLocation c_location = Location.unwrap(location)
CFlightClientOptions c_options = CFlightClientOptions.Defaults()
function[cb_client_middleware_start_call] start_call = \
&_client_middleware_start_call
CIntStringVariant variant
if tls_root_certs:
c_options.tls_root_certs = tobytes(tls_root_certs)
if cert_chain:
c_options.cert_chain = tobytes(cert_chain)
if private_key:
c_options.private_key = tobytes(private_key)
if override_hostname:
c_options.override_hostname = tobytes(override_hostname)
if disable_server_verification is not None:
c_options.disable_server_verification = disable_server_verification
if middleware:
for factory in middleware:
c_options.middleware.push_back(
<shared_ptr[CClientMiddlewareFactory]>
make_shared[CPyClientMiddlewareFactory](
<PyObject*> factory, start_call))
if write_size_limit_bytes is not None:
c_options.write_size_limit_bytes = write_size_limit_bytes
else:
c_options.write_size_limit_bytes = 0
if generic_options:
for key, value in generic_options:
if isinstance(value, (str, bytes)):
variant = CIntStringVariant(<c_string> tobytes(value))
else:
variant = CIntStringVariant(<int> value)
c_options.generic_options.push_back(
pair[c_string, CIntStringVariant](tobytes(key), variant))
with nogil:
check_flight_status(CFlightClient.Connect(c_location, c_options,
&self.client))
def wait_for_available(self, timeout=5):
"""Block until the server can be contacted.
Parameters
----------
timeout : int, default 5
The maximum seconds to wait.
"""
deadline = time.time() + timeout
while True:
try:
list(self.list_flights())
except FlightUnavailableError:
if time.time() < deadline:
time.sleep(0.025)
continue
else:
raise
except NotImplementedError:
# allow if list_flights is not implemented, because
# the server can be contacted nonetheless
break
else:
break
@classmethod
def connect(cls, location, tls_root_certs=None, cert_chain=None,
private_key=None, override_hostname=None,
disable_server_verification=None):
warnings.warn("The 'FlightClient.connect' method is deprecated, use "
"FlightClient constructor or pyarrow.flight.connect "
"function instead")
return FlightClient(
location, tls_root_certs=tls_root_certs,
cert_chain=cert_chain, private_key=private_key,
override_hostname=override_hostname,
disable_server_verification=disable_server_verification
)
def authenticate(self, auth_handler, options: FlightCallOptions = None):
"""Authenticate to the server.
Parameters
----------
auth_handler : ClientAuthHandler
The authentication mechanism to use.
options : FlightCallOptions
Options for this call.
"""
cdef:
unique_ptr[CClientAuthHandler] handler
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
if not isinstance(auth_handler, ClientAuthHandler):
raise TypeError(
"FlightClient.authenticate takes a ClientAuthHandler, "
"not '{}'".format(type(auth_handler)))
handler.reset((<ClientAuthHandler> auth_handler).to_handler())
with nogil:
check_flight_status(
self.client.get().Authenticate(deref(c_options),
move(handler)))
def authenticate_basic_token(self, username, password,
options: FlightCallOptions = None):
"""Authenticate to the server with HTTP basic authentication.
Parameters
----------
username : string
Username to authenticate with
password : string
Password to authenticate with
options : FlightCallOptions
Options for this call
Returns
-------
tuple : Tuple[str, str]
A tuple representing the FlightCallOptions authorization
header entry of a bearer token.
"""
cdef:
CResult[pair[c_string, c_string]] result
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
c_string user = tobytes(username)
c_string pw = tobytes(password)
with nogil:
result = self.client.get().AuthenticateBasicToken(deref(c_options),
user, pw)
check_flight_status(result.status())
return GetResultValue(result)
def list_actions(self, options: FlightCallOptions = None):
"""List the actions available on a service."""
cdef:
vector[CActionType] results
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
with nogil:
check_flight_status(
self.client.get().ListActions(deref(c_options), &results))
result = []
for action_type in results:
py_action = ActionType(frombytes(action_type.type),
frombytes(action_type.description))
result.append(py_action)
return result
def do_action(self, action, options: FlightCallOptions = None):
"""
Execute an action on a service.
Parameters
----------
action : str, tuple, or Action
Can be action type name (no body), type and body, or any Action
object
options : FlightCallOptions
RPC options
Returns
-------
results : iterator of Result values
"""
cdef:
unique_ptr[CResultStream] results
Result result
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
if isinstance(action, (str, bytes)):
action = Action(action, b'')
elif isinstance(action, tuple):
action = Action(*action)
elif not isinstance(action, Action):
raise TypeError("Action must be Action instance, string, or tuple")
cdef CAction c_action = Action.unwrap(<Action> action)
with nogil:
check_flight_status(
self.client.get().DoAction(deref(c_options), c_action,
&results))
while True:
result = Result.__new__(Result)
with nogil:
check_flight_status(results.get().Next(&result.result))
if result.result == NULL:
break
yield result
def list_flights(self, criteria: bytes = None,
options: FlightCallOptions = None):
"""List the flights available on a service."""
cdef:
unique_ptr[CFlightListing] listing
FlightInfo result
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
CCriteria c_criteria
if criteria:
c_criteria.expression = tobytes(criteria)
with nogil:
check_flight_status(
self.client.get().ListFlights(deref(c_options),
c_criteria, &listing))
while True:
result = FlightInfo.__new__(FlightInfo)
with nogil:
check_flight_status(listing.get().Next(&result.info))
if result.info == NULL:
break
yield result
def get_flight_info(self, descriptor: FlightDescriptor,
options: FlightCallOptions = None):
"""Request information about an available flight."""
cdef:
FlightInfo result = FlightInfo.__new__(FlightInfo)
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
CFlightDescriptor c_descriptor = \
FlightDescriptor.unwrap(descriptor)
with nogil:
check_flight_status(self.client.get().GetFlightInfo(
deref(c_options), c_descriptor, &result.info))
return result
def get_schema(self, descriptor: FlightDescriptor,
options: FlightCallOptions = None):
"""Request schema for an available flight."""
cdef:
SchemaResult result = SchemaResult.__new__(SchemaResult)
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
CFlightDescriptor c_descriptor = \
FlightDescriptor.unwrap(descriptor)
with nogil:
check_status(
self.client.get()
.GetSchema(deref(c_options), c_descriptor, &result.result)
)
return result
def do_get(self, ticket: Ticket, options: FlightCallOptions = None):
"""Request the data for a flight.
Returns
-------
reader : FlightStreamReader
"""
cdef:
unique_ptr[CFlightStreamReader] reader
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
with nogil:
check_flight_status(
self.client.get().DoGet(
deref(c_options), ticket.ticket, &reader))
result = FlightStreamReader()
result.reader.reset(reader.release())
return result
def do_put(self, descriptor: FlightDescriptor, schema: Schema,
options: FlightCallOptions = None):
"""Upload data to a flight.
Returns
-------
writer : FlightStreamWriter
reader : FlightMetadataReader
"""
cdef:
shared_ptr[CSchema] c_schema = pyarrow_unwrap_schema(schema)
unique_ptr[CFlightStreamWriter] writer
unique_ptr[CFlightMetadataReader] metadata_reader
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
CFlightDescriptor c_descriptor = \
FlightDescriptor.unwrap(descriptor)
FlightMetadataReader reader = FlightMetadataReader()
with nogil:
check_flight_status(self.client.get().DoPut(
deref(c_options),
c_descriptor,
c_schema,
&writer,
&reader.reader))
result = FlightStreamWriter()
result.writer.reset(writer.release())
return result, reader
def do_exchange(self, descriptor: FlightDescriptor,
options: FlightCallOptions = None):
"""Start a bidirectional data exchange with a server.
Parameters
----------
descriptor : FlightDescriptor
A descriptor for the flight.
options : FlightCallOptions
RPC options.
Returns
-------
writer : FlightStreamWriter
reader : FlightStreamReader
"""
cdef:
unique_ptr[CFlightStreamWriter] c_writer
unique_ptr[CFlightStreamReader] c_reader
CFlightCallOptions* c_options = FlightCallOptions.unwrap(options)
CFlightDescriptor c_descriptor = \
FlightDescriptor.unwrap(descriptor)
with nogil:
check_flight_status(self.client.get().DoExchange(
deref(c_options),
c_descriptor,
&c_writer,
&c_reader))
py_writer = FlightStreamWriter()
py_writer.writer.reset(c_writer.release())
py_reader = FlightStreamReader()
py_reader.reader.reset(c_reader.release())
return py_writer, py_reader
cdef class FlightDataStream(_Weakrefable):
"""Abstract base class for Flight data streams."""
cdef CFlightDataStream* to_stream(self) except *:
"""Create the C++ data stream for the backing Python object.
We don't expose the C++ object to Python, so we can manage its
lifetime from the Cython/C++ side.
"""
raise NotImplementedError
cdef class RecordBatchStream(FlightDataStream):
"""A Flight data stream backed by RecordBatches."""
cdef:
object data_source
CIpcWriteOptions write_options
def __init__(self, data_source, options=None):
"""Create a RecordBatchStream from a data source.
Parameters
----------
data_source : RecordBatchReader or Table
options : pyarrow.ipc.IpcWriteOptions, optional
"""
if (not isinstance(data_source, RecordBatchReader) and
not isinstance(data_source, lib.Table)):
raise TypeError("Expected RecordBatchReader or Table, "
"but got: {}".format(type(data_source)))
self.data_source = data_source
self.write_options = _get_options(options).c_options
cdef CFlightDataStream* to_stream(self) except *:
cdef:
shared_ptr[CRecordBatchReader] reader
if isinstance(self.data_source, RecordBatchReader):
reader = (<RecordBatchReader> self.data_source).reader
elif isinstance(self.data_source, lib.Table):
table = (<Table> self.data_source).table
reader.reset(new TableBatchReader(deref(table)))
else:
raise RuntimeError("Can't construct RecordBatchStream "
"from type {}".format(type(self.data_source)))
return new CRecordBatchStream(reader, self.write_options)
cdef class GeneratorStream(FlightDataStream):
"""A Flight data stream backed by a Python generator."""
cdef:
shared_ptr[CSchema] schema
object generator
# A substream currently being consumed by the client, if
# present. Produced by the generator.
unique_ptr[CFlightDataStream] current_stream
CIpcWriteOptions c_options
def __init__(self, schema, generator, options=None):
"""Create a GeneratorStream from a Python generator.
Parameters
----------
schema : Schema
The schema for the data to be returned.
generator : iterator or iterable
The generator should yield other FlightDataStream objects,
Tables, RecordBatches, or RecordBatchReaders.
options : pyarrow.ipc.IpcWriteOptions, optional
"""
self.schema = pyarrow_unwrap_schema(schema)
self.generator = iter(generator)
self.c_options = _get_options(options).c_options
cdef CFlightDataStream* to_stream(self) except *:
cdef:
function[cb_data_stream_next] callback = &_data_stream_next
return new CPyGeneratorFlightDataStream(self, self.schema, callback,
self.c_options)
cdef class ServerCallContext(_Weakrefable):
"""Per-call state/context."""
cdef:
const CServerCallContext* context
def peer_identity(self):
"""Get the identity of the authenticated peer.
May be the empty string.
"""
return tobytes(self.context.peer_identity())
def peer(self):
"""Get the address of the peer."""
# Set safe=True as gRPC on Windows sometimes gives garbage bytes
return frombytes(self.context.peer(), safe=True)
def get_middleware(self, key):
"""
Get a middleware instance by key.
Returns None if the middleware was not found.
"""
cdef:
CServerMiddleware* c_middleware = \
self.context.GetMiddleware(CPyServerMiddlewareName)
CPyServerMiddleware* middleware
if c_middleware == NULL:
return None
if c_middleware.name() != CPyServerMiddlewareName:
return None
middleware = <CPyServerMiddleware*> c_middleware
py_middleware = <_ServerMiddlewareWrapper> middleware.py_object()
return py_middleware.middleware.get(key)
@staticmethod
cdef ServerCallContext wrap(const CServerCallContext& context):
cdef ServerCallContext result = \
ServerCallContext.__new__(ServerCallContext)
result.context = &context
return result
cdef class ServerAuthReader(_Weakrefable):
"""A reader for messages from the client during an auth handshake."""
cdef:
CServerAuthReader* reader
def read(self):
cdef c_string token
if not self.reader:
raise ValueError("Cannot use ServerAuthReader outside "
"ServerAuthHandler.authenticate")
with nogil:
check_flight_status(self.reader.Read(&token))
return token
cdef void poison(self):
"""Prevent further usage of this object.
This object is constructed by taking a pointer to a reference,
so we want to make sure Python users do not access this after
the reference goes away.
"""
self.reader = NULL
@staticmethod
cdef ServerAuthReader wrap(CServerAuthReader* reader):
cdef ServerAuthReader result = \
ServerAuthReader.__new__(ServerAuthReader)
result.reader = reader
return result
cdef class ServerAuthSender(_Weakrefable):
"""A writer for messages to the client during an auth handshake."""
cdef:
CServerAuthSender* sender
def write(self, message):
cdef c_string c_message = tobytes(message)
if not self.sender:
raise ValueError("Cannot use ServerAuthSender outside "
"ServerAuthHandler.authenticate")
with nogil:
check_flight_status(self.sender.Write(c_message))
cdef void poison(self):
"""Prevent further usage of this object.
This object is constructed by taking a pointer to a reference,
so we want to make sure Python users do not access this after
the reference goes away.
"""
self.sender = NULL
@staticmethod
cdef ServerAuthSender wrap(CServerAuthSender* sender):
cdef ServerAuthSender result = \
ServerAuthSender.__new__(ServerAuthSender)
result.sender = sender
return result
cdef class ClientAuthReader(_Weakrefable):
"""A reader for messages from the server during an auth handshake."""
cdef:
CClientAuthReader* reader
def read(self):
cdef c_string token
if not self.reader:
raise ValueError("Cannot use ClientAuthReader outside "
"ClientAuthHandler.authenticate")
with nogil:
check_flight_status(self.reader.Read(&token))
return token
cdef void poison(self):
"""Prevent further usage of this object.
This object is constructed by taking a pointer to a reference,
so we want to make sure Python users do not access this after
the reference goes away.
"""
self.reader = NULL
@staticmethod
cdef ClientAuthReader wrap(CClientAuthReader* reader):
cdef ClientAuthReader result = \
ClientAuthReader.__new__(ClientAuthReader)
result.reader = reader
return result
cdef class ClientAuthSender(_Weakrefable):
"""A writer for messages to the server during an auth handshake."""
cdef:
CClientAuthSender* sender
def write(self, message):
cdef c_string c_message = tobytes(message)
if not self.sender:
raise ValueError("Cannot use ClientAuthSender outside "
"ClientAuthHandler.authenticate")
with nogil:
check_flight_status(self.sender.Write(c_message))
cdef void poison(self):
"""Prevent further usage of this object.
This object is constructed by taking a pointer to a reference,
so we want to make sure Python users do not access this after
the reference goes away.
"""
self.sender = NULL
@staticmethod
cdef ClientAuthSender wrap(CClientAuthSender* sender):
cdef ClientAuthSender result = \
ClientAuthSender.__new__(ClientAuthSender)
result.sender = sender
return result
cdef CStatus _data_stream_next(void* self, CFlightPayload* payload) except *:
"""Callback for implementing FlightDataStream in Python."""
cdef:
unique_ptr[CFlightDataStream] data_stream
py_stream = <object> self
if not isinstance(py_stream, GeneratorStream):
raise RuntimeError("self object in callback is not GeneratorStream")
stream = <GeneratorStream> py_stream
if stream.current_stream != nullptr:
check_flight_status(stream.current_stream.get().Next(payload))
# If the stream ended, see if there's another stream from the
# generator
if payload.ipc_message.metadata != nullptr:
return CStatus_OK()
stream.current_stream.reset(nullptr)
try:
result = next(stream.generator)
except StopIteration:
payload.ipc_message.metadata.reset(<CBuffer*> nullptr)
return CStatus_OK()
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
if isinstance(result, (list, tuple)):
result, metadata = result
else:
result, metadata = result, None
if isinstance(result, (Table, RecordBatchReader)):
if metadata:
raise ValueError("Can only return metadata alongside a "
"RecordBatch.")
result = RecordBatchStream(result)
stream_schema = pyarrow_wrap_schema(stream.schema)
if isinstance(result, FlightDataStream):
if metadata:
raise ValueError("Can only return metadata alongside a "
"RecordBatch.")
data_stream = unique_ptr[CFlightDataStream](
(<FlightDataStream> result).to_stream())
substream_schema = pyarrow_wrap_schema(data_stream.get().schema())
if substream_schema != stream_schema:
raise ValueError("Got a FlightDataStream whose schema does not "
"match the declared schema of this "
"GeneratorStream. "
"Got: {}\nExpected: {}".format(substream_schema,
stream_schema))
stream.current_stream.reset(
new CPyFlightDataStream(result, move(data_stream)))
return _data_stream_next(self, payload)
elif isinstance(result, RecordBatch):
batch = <RecordBatch> result
if batch.schema != stream_schema:
raise ValueError("Got a RecordBatch whose schema does not "
"match the declared schema of this "
"GeneratorStream. "
"Got: {}\nExpected: {}".format(batch.schema,
stream_schema))
check_flight_status(GetRecordBatchPayload(
deref(batch.batch),
stream.c_options,
&payload.ipc_message))
if metadata:
payload.app_metadata = pyarrow_unwrap_buffer(as_buffer(metadata))
else:
raise TypeError("GeneratorStream must be initialized with "
"an iterator of FlightDataStream, Table, "
"RecordBatch, or RecordBatchStreamReader objects, "
"not {}.".format(type(result)))
return CStatus_OK()
cdef CStatus _list_flights(void* self, const CServerCallContext& context,
const CCriteria* c_criteria,
unique_ptr[CFlightListing]* listing) except *:
"""Callback for implementing ListFlights in Python."""
cdef:
vector[CFlightInfo] flights
try:
result = (<object> self).list_flights(ServerCallContext.wrap(context),
c_criteria.expression)
for info in result:
if not isinstance(info, FlightInfo):
raise TypeError("FlightServerBase.list_flights must return "
"FlightInfo instances, but got {}".format(
type(info)))
flights.push_back(deref((<FlightInfo> info).info.get()))
listing.reset(new CSimpleFlightListing(flights))
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
return CStatus_OK()
cdef CStatus _get_flight_info(void* self, const CServerCallContext& context,
CFlightDescriptor c_descriptor,
unique_ptr[CFlightInfo]* info) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
FlightDescriptor py_descriptor = \
FlightDescriptor.__new__(FlightDescriptor)
py_descriptor.descriptor = c_descriptor
try:
result = (<object> self).get_flight_info(
ServerCallContext.wrap(context),
py_descriptor)
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
if not isinstance(result, FlightInfo):
raise TypeError("FlightServerBase.get_flight_info must return "
"a FlightInfo instance, but got {}".format(
type(result)))
info.reset(new CFlightInfo(deref((<FlightInfo> result).info.get())))
return CStatus_OK()
cdef CStatus _get_schema(void* self, const CServerCallContext& context,
CFlightDescriptor c_descriptor,
unique_ptr[CSchemaResult]* info) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
FlightDescriptor py_descriptor = \
FlightDescriptor.__new__(FlightDescriptor)
py_descriptor.descriptor = c_descriptor
result = (<object> self).get_schema(ServerCallContext.wrap(context),
py_descriptor)
if not isinstance(result, SchemaResult):
raise TypeError("FlightServerBase.get_schema_info must return "
"a SchemaResult instance, but got {}".format(
type(result)))
info.reset(new CSchemaResult(deref((<SchemaResult> result).result.get())))
return CStatus_OK()
cdef CStatus _do_put(void* self, const CServerCallContext& context,
unique_ptr[CFlightMessageReader] reader,
unique_ptr[CFlightMetadataWriter] writer) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
MetadataRecordBatchReader py_reader = MetadataRecordBatchReader()
FlightMetadataWriter py_writer = FlightMetadataWriter()
FlightDescriptor descriptor = \
FlightDescriptor.__new__(FlightDescriptor)
descriptor.descriptor = reader.get().descriptor()
py_reader.reader.reset(reader.release())
py_writer.writer.reset(writer.release())
try:
(<object> self).do_put(ServerCallContext.wrap(context), descriptor,
py_reader, py_writer)
return CStatus_OK()
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
cdef CStatus _do_get(void* self, const CServerCallContext& context,
CTicket ticket,
unique_ptr[CFlightDataStream]* stream) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
unique_ptr[CFlightDataStream] data_stream
py_ticket = Ticket(ticket.ticket)
try:
result = (<object> self).do_get(ServerCallContext.wrap(context),
py_ticket)
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
if not isinstance(result, FlightDataStream):
raise TypeError("FlightServerBase.do_get must return "
"a FlightDataStream")
data_stream = unique_ptr[CFlightDataStream](
(<FlightDataStream> result).to_stream())
stream[0] = unique_ptr[CFlightDataStream](
new CPyFlightDataStream(result, move(data_stream)))
return CStatus_OK()
cdef CStatus _do_exchange(void* self, const CServerCallContext& context,
unique_ptr[CFlightMessageReader] reader,
unique_ptr[CFlightMessageWriter] writer) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
MetadataRecordBatchReader py_reader = MetadataRecordBatchReader()
MetadataRecordBatchWriter py_writer = MetadataRecordBatchWriter()
FlightDescriptor descriptor = \
FlightDescriptor.__new__(FlightDescriptor)
descriptor.descriptor = reader.get().descriptor()
py_reader.reader.reset(reader.release())
py_writer.writer.reset(writer.release())
try:
(<object> self).do_exchange(ServerCallContext.wrap(context),
descriptor, py_reader, py_writer)
return CStatus_OK()
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
cdef CStatus _do_action_result_next(
void* self,
unique_ptr[CFlightResult]* result
) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
CFlightResult* c_result
try:
action_result = next(<object> self)
if not isinstance(action_result, Result):
action_result = Result(action_result)
c_result = (<Result> action_result).result.get()
result.reset(new CFlightResult(deref(c_result)))
except StopIteration:
result.reset(nullptr)
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
return CStatus_OK()
cdef CStatus _do_action(void* self, const CServerCallContext& context,
const CAction& action,
unique_ptr[CResultStream]* result) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
function[cb_result_next] ptr = &_do_action_result_next
py_action = Action(action.type, pyarrow_wrap_buffer(action.body))
try:
responses = (<object> self).do_action(ServerCallContext.wrap(context),
py_action)
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
# Let the application return an iterator or anything convertible
# into one
result.reset(new CPyFlightResultStream(iter(responses), ptr))
return CStatus_OK()
cdef CStatus _list_actions(void* self, const CServerCallContext& context,
vector[CActionType]* actions) except *:
"""Callback for implementing Flight servers in Python."""
cdef:
CActionType action_type
# Method should return a list of ActionTypes or similar tuple
try:
result = (<object> self).list_actions(ServerCallContext.wrap(context))
for action in result:
if not isinstance(action, tuple):
raise TypeError(
"Results of list_actions must be ActionType or tuple")
action_type.type = tobytes(action[0])
action_type.description = tobytes(action[1])
actions.push_back(action_type)
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
return CStatus_OK()
cdef CStatus _server_authenticate(void* self, CServerAuthSender* outgoing,
CServerAuthReader* incoming) except *:
"""Callback for implementing authentication in Python."""
sender = ServerAuthSender.wrap(outgoing)
reader = ServerAuthReader.wrap(incoming)
try:
(<object> self).authenticate(sender, reader)
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
finally:
sender.poison()
reader.poison()
return CStatus_OK()
cdef CStatus _is_valid(void* self, const c_string& token,
c_string* peer_identity) except *:
"""Callback for implementing authentication in Python."""
cdef c_string c_result
try:
c_result = tobytes((<object> self).is_valid(token))
peer_identity[0] = c_result
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
return CStatus_OK()
cdef CStatus _client_authenticate(void* self, CClientAuthSender* outgoing,
CClientAuthReader* incoming) except *:
"""Callback for implementing authentication in Python."""
sender = ClientAuthSender.wrap(outgoing)
reader = ClientAuthReader.wrap(incoming)
try:
(<object> self).authenticate(sender, reader)
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
finally:
sender.poison()
reader.poison()
return CStatus_OK()
cdef CStatus _get_token(void* self, c_string* token) except *:
"""Callback for implementing authentication in Python."""
cdef c_string c_result
try:
c_result = tobytes((<object> self).get_token())
token[0] = c_result
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
return CStatus_OK()
cdef CStatus _middleware_sending_headers(
void* self, CAddCallHeaders* add_headers) except *:
"""Callback for implementing middleware."""
try:
headers = (<object> self).sending_headers()
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
if headers:
for header, values in headers.items():
if isinstance(values, (str, bytes)):
values = (values,)
# Headers in gRPC (and HTTP/1, HTTP/2) are required to be
# valid ASCII.
if isinstance(header, str):
header = header.encode("ascii")
for value in values:
if isinstance(value, str):
value = value.encode("ascii")
# Allow bytes values to pass through.
add_headers.AddHeader(header, value)
return CStatus_OK()
cdef CStatus _middleware_call_completed(
void* self,
const CStatus& call_status) except *:
"""Callback for implementing middleware."""
try:
try:
check_flight_status(call_status)
except Exception as e:
(<object> self).call_completed(e)
else:
(<object> self).call_completed(None)
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
return CStatus_OK()
cdef CStatus _middleware_received_headers(
void* self,
const CCallHeaders& c_headers) except *:
"""Callback for implementing middleware."""
try:
headers = convert_headers(c_headers)
(<object> self).received_headers(headers)
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
return CStatus_OK()
cdef dict convert_headers(const CCallHeaders& c_headers):
cdef:
CCallHeaders.const_iterator header_iter = c_headers.cbegin()
headers = {}
while header_iter != c_headers.cend():
header = c_string(deref(header_iter).first).decode("ascii")
value = c_string(deref(header_iter).second)
if not header.endswith("-bin"):
# Text header values in gRPC (and HTTP/1, HTTP/2) are
# required to be valid ASCII. Binary header values are
# exposed as bytes.
value = value.decode("ascii")
headers.setdefault(header, []).append(value)
postincrement(header_iter)
return headers
cdef CStatus _server_middleware_start_call(
void* self,
const CCallInfo& c_info,
const CCallHeaders& c_headers,
shared_ptr[CServerMiddleware]* c_instance) except *:
"""Callback for implementing server middleware."""
instance = None
try:
call_info = wrap_call_info(c_info)
headers = convert_headers(c_headers)
instance = (<object> self).start_call(call_info, headers)
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
if instance:
ServerMiddleware.wrap(instance, c_instance)
return CStatus_OK()
cdef CStatus _client_middleware_start_call(
void* self,
const CCallInfo& c_info,
unique_ptr[CClientMiddleware]* c_instance) except *:
"""Callback for implementing client middleware."""
instance = None
try:
call_info = wrap_call_info(c_info)
instance = (<object> self).start_call(call_info)
except FlightError as flight_error:
return (<FlightError> flight_error).to_status()
if instance:
ClientMiddleware.wrap(instance, c_instance)
return CStatus_OK()
cdef class ServerAuthHandler(_Weakrefable):
"""Authentication middleware for a server.
To implement an authentication mechanism, subclass this class and
override its methods.
"""
def authenticate(self, outgoing, incoming):
"""Conduct the handshake with the client.
May raise an error if the client cannot authenticate.
Parameters
----------
outgoing : ServerAuthSender
A channel to send messages to the client.
incoming : ServerAuthReader
A channel to read messages from the client.
"""
raise NotImplementedError
def is_valid(self, token):
"""Validate a client token, returning their identity.
May return an empty string (if the auth mechanism does not
name the peer) or raise an exception (if the token is
invalid).
Parameters
----------
token : bytes
The authentication token from the client.
"""
raise NotImplementedError
cdef PyServerAuthHandler* to_handler(self):
cdef PyServerAuthHandlerVtable vtable
vtable.authenticate = _server_authenticate
vtable.is_valid = _is_valid
return new PyServerAuthHandler(self, vtable)
cdef class ClientAuthHandler(_Weakrefable):
"""Authentication plugin for a client."""
def authenticate(self, outgoing, incoming):
"""Conduct the handshake with the server.
Parameters
----------
outgoing : ClientAuthSender
A channel to send messages to the server.
incoming : ClientAuthReader
A channel to read messages from the server.
"""
raise NotImplementedError
def get_token(self):
"""Get the auth token for a call."""
raise NotImplementedError
cdef PyClientAuthHandler* to_handler(self):
cdef PyClientAuthHandlerVtable vtable
vtable.authenticate = _client_authenticate
vtable.get_token = _get_token
return new PyClientAuthHandler(self, vtable)
_CallInfo = collections.namedtuple("_CallInfo", ["method"])
class CallInfo(_CallInfo):
"""Information about a particular RPC for Flight middleware."""
cdef wrap_call_info(const CCallInfo& c_info):
method = wrap_flight_method(c_info.method)
return CallInfo(method=method)
cdef class ClientMiddlewareFactory(_Weakrefable):
"""A factory for new middleware instances.
All middleware methods will be called from the same thread as the
RPC method implementation. That is, thread-locals set in the
client are accessible from the middleware itself.
"""
def start_call(self, info):
"""Called at the start of an RPC.
This must be thread-safe and must not raise exceptions.
Parameters
----------
info : CallInfo
Information about the call.
Returns
-------
instance : ClientMiddleware
An instance of ClientMiddleware (the instance to use for
the call), or None if this call is not intercepted.
"""
cdef class ClientMiddleware(_Weakrefable):
"""Client-side middleware for a call, instantiated per RPC.
Methods here should be fast and must be infallible: they should
not raise exceptions or stall indefinitely.
"""
def sending_headers(self):
"""A callback before headers are sent.
Returns
-------
headers : dict
A dictionary of header values to add to the request, or
None if no headers are to be added. The dictionary should
have string keys and string or list-of-string values.
Bytes values are allowed, but the underlying transport may
not support them or may restrict them. For gRPC, binary
values are only allowed on headers ending in "-bin".
"""
def received_headers(self, headers):
"""A callback when headers are received.
The default implementation does nothing.
Parameters
----------
headers : dict
A dictionary of headers from the server. Keys are strings
and values are lists of strings (for text headers) or
bytes (for binary headers).
"""
def call_completed(self, exception):
"""A callback when the call finishes.
The default implementation does nothing.
Parameters
----------
exception : ArrowException
If the call errored, this is the equivalent
exception. Will be None if the call succeeded.
"""
@staticmethod
cdef void wrap(object py_middleware,
unique_ptr[CClientMiddleware]* c_instance):
cdef PyClientMiddlewareVtable vtable
vtable.sending_headers = _middleware_sending_headers
vtable.received_headers = _middleware_received_headers
vtable.call_completed = _middleware_call_completed
c_instance[0].reset(new CPyClientMiddleware(py_middleware, vtable))
cdef class ServerMiddlewareFactory(_Weakrefable):
"""A factory for new middleware instances.
All middleware methods will be called from the same thread as the
RPC method implementation. That is, thread-locals set in the
middleware are accessible from the method itself.
"""
def start_call(self, info, headers):
"""Called at the start of an RPC.
This must be thread-safe.
Parameters
----------
info : CallInfo
Information about the call.
headers : dict
A dictionary of headers from the client. Keys are strings
and values are lists of strings (for text headers) or
bytes (for binary headers).
Returns
-------
instance : ServerMiddleware
An instance of ServerMiddleware (the instance to use for
the call), or None if this call is not intercepted.
Raises
------
exception : pyarrow.ArrowException
If an exception is raised, the call will be rejected with
the given error.
"""
cdef class ServerMiddleware(_Weakrefable):
"""Server-side middleware for a call, instantiated per RPC.
Methods here should be fast and must be infalliable: they should
not raise exceptions or stall indefinitely.
"""
def sending_headers(self):
"""A callback before headers are sent.
Returns
-------
headers : dict
A dictionary of header values to add to the response, or
None if no headers are to be added. The dictionary should
have string keys and string or list-of-string values.
Bytes values are allowed, but the underlying transport may
not support them or may restrict them. For gRPC, binary
values are only allowed on headers ending in "-bin".
"""
def call_completed(self, exception):
"""A callback when the call finishes.
Parameters
----------
exception : pyarrow.ArrowException
If the call errored, this is the equivalent
exception. Will be None if the call succeeded.
"""
@staticmethod
cdef void wrap(object py_middleware,
shared_ptr[CServerMiddleware]* c_instance):
cdef PyServerMiddlewareVtable vtable
vtable.sending_headers = _middleware_sending_headers
vtable.call_completed = _middleware_call_completed
c_instance[0].reset(new CPyServerMiddleware(py_middleware, vtable))
cdef class _ServerMiddlewareFactoryWrapper(ServerMiddlewareFactory):
"""Wrapper to bundle server middleware into a single C++ one."""
cdef:
dict factories
def __init__(self, dict factories):
self.factories = factories
def start_call(self, info, headers):
instances = {}
for key, factory in self.factories.items():
instance = factory.start_call(info, headers)
if instance:
# TODO: prevent duplicate keys
instances[key] = instance
if instances:
wrapper = _ServerMiddlewareWrapper(instances)
return wrapper
return None
cdef class _ServerMiddlewareWrapper(ServerMiddleware):
cdef:
dict middleware
def __init__(self, dict middleware):
self.middleware = middleware
def sending_headers(self):
headers = collections.defaultdict(list)
for instance in self.middleware.values():
more_headers = instance.sending_headers()
if not more_headers:
continue
# Manually merge with existing headers (since headers are
# multi-valued)
for key, values in more_headers.items():
if isinstance(values, (bytes, str)):
values = (values,)
headers[key].extend(values)
return headers
def call_completed(self, exception):
for instance in self.middleware.values():
instance.call_completed(exception)
cdef class FlightServerBase(_Weakrefable):
"""A Flight service definition.
Override methods to define your Flight service.
Parameters
----------
location : str, tuple or Location optional, default None
Location to serve on. Either a gRPC URI like `grpc://localhost:port`,
a tuple of (host, port) pair, or a Location instance.
If None is passed then the server will be started on localhost with a
system provided random port.
auth_handler : ServerAuthHandler optional, default None
An authentication mechanism to use. May be None.
tls_certificates : list optional, default None
A list of (certificate, key) pairs.
verify_client : boolean optional, default False
If True, then enable mutual TLS: require the client to present
a client certificate, and validate the certificate.
root_certificates : bytes optional, default None
If enabling mutual TLS, this specifies the PEM-encoded root
certificate used to validate client certificates.
middleware : list optional, default None
A dictionary of :class:`ServerMiddlewareFactory` items. The
keys are used to retrieve the middleware instance during calls
(see :meth:`ServerCallContext.get_middleware`).
"""
cdef:
unique_ptr[PyFlightServer] server
def __init__(self, location=None, auth_handler=None,
tls_certificates=None, verify_client=None,
root_certificates=None, middleware=None):
if isinstance(location, (bytes, str)):
location = Location(location)
elif isinstance(location, (tuple, type(None))):
if location is None:
location = ('localhost', 0)
host, port = location
if tls_certificates:
location = Location.for_grpc_tls(host, port)
else:
location = Location.for_grpc_tcp(host, port)
elif not isinstance(location, Location):
raise TypeError('`location` argument must be a string, tuple or a '
'Location instance')
self.init(location, auth_handler, tls_certificates, verify_client,
tobytes(root_certificates or b""), middleware)
cdef init(self, Location location, ServerAuthHandler auth_handler,
list tls_certificates, c_bool verify_client,
bytes root_certificates, dict middleware):
cdef:
PyFlightServerVtable vtable = PyFlightServerVtable()
PyFlightServer* c_server
unique_ptr[CFlightServerOptions] c_options
CCertKeyPair c_cert
function[cb_server_middleware_start_call] start_call = \
&_server_middleware_start_call
pair[c_string, shared_ptr[CServerMiddlewareFactory]] c_middleware
c_options.reset(new CFlightServerOptions(Location.unwrap(location)))
# mTLS configuration
c_options.get().verify_client = verify_client
c_options.get().root_certificates = root_certificates
if auth_handler:
if not isinstance(auth_handler, ServerAuthHandler):
raise TypeError("auth_handler must be a ServerAuthHandler, "
"not a '{}'".format(type(auth_handler)))
c_options.get().auth_handler.reset(
(<ServerAuthHandler> auth_handler).to_handler())
if tls_certificates:
for cert, key in tls_certificates:
c_cert.pem_cert = tobytes(cert)
c_cert.pem_key = tobytes(key)
c_options.get().tls_certificates.push_back(c_cert)
if middleware:
py_middleware = _ServerMiddlewareFactoryWrapper(middleware)
c_middleware.first = CPyServerMiddlewareName
c_middleware.second.reset(new CPyServerMiddlewareFactory(
py_middleware,
start_call))
c_options.get().middleware.push_back(c_middleware)
vtable.list_flights = &_list_flights
vtable.get_flight_info = &_get_flight_info
vtable.get_schema = &_get_schema
vtable.do_put = &_do_put
vtable.do_get = &_do_get
vtable.do_exchange = &_do_exchange
vtable.list_actions = &_list_actions
vtable.do_action = &_do_action
c_server = new PyFlightServer(self, vtable)
self.server.reset(c_server)
with nogil:
check_flight_status(c_server.Init(deref(c_options)))
@property
def port(self):
"""
Get the port that this server is listening on.
Returns a non-positive value if the operation is invalid
(e.g. init() was not called or server is listening on a domain
socket).
"""
return self.server.get().port()
def list_flights(self, context, criteria):
raise NotImplementedError
def get_flight_info(self, context, descriptor):
raise NotImplementedError
def get_schema(self, context, descriptor):
raise NotImplementedError
def do_put(self, context, descriptor, reader,
writer: FlightMetadataWriter):
raise NotImplementedError
def do_get(self, context, ticket):
raise NotImplementedError
def do_exchange(self, context, descriptor, reader, writer):
raise NotImplementedError
def list_actions(self, context):
raise NotImplementedError
def do_action(self, context, action):
raise NotImplementedError
def serve(self):
"""Start serving.
This method only returns if shutdown() is called or a signal a
received.
"""
if self.server.get() == nullptr:
raise ValueError("run() on uninitialized FlightServerBase")
with nogil:
check_flight_status(self.server.get().ServeWithSignals())
def run(self):
warnings.warn("The 'FlightServer.run' method is deprecated, use "
"FlightServer.serve method instead")
self.serve()
def shutdown(self):
"""Shut down the server, blocking until current requests finish.
Do not call this directly from the implementation of a Flight
method, as then the server will block forever waiting for that
request to finish. Instead, call this method from a background
thread.
"""
# Must not hold the GIL: shutdown waits for pending RPCs to
# complete. Holding the GIL means Python-implemented Flight
# methods will never get to run, so this will hang
# indefinitely.
if self.server.get() == nullptr:
raise ValueError("shutdown() on uninitialized FlightServerBase")
with nogil:
check_flight_status(self.server.get().Shutdown())
def wait(self):
"""Block until server is terminated with shutdown."""
with nogil:
self.server.get().Wait()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.shutdown()
self.wait()
def connect(location, **kwargs):
"""
Connect to the Flight server
Parameters
----------
location : str, tuple or Location
Location to connect to. Either a gRPC URI like `grpc://localhost:port`,
a tuple of (host, port) pair, or a Location instance.
tls_root_certs : bytes or None
PEM-encoded
cert_chain: str or None
If provided, enables TLS mutual authentication.
private_key: str or None
If provided, enables TLS mutual authentication.
override_hostname : str or None
Override the hostname checked by TLS. Insecure, use with caution.
middleware : list or None
A list of ClientMiddlewareFactory instances to apply.
write_size_limit_bytes : int or None
A soft limit on the size of a data payload sent to the
server. Enabled if positive. If enabled, writing a record
batch that (when serialized) exceeds this limit will raise an
exception; the client can retry the write with a smaller
batch.
disable_server_verification : boolean or None
Disable verifying the server when using TLS.
Insecure, use with caution.
generic_options : list or None
A list of generic (string, int or string) options to pass to
the underlying transport.
Returns
-------
client : FlightClient
"""
return FlightClient(location, **kwargs)