| """ |
| @file lda.py_in |
| |
| @brief Latent Dirichlet Allocation inference using collapsed Gibbs |
| sampling algorithm |
| |
| @namespace lda |
| |
| LDA: Driver and auxiliary functions |
| """ |
| |
| import plpy |
| # import time |
| |
| # use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0 |
| from utilities.control import OptimizerControl |
| from utilities.control import HashaggControl |
| from utilities.utilities import __mad_version, _assert, warn |
| from utilities.validate_args import output_tbl_valid |
| from utilities.validate_args import input_tbl_valid |
| from utilities.utilities import py_list_to_sql_string |
| |
| # use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0 |
| version_wrapper = __mad_version() |
| string_to_array = version_wrapper.select_vecfunc() |
| array_to_string = version_wrapper.select_vec_return() |
| |
| |
| class LDATrainer: |
| |
| """ |
| @brief This class defines a LDATrainer. |
| """ |
| |
| def __init__(self, schema_madlib, data_table, model_table, |
| output_data_table, voc_size, topic_num, |
| iter_num, alpha, beta, evaluate_every, perplexity_tol ): |
| self.schema_madlib = schema_madlib |
| self.data_table = data_table |
| self.voc_size = voc_size |
| self.topic_num = topic_num |
| self.iter_num = iter_num |
| self.alpha = alpha |
| self.beta = beta |
| self.model_table = model_table |
| self.output_data_table = output_data_table |
| self.work_table_0 = '__work_table_train_0__' |
| self.work_table_1 = '__work_table_train_1__' |
| self.evaluate_every = evaluate_every |
| self.perplexity_tol = perplexity_tol |
| self.perplexity = [] |
| self.perplexity_diff = self.perplexity_tol |
| self.perplexity_iters = [] |
| self.tol_reached = False |
| self.num_iterations = 0 |
| |
| plpy.execute("DROP TABLE IF EXISTS " + self.work_table_0) |
| plpy.execute(""" |
| CREATE TEMP TABLE {work_table_0}( |
| docid INT4, |
| wordcount INT4, |
| words INT4[], |
| counts INT4[], |
| doc_topic INT4[] |
| ) |
| m4_ifdef(`__POSTGRESQL__', `', |
| `WITH (APPENDONLY=TRUE) |
| DISTRIBUTED BY (docid)') |
| """.format(work_table_0=self.work_table_0)) |
| |
| plpy.execute("DROP TABLE IF EXISTS " + self.work_table_1) |
| plpy.execute(""" |
| CREATE TEMP TABLE {work_table_1}( |
| docid INT4, |
| wordcount INT4, |
| words INT4[], |
| counts INT4[], |
| doc_topic INT4[] |
| ) |
| m4_ifdef(`__POSTGRESQL__', `', |
| `WITH (APPENDONLY=TRUE) |
| DISTRIBUTED BY (docid)') |
| """.format(work_table_1=self.work_table_1)) |
| |
| plpy.execute(""" |
| CREATE TABLE {model_table}( |
| voc_size INT4, |
| topic_num INT4, |
| alpha FLOAT8, |
| beta FLOAT8, |
| model INT8[], |
| num_iterations INT, |
| perplexity DOUBLE PRECISION[], |
| perplexity_iters INTEGER[] |
| ) |
| m4_ifdef(`__POSTGRESQL__', `', |
| `WITH (APPENDONLY=TRUE) |
| DISTRIBUTED RANDOMLY') |
| """.format(model_table=self.model_table)) |
| |
| plpy.execute(""" |
| CREATE TABLE {output_data_table}( |
| docid INT4, |
| wordcount INT4, |
| words INT4[], |
| counts INT4[], |
| topic_count INT4[], |
| topic_assignment INT4[] |
| ) |
| m4_ifdef(`__POSTGRESQL__', `', |
| `WITH (APPENDONLY=TRUE) |
| DISTRIBUTED BY (docid)') |
| """.format(output_data_table=self.output_data_table)) |
| |
| def init_random(self): |
| # stime = time.time() |
| # plpy.notice('initializing topics randomly ...') |
| |
| plpy.execute('TRUNCATE TABLE ' + self.work_table_0) |
| plpy.execute(""" |
| INSERT INTO {work_table} |
| SELECT |
| docid, wordcount, words, counts, |
| {schema_madlib}.__lda_random_assign(wordcount, {topic_num}) AS topics |
| FROM {data_table} |
| """.format(work_table=self.work_table_0, |
| schema_madlib=self.schema_madlib, |
| topic_num=self.topic_num, |
| data_table=self.data_table)) |
| |
| # etime = time.time() |
| # plpy.notice('\t\ttime elapsed: %.2f seconds' % (etime - stime)) |
| |
| def gen_final_data_tables(self): |
| # stime = time.time() |
| # plpy.notice('\t\tgenerating output data table ...') |
| |
| ##This function updates 2 tables, one is the model table and |
| # the other one is the output table |
| |
| work_table_final = self.work_table_0 if self.iter_num % 2 == 0 \ |
| else self.work_table_1 |
| |
| # Update model table |
| #JIRA:MADLIB-1201, we have to update model table one more time after |
| # iteration to sycn up output table and model table |
| self.update_model_table(work_table_final) |
| |
| self.gen_output_data_table(work_table_final) |
| |
| # JIRA: MADLIB-1351 |
| # Calculate Perplexity after the final update of |
| # the Model and Output Table |
| if self.evaluate_every > 0 and not self.tol_reached: |
| self.perplexity.append( |
| get_perplexity(self.schema_madlib, |
| self.model_table, |
| self.output_data_table)) |
| # Need to update Model Table one more time to update |
| # last calculated value of perplexity to it |
| self.update_model_table(work_table_final) |
| |
| |
| def iteration(self, it): |
| # stime = time.time() |
| work_table_in = self.work_table_0 |
| work_table_out = self.work_table_1 |
| if it % 2 == 0: |
| work_table_in = self.work_table_1 |
| work_table_out = self.work_table_0 |
| |
| # Update model_table using work_table_out from the previous iteration |
| self.update_model_table(work_table_in) |
| |
| # For GPDB 4.3 and higher we disable the optimzer (ORCA) for the query |
| # planner since currently ORCA doesn't support InitPlan. This would have |
| # to be fixed when ORCA is the only available query planner. |
| with OptimizerControl(False): |
| # plpy.notice('\t\tsampling ...') |
| plpy.execute('TRUNCATE TABLE ' + work_table_out) |
| query = """ |
| INSERT INTO {work_table_out} |
| SELECT |
| docid, wordcount, words, counts, |
| {schema_madlib}.__lda_gibbs_sample( |
| words, counts, doc_topic, |
| (SELECT model FROM {model_table}), |
| {alpha}, {beta}, {voc_size}, {topic_num}, 1) |
| FROM |
| {work_table_in} |
| """.format(work_table_out=work_table_out, |
| schema_madlib=self.schema_madlib, |
| model_table=self.model_table, |
| alpha=self.alpha, |
| beta=self.beta, |
| voc_size=self.voc_size, |
| topic_num=self.topic_num, |
| work_table_in=work_table_in) |
| plpy.execute(query) |
| |
| # etime = time.time() |
| # plpy.notice('\t\ttime elapsed: %.2f seconds' % (etime - stime)) |
| |
| self.calculatePerplexity(it, work_table_in) |
| |
| |
| def update_model_table(self, work_table_in): |
| |
| # JIRA: MADLIB-1351 |
| # Create a string based on the value of self.perplexity |
| perplexity_values = "" |
| perplexity_iterations = "" |
| if len(self.perplexity) >= 1: |
| perplexity_values = ", {0}".format(py_list_to_sql_string(self.perplexity)) |
| perplexity_iterations = ", {0}".format(py_list_to_sql_string(self.perplexity_iters)) |
| n_iterations=", {0}".format(self.num_iterations) |
| |
| |
| plpy.execute('TRUNCATE TABLE ' + self.model_table) |
| if version_wrapper.is_gp43(): |
| with OptimizerControl(True): |
| plpy.execute(""" |
| INSERT INTO {model_table} |
| SELECT |
| {voc_size}, |
| {topic_num}, |
| {alpha}, |
| {beta}, |
| {schema_madlib}.__lda_count_topic_agg( |
| words, |
| counts, |
| doc_topic[{topic_num} + 1:array_upper(doc_topic, 1)], |
| {voc_size}, |
| {topic_num} |
| ) AS model |
| {n_iterations}{perplexity_values} {perplexity_iterations} |
| FROM {work_table_in} |
| """.format(model_table=self.model_table, |
| topic_num=self.topic_num, |
| voc_size=self.voc_size, |
| alpha=self.alpha, |
| beta=self.beta, |
| perplexity_values=perplexity_values, |
| perplexity_iterations=perplexity_iterations, |
| schema_madlib=self.schema_madlib, |
| work_table_in=work_table_in, |
| n_iterations=n_iterations)) |
| else: |
| # work around insertion memory error (MPP-25561) |
| # by copying the model to Python temporarily |
| model = plpy.execute(""" |
| SELECT |
| {schema_madlib}.__lda_count_topic_agg( |
| words, |
| counts, |
| doc_topic[{topic_num} + 1:array_upper(doc_topic, 1)], |
| {voc_size}, |
| {topic_num} |
| ) AS model |
| FROM {work_table_in} |
| """.format(schema_madlib=self.schema_madlib, |
| topic_num=self.topic_num, |
| voc_size=self.voc_size, |
| work_table_in=work_table_in))[0]['model'] |
| # insert it back immediately |
| plan = plpy.prepare(""" |
| INSERT INTO {model_table} |
| SELECT |
| {voc_size}, {topic_num}, {alpha}, {beta}, $1 |
| {n_iterations}{perplexity_values} {perplexity_iterations} |
| """.format(model_table=self.model_table, |
| topic_num=self.topic_num, |
| voc_size=self.voc_size, |
| alpha=self.alpha, |
| beta=self.beta, |
| perplexity_values=perplexity_values, |
| perplexity_iterations=perplexity_iterations, |
| schema_madlib=self.schema_madlib, |
| n_iterations=n_iterations), |
| ['bigint[]']) |
| plpy.execute(plan, [model]) |
| |
| def run(self): |
| # stime = time.time() |
| # plpy.notice('start training process ...') |
| self.init_random() |
| # sstime = time.time() |
| for it in range(1, self.iter_num + 1): |
| # JIRA: MADLIB-1351 |
| # If the Perplexity_diff is less than the perplexity_tol, |
| # Stop the iteration |
| if self.perplexity_diff < self.perplexity_tol: |
| self.tol_reached = True |
| # When toll is reached before the number of iterations, |
| # Reduce the num_iterations by 1 since perplexity_iters |
| # Runs one iteration behind in this case. |
| self.num_iterations-=1; |
| break |
| |
| |
| self.iteration(it) |
| self.num_iterations+=1; |
| # eetime = time.time() |
| # plpy.notice('\t\titeration done, time elapsed: %.2f seconds' % (eetime - sstime)) |
| |
| |
| # JIRA: MADLIB-1351 |
| # Add the last iteration value to the array |
| if self.evaluate_every > 0 and not self.tol_reached: |
| self.perplexity_iters.append(self.iter_num) |
| self.gen_final_data_tables() |
| |
| # etime = time.time() |
| # plpy.notice('finished, time elapsed: %.2f seconds' % (etime - stime)) |
| |
| # Update output table |
| def gen_output_data_table(self, work_table_final): |
| plpy.execute("TRUNCATE TABLE " + self.output_data_table) |
| plpy.execute(""" |
| INSERT INTO {output_data_table} |
| SELECT |
| docid, wordcount, words, counts, doc_topic[1:{topic_num}] topic_count, |
| doc_topic[{topic_num} + 1:array_upper(doc_topic,1)] topic_assignment |
| FROM |
| {work_table_final} |
| """.format(output_data_table=self.output_data_table, |
| topic_num=self.topic_num, |
| work_table_final=work_table_final)) |
| # etime = time.time() |
| # plpy.notice('\t\t\ttime elapsed: %.2f seconds' % (etime - stime)) |
| |
| |
| def calculatePerplexity(self,it, work_table_in): |
| # JIRA: MADLIB-1351 |
| # Calculate Perplexity for evaluate_every Iteration |
| # Skip the calculation at the first iteration |
| # For each iteration: |
| # Model table is updated (for the first iteration, it is the random model. For iteration >1 , the model that is # updated is learnt in the previous iteration) |
| # __lda_count_topic_agg is called then lda_gibbs_sample is called which learns and updates the model(the updated # model is not passed to python. The learnt model is updated in the next iteration) |
| # Because of this workflow we can safely ignore the first perplexity value. |
| |
| |
| if it > self.evaluate_every and self.evaluate_every > 0 and ( |
| it - 1) % self.evaluate_every == 0: |
| self.gen_output_data_table(work_table_in) |
| perplexity = get_perplexity(self.schema_madlib, |
| self.model_table, |
| self.output_data_table) |
| if len(self.perplexity) > 0: |
| self.perplexity_diff = abs(self.perplexity[-1] - perplexity) |
| self.perplexity_iters.append(it - 1) |
| self.perplexity.append(perplexity) |
| |
| # ------------------------------------------------------------------------------ |
| |
| |
| class LDAPredictor: |
| |
| """ |
| @brief This class defines a LDAPredictor |
| """ |
| |
| def __init__(self, schema_madlib, test_table, model_table, output_table, |
| iter_num): |
| self.schema_madlib = schema_madlib |
| self.test_table = test_table |
| self.model_table = model_table |
| self.iter_num = iter_num |
| |
| rv = plpy.execute(""" |
| SELECT |
| voc_size, topic_num, alpha, beta |
| FROM |
| {model_table} |
| """.format(model_table=self.model_table)) |
| self.voc_size = rv[0]['voc_size'] |
| self.topic_num = rv[0]['topic_num'] |
| self.alpha = rv[0]['alpha'] |
| self.beta = rv[0]['beta'] |
| |
| self.doc_topic = output_table |
| self.work_table_0 = '__work_table_pred_0__' |
| self.work_table_1 = '__word_table_pred_1__' |
| |
| plpy.execute("DROP TABLE IF EXISTS " + self.work_table_0) |
| plpy.execute(""" |
| CREATE TEMP TABLE {work_table_0}( |
| docid INT4, |
| wordcount INT4, |
| words INT4[], |
| counts INT4[], |
| doc_topic INT4[] |
| ) |
| m4_ifdef(`__POSTGRESQL__', `', |
| `WITH (APPENDONLY=TRUE) |
| DISTRIBUTED BY (docid)') |
| """.format(work_table_0=self.work_table_0)) |
| |
| plpy.execute("DROP TABLE IF EXISTS " + self.work_table_1) |
| plpy.execute(""" |
| CREATE TEMP TABLE {work_table_1}( |
| docid INT4, |
| wordcount INT4, |
| words INT4[], |
| counts INT4[], |
| doc_topic INT4[] |
| ) |
| m4_ifdef(`__POSTGRESQL__', `', |
| `WITH (APPENDONLY=TRUE) |
| DISTRIBUTED BY (docid)') |
| """.format(work_table_1=self.work_table_1)) |
| |
| plpy.execute(""" |
| CREATE TABLE {doc_topic}( |
| docid INT4, |
| wordcount INT4, |
| words INT4[], |
| counts INT4[], |
| topic_count INT4[], |
| topic_assignment INT4[] |
| ) |
| m4_ifdef(`__POSTGRESQL__', `', |
| `WITH (APPENDONLY=TRUE) |
| DISTRIBUTED BY (docid)') |
| """.format(doc_topic=self.doc_topic)) |
| |
| def init_random(self): |
| # stime = time.time() |
| # plpy.notice('initializing topics randomly ...') |
| |
| plpy.execute('TRUNCATE TABLE ' + self.work_table_0) |
| plpy.execute(""" |
| INSERT INTO {work_table} |
| SELECT |
| docid, wordcount, words, counts, |
| {schema_madlib}.__lda_random_assign(wordcount, {topic_num}) AS topics |
| FROM {data_table} |
| """.format(work_table=self.work_table_0, |
| schema_madlib=self.schema_madlib, |
| topic_num=self.topic_num, |
| data_table=self.test_table)) |
| |
| # etime = time.time() |
| # plpy.notice('\t\ttime elapsed: %.2f seconds' % (etime - stime)) |
| |
| def gen_output_table(self): |
| plpy.execute("TRUNCATE TABLE " + self.doc_topic) |
| plpy.execute(""" |
| INSERT INTO {doc_topic} |
| SELECT |
| docid, wordcount, words, counts, doc_topic[1:{topic_num}] topic_count, |
| doc_topic[{topic_num} + 1:array_upper(doc_topic,1)] topic_assignment |
| FROM {work_table_out} |
| """.format(doc_topic=self.doc_topic, |
| topic_num=self.topic_num, |
| work_table_out=self.work_table_1)) |
| |
| def infer(self): |
| # stime = time.time() |
| # plpy.notice('infering ...') |
| |
| # For GPDB 4.3 and higher we disable the optimzer (ORCA) for the query |
| # planner since currently ORCA doesn't support InitPlan. This would have |
| # to be fixed when ORCA is the only available query planner. |
| with OptimizerControl(False): |
| query = """ |
| INSERT INTO {work_table_out} |
| SELECT |
| docid, wordcount, words, counts, |
| {schema_madlib}.__lda_gibbs_sample( |
| words, counts, doc_topic, |
| (SELECT model FROM {model_table}), |
| {alpha}, {beta}, |
| {voc_size}, {topic_num}, {iter_num} |
| ) |
| FROM |
| {work_table_in} |
| """.format(work_table_out=self.work_table_1, |
| work_table_in=self.work_table_0, |
| model_table=self.model_table, |
| schema_madlib=self.schema_madlib, |
| alpha=self.alpha, |
| beta=self.beta, |
| voc_size=self.voc_size, |
| topic_num=self.topic_num, |
| iter_num=self.iter_num) |
| plpy.execute(query) |
| |
| # etime = time.time() |
| # plpy.notice('\t\ttime elapsed: %.2f seconds' % (etime - stime)) |
| |
| def run(self): |
| # stime = time.time() |
| # plpy.notice('start prediction process ...') |
| |
| self.init_random() |
| self.infer() |
| self.gen_output_table() |
| |
| # etime = time.time() |
| # plpy.notice('finished, time elapsed: %.2f seconds' % (etime - stime)) |
| # ------------------------------------------------------------------------------ |
| |
| |
| def lda_train(schema_madlib, train_table, model_table, output_data_table, voc_size, |
| topic_num, iter_num, alpha, beta, evaluate_every, perplexity_tol): |
| """ |
| @brief This function provides the entry for the LDA training process. |
| @param schema_madlib MDALib schema |
| @param data_table Training data table |
| @param voc_size Size of vocabulary |
| @param topic_num Number of topics |
| @param iter_num Number of iterations |
| @param alpha Dirichlet parameter for per-document topic multinomial |
| @param beta Dirichlet parameter for per-topic word multinomial |
| @param model_table Learned model table |
| @param output_data_table Output data table |
| """ |
| _assert(train_table is not None and train_table.strip() != '', |
| 'invalid argument: train_table is not specified') |
| _assert(model_table is not None and model_table.strip() != '', |
| 'invalid argument: model_table is not specified') |
| _assert(output_data_table is not None and output_data_table.strip() != '', |
| 'invalid argument: output_data_table is not specified') |
| _assert(voc_size is not None and voc_size > 0, |
| 'invalid argument: positive integer expected for voc_size') |
| _assert(topic_num is not None and topic_num > 0, |
| 'invalid argument: positive integer expected for topic_num') |
| _assert(iter_num is not None and iter_num > 0, |
| 'invalid argument: positive integer expected for iter_num') |
| _assert(alpha is not None and alpha > 0, |
| 'invalid argument: positive real expected for alpha') |
| _assert(beta is not None and beta > 0, |
| 'invalid argument: positive real expected for beta') |
| |
| # Setting the default values for perplexity_tol and evaluate_every |
| if perplexity_tol is None: |
| perplexity_tol = 0.1 |
| if evaluate_every is None: |
| evaluate_every = 0 |
| _assert(evaluate_every <= iter_num, |
| 'invalid argument: evaluate_every should not be greater than iter_num') |
| _assert(perplexity_tol is not None and perplexity_tol >= 0, |
| 'invalid argument: perplexity_tol should not be less than 0') |
| |
| output_tbl_valid(model_table, 'LDA') |
| output_tbl_valid(output_data_table, 'LDA') |
| |
| warn(voc_size <= 1e5, |
| """the voc_size is very large: %d - make sure that the system has |
| enough memory or reduce the vocabulary size""" % (voc_size)) |
| warn(topic_num <= 1e3, |
| """the topic_num is large: %d - make sure that the system has enough |
| memory or reduce the number of topics """ % (topic_num)) |
| |
| _validate_data_table(train_table, voc_size) |
| convt_table = _convert_data_table(schema_madlib, train_table) |
| lt = LDATrainer(schema_madlib, convt_table, model_table, |
| output_data_table, voc_size, topic_num, |
| iter_num, alpha, beta,evaluate_every, perplexity_tol) |
| |
| lt.run() |
| |
| # __lda_check_count_ceiling returns NULL when the count ceiling is |
| # not hit, otherwise, a sample of wordid's that hit |
| examples_hit_ceiling = plpy.execute(""" |
| SELECT |
| array_to_string( |
| {schema_madlib}.__lda_check_count_ceiling( |
| model, |
| voc_size, |
| topic_num)::integer[], |
| ',' |
| ) AS examples_hit_ceiling |
| FROM {model_table} |
| """.format(**locals()))[0]['examples_hit_ceiling'] |
| |
| if examples_hit_ceiling is not None: |
| plpy.info(""" |
| LDA warning: some words ({examples_hit_ceiling}, etc.) occur too frequently. |
| More than 2e9 times. It might add noise for the LDA model. |
| Please remove the very frequent words in the data. |
| """.format(**locals())) |
| |
| |
| def lda_predict(schema_madlib, test_table, model_table, output_data_table, |
| iter_num=20): |
| """ |
| @brief This function provides the entry for the LDA prediction process. |
| @param test_table name of the testing dataset table |
| @param model_table name of the model table |
| @param iter_num number of iterations |
| @param output_table name of output table |
| """ |
| |
| input_tbl_valid(test_table, 'LDA') |
| input_tbl_valid(model_table, 'LDA') |
| output_tbl_valid(output_data_table, 'LDA') |
| |
| iter_num = 20 if iter_num is None else iter_num |
| _assert( |
| iter_num >= 0, |
| 'invalid argument: positive integer expected for iter_num') |
| warn( |
| iter_num <= 20, |
| """the iter_num is large: %d - a smaller iter_num (e.g. 20) should be |
| good enough""" % (iter_num)) |
| |
| _validate_model_table(model_table) |
| rv = plpy.execute('SELECT voc_size FROM ' + model_table) |
| voc_size = rv[0]['voc_size'] |
| _validate_data_table(test_table, voc_size) |
| |
| convt_table = _convert_data_table(schema_madlib, test_table) |
| lp = LDAPredictor( |
| schema_madlib, convt_table, model_table, output_data_table, iter_num) |
| |
| lp.run() |
| |
| |
| def get_topic_desc(schema_madlib, model_table, vocab_table, desc_table, |
| top_k=15): |
| """ |
| @brief Get the per-topic description by top-k words |
| @param model_table The model table generated by the training process |
| @param vocab_table The vocabulary table |
| @param top_k The top k words for topic description |
| @param desc_table The output table for storing the per-topic word description |
| """ |
| _assert(model_table != '' and vocab_table != '' and desc_table != '', |
| "invalid argument: at least one of the table names is not specified") |
| |
| _assert(top_k > 0, "invalid argument: Positive integer expected for top_k") |
| |
| _validate_model_table(model_table) |
| _validate_vocab_table(vocab_table) |
| |
| output_tbl_valid(desc_table, 'LDA') |
| |
| plpy.execute("DROP TABLE IF EXISTS __lda_topic_word_count__") |
| plpy.execute(""" |
| CREATE TEMP TABLE __lda_topic_word_count__( |
| topicid INT4, |
| word_count INT4[], |
| beta FLOAT8 |
| ) |
| m4_ifdef( |
| `__POSTGRESQL__', `', |
| `WITH (APPENDONLY=TRUE) DISTRIBUTED BY |
| (topicid)') |
| """) |
| |
| plpy.execute(""" |
| INSERT INTO __lda_topic_word_count__ |
| SELECT |
| generate_series(0, topic_num - 1) AS topicid, |
| {schema_madlib}.__lda_util_unnest_transpose( |
| model, |
| voc_size, |
| topic_num |
| ) AS word_count, |
| beta |
| FROM {model_table} |
| """.format( |
| schema_madlib=schema_madlib, |
| model_table=model_table)) |
| |
| plpy.execute("DROP TABLE IF EXISTS __lda_topic_word_dist__") |
| plpy.execute(""" |
| CREATE TEMP TABLE __lda_topic_word_dist__( |
| topicid INT4, |
| word_dist FLOAT[] |
| ) |
| m4_ifdef( |
| `__POSTGRESQL__', `', |
| `WITH (APPENDONLY=TRUE) |
| DISTRIBUTED BY (topicid)') |
| """) |
| |
| plpy.execute(""" |
| INSERT INTO __lda_topic_word_dist__ |
| SELECT |
| topicid, |
| {schema_madlib}.__lda_util_norm_with_smoothing(word_count, beta) dist |
| FROM |
| __lda_topic_word_count__ |
| """.format(schema_madlib=schema_madlib)) |
| |
| plpy.execute(""" |
| CREATE TABLE {desc_table} ( |
| topicid INT4, |
| wordid INT4, |
| prob FLOAT, |
| word TEXT) |
| m4_ifdef( |
| `__POSTGRESQL__', `', |
| `WITH (APPENDONLY=TRUE) |
| DISTRIBUTED BY (topicid)') |
| """.format(desc_table=desc_table)) |
| |
| plpy.execute(""" |
| INSERT INTO {desc_table} |
| SELECT |
| topicid, t2.wordid, prob, word |
| FROM |
| ( |
| SELECT |
| topicid, wordid, prob, |
| rank() OVER(PARTITION BY topicid ORDER BY prob DESC) r |
| FROM |
| ( |
| SELECT |
| topicid, |
| generate_series(0, array_upper(word_dist, 1) - 1) wordid, |
| unnest(word_dist) prob |
| FROM |
| __lda_topic_word_dist__ |
| ) t1 |
| ) t2, {vocab_table} AS vocab |
| WHERE t2.r <= {top_k} AND t2.wordid = vocab.wordid |
| """.format(desc_table=desc_table, |
| vocab_table=vocab_table, top_k=top_k)) |
| |
| |
| def get_topic_word_count(schema_madlib, model_table, output_table): |
| """ |
| @brief Get the per-topic word counts from the model table |
| @param model_table The model table generated by the training process |
| @param output_table The output table for storing the per-topic word counts |
| """ |
| _assert( |
| model_table != '', |
| 'invalid argument: model table name is not specified') |
| _validate_model_table(model_table) |
| output_tbl_valid(output_table, 'LDA') |
| |
| plpy.execute(""" |
| CREATE TABLE {output_table} ( |
| topicid INT4, |
| word_count INT4[]) |
| m4_ifdef( |
| `__POSTGRESQL__', `', |
| `WITH (APPENDONLY=TRUE) |
| DISTRIBUTED BY (topicid)') |
| """.format(output_table=output_table)) |
| |
| plpy.execute(""" |
| INSERT INTO {output_table} |
| SELECT |
| generate_series(0, topic_num - 1) topicid, |
| {schema_madlib}.__lda_util_unnest_transpose( |
| model, |
| voc_size, |
| topic_num |
| ) AS word_count |
| FROM |
| {model_table} |
| """.format(output_table=output_table, schema_madlib=schema_madlib, |
| model_table=model_table)) |
| |
| |
| def get_word_topic_count(schema_madlib, model_table, output_table): |
| """ |
| @brief Get the per-word topic counts from the model table |
| @param model_table The model table generated by the training process |
| @param output_table The output table for storing the per-word topic counts |
| """ |
| _assert(model_table != '', |
| "invalid argument: model table name is not specified") |
| _validate_model_table(model_table) |
| output_tbl_valid(output_table, 'LDA') |
| |
| plpy.execute(""" |
| CREATE TABLE {output_table} ( |
| wordid INT4, |
| topic_count INT4[]) |
| m4_ifdef( |
| `__POSTGRESQL__', `', |
| `WITH (APPENDONLY=TRUE) |
| DISTRIBUTED BY (wordid)') |
| """.format(output_table=output_table)) |
| |
| plpy.execute(""" |
| INSERT INTO {output_table} |
| SELECT |
| generate_series(0, voc_size - 1) wordid, |
| {schema_madlib}.__lda_util_unnest(model, voc_size, topic_num) word_count |
| FROM {model_table} |
| """.format(output_table=output_table, schema_madlib=schema_madlib, |
| model_table=model_table)) |
| |
| def get_word_topic_mapping(schema_madlib, lda_output_table, mapping_table): |
| """ |
| @brief Get the wordid - topicid mapping from the lda training output table |
| @param lda_output_table The output table from lda traning or predicting |
| @param mapping_table The result table that saves the mapping info |
| """ |
| _assert(lda_output_table != '', |
| "invalid argument: LDA output table name is not specified") |
| output_tbl_valid(mapping_table, 'LDA') |
| |
| plpy.execute(""" |
| CREATE TABLE {mapping_table} ( |
| docid INT4, |
| wordid INT4, |
| topicid INT4) |
| m4_ifdef( |
| `__POSTGRESQL__', `', |
| `WITH (APPENDONLY=TRUE) |
| DISTRIBUTED BY (docid)') |
| """.format(mapping_table=mapping_table)) |
| |
| ## The following query is a workaround for GPDB 4.3 because it cannot |
| ## convert text string to svec (that's why we have to call |
| ## array_to_string first to form a string and then call |
| ## {schema_madlib}.svec_from_string to convert it to svec format). |
| ## In GPDB5, the query can be written as |
| ## ``` |
| ## INSERT INTO {mapping_table} |
| ## SELECT docid, unnest((counts::text || ':' || |
| ## words::text)::{schema_madlib}.svec::float[]) AS wordid, |
| ## unnest(topic_assignment) AS topicid |
| ## FROM {lda_output_table} |
| ## GROUP BY docid, wordid, topicid |
| ## ``` |
| ## Also look at validate_lda_output() function in lda install check |
| ## which applies the same workaround |
| |
| plpy.execute(""" |
| INSERT INTO {mapping_table} |
| SELECT docid, |
| unnest({schema_madlib}.svec_from_string('{{' || array_to_string(counts, ',') || '}}:{{' || array_to_string(words, ',') || '}}')::float[]) AS wordid, |
| unnest(topic_assignment) AS topicid |
| FROM {lda_output_table} |
| GROUP BY docid, wordid, topicid |
| ORDER BY docid |
| """.format(lda_output_table=lda_output_table, |
| schema_madlib=schema_madlib, mapping_table=mapping_table)) |
| |
| def get_perplexity(schema_madlib, model_table, output_data_table): |
| """ |
| @brief Get the perplexity given the prediction and model. |
| @param model_table The model table generated by lda_train |
| @param output_data_table The output data table generated by lda_predict |
| """ |
| _assert(model_table != '' and output_data_table != '', |
| 'invalid argument: at least one of the table names is not specified') |
| |
| _validate_model_table(model_table) |
| params = plpy.execute(""" |
| SELECT topic_num, voc_size, alpha, beta FROM {model_table} |
| """.format(model_table=model_table))[0] |
| topic_num = params['topic_num'] |
| voc_size = params['voc_size'] |
| alpha = params['alpha'] |
| beta = params['beta'] |
| _validate_output_data_table(output_data_table, topic_num) |
| |
| # For GPDB 4.3 and higher we disable the optimzer (ORCA) for the query |
| # planner since currently ORCA doesn't support InitPlan. This would have |
| # to be fixed when ORCA is the only available query planner. |
| with OptimizerControl(False): |
| query = """ |
| SELECT exp(-part_perp/total_word) AS perp |
| FROM |
| ( |
| SELECT {schema_madlib}.__lda_perplexity_agg( |
| words, counts, topic_count, (SELECT model FROM {model_table}), |
| {alpha}, {beta}, {voc_size}, {topic_num}) AS part_perp |
| FROM |
| {out_data_table} |
| ) t1, |
| ( |
| SELECT sum(wordcount) total_word FROM {out_data_table} |
| ) t2 |
| """.format(schema_madlib=schema_madlib, |
| out_data_table=output_data_table, |
| model_table=model_table, |
| alpha=alpha, |
| beta=beta, |
| topic_num=topic_num, |
| voc_size=voc_size) |
| rv = plpy.execute(query) |
| return rv[0]['perp'] |
| |
| |
| def norm_vocab(vocab_table, out_table): |
| """ |
| @brief Checks the vocabulary and converts non-continous wordids into continuous |
| integers ranging from 0 to voc_size - 1. |
| @param vocab_table The vocabulary table in the form of |
| <wordid::INT4, word::text> |
| @param out_table The normalized vocabulary table |
| """ |
| _validate_vocab_table(vocab_table) |
| output_tbl_valid(out_table, 'LDA') |
| |
| plpy.execute(""" |
| CREATE TABLE {out_table}( |
| wordid INT4, |
| old_wordid INT4, |
| word TEXT |
| ) |
| m4_ifdef(`__POSTGRESQL__', `', |
| `WITH(APPENDONLY=TRUE) |
| DISTRIBUTED BY (wordid)') |
| """.format(out_table=out_table)) |
| |
| plpy.execute(""" |
| INSERT INTO {out_table} |
| SELECT r - 1, wordid, word |
| FROM |
| ( |
| SELECT |
| wordid, word, rank() OVER(ORDER BY wordid) r |
| FROM |
| ( |
| SELECT wordid, word |
| FROM |
| ( |
| SELECT |
| wordid, |
| word, |
| rank() OVER(PARTITION BY wordid ORDER BY word ASC) r |
| FROM |
| {vocab_table} |
| ) t1 |
| WHERE r = 1 |
| ) t2 |
| ) t3 |
| """.format(out_table=out_table, vocab_table=vocab_table)) |
| |
| |
| def norm_dataset(data_table, vocab_table, output_table): |
| """ |
| @brief Normalize the data table according to the normalized vocabulary, rows |
| with non-positive count values will be removed |
| @param data_table The data table to be normalized |
| @param vocab_table The nomralized vocabulary table |
| @param output_table The normalized data table |
| """ |
| _validate_data_table_cols(data_table) |
| _validate_norm_vocab_table(vocab_table) |
| output_tbl_valid(output_table, 'LDA') |
| |
| plpy.execute(""" |
| CREATE TABLE {output_table}( |
| docid INT4, |
| wordid INT4, |
| count INT4 |
| ) |
| m4_ifdef(`__POSTGRESQL__', `', |
| `WITH(APPENDONLY=TRUE) |
| DISTRIBUTED BY (docid)') |
| """.format(output_table=output_table)) |
| |
| plpy.execute(""" |
| INSERT INTO {output_table} |
| SELECT |
| docid, |
| vocab.wordid, |
| count |
| FROM |
| {data_table} AS data, |
| {vocab_table} AS vocab |
| WHERE |
| data.wordid = vocab.old_wordid AND data.count > 0 |
| """.format(output_table=output_table, |
| data_table=data_table, |
| vocab_table=vocab_table)) |
| |
| |
| def conorm_data(data_table, vocab_table, output_data_table, output_vocab_table): |
| """ |
| @brief Co-normalize the data table and the vocabulary table |
| @param data_table The data table to be normalized |
| @param vocab_table The vocabulary table to be nomralized |
| @param output_data_table The normalized data table |
| @param output_vocab_table The normalized vocabulary table |
| """ |
| _validate_data_table_cols(data_table) |
| _validate_vocab_table(vocab_table) |
| output_tbl_valid(output_data_table, 'LDA') |
| output_tbl_valid(output_vocab_table, 'LDA') |
| |
| plpy.execute("DROP TABLE IF EXISTS __vocab__") |
| plpy.execute(""" |
| CREATE TEMP TABLE __vocab__( |
| wordid INT4, |
| word TEXT |
| ) |
| m4_ifdef(`__POSTGRESQL__', `', `WITH(APPENDONLY=TRUE) |
| DISTRIBUTED BY (wordid)') |
| """) |
| |
| plpy.execute(""" |
| INSERT INTO __vocab__ |
| SELECT voc.wordid, voc.word |
| FROM |
| ( |
| SELECT |
| wordid |
| FROM |
| {data_table} AS data |
| GROUP BY wordid |
| ) tvoc, {vocab_table} voc |
| WHERE tvoc.wordid = voc.wordid |
| """.format(data_table=data_table, vocab_table=vocab_table)) |
| |
| norm_vocab('__vocab__', output_vocab_table) |
| norm_dataset(data_table, output_vocab_table, output_data_table) |
| |
| |
| def index_sort(arr, **kwargs): |
| """ |
| @brief Return the index of elements in a sorted order |
| @param arr The array to be sorted |
| @return The index of elements |
| """ |
| # process arrays for GPDB < 4.1 and PG < 9.0 |
| arr = string_to_array(arr, False) |
| dim = len(arr) |
| idx = range(dim) |
| idx.sort(key=lambda r: arr[r]) |
| return array_to_string(map(lambda r: r + 1, idx)) |
| |
| |
| def _convert_data_table(schema_madlib, data_table): |
| """ |
| @brief Convert the format of data table from <docid, wordid, count> to <docid, |
| wordcount, words, counts>. |
| @param data_table The data table to be converted |
| @param return The converted table name |
| """ |
| # plpy.notice('converting the data table ...') |
| convt_table = '__lda_convt_corpus__' |
| plpy.execute("DROP TABLE IF EXISTS " + convt_table) |
| plpy.execute(""" |
| CREATE TEMP TABLE {convt_table}( |
| docid INT4, |
| wordcount INT4, |
| words INT4[], |
| counts INT4[] |
| ) |
| m4_ifdef(`__POSTGRESQL__', `', `WITH(APPENDONLY=TRUE) |
| DISTRIBUTED BY (docid)') |
| """.format(convt_table=convt_table)) |
| |
| with OptimizerControl(False): |
| with HashaggControl(False): |
| plpy.execute(""" |
| INSERT INTO {convt_table} |
| SELECT |
| docid, |
| sum(count) wordcount, |
| array_agg(wordid) words, |
| array_agg(count) counts |
| FROM |
| {data_table} |
| WHERE |
| (docid IS NOT NULL) AND |
| (wordid IS NOT NULL) AND |
| (count IS NOT NULL) |
| GROUP BY docid |
| """.format(convt_table=convt_table, |
| schema_madlib=schema_madlib, |
| data_table=data_table)) |
| return convt_table |
| |
| |
| def _validate_data_table_cols(data_table): |
| """ |
| @brief Check the structure of the data table |
| @param data_table The data table name |
| """ |
| # plpy.notice('checking the data table ... ') |
| _assert(data_table is not None and data_table.strip() != '', |
| 'no data table is specified') |
| try: |
| rv = plpy.execute(""" |
| SELECT count(*) cnt FROM pg_attribute |
| WHERE |
| attrelid = '{data_table}'::regclass AND |
| ((atttypid = 'INT4'::regtype AND attname = 'docid') OR |
| (atttypid = 'INT4'::regtype AND attname = 'wordid') OR |
| (atttypid = 'INT4'::regtype AND attname = 'count')) |
| """.format(data_table=data_table)) |
| _assert(rv[0]['cnt'] == 3, |
| "Table %s should have docid::INT4, wordid::INT4, count::INT4" |
| "columns" % (data_table)) |
| except: |
| _assert(False, |
| "Table %s must exist and should have docid, wordid, and " |
| "count columns" % (data_table)) |
| |
| |
| def _validate_data_table(data_table, voc_size): |
| """ |
| @brief Check the validity of the data table |
| @param data_table The data table |
| @param voc_size The size of vocabulary |
| """ |
| # plpy.notice('validating the data table ...') |
| _validate_data_table_cols(data_table) |
| rv = plpy.execute('SELECT count(*) cnt FROM %s WHERE docid < 0' % (data_table)) |
| warn( |
| rv[0]['cnt'] == 0, |
| """%d rows have negative docid - use continous non-negative |
| integers for docid for better interpretation""" % (rv[0]['cnt'])) |
| |
| rv = plpy.execute('SELECT count(*) cnt FROM %s WHERE wordid < 0' % (data_table)) |
| _assert( |
| rv[0]['cnt'] == 0, |
| """%d rows have negative wordid - the wordid must range from 0 to |
| voc_size - 1""" % (rv[0]['cnt'])) |
| |
| rv = plpy.execute(""" |
| SELECT |
| count(wordid) size, |
| min(wordid) min_wordid, |
| max(wordid) max_wordid |
| FROM |
| ( |
| SELECT wordid |
| FROM {data_table} |
| GROUP BY wordid |
| ) t1 |
| """.format(data_table=data_table)) |
| _assert(0 < rv[0]['size'], 'Table %s is empty' % (data_table)) |
| _assert( |
| rv[0]['size'] <= voc_size, |
| """The actual size of vocabulary %d is larger than the specified %d - |
| set the correct voc_size and try again""" % (rv[0]['size'], voc_size)) |
| _assert( |
| rv[0]['min_wordid'] <= voc_size - 1 and |
| rv[0]['max_wordid'] <= voc_size - 1, |
| """The wordid should be in the range of 0 to %d""" % (voc_size - 1)) |
| warn( |
| rv[0]['size'] == voc_size, |
| """Actual size of vocabulary %d is smaller than the specified %d - |
| the vocabulary and dataset normalization is highly recommended for |
| memory efficiency""" % (rv[0]['size'], voc_size)) |
| warn( |
| rv[0]['min_wordid'] == 0, |
| """The actual min wordid is large than 0 - the vocabulary and dataset |
| normalization is highly recommended for memory efficiency""") |
| warn(rv[0]['max_wordid'] == voc_size - 1, |
| """Actual max word_id is smaller than the specified |
| (voc_size - 1 = %d) - set the voc_size to %d for memory efficiency |
| """ % (voc_size - 1, rv[0]['max_wordid'] + 1)) |
| |
| rv = plpy.execute('SELECT count(*) cnt FROM %s WHERE count <= 0' % (data_table)) |
| _assert( |
| rv[0]['cnt'] == 0, |
| """%d rows have zero or negative count - the value in the count column |
| must be positive integers""" % (rv[0]['cnt'])) |
| |
| |
| def _validate_vocab_table(vocab_table): |
| """ |
| @brief Check the validity of the vocabulary table |
| @param vocab_table The vocabulary table name |
| """ |
| # plpy.notice('checking the vocabulary table ...') |
| try: |
| rv = plpy.execute(""" |
| SELECT count(*) cnt FROM pg_attribute |
| WHERE |
| attrelid = '{vocab_table}'::regclass AND |
| ((atttypid = 'INT4'::regtype AND attname = 'wordid') OR |
| (atttypid = 'text'::regtype AND attname = 'word')) |
| """.format(vocab_table=vocab_table)) |
| _assert( |
| rv[0]['cnt'] == 2, |
| """the %s should have wordid::INT4, word::TEXT columns""" % (vocab_table)) |
| except: |
| _assert(0, "Table %s must exist and should have wordid and " |
| "word columnes" % (vocab_table)) |
| |
| |
| def _validate_norm_vocab_table(vocab_table): |
| """ |
| @brief Check the validity of the normalized vocabulary table |
| @param vocab_table The normalized vocabulary table name |
| """ |
| try: |
| rv = plpy.execute(""" |
| SELECT count(*) cnt FROM pg_attribute |
| WHERE |
| attrelid = '{vocab_table}'::regclass AND |
| ((atttypid = 'INT4'::regtype AND attname = 'wordid') OR |
| (atttypid = 'INT4'::regtype AND attname = 'old_wordid') OR |
| (atttypid = 'text'::regtype AND attname = 'word')) |
| """.format(vocab_table=vocab_table)) |
| _assert( |
| rv[0]['cnt'] == 3, |
| """the %s should have wordid, old_wordid, and word columns""" % |
| (vocab_table)) |
| except: |
| _assert(0, "Table %s must exist and should have wordid, old_wordid, " |
| "and word columns" % (vocab_table)) |
| |
| |
| def _validate_output_data_table(output_data_table, topic_num): |
| """ |
| @brief Check the validity of the output data table |
| @param model_table Output data table name |
| @param topic_num Topic number |
| """ |
| try: |
| rv = plpy.execute(""" |
| SELECT count(*) cnt |
| FROM pg_attribute |
| WHERE |
| attrelid = '{output_data_table}'::regclass AND |
| ((atttypid = 'INT4[]'::regtype AND attname = 'words') OR |
| (atttypid = 'INT4[]'::regtype AND attname = 'counts') OR |
| (atttypid = 'INT4[]'::regtype AND attname = 'topic_count')) |
| """.format(output_data_table=output_data_table)) |
| _assert( |
| rv[0]['cnt'] == 3, |
| "Table %s must have words::INT4[], counts::INT4[], and" |
| " topic_count::INT4[] columns" % (output_data_table)) |
| except: |
| _assert(0, |
| "Table %s must exist and should have words, counts, and " |
| " topic_count columns" % (output_data_table)) |
| |
| rv = plpy.execute(""" |
| SELECT min(dim) min_dim, max(dim) max_dim |
| FROM |
| ( |
| SELECT array_upper(topic_count, 1) dim |
| FROM {output_data_table} |
| ) t1 |
| """.format(output_data_table=output_data_table)) |
| _assert(rv[0]['min_dim'] == rv[0]['max_dim'] and |
| rv[0]['min_dim'] == topic_num, |
| "Dimension mismatch - array_upper(topic_count, 1) <> topic_num") |
| |
| |
| def _validate_model_table(model_table): |
| """ |
| @brief Check the validity of the model table |
| @param model_table Model table name |
| """ |
| # plpy.notice('checking the model table ...') |
| try: |
| rv = plpy.execute(""" |
| SELECT count(*) cnt |
| FROM pg_attribute |
| WHERE |
| attrelid = '{model_table}'::regclass AND |
| ((atttypid = 'INT4'::regtype AND attname = 'voc_size') OR |
| (atttypid = 'INT4'::regtype AND attname = 'topic_num') OR |
| (atttypid = 'FLOAT8'::regtype AND attname = 'alpha') OR |
| (atttypid = 'FLOAT8'::regtype AND attname = 'beta') OR |
| (atttypid = 'INT8[]'::regtype AND attname = 'model')) |
| """.format(model_table=model_table)) |
| _assert( |
| rv[0]['cnt'] == 5, |
| """Table %s must have voc_size::INT4, topic_num::INT4, alpha::FLOAT8, |
| beta::FLOAT8, and model::INT8[] columns""" % (model_table)) |
| except: |
| _assert(0, |
| """Table %s must exist and should have voc_size, topic_num, alpha, |
| beta, word_topic, and corpus_topic columns""" % (model_table)) |
| |
| rv = plpy.execute(""" |
| SELECT voc_size, topic_num, alpha, beta, |
| array_upper(model, 1) model_size |
| FROM %s |
| """ % (model_table)) |
| _assert(len(rv) > 0, '%s is empty' % (model_table)) |
| _assert(len(rv) == 1, '%s should have only 1 row' % (model_table)) |
| _assert(rv[0]['voc_size'] > 0, |
| 'voc_size in %s should be a positive integer' % (model_table)) |
| _assert(rv[0]['topic_num'] > 0, |
| 'topic_num in %s should be a positive integer' % (model_table)) |
| _assert(rv[0]['alpha'] > 0, |
| 'alpha in %s should be a positive real number' % (model_table)) |
| _assert(rv[0]['beta'] > 0, |
| 'beta in %s should be a positive real number' % (model_table)) |
| _assert(rv[0]['model_size'] == ((rv[0]['voc_size']) * (rv[0]['topic_num'] + 1) + 1) / 2, |
| "model_size mismatches with voc_size and topic_num in %s" % (model_table)) |