blob: a345c818ada7bd805be65407137aa9ba5aaa1d2b [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import functools
import os
from typing import List
from pyflink.common import Row, Types
from pyflink.java_gateway import get_gateway
from pyflink.ml.linalg import Vectors, SparseVectorTypeInfo, DenseVector
from pyflink.ml.wrapper import JavaWithParams
from pyflink.ml.feature.lsh import MinHashLSH, MinHashLSHModel
from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
from pyflink.table import Table
from pyflink.table.expressions import col
class MinHashLSHTest(PyFlinkMLTestCase):
def setUp(self):
super(MinHashLSHTest, self).setUp()
self.data = self.t_env.from_data_stream(
self.env.from_collection([
(0, Vectors.sparse(6, [0, 1, 2], [1., 1., 1.])),
(1, Vectors.sparse(6, [2, 3, 4], [1., 1., 1.])),
(2, Vectors.sparse(6, [0, 2, 4], [1., 1., 1.])),
],
type_info=Types.ROW_NAMED(
['id', 'vec'],
[Types.INT(), SparseVectorTypeInfo()])))
self.expected = [
Row([
Vectors.dense(1.73046954E8, 1.57275425E8, 6.90717571E8),
Vectors.dense(5.02301169E8, 7.967141E8, 4.06089319E8),
Vectors.dense(2.83652171E8, 1.97714719E8, 6.04731316E8),
Vectors.dense(5.2181506E8, 6.36933726E8, 6.13894128E8),
Vectors.dense(3.04301769E8, 1.113672955E9, 6.1388711E8),
]),
Row([
Vectors.dense(1.73046954E8, 1.57275425E8, 6.7798584E7),
Vectors.dense(6.38582806E8, 1.78703694E8, 4.06089319E8),
Vectors.dense(6.232638E8, 9.28867E7, 9.92010642E8),
Vectors.dense(2.461064E8, 1.12787481E8, 1.92180297E8),
Vectors.dense(2.38162496E8, 1.552933319E9, 2.77995137E8),
]),
Row([
Vectors.dense(1.73046954E8, 1.57275425E8, 6.90717571E8),
Vectors.dense(1.453197722E9, 7.967141E8, 4.06089319E8),
Vectors.dense(6.232638E8, 1.97714719E8, 6.04731316E8),
Vectors.dense(2.461064E8, 1.12787481E8, 1.92180297E8),
Vectors.dense(1.224130231E9, 1.113672955E9, 2.77995137E8),
])]
def test_param(self):
lsh = MinHashLSH()
self.assertEqual('input', lsh.input_col)
self.assertEqual('output', lsh.output_col)
self.assertEqual(-1229568175, lsh.seed)
self.assertEqual(1, lsh.num_hash_tables)
self.assertEqual(1, lsh.num_hash_functions_per_table)
lsh.set_input_col('test_input') \
.set_output_col('test_output') \
.set_seed(2022) \
.set_num_hash_tables(3) \
.set_num_hash_functions_per_table(4)
self.assertEqual('test_input', lsh.input_col)
self.assertEqual('test_output', lsh.output_col)
self.assertEqual(2022, lsh.seed)
self.assertEqual(3, lsh.num_hash_tables)
self.assertEqual(4, lsh.num_hash_functions_per_table)
def test_output_schema(self):
lsh = MinHashLSH() \
.set_input_col('vec') \
.set_output_col('hashes') \
.set_seed(2022) \
.set_num_hash_tables(5) \
.set_num_hash_functions_per_table(3)
model = lsh.fit(self.data)
output = model.transform(self.data)[0]
self.assertEqual(['id', 'vec', 'hashes'], output.get_schema().get_field_names())
def test_fit_and_transform(self):
lsh = MinHashLSH() \
.set_input_col('vec') \
.set_output_col('hashes') \
.set_seed(2022) \
.set_num_hash_tables(5) \
.set_num_hash_functions_per_table(3)
model = lsh.fit(self.data)
output = model.transform(self.data)[0].select(col("hashes"))
self.verify_output_hashes(output, self.expected)
def test_estimator_save_load_transform(self):
path = os.path.join(self.temp_dir, 'test_estimator_save_load_transform_minhashlsh')
lsh = MinHashLSH() \
.set_input_col('vec') \
.set_output_col('hashes') \
.set_seed(2022) \
.set_num_hash_tables(5) \
.set_num_hash_functions_per_table(3)
lsh.save(path)
lsh = MinHashLSH.load(self.t_env, path)
model = lsh.fit(self.data)
output = model.transform(self.data)[0].select(col(lsh.output_col))
self.verify_output_hashes(output, self.expected)
def test_model_save_load_transform(self):
path = os.path.join(self.temp_dir, 'test_model_save_load_transform_minhashlsh')
lsh = MinHashLSH() \
.set_input_col('vec') \
.set_output_col('hashes') \
.set_seed(2022) \
.set_num_hash_tables(5) \
.set_num_hash_functions_per_table(3)
model = lsh.fit(self.data)
model.save(path)
self.env.execute('save_model')
model = MinHashLSHModel.load(self.t_env, path)
output = model.transform(self.data)[0].select(col(lsh.output_col))
self.verify_output_hashes(output, self.expected)
def test_get_model_data(self):
lsh = MinHashLSH() \
.set_input_col('vec') \
.set_output_col('hashes') \
.set_seed(2022) \
.set_num_hash_tables(5) \
.set_num_hash_functions_per_table(3)
model = lsh.fit(self.data)
model_data_table: Table = model.get_model_data()[0]
self.assertEqual(
['numHashTables', 'numHashFunctionsPerTable', 'randCoefficientA', 'randCoefficientB'],
model_data_table.get_schema().get_field_names())
model_data_row = list(self.t_env.to_data_stream(model_data_table).execute_and_collect())[0]
self.assertEqual(lsh.num_hash_tables, model_data_row[0])
self.assertEqual(lsh.num_hash_functions_per_table, model_data_row[1])
self.assertEqual(lsh.num_hash_tables * lsh.num_hash_functions_per_table,
len(model_data_row[2]))
self.assertEqual(lsh.num_hash_tables * lsh.num_hash_functions_per_table,
len(model_data_row[3]))
def test_set_model_data(self):
lsh = MinHashLSH() \
.set_input_col('vec') \
.set_output_col('hashes') \
.set_seed(2022) \
.set_num_hash_tables(5) \
.set_num_hash_functions_per_table(3)
model_a = lsh.fit(self.data)
model_data_table = model_a.get_model_data()[0]
model_b: MinHashLSHModel = MinHashLSHModel().set_model_data(model_data_table)
self.update_existing_params(model_b, model_a)
output = model_b.transform(self.data)[0].select(col(lsh.output_col))
self.verify_output_hashes(output, self.expected)
def test_approx_nearest_neighbors(self):
lsh = MinHashLSH() \
.set_input_col('vec') \
.set_output_col('hashes') \
.set_seed(2022) \
.set_num_hash_tables(5) \
.set_num_hash_functions_per_table(1)
expected = [
Row(id=0, distCol=0.75),
Row(id=1, distCol=0.75),
]
model: MinHashLSHModel = lsh.fit(self.data)
key = Vectors.sparse(6, [1, 3], [1., 1.])
output = model.approx_nearest_neighbors(self.data, key, 2).select(col("id"), col("distCol"))
actual_result = [r for r in self.t_env.to_data_stream(output).execute_and_collect()]
actual_result.sort(key=lambda r: r[0])
self.assertEqual(expected, actual_result)
def test_approx_nearest_neighbors_dense_vector(self):
lsh = MinHashLSH() \
.set_input_col('vec') \
.set_output_col('hashes') \
.set_seed(2022) \
.set_num_hash_tables(5) \
.set_num_hash_functions_per_table(1)
expected = [
Row(id=0, distCol=0.75),
Row(id=1, distCol=0.75),
]
model: MinHashLSHModel = lsh.fit(self.data)
key = Vectors.dense([0., 1., 0., 1., 0., 0.])
output = model.approx_nearest_neighbors(self.data, key, 2).select(col("id"), col("distCol"))
actual_result = [r for r in self.t_env.to_data_stream(output).execute_and_collect()]
actual_result.sort(key=lambda r: r[0])
self.assertEqual(expected, actual_result)
def test_approx_similarity_join(self):
lsh = MinHashLSH() \
.set_input_col('vec') \
.set_output_col('hashes') \
.set_seed(2022) \
.set_num_hash_tables(5) \
.set_num_hash_functions_per_table(1)
data_a = self.data
model: MinHashLSHModel = lsh.fit(data_a)
data_b = self.t_env.from_data_stream(
self.env.from_collection([
(3, Vectors.sparse(6, [1, 3, 5], [1., 1., 1.])),
(4, Vectors.sparse(6, [2, 3, 5], [1., 1., 1.])),
(5, Vectors.sparse(6, [1, 2, 4], [1., 1., 1.])),
], type_info=Types.ROW_NAMED(['id', 'vec'], [Types.INT(), SparseVectorTypeInfo()])))
expected = [
Row(1, 4, .5),
Row(0, 5, .5),
Row(1, 5, .5),
Row(2, 5, .5)
]
for r in expected:
r.set_field_names(['datasetA.id', 'datasetB.id', 'distCol'])
output = model.approx_similarity_join(data_a, data_b, .6, "id")
actual_result = [r for r in self.t_env.to_data_stream(output).execute_and_collect()]
expected.sort(key=lambda r: (r[0], r[1]))
actual_result.sort(key=lambda r: (r[0], r[1]))
self.assertEqual(expected, actual_result)
@classmethod
def update_existing_params(cls, target: JavaWithParams, source: JavaWithParams):
get_gateway().jvm.org.apache.flink.ml.util.ParamUtils \
.updateExistingParams(target._java_obj, source._java_obj.getParamMap())
@classmethod
def dense_vector_comparator(cls, dv0: DenseVector, dv1: DenseVector):
if dv0.size() != dv1.size():
return dv0.size() - dv1.size()
for e0, e1 in zip(dv0.values, dv1.values):
if e0 != e1:
return 1 if e0 > e1 else -1
return 0
@classmethod
def dense_vector_array_comparator(cls, dvs0: List[DenseVector], dvs1: List[DenseVector]):
if len(dvs0) != len(dvs1):
return len(dvs0) - len(dvs1)
for dv0, dv1 in zip(dvs0, dvs1):
cmp = cls.dense_vector_comparator(dv0, dv1)
if cmp != 0:
return cmp
return 0
def verify_output_hashes(self, output: Table, expected_result: List[Row]):
actual_result = [r for r in
self.t_env.to_data_stream(output).execute_and_collect()]
actual_result.sort(
key=lambda x: functools.cmp_to_key(self.dense_vector_array_comparator)(x[0]))
expected_result.sort(
key=lambda x: functools.cmp_to_key(self.dense_vector_array_comparator)(x[0]))
for r0, r1 in zip(actual_result, expected_result):
self.assertEqual(0, self.dense_vector_array_comparator(r0[0], r1[0]))