| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| m4_changequote(`<!', `!>') |
| |
| """ |
| @file knn.py_in |
| |
| @brief knn: Driver functions |
| |
| @namespace knn |
| |
| @brief knn: Driver functions |
| """ |
| |
| import plpy |
| from utilities.validate_args import table_exists |
| from utilities.validate_args import table_is_empty |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.validate_args import is_col_array |
| from utilities.validate_args import array_col_has_no_null |
| from utilities.validate_args import get_cols_and_types |
| |
| STATE_IN_MEM = m4_ifdef(<!__HAWQ__!>, <!True!>, <!False!>) |
| HAS_FUNCTION_PROPERTIES = m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!True!>, <!False!>) |
| UDF_ON_SEGMENT_NOT_ALLOWED = m4_ifdef(<!__UDF_ON_SEGMENT_NOT_ALLOWED__!>, <!True!>, <!False!>) |
| # ---------------------------------------------------------------------- |
| |
| |
| def knn_validate_src(schema_madlib, point_source, point_column_name, label_column_name, |
| test_source, test_column_name, id_column_name, output_table, operation, k, **kwargs): |
| if not operation or operation not in ['c', 'r']: |
| plpy.error("kNN Error: operation='{0}' is an invalid value, has to be 'r' for regression OR 'c' for classification.".format(operation)) |
| if not point_source: |
| plpy.error("kNN Error: Invalid training table name.") |
| if not table_exists(point_source): |
| plpy.error("kNN Error: Training table '{0}' does not exist.".format(point_source)) |
| if table_is_empty(point_source): |
| plpy.error("kNN Error: Training table '{0}' is empty.".format(point_source)) |
| |
| if not test_source: |
| plpy.error("kNN Error: Invalid test table name.") |
| if not table_exists(test_source): |
| plpy.error("kNN Error: Test table '{0}' does not exist.".format(test_source)) |
| if table_is_empty(test_source): |
| plpy.error("kNN Error: Test table '{0}' is empty.".format(test_source)) |
| |
| for c in (label_column_name, point_column_name): |
| if not c: |
| plpy.error("kNN Error: Invalid column name in training table.") |
| if not columns_exist_in_table(point_source, [c]): |
| plpy.error("kNN Error: " + \ |
| "Column '{0}' does not exist in {1}.".format(c, point_source)) |
| |
| for c in (test_column_name, id_column_name): |
| if not c: |
| plpy.error("kNN Error: Invalid column name in test table.") |
| if not columns_exist_in_table(test_source, [c]): |
| plpy.error("kNN Error: " + \ |
| "Column '{0}' does not exist in {1}.".format(c, test_source)) |
| |
| if not is_col_array(point_source, point_column_name): |
| plpy.error("kNN Error: " + \ |
| "Feature column '{0}' in train table is not an array.".format(point_column_name)) |
| if not is_col_array(test_source, test_column_name): |
| plpy.error("kNN Error: " + \ |
| "Feature column '{0}' in test table is not an array.".format(test_column_name)) |
| |
| if not array_col_has_no_null(point_source, point_column_name): |
| plpy.error("kNN Error: " + \ |
| "Feature column '{0}' in train table has some NULL values.".format(point_column_name)) |
| if not array_col_has_no_null(test_source, test_column_name): |
| plpy.error("kNN Error: " + \ |
| "Feature column '{0}' in test table has some NULL values.".format(test_column_name)) |
| |
| if not output_table: |
| plpy.error("kNN Error: Invalid output table name") |
| if table_exists(output_table): |
| plpy.error("kNN Error: Table '{0}' already exists, cannot use it as output table.".format(output_table)) |
| |
| if k is None: |
| k = 1 |
| if k<=0: |
| plpy.error("kNN Error: k='{0}' is an invalid value, must be greater than 0.".format(k)) |
| bound = plpy.execute("""SELECT {k} <= count(*) |
| AS bound FROM {tbl}""".format(k=k, |
| point_column_name=point_column_name, tbl=point_source))[0]['bound'] |
| if not bound: |
| plpy.error("kNN Error: " + \ |
| "k='{0}' is greater than number of rows in training table.".format(k)) |
| |
| colTypesList = get_cols_and_types(point_source) |
| colType = '' |
| for type in colTypesList: |
| if type[0] == label_column_name: |
| colType = type[1] |
| break |
| if colType not in ['INTEGER','integer','double precision','DOUBLE PRECISION','float','FLOAT','boolean','BOOLEAN'] : |
| plpy.error("kNN Error: " + \ |
| "Data type '{0}' is not a valid type for column '{1}' in table '{2}'.".format(colType, label_column_name, point_source)) |
| |
| colTypesTestList = get_cols_and_types(test_source) |
| colType = '' |
| for type in colTypesTestList: |
| if type[0] == id_column_name: |
| colType = type[1] |
| break |
| if colType not in ['INTEGER','integer'] : |
| plpy.error("kNN Error: " + \ |
| "Data type '{0}' is not a valid type for column '{1}' in table '{2}'.".format(colType, id_column_name, test_source)) |
| return k |
| |
| # ---------------------------------------------------------------------- |
| m4_changequote(<!`!>, <!'!>) |