| from utilities import add_postfix |
| from validate_args import input_tbl_valid, output_tbl_valid, cols_in_tbl_valid |
| |
| if __name__ != "__main__": |
| import plpy |
| |
| |
| def _create_vocab_table(input_table, doc_id_col, word_vec_col, output_table): |
| """ Convert a collection of arrays of words to a single dictionary of words |
| """ |
| plpy.execute(""" |
| CREATE TABLE {output_table} AS |
| SELECT (row_number() OVER (order by word))::INTEGER - 1 as wordid, |
| word::TEXT |
| FROM ( |
| SELECT distinct(words) as word |
| FROM ( |
| SELECT unnest({word_vec_col}::TEXT[]) as words |
| FROM {input_table} |
| ) q1 |
| ) q2 |
| """.format(**locals())) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def _create_tf_table(input_table, doc_id_col, word_vec_col, |
| output_table, vocab_table=None): |
| """ Convert the input documents table to a count format (<docid, word, count>). |
| Args: |
| @param schema_madlib: str, Name of MADlib schema |
| @param input_table: str, The data table name to be converted |
| (This table is assumed to contain the columns ) |
| @param doc_id_col: str, Name of the column containing document identifier |
| @param word_vec_col: str, Name of the column containing words in document |
| (Column must of type that can be cast to TEXT[]) |
| @param vocab_table: str, Table containing the dictionary/vocabulary |
| (This table is assumed to be of form: <wordid, word>) |
| |
| Returns: None |
| """ |
| word_type = 'INTEGER' if vocab_table else 'TEXT' |
| word_name = 'wordid' if vocab_table else 'word' |
| plpy.execute(""" |
| CREATE TABLE {output_table}( |
| {doc_id_col} INTEGER, |
| {word_name} {word_type}, |
| count INTEGER |
| ) |
| """.format(**locals())) |
| |
| if vocab_table: |
| # if vocabulary is provided then we convert a word to a wordid |
| inner_query = """ |
| , {0} as w |
| WHERE |
| q2.word = w.word |
| """.format(vocab_table) |
| word_select = "w.wordid" |
| else: |
| inner_query = '' |
| word_select = "word" |
| |
| plpy.execute(""" |
| INSERT INTO {output_table} |
| SELECT {doc_id_col}, {word_select} as {word_name}, word_count as count |
| FROM ( |
| SELECT {doc_id_col}, word::TEXT, count(*) as word_count |
| FROM |
| ( |
| SELECT {doc_id_col}, unnest({word_vec_col}::TEXT[]) as word |
| FROM {input_table} |
| WHERE |
| {doc_id_col} IS NOT NULL |
| ) q1 |
| GROUP BY {doc_id_col}, word |
| ) q2 |
| {inner_query} |
| """.format(**locals())) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def term_frequency(input_table, doc_id_col, word_vec_col, |
| output_table, compute_vocab=False): |
| |
| input_tbl_valid(input_table, "Term frequency") |
| output_tbl_valid(output_table, "Term frequency") |
| cols_in_tbl_valid(input_table, [doc_id_col, word_vec_col], |
| 'Term frequency') |
| try: |
| plpy.execute("SELECT {0}::TEXT[] FROM {1} LIMIT 1". |
| format(word_vec_col, input_table)) |
| except: |
| plpy.error("Term frequency error: Word vector input cannot be " |
| "cast to a TEXT array") |
| |
| if compute_vocab: |
| vocab_table = add_postfix(output_table, "_vocabulary") |
| output_tbl_valid(vocab_table, "Term frequency") |
| _create_vocab_table(input_table, doc_id_col, word_vec_col, vocab_table) |
| else: |
| vocab_table = None |
| |
| _create_tf_table(input_table, doc_id_col, word_vec_col, |
| output_table, vocab_table=vocab_table) |
| return_str = "Term frequency output in table {0}".format(output_table) |
| if vocab_table: |
| return_str += ", vocabulary in table {0}".format(vocab_table) |
| return return_str |
| # ------------------------------------------------------------------------------ |