| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import plpy |
| from utilities import _assert |
| from utilities import is_valid_psql_type |
| from utilities import NUMERIC, ONLY_ARRAY |
| from validate_args import get_expr_type |
| |
| def validate_dependent_var_for_minibatch(table_name, var_name, expr_type=None): |
| # The dependent variable is always a double precision array in |
| # preprocessed data (so check for numeric types) |
| if not expr_type: |
| expr_type = get_expr_type(var_name, table_name) |
| _assert(is_valid_psql_type(expr_type, NUMERIC | ONLY_ARRAY), |
| "Dependent variable column {0} in table {1} " |
| "should be a numeric array.".format(var_name, table_name)) |
| |
| query = """SELECT array_upper({var_name}, 2) > 1 AS is_encoded FROM |
| {table_name} LIMIT 1;""".format(**locals()) |
| result = plpy.execute(query) |
| if not result[0]["is_encoded"]: |
| plpy.error("Dependent variable column {0} in table {1} should be " |
| "minibatched and one hot encoded. You might need to re run " |
| "the minibatch_preprocessor function and make sure that " |
| "the variable is encoded.".format(var_name, table_name)) |
| |
| def validate_bytea_var_for_minibatch(table_name, var_name, expr_type=None): |
| if not expr_type: |
| expr_type = get_expr_type(var_name, table_name) |
| _assert(expr_type == 'bytea', |
| "Dependent variable column {0} in table {1} " |
| "should be minibatched. You might need to re run " |
| "the preprocessor function.".format(var_name, table_name)) |