blob: d39c845a0ffa1270e04430e8d64720a624e263a8 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import plpy
from utilities.utilities import is_psql_char_type
from utilities.validate_args import get_col_dimension
from utilities.validate_args import get_expr_type
m4_changequote(`<!', `!>')
QUOTE_DELIMITER="$__madlib__$"
def get_distinct_col_levels(source_table, col_name, col_type=None, include_nulls=False):
"""
Add description here
:return:
"""
if not col_type:
col_type = get_expr_type(col_name, source_table)
if is_psql_char_type(col_type):
dep_var_text_patched = "quote_literal({0})".format(col_name)
else:
dep_var_text_patched = col_name
where_clause = 'WHERE ({0}) is NOT NULL'.format(col_name)
if include_nulls:
where_clause = ''
levels = plpy.execute("""
SELECT DISTINCT {dep_var_text_patched} AS levels
FROM {source_table}
{where_clause}
""".format(**locals()))
levels = sorted(l["levels"] for l in levels)
return levels
def get_one_hot_encoded_expr(col_name, col_levels):
"""
All the values in col_levels should have been quoted and escaped with
the sql function `quote_literal`.
:param col_name:
:param col_levels:
:return:
"""
one_hot_encoded_expr = ["({0}) = {1}".format(
col_name, c) for c in col_levels]
return 'ARRAY[{0}]::INTEGER[]'.format(', '.join(one_hot_encoded_expr))
# ------------------------------------------------------------------------------
def quote_literal(input_str):
""" Return the given string suitably quoted to be used as a string literal
in an SQL statement string.
The plpy.quote_literal is not available in GPDB 4.3 - this function is
provided as a proxy for that platform. For all other platforms this
function, forwards the argument to plpy.quote_literal.
"""
try:
return plpy.quote_literal(str(input_str))
except AttributeError:
# plpy.quote_literal is not supported, we work around by returning
# dollar-quoted string with obscure tag
return "{qd}{input_str}{qd}".format(qd=QUOTE_DELIMITER,
input_str=input_str)
# ------------------------------------------------------------------------------
def quote_nullable(input_str):
if input_str is not None:
return quote_literal(input_str)
else:
return 'NULL'
# ------------------------------------------------------------------------------
def is_col_1d_array(source_table, col_name):
query = """
SELECT array_upper({0}, 2) IS NULL AS n_y
FROM {1}
LIMIT 1
""".format(col_name, source_table)
result = plpy.execute(query)
return result[0]["n_y"]
# ------------------------------------------------------------------------------
# This function runs postgres array_ndims function to get
# the dimension of an array. For example if it is a 3
# dimension array it will be an array with 3 elements
# like [32,32,3].
def get_ndims(source_table, col_name):
array_ndims = plpy.execute("""
SELECT array_ndims({0}) AS ndims
FROM {1}
""".format(col_name, source_table), 1)[0]['ndims']
return array_ndims
# This function is to calculate the total `length` of a
# multi dimensional array. For example, if an array is
# with 3 dimensions and ndims=[32,32,3], this function
# will return the product of them, which is 32*32*3
def get_product_of_dimensions(source_table, col_name):
ndims = get_ndims(source_table, col_name)
dimension_product = 1
for i in range(1, ndims + 1):
dimension = get_col_dimension(source_table, col_name, i)
dimension_product *= dimension
return dimension_product