| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import collections |
| import warnings |
| |
| import numpy as np |
| |
| import pyarrow as pa |
| from pyarrow.lib import SerializationContext, py_buffer, builtin_pickle |
| |
| try: |
| import cloudpickle |
| except ImportError: |
| cloudpickle = builtin_pickle |
| |
| |
| try: |
| # This function is available after numpy-0.16.0. |
| # See also: https://github.com/numpy/numpy/blob/master/numpy/lib/format.py |
| from numpy.lib.format import descr_to_dtype |
| except ImportError: |
| def descr_to_dtype(descr): |
| ''' |
| descr may be stored as dtype.descr, which is a list of (name, format, |
| [shape]) tuples where format may be a str or a tuple. Offsets are not |
| explicitly saved, rather empty fields with name, format == '', '|Vn' |
| are added as padding. This function reverses the process, eliminating |
| the empty padding fields. |
| ''' |
| if isinstance(descr, str): |
| # No padding removal needed |
| return np.dtype(descr) |
| elif isinstance(descr, tuple): |
| # subtype, will always have a shape descr[1] |
| dt = descr_to_dtype(descr[0]) |
| return np.dtype((dt, descr[1])) |
| fields = [] |
| offset = 0 |
| for field in descr: |
| if len(field) == 2: |
| name, descr_str = field |
| dt = descr_to_dtype(descr_str) |
| else: |
| name, descr_str, shape = field |
| dt = np.dtype((descr_to_dtype(descr_str), shape)) |
| |
| # Ignore padding bytes, which will be void bytes with '' as name |
| # Once support for blank names is removed, only "if name == ''" |
| # needed) |
| is_pad = (name == '' and dt.type is np.void and dt.names is None) |
| if not is_pad: |
| fields.append((name, dt, offset)) |
| |
| offset += dt.itemsize |
| |
| names, formats, offsets = zip(*fields) |
| # names may be (title, names) tuples |
| nametups = (n if isinstance(n, tuple) else (None, n) for n in names) |
| titles, names = zip(*nametups) |
| return np.dtype({'names': names, 'formats': formats, 'titles': titles, |
| 'offsets': offsets, 'itemsize': offset}) |
| |
| |
| def _deprecate_serialization(name): |
| msg = ( |
| "'pyarrow.{}' is deprecated as of 2.0.0 and will be removed in a " |
| "future version. Use pickle or the pyarrow IPC functionality instead." |
| ).format(name) |
| warnings.warn(msg, FutureWarning, stacklevel=3) |
| |
| |
| # ---------------------------------------------------------------------- |
| # Set up serialization for numpy with dtype object (primitive types are |
| # handled efficiently with Arrow's Tensor facilities, see |
| # python_to_arrow.cc) |
| |
| def _serialize_numpy_array_list(obj): |
| if obj.dtype.str != '|O': |
| # Make the array c_contiguous if necessary so that we can call change |
| # the view. |
| if not obj.flags.c_contiguous: |
| obj = np.ascontiguousarray(obj) |
| return obj.view('uint8'), np.lib.format.dtype_to_descr(obj.dtype) |
| else: |
| return obj.tolist(), np.lib.format.dtype_to_descr(obj.dtype) |
| |
| |
| def _deserialize_numpy_array_list(data): |
| if data[1] != '|O': |
| assert data[0].dtype == np.uint8 |
| return data[0].view(descr_to_dtype(data[1])) |
| else: |
| return np.array(data[0], dtype=np.dtype(data[1])) |
| |
| |
| def _serialize_numpy_matrix(obj): |
| if obj.dtype.str != '|O': |
| # Make the array c_contiguous if necessary so that we can call change |
| # the view. |
| if not obj.flags.c_contiguous: |
| obj = np.ascontiguousarray(obj.A) |
| return obj.A.view('uint8'), np.lib.format.dtype_to_descr(obj.dtype) |
| else: |
| return obj.A.tolist(), np.lib.format.dtype_to_descr(obj.dtype) |
| |
| |
| def _deserialize_numpy_matrix(data): |
| if data[1] != '|O': |
| assert data[0].dtype == np.uint8 |
| return np.matrix(data[0].view(descr_to_dtype(data[1])), |
| copy=False) |
| else: |
| return np.matrix(data[0], dtype=np.dtype(data[1]), copy=False) |
| |
| |
| # ---------------------------------------------------------------------- |
| # pyarrow.RecordBatch-specific serialization matters |
| |
| def _serialize_pyarrow_recordbatch(batch): |
| output_stream = pa.BufferOutputStream() |
| with pa.RecordBatchStreamWriter(output_stream, schema=batch.schema) as wr: |
| wr.write_batch(batch) |
| return output_stream.getvalue() # This will also close the stream. |
| |
| |
| def _deserialize_pyarrow_recordbatch(buf): |
| with pa.RecordBatchStreamReader(buf) as reader: |
| return reader.read_next_batch() |
| |
| |
| # ---------------------------------------------------------------------- |
| # pyarrow.Array-specific serialization matters |
| |
| def _serialize_pyarrow_array(array): |
| # TODO(suquark): implement more effcient array serialization. |
| batch = pa.RecordBatch.from_arrays([array], ['']) |
| return _serialize_pyarrow_recordbatch(batch) |
| |
| |
| def _deserialize_pyarrow_array(buf): |
| # TODO(suquark): implement more effcient array deserialization. |
| batch = _deserialize_pyarrow_recordbatch(buf) |
| return batch.columns[0] |
| |
| |
| # ---------------------------------------------------------------------- |
| # pyarrow.Table-specific serialization matters |
| |
| def _serialize_pyarrow_table(table): |
| output_stream = pa.BufferOutputStream() |
| with pa.RecordBatchStreamWriter(output_stream, schema=table.schema) as wr: |
| wr.write_table(table) |
| return output_stream.getvalue() # This will also close the stream. |
| |
| |
| def _deserialize_pyarrow_table(buf): |
| with pa.RecordBatchStreamReader(buf) as reader: |
| return reader.read_all() |
| |
| |
| def _pickle_to_buffer(x): |
| pickled = builtin_pickle.dumps(x, protocol=builtin_pickle.HIGHEST_PROTOCOL) |
| return py_buffer(pickled) |
| |
| |
| def _load_pickle_from_buffer(data): |
| as_memoryview = memoryview(data) |
| return builtin_pickle.loads(as_memoryview) |
| |
| |
| # ---------------------------------------------------------------------- |
| # pandas-specific serialization matters |
| |
| def _register_custom_pandas_handlers(context): |
| # ARROW-1784, faster path for pandas-only visibility |
| |
| try: |
| import pandas as pd |
| except ImportError: |
| return |
| |
| import pyarrow.pandas_compat as pdcompat |
| |
| sparse_type_error_msg = ( |
| '{0} serialization is not supported.\n' |
| 'Note that {0} is planned to be deprecated ' |
| 'in pandas future releases.\n' |
| 'See https://github.com/pandas-dev/pandas/issues/19239 ' |
| 'for more information.' |
| ) |
| |
| def _serialize_pandas_dataframe(obj): |
| if (pdcompat._pandas_api.has_sparse and |
| isinstance(obj, pd.SparseDataFrame)): |
| raise NotImplementedError( |
| sparse_type_error_msg.format('SparseDataFrame') |
| ) |
| |
| return pdcompat.dataframe_to_serialized_dict(obj) |
| |
| def _deserialize_pandas_dataframe(data): |
| return pdcompat.serialized_dict_to_dataframe(data) |
| |
| def _serialize_pandas_series(obj): |
| if (pdcompat._pandas_api.has_sparse and |
| isinstance(obj, pd.SparseSeries)): |
| raise NotImplementedError( |
| sparse_type_error_msg.format('SparseSeries') |
| ) |
| |
| return _serialize_pandas_dataframe(pd.DataFrame({obj.name: obj})) |
| |
| def _deserialize_pandas_series(data): |
| deserialized = _deserialize_pandas_dataframe(data) |
| return deserialized[deserialized.columns[0]] |
| |
| context.register_type( |
| pd.Series, 'pd.Series', |
| custom_serializer=_serialize_pandas_series, |
| custom_deserializer=_deserialize_pandas_series) |
| |
| context.register_type( |
| pd.Index, 'pd.Index', |
| custom_serializer=_pickle_to_buffer, |
| custom_deserializer=_load_pickle_from_buffer) |
| |
| if hasattr(pd.core, 'arrays'): |
| if hasattr(pd.core.arrays, 'interval'): |
| context.register_type( |
| pd.core.arrays.interval.IntervalArray, |
| 'pd.core.arrays.interval.IntervalArray', |
| custom_serializer=_pickle_to_buffer, |
| custom_deserializer=_load_pickle_from_buffer) |
| |
| if hasattr(pd.core.arrays, 'period'): |
| context.register_type( |
| pd.core.arrays.period.PeriodArray, |
| 'pd.core.arrays.period.PeriodArray', |
| custom_serializer=_pickle_to_buffer, |
| custom_deserializer=_load_pickle_from_buffer) |
| |
| if hasattr(pd.core.arrays, 'datetimes'): |
| context.register_type( |
| pd.core.arrays.datetimes.DatetimeArray, |
| 'pd.core.arrays.datetimes.DatetimeArray', |
| custom_serializer=_pickle_to_buffer, |
| custom_deserializer=_load_pickle_from_buffer) |
| |
| context.register_type( |
| pd.DataFrame, 'pd.DataFrame', |
| custom_serializer=_serialize_pandas_dataframe, |
| custom_deserializer=_deserialize_pandas_dataframe) |
| |
| |
| def register_torch_serialization_handlers(serialization_context): |
| # ---------------------------------------------------------------------- |
| # Set up serialization for pytorch tensors |
| _deprecate_serialization("register_torch_serialization_handlers") |
| |
| try: |
| import torch |
| |
| def _serialize_torch_tensor(obj): |
| if obj.is_sparse: |
| return pa.SparseCOOTensor.from_numpy( |
| obj._values().detach().numpy(), |
| obj._indices().detach().numpy().T, |
| shape=list(obj.shape)) |
| else: |
| return obj.detach().numpy() |
| |
| def _deserialize_torch_tensor(data): |
| if isinstance(data, pa.SparseCOOTensor): |
| return torch.sparse_coo_tensor( |
| indices=data.to_numpy()[1].T, |
| values=data.to_numpy()[0][:, 0], |
| size=data.shape) |
| else: |
| return torch.from_numpy(data) |
| |
| for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor, |
| torch.ByteTensor, torch.CharTensor, torch.ShortTensor, |
| torch.IntTensor, torch.LongTensor, torch.Tensor]: |
| serialization_context.register_type( |
| t, "torch." + t.__name__, |
| custom_serializer=_serialize_torch_tensor, |
| custom_deserializer=_deserialize_torch_tensor) |
| except ImportError: |
| # no torch |
| pass |
| |
| |
| def _register_collections_serialization_handlers(serialization_context): |
| def _serialize_deque(obj): |
| return list(obj) |
| |
| def _deserialize_deque(data): |
| return collections.deque(data) |
| |
| serialization_context.register_type( |
| collections.deque, "collections.deque", |
| custom_serializer=_serialize_deque, |
| custom_deserializer=_deserialize_deque) |
| |
| def _serialize_ordered_dict(obj): |
| return list(obj.keys()), list(obj.values()) |
| |
| def _deserialize_ordered_dict(data): |
| return collections.OrderedDict(zip(data[0], data[1])) |
| |
| serialization_context.register_type( |
| collections.OrderedDict, "collections.OrderedDict", |
| custom_serializer=_serialize_ordered_dict, |
| custom_deserializer=_deserialize_ordered_dict) |
| |
| def _serialize_default_dict(obj): |
| return list(obj.keys()), list(obj.values()), obj.default_factory |
| |
| def _deserialize_default_dict(data): |
| return collections.defaultdict(data[2], zip(data[0], data[1])) |
| |
| serialization_context.register_type( |
| collections.defaultdict, "collections.defaultdict", |
| custom_serializer=_serialize_default_dict, |
| custom_deserializer=_deserialize_default_dict) |
| |
| def _serialize_counter(obj): |
| return list(obj.keys()), list(obj.values()) |
| |
| def _deserialize_counter(data): |
| return collections.Counter(dict(zip(data[0], data[1]))) |
| |
| serialization_context.register_type( |
| collections.Counter, "collections.Counter", |
| custom_serializer=_serialize_counter, |
| custom_deserializer=_deserialize_counter) |
| |
| |
| # ---------------------------------------------------------------------- |
| # Set up serialization for scipy sparse matrices. Primitive types are handled |
| # efficiently with Arrow's SparseTensor facilities, see numpy_convert.cc) |
| |
| def _register_scipy_handlers(serialization_context): |
| try: |
| from scipy.sparse import (csr_matrix, csc_matrix, coo_matrix, |
| isspmatrix_coo, isspmatrix_csr, |
| isspmatrix_csc, isspmatrix) |
| |
| def _serialize_scipy_sparse(obj): |
| if isspmatrix_coo(obj): |
| return 'coo', pa.SparseCOOTensor.from_scipy(obj) |
| |
| elif isspmatrix_csr(obj): |
| return 'csr', pa.SparseCSRMatrix.from_scipy(obj) |
| |
| elif isspmatrix_csc(obj): |
| return 'csc', pa.SparseCSCMatrix.from_scipy(obj) |
| |
| elif isspmatrix(obj): |
| return 'csr', pa.SparseCOOTensor.from_scipy(obj.to_coo()) |
| |
| else: |
| raise NotImplementedError( |
| "Serialization of {} is not supported.".format(obj[0])) |
| |
| def _deserialize_scipy_sparse(data): |
| if data[0] == 'coo': |
| return data[1].to_scipy() |
| |
| elif data[0] == 'csr': |
| return data[1].to_scipy() |
| |
| elif data[0] == 'csc': |
| return data[1].to_scipy() |
| |
| else: |
| return data[1].to_scipy() |
| |
| serialization_context.register_type( |
| coo_matrix, 'scipy.sparse.coo.coo_matrix', |
| custom_serializer=_serialize_scipy_sparse, |
| custom_deserializer=_deserialize_scipy_sparse) |
| |
| serialization_context.register_type( |
| csr_matrix, 'scipy.sparse.csr.csr_matrix', |
| custom_serializer=_serialize_scipy_sparse, |
| custom_deserializer=_deserialize_scipy_sparse) |
| |
| serialization_context.register_type( |
| csc_matrix, 'scipy.sparse.csc.csc_matrix', |
| custom_serializer=_serialize_scipy_sparse, |
| custom_deserializer=_deserialize_scipy_sparse) |
| |
| except ImportError: |
| # no scipy |
| pass |
| |
| |
| # ---------------------------------------------------------------------- |
| # Set up serialization for pydata/sparse tensors. |
| |
| def _register_pydata_sparse_handlers(serialization_context): |
| try: |
| import sparse |
| |
| def _serialize_pydata_sparse(obj): |
| if isinstance(obj, sparse.COO): |
| return 'coo', pa.SparseCOOTensor.from_pydata_sparse(obj) |
| else: |
| raise NotImplementedError( |
| "Serialization of {} is not supported.".format(sparse.COO)) |
| |
| def _deserialize_pydata_sparse(data): |
| if data[0] == 'coo': |
| data_array, coords = data[1].to_numpy() |
| return sparse.COO( |
| data=data_array[:, 0], |
| coords=coords.T, shape=data[1].shape) |
| |
| serialization_context.register_type( |
| sparse.COO, 'sparse.COO', |
| custom_serializer=_serialize_pydata_sparse, |
| custom_deserializer=_deserialize_pydata_sparse) |
| |
| except ImportError: |
| # no pydata/sparse |
| pass |
| |
| |
| def _register_default_serialization_handlers(serialization_context): |
| |
| # ---------------------------------------------------------------------- |
| # Set up serialization for primitive datatypes |
| |
| # TODO(pcm): This is currently a workaround until arrow supports |
| # arbitrary precision integers. This is only called on long integers, |
| # see the associated case in the append method in python_to_arrow.cc |
| serialization_context.register_type( |
| int, "int", |
| custom_serializer=lambda obj: str(obj), |
| custom_deserializer=lambda data: int(data)) |
| |
| serialization_context.register_type( |
| type(lambda: 0), "function", |
| pickle=True) |
| |
| serialization_context.register_type(type, "type", pickle=True) |
| |
| serialization_context.register_type( |
| np.matrix, 'np.matrix', |
| custom_serializer=_serialize_numpy_matrix, |
| custom_deserializer=_deserialize_numpy_matrix) |
| |
| serialization_context.register_type( |
| np.ndarray, 'np.array', |
| custom_serializer=_serialize_numpy_array_list, |
| custom_deserializer=_deserialize_numpy_array_list) |
| |
| serialization_context.register_type( |
| pa.Array, 'pyarrow.Array', |
| custom_serializer=_serialize_pyarrow_array, |
| custom_deserializer=_deserialize_pyarrow_array) |
| |
| serialization_context.register_type( |
| pa.RecordBatch, 'pyarrow.RecordBatch', |
| custom_serializer=_serialize_pyarrow_recordbatch, |
| custom_deserializer=_deserialize_pyarrow_recordbatch) |
| |
| serialization_context.register_type( |
| pa.Table, 'pyarrow.Table', |
| custom_serializer=_serialize_pyarrow_table, |
| custom_deserializer=_deserialize_pyarrow_table) |
| |
| _register_collections_serialization_handlers(serialization_context) |
| _register_custom_pandas_handlers(serialization_context) |
| _register_scipy_handlers(serialization_context) |
| _register_pydata_sparse_handlers(serialization_context) |
| |
| |
| def register_default_serialization_handlers(serialization_context): |
| _deprecate_serialization("register_default_serialization_handlers") |
| _register_default_serialization_handlers(serialization_context) |
| |
| |
| def default_serialization_context(): |
| _deprecate_serialization("default_serialization_context") |
| context = SerializationContext() |
| _register_default_serialization_handlers(context) |
| return context |