blob: 19c8abc2e40bb34b0f4eefb5416f6ff660fde7c8 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
"""
Tests for write_arrow_batch: write PyArrow RecordBatch/Table to tsfile
and verify correctness by reading back.
"""
import os
from datetime import date
import numpy as np
import pandas as pd
import pytest
pa = pytest.importorskip("pyarrow", reason="pyarrow is not installed")
from tsfile import ColumnCategory, ColumnSchema, TableSchema, TSDataType, TsFileReader
from tsfile import TsFileTableWriter
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_schema(table_name, extra_cols):
"""Build a TableSchema with a string TAG 'device' plus the given field cols."""
return TableSchema(
table_name,
[ColumnSchema("device", TSDataType.STRING, ColumnCategory.TAG)] + extra_cols,
)
def _read_all_arrow(file_path, table_name, columns, start=0, end=10**18, batch_size=4096):
"""Read all rows from file via read_arrow_batch and return as a pa.Table."""
reader = TsFileReader(file_path)
rs = reader.query_table_batch(
table_name=table_name,
column_names=columns,
start_time=start,
end_time=end,
batch_size=batch_size,
)
batches = []
while True:
batch = rs.read_arrow_batch()
if batch is None:
break
batches.append(batch)
rs.close()
reader.close()
if not batches:
return pa.table({})
return pa.concat_tables(batches)
# ---------------------------------------------------------------------------
# Basic write + read-back
# ---------------------------------------------------------------------------
def test_write_arrow_basic():
"""Write 1 000 rows via write_arrow_batch and verify count + values."""
path = "test_write_arrow_basic.tsfile"
table_name = "t"
n = 1000
schema = _make_schema(table_name, [
ColumnSchema("value1", TSDataType.INT64, ColumnCategory.FIELD),
ColumnSchema("value2", TSDataType.DOUBLE, ColumnCategory.FIELD),
])
batch = pa.record_batch({
"time": pa.array(np.arange(n, dtype="int64"), type=pa.timestamp("ns")),
"device": pa.array([f"d{i}" for i in range(n)], type=pa.string()),
"value1": pa.array(np.arange(n, dtype="int64"), type=pa.int64()),
"value2": pa.array(np.arange(n, dtype="float64") * 1.5, type=pa.float64()),
})
try:
if os.path.exists(path):
os.remove(path)
with TsFileTableWriter(path, schema) as w:
w.write_arrow_batch(batch)
result = _read_all_arrow(path, table_name, ["device", "value1", "value2"])
assert len(result) == n
df = result.to_pandas().sort_values("time").reset_index(drop=True)
assert list(df["value1"]) == list(range(n))
assert all(abs(df["value2"].iloc[i] - i * 1.5) < 1e-9 for i in range(n))
finally:
if os.path.exists(path):
os.remove(path)
# ---------------------------------------------------------------------------
# pa.Table input
# ---------------------------------------------------------------------------
def test_write_arrow_from_table():
"""write_arrow_batch should accept pa.Table (multi-chunk) as well."""
path = "test_write_arrow_from_table.tsfile"
table_name = "t"
n = 500
schema = _make_schema(table_name, [
ColumnSchema("v", TSDataType.INT32, ColumnCategory.FIELD),
])
tbl = pa.table({
"time": pa.array(np.arange(n, dtype="int64"), type=pa.timestamp("ns")),
"device": pa.array(["dev"] * n, type=pa.string()),
"v": pa.array(np.arange(n, dtype="int32"), type=pa.int32()),
})
try:
if os.path.exists(path):
os.remove(path)
with TsFileTableWriter(path, schema) as w:
w.write_arrow_batch(tbl)
result = _read_all_arrow(path, table_name, ["device", "v"])
assert len(result) == n
df = result.to_pandas().sort_values("time").reset_index(drop=True)
assert list(df["v"]) == list(range(n))
finally:
if os.path.exists(path):
os.remove(path)
# ---------------------------------------------------------------------------
# Multiple batches
# ---------------------------------------------------------------------------
def test_write_arrow_multiple_batches():
"""Write several batches sequentially and verify the total row count."""
path = "test_write_arrow_multi.tsfile"
table_name = "t"
rows_per_batch = 300
num_batches = 4
total = rows_per_batch * num_batches
schema = _make_schema(table_name, [
ColumnSchema("v", TSDataType.INT64, ColumnCategory.FIELD),
])
try:
if os.path.exists(path):
os.remove(path)
with TsFileTableWriter(path, schema) as w:
for b in range(num_batches):
start_ts = b * rows_per_batch
batch = pa.record_batch({
"time": pa.array(
np.arange(start_ts, start_ts + rows_per_batch, dtype="int64"),
type=pa.timestamp("ns")),
"device": pa.array(["dev"] * rows_per_batch, type=pa.string()),
"v": pa.array(
np.arange(start_ts, start_ts + rows_per_batch, dtype="int64"),
type=pa.int64()),
})
w.write_arrow_batch(batch)
result = _read_all_arrow(path, table_name, ["device", "v"])
assert len(result) == total
finally:
if os.path.exists(path):
os.remove(path)
# ---------------------------------------------------------------------------
# All supported data types
# ---------------------------------------------------------------------------
def test_write_arrow_all_datatypes():
"""Write every supported data type and verify values read back correctly."""
path = "test_write_arrow_all_types.tsfile"
table_name = "t"
n = 200
schema = TableSchema(table_name, [
ColumnSchema("tag", TSDataType.STRING, ColumnCategory.TAG),
ColumnSchema("bool_col", TSDataType.BOOLEAN, ColumnCategory.FIELD),
ColumnSchema("int32_col", TSDataType.INT32, ColumnCategory.FIELD),
ColumnSchema("int64_col", TSDataType.INT64, ColumnCategory.FIELD),
ColumnSchema("float_col", TSDataType.FLOAT, ColumnCategory.FIELD),
ColumnSchema("double_col", TSDataType.DOUBLE, ColumnCategory.FIELD),
ColumnSchema("str_col", TSDataType.STRING, ColumnCategory.FIELD),
ColumnSchema("date_col", TSDataType.DATE, ColumnCategory.FIELD),
])
dates_days = [
(date(2025, 1, (i % 28) + 1) - date(1970, 1, 1)).days for i in range(n)
]
batch = pa.record_batch({
"time": pa.array(np.arange(n, dtype="int64"), type=pa.timestamp("ns")),
"tag": pa.array([f"dev{i}" for i in range(n)], type=pa.string()),
"bool_col": pa.array([i % 2 == 0 for i in range(n)], type=pa.bool_()),
"int32_col": pa.array(np.arange(n, dtype="int32"), type=pa.int32()),
"int64_col": pa.array(np.arange(n, dtype="int64") * 10, type=pa.int64()),
"float_col": pa.array(np.arange(n, dtype="float32") * 0.5, type=pa.float32()),
"double_col": pa.array(np.arange(n, dtype="float64") * 1.1, type=pa.float64()),
"str_col": pa.array([f"s{i}" for i in range(n)], type=pa.string()),
"date_col": pa.array(dates_days, type=pa.date32()),
})
try:
if os.path.exists(path):
os.remove(path)
with TsFileTableWriter(path, schema) as w:
w.write_arrow_batch(batch)
result = _read_all_arrow(
path, table_name,
["tag", "bool_col", "int32_col", "int64_col",
"float_col", "double_col", "str_col", "date_col"],
)
assert len(result) == n
df = result.to_pandas().sort_values("time").reset_index(drop=True)
for col in ["tag", "bool_col", "int32_col", "int64_col",
"float_col", "double_col", "str_col", "date_col"]:
assert col in df.columns, f"Column '{col}' missing from result"
assert list(df["int32_col"]) == list(range(n))
assert list(df["int64_col"]) == [i * 10 for i in range(n)]
for i in range(n):
assert df["bool_col"].iloc[i] == (i % 2 == 0)
assert abs(df["double_col"].iloc[i] - i * 1.1) < 1e-9
assert df["str_col"].iloc[i] == f"s{i}"
finally:
if os.path.exists(path):
os.remove(path)
# ---------------------------------------------------------------------------
# Parity with write_dataframe
# ---------------------------------------------------------------------------
def test_write_arrow_parity_with_dataframe():
"""Data written via write_arrow_batch must match data written via write_dataframe."""
arrow_path = "test_write_arrow_parity_arrow.tsfile"
df_path = "test_write_arrow_parity_df.tsfile"
table_name = "t"
n = 500
schema_arrow = TableSchema(table_name, [
ColumnSchema("device", TSDataType.STRING, ColumnCategory.TAG),
ColumnSchema("v_i32", TSDataType.INT32, ColumnCategory.FIELD),
ColumnSchema("v_f64", TSDataType.DOUBLE, ColumnCategory.FIELD),
ColumnSchema("v_bool", TSDataType.BOOLEAN, ColumnCategory.FIELD),
ColumnSchema("v_str", TSDataType.STRING, ColumnCategory.FIELD),
])
schema_df = TableSchema(table_name, [
ColumnSchema("device", TSDataType.STRING, ColumnCategory.TAG),
ColumnSchema("v_i32", TSDataType.INT32, ColumnCategory.FIELD),
ColumnSchema("v_f64", TSDataType.DOUBLE, ColumnCategory.FIELD),
ColumnSchema("v_bool", TSDataType.BOOLEAN, ColumnCategory.FIELD),
ColumnSchema("v_str", TSDataType.STRING, ColumnCategory.FIELD),
])
timestamps = np.arange(n, dtype="int64")
v_i32 = np.arange(n, dtype="int32")
v_f64 = np.arange(n, dtype="float64") * 2.5
v_bool = np.array([i % 3 == 0 for i in range(n)])
v_str = [f"row{i}" for i in range(n)]
device = ["dev"] * n
batch = pa.record_batch({
"time": pa.array(timestamps, type=pa.timestamp("ns")),
"device": pa.array(device, type=pa.string()),
"v_i32": pa.array(v_i32, type=pa.int32()),
"v_f64": pa.array(v_f64, type=pa.float64()),
"v_bool": pa.array(v_bool, type=pa.bool_()),
"v_str": pa.array(v_str, type=pa.string()),
})
dataframe = pd.DataFrame({
"time": pd.Series(timestamps, dtype="int64"),
"device": device,
"v_i32": pd.Series(v_i32, dtype="int32"),
"v_f64": pd.Series(v_f64, dtype="float64"),
"v_bool": pd.Series(v_bool, dtype="bool"),
"v_str": v_str,
})
cols = ["device", "v_i32", "v_f64", "v_bool", "v_str"]
try:
for p in (arrow_path, df_path):
if os.path.exists(p):
os.remove(p)
with TsFileTableWriter(arrow_path, schema_arrow) as w:
w.write_arrow_batch(batch)
with TsFileTableWriter(df_path, schema_df) as w:
w.write_dataframe(dataframe)
result_arrow = _read_all_arrow(arrow_path, table_name, cols).to_pandas()
result_df = _read_all_arrow(df_path, table_name, cols).to_pandas()
result_arrow = result_arrow.sort_values("time").reset_index(drop=True)
result_df = result_df.sort_values("time").reset_index(drop=True)
assert len(result_arrow) == len(result_df) == n
assert list(result_arrow["v_i32"]) == list(result_df["v_i32"])
assert list(result_arrow["v_str"]) == list(result_df["v_str"])
assert list(result_arrow["v_bool"]) == list(result_df["v_bool"])
for i in range(n):
assert abs(result_arrow["v_f64"].iloc[i] - result_df["v_f64"].iloc[i]) < 1e-9
finally:
for p in (arrow_path, df_path):
if os.path.exists(p):
os.remove(p)
# ---------------------------------------------------------------------------
# Large batch
# ---------------------------------------------------------------------------
def test_write_arrow_large_batch():
"""Write a single large batch (100 k rows) and verify row count."""
path = "test_write_arrow_large.tsfile"
table_name = "t"
n = 100_000
schema = _make_schema(table_name, [
ColumnSchema("v", TSDataType.DOUBLE, ColumnCategory.FIELD),
])
batch = pa.record_batch({
"time": pa.array(np.arange(n, dtype="int64"), type=pa.timestamp("ns")),
"device": pa.array(["d"] * n, type=pa.string()),
"v": pa.array(np.random.rand(n), type=pa.float64()),
})
try:
if os.path.exists(path):
os.remove(path)
with TsFileTableWriter(path, schema) as w:
w.write_arrow_batch(batch)
result = _read_all_arrow(path, table_name, ["device", "v"], batch_size=8192)
assert len(result) == n
finally:
if os.path.exists(path):
os.remove(path)
if __name__ == "__main__":
os.chdir(os.path.dirname(os.path.abspath(__file__)))
pytest.main([__file__, "-v", "-s"])