blob: 2685f745bf3dc0731301950dfc206acef1f8a697 [file]
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
import os
import struct
from abc import ABC, ABCMeta, abstractmethod
from enum import Enum
from typing import cast
import h5py
import numpy as np
from solrorbit.exceptions import InvalidExtensionException
from solrorbit.utils.parse import ConfigurationError
class Context(Enum):
"""DataSet context enum. Can be used to add additional context for how a
data-set should be interpreted.
"""
INDEX = 1
QUERY = 2
NEIGHBORS = 3
MAX_DISTANCE_NEIGHBORS = 4
MIN_SCORE_NEIGHBORS = 5
PARENTS = 6
ATTRIBUTES = 7
class DataSet(ABC):
"""DataSet interface. Used for reading data-sets from files.
Methods:
read: Read a chunk of data from the data-set
seek: Get to position in the data-set
size: Gets the number of items in the data-set
reset: Resets internal state of data-set to beginning
"""
__metaclass__ = ABCMeta
BEGINNING = 0
@abstractmethod
def read(self, chunk_size: int):
"""Read vector for given chunk size
@param chunk_size: limits vector size to read
"""
@abstractmethod
def seek(self, offset: int):
"""
Move reader to given offset
@param offset: value to move reader pointer to
"""
@abstractmethod
def size(self):
"""
Returns size of dataset
"""
@abstractmethod
def reset(self):
"""
Resets the dataset reader
"""
def get_data_set(data_set_format: str, path: str, context: Context):
"""
Factory method to get instance of Dataset for given format.
Args:
data_set_format: File format like hdf5, bigann
path: Data set file path
context: Dataset Context Enum
Returns: DataSet instance
"""
if data_set_format == HDF5DataSet.FORMAT_NAME:
return HDF5DataSet(path, context)
if data_set_format == BigANNVectorDataSet.FORMAT_NAME:
return create_big_ann_dataset(path)
raise ConfigurationError("Invalid data set format")
class HDF5DataSet(DataSet):
""" Data-set format corresponding to `ANN Benchmarks
<https://github.com/erikbern/ann-benchmarks#data-sets>`_
"""
FORMAT_NAME = "hdf5"
def __init__(self, dataset_path: str, context: Context):
self.dataset_path = dataset_path
self.context = self.parse_context(context)
self.current = self.BEGINNING
self.data = None
def _load(self):
if self.data is None:
file = h5py.File(self.dataset_path)
self.data = cast(h5py.Dataset, file[self.context])
def read(self, chunk_size: int):
# load file first before read
self._load()
if self.current >= self.size():
return None
end_offset = self.current + chunk_size
if end_offset > self.size():
end_offset = self.size()
vectors = cast(np.ndarray, self.data[self.current:end_offset])
self.current = end_offset
return vectors
def seek(self, offset: int):
if offset < self.BEGINNING:
raise Exception("Offset must be greater than or equal to 0")
if offset >= self.size():
raise Exception("Offset must be less than the data set size")
self.current = offset
def size(self):
# load file first before return size
self._load()
return self.data.len()
def reset(self):
self.current = self.BEGINNING
# pylint: disable=R0911
@staticmethod
def parse_context(context: Context) -> str:
if context == Context.NEIGHBORS:
return "neighbors"
if context == Context.INDEX:
return "train"
if context == Context.QUERY:
return "test"
if context == Context.PARENTS:
return "parents" # used in nested benchmarks to get the parent document id associated with each vector.
if context == Context.MAX_DISTANCE_NEIGHBORS:
return "max_distance_neighbors"
if context == Context.MIN_SCORE_NEIGHBORS:
return "min_score_neighbors"
if context == Context.ATTRIBUTES:
return "attributes"
raise Exception("Unsupported context")
class BigANNDataSet(DataSet):
DATA_SET_HEADER_LENGTH = 8
FORMAT_NAME = "bigann"
def __init__(self, dataset_path: str):
self.dataset_path = dataset_path
self.file = None
self.num_bytes = 0
self.current = DataSet.BEGINNING
self.bytes_per_num = 0
self.rows = 0
self.row_length = 0
def _init_internal_params(self):
self.file = open(self.dataset_path, 'rb')
self.file.seek(DataSet.BEGINNING, os.SEEK_END)
self.num_bytes = self.file.tell()
if self.num_bytes < BigANNDataSet.DATA_SET_HEADER_LENGTH:
raise Exception("Invalid file: file size cannot be less than {} bytes".format(
BigANNDataSet.DATA_SET_HEADER_LENGTH))
self.file.seek(BigANNDataSet.BEGINNING)
self.rows = int.from_bytes(self.file.read(4), "little")
self.row_length = int.from_bytes(self.file.read(4), "little")
self.bytes_per_num = self._get_data_size()
self.reader = self._value_reader()
def _load(self):
# load file if it is not loaded yet
if self.file is None:
self._init_internal_params()
def read(self, chunk_size: int):
# load file first before read
self._load()
if self.current >= self.size():
return None
end_offset = self.current + chunk_size
if end_offset > self.size():
end_offset = self.size()
vectors = np.asarray(
[self._read_vector() for _ in range(end_offset - self.current)]
)
self.current = end_offset
return vectors
def _get_file_byte_offset(self, offset):
"""Return file byte offset for given offset"""
return BigANNDataSet.DATA_SET_HEADER_LENGTH + (self.row_length * self.bytes_per_num * offset)
def seek(self, offset: int):
# load file first before seek
self._load()
if offset < self.BEGINNING:
raise Exception("Offset must be greater than or equal to 0")
if offset >= self.size():
raise Exception("Offset must be less than the data set size")
bytes_offset = self._get_file_byte_offset(offset)
self.file.seek(bytes_offset)
self.current = offset
def _read_vector(self):
return np.asarray([self.reader(self.file) for _ in range(self.row_length)])
def size(self):
# load file first before return size
self._load()
return self.rows
def reset(self):
if self.file:
self.file.seek(BigANNDataSet.DATA_SET_HEADER_LENGTH)
self.current = BigANNDataSet.BEGINNING
def __del__(self):
if self.file:
self.file.close()
@abstractmethod
def _get_supported_extension(self):
"""Return list of supported extension by this dataset"""
def _get_extension(self):
ext = self.dataset_path.split('.')[-1]
supported_extension = self._get_supported_extension()
if ext not in supported_extension:
raise InvalidExtensionException(
"Unknown extension :{}, supported extensions are: {}".format(
ext, str(supported_extension)))
return ext
@abstractmethod
def get_data_size(self, extension):
"""Return data size based on extension"""
def _get_data_size(self):
"""Return data size"""
ext = self._get_extension()
return self.get_data_size(ext)
@abstractmethod
def _get_value_reader(self, extension):
"""Return value reader based on extension"""
def _value_reader(self):
ext = self._get_extension()
return self._get_value_reader(ext)
class BigANNVectorDataSet(BigANNDataSet):
""" Data-set format for vector data-sets for `Big ANN Benchmarks
<https://big-ann-benchmarks.com/index.html#bench-datasets>`
"""
U8BIN_EXTENSION = "u8bin"
FBIN_EXTENSION = "fbin"
SUPPORTED_EXTENSION = [
FBIN_EXTENSION, U8BIN_EXTENSION
]
BYTES_PER_U8INT = 1
BYTES_PER_FLOAT = 4
def _init_internal_params(self):
super()._init_internal_params()
if (self.num_bytes - BigANNDataSet.DATA_SET_HEADER_LENGTH) != (
self.rows * self.row_length * self.bytes_per_num):
raise Exception("Invalid file. File size is not matching with expected estimated "
"value based on number of points, dimension and bytes per point")
def _get_supported_extension(self):
return BigANNVectorDataSet.SUPPORTED_EXTENSION
def get_data_size(self, extension):
if extension == BigANNVectorDataSet.U8BIN_EXTENSION:
return BigANNVectorDataSet.BYTES_PER_U8INT
if extension == BigANNVectorDataSet.FBIN_EXTENSION:
return BigANNVectorDataSet.BYTES_PER_FLOAT
return None
def _get_value_reader(self, extension):
if extension == BigANNVectorDataSet.U8BIN_EXTENSION:
return lambda file: float(int.from_bytes(file.read(BigANNVectorDataSet.BYTES_PER_U8INT), "little"))
if extension == BigANNVectorDataSet.FBIN_EXTENSION:
return lambda file: struct.unpack('<f', file.read(BigANNVectorDataSet.BYTES_PER_FLOAT))
return None
class BigANNGroundTruthDataSet(BigANNDataSet):
""" Data-set format for neighbor data-sets for `Big ANN Benchmarks
<https://big-ann-benchmarks.com/index.html#bench-datasets>`"""
BIN_EXTENSION = "bin"
SUPPORTED_EXTENSION = [BIN_EXTENSION]
BYTES_PER_UNSIGNED_INT32 = 4
def _init_internal_params(self):
super()._init_internal_params()
# The ground truth binary files consist of the following information:
# num_queries(uint32_t) K-NN(uint32) followed by num_queries X K x sizeof(uint32_t) bytes of data
# representing the IDs of the K-nearest neighbors of the queries, followed by num_queries X K x sizeof(float)
# bytes of data representing the distances to the corresponding points.
if (self.num_bytes - BigANNDataSet.DATA_SET_HEADER_LENGTH) != 2 * (
self.rows * self.row_length * self.bytes_per_num):
raise Exception("Invalid file. File size is not matching with expected estimated "
"value based on number of queries, k and bytes per query")
def _get_supported_extension(self):
return BigANNGroundTruthDataSet.SUPPORTED_EXTENSION
def get_data_size(self, extension):
return BigANNGroundTruthDataSet.BYTES_PER_UNSIGNED_INT32
def _get_value_reader(self, extension):
return lambda file: int.from_bytes(
file.read(BigANNGroundTruthDataSet.BYTES_PER_UNSIGNED_INT32), "little")
def create_big_ann_dataset(file_path: str):
if not file_path:
raise Exception("Invalid file path")
extension = file_path.split('.')[-1]
if extension in BigANNGroundTruthDataSet.SUPPORTED_EXTENSION:
return BigANNGroundTruthDataSet(file_path)
if extension in BigANNVectorDataSet.SUPPORTED_EXTENSION:
return BigANNVectorDataSet(file_path)
raise Exception("Unsupported file")