| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """ |
| @file knn.py_in |
| |
| @brief knn: Driver functions |
| |
| @namespace knn |
| |
| """ |
| |
| import plpy |
| from utilities.validate_args import input_tbl_valid, output_tbl_valid |
| from utilities.validate_args import cols_in_tbl_valid |
| from utilities.validate_args import is_col_array |
| from utilities.validate_args import array_col_has_no_null |
| from utilities.validate_args import get_expr_type |
| from utilities.utilities import unique_string |
| from utilities.control import MinWarning |
| |
| |
| def knn_validate_src(schema_madlib, point_source, point_column_name, point_id, |
| label_column_name, test_source, test_column_name, |
| test_id, output_table, k, output_neighbors, **kwargs): |
| input_tbl_valid(point_source, 'kNN') |
| input_tbl_valid(test_source, 'kNN') |
| output_tbl_valid(output_table, 'kNN') |
| if label_column_name is not None and label_column_name != '': |
| cols_in_tbl_valid( |
| point_source, |
| (label_column_name, |
| point_column_name), |
| 'kNN') |
| cols_in_tbl_valid(point_source, (point_column_name, point_id), 'kNN') |
| cols_in_tbl_valid(test_source, (test_column_name, test_id), 'kNN') |
| |
| if not is_col_array(point_source, point_column_name): |
| plpy.error("kNN Error: Feature column '{0}' in train table is not" |
| " an array.").format(point_column_name) |
| if not is_col_array(test_source, test_column_name): |
| plpy.error("kNN Error: Feature column '{0}' in test table is not" |
| " an array.").format(test_column_name) |
| |
| if not array_col_has_no_null(point_source, point_column_name): |
| plpy.error("kNN Error: Feature column '{0}' in train table has some" |
| " NULL values.".format(point_column_name)) |
| if not array_col_has_no_null(test_source, test_column_name): |
| plpy.error("kNN Error: Feature column '{0}' in test table has some" |
| " NULL values.".format(test_column_name)) |
| |
| if k is None: |
| k = 1 |
| if k <= 0: |
| plpy.error("kNN Error: k={0} is an invalid value, must be greater" |
| "than 0.".format(k)) |
| bound = plpy.execute("SELECT {k} <= count(*) AS bound FROM {tbl}". |
| format(k=k, tbl=point_source))[0]['bound'] |
| if not bound: |
| plpy.error("kNN Error: k={0} is greater than number of rows in" |
| " training table.".format(k)) |
| |
| if label_column_name is not None and label_column_name != '': |
| col_type = get_expr_type(label_column_name, point_source).lower() |
| if col_type not in ['integer', 'double precision', 'float', 'boolean']: |
| plpy.error("kNN error: Data type '{0}' is not a valid type for" |
| " column '{1}' in table '{2}'.". |
| format(col_type, label_column_name, point_source)) |
| |
| col_type_test = get_expr_type(test_id, test_source).lower() |
| if col_type_test not in ['integer']: |
| plpy.error("kNN Error: Data type '{0}' is not a valid type for" |
| " column '{1}' in table '{2}'.". |
| format(col_type_test, test_id, test_source)) |
| return k |
| # ------------------------------------------------------------------------------ |
| |
| |
| def knn(schema_madlib, point_source, point_column_name, point_id, label_column_name, |
| test_source, test_column_name, test_id, output_table, k, output_neighbors): |
| """ |
| KNN function to find the K Nearest neighbours |
| Args: |
| @param schema_madlib Name of the Madlib Schema |
| @param point_source Training data table |
| @param point_column_name Name of the column with training data |
| @param point_id Name of the column having ids of data |
| point in train data table |
| points. |
| @param label_column_name Name of the column with labels/values |
| of training data points. |
| @param test_source Name of the table containing the test |
| data points. |
| @param test_column_name Name of the column with testing data |
| points. |
| @param test_id Name of the column having ids of data |
| points in test data table. |
| @param output_table Name of the table to store final |
| results. |
| @param operation Flag for the operation: |
| 'c' for classification and |
| 'r' for regression |
| @param k default: 1. Number of nearest |
| neighbors to consider |
| @output_neighbours Outputs the list of k-nearest neighbors |
| that were used in the voting/averaging. |
| Returns: |
| VARCHAR Name of the output table. |
| """ |
| with MinWarning('warning'): |
| k_val = knn_validate_src(schema_madlib, point_source, |
| point_column_name, point_id, label_column_name, |
| test_source, test_column_name, test_id, |
| output_table, k, output_neighbors) |
| |
| x_temp_table = unique_string(desp='x_temp_table') |
| y_temp_table = unique_string(desp='y_temp_table') |
| label_col_temp = unique_string(desp='label_col_temp') |
| test_id_temp = unique_string(desp='test_id_temp') |
| |
| if output_neighbors is None or '': |
| output_neighbors = False |
| |
| interim_table = unique_string(desp='interim_table') |
| |
| if label_column_name is None or label_column_name == '': |
| plpy.execute( |
| """ |
| CREATE TEMP TABLE {interim_table} AS |
| SELECT * FROM |
| ( |
| SELECT row_number() over |
| (partition by {test_id_temp} order by dist) AS r, |
| {x_temp_table}.* |
| FROM |
| ( |
| SELECT test.{test_id} AS {test_id_temp} , |
| train.{point_id} as train_id , |
| {schema_madlib}.squared_dist_norm2( |
| train.{point_column_name}, |
| test.{test_column_name}) |
| AS dist |
| FROM {point_source} AS train, {test_source} AS test |
| ) {x_temp_table} |
| ) {y_temp_table} |
| WHERE {y_temp_table}.r <= {k_val} |
| """.format(**locals())) |
| plpy.execute( |
| """ |
| CREATE TABLE {output_table} AS |
| SELECT {test_id_temp} AS id, {test_column_name} , |
| CASE WHEN {output_neighbors} |
| THEN array_agg(knn_temp.train_id) |
| ELSE NULL END AS k_nearest_neighbours |
| FROM pg_temp.{interim_table} AS knn_temp |
| join |
| {test_source} AS knn_test ON |
| knn_temp.{test_id_temp} = knn_test.{test_id} |
| GROUP BY {test_id_temp} , {test_column_name} |
| """.format(**locals())) |
| return |
| |
| is_classification = False |
| label_column_type = get_expr_type( |
| label_column_name, point_source).lower() |
| if label_column_type in ['boolean', 'integer', 'text']: |
| is_classification = True |
| convert_boolean_to_int = '::INTEGER' |
| else: |
| is_classification = False |
| |
| plpy.execute( |
| """ |
| CREATE TEMP TABLE {interim_table} AS |
| SELECT * FROM |
| ( |
| SELECT row_number() over |
| (partition by {test_id_temp} order by dist) AS r, |
| {x_temp_table}.* |
| FROM |
| ( |
| SELECT test.{test_id} AS {test_id_temp} , |
| train.{point_id} as train_id , |
| {schema_madlib}.squared_dist_norm2( |
| train.{point_column_name}, |
| test.{test_column_name}) |
| AS dist, |
| train.{label_column_name}{cast_to_int} |
| AS {label_col_temp} |
| FROM {point_source} AS train, {test_source} AS test |
| ) {x_temp_table} |
| ) {y_temp_table} |
| WHERE {y_temp_table}.r <= {k_val} |
| """.format(cast_to_int='::INTEGER' if is_classification else '', |
| **locals())) |
| |
| knn_create_table = 'CREATE TABLE ' + output_table + ' AS ' \ |
| 'SELECT ' + test_id_temp + ' AS id,' + test_column_name + ',' |
| knn_pred_class = schema_madlib + \ |
| '.mode(' + label_col_temp + ') AS prediction' |
| knn_pred_reg = 'avg(' + label_col_temp + ') AS prediction' |
| knn_neighbours = ', array_agg(knn_temp.train_id) AS k_nearest_neighbours ' |
| knn_group_by = 'FROM pg_temp.' + interim_table + ' AS knn_temp join ' \ |
| + test_source + ' AS knn_test ON knn_temp.' + test_id_temp + '= knn_test.' \ |
| + test_id + ' GROUP BY ' + test_id_temp + ', ' + test_column_name |
| |
| if is_classification: |
| if output_neighbors: |
| plpy.execute("""{knn_create_table}{knn_pred_class} |
| {knn_neighbours}{knn_group_by}""".format(**locals())) |
| else: |
| plpy.execute(""" {knn_create_table}{knn_pred_class} |
| {knn_group_by}""".format(**locals())) |
| else: |
| if output_neighbors: |
| plpy.execute(""" {knn_create_table}{knn_pred_reg} |
| {knn_neighbours}{knn_group_by}""".format(**locals())) |
| else: |
| plpy.execute("""{knn_create_table}{knn_pred_reg} |
| {knn_group_by}""".format(**locals())) |
| |
| plpy.execute("DROP TABLE IF EXISTS {0}".format(interim_table)) |
| # ------------------------------------------------------------------------------ |