blob: dec3f1f16c3ec6f89d9d00ece81896e4f8b2593a [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
@file knn.py_in
@brief knn: Driver functions
@namespace knn
"""
import plpy
from utilities.validate_args import input_tbl_valid, output_tbl_valid
from utilities.validate_args import cols_in_tbl_valid
from utilities.validate_args import is_col_array
from utilities.validate_args import array_col_has_no_null
from utilities.validate_args import get_expr_type
from utilities.utilities import unique_string
from utilities.control import MinWarning
def knn_validate_src(schema_madlib, point_source, point_column_name, point_id,
label_column_name, test_source, test_column_name,
test_id, output_table, k, output_neighbors, **kwargs):
input_tbl_valid(point_source, 'kNN')
input_tbl_valid(test_source, 'kNN')
output_tbl_valid(output_table, 'kNN')
if label_column_name is not None and label_column_name != '':
cols_in_tbl_valid(
point_source,
(label_column_name,
point_column_name),
'kNN')
cols_in_tbl_valid(point_source, (point_column_name, point_id), 'kNN')
cols_in_tbl_valid(test_source, (test_column_name, test_id), 'kNN')
if not is_col_array(point_source, point_column_name):
plpy.error("kNN Error: Feature column '{0}' in train table is not"
" an array.").format(point_column_name)
if not is_col_array(test_source, test_column_name):
plpy.error("kNN Error: Feature column '{0}' in test table is not"
" an array.").format(test_column_name)
if not array_col_has_no_null(point_source, point_column_name):
plpy.error("kNN Error: Feature column '{0}' in train table has some"
" NULL values.".format(point_column_name))
if not array_col_has_no_null(test_source, test_column_name):
plpy.error("kNN Error: Feature column '{0}' in test table has some"
" NULL values.".format(test_column_name))
if k is None:
k = 1
if k <= 0:
plpy.error("kNN Error: k={0} is an invalid value, must be greater"
"than 0.".format(k))
bound = plpy.execute("SELECT {k} <= count(*) AS bound FROM {tbl}".
format(k=k, tbl=point_source))[0]['bound']
if not bound:
plpy.error("kNN Error: k={0} is greater than number of rows in"
" training table.".format(k))
if label_column_name is not None and label_column_name != '':
col_type = get_expr_type(label_column_name, point_source).lower()
if col_type not in ['integer', 'double precision', 'float', 'boolean']:
plpy.error("kNN error: Data type '{0}' is not a valid type for"
" column '{1}' in table '{2}'.".
format(col_type, label_column_name, point_source))
col_type_test = get_expr_type(test_id, test_source).lower()
if col_type_test not in ['integer']:
plpy.error("kNN Error: Data type '{0}' is not a valid type for"
" column '{1}' in table '{2}'.".
format(col_type_test, test_id, test_source))
return k
# ------------------------------------------------------------------------------
def knn(schema_madlib, point_source, point_column_name, point_id, label_column_name,
test_source, test_column_name, test_id, output_table, k, output_neighbors):
"""
KNN function to find the K Nearest neighbours
Args:
@param schema_madlib Name of the Madlib Schema
@param point_source Training data table
@param point_column_name Name of the column with training data
@param point_id Name of the column having ids of data
point in train data table
points.
@param label_column_name Name of the column with labels/values
of training data points.
@param test_source Name of the table containing the test
data points.
@param test_column_name Name of the column with testing data
points.
@param test_id Name of the column having ids of data
points in test data table.
@param output_table Name of the table to store final
results.
@param operation Flag for the operation:
'c' for classification and
'r' for regression
@param k default: 1. Number of nearest
neighbors to consider
@output_neighbours Outputs the list of k-nearest neighbors
that were used in the voting/averaging.
Returns:
VARCHAR Name of the output table.
"""
with MinWarning('warning'):
k_val = knn_validate_src(schema_madlib, point_source,
point_column_name, point_id, label_column_name,
test_source, test_column_name, test_id,
output_table, k, output_neighbors)
x_temp_table = unique_string(desp='x_temp_table')
y_temp_table = unique_string(desp='y_temp_table')
label_col_temp = unique_string(desp='label_col_temp')
test_id_temp = unique_string(desp='test_id_temp')
if output_neighbors is None or '':
output_neighbors = False
interim_table = unique_string(desp='interim_table')
if label_column_name is None or label_column_name == '':
plpy.execute(
"""
CREATE TEMP TABLE {interim_table} AS
SELECT * FROM
(
SELECT row_number() over
(partition by {test_id_temp} order by dist) AS r,
{x_temp_table}.*
FROM
(
SELECT test.{test_id} AS {test_id_temp} ,
train.{point_id} as train_id ,
{schema_madlib}.squared_dist_norm2(
train.{point_column_name},
test.{test_column_name})
AS dist
FROM {point_source} AS train, {test_source} AS test
) {x_temp_table}
) {y_temp_table}
WHERE {y_temp_table}.r <= {k_val}
""".format(**locals()))
plpy.execute(
"""
CREATE TABLE {output_table} AS
SELECT {test_id_temp} AS id, {test_column_name} ,
CASE WHEN {output_neighbors}
THEN array_agg(knn_temp.train_id)
ELSE NULL END AS k_nearest_neighbours
FROM pg_temp.{interim_table} AS knn_temp
join
{test_source} AS knn_test ON
knn_temp.{test_id_temp} = knn_test.{test_id}
GROUP BY {test_id_temp} , {test_column_name}
""".format(**locals()))
return
is_classification = False
label_column_type = get_expr_type(
label_column_name, point_source).lower()
if label_column_type in ['boolean', 'integer', 'text']:
is_classification = True
convert_boolean_to_int = '::INTEGER'
else:
is_classification = False
plpy.execute(
"""
CREATE TEMP TABLE {interim_table} AS
SELECT * FROM
(
SELECT row_number() over
(partition by {test_id_temp} order by dist) AS r,
{x_temp_table}.*
FROM
(
SELECT test.{test_id} AS {test_id_temp} ,
train.{point_id} as train_id ,
{schema_madlib}.squared_dist_norm2(
train.{point_column_name},
test.{test_column_name})
AS dist,
train.{label_column_name}{cast_to_int}
AS {label_col_temp}
FROM {point_source} AS train, {test_source} AS test
) {x_temp_table}
) {y_temp_table}
WHERE {y_temp_table}.r <= {k_val}
""".format(cast_to_int='::INTEGER' if is_classification else '',
**locals()))
knn_create_table = 'CREATE TABLE ' + output_table + ' AS ' \
'SELECT ' + test_id_temp + ' AS id,' + test_column_name + ','
knn_pred_class = schema_madlib + \
'.mode(' + label_col_temp + ') AS prediction'
knn_pred_reg = 'avg(' + label_col_temp + ') AS prediction'
knn_neighbours = ', array_agg(knn_temp.train_id) AS k_nearest_neighbours '
knn_group_by = 'FROM pg_temp.' + interim_table + ' AS knn_temp join ' \
+ test_source + ' AS knn_test ON knn_temp.' + test_id_temp + '= knn_test.' \
+ test_id + ' GROUP BY ' + test_id_temp + ', ' + test_column_name
if is_classification:
if output_neighbors:
plpy.execute("""{knn_create_table}{knn_pred_class}
{knn_neighbours}{knn_group_by}""".format(**locals()))
else:
plpy.execute(""" {knn_create_table}{knn_pred_class}
{knn_group_by}""".format(**locals()))
else:
if output_neighbors:
plpy.execute(""" {knn_create_table}{knn_pred_reg}
{knn_neighbours}{knn_group_by}""".format(**locals()))
else:
plpy.execute("""{knn_create_table}{knn_pred_reg}
{knn_group_by}""".format(**locals()))
plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table))
# ------------------------------------------------------------------------------