blob: c2460d9b9edf1781e09b21719cd325255087cc78 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
FAISS Vector Global Index Reader.
"""
import os
import tempfile
import uuid
from typing import Dict, List, Optional
import numpy as np
from pypaimon.globalindex.global_index_reader import GlobalIndexReader
from pypaimon.globalindex.global_index_result import GlobalIndexResult
from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta
from pypaimon.globalindex.vector_search_result import DictBasedVectorSearchResult
from pypaimon.globalindex.roaring_bitmap import RoaringBitmap64
from pypaimon.globalindex.faiss.faiss_options import (
FaissVectorIndexOptions,
FaissVectorMetric,
FaissIndexType,
)
from pypaimon.globalindex.faiss.faiss_index_meta import FaissIndexMeta
from pypaimon.globalindex.faiss.faiss_index import FaissIndex
class FaissVectorGlobalIndexReader(GlobalIndexReader):
"""
Vector global index reader using FAISS.
"""
def __init__(
self,
file_io: 'FileIO',
index_path: str,
io_metas: List[GlobalIndexIOMeta],
options: FaissVectorIndexOptions
):
self._file_io = file_io
self._index_path = index_path
self._io_metas = io_metas
self._options = options
self._indices: List[Optional[FaissIndex]] = []
self._index_metas: List[FaissIndexMeta] = []
self._local_index_files: List[str] = []
self._metas_loaded = False
self._indices_loaded = False
def visit_vector_search(self, vector_search: 'VectorSearch') -> Optional[GlobalIndexResult]:
"""Perform vector similarity search."""
try:
# First load only metadata
self._ensure_load_metas()
include_row_ids = vector_search.include_row_ids
# If include_row_ids is specified, check which indices contain matching rows
if include_row_ids is not None:
matching_indices = []
for i, meta in enumerate(self._index_metas):
if self._has_overlap(meta.min_id, meta.max_id, include_row_ids):
matching_indices.append(i)
# If no index contains matching rowIds, return empty
if not matching_indices:
return None
# Load only matching indices
self._ensure_load_indices(matching_indices)
else:
# Load all indices
self._ensure_load_all_indices()
return self._search(vector_search)
except Exception as e:
raise RuntimeError(
f"Failed to search FAISS vector index with field_name={vector_search.field_name}, "
f"limit={vector_search.limit}"
) from e
def _has_overlap(self, min_id: int, max_id: int, include_row_ids: RoaringBitmap64) -> bool:
"""Check if the range [min_id, max_id] has any overlap with include_row_ids."""
for row_id in include_row_ids:
if min_id <= row_id <= max_id:
return True
if row_id > max_id:
break
return False
def _search(self, vector_search: 'VectorSearch') -> Optional[GlobalIndexResult]:
"""Perform the actual search across all loaded indices."""
query_vector = np.array(vector_search.vector, dtype=np.float32)
# L2 normalize the query vector if enabled
if self._options.normalize:
query_vector = self._normalize_l2(query_vector)
limit = vector_search.limit
include_row_ids = vector_search.include_row_ids
# When filtering is enabled, fetch more results
search_k = limit
if include_row_ids is not None:
search_k = max(
limit * self._options.search_factor,
include_row_ids.cardinality()
)
# Collect results from all indices using a priority queue approach
results: Dict[int, float] = {}
for index in self._indices:
if index is None:
continue
# Configure search parameters based on index type
self._configure_search_params(index)
# Limit search_k to index size
effective_k = min(search_k, max(1, index.size()))
if effective_k <= 0:
continue
# Perform search
distances, labels = index.search(query_vector, effective_k)
for i in range(effective_k):
row_id = int(labels[0, i])
if row_id < 0:
# Invalid result
continue
# Filter by include row IDs if specified
if include_row_ids is not None and row_id not in include_row_ids:
continue
# Convert distance to score (higher is better)
score = self._convert_distance_to_score(float(distances[0, i]))
# Keep top-k results
if len(results) < limit:
results[row_id] = score
else:
# Find minimum score in current results
min_row_id = min(results.keys(), key=lambda k: results[k])
if score > results[min_row_id]:
del results[min_row_id]
results[row_id] = score
if not results:
return None
return DictBasedVectorSearchResult(results)
def _configure_search_params(self, index: FaissIndex) -> None:
"""Configure search parameters based on index type."""
if index.index_type == FaissIndexType.HNSW:
index.set_hnsw_ef_search(self._options.ef_search)
elif index.index_type in (FaissIndexType.IVF, FaissIndexType.IVF_PQ, FaissIndexType.IVF_SQ8):
# For small indices, use higher nprobe
effective_nprobe = max(
self._options.nprobe,
max(1, index.size() // 10)
)
index.set_ivf_nprobe(effective_nprobe)
def _convert_distance_to_score(self, distance: float) -> float:
"""Convert distance to similarity score."""
if self._options.metric == FaissVectorMetric.L2:
# For L2 distance, smaller is better, so invert it
return 1.0 / (1.0 + distance)
else:
# Inner product is already a similarity
return distance
@staticmethod
def _normalize_l2(vector: np.ndarray) -> np.ndarray:
"""L2 normalize the vector."""
norm = np.linalg.norm(vector)
if norm > 0:
return vector / norm
return vector
def _ensure_load_metas(self) -> None:
"""Load only metadata from all index files."""
if self._metas_loaded:
return
for io_meta in self._io_metas:
if io_meta.metadata:
meta = FaissIndexMeta.deserialize(io_meta.metadata)
self._index_metas.append(meta)
self._metas_loaded = True
def _ensure_load_all_indices(self) -> None:
"""Load all indices."""
if self._indices_loaded:
return
for i in range(len(self._io_metas)):
self._load_index_at(i)
self._indices_loaded = True
def _ensure_load_indices(self, positions: List[int]) -> None:
"""Load only the specified indices by their positions."""
# Ensure indices list is large enough
while len(self._indices) < len(self._io_metas):
self._indices.append(None)
for pos in positions:
if self._indices[pos] is None:
self._load_index_at(pos)
def _load_index_at(self, position: int) -> None:
"""Load a single index at the specified position."""
io_meta = self._io_metas[position]
# Read index file from storage
index_file_path = f"{self._index_path}/{io_meta.file_name}"
# Create a temp file for the FAISS index
# Add to tracking list immediately to ensure cleanup on any failure
temp_path = None
try:
temp_file = tempfile.NamedTemporaryFile(
prefix=f"paimon-faiss-{uuid.uuid4()}-",
suffix=".faiss",
delete=False
)
temp_path = temp_file.name
# Track immediately after creation to prevent leaks
self._local_index_files.append(temp_path)
# Copy index data to temp file
with self._file_io.new_input_stream(index_file_path) as input_stream:
data = input_stream.read()
temp_file.write(data)
temp_file.close()
# Load FAISS index from temp file
index = FaissIndex.from_file(temp_path)
# Ensure indices list is large enough
while len(self._indices) <= position:
self._indices.append(None)
self._indices[position] = index
except Exception as e:
# Clean up on failure
if temp_path is not None:
try:
temp_file.close()
except Exception:
pass
if os.path.exists(temp_path):
os.unlink(temp_path)
# Remove from tracking list since we've already cleaned it up
if temp_path in self._local_index_files:
self._local_index_files.remove(temp_path)
raise e
def close(self) -> None:
"""Close the reader and release resources."""
# Close all FAISS indices
for index in self._indices:
if index is not None:
try:
index.close()
except Exception:
pass
self._indices.clear()
# Delete local temporary files
for local_file in self._local_index_files:
try:
if os.path.exists(local_file):
os.unlink(local_file)
except Exception:
pass
self._local_index_files.clear()
def __enter__(self) -> 'FaissVectorGlobalIndexReader':
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()