blob: 5e2d8f070075539d6deab7c900678a3ccb321ba4 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import plpy
from utilities import _assert
from utilities import is_valid_psql_type
from utilities import NUMERIC, ONLY_ARRAY
from validate_args import get_expr_type
def validate_dependent_var_for_minibatch(table_name, var_name, expr_type=None):
# The dependent variable is always a double precision array in
# preprocessed data (so check for numeric types)
if not expr_type:
expr_type = get_expr_type(var_name, table_name)
_assert(is_valid_psql_type(expr_type, NUMERIC | ONLY_ARRAY),
"Dependent variable column {0} in table {1} "
"should be a numeric array.".format(var_name, table_name))
query = """SELECT array_upper({var_name}, 2) > 1 AS is_encoded FROM
{table_name} LIMIT 1;""".format(**locals())
result = plpy.execute(query)
if not result[0]["is_encoded"]:
plpy.error("Dependent variable column {0} in table {1} should be "
"minibatched and one hot encoded. You might need to re run "
"the minibatch_preprocessor function and make sure that "
"the variable is encoded.".format(var_name, table_name))
def validate_bytea_var_for_minibatch(table_name, var_name, expr_type=None):
if not expr_type:
expr_type = get_expr_type(var_name, table_name)
_assert(expr_type == 'bytea',
"Dependent variable column {0} in table {1} "
"should be minibatched. You might need to re run "
"the preprocessor function.".format(var_name, table_name))