| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """ |
| @file pathing.py_in |
| |
| @brief Pathing functions |
| |
| @namespace utilities |
| """ |
| import plpy |
| import shlex |
| import string |
| import re |
| |
| from control import MinWarning |
| from utilities import unique_string, _assert, add_postfix |
| from utilities import py_list_to_sql_string |
| from validate_args import get_cols |
| from validate_args import input_tbl_valid, output_tbl_valid, is_var_valid |
| from validate_args import quote_ident |
| # ------------------------------------------------------------------------ |
| |
| m4_changequote(`<!', `!>') |
| |
| |
| def path(schema_madlib, source_table, output_table, partition_expr, |
| order_expr, pattern_expr, symbol_expr, agg_func, |
| persist_rows, overlapping_patterns, **kwargs): |
| """ |
| Perform regular pattern matching over a sequence of rows. |
| |
| Args: |
| @param schema_madlib: str, Name of the MADlib schema |
| @param source_table: str, Name of the input table/view |
| @param output_table: str, Name of the table to store result |
| @param partition_expr: str, Expression to partition (group) the input data |
| @param order_expr: str, Expression to order the input data |
| @param pattern_expr: str, Expression to define the pattern to search for |
| @param symbol_expr: str, Definition for each symbol, comma-separated list |
| @param agg_func: str, List of the result functions/aggregates to apply on matched patterns |
| |
| """ |
| with MinWarning("error"): |
| # check for both false and None |
| if not partition_expr: |
| partition_expr = "1 = 1" |
| if not persist_rows: |
| # persist_rows = None implies no preference |
| # persist_rows = False implies do not store the matched rows |
| persist_rows = not bool(agg_func) |
| if not overlapping_patterns: |
| overlapping_patterns = False |
| |
| _validate(source_table, output_table, partition_expr, order_expr, |
| pattern_expr, symbol_expr, agg_func, persist_rows) |
| |
| new_pattern_expr, long_sym_str, short_sym_str = _parse_symbol_str(symbol_expr, pattern_expr) |
| |
| # build variables for intermediate objects |
| input_with_id = unique_string('input_with_id') |
| matched_view = unique_string('matched_view') |
| id_col_name = unique_string('id_col') |
| matched_partitions = unique_string('matched_partitions') |
| seq_gen = unique_string('seq_gen') |
| short_sym_name_str = unique_string('short_sym') |
| long_sym_name_str = unique_string('long_sym') |
| match_to_row_id = unique_string('match_to_row_id') |
| match_id = unique_string('match_id') |
| |
| all_input_cols = [i.strip() for i in get_cols(source_table)] |
| all_input_cols_str = ', '.join(all_input_cols) |
| if persist_rows: |
| matched_rows = add_postfix(output_table, "_tuples") |
| table_or_view = 'TABLE' |
| else: |
| matched_rows = unique_string('matched_rows') |
| table_or_view = 'VIEW' |
| |
| |
| # build a new input temp table that contains a sequence and partition columns |
| split_p_cols = [i.strip() for i in partition_expr.split(',')] |
| p_col_names = [unique_string() for i in split_p_cols] |
| p_col_as_str = ','.join( |
| [i + " AS " + j for i, j in zip(split_p_cols, p_col_names)]) |
| p_col_name_str = ', '.join(p_col_names) |
| distribution = m4_ifdef(<!__POSTGRESQL__!>, <!''!>, |
| <!"DISTRIBUTED BY ({0})".format(p_col_name_str)!>) |
| plpy.execute(""" |
| CREATE TEMP TABLE {input_with_id} AS |
| SELECT |
| {p_col_as_str}, |
| *, |
| row_number() OVER() AS {id_col_name}, |
| CASE |
| {short_sym_str} |
| END AS {short_sym_name_str}, |
| CASE |
| {long_sym_str} |
| END AS {long_sym_name_str} |
| FROM {source_table} |
| {distribution} |
| """.format(**locals())) |
| # Explanation for computing the path matches: |
| # Match is performed using regular expression pattern matching on a |
| # string produced by concatenating the symbols. The exact rows that |
| # produce the match are identified by correlating the matched string |
| # indices with another array containing row ids. |
| |
| match_id_name = "__madlib_path_match_id__" if "match_id" in all_input_cols else "match_id" |
| symbol_name = "__madlib_path_symbol__" if "symbol" in all_input_cols else "symbol" |
| |
| plpy.execute(""" |
| CREATE {table_or_view} {matched_rows} AS |
| SELECT {all_input_cols_str}, |
| {long_sym_name_str} AS {symbol_name}, |
| {match_id} AS {match_id_name} |
| FROM |
| {input_with_id} as source, |
| ( |
| SELECT |
| unnest((matched).id) as {match_id}, |
| unnest((matched).row_id) as {match_to_row_id} |
| FROM |
| ( |
| SELECT |
| {m}.path_pattern_match( |
| array_to_string(array_agg({short_sym_name_str} ORDER BY {order_expr}), '')::text, |
| '{new_pattern_expr}'::text, |
| array_agg({id_col_name} ORDER BY {order_expr})::float8[], |
| {overlapping_patterns}::boolean |
| ) as matched |
| FROM {input_with_id} |
| WHERE {short_sym_name_str} is NOT NULL |
| GROUP BY {p_col_name_str} |
| ) q |
| ) as matched_rows |
| WHERE source.{id_col_name} = matched_rows.{match_to_row_id} |
| """.format(m=schema_madlib, **locals())) |
| |
| quoted_split_p_cols = [quote_ident(i) for i in split_p_cols] |
| p_col_orig_name_str = ','.join( |
| [i + " AS " + j for i, j in zip(split_p_cols, quoted_split_p_cols)]) |
| if agg_func: |
| if partition_expr == '1 = 1': |
| # no partition |
| plpy.execute(""" |
| CREATE TABLE {output_table} AS |
| SELECT |
| {match_id_name}, |
| {agg_func} |
| FROM {matched_rows} |
| GROUP BY {match_id_name} |
| """.format(**locals())) |
| else: |
| plpy.execute(""" |
| CREATE TABLE {output_table} AS |
| SELECT |
| {p_col_orig_name_str}, |
| {match_id_name}, |
| {agg_func} |
| FROM {matched_rows} |
| GROUP BY {partition_expr}, {match_id_name} |
| """.format(**locals())) |
| result = "Aggregation result available in table " + output_table |
| else: |
| result = "No aggregation table created" |
| if not persist_rows: |
| plpy.execute("DROP VIEW IF EXISTS " + matched_rows) |
| else: |
| result += "\n Matched tuples can be found in table " + matched_rows |
| plpy.execute("DROP TABLE IF EXISTS " + input_with_id) |
| return result |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _validate(source_table, output_table, partition_expr, order_expr, |
| pattern_expr, symbol_expr, agg_func, persist_rows): |
| input_tbl_valid(source_table, 'Path') |
| output_tbl_valid(output_table, 'Path') |
| if persist_rows: |
| output_tbl_valid(add_postfix(output_table, "_tuples"), 'Path') |
| |
| # ensure the expressions are not None or empty strings |
| _assert(partition_expr, "Path error: Invalid partition expression") |
| _assert(order_expr, "Path error: Invalid order expression") |
| # ensure the partition/order expression can actually be used |
| _assert(is_var_valid(source_table, partition_expr, order_expr), |
| "Path error: invalid partition expression or order expression") |
| |
| _assert(pattern_expr, "Path error: Invalid pattern expression") |
| _assert(symbol_expr, "Path error: Invalid symbol expression") |
| # ---------------------------------------------------------------------- |
| |
| |
| def _parse_symbol_str(symbol_expr, pattern_expr): |
| """ Parse symbol definition to build a CASE statement string |
| and return a mapping of the definitions. |
| |
| Currently only single-character symbols are allowed. |
| Postgresql regular expression match functions will be used on a string |
| of symbols, where each symbol represents a tuple. Only single-length |
| symbols can be used to maintain 1:1 correspondence between symbol and tuple. |
| |
| To allow input for multicharacter symbol for user convenience, inputed |
| symbols are mapped to a single character in the pattern expression. |
| This updated pattern expression is returned back to the caller. |
| |
| Further, two case statements are built to mark each tuple in the input table |
| with the corresponding original (user-supplied) symbol and the |
| new (single-character) symbol. |
| |
| Args: |
| @param symbol_expr: str, A comma-separated string containing |
| symbol definitions of the form: <symbol> := <symbol_definition> |
| @param sym_mapping_tbl: str, Name of the table to output the |
| correspondence table between symbol |
| @param pattern_expr: str, The pattern expression where the original |
| symbols are used. The original symbols are replaced by the corresponding |
| new symbols. |
| |
| Example: |
| symbol_expr = ('BEFORE:=start >= \'0:00:00\' and start < \'9:30:00\', ' |
| 'MARKET:=start >= \'9:30:00\' and start < \'16:00:00\'') |
| pattern_expr = "(BEFORE)*(MARKET)*" |
| |
| returns ("a*b*", |
| "CASE |
| WHEN start >= \'0:00:00\' and start < \'9:30:00\' THEN 'BEFORE' |
| WHEN start >= \'9:30:00\' and start < \'16:00:00\' THEN 'MARKET' |
| END", |
| "CASE |
| WHEN start >= \'0:00:00\' and start < \'9:30:00\' THEN 'a' |
| WHEN start >= \'9:30:00\' and start < \'16:00:00\' THEN 'b' |
| END" |
| Returns: |
| (str, str, str) |
| """ |
| # all_symbols is all valid single-character symbols |
| all_symbols = iter(string.ascii_lowercase + string.digits) |
| symbol_expr_parser = shlex.shlex(symbol_expr) |
| symbol_expr_parser.wordchars = [i for i in string.printable |
| if i not in (symbol_expr_parser.quotes + ",")] |
| symbol_expr_parser.whitespace = ',' |
| # parse symbol expr to get the strings between commas |
| sym_def_parsed = list(symbol_expr_parser) |
| |
| orig_symbols_ordered = [] |
| orig_sym_definitions = {} |
| new_sym_definitions = {} |
| old_to_new = {} |
| for each_sym_def in sym_def_parsed: |
| # symbols are defined as a pair: <name> := <definition> |
| sym_def_split = each_sym_def.split(":=") |
| if len(sym_def_split) == 2: |
| orig_sym, sym_def = (i.strip() for i in sym_def_split) |
| orig_symbols_ordered.append(orig_sym) |
| try: |
| next_sym = all_symbols.next() |
| except StopIteration: |
| plpy.error("Path error: Total symbols in the symbol expression " |
| "exceed maximum number of symbols allowed.") |
| # symbols are supposed to be case-insensitive. Use the lower-case |
| # version to maintain a mapping from original to new symbol name |
| orig_sym_lower = re.escape(orig_sym.lower()) |
| _assert(orig_sym_lower not in old_to_new, |
| "Path error: Multipe definitions of a symbol") |
| old_to_new[orig_sym_lower] = next_sym |
| orig_sym_definitions[orig_sym] = sym_def |
| new_sym_definitions[next_sym] = sym_def |
| |
| # replace each occurence of the original symbol with the new |
| # perform this operation in descending order of length to avoid substituting |
| # subset of any symbol |
| old_symbols_desc = list(sorted(old_to_new.keys(), key=len, reverse=True)) |
| replace_pattern = re.compile('|'.join(old_symbols_desc), re.IGNORECASE) |
| new_pattern_expr = replace_pattern.sub( |
| lambda m: old_to_new[re.escape(string.lower(m.group(0)))], |
| pattern_expr) |
| |
| # build a case statement to search a tuple for each definition and pick the |
| # appropriate symbol. |
| orig_sym_case_stmt = [] |
| new_sym_case_stmt = [] |
| case_stmt = "WHEN {d} THEN '{s}'::text" |
| for k in orig_symbols_ordered: |
| orig_sym_case_stmt.append(case_stmt.format(s=k, d=orig_sym_definitions[k])) |
| new_sym_case_stmt.append(case_stmt.format(s=old_to_new[re.escape(k.lower())], |
| d=orig_sym_definitions[k])) |
| return (new_pattern_expr, '\n'.join(orig_sym_case_stmt), '\n'.join(new_sym_case_stmt)) |
| # ---------------------------------------------------------------------- |
| |
| |
| def path_help_message(schema_madlib, message, **kwargs): |
| """ Help message for path function |
| """ |
| summary_string = """ |
| --------------------------------------------------------------------------- |
| SUMMARY |
| --------------------------------------------------------------------------- |
| Functionality: Path |
| |
| The goal of the MADlib path function is to perform regular pattern matching |
| over a sequence of rows, and to extract useful information about the matches. |
| The useful information could be a simple count of matches or something more |
| involved like aggregation. |
| |
| For more details on function usage: |
| SELECT {schema_madlib}.path('usage'); |
| """.format(schema_madlib=schema_madlib) |
| |
| usage_string = """ |
| --------------------------------------------------------------------------- |
| USAGE |
| --------------------------------------------------------------------------- |
| SELECT {schema_madlib}.path( |
| 'source_table', -- Name of the table |
| 'output_table', -- Table name to store the path results |
| 'partition_expr', -- Partition expression to group the data table |
| 'order_expr', -- Order expression to sort the tuples of the data table |
| 'symbol_def', -- Definition of various symbols used in the pattern definition |
| 'pattern_def', -- Definition of the path pattern to search for |
| 'agg_func', -- Aggregate/window functions to be applied on the matched paths |
| persist_rows, -- Boolean indicating whether to output the matched |
| -- rows in an additional table (named <output_table>_tuples) |
| overlapping_patterns -- Boolean indicating whether to find every |
| -- overlapping occurrence of the pattern in the partition |
| ); |
| """.format(schema_madlib=schema_madlib) |
| |
| if not message: |
| return summary_string |
| elif message.lower() in ('usage', 'help', '?'): |
| return usage_string |
| else: |
| return """ |
| No such option. Use "SELECT {schema_madlib}.path()" for help. |
| """.format(schema_madlib=schema_madlib) |
| # ------------------------------------------------------------ |