blob: f595c52bd4b2f9aa35cf90e842aab3b0a8b646b0 [file]
# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
import tempfile
from unittest import TestCase
from solrorbit.utils.dataset import Context, get_data_set, HDF5DataSet, BigANNVectorDataSet
from solrorbit.utils.parse import ConfigurationError
from tests.utils.dataset_helper import create_data_set, create_ground_truth
DEFAULT_INDEX_NAME = "test-index"
DEFAULT_FIELD_NAME = "test-field"
DEFAULT_CONTEXT = Context.INDEX
DEFAULT_NUM_VECTORS = 10
DEFAULT_DIMENSION = 10
DEFAULT_RANDOM_STRING_LENGTH = 8
class DataSetTestCase(TestCase):
def testHDF5AsAcceptableDataSetFormat(self):
with tempfile.TemporaryDirectory() as data_set_dir:
valid_data_set_path = create_data_set(
DEFAULT_NUM_VECTORS,
DEFAULT_DIMENSION,
HDF5DataSet.FORMAT_NAME,
DEFAULT_CONTEXT,
data_set_dir
)
data_set_instance = get_data_set("hdf5", valid_data_set_path, Context.INDEX)
self.assertEqual(data_set_instance.FORMAT_NAME, HDF5DataSet.FORMAT_NAME)
self.assertEqual(data_set_instance.size(), DEFAULT_NUM_VECTORS)
def testBigANNAsAcceptableDataSetFormatWithFloatExtension(self):
float_extension = "fbin"
data_set_dir = tempfile.mkdtemp()
valid_data_set_path = create_data_set(
DEFAULT_NUM_VECTORS,
DEFAULT_DIMENSION,
float_extension,
DEFAULT_CONTEXT,
data_set_dir
)
data_set_instance = get_data_set("bigann", valid_data_set_path, Context.INDEX)
self.assertEqual(data_set_instance.FORMAT_NAME, BigANNVectorDataSet.FORMAT_NAME)
self.assertEqual(data_set_instance.size(), DEFAULT_NUM_VECTORS)
def testBigANNGroundTruthAsAcceptableDataSetFormat(self):
bin_extension = "bin"
data_set_dir = tempfile.mkdtemp()
valid_data_set_path = create_ground_truth(
100,
10,
bin_extension,
Context.NEIGHBORS,
data_set_dir
)
data_set_instance = get_data_set("bigann", valid_data_set_path, Context.NEIGHBORS)
self.assertEqual(data_set_instance.FORMAT_NAME, BigANNVectorDataSet.FORMAT_NAME)
self.assertEqual(data_set_instance.size(), 100)
def testUnSupportedDataSetFormat(self):
with self.assertRaises(ConfigurationError) as _:
get_data_set("random", "/some/path", Context.INDEX)