| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| |
| import typing |
| from abc import ABC |
| from pyflink.java_gateway import get_gateway |
| from pyflink.table import Table |
| from pyflink.util.java_utils import to_jarray |
| |
| from pyflink.ml.core.linalg import Vector, DenseVector, SparseVector |
| from pyflink.ml.core.param import Param, IntParam, ParamValidators |
| from pyflink.ml.core.wrapper import JavaWithParams |
| from pyflink.ml.lib.feature.common import JavaFeatureEstimator, JavaFeatureModel |
| from pyflink.ml.lib.param import HasInputCol, HasOutputCol, HasSeed |
| |
| |
| class _LSHModelParams(JavaWithParams, |
| HasInputCol, |
| HasOutputCol): |
| """ |
| Params for :class:`LSHModel` |
| """ |
| |
| def __init__(self, java_params): |
| super(_LSHModelParams, self).__init__(java_params) |
| |
| |
| class _LSHParams(_LSHModelParams): |
| """ |
| Params for :class:`LSH` |
| """ |
| |
| """ |
| Param for the number of hash tables used in LSH OR-amplification. |
| |
| OR-amplification can be used to reduce the false negative rate. Higher values of this |
| param lead to a reduced false negative rate, at the expense of added computational |
| complexity. |
| """ |
| NUM_HASH_TABLES: Param[int] = IntParam( |
| "num_hash_tables", "Number of hash tables.", 1, ParamValidators.gt_eq(1) |
| ) |
| |
| """ |
| Param for the number of hash functions per hash table used in LSH AND-amplification. |
| |
| AND-amplification can be used to reduce the false positive rate. Higher values of this |
| param lead to a reduced false positive rate, at the expense of added computational |
| complexity. |
| """ |
| NUM_HASH_FUNCTIONS_PER_TABLE: Param[int] = IntParam( |
| "num_hash_functions_per_table", |
| "Number of hash functions per table.", |
| 1, |
| ParamValidators.gt_eq(1)) |
| |
| def __init__(self, java_params): |
| super(_LSHParams, self).__init__(java_params) |
| |
| def set_num_hash_tables(self, value: int): |
| return typing.cast(_LSHParams, self.set(self.NUM_HASH_TABLES, value)) |
| |
| def get_num_hash_tables(self): |
| return self.get(self.NUM_HASH_TABLES) |
| |
| @property |
| def num_hash_tables(self): |
| return self.get_num_hash_tables() |
| |
| def set_num_hash_functions_per_table(self, value: int): |
| return typing.cast(_LSHParams, self.set(self.NUM_HASH_FUNCTIONS_PER_TABLE, value)) |
| |
| def get_num_hash_functions_per_table(self): |
| return self.get(self.NUM_HASH_FUNCTIONS_PER_TABLE) |
| |
| @property |
| def num_hash_functions_per_table(self): |
| return self.get_num_hash_functions_per_table() |
| |
| |
| class _MinHashLSHParams(_LSHParams, HasSeed): |
| """ |
| Params for :class:`MinHashLSH` |
| """ |
| |
| def __init__(self, java_params): |
| super(_MinHashLSHParams, self).__init__(java_params) |
| |
| |
| class _LSH(JavaFeatureEstimator, ABC): |
| """ |
| Base class for estimators that support LSH (Locality-sensitive hashing) algorithm for different |
| metrics (e.g., Jaccard distance). |
| |
| The basic idea of LSH is to use a set of hash functions to map input vectors into different |
| buckets, where closer vectors are expected to be in the same bucket with higher probabilities. |
| In detail, each input vector is hashed by all functions. To decide whether two input vectors |
| are mapped into the same bucket, two mechanisms for assigning buckets are proposed as follows. |
| |
| <ul> |
| <li>AND-amplification: The two input vectors are defined to be in the same bucket as long as |
| ALL of the hash value matches. |
| <li>OR-amplification: The two input vectors are defined to be in the same bucket as long as |
| ANY of the hash value matches. |
| </ul> |
| |
| See: <a href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing"> |
| Locality-sensitive_hashing</a>. |
| """ |
| |
| def __init__(self): |
| super(_LSH, self).__init__() |
| |
| @classmethod |
| def _java_estimator_package_name(cls) -> str: |
| return "lsh" |
| |
| |
| class _LSHModel(JavaFeatureModel, ABC): |
| """ |
| Base class for LSH model. |
| |
| In addition to transforming input feature vectors to multiple hash values, it also supports |
| approximate nearest neighbors search within a dataset regarding a key vector and approximate |
| similarity join between two datasets. |
| """ |
| |
| def __init__(self, java_model): |
| super(_LSHModel, self).__init__(java_model) |
| |
| @classmethod |
| def _java_model_package_name(cls) -> str: |
| return "lsh" |
| |
| def approx_nearest_neighbors(self, dataset: Table, key: Vector, k: int, |
| dist_col: str = 'distCol'): |
| """ |
| Approximately finds at most k items from a dataset which have the closest distance to a |
| given item. If the `outputCol` is missing in the given dataset, this method transforms the |
| dataset with the model at first. |
| |
| :param dataset: The dataset in which to to search for nearest neighbors. |
| :param key: The item to search for. |
| :param k: The maximum number of nearest neighbors. |
| :param dist_col: The output column storing the distance between each neighbor and the key. |
| :return: A dataset containing at most k items closest to the key with a column named |
| `distCol` appended. |
| """ |
| j_vectors = get_gateway().jvm.org.apache.flink.ml.linalg.Vectors |
| if isinstance(key, (DenseVector,)): |
| j_key = j_vectors.dense(to_jarray(get_gateway().jvm.double, key.values.tolist())) |
| elif isinstance(key, (SparseVector,)): |
| # noinspection PyProtectedMember |
| j_key = j_vectors.sparse( |
| key.size(), |
| to_jarray(get_gateway().jvm.int, key._indices.tolist()), |
| to_jarray(get_gateway().jvm.double, key._values.tolist()) |
| ) |
| else: |
| raise TypeError(f'Key {key} must be an instance of Vector.') |
| |
| # noinspection PyProtectedMember |
| return Table(self._java_obj.approxNearestNeighbors( |
| dataset._j_table, j_key, k, dist_col), self._t_env) |
| |
| def approx_similarity_join(self, dataset_a: Table, dataset_b: Table, threshold: float, |
| id_col: str, dist_col: str = 'distCol'): |
| """ |
| Joins two datasets to approximately find all pairs of rows whose distance are smaller than |
| or equal to the threshold. If the `outputCol` is missing in either dataset, this method |
| transforms the dataset at first. |
| |
| :param dataset_a: One dataset. |
| :param dataset_b: The other dataset. |
| :param threshold: The distance threshold. |
| :param id_col: A column in the two datasets to identify each row. |
| :param dist_col: The output column storing the distance between each pair of rows. |
| :return: A joined dataset containing pairs of rows. The original rows are in columns |
| "dataset_a" and "dataset_b", and a column "distCol" is added to show the distance |
| between each pair. |
| """ |
| # noinspection PyProtectedMember |
| return Table(self._java_obj.approxSimilarityJoin( |
| dataset_a._j_table, dataset_b._j_table, |
| threshold, id_col, dist_col), self._t_env) |
| |
| |
| class MinHashLSHModel(_LSHModel, _LSHModelParams): |
| """ |
| A Model which generates hash values using the model data computed by :class:`MinHashLSH`. |
| """ |
| |
| def __init__(self, java_model=None): |
| super(MinHashLSHModel, self).__init__(java_model) |
| |
| @classmethod |
| def _java_model_class_name(cls) -> str: |
| return "MinHashLSHModel" |
| |
| |
| class MinHashLSH(_LSH, _MinHashLSHParams): |
| """ |
| An Estimator that implements the MinHash LSH algorithm, which supports LSH for Jaccard distance. |
| |
| The input could be dense or sparse vectors. Each input vector must have at least one non-zero |
| index and all non-zero values are treated as binary "1" values. The sizes of input vectors |
| should be same and not larger than a predefined prime (i.e., 2038074743). |
| |
| See: <a href="https://en.wikipedia.org/wiki/MinHash">MinHash</a>. |
| """ |
| |
| def __init__(self): |
| super(MinHashLSH, self).__init__() |
| |
| @classmethod |
| def _create_model(cls, java_model) -> MinHashLSHModel: |
| return MinHashLSHModel(java_model) |
| |
| @classmethod |
| def _java_estimator_class_name(cls) -> str: |
| return "MinHashLSH" |