blob: 87944bcc0665cb7393ba8995e82dd64c2a596246 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import UserList
import io
import pathlib
import pytest
import socket
import threading
import weakref
import numpy as np
import pyarrow as pa
from pyarrow.tests.util import changed_environ
try:
from pandas.testing import assert_frame_equal, assert_series_equal
import pandas as pd
except ImportError:
pass
class IpcFixture:
write_stats = None
def __init__(self, sink_factory=lambda: io.BytesIO()):
self._sink_factory = sink_factory
self.sink = self.get_sink()
def get_sink(self):
return self._sink_factory()
def get_source(self):
return self.sink.getvalue()
def write_batches(self, num_batches=5, as_table=False):
nrows = 5
schema = pa.schema([('one', pa.float64()), ('two', pa.utf8())])
writer = self._get_writer(self.sink, schema)
batches = []
for i in range(num_batches):
batch = pa.record_batch(
[np.random.randn(nrows),
['foo', None, 'bar', 'bazbaz', 'qux']],
schema=schema)
batches.append(batch)
if as_table:
table = pa.Table.from_batches(batches)
writer.write_table(table)
else:
for batch in batches:
writer.write_batch(batch)
self.write_stats = writer.stats
writer.close()
return batches
class FileFormatFixture(IpcFixture):
def _get_writer(self, sink, schema):
return pa.ipc.new_file(sink, schema)
def _check_roundtrip(self, as_table=False):
batches = self.write_batches(as_table=as_table)
file_contents = pa.BufferReader(self.get_source())
reader = pa.ipc.open_file(file_contents)
assert reader.num_record_batches == len(batches)
for i, batch in enumerate(batches):
# it works. Must convert back to DataFrame
batch = reader.get_batch(i)
assert batches[i].equals(batch)
assert reader.schema.equals(batches[0].schema)
assert isinstance(reader.stats, pa.ipc.ReadStats)
assert isinstance(self.write_stats, pa.ipc.WriteStats)
assert tuple(reader.stats) == tuple(self.write_stats)
class StreamFormatFixture(IpcFixture):
# ARROW-6474, for testing writing old IPC protocol with 4-byte prefix
use_legacy_ipc_format = False
# ARROW-9395, for testing writing old metadata version
options = None
def _get_writer(self, sink, schema):
return pa.ipc.new_stream(
sink,
schema,
use_legacy_format=self.use_legacy_ipc_format,
options=self.options,
)
class MessageFixture(IpcFixture):
def _get_writer(self, sink, schema):
return pa.RecordBatchStreamWriter(sink, schema)
@pytest.fixture
def ipc_fixture():
return IpcFixture()
@pytest.fixture
def file_fixture():
return FileFormatFixture()
@pytest.fixture
def stream_fixture():
return StreamFormatFixture()
def test_empty_file():
buf = b''
with pytest.raises(pa.ArrowInvalid):
pa.ipc.open_file(pa.BufferReader(buf))
def test_file_simple_roundtrip(file_fixture):
file_fixture._check_roundtrip(as_table=False)
def test_file_write_table(file_fixture):
file_fixture._check_roundtrip(as_table=True)
@pytest.mark.parametrize("sink_factory", [
lambda: io.BytesIO(),
lambda: pa.BufferOutputStream()
])
def test_file_read_all(sink_factory):
fixture = FileFormatFixture(sink_factory)
batches = fixture.write_batches()
file_contents = pa.BufferReader(fixture.get_source())
reader = pa.ipc.open_file(file_contents)
result = reader.read_all()
expected = pa.Table.from_batches(batches)
assert result.equals(expected)
def test_open_file_from_buffer(file_fixture):
# ARROW-2859; APIs accept the buffer protocol
file_fixture.write_batches()
source = file_fixture.get_source()
reader1 = pa.ipc.open_file(source)
reader2 = pa.ipc.open_file(pa.BufferReader(source))
reader3 = pa.RecordBatchFileReader(source)
result1 = reader1.read_all()
result2 = reader2.read_all()
result3 = reader3.read_all()
assert result1.equals(result2)
assert result1.equals(result3)
st1 = reader1.stats
assert st1.num_messages == 6
assert st1.num_record_batches == 5
assert reader2.stats == st1
assert reader3.stats == st1
@pytest.mark.pandas
def test_file_read_pandas(file_fixture):
frames = [batch.to_pandas() for batch in file_fixture.write_batches()]
file_contents = pa.BufferReader(file_fixture.get_source())
reader = pa.ipc.open_file(file_contents)
result = reader.read_pandas()
expected = pd.concat(frames).reset_index(drop=True)
assert_frame_equal(result, expected)
def test_file_pathlib(file_fixture, tmpdir):
file_fixture.write_batches()
source = file_fixture.get_source()
path = tmpdir.join('file.arrow').strpath
with open(path, 'wb') as f:
f.write(source)
t1 = pa.ipc.open_file(pathlib.Path(path)).read_all()
t2 = pa.ipc.open_file(pa.OSFile(path)).read_all()
assert t1.equals(t2)
def test_empty_stream():
buf = io.BytesIO(b'')
with pytest.raises(pa.ArrowInvalid):
pa.ipc.open_stream(buf)
@pytest.mark.pandas
def test_stream_categorical_roundtrip(stream_fixture):
df = pd.DataFrame({
'one': np.random.randn(5),
'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
categories=['foo', 'bar'],
ordered=True)
})
batch = pa.RecordBatch.from_pandas(df)
with stream_fixture._get_writer(stream_fixture.sink, batch.schema) as wr:
wr.write_batch(batch)
table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source()))
.read_all())
assert_frame_equal(table.to_pandas(), df)
def test_open_stream_from_buffer(stream_fixture):
# ARROW-2859
stream_fixture.write_batches()
source = stream_fixture.get_source()
reader1 = pa.ipc.open_stream(source)
reader2 = pa.ipc.open_stream(pa.BufferReader(source))
reader3 = pa.RecordBatchStreamReader(source)
result1 = reader1.read_all()
result2 = reader2.read_all()
result3 = reader3.read_all()
assert result1.equals(result2)
assert result1.equals(result3)
st1 = reader1.stats
assert st1.num_messages == 6
assert st1.num_record_batches == 5
assert reader2.stats == st1
assert reader3.stats == st1
assert tuple(st1) == tuple(stream_fixture.write_stats)
@pytest.mark.pandas
def test_stream_write_dispatch(stream_fixture):
# ARROW-1616
df = pd.DataFrame({
'one': np.random.randn(5),
'two': pd.Categorical(['foo', np.nan, 'bar', 'foo', 'foo'],
categories=['foo', 'bar'],
ordered=True)
})
table = pa.Table.from_pandas(df, preserve_index=False)
batch = pa.RecordBatch.from_pandas(df, preserve_index=False)
with stream_fixture._get_writer(stream_fixture.sink, table.schema) as wr:
wr.write(table)
wr.write(batch)
table = (pa.ipc.open_stream(pa.BufferReader(stream_fixture.get_source()))
.read_all())
assert_frame_equal(table.to_pandas(),
pd.concat([df, df], ignore_index=True))
@pytest.mark.pandas
def test_stream_write_table_batches(stream_fixture):
# ARROW-504
df = pd.DataFrame({
'one': np.random.randn(20),
})
b1 = pa.RecordBatch.from_pandas(df[:10], preserve_index=False)
b2 = pa.RecordBatch.from_pandas(df, preserve_index=False)
table = pa.Table.from_batches([b1, b2, b1])
with stream_fixture._get_writer(stream_fixture.sink, table.schema) as wr:
wr.write_table(table, max_chunksize=15)
batches = list(pa.ipc.open_stream(stream_fixture.get_source()))
assert list(map(len, batches)) == [10, 15, 5, 10]
result_table = pa.Table.from_batches(batches)
assert_frame_equal(result_table.to_pandas(),
pd.concat([df[:10], df, df[:10]],
ignore_index=True))
@pytest.mark.parametrize('use_legacy_ipc_format', [False, True])
def test_stream_simple_roundtrip(stream_fixture, use_legacy_ipc_format):
stream_fixture.use_legacy_ipc_format = use_legacy_ipc_format
batches = stream_fixture.write_batches()
file_contents = pa.BufferReader(stream_fixture.get_source())
reader = pa.ipc.open_stream(file_contents)
assert reader.schema.equals(batches[0].schema)
total = 0
for i, next_batch in enumerate(reader):
assert next_batch.equals(batches[i])
total += 1
assert total == len(batches)
with pytest.raises(StopIteration):
reader.read_next_batch()
@pytest.mark.zstd
def test_compression_roundtrip():
sink = io.BytesIO()
values = np.random.randint(0, 10, 10000)
table = pa.Table.from_arrays([values], names=["values"])
options = pa.ipc.IpcWriteOptions(compression='zstd')
with pa.ipc.RecordBatchFileWriter(
sink, table.schema, options=options) as writer:
writer.write_table(table)
len1 = len(sink.getvalue())
sink2 = io.BytesIO()
codec = pa.Codec('zstd', compression_level=5)
options = pa.ipc.IpcWriteOptions(compression=codec)
with pa.ipc.RecordBatchFileWriter(
sink2, table.schema, options=options) as writer:
writer.write_table(table)
len2 = len(sink2.getvalue())
# In theory len2 should be less than len1 but for this test we just want
# to ensure compression_level is being correctly passed down to the C++
# layer so we don't really care if it makes it worse or better
assert len2 != len1
t1 = pa.ipc.open_file(sink).read_all()
t2 = pa.ipc.open_file(sink2).read_all()
assert t1 == t2
def test_write_options():
options = pa.ipc.IpcWriteOptions()
assert options.allow_64bit is False
assert options.use_legacy_format is False
assert options.metadata_version == pa.ipc.MetadataVersion.V5
options.allow_64bit = True
assert options.allow_64bit is True
options.use_legacy_format = True
assert options.use_legacy_format is True
options.metadata_version = pa.ipc.MetadataVersion.V4
assert options.metadata_version == pa.ipc.MetadataVersion.V4
for value in ('V5', 42):
with pytest.raises((TypeError, ValueError)):
options.metadata_version = value
assert options.compression is None
for value in ['lz4', 'zstd']:
if pa.Codec.is_available(value):
options.compression = value
assert options.compression == value
options.compression = value.upper()
assert options.compression == value
options.compression = None
assert options.compression is None
with pytest.raises(TypeError):
options.compression = 0
assert options.use_threads is True
options.use_threads = False
assert options.use_threads is False
if pa.Codec.is_available('lz4'):
options = pa.ipc.IpcWriteOptions(
metadata_version=pa.ipc.MetadataVersion.V4,
allow_64bit=True,
use_legacy_format=True,
compression='lz4',
use_threads=False)
assert options.metadata_version == pa.ipc.MetadataVersion.V4
assert options.allow_64bit is True
assert options.use_legacy_format is True
assert options.compression == 'lz4'
assert options.use_threads is False
def test_write_options_legacy_exclusive(stream_fixture):
with pytest.raises(
ValueError,
match="provide at most one of options and use_legacy_format"):
stream_fixture.use_legacy_ipc_format = True
stream_fixture.options = pa.ipc.IpcWriteOptions()
stream_fixture.write_batches()
@pytest.mark.parametrize('options', [
pa.ipc.IpcWriteOptions(),
pa.ipc.IpcWriteOptions(allow_64bit=True),
pa.ipc.IpcWriteOptions(use_legacy_format=True),
pa.ipc.IpcWriteOptions(metadata_version=pa.ipc.MetadataVersion.V4),
pa.ipc.IpcWriteOptions(use_legacy_format=True,
metadata_version=pa.ipc.MetadataVersion.V4),
])
def test_stream_options_roundtrip(stream_fixture, options):
stream_fixture.use_legacy_ipc_format = None
stream_fixture.options = options
batches = stream_fixture.write_batches()
file_contents = pa.BufferReader(stream_fixture.get_source())
message = pa.ipc.read_message(stream_fixture.get_source())
assert message.metadata_version == options.metadata_version
reader = pa.ipc.open_stream(file_contents)
assert reader.schema.equals(batches[0].schema)
total = 0
for i, next_batch in enumerate(reader):
assert next_batch.equals(batches[i])
total += 1
assert total == len(batches)
with pytest.raises(StopIteration):
reader.read_next_batch()
def test_dictionary_delta(stream_fixture):
ty = pa.dictionary(pa.int8(), pa.utf8())
data = [["foo", "foo", None],
["foo", "bar", "foo"], # potential delta
["foo", "bar"],
["foo", None, "bar", "quux"], # potential delta
["bar", "quux"], # replacement
]
batches = [
pa.RecordBatch.from_arrays([pa.array(v, type=ty)], names=['dicts'])
for v in data]
schema = batches[0].schema
def write_batches():
with stream_fixture._get_writer(pa.MockOutputStream(),
schema) as writer:
for batch in batches:
writer.write_batch(batch)
return writer.stats
st = write_batches()
assert st.num_record_batches == 5
assert st.num_dictionary_batches == 4
assert st.num_replaced_dictionaries == 3
assert st.num_dictionary_deltas == 0
stream_fixture.use_legacy_ipc_format = None
stream_fixture.options = pa.ipc.IpcWriteOptions(
emit_dictionary_deltas=True)
st = write_batches()
assert st.num_record_batches == 5
assert st.num_dictionary_batches == 4
assert st.num_replaced_dictionaries == 1
assert st.num_dictionary_deltas == 2
def test_envvar_set_legacy_ipc_format():
schema = pa.schema([pa.field('foo', pa.int32())])
writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
assert not writer._use_legacy_format
assert writer._metadata_version == pa.ipc.MetadataVersion.V5
writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
assert not writer._use_legacy_format
assert writer._metadata_version == pa.ipc.MetadataVersion.V5
with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'):
writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
assert writer._use_legacy_format
assert writer._metadata_version == pa.ipc.MetadataVersion.V5
writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
assert writer._use_legacy_format
assert writer._metadata_version == pa.ipc.MetadataVersion.V5
with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'):
writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
assert not writer._use_legacy_format
assert writer._metadata_version == pa.ipc.MetadataVersion.V4
writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
assert not writer._use_legacy_format
assert writer._metadata_version == pa.ipc.MetadataVersion.V4
with changed_environ('ARROW_PRE_1_0_METADATA_VERSION', '1'):
with changed_environ('ARROW_PRE_0_15_IPC_FORMAT', '1'):
writer = pa.ipc.new_stream(pa.BufferOutputStream(), schema)
assert writer._use_legacy_format
assert writer._metadata_version == pa.ipc.MetadataVersion.V4
writer = pa.ipc.new_file(pa.BufferOutputStream(), schema)
assert writer._use_legacy_format
assert writer._metadata_version == pa.ipc.MetadataVersion.V4
def test_stream_read_all(stream_fixture):
batches = stream_fixture.write_batches()
file_contents = pa.BufferReader(stream_fixture.get_source())
reader = pa.ipc.open_stream(file_contents)
result = reader.read_all()
expected = pa.Table.from_batches(batches)
assert result.equals(expected)
@pytest.mark.pandas
def test_stream_read_pandas(stream_fixture):
frames = [batch.to_pandas() for batch in stream_fixture.write_batches()]
file_contents = stream_fixture.get_source()
reader = pa.ipc.open_stream(file_contents)
result = reader.read_pandas()
expected = pd.concat(frames).reset_index(drop=True)
assert_frame_equal(result, expected)
@pytest.fixture
def example_messages(stream_fixture):
batches = stream_fixture.write_batches()
file_contents = stream_fixture.get_source()
buf_reader = pa.BufferReader(file_contents)
reader = pa.MessageReader.open_stream(buf_reader)
return batches, list(reader)
def test_message_ctors_no_segfault():
with pytest.raises(TypeError):
repr(pa.Message())
with pytest.raises(TypeError):
repr(pa.MessageReader())
def test_message_reader(example_messages):
_, messages = example_messages
assert len(messages) == 6
assert messages[0].type == 'schema'
assert isinstance(messages[0].metadata, pa.Buffer)
assert isinstance(messages[0].body, pa.Buffer)
assert messages[0].metadata_version == pa.MetadataVersion.V5
for msg in messages[1:]:
assert msg.type == 'record batch'
assert isinstance(msg.metadata, pa.Buffer)
assert isinstance(msg.body, pa.Buffer)
assert msg.metadata_version == pa.MetadataVersion.V5
def test_message_serialize_read_message(example_messages):
_, messages = example_messages
msg = messages[0]
buf = msg.serialize()
reader = pa.BufferReader(buf.to_pybytes() * 2)
restored = pa.ipc.read_message(buf)
restored2 = pa.ipc.read_message(reader)
restored3 = pa.ipc.read_message(buf.to_pybytes())
restored4 = pa.ipc.read_message(reader)
assert msg.equals(restored)
assert msg.equals(restored2)
assert msg.equals(restored3)
assert msg.equals(restored4)
with pytest.raises(pa.ArrowInvalid, match="Corrupted message"):
pa.ipc.read_message(pa.BufferReader(b'ab'))
with pytest.raises(EOFError):
pa.ipc.read_message(reader)
@pytest.mark.gzip
def test_message_read_from_compressed(example_messages):
# Part of ARROW-5910
_, messages = example_messages
for message in messages:
raw_out = pa.BufferOutputStream()
with pa.output_stream(raw_out, compression='gzip') as compressed_out:
message.serialize_to(compressed_out)
compressed_buf = raw_out.getvalue()
result = pa.ipc.read_message(pa.input_stream(compressed_buf,
compression='gzip'))
assert result.equals(message)
def test_message_read_record_batch(example_messages):
batches, messages = example_messages
for batch, message in zip(batches, messages[1:]):
read_batch = pa.ipc.read_record_batch(message, batch.schema)
assert read_batch.equals(batch)
def test_read_record_batch_on_stream_error_message():
# ARROW-5374
batch = pa.record_batch([pa.array([b"foo"], type=pa.utf8())],
names=['strs'])
stream = pa.BufferOutputStream()
with pa.ipc.new_stream(stream, batch.schema) as writer:
writer.write_batch(batch)
buf = stream.getvalue()
with pytest.raises(IOError,
match="type record batch but got schema"):
pa.ipc.read_record_batch(buf, batch.schema)
# ----------------------------------------------------------------------
# Socket streaming testa
class StreamReaderServer(threading.Thread):
def init(self, do_read_all):
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._sock.bind(('127.0.0.1', 0))
self._sock.listen(1)
host, port = self._sock.getsockname()
self._do_read_all = do_read_all
self._schema = None
self._batches = []
self._table = None
return port
def run(self):
connection, client_address = self._sock.accept()
try:
source = connection.makefile(mode='rb')
reader = pa.ipc.open_stream(source)
self._schema = reader.schema
if self._do_read_all:
self._table = reader.read_all()
else:
for i, batch in enumerate(reader):
self._batches.append(batch)
finally:
connection.close()
def get_result(self):
return(self._schema, self._table if self._do_read_all
else self._batches)
class SocketStreamFixture(IpcFixture):
def __init__(self):
# XXX(wesm): test will decide when to start socket server. This should
# probably be refactored
pass
def start_server(self, do_read_all):
self._server = StreamReaderServer()
port = self._server.init(do_read_all)
self._server.start()
self._sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._sock.connect(('127.0.0.1', port))
self.sink = self.get_sink()
def stop_and_get_result(self):
import struct
self.sink.write(struct.pack('Q', 0))
self.sink.flush()
self._sock.close()
self._server.join()
return self._server.get_result()
def get_sink(self):
return self._sock.makefile(mode='wb')
def _get_writer(self, sink, schema):
return pa.RecordBatchStreamWriter(sink, schema)
@pytest.fixture
def socket_fixture():
return SocketStreamFixture()
def test_socket_simple_roundtrip(socket_fixture):
socket_fixture.start_server(do_read_all=False)
writer_batches = socket_fixture.write_batches()
reader_schema, reader_batches = socket_fixture.stop_and_get_result()
assert reader_schema.equals(writer_batches[0].schema)
assert len(reader_batches) == len(writer_batches)
for i, batch in enumerate(writer_batches):
assert reader_batches[i].equals(batch)
def test_socket_read_all(socket_fixture):
socket_fixture.start_server(do_read_all=True)
writer_batches = socket_fixture.write_batches()
_, result = socket_fixture.stop_and_get_result()
expected = pa.Table.from_batches(writer_batches)
assert result.equals(expected)
# ----------------------------------------------------------------------
# Miscellaneous IPC tests
@pytest.mark.pandas
def test_ipc_file_stream_has_eos():
# ARROW-5395
df = pd.DataFrame({'foo': [1.5]})
batch = pa.RecordBatch.from_pandas(df)
sink = pa.BufferOutputStream()
write_file(batch, sink)
buffer = sink.getvalue()
# skip the file magic
reader = pa.ipc.open_stream(buffer[8:])
# will fail if encounters footer data instead of eos
rdf = reader.read_pandas()
assert_frame_equal(df, rdf)
@pytest.mark.pandas
def test_ipc_zero_copy_numpy():
df = pd.DataFrame({'foo': [1.5]})
batch = pa.RecordBatch.from_pandas(df)
sink = pa.BufferOutputStream()
write_file(batch, sink)
buffer = sink.getvalue()
reader = pa.BufferReader(buffer)
batches = read_file(reader)
data = batches[0].to_pandas()
rdf = pd.DataFrame(data)
assert_frame_equal(df, rdf)
def test_ipc_stream_no_batches():
# ARROW-2307
table = pa.Table.from_arrays([pa.array([1, 2, 3, 4]),
pa.array(['foo', 'bar', 'baz', 'qux'])],
names=['a', 'b'])
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, table.schema):
pass
source = sink.getvalue()
with pa.ipc.open_stream(source) as reader:
result = reader.read_all()
assert result.schema.equals(table.schema)
assert len(result) == 0
@pytest.mark.pandas
def test_get_record_batch_size():
N = 10
itemsize = 8
df = pd.DataFrame({'foo': np.random.randn(N)})
batch = pa.RecordBatch.from_pandas(df)
assert pa.ipc.get_record_batch_size(batch) > (N * itemsize)
@pytest.mark.pandas
def _check_serialize_pandas_round_trip(df, use_threads=False):
buf = pa.serialize_pandas(df, nthreads=2 if use_threads else 1)
result = pa.deserialize_pandas(buf, use_threads=use_threads)
assert_frame_equal(result, df)
@pytest.mark.pandas
def test_pandas_serialize_round_trip():
index = pd.Index([1, 2, 3], name='my_index')
columns = ['foo', 'bar']
df = pd.DataFrame(
{'foo': [1.5, 1.6, 1.7], 'bar': list('abc')},
index=index, columns=columns
)
_check_serialize_pandas_round_trip(df)
@pytest.mark.pandas
def test_pandas_serialize_round_trip_nthreads():
index = pd.Index([1, 2, 3], name='my_index')
columns = ['foo', 'bar']
df = pd.DataFrame(
{'foo': [1.5, 1.6, 1.7], 'bar': list('abc')},
index=index, columns=columns
)
_check_serialize_pandas_round_trip(df, use_threads=True)
@pytest.mark.pandas
def test_pandas_serialize_round_trip_multi_index():
index1 = pd.Index([1, 2, 3], name='level_1')
index2 = pd.Index(list('def'), name=None)
index = pd.MultiIndex.from_arrays([index1, index2])
columns = ['foo', 'bar']
df = pd.DataFrame(
{'foo': [1.5, 1.6, 1.7], 'bar': list('abc')},
index=index,
columns=columns,
)
_check_serialize_pandas_round_trip(df)
@pytest.mark.pandas
def test_serialize_pandas_empty_dataframe():
df = pd.DataFrame()
_check_serialize_pandas_round_trip(df)
@pytest.mark.pandas
def test_pandas_serialize_round_trip_not_string_columns():
df = pd.DataFrame(list(zip([1.5, 1.6, 1.7], 'abc')))
buf = pa.serialize_pandas(df)
result = pa.deserialize_pandas(buf)
assert_frame_equal(result, df)
@pytest.mark.pandas
def test_serialize_pandas_no_preserve_index():
df = pd.DataFrame({'a': [1, 2, 3]}, index=[1, 2, 3])
expected = pd.DataFrame({'a': [1, 2, 3]})
buf = pa.serialize_pandas(df, preserve_index=False)
result = pa.deserialize_pandas(buf)
assert_frame_equal(result, expected)
buf = pa.serialize_pandas(df, preserve_index=True)
result = pa.deserialize_pandas(buf)
assert_frame_equal(result, df)
@pytest.mark.pandas
@pytest.mark.filterwarnings("ignore:'pyarrow:FutureWarning")
def test_serialize_with_pandas_objects():
df = pd.DataFrame({'a': [1, 2, 3]}, index=[1, 2, 3])
s = pd.Series([1, 2, 3, 4])
data = {
'a_series': df['a'],
'a_frame': df,
's_series': s
}
serialized = pa.serialize(data).to_buffer()
deserialized = pa.deserialize(serialized)
assert_frame_equal(deserialized['a_frame'], df)
assert_series_equal(deserialized['a_series'], df['a'])
assert deserialized['a_series'].name == 'a'
assert_series_equal(deserialized['s_series'], s)
assert deserialized['s_series'].name is None
@pytest.mark.pandas
def test_schema_batch_serialize_methods():
nrows = 5
df = pd.DataFrame({
'one': np.random.randn(nrows),
'two': ['foo', np.nan, 'bar', 'bazbaz', 'qux']})
batch = pa.RecordBatch.from_pandas(df)
s_schema = batch.schema.serialize()
s_batch = batch.serialize()
recons_schema = pa.ipc.read_schema(s_schema)
recons_batch = pa.ipc.read_record_batch(s_batch, recons_schema)
assert recons_batch.equals(batch)
def test_schema_serialization_with_metadata():
field_metadata = {b'foo': b'bar', b'kind': b'field'}
schema_metadata = {b'foo': b'bar', b'kind': b'schema'}
f0 = pa.field('a', pa.int8())
f1 = pa.field('b', pa.string(), metadata=field_metadata)
schema = pa.schema([f0, f1], metadata=schema_metadata)
s_schema = schema.serialize()
recons_schema = pa.ipc.read_schema(s_schema)
assert recons_schema.equals(schema)
assert recons_schema.metadata == schema_metadata
assert recons_schema[0].metadata is None
assert recons_schema[1].metadata == field_metadata
def test_deprecated_pyarrow_ns_apis():
table = pa.table([pa.array([1, 2, 3, 4])], names=['a'])
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, table.schema) as writer:
writer.write(table)
with pytest.warns(FutureWarning,
match="please use pyarrow.ipc.open_stream"):
pa.open_stream(sink.getvalue())
sink = pa.BufferOutputStream()
with pa.ipc.new_file(sink, table.schema) as writer:
writer.write(table)
with pytest.warns(FutureWarning, match="please use pyarrow.ipc.open_file"):
pa.open_file(sink.getvalue())
def write_file(batch, sink):
with pa.ipc.new_file(sink, batch.schema) as writer:
writer.write_batch(batch)
def read_file(source):
with pa.ipc.open_file(source) as reader:
return [reader.get_batch(i) for i in range(reader.num_record_batches)]
def test_write_empty_ipc_file():
# ARROW-3894: IPC file was not being properly initialized when no record
# batches are being written
schema = pa.schema([('field', pa.int64())])
sink = pa.BufferOutputStream()
with pa.ipc.new_file(sink, schema):
pass
buf = sink.getvalue()
with pa.RecordBatchFileReader(pa.BufferReader(buf)) as reader:
table = reader.read_all()
assert len(table) == 0
assert table.schema.equals(schema)
def test_py_record_batch_reader():
def make_schema():
return pa.schema([('field', pa.int64())])
def make_batches():
schema = make_schema()
batch1 = pa.record_batch([[1, 2, 3]], schema=schema)
batch2 = pa.record_batch([[4, 5]], schema=schema)
return [batch1, batch2]
# With iterable
batches = UserList(make_batches()) # weakrefable
wr = weakref.ref(batches)
with pa.ipc.RecordBatchReader.from_batches(make_schema(),
batches) as reader:
batches = None
assert wr() is not None
assert list(reader) == make_batches()
assert wr() is None
# With iterator
batches = iter(UserList(make_batches())) # weakrefable
wr = weakref.ref(batches)
with pa.ipc.RecordBatchReader.from_batches(make_schema(),
batches) as reader:
batches = None
assert wr() is not None
assert list(reader) == make_batches()
assert wr() is None