| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import plpy |
| from utilities.utilities import is_psql_char_type |
| from utilities.validate_args import get_col_dimension |
| from utilities.validate_args import get_expr_type |
| |
| m4_changequote(`<!', `!>') |
| |
| QUOTE_DELIMITER="$__madlib__$" |
| |
| |
| def get_distinct_col_levels(source_table, col_name, col_type=None, include_nulls=False): |
| """ |
| Add description here |
| :return: |
| """ |
| if not col_type: |
| col_type = get_expr_type(col_name, source_table) |
| |
| if is_psql_char_type(col_type): |
| dep_var_text_patched = "quote_literal({0})".format(col_name) |
| else: |
| dep_var_text_patched = col_name |
| |
| where_clause = 'WHERE ({0}) is NOT NULL'.format(col_name) |
| if include_nulls: |
| where_clause = '' |
| levels = plpy.execute(""" |
| SELECT DISTINCT {dep_var_text_patched} AS levels |
| FROM {source_table} |
| {where_clause} |
| """.format(**locals())) |
| |
| levels = sorted(l["levels"] for l in levels) |
| return levels |
| |
| |
| def get_one_hot_encoded_expr(col_name, col_levels): |
| """ |
| All the values in col_levels should have been quoted and escaped with |
| the sql function `quote_literal`. |
| :param col_name: |
| :param col_levels: |
| |
| :return: |
| """ |
| one_hot_encoded_expr = ["({0}) = {1}".format( |
| col_name, c) for c in col_levels] |
| return 'ARRAY[{0}]::INTEGER[]'.format(', '.join(one_hot_encoded_expr)) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def quote_literal(input_str): |
| """ Return the given string suitably quoted to be used as a string literal |
| in an SQL statement string. |
| |
| The plpy.quote_literal is not available in GPDB 4.3 - this function is |
| provided as a proxy for that platform. For all other platforms this |
| function, forwards the argument to plpy.quote_literal. |
| """ |
| try: |
| return plpy.quote_literal(str(input_str)) |
| except AttributeError: |
| # plpy.quote_literal is not supported, we work around by returning |
| # dollar-quoted string with obscure tag |
| return "{qd}{input_str}{qd}".format(qd=QUOTE_DELIMITER, |
| input_str=input_str) |
| # ------------------------------------------------------------------------------ |
| |
| def is_col_1d_array(source_table, col_name): |
| query = """ |
| SELECT array_upper({0}, 2) IS NULL AS n_y |
| FROM {1} |
| LIMIT 1 |
| """.format(col_name, source_table) |
| result = plpy.execute(query) |
| return result[0]["n_y"] |
| |
| # ------------------------------------------------------------------------------ |
| # This function runs postgres array_ndims function to get |
| # the dimension of an array. For example if it is a 3 |
| # dimension array it will be an array with 3 elements |
| # like [32,32,3]. |
| def get_ndims(source_table, col_name): |
| array_ndims = plpy.execute(""" |
| SELECT array_ndims({0}) AS ndims |
| FROM {1} |
| """.format(col_name, source_table), 1)[0]['ndims'] |
| return array_ndims |
| |
| # This function is to calculate the total `length` of a |
| # multi dimensional array. For example, if an array is |
| # with 3 dimensions and ndims=[32,32,3], this function |
| # will return the product of them, which is 32*32*3 |
| def get_product_of_dimensions(source_table, col_name): |
| ndims = get_ndims(source_table, col_name) |
| dimension_product = 1 |
| for i in range(1, ndims + 1): |
| dimension = get_col_dimension(source_table, col_name, i) |
| dimension_product *= dimension |
| return dimension_product |