| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """ |
| @file knn.py_in |
| |
| @brief knn: K-Nearest Neighbors for regression and classification |
| |
| @namespace knn |
| |
| """ |
| |
| import plpy |
| import copy |
| from collections import defaultdict |
| from math import log |
| from utilities.control import MinWarning |
| from utilities.utilities import INTEGER |
| from utilities.utilities import _assert |
| from utilities.utilities import add_postfix |
| from utilities.utilities import extract_keyvalue_params |
| from utilities.utilities import py_list_to_sql_string |
| from utilities.utilities import unique_string |
| from utilities.utilities import NUMERIC, ONLY_ARRAY |
| from utilities.utilities import is_valid_psql_type |
| from utilities.utilities import is_pg_major_version_less_than |
| from utilities.utilities import num_features |
| from utilities.validate_args import array_col_has_no_null |
| from utilities.validate_args import cols_in_tbl_valid |
| from utilities.validate_args import drop_tables |
| from utilities.validate_args import get_cols |
| from utilities.validate_args import get_expr_type |
| from utilities.validate_args import input_tbl_valid, output_tbl_valid |
| from utilities.validate_args import is_col_array |
| from utilities.validate_args import is_var_valid |
| from utilities.validate_args import quote_ident |
| from utilities.validate_args import get_algorithm_name |
| |
| WEIGHT_FOR_ZERO_DIST = 1e107 |
| BRUTE_FORCE = 'brute_force' |
| KD_TREE = 'kd_tree' |
| |
| def knn_validate_src(schema_madlib, point_source, point_column_name, point_id, |
| label_column_name, test_source, test_column_name, |
| test_id, output_table, k, output_neighbors, fn_dist, |
| is_brute_force, depth, leaf_nodes, **kwargs): |
| input_tbl_valid(point_source, 'kNN') |
| input_tbl_valid(test_source, 'kNN') |
| output_tbl_valid(output_table, 'kNN') |
| |
| _assert(label_column_name or output_neighbors, |
| "kNN error: Either label_column_name or " |
| "output_neighbors has to be inputed.") |
| |
| if label_column_name and label_column_name.strip(): |
| cols_in_tbl_valid(point_source, [label_column_name], 'kNN') |
| |
| _assert(is_var_valid(point_source, point_column_name), |
| "kNN error: {0} is an invalid column name or " |
| "expression for point_column_name param".format(point_column_name)) |
| point_col_type = get_expr_type(point_column_name, point_source) |
| _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY), |
| "kNN Error: Feature column or expression '{0}' in train table is not" |
| " an array.".format(point_column_name)) |
| |
| _assert(is_var_valid(test_source, test_column_name), |
| "kNN error: {0} is an invalid column name or expression for " |
| "test_column_name param".format(test_column_name)) |
| test_col_type = get_expr_type(test_column_name, test_source) |
| _assert(is_valid_psql_type(test_col_type, NUMERIC | ONLY_ARRAY), |
| "kNN Error: Feature column or expression '{0}' in test table is not" |
| " an array.".format(test_column_name)) |
| |
| cols_in_tbl_valid(point_source, [point_id], 'kNN') |
| cols_in_tbl_valid(test_source, [test_id], 'kNN') |
| |
| if not array_col_has_no_null(point_source, point_column_name): |
| plpy.error("kNN Error: Feature column '{0}' in train table has some" |
| " NULL values.".format(point_column_name)) |
| if not array_col_has_no_null(test_source, test_column_name): |
| plpy.error("kNN Error: Feature column '{0}' in test table has some" |
| " NULL values.".format(test_column_name)) |
| |
| if k <= 0: |
| plpy.error("kNN Error: k={0} is an invalid value, must be greater " |
| "than 0.".format(k)) |
| |
| bound = plpy.execute("SELECT {k} <= count(*) AS bound FROM {tbl}". |
| format(k=k, tbl=point_source))[0]['bound'] |
| if not bound: |
| plpy.error("kNN Error: k={0} is greater than number of rows in" |
| " training table.".format(k)) |
| |
| if label_column_name: |
| col_type = get_expr_type(label_column_name, point_source).lower() |
| if col_type not in ['integer', 'double precision', 'float', 'boolean']: |
| plpy.error("kNN error: Invalid data type '{0}' for" |
| " label_column_name in table '{1}'.". |
| format(col_type, point_source)) |
| |
| col_type_test = get_expr_type(test_id, test_source).lower() |
| if col_type_test not in INTEGER: |
| plpy.error("kNN Error: Invalid data type '{0}' for" |
| " test_id column in table '{1}'.". |
| format(col_type_test, test_source)) |
| |
| if fn_dist: |
| fn_dist = fn_dist.lower().strip() |
| profunc = ("proisagg = TRUE" |
| if is_pg_major_version_less_than(schema_madlib, 11) |
| else "prokind = 'a'") |
| |
| is_invalid_func = plpy.execute(""" |
| SELECT prorettype != 'DOUBLE PRECISION'::regtype OR {profunc} AS OUTPUT |
| FROM pg_proc |
| WHERE oid='{fn_dist}(DOUBLE PRECISION[], DOUBLE PRECISION[])'::regprocedure; |
| """.format(fn_dist=fn_dist, profunc=profunc))[0]['output'] |
| |
| if is_invalid_func: |
| plpy.error("KNN error: Distance function ({0}). Either the distance"\ |
| " function does not exist or the signature is wrong or it is"\ |
| " not a PostgreSQL type UDF. Also note that to use a MADlib"\ |
| " built-in distance function you must prepend with 'madlib',"\ |
| " schema name e.g., 'madlib.dist_norm2'".format(fn_dist)) |
| if not is_brute_force: |
| if depth <= 0: |
| plpy.error("kNN Error: depth={0} is an invalid value, must be " |
| "greater than 0.".format(depth)) |
| if leaf_nodes <= 0: |
| plpy.error("kNN Error: leaf_nodes={0} is an invalid value, must be " |
| "greater than 0.".format(leaf_nodes)) |
| if pow(2, depth) <= leaf_nodes: |
| plpy.error("kNN Error: depth={0}, leaf_nodes={1} is not valid. " |
| "The leaf_nodes value must be lower than 2^depth". |
| format(depth, leaf_nodes)) |
| return k |
| # ------------------------------------------------------------------------------ |
| |
| |
| def build_kd_tree(schema_madlib, source_table, output_table, point_column_name, |
| depth, r_id, **kwargs): |
| """ |
| KD-tree function to create a partitioning for KNN |
| Args: |
| @param schema_madlib Name of the Madlib Schema |
| @param source_table Training data table |
| @param output_table Name of the table to store kd tree |
| @param point_column_name Name of the column with training data |
| or expression that evaluates to a |
| numeric array |
| @param depth Depth of the kd tree |
| @param r_id Name of the region id column |
| """ |
| with MinWarning("error"): |
| |
| validate_kd_tree(source_table, output_table, point_column_name, depth) |
| n_features = num_features(source_table, point_column_name) |
| |
| clauses = [' 1=1 '] |
| centers_table = add_postfix(output_table, "_centers") |
| clause_counter = 0 |
| for curr_level in range(depth): |
| curr_feature = (curr_level % n_features) + 1 |
| for curr_leaf in range(pow(2,curr_level)): |
| clause = clauses[clause_counter] |
| cutoff_sql = """ |
| SELECT percentile_disc(0.5) |
| WITHIN GROUP ( |
| ORDER BY ({point_column_name})[{curr_feature}] |
| ) AS cutoff |
| FROM {source_table} |
| WHERE {clause} |
| """.format(**locals()) |
| |
| cutoff = plpy.execute(cutoff_sql)[0]['cutoff'] |
| cutoff = "NULL" if cutoff is None else cutoff |
| clause_counter += 1 |
| clauses.append(clause + |
| "AND ({point_column_name})[{curr_feature}] < {cutoff} ". |
| format(**locals())) |
| clauses.append(clause + |
| "AND ({point_column_name})[{curr_feature}] >= {cutoff} ". |
| format(**locals())) |
| |
| n_leaves = pow(2, depth) |
| case_when_clause = '\n'.join(["WHEN {0} THEN {1}::INTEGER".format(cond, i) |
| for i, cond in enumerate(clauses[-n_leaves:])]) |
| output_sql = """ |
| CREATE TABLE {output_table} AS |
| SELECT *, |
| CASE {case_when_clause} END AS {r_id} |
| FROM {source_table} |
| """.format(**locals()) |
| plpy.execute(output_sql) |
| plpy.execute("DROP TABLE IF EXISTS {0}".format(centers_table)) |
| centers_sql = """ |
| CREATE TABLE {centers_table} AS |
| SELECT {r_id}, {schema_madlib}.array_scalar_mult( |
| {schema_madlib}.sum({point_column_name})::DOUBLE PRECISION[], |
| (1.0/count(*))::DOUBLE PRECISION) AS __center__ |
| FROM {output_table} |
| GROUP BY {r_id} |
| """.format(**locals()) |
| plpy.execute(centers_sql) |
| return case_when_clause |
| # ------------------------------------------------------------------------------ |
| |
| |
| def validate_kd_tree(source_table, output_table, point_column_name, depth): |
| |
| input_tbl_valid(source_table, 'kd_tree') |
| output_tbl_valid(output_table, 'kd_tree') |
| output_tbl_valid(output_table+"_centers", 'kd_tree') |
| |
| _assert(is_var_valid(source_table, point_column_name), |
| "kd_tree error: {0} is an invalid column name or expression for " |
| "point_column_name param".format(point_column_name)) |
| point_col_type = get_expr_type(point_column_name, source_table) |
| _assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY), |
| "kNN Error: Feature column or expression '{0}' in train table is not" |
| " an array.".format(point_column_name)) |
| if depth <= 0: |
| plpy.error("kNN Error: depth={0} is an invalid value, must be greater " |
| "than 0.".format(depth)) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def knn_kd_tree(schema_madlib, kd_out, test_source, test_column_name, test_id, |
| fn_dist, max_leaves_to_explore, depth, r_id, case_when_clause, |
| t_col_name, **kwargs): |
| """ |
| KNN function to find the K Nearest neighbours using kd tree |
| Args: |
| @param schema_madlib Name of the Madlib Schema |
| @param kd_out Name of the kd tree table |
| @param test_source Name of the table containing the test |
| data points. |
| @param test_column_name Name of the column with testing data |
| points or expression that evaluates to a |
| numeric array |
| @param test_id Name of the column having ids of data |
| points in test data table. |
| @param fn_dist Distance metrics function. |
| @param max_leaves_to_explore Number of leaf nodes to explore |
| @param depth Depth of the kd tree |
| @param r_id Name of the region id column |
| @param case_when_clause SQL string for reconstructing the |
| kd-tree |
| @param t_col_name Unique test point column name |
| """ |
| with MinWarning("error"): |
| centers_table = add_postfix(kd_out, "_centers") |
| test_view = add_postfix(kd_out, "_test_view") |
| |
| n_leaves = pow(2,depth) |
| plpy.execute("DROP VIEW IF EXISTS {test_view}".format(**locals())) |
| test_view_sql = """ |
| CREATE VIEW {test_view} AS |
| SELECT {test_id}, |
| ({test_column_name})::DOUBLE PRECISION[] AS {t_col_name}, |
| CASE |
| {case_when_clause} |
| END AS {r_id} |
| FROM {test_source}""".format(**locals()) |
| plpy.execute(test_view_sql) |
| |
| if max_leaves_to_explore > 1: |
| ext_test_view = add_postfix(kd_out, "_ext_test_view") |
| ext_test_view_sql = """ |
| CREATE VIEW {ext_test_view} AS |
| SELECT * FROM( |
| SELECT |
| row_number() OVER (PARTITION BY {test_id} |
| ORDER BY __dist_center__) AS r, |
| {test_id}, |
| {t_col_name}, |
| {r_id} |
| FROM ( |
| SELECT |
| {test_id}, |
| {t_col_name}, |
| {centers_table}.{r_id} AS {r_id}, |
| {fn_dist}({t_col_name}, __center__) AS __dist_center__ |
| FROM {test_view}, {centers_table} |
| ) q1 |
| ) q2 |
| WHERE r <= {max_leaves_to_explore} |
| """.format(**locals()) |
| plpy.execute(ext_test_view_sql) |
| else: |
| ext_test_view = test_view |
| |
| return ext_test_view |
| # ------------------------------------------------------------------------------ |
| |
| def _create_interim_tbl(schema_madlib, point_source, point_column_name, point_id, |
| label_name, test_source, test_column_name, test_id, interim_table, k, |
| fn_dist, test_id_temp, train_id, dist_inverse, comma_label_out_alias, |
| label_out, r_id, kd_out, train, t_col_name, dist, **kwargs): |
| """ |
| KNN function to create the interim table |
| Args: |
| @param schema_madlib Name of the Madlib Schema |
| @param point_source Training data table |
| @param point_column_name Name of the column with training data |
| or expression that evaluates to a |
| numeric array |
| @param point_id Name of the column having ids of data |
| point in train data table |
| points. |
| @param label_name Name of the column with labels/values |
| of training data points. |
| @param test_source Name of the table containing the test |
| data points. |
| @param test_column_name Name of the column with testing data |
| points or expression that evaluates to a |
| numeric array |
| @param test_id Name of the column having ids of data |
| points in test data table. |
| @param interim_table Name of the table to store interim |
| results. |
| @param k default: 1. Number of nearest |
| neighbors to consider |
| @param fn_dist Distance metrics function. Default is |
| squared_dist_norm2. Following functions |
| are supported : |
| dist_norm1 , dist_norm2,squared_dist_norm2, |
| dist_angle , dist_tanimoto |
| Or user defined function with signature |
| DOUBLE PRECISION[] x, DOUBLE PRECISION[] y |
| -> DOUBLE PRECISION |
| Following parameters are passed to ensure the interim table has |
| identical features in both implementations |
| @param test_id_temp |
| @param train_id |
| @param dist_inverse |
| @param comma_label_out_alias |
| @param label_out |
| @param r_id |
| @param kd_out |
| @param train |
| @param t_col_name |
| @param dist |
| """ |
| with MinWarning("error"): |
| # If r_id is None, we are using the brute force algorithm. |
| is_brute_force = not bool(r_id) |
| r_id = "NULL AS {0}".format(unique_string()) if not r_id else r_id |
| |
| p_col_name = unique_string(desp='p_col_name') |
| x_temp_table = unique_string(desp='x_temp_table') |
| y_temp_table = unique_string(desp='y_temp_table') |
| test = unique_string(desp='test') |
| r = unique_string(desp='r') |
| |
| if not is_brute_force: |
| point_source = kd_out |
| where_condition = "{train}.{r_id} = {test}.{r_id} ".format(**locals()) |
| select_sql = """ {train}.{r_id} AS tr_{r_id}, |
| {test}.{r_id} AS test_{r_id}, """.format(**locals()) |
| t_col_cast = t_col_name |
| else: |
| where_condition = "1 = 1" |
| select_sql = "" |
| t_col_cast = "({test_column_name}) AS {t_col_name}".format(**locals()) |
| |
| plpy.execute(""" |
| CREATE TABLE {interim_table} AS |
| SELECT * |
| FROM ( |
| SELECT row_number() OVER |
| (PARTITION BY {test_id_temp} ORDER BY {dist}) AS {r}, |
| {test_id_temp}, |
| {train_id}, |
| {dist}, |
| CASE WHEN {dist} = 0.0 THEN {weight_for_zero_dist} |
| ELSE 1.0 / {dist} |
| END AS {dist_inverse} |
| {comma_label_out_alias} |
| FROM ( |
| SELECT {select_sql} |
| {test}.{test_id} AS {test_id_temp}, |
| {train}.{point_id} AS {train_id}, |
| {fn_dist}({p_col_name}, {t_col_name}) AS {dist} |
| {label_out} |
| FROM |
| ( |
| SELECT {point_id}, |
| {r_id}, |
| {point_column_name} AS {p_col_name} |
| {label_name} |
| FROM {point_source} |
| ) {train}, |
| ( |
| SELECT {test_id}, |
| {t_col_cast}, |
| {r_id} |
| FROM {test_source} |
| ) {test} |
| WHERE |
| {where_condition} |
| ) {x_temp_table} |
| ) {y_temp_table} |
| WHERE {y_temp_table}.{r} <= {k} |
| """.format(weight_for_zero_dist=WEIGHT_FOR_ZERO_DIST, **locals())) |
| |
| # ------------------------------------------------------------------------------ |
| |
| def knn(schema_madlib, point_source, point_column_name, point_id, |
| label_column_name, test_source, test_column_name, test_id, output_table, |
| k, output_neighbors, fn_dist, weighted_avg, algorithm, algorithm_params, |
| **kwargs): |
| """ |
| KNN function to find the K Nearest neighbours |
| Args: |
| @param schema_madlib Name of the Madlib Schema |
| @param point_source Training data table |
| @param point_column_name Name of the column with training data |
| or expression that evaluates to a |
| numeric array |
| @param point_id Name of the column having ids of data |
| point in train data table |
| points. |
| @param label_column_name Name of the column with labels/values |
| of training data points. |
| @param test_source Name of the table containing the test |
| data points. |
| @param test_column_name Name of the column with testing data |
| points or expression that evaluates to a |
| numeric array |
| @param test_id Name of the column having ids of data |
| points in test data table. |
| @param output_table Name of the table to store final |
| results. |
| @param k default: 1. Number of nearest |
| neighbors to consider |
| @param output_neighbours Outputs the list of k-nearest neighbors |
| that were used in the voting/averaging. |
| @param fn_dist Distance metrics function. Default is |
| squared_dist_norm2. Following functions |
| are supported : |
| dist_norm1 , dist_norm2,squared_dist_norm2, |
| dist_angle , dist_tanimoto |
| Or user defined function with signature |
| DOUBLE PRECISION[] x, DOUBLE PRECISION[] y |
| -> DOUBLE PRECISION |
| @param weighted_avg Calculates the Regression or |
| classication of k-NN using |
| the weighted average method. |
| @param algorithm The algorithm to use for knn |
| @param algorithm_params The parameters for kd-tree algorithm |
| """ |
| with MinWarning('warning'): |
| output_neighbors = True if output_neighbors is None else output_neighbors |
| if k is None: |
| k = 1 |
| |
| algorithm = get_algorithm_name(algorithm, BRUTE_FORCE, |
| [BRUTE_FORCE, KD_TREE], 'kNN') |
| |
| # Default values for depth and leaf nodes |
| depth = 3 |
| max_leaves_to_explore = 2 |
| |
| if algorithm_params: |
| params_types = {'depth': int, 'leaf_nodes': int} |
| default_args = {'depth': 3, 'leaf_nodes': 2} |
| algorithm_params_dict = extract_keyvalue_params(algorithm_params, |
| params_types, |
| default_args) |
| |
| depth = algorithm_params_dict['depth'] |
| max_leaves_to_explore = algorithm_params_dict['leaf_nodes'] |
| |
| knn_validate_src(schema_madlib, point_source, |
| point_column_name, point_id, label_column_name, |
| test_source, test_column_name, test_id, |
| output_table, k, output_neighbors, fn_dist, |
| algorithm == BRUTE_FORCE, depth, max_leaves_to_explore) |
| |
| n_features = num_features(test_source, test_column_name) |
| |
| # Unique Strings |
| label_col_temp = unique_string(desp='label_col_temp') |
| test_id_temp = unique_string(desp='test_id_temp') |
| |
| train = unique_string(desp='train') |
| train_id = unique_string(desp='train_id') |
| dist_inverse = unique_string(desp='dist_inverse') |
| dim = unique_string(desp='dim') |
| t_col_name = unique_string(desp='t_col_name') |
| dist = unique_string(desp='dist') |
| |
| if not fn_dist: |
| fn_dist = '{0}.squared_dist_norm2'.format(schema_madlib) |
| |
| fn_dist = fn_dist.lower().strip() |
| interim_table = unique_string(desp='interim_table') |
| |
| pred_out = "" |
| knn_neighbors = "" |
| label_out = "" |
| cast_to_int = "" |
| view_def = "" |
| view_join = "" |
| view_grp_by = "" |
| r_id = None |
| kd_output_table = None |
| test_data = None |
| |
| if label_column_name: |
| label_column_type = get_expr_type( |
| label_column_name, point_source).lower() |
| if label_column_type in ['boolean', 'integer', 'text']: |
| is_classification = True |
| cast_to_int = '::INTEGER' |
| else: |
| is_classification = False |
| |
| if is_classification: |
| if weighted_avg: |
| # This view is to calculate the max value of sum of the 1/distance grouped by label and Id. |
| # And this max value will be the prediction for the |
| # classification model. |
| view_def = """ |
| WITH vw AS ( |
| SELECT DISTINCT ON({test_id_temp}) |
| {test_id_temp}, |
| last_value(data_sum) OVER ( |
| PARTITION BY {test_id_temp} |
| ORDER BY data_sum, {label_col_temp} |
| ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING |
| ) AS data_dist , |
| last_value({label_col_temp}) OVER ( |
| PARTITION BY {test_id_temp} |
| ORDER BY data_sum, {label_col_temp} |
| ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING |
| ) AS {label_col_temp} |
| FROM ( |
| SELECT |
| {test_id_temp}, |
| {label_col_temp}, |
| sum({dist_inverse}) data_sum |
| FROM {interim_table} |
| GROUP BY {test_id_temp}, |
| {label_col_temp} |
| ) a |
| ) |
| """.format(**locals()) |
| # This join is needed to get the max value of predicion |
| # calculated above |
| view_join = (" JOIN vw ON knn_temp.{0} = vw.{0}". |
| format(test_id_temp)) |
| view_grp_by = ", vw.{0}".format(label_col_temp) |
| pred_out = ", vw.{0}".format(label_col_temp) |
| else: |
| pred_out = ", {0}.mode({1})".format( |
| schema_madlib, label_col_temp) |
| else: |
| if weighted_avg: |
| pred_out = (", sum({0} * {dist_inverse}) / sum({dist_inverse})". |
| format(label_col_temp, dist_inverse=dist_inverse)) |
| else: |
| pred_out = ", avg({0})".format(label_col_temp) |
| |
| pred_out += " AS prediction" |
| label_out = (", {train}.{label_column_name}{cast_to_int}" |
| " AS {label_col_temp}").format(**locals()) |
| comma_label_out_alias = ', ' + label_col_temp |
| label_name = ", {label_column_name}".format( |
| label_column_name=label_column_name) |
| |
| else: |
| pred_out = "" |
| label_out = "" |
| comma_label_out_alias = "" |
| label_name = "" |
| |
| if output_neighbors: |
| knn_neighbors = (", array_agg(knn_temp.{train_id} ORDER BY " |
| "knn_temp.{dist_inverse} DESC) AS k_nearest_neighbours " |
| ", array_agg(knn_temp.{dist} ORDER BY " |
| "knn_temp.{dist_inverse} DESC) AS distance").format(**locals()) |
| else: |
| knn_neighbors = '' |
| |
| if 'kd_tree' in algorithm: |
| r_id = unique_string(desp='r_id') |
| kd_output_table = unique_string(desp='kd_tree') |
| case_when_clause = build_kd_tree(schema_madlib, |
| point_source, |
| kd_output_table, |
| point_column_name, |
| depth, r_id) |
| test_data = knn_kd_tree(schema_madlib, kd_output_table, test_source, |
| test_column_name, test_id, fn_dist, |
| max_leaves_to_explore, depth, r_id, |
| case_when_clause, t_col_name) |
| else: |
| test_data = test_source |
| |
| # interim_table picks the 'k' nearest neighbors for each test point |
| _create_interim_tbl(schema_madlib, point_source, point_column_name, |
| point_id, label_name, test_data, test_column_name, |
| test_id, interim_table, k, fn_dist, test_id_temp, |
| train_id, dist_inverse, comma_label_out_alias, |
| label_out, r_id, kd_output_table, train, t_col_name, |
| dist) |
| output_sql = """ |
| CREATE TABLE {output_table} AS |
| {view_def} |
| SELECT |
| knn_temp.{test_id_temp} AS id, |
| {test_column_name} as "{test_column_name}" |
| {pred_out} |
| {knn_neighbors} |
| FROM |
| {interim_table} AS knn_temp |
| JOIN |
| {test_source} AS knn_test |
| ON knn_temp.{test_id_temp} = knn_test.{test_id} |
| {view_join} |
| GROUP BY knn_temp.{test_id_temp}, |
| {test_column_name} |
| {view_grp_by} |
| """.format(**locals()) |
| plpy.execute(output_sql) |
| drop_tables([interim_table]) |
| |
| if 'kd_tree' in algorithm: |
| centers_table = add_postfix(kd_output_table, "_centers") |
| test_view = add_postfix(kd_output_table, "_test_view") |
| ext_test_view = add_postfix(kd_output_table, "_ext_test_view") |
| plpy.execute("DROP VIEW IF EXISTS {0} CASCADE".format(test_view)) |
| plpy.execute("DROP VIEW IF EXISTS {0} CASCADE".format(ext_test_view)) |
| drop_tables([centers_table, kd_output_table]) |
| return |
| # ------------------------------------------------------------------------------ |
| |
| def knn_help(schema_madlib, message, **kwargs): |
| """ |
| Help function for knn |
| |
| Args: |
| @param schema_madlib |
| @param message: string, Help message string |
| @param kwargs |
| |
| Returns: |
| String. Help/usage information |
| """ |
| if message is not None and \ |
| message.lower() in ("usage", "help", "?"): |
| help_string = """ |
| ----------------------------------------------------------------------- |
| USAGE |
| ----------------------------------------------------------------------- |
| SELECT {schema_madlib}.knn( |
| point_source, -- Training data table having training features as vector column and labels |
| point_column_name, -- Name of column having feature vectors in training data table |
| point_id, -- Name of column having feature vector Ids in train data table |
| label_column_name, -- Name of column having actual label/vlaue for corresponding feature vector in training data table |
| test_source, -- Test data table having features as vector column. Id of features is mandatory |
| test_column_name, -- Name of column having feature vectors in test data table |
| test_id, -- Name of column having feature vector Ids in test data table |
| output_table, -- Name of output table |
| k, -- value of k. Default will go as 1 |
| output_neighbors -- Outputs the list of k-nearest neighbors that were used in the voting/averaging. |
| fn_dist -- The name of the function to use to calculate the distance from a data point to a centroid. |
| weighted_avg -- Calculates the Regression or classication of k-NN using the weighted average method. |
| algorithm -- The algorithm to use for knn. |
| algorithm_params -- The parameters for kd-tree algorithm. |
| ); |
| |
| ----------------------------------------------------------------------- |
| OUTPUT |
| ----------------------------------------------------------------------- |
| The output of the KNN module is a table with the following columns: |
| |
| id The ids of test data points. |
| test_column_name The test data points. |
| prediction The output of KNN- label in case of classification, average value in case of regression. |
| k_nearest_neighbours The list of k-nearest neighbors that were used in the voting/averaging. |
| distance The list of nearest distances, sorted closest to furthest from the corresponding test point. |
| """ |
| else: |
| help_string = """ |
| ---------------------------------------------------------------------------- |
| SUMMARY |
| ---------------------------------------------------------------------------- |
| k-Nearest Neighbors is a method for finding k closest points to a given data |
| point in terms of a given metric. Its input consist of data points as features |
| from testing examples. For a given k, it looks for k closest points in |
| training set for each of the data points in test set. Algorithm generates one |
| output per testing example. The output of KNN depends on the type of task: |
| For Classification, the output is majority vote of the classes of the k |
| nearest data points. The testing example gets assigned the most popular class |
| among nearest neighbors. For Regression, the output is average of the values |
| of k nearest neighbors of the given testing example. |
| -- |
| For an overview on usage, run: |
| SELECT {schema_madlib}.knn('usage'); |
| """ |
| return help_string.format(schema_madlib=schema_madlib) |
| # ------------------------------------------------------------------------------ |