blob: 24c399a1921328cc882e85854a762b82bc4e2f06 [file]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""Vector search read to read index files."""
from abc import ABC, abstractmethod
from pypaimon.globalindex.global_index_meta import GlobalIndexIOMeta
from pypaimon.globalindex.global_index_result import GlobalIndexResult
from pypaimon.globalindex.offset_global_index_reader import OffsetGlobalIndexReader
from pypaimon.globalindex.vector_search import VectorSearch
from pypaimon.globalindex.vector_search_result import ScoredGlobalIndexResult
class VectorSearchRead(ABC):
"""Vector search read to read index files."""
def read_plan(self, plan):
# type: (VectorSearchScanPlan) -> GlobalIndexResult
return self.read(plan.splits())
@abstractmethod
def read(self, splits):
# type: (List[VectorSearchSplit]) -> GlobalIndexResult
pass
class VectorSearchReadImpl(VectorSearchRead):
"""Implementation for VectorSearchRead."""
def __init__(self, table, limit, vector_column, query_vector, filter_=None):
self._table = table
self._limit = limit
self._vector_column = vector_column
self._query_vector = query_vector
self._filter = filter_
def read(self, splits):
# type: (List[VectorSearchSplit]) -> GlobalIndexResult
if not splits:
return GlobalIndexResult.create_empty()
pre_filter = self._pre_filter(splits)
result = ScoredGlobalIndexResult.create_empty()
for split in splits:
split_result = self._eval(
split.row_range_start, split.row_range_end,
split.vector_index_files, pre_filter
)
if split_result is not None:
result = result.or_(split_result)
return result.top_k(self._limit)
def _pre_filter(self, splits):
# type: (list) -> Optional[RoaringBitmap64]
"""Evaluate the scalar filter against scalar global indexes to produce a row-id bitmap."""
if self._filter is None:
return None
# Collect scalar index files across splits, deduplicated by file name.
seen = set()
scalar_files = []
for split in splits:
for index_file in split.scalar_index_files:
if index_file.file_name in seen:
continue
seen.add(index_file.file_name)
scalar_files.append(index_file)
if not scalar_files:
return None
from pypaimon.globalindex.global_index_scanner import GlobalIndexScanner
scanner = GlobalIndexScanner.create(self._table, index_files=scalar_files)
if scanner is None:
return None
try:
result = scanner.scan(self._filter)
if result is None:
return None
return result.results()
finally:
scanner.close()
def _eval(self, row_range_start, row_range_end, vector_index_files,
include_row_ids):
# type: (int, int, list, Optional[RoaringBitmap64]) -> Optional[ScoredGlobalIndexResult]
if not vector_index_files:
return None
index_io_meta_list = []
for index_file in vector_index_files:
meta = index_file.global_index_meta
assert meta is not None
index_io_meta_list.append(
GlobalIndexIOMeta(
file_name=index_file.file_name,
file_size=index_file.file_size,
metadata=meta.index_meta,
external_path=index_file.external_path,
)
)
index_type = vector_index_files[0].index_type
index_path = self._table.path_factory().global_index_path_factory().index_path()
file_io = self._table.file_io
options = self._table.table_schema.options
vector_search = VectorSearch(
vector=self._query_vector,
limit=self._limit,
field_name=self._vector_column.name
)
if include_row_ids is not None:
vector_search = vector_search.with_include_row_ids(include_row_ids)
with _create_vector_reader(
index_type, file_io, index_path,
index_io_meta_list, options
) as reader:
offset_reader = OffsetGlobalIndexReader(reader, row_range_start, row_range_end)
return offset_reader.visit_vector_search(vector_search)
def _create_vector_reader(index_type, file_io, index_path, index_io_meta_list, options=None):
"""Create a global index reader for vector search."""
from pypaimon.globalindex.lumina.lumina_vector_global_index_reader import (
LUMINA_IDENTIFIERS,
LuminaVectorGlobalIndexReader,
)
if index_type in LUMINA_IDENTIFIERS:
return LuminaVectorGlobalIndexReader(
file_io, index_path, index_io_meta_list, options
)
raise ValueError("Unsupported vector index type: '%s'" % index_type)