blob: ae88d630bc191967437282c94f65f08db136c94e [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""Tests for global index functionality."""
import os
import unittest
import faiss
import numpy as np
from pypaimon import Schema
from pypaimon.globalindex.faiss import FaissIndex
from pypaimon.globalindex.vector_search import VectorSearch
from pypaimon.globalindex.faiss.faiss_options import FaissVectorMetric
class TestFaissVectorGlobalIndexE2E(unittest.TestCase):
"""
End-to-end test for FAISS vector global index.
"""
@classmethod
def setUpClass(cls):
"""Set up test fixtures."""
import tempfile
from pypaimon import CatalogFactory
cls.tempdir = tempfile.mkdtemp()
cls.warehouse = os.path.join(cls.tempdir, 'warehouse')
cls.catalog = CatalogFactory.create({
'warehouse': cls.warehouse
})
cls.catalog.create_database('default', True)
cls.dimension = 32
cls.n_vectors = 100
# Generate random vectors with a fixed seed for reproducibility
np.random.seed(42)
cls.vectors = np.random.random(
(cls.n_vectors, cls.dimension)
).astype(np.float32)
@classmethod
def tearDownClass(cls):
"""Clean up temp files."""
import shutil
shutil.rmtree(cls.tempdir, ignore_errors=True)
def test_faiss_vector_global_index_write_scan_read(self):
"""
Test FAISS vector global index end-to-end.
"""
import pyarrow as pa
import json
from pypaimon.index.index_file_meta import IndexFileMeta
from pypaimon.globalindex.global_index_meta import GlobalIndexMeta
from pypaimon.globalindex.faiss.faiss_index_meta import FaissIndexMeta
pa_schema = pa.schema([
('id', pa.int32()),
('vec', pa.list_(pa.float32())),
])
table_options = {
'bucket': '1',
'vector.dim': str(self.dimension),
'vector.metric': 'L2',
'vector.index-type': 'FLAT',
'data-evolution.enabled': 'true',
'row-tracking.enabled': 'true',
'global-index.enabled': 'true',
}
schema = Schema.from_pyarrow_schema(pa_schema, options=table_options)
self.catalog.create_table(
'default.test_faiss_vector_e2e', schema, False
)
table = self.catalog.get_table('default.test_faiss_vector_e2e')
write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()
vector_field_name = 'vec'
data = pa.Table.from_pydict({
'id': list(range(self.n_vectors)),
vector_field_name: [self.vectors[i].tolist() for i in range(self.n_vectors)],
}, schema=pa_schema)
table_write.write_arrow(data)
table_commit.commit(table_write.prepare_commit())
table_write.close()
table_commit.close()
# ========== Step 3: Build FAISS index ==========
index_dir = table.path_factory().index_path()
os.makedirs(index_dir, exist_ok=True)
index = FaissIndex.create_flat_index(
self.dimension,
FaissVectorMetric.L2
)
# Add vectors with IDs - same as Java's FaissIndex.addWithIds()
ids = np.arange(self.n_vectors, dtype=np.int64)
vectors = np.ascontiguousarray(self.vectors, dtype=np.float32)
ids = np.ascontiguousarray(ids, dtype=np.int64)
index._index.add_with_ids(vectors, ids)
self.assertEqual(index.size(), self.n_vectors)
# Both use the same underlying FAISS serialization format
index_filename = 'faiss-vec-0.index'
index_path = os.path.join(index_dir, index_filename)
faiss.write_index(index._index, index_path)
index.close()
loaded_index = FaissIndex.from_file(index_path)
self.assertEqual(loaded_index.size(), self.n_vectors)
# Verify search works on loaded index
query = self.vectors[42:43]
distances, labels = loaded_index.search(query, 1)
self.assertEqual(labels[0][0], 42) # Should find the exact vector
loaded_index.close()
# Create FaissIndexMeta
faiss_meta = FaissIndexMeta(
dim=self.dimension,
metric_value=0, # L2
index_type_ordinal=0, # FLAT
num_vectors=self.n_vectors,
min_id=0,
max_id=self.n_vectors - 1
)
# Create GlobalIndexMeta
global_index_meta = GlobalIndexMeta(
row_range_start=0,
row_range_end=self.n_vectors - 1,
index_field_id=1,
index_meta=faiss_meta.serialize()
)
# Create IndexFileMeta
index_file_meta = IndexFileMeta(
index_type='faiss-vector-ann',
file_name=index_filename,
file_size=os.path.getsize(index_path),
row_count=self.n_vectors,
global_index_meta=global_index_meta
)
# Write index manifest file
manifest_dir = table.path_factory().manifest_path()
os.makedirs(manifest_dir, exist_ok=True)
index_manifest_filename = 'index-manifest-0'
index_manifest_path = os.path.join(manifest_dir, index_manifest_filename)
# Simple JSON format for index manifest entries
manifest_entries = [{
'kind': 0, # ADD
'partition': [],
'bucket': 1,
'index_file': {
'file_name': index_file_meta.file_name,
'file_size': index_file_meta.file_size,
'row_count': index_file_meta.row_count,
'index_type': index_file_meta.index_type,
'global_index_meta': {
'row_range_start': global_index_meta.row_range_start,
'row_range_end': global_index_meta.row_range_end,
'index_field_id': global_index_meta.index_field_id,
}
}
}]
with open(index_manifest_path, 'w') as f:
json.dump(manifest_entries, f)
# Update snapshot to include index manifest
from pypaimon.snapshot.snapshot import Snapshot
from pypaimon.common.json_util import JSON
snapshot_manager = table.snapshot_manager()
current_snapshot = snapshot_manager.get_latest_snapshot()
# Create new snapshot with index manifest
# Set next_row_id to n_vectors since we've indexed all rows (0 to n_vectors-1)
new_snapshot = Snapshot(
id=current_snapshot.id + 1,
schema_id=current_snapshot.schema_id,
base_manifest_list=current_snapshot.base_manifest_list,
delta_manifest_list=current_snapshot.delta_manifest_list,
changelog_manifest_list=current_snapshot.changelog_manifest_list,
index_manifest=index_manifest_filename,
commit_kind=current_snapshot.commit_kind,
commit_user=current_snapshot.commit_user,
commit_identifier=current_snapshot.commit_identifier,
version=current_snapshot.version,
time_millis=current_snapshot.time_millis,
total_record_count=current_snapshot.total_record_count,
delta_record_count=current_snapshot.delta_record_count,
next_row_id=self.n_vectors
)
# Write new snapshot
snapshot_path = os.path.join(
snapshot_manager.snapshot_dir, 'snapshot-{}'.format(new_snapshot.id)
)
with open(snapshot_path, 'w') as f:
f.write(JSON.to_json(new_snapshot))
# Update LATEST hint
with open(snapshot_manager.latest_file, 'w') as f:
f.write(str(new_snapshot.id))
query_idx = 42
query_vector = self.vectors[query_idx]
vector_search = VectorSearch(
vector=query_vector,
limit=10,
field_name=vector_field_name
)
read_builder = table.new_read_builder().with_vector_search(vector_search)
# Get scan and plan
scan = read_builder.new_scan()
plan = scan.plan()
plan_splits = plan.splits()
# Read data through ReadBuilder API
table_read = read_builder.new_read()
result = table_read.to_arrow(plan_splits)
result_ids = result.column('id').to_pylist()
# Verify result count matches vector search limit
self.assertEqual(
len(result_ids), 10,
"Vector search with limit=10 should return exactly 10 results. "
"Got {} results: {}".format(len(result_ids), result_ids)
)
# Query vector (id=42) should be in results
self.assertIn(
query_idx, result_ids,
"Query vector (id={}) should be in results: {}".format(
query_idx, result_ids
)
)
class JavaPyFaissE2ETest(unittest.TestCase):
"""
E2E test for FAISS vector index: Java writes, Python reads.
Following the JavaPyReadWriteTest pattern.
Run this test with:
1. Run Java test first: mvn test -Dtest=JavaPyFaissE2ETest#testJavaWriteFaissVectorIndex \
-pl paimon-faiss/paimon-faiss-index -Drun.e2e.tests=true
2. Then run Python test: python -m pytest test_global_index.py::JavaPyFaissE2ETest -v
"""
@classmethod
def setUpClass(cls):
"""Set up using the shared e2e warehouse."""
from pypaimon import CatalogFactory
# Use the shared e2e warehouse (same as JavaPyE2ETest)
# The test runs from pypaimon/tests directory, warehouse is in e2e/warehouse
cls.tempdir = os.path.abspath(os.path.dirname(__file__))
cls.warehouse = os.path.join(cls.tempdir, 'e2e', 'warehouse')
# Create database if it doesn't exist (needed for catalog to find tables)
cls.catalog = CatalogFactory.create({
'warehouse': cls.warehouse
})
cls.catalog.create_database('default', True)
def test_read_faiss_vector_table(self):
"""
Read FAISS vector table written by Java (JavaPyFaissE2ETest.testJavaWriteFaissVectorIndex).
"""
from pypaimon.catalog.catalog_exception import TableNotExistException
try:
table = self.catalog.get_table('default.faiss_vector_table_j')
except TableNotExistException:
self.skipTest("Table 'default.faiss_vector_table_j' not found. ")
vector_field_name = 'vec'
# Query with [0.85, 0.15], limit=2
query_vector = np.array([0.85, 0.15], dtype=np.float32)
vector_search = VectorSearch(
vector=query_vector,
limit=2,
field_name=vector_field_name
)
read_builder = table.new_read_builder().with_vector_search(vector_search)
scan = read_builder.new_scan()
plan = scan.plan()
plan_splits = plan.splits()
table_read = read_builder.new_read()
result = table_read.to_arrow(plan_splits)
result_ids = result.column('id').to_pylist()
print(f"Vector search results: {result_ids}")
# With L2 distance, closest to [0.85, 0.15] should be:
# - [0.95, 0.1] (id=1)
# - [0.98, 0.05] (id=3)
self.assertEqual(
len(result_ids), 2,
f"Vector search with limit=2 should return exactly 2 results. Got: {result_ids}"
)
self.assertIn(1, result_ids, "ID 1 ([0.95, 0.1]) should be in results")
self.assertIn(3, result_ids, "ID 3 ([0.98, 0.05]) should be in results")
if __name__ == '__main__':
unittest.main()