blob: c0d9cd7be094405a32c5e761474aa63b72254497 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
m4_changequote(`<!', `!>')
"""
@file knn.py_in
@brief knn: Driver functions
@namespace knn
@brief knn: Driver functions
"""
import plpy
from utilities.validate_args import table_exists
from utilities.validate_args import table_is_empty
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import columns_exist_in_table
from utilities.validate_args import is_col_array
from utilities.validate_args import array_col_has_no_null
from utilities.validate_args import get_cols_and_types
STATE_IN_MEM = m4_ifdef(<!__HAWQ__!>, <!True!>, <!False!>)
HAS_FUNCTION_PROPERTIES = m4_ifdef(<!__HAS_FUNCTION_PROPERTIES__!>, <!True!>, <!False!>)
UDF_ON_SEGMENT_NOT_ALLOWED = m4_ifdef(<!__UDF_ON_SEGMENT_NOT_ALLOWED__!>, <!True!>, <!False!>)
# ----------------------------------------------------------------------
def knn_validate_src(schema_madlib, point_source, point_column_name, label_column_name,
test_source, test_column_name, id_column_name, output_table, operation, k, **kwargs):
if not operation or operation not in ['c', 'r']:
plpy.error("kNN Error: operation='{0}' is an invalid value, has to be 'r' for regression OR 'c' for classification.".format(operation))
if not point_source:
plpy.error("kNN Error: Invalid training table name.")
if not table_exists(point_source):
plpy.error("kNN Error: Training table '{0}' does not exist.".format(point_source))
if table_is_empty(point_source):
plpy.error("kNN Error: Training table '{0}' is empty.".format(point_source))
if not test_source:
plpy.error("kNN Error: Invalid test table name.")
if not table_exists(test_source):
plpy.error("kNN Error: Test table '{0}' does not exist.".format(test_source))
if table_is_empty(test_source):
plpy.error("kNN Error: Test table '{0}' is empty.".format(test_source))
for c in (label_column_name, point_column_name):
if not c:
plpy.error("kNN Error: Invalid column name in training table.")
if not columns_exist_in_table(point_source, [c]):
plpy.error("kNN Error: " + \
"Column '{0}' does not exist in {1}.".format(c, point_source))
for c in (test_column_name, id_column_name):
if not c:
plpy.error("kNN Error: Invalid column name in test table.")
if not columns_exist_in_table(test_source, [c]):
plpy.error("kNN Error: " + \
"Column '{0}' does not exist in {1}.".format(c, test_source))
if not is_col_array(point_source, point_column_name):
plpy.error("kNN Error: " + \
"Feature column '{0}' in train table is not an array.".format(point_column_name))
if not is_col_array(test_source, test_column_name):
plpy.error("kNN Error: " + \
"Feature column '{0}' in test table is not an array.".format(test_column_name))
if not array_col_has_no_null(point_source, point_column_name):
plpy.error("kNN Error: " + \
"Feature column '{0}' in train table has some NULL values.".format(point_column_name))
if not array_col_has_no_null(test_source, test_column_name):
plpy.error("kNN Error: " + \
"Feature column '{0}' in test table has some NULL values.".format(test_column_name))
if not output_table:
plpy.error("kNN Error: Invalid output table name")
if table_exists(output_table):
plpy.error("kNN Error: Table '{0}' already exists, cannot use it as output table.".format(output_table))
if k is None:
k = 1
if k<=0:
plpy.error("kNN Error: k='{0}' is an invalid value, must be greater than 0.".format(k))
bound = plpy.execute("""SELECT {k} <= count(*)
AS bound FROM {tbl}""".format(k=k,
point_column_name=point_column_name, tbl=point_source))[0]['bound']
if not bound:
plpy.error("kNN Error: " + \
"k='{0}' is greater than number of rows in training table.".format(k))
colTypesList = get_cols_and_types(point_source)
colType = ''
for type in colTypesList:
if type[0] == label_column_name:
colType = type[1]
break
if colType not in ['INTEGER','integer','double precision','DOUBLE PRECISION','float','FLOAT','boolean','BOOLEAN'] :
plpy.error("kNN Error: " + \
"Data type '{0}' is not a valid type for column '{1}' in table '{2}'.".format(colType, label_column_name, point_source))
colTypesTestList = get_cols_and_types(test_source)
colType = ''
for type in colTypesTestList:
if type[0] == id_column_name:
colType = type[1]
break
if colType not in ['INTEGER','integer'] :
plpy.error("kNN Error: " + \
"Data type '{0}' is not a valid type for column '{1}' in table '{2}'.".format(colType, id_column_name, test_source))
return k
# ----------------------------------------------------------------------
m4_changequote(<!`!>, <!'!>)