blob: eb2150e2b5328065610eec1bff35383b22501f8c [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
@file knn.py_in
@brief knn: K-Nearest Neighbors for regression and classification
@namespace knn
"""
import plpy
import copy
from collections import defaultdict
from math import log
from utilities.control import MinWarning
from utilities.utilities import INTEGER
from utilities.utilities import _assert
from utilities.utilities import add_postfix
from utilities.utilities import extract_keyvalue_params
from utilities.utilities import py_list_to_sql_string
from utilities.utilities import unique_string
from utilities.utilities import NUMERIC, ONLY_ARRAY
from utilities.utilities import is_valid_psql_type
from utilities.utilities import is_pg_major_version_less_than
from utilities.utilities import num_features
from utilities.validate_args import array_col_has_no_null
from utilities.validate_args import cols_in_tbl_valid
from utilities.validate_args import drop_tables
from utilities.validate_args import get_cols
from utilities.validate_args import get_expr_type
from utilities.validate_args import input_tbl_valid, output_tbl_valid
from utilities.validate_args import is_col_array
from utilities.validate_args import is_var_valid
from utilities.validate_args import quote_ident
from utilities.validate_args import get_algorithm_name
WEIGHT_FOR_ZERO_DIST = 1e107
BRUTE_FORCE = 'brute_force'
KD_TREE = 'kd_tree'
def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
label_column_name, test_source, test_column_name,
test_id, output_table, k, output_neighbors, fn_dist,
is_brute_force, depth, leaf_nodes, **kwargs):
input_tbl_valid(point_source, 'kNN')
input_tbl_valid(test_source, 'kNN')
output_tbl_valid(output_table, 'kNN')
_assert(label_column_name or output_neighbors,
"kNN error: Either label_column_name or "
"output_neighbors has to be inputed.")
if label_column_name and label_column_name.strip():
cols_in_tbl_valid(point_source, [label_column_name], 'kNN')
_assert(is_var_valid(point_source, point_column_name),
"kNN error: {0} is an invalid column name or "
"expression for point_column_name param".format(point_column_name))
point_col_type = get_expr_type(point_column_name, point_source)
_assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
"kNN Error: Feature column or expression '{0}' in train table is not"
" an array.".format(point_column_name))
_assert(is_var_valid(test_source, test_column_name),
"kNN error: {0} is an invalid column name or expression for "
"test_column_name param".format(test_column_name))
test_col_type = get_expr_type(test_column_name, test_source)
_assert(is_valid_psql_type(test_col_type, NUMERIC | ONLY_ARRAY),
"kNN Error: Feature column or expression '{0}' in test table is not"
" an array.".format(test_column_name))
cols_in_tbl_valid(point_source, [point_id], 'kNN')
cols_in_tbl_valid(test_source, [test_id], 'kNN')
if not array_col_has_no_null(point_source, point_column_name):
plpy.error("kNN Error: Feature column '{0}' in train table has some"
" NULL values.".format(point_column_name))
if not array_col_has_no_null(test_source, test_column_name):
plpy.error("kNN Error: Feature column '{0}' in test table has some"
" NULL values.".format(test_column_name))
if k <= 0:
plpy.error("kNN Error: k={0} is an invalid value, must be greater "
"than 0.".format(k))
bound = plpy.execute("SELECT {k} <= count(*) AS bound FROM {tbl}".
format(k=k, tbl=point_source))[0]['bound']
if not bound:
plpy.error("kNN Error: k={0} is greater than number of rows in"
" training table.".format(k))
if label_column_name:
col_type = get_expr_type(label_column_name, point_source).lower()
if col_type not in ['integer', 'double precision', 'float', 'boolean']:
plpy.error("kNN error: Invalid data type '{0}' for"
" label_column_name in table '{1}'.".
format(col_type, point_source))
col_type_test = get_expr_type(test_id, test_source).lower()
if col_type_test not in INTEGER:
plpy.error("kNN Error: Invalid data type '{0}' for"
" test_id column in table '{1}'.".
format(col_type_test, test_source))
if fn_dist:
fn_dist = fn_dist.lower().strip()
profunc = ("proisagg = TRUE"
if is_pg_major_version_less_than(schema_madlib, 11)
else "prokind = 'a'")
is_invalid_func = plpy.execute("""
SELECT prorettype != 'DOUBLE PRECISION'::regtype OR {profunc} AS OUTPUT
FROM pg_proc
WHERE oid='{fn_dist}(DOUBLE PRECISION[], DOUBLE PRECISION[])'::regprocedure;
""".format(fn_dist=fn_dist, profunc=profunc))[0]['output']
if is_invalid_func:
plpy.error("KNN error: Distance function ({0}). Either the distance"\
" function does not exist or the signature is wrong or it is"\
" not a PostgreSQL type UDF. Also note that to use a MADlib"\
" built-in distance function you must prepend with 'madlib',"\
" schema name e.g., 'madlib.dist_norm2'".format(fn_dist))
if not is_brute_force:
if depth <= 0:
plpy.error("kNN Error: depth={0} is an invalid value, must be "
"greater than 0.".format(depth))
if leaf_nodes <= 0:
plpy.error("kNN Error: leaf_nodes={0} is an invalid value, must be "
"greater than 0.".format(leaf_nodes))
if pow(2, depth) <= leaf_nodes:
plpy.error("kNN Error: depth={0}, leaf_nodes={1} is not valid. "
"The leaf_nodes value must be lower than 2^depth".
format(depth, leaf_nodes))
return k
# ------------------------------------------------------------------------------
def build_kd_tree(schema_madlib, source_table, output_table, point_column_name,
depth, r_id, **kwargs):
"""
KD-tree function to create a partitioning for KNN
Args:
@param schema_madlib Name of the Madlib Schema
@param source_table Training data table
@param output_table Name of the table to store kd tree
@param point_column_name Name of the column with training data
or expression that evaluates to a
numeric array
@param depth Depth of the kd tree
@param r_id Name of the region id column
"""
with MinWarning("error"):
validate_kd_tree(source_table, output_table, point_column_name, depth)
n_features = num_features(source_table, point_column_name)
clauses = [' 1=1 ']
centers_table = add_postfix(output_table, "_centers")
clause_counter = 0
for curr_level in range(depth):
curr_feature = (curr_level % n_features) + 1
for curr_leaf in range(pow(2,curr_level)):
clause = clauses[clause_counter]
cutoff_sql = """
SELECT percentile_disc(0.5)
WITHIN GROUP (
ORDER BY ({point_column_name})[{curr_feature}]
) AS cutoff
FROM {source_table}
WHERE {clause}
""".format(**locals())
cutoff = plpy.execute(cutoff_sql)[0]['cutoff']
cutoff = "NULL" if cutoff is None else cutoff
clause_counter += 1
clauses.append(clause +
"AND ({point_column_name})[{curr_feature}] < {cutoff} ".
format(**locals()))
clauses.append(clause +
"AND ({point_column_name})[{curr_feature}] >= {cutoff} ".
format(**locals()))
n_leaves = pow(2, depth)
case_when_clause = '\n'.join(["WHEN {0} THEN {1}::INTEGER".format(cond, i)
for i, cond in enumerate(clauses[-n_leaves:])])
output_sql = """
CREATE TABLE {output_table} AS
SELECT *,
CASE {case_when_clause} END AS {r_id}
FROM {source_table}
""".format(**locals())
plpy.execute(output_sql)
plpy.execute("DROP TABLE IF EXISTS {0}".format(centers_table))
centers_sql = """
CREATE TABLE {centers_table} AS
SELECT {r_id}, {schema_madlib}.array_scalar_mult(
{schema_madlib}.sum({point_column_name})::DOUBLE PRECISION[],
(1.0/count(*))::DOUBLE PRECISION) AS __center__
FROM {output_table}
GROUP BY {r_id}
""".format(**locals())
plpy.execute(centers_sql)
return case_when_clause
# ------------------------------------------------------------------------------
def validate_kd_tree(source_table, output_table, point_column_name, depth):
input_tbl_valid(source_table, 'kd_tree')
output_tbl_valid(output_table, 'kd_tree')
output_tbl_valid(output_table+"_centers", 'kd_tree')
_assert(is_var_valid(source_table, point_column_name),
"kd_tree error: {0} is an invalid column name or expression for "
"point_column_name param".format(point_column_name))
point_col_type = get_expr_type(point_column_name, source_table)
_assert(is_valid_psql_type(point_col_type, NUMERIC | ONLY_ARRAY),
"kNN Error: Feature column or expression '{0}' in train table is not"
" an array.".format(point_column_name))
if depth <= 0:
plpy.error("kNN Error: depth={0} is an invalid value, must be greater "
"than 0.".format(depth))
# ------------------------------------------------------------------------------
def knn_kd_tree(schema_madlib, kd_out, test_source, test_column_name, test_id,
fn_dist, max_leaves_to_explore, depth, r_id, case_when_clause,
t_col_name, **kwargs):
"""
KNN function to find the K Nearest neighbours using kd tree
Args:
@param schema_madlib Name of the Madlib Schema
@param kd_out Name of the kd tree table
@param test_source Name of the table containing the test
data points.
@param test_column_name Name of the column with testing data
points or expression that evaluates to a
numeric array
@param test_id Name of the column having ids of data
points in test data table.
@param fn_dist Distance metrics function.
@param max_leaves_to_explore Number of leaf nodes to explore
@param depth Depth of the kd tree
@param r_id Name of the region id column
@param case_when_clause SQL string for reconstructing the
kd-tree
@param t_col_name Unique test point column name
"""
with MinWarning("error"):
centers_table = add_postfix(kd_out, "_centers")
test_view = add_postfix(kd_out, "_test_view")
n_leaves = pow(2,depth)
plpy.execute("DROP VIEW IF EXISTS {test_view}".format(**locals()))
test_view_sql = """
CREATE VIEW {test_view} AS
SELECT {test_id},
({test_column_name})::DOUBLE PRECISION[] AS {t_col_name},
CASE
{case_when_clause}
END AS {r_id}
FROM {test_source}""".format(**locals())
plpy.execute(test_view_sql)
if max_leaves_to_explore > 1:
ext_test_view = add_postfix(kd_out, "_ext_test_view")
ext_test_view_sql = """
CREATE VIEW {ext_test_view} AS
SELECT * FROM(
SELECT
row_number() OVER (PARTITION BY {test_id}
ORDER BY __dist_center__) AS r,
{test_id},
{t_col_name},
{r_id}
FROM (
SELECT
{test_id},
{t_col_name},
{centers_table}.{r_id} AS {r_id},
{fn_dist}({t_col_name}, __center__) AS __dist_center__
FROM {test_view}, {centers_table}
) q1
) q2
WHERE r <= {max_leaves_to_explore}
""".format(**locals())
plpy.execute(ext_test_view_sql)
else:
ext_test_view = test_view
return ext_test_view
# ------------------------------------------------------------------------------
def _create_interim_tbl(schema_madlib, point_source, point_column_name, point_id,
label_name, test_source, test_column_name, test_id, interim_table, k,
fn_dist, test_id_temp, train_id, dist_inverse, comma_label_out_alias,
label_out, r_id, kd_out, train, t_col_name, dist, **kwargs):
"""
KNN function to create the interim table
Args:
@param schema_madlib Name of the Madlib Schema
@param point_source Training data table
@param point_column_name Name of the column with training data
or expression that evaluates to a
numeric array
@param point_id Name of the column having ids of data
point in train data table
points.
@param label_name Name of the column with labels/values
of training data points.
@param test_source Name of the table containing the test
data points.
@param test_column_name Name of the column with testing data
points or expression that evaluates to a
numeric array
@param test_id Name of the column having ids of data
points in test data table.
@param interim_table Name of the table to store interim
results.
@param k default: 1. Number of nearest
neighbors to consider
@param fn_dist Distance metrics function. Default is
squared_dist_norm2. Following functions
are supported :
dist_norm1 , dist_norm2,squared_dist_norm2,
dist_angle , dist_tanimoto
Or user defined function with signature
DOUBLE PRECISION[] x, DOUBLE PRECISION[] y
-> DOUBLE PRECISION
Following parameters are passed to ensure the interim table has
identical features in both implementations
@param test_id_temp
@param train_id
@param dist_inverse
@param comma_label_out_alias
@param label_out
@param r_id
@param kd_out
@param train
@param t_col_name
@param dist
"""
with MinWarning("error"):
# If r_id is None, we are using the brute force algorithm.
is_brute_force = not bool(r_id)
r_id = "NULL AS {0}".format(unique_string()) if not r_id else r_id
p_col_name = unique_string(desp='p_col_name')
x_temp_table = unique_string(desp='x_temp_table')
y_temp_table = unique_string(desp='y_temp_table')
test = unique_string(desp='test')
r = unique_string(desp='r')
if not is_brute_force:
point_source = kd_out
where_condition = "{train}.{r_id} = {test}.{r_id} ".format(**locals())
select_sql = """ {train}.{r_id} AS tr_{r_id},
{test}.{r_id} AS test_{r_id}, """.format(**locals())
t_col_cast = t_col_name
else:
where_condition = "1 = 1"
select_sql = ""
t_col_cast = "({test_column_name}) AS {t_col_name}".format(**locals())
plpy.execute("""
CREATE TABLE {interim_table} AS
SELECT *
FROM (
SELECT row_number() OVER
(PARTITION BY {test_id_temp} ORDER BY {dist}) AS {r},
{test_id_temp},
{train_id},
{dist},
CASE WHEN {dist} = 0.0 THEN {weight_for_zero_dist}
ELSE 1.0 / {dist}
END AS {dist_inverse}
{comma_label_out_alias}
FROM (
SELECT {select_sql}
{test}.{test_id} AS {test_id_temp},
{train}.{point_id} AS {train_id},
{fn_dist}({p_col_name}, {t_col_name}) AS {dist}
{label_out}
FROM
(
SELECT {point_id},
{r_id},
{point_column_name} AS {p_col_name}
{label_name}
FROM {point_source}
) {train},
(
SELECT {test_id},
{t_col_cast},
{r_id}
FROM {test_source}
) {test}
WHERE
{where_condition}
) {x_temp_table}
) {y_temp_table}
WHERE {y_temp_table}.{r} <= {k}
""".format(weight_for_zero_dist=WEIGHT_FOR_ZERO_DIST, **locals()))
# ------------------------------------------------------------------------------
def knn(schema_madlib, point_source, point_column_name, point_id,
label_column_name, test_source, test_column_name, test_id, output_table,
k, output_neighbors, fn_dist, weighted_avg, algorithm, algorithm_params,
**kwargs):
"""
KNN function to find the K Nearest neighbours
Args:
@param schema_madlib Name of the Madlib Schema
@param point_source Training data table
@param point_column_name Name of the column with training data
or expression that evaluates to a
numeric array
@param point_id Name of the column having ids of data
point in train data table
points.
@param label_column_name Name of the column with labels/values
of training data points.
@param test_source Name of the table containing the test
data points.
@param test_column_name Name of the column with testing data
points or expression that evaluates to a
numeric array
@param test_id Name of the column having ids of data
points in test data table.
@param output_table Name of the table to store final
results.
@param k default: 1. Number of nearest
neighbors to consider
@param output_neighbours Outputs the list of k-nearest neighbors
that were used in the voting/averaging.
@param fn_dist Distance metrics function. Default is
squared_dist_norm2. Following functions
are supported :
dist_norm1 , dist_norm2,squared_dist_norm2,
dist_angle , dist_tanimoto
Or user defined function with signature
DOUBLE PRECISION[] x, DOUBLE PRECISION[] y
-> DOUBLE PRECISION
@param weighted_avg Calculates the Regression or
classication of k-NN using
the weighted average method.
@param algorithm The algorithm to use for knn
@param algorithm_params The parameters for kd-tree algorithm
"""
with MinWarning('warning'):
output_neighbors = True if output_neighbors is None else output_neighbors
if k is None:
k = 1
algorithm = get_algorithm_name(algorithm, BRUTE_FORCE,
[BRUTE_FORCE, KD_TREE], 'kNN')
# Default values for depth and leaf nodes
depth = 3
max_leaves_to_explore = 2
if algorithm_params:
params_types = {'depth': int, 'leaf_nodes': int}
default_args = {'depth': 3, 'leaf_nodes': 2}
algorithm_params_dict = extract_keyvalue_params(algorithm_params,
params_types,
default_args)
depth = algorithm_params_dict['depth']
max_leaves_to_explore = algorithm_params_dict['leaf_nodes']
knn_validate_src(schema_madlib, point_source,
point_column_name, point_id, label_column_name,
test_source, test_column_name, test_id,
output_table, k, output_neighbors, fn_dist,
algorithm == BRUTE_FORCE, depth, max_leaves_to_explore)
n_features = num_features(test_source, test_column_name)
# Unique Strings
label_col_temp = unique_string(desp='label_col_temp')
test_id_temp = unique_string(desp='test_id_temp')
train = unique_string(desp='train')
train_id = unique_string(desp='train_id')
dist_inverse = unique_string(desp='dist_inverse')
dim = unique_string(desp='dim')
t_col_name = unique_string(desp='t_col_name')
dist = unique_string(desp='dist')
if not fn_dist:
fn_dist = '{0}.squared_dist_norm2'.format(schema_madlib)
fn_dist = fn_dist.lower().strip()
interim_table = unique_string(desp='interim_table')
pred_out = ""
knn_neighbors = ""
label_out = ""
cast_to_int = ""
view_def = ""
view_join = ""
view_grp_by = ""
r_id = None
kd_output_table = None
test_data = None
if label_column_name:
label_column_type = get_expr_type(
label_column_name, point_source).lower()
if label_column_type in ['boolean', 'integer', 'text']:
is_classification = True
cast_to_int = '::INTEGER'
else:
is_classification = False
if is_classification:
if weighted_avg:
# This view is to calculate the max value of sum of the 1/distance grouped by label and Id.
# And this max value will be the prediction for the
# classification model.
view_def = """
WITH vw AS (
SELECT DISTINCT ON({test_id_temp})
{test_id_temp},
last_value(data_sum) OVER (
PARTITION BY {test_id_temp}
ORDER BY data_sum, {label_col_temp}
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS data_dist ,
last_value({label_col_temp}) OVER (
PARTITION BY {test_id_temp}
ORDER BY data_sum, {label_col_temp}
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS {label_col_temp}
FROM (
SELECT
{test_id_temp},
{label_col_temp},
sum({dist_inverse}) data_sum
FROM {interim_table}
GROUP BY {test_id_temp},
{label_col_temp}
) a
)
""".format(**locals())
# This join is needed to get the max value of predicion
# calculated above
view_join = (" JOIN vw ON knn_temp.{0} = vw.{0}".
format(test_id_temp))
view_grp_by = ", vw.{0}".format(label_col_temp)
pred_out = ", vw.{0}".format(label_col_temp)
else:
pred_out = ", {0}.mode({1})".format(
schema_madlib, label_col_temp)
else:
if weighted_avg:
pred_out = (", sum({0} * {dist_inverse}) / sum({dist_inverse})".
format(label_col_temp, dist_inverse=dist_inverse))
else:
pred_out = ", avg({0})".format(label_col_temp)
pred_out += " AS prediction"
label_out = (", {train}.{label_column_name}{cast_to_int}"
" AS {label_col_temp}").format(**locals())
comma_label_out_alias = ', ' + label_col_temp
label_name = ", {label_column_name}".format(
label_column_name=label_column_name)
else:
pred_out = ""
label_out = ""
comma_label_out_alias = ""
label_name = ""
if output_neighbors:
knn_neighbors = (", array_agg(knn_temp.{train_id} ORDER BY "
"knn_temp.{dist_inverse} DESC) AS k_nearest_neighbours "
", array_agg(knn_temp.{dist} ORDER BY "
"knn_temp.{dist_inverse} DESC) AS distance").format(**locals())
else:
knn_neighbors = ''
if 'kd_tree' in algorithm:
r_id = unique_string(desp='r_id')
kd_output_table = unique_string(desp='kd_tree')
case_when_clause = build_kd_tree(schema_madlib,
point_source,
kd_output_table,
point_column_name,
depth, r_id)
test_data = knn_kd_tree(schema_madlib, kd_output_table, test_source,
test_column_name, test_id, fn_dist,
max_leaves_to_explore, depth, r_id,
case_when_clause, t_col_name)
else:
test_data = test_source
# interim_table picks the 'k' nearest neighbors for each test point
_create_interim_tbl(schema_madlib, point_source, point_column_name,
point_id, label_name, test_data, test_column_name,
test_id, interim_table, k, fn_dist, test_id_temp,
train_id, dist_inverse, comma_label_out_alias,
label_out, r_id, kd_output_table, train, t_col_name,
dist)
output_sql = """
CREATE TABLE {output_table} AS
{view_def}
SELECT
knn_temp.{test_id_temp} AS id,
{test_column_name} as "{test_column_name}"
{pred_out}
{knn_neighbors}
FROM
{interim_table} AS knn_temp
JOIN
{test_source} AS knn_test
ON knn_temp.{test_id_temp} = knn_test.{test_id}
{view_join}
GROUP BY knn_temp.{test_id_temp},
{test_column_name}
{view_grp_by}
""".format(**locals())
plpy.execute(output_sql)
drop_tables([interim_table])
if 'kd_tree' in algorithm:
centers_table = add_postfix(kd_output_table, "_centers")
test_view = add_postfix(kd_output_table, "_test_view")
ext_test_view = add_postfix(kd_output_table, "_ext_test_view")
plpy.execute("DROP VIEW IF EXISTS {0} CASCADE".format(test_view))
plpy.execute("DROP VIEW IF EXISTS {0} CASCADE".format(ext_test_view))
drop_tables([centers_table, kd_output_table])
return
# ------------------------------------------------------------------------------
def knn_help(schema_madlib, message, **kwargs):
"""
Help function for knn
Args:
@param schema_madlib
@param message: string, Help message string
@param kwargs
Returns:
String. Help/usage information
"""
if message is not None and \
message.lower() in ("usage", "help", "?"):
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT {schema_madlib}.knn(
point_source, -- Training data table having training features as vector column and labels
point_column_name, -- Name of column having feature vectors in training data table
point_id, -- Name of column having feature vector Ids in train data table
label_column_name, -- Name of column having actual label/vlaue for corresponding feature vector in training data table
test_source, -- Test data table having features as vector column. Id of features is mandatory
test_column_name, -- Name of column having feature vectors in test data table
test_id, -- Name of column having feature vector Ids in test data table
output_table, -- Name of output table
k, -- value of k. Default will go as 1
output_neighbors -- Outputs the list of k-nearest neighbors that were used in the voting/averaging.
fn_dist -- The name of the function to use to calculate the distance from a data point to a centroid.
weighted_avg -- Calculates the Regression or classication of k-NN using the weighted average method.
algorithm -- The algorithm to use for knn.
algorithm_params -- The parameters for kd-tree algorithm.
);
-----------------------------------------------------------------------
OUTPUT
-----------------------------------------------------------------------
The output of the KNN module is a table with the following columns:
id The ids of test data points.
test_column_name The test data points.
prediction The output of KNN- label in case of classification, average value in case of regression.
k_nearest_neighbours The list of k-nearest neighbors that were used in the voting/averaging.
distance The list of nearest distances, sorted closest to furthest from the corresponding test point.
"""
else:
help_string = """
----------------------------------------------------------------------------
SUMMARY
----------------------------------------------------------------------------
k-Nearest Neighbors is a method for finding k closest points to a given data
point in terms of a given metric. Its input consist of data points as features
from testing examples. For a given k, it looks for k closest points in
training set for each of the data points in test set. Algorithm generates one
output per testing example. The output of KNN depends on the type of task:
For Classification, the output is majority vote of the classes of the k
nearest data points. The testing example gets assigned the most popular class
among nearest neighbors. For Regression, the output is average of the values
of k nearest neighbors of the given testing example.
--
For an overview on usage, run:
SELECT {schema_madlib}.knn('usage');
"""
return help_string.format(schema_madlib=schema_madlib)
# ------------------------------------------------------------------------------