blob: 29c9d85527a4a91a87576e9f0b4fc34de56bfc32 [file] [log] [blame]
"""
@file lda.py_in
@brief Latent Dirichlet Allocation inference using collapsed Gibbs
sampling algorithm
@namespace lda
LDA: Driver and auxiliary functions
"""
import plpy
# import time
# use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0
from utilities.control import OptimizerControl
from utilities.control import HashaggControl
from utilities.utilities import __mad_version, _assert, warn
from utilities.validate_args import output_tbl_valid
from utilities.validate_args import input_tbl_valid
from utilities.utilities import py_list_to_sql_string
# use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0
version_wrapper = __mad_version()
string_to_array = version_wrapper.select_vecfunc()
array_to_string = version_wrapper.select_vec_return()
class LDATrainer:
"""
@brief This class defines a LDATrainer.
"""
def __init__(self, schema_madlib, data_table, model_table,
output_data_table, voc_size, topic_num,
iter_num, alpha, beta, evaluate_every, perplexity_tol ):
self.schema_madlib = schema_madlib
self.data_table = data_table
self.voc_size = voc_size
self.topic_num = topic_num
self.iter_num = iter_num
self.alpha = alpha
self.beta = beta
self.model_table = model_table
self.output_data_table = output_data_table
self.work_table_0 = '__work_table_train_0__'
self.work_table_1 = '__work_table_train_1__'
self.evaluate_every = evaluate_every
self.perplexity_tol = perplexity_tol
self.perplexity = []
self.perplexity_diff = self.perplexity_tol
self.perplexity_iters = []
self.tol_reached = False
self.num_iterations = 0
plpy.execute("DROP TABLE IF EXISTS " + self.work_table_0)
plpy.execute("""
CREATE TEMP TABLE {work_table_0}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[],
doc_topic INT4[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE)
DISTRIBUTED BY (docid)')
""".format(work_table_0=self.work_table_0))
plpy.execute("DROP TABLE IF EXISTS " + self.work_table_1)
plpy.execute("""
CREATE TEMP TABLE {work_table_1}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[],
doc_topic INT4[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE)
DISTRIBUTED BY (docid)')
""".format(work_table_1=self.work_table_1))
plpy.execute("""
CREATE TABLE {model_table}(
voc_size INT4,
topic_num INT4,
alpha FLOAT8,
beta FLOAT8,
model INT8[],
num_iterations INT,
perplexity DOUBLE PRECISION[],
perplexity_iters INTEGER[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE)
DISTRIBUTED RANDOMLY')
""".format(model_table=self.model_table))
plpy.execute("""
CREATE TABLE {output_data_table}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[],
topic_count INT4[],
topic_assignment INT4[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE)
DISTRIBUTED BY (docid)')
""".format(output_data_table=self.output_data_table))
def init_random(self):
# stime = time.time()
# plpy.notice('initializing topics randomly ...')
plpy.execute('TRUNCATE TABLE ' + self.work_table_0)
plpy.execute("""
INSERT INTO {work_table}
SELECT
docid, wordcount, words, counts,
{schema_madlib}.__lda_random_assign(wordcount, {topic_num}) AS topics
FROM {data_table}
""".format(work_table=self.work_table_0,
schema_madlib=self.schema_madlib,
topic_num=self.topic_num,
data_table=self.data_table))
# etime = time.time()
# plpy.notice('\t\ttime elapsed: %.2f seconds' % (etime - stime))
def gen_final_data_tables(self):
# stime = time.time()
# plpy.notice('\t\tgenerating output data table ...')
##This function updates 2 tables, one is the model table and
# the other one is the output table
work_table_final = self.work_table_0 if self.iter_num % 2 == 0 \
else self.work_table_1
# Update model table
#JIRA:MADLIB-1201, we have to update model table one more time after
# iteration to sycn up output table and model table
self.update_model_table(work_table_final)
self.gen_output_data_table(work_table_final)
# JIRA: MADLIB-1351
# Calculate Perplexity after the final update of
# the Model and Output Table
if self.evaluate_every > 0 and not self.tol_reached:
self.perplexity.append(
get_perplexity(self.schema_madlib,
self.model_table,
self.output_data_table))
# Need to update Model Table one more time to update
# last calculated value of perplexity to it
self.update_model_table(work_table_final)
def iteration(self, it):
# stime = time.time()
work_table_in = self.work_table_0
work_table_out = self.work_table_1
if it % 2 == 0:
work_table_in = self.work_table_1
work_table_out = self.work_table_0
# Update model_table using work_table_out from the previous iteration
self.update_model_table(work_table_in)
# For GPDB 4.3 and higher we disable the optimzer (ORCA) for the query
# planner since currently ORCA doesn't support InitPlan. This would have
# to be fixed when ORCA is the only available query planner.
with OptimizerControl(False):
# plpy.notice('\t\tsampling ...')
plpy.execute('TRUNCATE TABLE ' + work_table_out)
query = """
INSERT INTO {work_table_out}
SELECT
docid, wordcount, words, counts,
{schema_madlib}.__lda_gibbs_sample(
words, counts, doc_topic,
(SELECT model FROM {model_table}),
{alpha}, {beta}, {voc_size}, {topic_num}, 1)
FROM
{work_table_in}
""".format(work_table_out=work_table_out,
schema_madlib=self.schema_madlib,
model_table=self.model_table,
alpha=self.alpha,
beta=self.beta,
voc_size=self.voc_size,
topic_num=self.topic_num,
work_table_in=work_table_in)
plpy.execute(query)
# etime = time.time()
# plpy.notice('\t\ttime elapsed: %.2f seconds' % (etime - stime))
self.calculatePerplexity(it, work_table_in)
def update_model_table(self, work_table_in):
# JIRA: MADLIB-1351
# Create a string based on the value of self.perplexity
perplexity_values = ""
perplexity_iterations = ""
if len(self.perplexity) >= 1:
perplexity_values = ", {0}".format(py_list_to_sql_string(self.perplexity))
perplexity_iterations = ", {0}".format(py_list_to_sql_string(self.perplexity_iters))
n_iterations=", {0}".format(self.num_iterations)
plpy.execute('TRUNCATE TABLE ' + self.model_table)
if version_wrapper.is_gp43():
with OptimizerControl(True):
plpy.execute("""
INSERT INTO {model_table}
SELECT
{voc_size},
{topic_num},
{alpha},
{beta},
{schema_madlib}.__lda_count_topic_agg(
words,
counts,
doc_topic[{topic_num} + 1:array_upper(doc_topic, 1)],
{voc_size},
{topic_num}
) AS model
{n_iterations}{perplexity_values} {perplexity_iterations}
FROM {work_table_in}
""".format(model_table=self.model_table,
topic_num=self.topic_num,
voc_size=self.voc_size,
alpha=self.alpha,
beta=self.beta,
perplexity_values=perplexity_values,
perplexity_iterations=perplexity_iterations,
schema_madlib=self.schema_madlib,
work_table_in=work_table_in,
n_iterations=n_iterations))
else:
# work around insertion memory error (MPP-25561)
# by copying the model to Python temporarily
model = plpy.execute("""
SELECT
{schema_madlib}.__lda_count_topic_agg(
words,
counts,
doc_topic[{topic_num} + 1:array_upper(doc_topic, 1)],
{voc_size},
{topic_num}
) AS model
FROM {work_table_in}
""".format(schema_madlib=self.schema_madlib,
topic_num=self.topic_num,
voc_size=self.voc_size,
work_table_in=work_table_in))[0]['model']
# insert it back immediately
plan = plpy.prepare("""
INSERT INTO {model_table}
SELECT
{voc_size}, {topic_num}, {alpha}, {beta}, $1
{n_iterations}{perplexity_values} {perplexity_iterations}
""".format(model_table=self.model_table,
topic_num=self.topic_num,
voc_size=self.voc_size,
alpha=self.alpha,
beta=self.beta,
perplexity_values=perplexity_values,
perplexity_iterations=perplexity_iterations,
schema_madlib=self.schema_madlib,
n_iterations=n_iterations),
['bigint[]'])
plpy.execute(plan, [model])
def run(self):
# stime = time.time()
# plpy.notice('start training process ...')
self.init_random()
# sstime = time.time()
for it in range(1, self.iter_num + 1):
# JIRA: MADLIB-1351
# If the Perplexity_diff is less than the perplexity_tol,
# Stop the iteration
if self.perplexity_diff < self.perplexity_tol:
self.tol_reached = True
# When toll is reached before the number of iterations,
# Reduce the num_iterations by 1 since perplexity_iters
# Runs one iteration behind in this case.
self.num_iterations-=1;
break
self.iteration(it)
self.num_iterations+=1;
# eetime = time.time()
# plpy.notice('\t\titeration done, time elapsed: %.2f seconds' % (eetime - sstime))
# JIRA: MADLIB-1351
# Add the last iteration value to the array
if self.evaluate_every > 0 and not self.tol_reached:
self.perplexity_iters.append(self.iter_num)
self.gen_final_data_tables()
# etime = time.time()
# plpy.notice('finished, time elapsed: %.2f seconds' % (etime - stime))
# Update output table
def gen_output_data_table(self, work_table_final):
plpy.execute("TRUNCATE TABLE " + self.output_data_table)
plpy.execute("""
INSERT INTO {output_data_table}
SELECT
docid, wordcount, words, counts, doc_topic[1:{topic_num}] topic_count,
doc_topic[{topic_num} + 1:array_upper(doc_topic,1)] topic_assignment
FROM
{work_table_final}
""".format(output_data_table=self.output_data_table,
topic_num=self.topic_num,
work_table_final=work_table_final))
# etime = time.time()
# plpy.notice('\t\t\ttime elapsed: %.2f seconds' % (etime - stime))
def calculatePerplexity(self,it, work_table_in):
# JIRA: MADLIB-1351
# Calculate Perplexity for evaluate_every Iteration
# Skip the calculation at the first iteration
# For each iteration:
# Model table is updated (for the first iteration, it is the random model. For iteration >1 , the model that is # updated is learnt in the previous iteration)
# __lda_count_topic_agg is called then lda_gibbs_sample is called which learns and updates the model(the updated # model is not passed to python. The learnt model is updated in the next iteration)
# Because of this workflow we can safely ignore the first perplexity value.
if it > self.evaluate_every and self.evaluate_every > 0 and (
it - 1) % self.evaluate_every == 0:
self.gen_output_data_table(work_table_in)
perplexity = get_perplexity(self.schema_madlib,
self.model_table,
self.output_data_table)
if len(self.perplexity) > 0:
self.perplexity_diff = abs(self.perplexity[-1] - perplexity)
self.perplexity_iters.append(it - 1)
self.perplexity.append(perplexity)
# ------------------------------------------------------------------------------
class LDAPredictor:
"""
@brief This class defines a LDAPredictor
"""
def __init__(self, schema_madlib, test_table, model_table, output_table,
iter_num):
self.schema_madlib = schema_madlib
self.test_table = test_table
self.model_table = model_table
self.iter_num = iter_num
rv = plpy.execute("""
SELECT
voc_size, topic_num, alpha, beta
FROM
{model_table}
""".format(model_table=self.model_table))
self.voc_size = rv[0]['voc_size']
self.topic_num = rv[0]['topic_num']
self.alpha = rv[0]['alpha']
self.beta = rv[0]['beta']
self.doc_topic = output_table
self.work_table_0 = '__work_table_pred_0__'
self.work_table_1 = '__word_table_pred_1__'
plpy.execute("DROP TABLE IF EXISTS " + self.work_table_0)
plpy.execute("""
CREATE TEMP TABLE {work_table_0}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[],
doc_topic INT4[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE)
DISTRIBUTED BY (docid)')
""".format(work_table_0=self.work_table_0))
plpy.execute("DROP TABLE IF EXISTS " + self.work_table_1)
plpy.execute("""
CREATE TEMP TABLE {work_table_1}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[],
doc_topic INT4[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE)
DISTRIBUTED BY (docid)')
""".format(work_table_1=self.work_table_1))
plpy.execute("""
CREATE TABLE {doc_topic}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[],
topic_count INT4[],
topic_assignment INT4[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE)
DISTRIBUTED BY (docid)')
""".format(doc_topic=self.doc_topic))
def init_random(self):
# stime = time.time()
# plpy.notice('initializing topics randomly ...')
plpy.execute('TRUNCATE TABLE ' + self.work_table_0)
plpy.execute("""
INSERT INTO {work_table}
SELECT
docid, wordcount, words, counts,
{schema_madlib}.__lda_random_assign(wordcount, {topic_num}) AS topics
FROM {data_table}
""".format(work_table=self.work_table_0,
schema_madlib=self.schema_madlib,
topic_num=self.topic_num,
data_table=self.test_table))
# etime = time.time()
# plpy.notice('\t\ttime elapsed: %.2f seconds' % (etime - stime))
def gen_output_table(self):
plpy.execute("TRUNCATE TABLE " + self.doc_topic)
plpy.execute("""
INSERT INTO {doc_topic}
SELECT
docid, wordcount, words, counts, doc_topic[1:{topic_num}] topic_count,
doc_topic[{topic_num} + 1:array_upper(doc_topic,1)] topic_assignment
FROM {work_table_out}
""".format(doc_topic=self.doc_topic,
topic_num=self.topic_num,
work_table_out=self.work_table_1))
def infer(self):
# stime = time.time()
# plpy.notice('infering ...')
# For GPDB 4.3 and higher we disable the optimzer (ORCA) for the query
# planner since currently ORCA doesn't support InitPlan. This would have
# to be fixed when ORCA is the only available query planner.
with OptimizerControl(False):
query = """
INSERT INTO {work_table_out}
SELECT
docid, wordcount, words, counts,
{schema_madlib}.__lda_gibbs_sample(
words, counts, doc_topic,
(SELECT model FROM {model_table}),
{alpha}, {beta},
{voc_size}, {topic_num}, {iter_num}
)
FROM
{work_table_in}
""".format(work_table_out=self.work_table_1,
work_table_in=self.work_table_0,
model_table=self.model_table,
schema_madlib=self.schema_madlib,
alpha=self.alpha,
beta=self.beta,
voc_size=self.voc_size,
topic_num=self.topic_num,
iter_num=self.iter_num)
plpy.execute(query)
# etime = time.time()
# plpy.notice('\t\ttime elapsed: %.2f seconds' % (etime - stime))
def run(self):
# stime = time.time()
# plpy.notice('start prediction process ...')
self.init_random()
self.infer()
self.gen_output_table()
# etime = time.time()
# plpy.notice('finished, time elapsed: %.2f seconds' % (etime - stime))
# ------------------------------------------------------------------------------
def lda_train(schema_madlib, train_table, model_table, output_data_table, voc_size,
topic_num, iter_num, alpha, beta, evaluate_every, perplexity_tol):
"""
@brief This function provides the entry for the LDA training process.
@param schema_madlib MDALib schema
@param data_table Training data table
@param voc_size Size of vocabulary
@param topic_num Number of topics
@param iter_num Number of iterations
@param alpha Dirichlet parameter for per-document topic multinomial
@param beta Dirichlet parameter for per-topic word multinomial
@param model_table Learned model table
@param output_data_table Output data table
"""
_assert(train_table is not None and train_table.strip() != '',
'invalid argument: train_table is not specified')
_assert(model_table is not None and model_table.strip() != '',
'invalid argument: model_table is not specified')
_assert(output_data_table is not None and output_data_table.strip() != '',
'invalid argument: output_data_table is not specified')
_assert(voc_size is not None and voc_size > 0,
'invalid argument: positive integer expected for voc_size')
_assert(topic_num is not None and topic_num > 0,
'invalid argument: positive integer expected for topic_num')
_assert(iter_num is not None and iter_num > 0,
'invalid argument: positive integer expected for iter_num')
_assert(alpha is not None and alpha > 0,
'invalid argument: positive real expected for alpha')
_assert(beta is not None and beta > 0,
'invalid argument: positive real expected for beta')
# Setting the default values for perplexity_tol and evaluate_every
if perplexity_tol is None:
perplexity_tol = 0.1
if evaluate_every is None:
evaluate_every = 0
_assert(evaluate_every <= iter_num,
'invalid argument: evaluate_every should not be greater than iter_num')
_assert(perplexity_tol is not None and perplexity_tol >= 0,
'invalid argument: perplexity_tol should not be less than 0')
output_tbl_valid(model_table, 'LDA')
output_tbl_valid(output_data_table, 'LDA')
warn(voc_size <= 1e5,
"""the voc_size is very large: %d - make sure that the system has
enough memory or reduce the vocabulary size""" % (voc_size))
warn(topic_num <= 1e3,
"""the topic_num is large: %d - make sure that the system has enough
memory or reduce the number of topics """ % (topic_num))
_validate_data_table(train_table, voc_size)
convt_table = _convert_data_table(schema_madlib, train_table)
lt = LDATrainer(schema_madlib, convt_table, model_table,
output_data_table, voc_size, topic_num,
iter_num, alpha, beta,evaluate_every, perplexity_tol)
lt.run()
# __lda_check_count_ceiling returns NULL when the count ceiling is
# not hit, otherwise, a sample of wordid's that hit
examples_hit_ceiling = plpy.execute("""
SELECT
array_to_string(
{schema_madlib}.__lda_check_count_ceiling(
model,
voc_size,
topic_num)::integer[],
','
) AS examples_hit_ceiling
FROM {model_table}
""".format(**locals()))[0]['examples_hit_ceiling']
if examples_hit_ceiling is not None:
plpy.info("""
LDA warning: some words ({examples_hit_ceiling}, etc.) occur too frequently.
More than 2e9 times. It might add noise for the LDA model.
Please remove the very frequent words in the data.
""".format(**locals()))
def lda_predict(schema_madlib, test_table, model_table, output_data_table,
iter_num=20):
"""
@brief This function provides the entry for the LDA prediction process.
@param test_table name of the testing dataset table
@param model_table name of the model table
@param iter_num number of iterations
@param output_table name of output table
"""
input_tbl_valid(test_table, 'LDA')
input_tbl_valid(model_table, 'LDA')
output_tbl_valid(output_data_table, 'LDA')
iter_num = 20 if iter_num is None else iter_num
_assert(
iter_num >= 0,
'invalid argument: positive integer expected for iter_num')
warn(
iter_num <= 20,
"""the iter_num is large: %d - a smaller iter_num (e.g. 20) should be
good enough""" % (iter_num))
_validate_model_table(model_table)
rv = plpy.execute('SELECT voc_size FROM ' + model_table)
voc_size = rv[0]['voc_size']
_validate_data_table(test_table, voc_size)
convt_table = _convert_data_table(schema_madlib, test_table)
lp = LDAPredictor(
schema_madlib, convt_table, model_table, output_data_table, iter_num)
lp.run()
def get_topic_desc(schema_madlib, model_table, vocab_table, desc_table,
top_k=15):
"""
@brief Get the per-topic description by top-k words
@param model_table The model table generated by the training process
@param vocab_table The vocabulary table
@param top_k The top k words for topic description
@param desc_table The output table for storing the per-topic word description
"""
_assert(model_table != '' and vocab_table != '' and desc_table != '',
"invalid argument: at least one of the table names is not specified")
_assert(top_k > 0, "invalid argument: Positive integer expected for top_k")
_validate_model_table(model_table)
_validate_vocab_table(vocab_table)
output_tbl_valid(desc_table, 'LDA')
plpy.execute("DROP TABLE IF EXISTS __lda_topic_word_count__")
plpy.execute("""
CREATE TEMP TABLE __lda_topic_word_count__(
topicid INT4,
word_count INT4[],
beta FLOAT8
)
m4_ifdef(
`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE) DISTRIBUTED BY
(topicid)')
""")
plpy.execute("""
INSERT INTO __lda_topic_word_count__
SELECT
generate_series(0, topic_num - 1) AS topicid,
{schema_madlib}.__lda_util_unnest_transpose(
model,
voc_size,
topic_num
) AS word_count,
beta
FROM {model_table}
""".format(
schema_madlib=schema_madlib,
model_table=model_table))
plpy.execute("DROP TABLE IF EXISTS __lda_topic_word_dist__")
plpy.execute("""
CREATE TEMP TABLE __lda_topic_word_dist__(
topicid INT4,
word_dist FLOAT[]
)
m4_ifdef(
`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE)
DISTRIBUTED BY (topicid)')
""")
plpy.execute("""
INSERT INTO __lda_topic_word_dist__
SELECT
topicid,
{schema_madlib}.__lda_util_norm_with_smoothing(word_count, beta) dist
FROM
__lda_topic_word_count__
""".format(schema_madlib=schema_madlib))
plpy.execute("""
CREATE TABLE {desc_table} (
topicid INT4,
wordid INT4,
prob FLOAT,
word TEXT)
m4_ifdef(
`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE)
DISTRIBUTED BY (topicid)')
""".format(desc_table=desc_table))
plpy.execute("""
INSERT INTO {desc_table}
SELECT
topicid, t2.wordid, prob, word
FROM
(
SELECT
topicid, wordid, prob,
rank() OVER(PARTITION BY topicid ORDER BY prob DESC) r
FROM
(
SELECT
topicid,
generate_series(0, array_upper(word_dist, 1) - 1) wordid,
unnest(word_dist) prob
FROM
__lda_topic_word_dist__
) t1
) t2, {vocab_table} AS vocab
WHERE t2.r <= {top_k} AND t2.wordid = vocab.wordid
""".format(desc_table=desc_table,
vocab_table=vocab_table, top_k=top_k))
def get_topic_word_count(schema_madlib, model_table, output_table):
"""
@brief Get the per-topic word counts from the model table
@param model_table The model table generated by the training process
@param output_table The output table for storing the per-topic word counts
"""
_assert(
model_table != '',
'invalid argument: model table name is not specified')
_validate_model_table(model_table)
output_tbl_valid(output_table, 'LDA')
plpy.execute("""
CREATE TABLE {output_table} (
topicid INT4,
word_count INT4[])
m4_ifdef(
`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE)
DISTRIBUTED BY (topicid)')
""".format(output_table=output_table))
plpy.execute("""
INSERT INTO {output_table}
SELECT
generate_series(0, topic_num - 1) topicid,
{schema_madlib}.__lda_util_unnest_transpose(
model,
voc_size,
topic_num
) AS word_count
FROM
{model_table}
""".format(output_table=output_table, schema_madlib=schema_madlib,
model_table=model_table))
def get_word_topic_count(schema_madlib, model_table, output_table):
"""
@brief Get the per-word topic counts from the model table
@param model_table The model table generated by the training process
@param output_table The output table for storing the per-word topic counts
"""
_assert(model_table != '',
"invalid argument: model table name is not specified")
_validate_model_table(model_table)
output_tbl_valid(output_table, 'LDA')
plpy.execute("""
CREATE TABLE {output_table} (
wordid INT4,
topic_count INT4[])
m4_ifdef(
`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE)
DISTRIBUTED BY (wordid)')
""".format(output_table=output_table))
plpy.execute("""
INSERT INTO {output_table}
SELECT
generate_series(0, voc_size - 1) wordid,
{schema_madlib}.__lda_util_unnest(model, voc_size, topic_num) word_count
FROM {model_table}
""".format(output_table=output_table, schema_madlib=schema_madlib,
model_table=model_table))
def get_word_topic_mapping(schema_madlib, lda_output_table, mapping_table):
"""
@brief Get the wordid - topicid mapping from the lda training output table
@param lda_output_table The output table from lda traning or predicting
@param mapping_table The result table that saves the mapping info
"""
_assert(lda_output_table != '',
"invalid argument: LDA output table name is not specified")
output_tbl_valid(mapping_table, 'LDA')
plpy.execute("""
CREATE TABLE {mapping_table} (
docid INT4,
wordid INT4,
topicid INT4)
m4_ifdef(
`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE)
DISTRIBUTED BY (docid)')
""".format(mapping_table=mapping_table))
## The following query is a workaround for GPDB 4.3 because it cannot
## convert text string to svec (that's why we have to call
## array_to_string first to form a string and then call
## {schema_madlib}.svec_from_string to convert it to svec format).
## In GPDB5, the query can be written as
## ```
## INSERT INTO {mapping_table}
## SELECT docid, unnest((counts::text || ':' ||
## words::text)::{schema_madlib}.svec::float[]) AS wordid,
## unnest(topic_assignment) AS topicid
## FROM {lda_output_table}
## GROUP BY docid, wordid, topicid
## ```
## Also look at validate_lda_output() function in lda install check
## which applies the same workaround
plpy.execute("""
INSERT INTO {mapping_table}
SELECT docid,
unnest({schema_madlib}.svec_from_string('{{' || array_to_string(counts, ',') || '}}:{{' || array_to_string(words, ',') || '}}')::float[]) AS wordid,
unnest(topic_assignment) AS topicid
FROM {lda_output_table}
GROUP BY docid, wordid, topicid
ORDER BY docid
""".format(lda_output_table=lda_output_table,
schema_madlib=schema_madlib, mapping_table=mapping_table))
def get_perplexity(schema_madlib, model_table, output_data_table):
"""
@brief Get the perplexity given the prediction and model.
@param model_table The model table generated by lda_train
@param output_data_table The output data table generated by lda_predict
"""
_assert(model_table != '' and output_data_table != '',
'invalid argument: at least one of the table names is not specified')
_validate_model_table(model_table)
params = plpy.execute("""
SELECT topic_num, voc_size, alpha, beta FROM {model_table}
""".format(model_table=model_table))[0]
topic_num = params['topic_num']
voc_size = params['voc_size']
alpha = params['alpha']
beta = params['beta']
_validate_output_data_table(output_data_table, topic_num)
# For GPDB 4.3 and higher we disable the optimzer (ORCA) for the query
# planner since currently ORCA doesn't support InitPlan. This would have
# to be fixed when ORCA is the only available query planner.
with OptimizerControl(False):
query = """
SELECT exp(-part_perp/total_word) AS perp
FROM
(
SELECT {schema_madlib}.__lda_perplexity_agg(
words, counts, topic_count, (SELECT model FROM {model_table}),
{alpha}, {beta}, {voc_size}, {topic_num}) AS part_perp
FROM
{out_data_table}
) t1,
(
SELECT sum(wordcount) total_word FROM {out_data_table}
) t2
""".format(schema_madlib=schema_madlib,
out_data_table=output_data_table,
model_table=model_table,
alpha=alpha,
beta=beta,
topic_num=topic_num,
voc_size=voc_size)
rv = plpy.execute(query)
return rv[0]['perp']
def norm_vocab(vocab_table, out_table):
"""
@brief Checks the vocabulary and converts non-continous wordids into continuous
integers ranging from 0 to voc_size - 1.
@param vocab_table The vocabulary table in the form of
<wordid::INT4, word::text>
@param out_table The normalized vocabulary table
"""
_validate_vocab_table(vocab_table)
output_tbl_valid(out_table, 'LDA')
plpy.execute("""
CREATE TABLE {out_table}(
wordid INT4,
old_wordid INT4,
word TEXT
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH(APPENDONLY=TRUE)
DISTRIBUTED BY (wordid)')
""".format(out_table=out_table))
plpy.execute("""
INSERT INTO {out_table}
SELECT r - 1, wordid, word
FROM
(
SELECT
wordid, word, rank() OVER(ORDER BY wordid) r
FROM
(
SELECT wordid, word
FROM
(
SELECT
wordid,
word,
rank() OVER(PARTITION BY wordid ORDER BY word ASC) r
FROM
{vocab_table}
) t1
WHERE r = 1
) t2
) t3
""".format(out_table=out_table, vocab_table=vocab_table))
def norm_dataset(data_table, vocab_table, output_table):
"""
@brief Normalize the data table according to the normalized vocabulary, rows
with non-positive count values will be removed
@param data_table The data table to be normalized
@param vocab_table The nomralized vocabulary table
@param output_table The normalized data table
"""
_validate_data_table_cols(data_table)
_validate_norm_vocab_table(vocab_table)
output_tbl_valid(output_table, 'LDA')
plpy.execute("""
CREATE TABLE {output_table}(
docid INT4,
wordid INT4,
count INT4
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH(APPENDONLY=TRUE)
DISTRIBUTED BY (docid)')
""".format(output_table=output_table))
plpy.execute("""
INSERT INTO {output_table}
SELECT
docid,
vocab.wordid,
count
FROM
{data_table} AS data,
{vocab_table} AS vocab
WHERE
data.wordid = vocab.old_wordid AND data.count > 0
""".format(output_table=output_table,
data_table=data_table,
vocab_table=vocab_table))
def conorm_data(data_table, vocab_table, output_data_table, output_vocab_table):
"""
@brief Co-normalize the data table and the vocabulary table
@param data_table The data table to be normalized
@param vocab_table The vocabulary table to be nomralized
@param output_data_table The normalized data table
@param output_vocab_table The normalized vocabulary table
"""
_validate_data_table_cols(data_table)
_validate_vocab_table(vocab_table)
output_tbl_valid(output_data_table, 'LDA')
output_tbl_valid(output_vocab_table, 'LDA')
plpy.execute("DROP TABLE IF EXISTS __vocab__")
plpy.execute("""
CREATE TEMP TABLE __vocab__(
wordid INT4,
word TEXT
)
m4_ifdef(`__POSTGRESQL__', `', `WITH(APPENDONLY=TRUE)
DISTRIBUTED BY (wordid)')
""")
plpy.execute("""
INSERT INTO __vocab__
SELECT voc.wordid, voc.word
FROM
(
SELECT
wordid
FROM
{data_table} AS data
GROUP BY wordid
) tvoc, {vocab_table} voc
WHERE tvoc.wordid = voc.wordid
""".format(data_table=data_table, vocab_table=vocab_table))
norm_vocab('__vocab__', output_vocab_table)
norm_dataset(data_table, output_vocab_table, output_data_table)
def index_sort(arr, **kwargs):
"""
@brief Return the index of elements in a sorted order
@param arr The array to be sorted
@return The index of elements
"""
# process arrays for GPDB < 4.1 and PG < 9.0
arr = string_to_array(arr, False)
dim = len(arr)
idx = range(dim)
idx.sort(key=lambda r: arr[r])
return array_to_string(map(lambda r: r + 1, idx))
def _convert_data_table(schema_madlib, data_table):
"""
@brief Convert the format of data table from <docid, wordid, count> to <docid,
wordcount, words, counts>.
@param data_table The data table to be converted
@param return The converted table name
"""
# plpy.notice('converting the data table ...')
convt_table = '__lda_convt_corpus__'
plpy.execute("DROP TABLE IF EXISTS " + convt_table)
plpy.execute("""
CREATE TEMP TABLE {convt_table}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[]
)
m4_ifdef(`__POSTGRESQL__', `', `WITH(APPENDONLY=TRUE)
DISTRIBUTED BY (docid)')
""".format(convt_table=convt_table))
with OptimizerControl(False):
with HashaggControl(False):
plpy.execute("""
INSERT INTO {convt_table}
SELECT
docid,
sum(count) wordcount,
array_agg(wordid) words,
array_agg(count) counts
FROM
{data_table}
WHERE
(docid IS NOT NULL) AND
(wordid IS NOT NULL) AND
(count IS NOT NULL)
GROUP BY docid
""".format(convt_table=convt_table,
schema_madlib=schema_madlib,
data_table=data_table))
return convt_table
def _validate_data_table_cols(data_table):
"""
@brief Check the structure of the data table
@param data_table The data table name
"""
# plpy.notice('checking the data table ... ')
_assert(data_table is not None and data_table.strip() != '',
'no data table is specified')
try:
rv = plpy.execute("""
SELECT count(*) cnt FROM pg_attribute
WHERE
attrelid = '{data_table}'::regclass AND
((atttypid = 'INT4'::regtype AND attname = 'docid') OR
(atttypid = 'INT4'::regtype AND attname = 'wordid') OR
(atttypid = 'INT4'::regtype AND attname = 'count'))
""".format(data_table=data_table))
_assert(rv[0]['cnt'] == 3,
"Table %s should have docid::INT4, wordid::INT4, count::INT4"
"columns" % (data_table))
except:
_assert(False,
"Table %s must exist and should have docid, wordid, and "
"count columns" % (data_table))
def _validate_data_table(data_table, voc_size):
"""
@brief Check the validity of the data table
@param data_table The data table
@param voc_size The size of vocabulary
"""
# plpy.notice('validating the data table ...')
_validate_data_table_cols(data_table)
rv = plpy.execute('SELECT count(*) cnt FROM %s WHERE docid < 0' % (data_table))
warn(
rv[0]['cnt'] == 0,
"""%d rows have negative docid - use continous non-negative
integers for docid for better interpretation""" % (rv[0]['cnt']))
rv = plpy.execute('SELECT count(*) cnt FROM %s WHERE wordid < 0' % (data_table))
_assert(
rv[0]['cnt'] == 0,
"""%d rows have negative wordid - the wordid must range from 0 to
voc_size - 1""" % (rv[0]['cnt']))
rv = plpy.execute("""
SELECT
count(wordid) size,
min(wordid) min_wordid,
max(wordid) max_wordid
FROM
(
SELECT wordid
FROM {data_table}
GROUP BY wordid
) t1
""".format(data_table=data_table))
_assert(0 < rv[0]['size'], 'Table %s is empty' % (data_table))
_assert(
rv[0]['size'] <= voc_size,
"""The actual size of vocabulary %d is larger than the specified %d -
set the correct voc_size and try again""" % (rv[0]['size'], voc_size))
_assert(
rv[0]['min_wordid'] <= voc_size - 1 and
rv[0]['max_wordid'] <= voc_size - 1,
"""The wordid should be in the range of 0 to %d""" % (voc_size - 1))
warn(
rv[0]['size'] == voc_size,
"""Actual size of vocabulary %d is smaller than the specified %d -
the vocabulary and dataset normalization is highly recommended for
memory efficiency""" % (rv[0]['size'], voc_size))
warn(
rv[0]['min_wordid'] == 0,
"""The actual min wordid is large than 0 - the vocabulary and dataset
normalization is highly recommended for memory efficiency""")
warn(rv[0]['max_wordid'] == voc_size - 1,
"""Actual max word_id is smaller than the specified
(voc_size - 1 = %d) - set the voc_size to %d for memory efficiency
""" % (voc_size - 1, rv[0]['max_wordid'] + 1))
rv = plpy.execute('SELECT count(*) cnt FROM %s WHERE count <= 0' % (data_table))
_assert(
rv[0]['cnt'] == 0,
"""%d rows have zero or negative count - the value in the count column
must be positive integers""" % (rv[0]['cnt']))
def _validate_vocab_table(vocab_table):
"""
@brief Check the validity of the vocabulary table
@param vocab_table The vocabulary table name
"""
# plpy.notice('checking the vocabulary table ...')
try:
rv = plpy.execute("""
SELECT count(*) cnt FROM pg_attribute
WHERE
attrelid = '{vocab_table}'::regclass AND
((atttypid = 'INT4'::regtype AND attname = 'wordid') OR
(atttypid = 'text'::regtype AND attname = 'word'))
""".format(vocab_table=vocab_table))
_assert(
rv[0]['cnt'] == 2,
"""the %s should have wordid::INT4, word::TEXT columns""" % (vocab_table))
except:
_assert(0, "Table %s must exist and should have wordid and "
"word columnes" % (vocab_table))
def _validate_norm_vocab_table(vocab_table):
"""
@brief Check the validity of the normalized vocabulary table
@param vocab_table The normalized vocabulary table name
"""
try:
rv = plpy.execute("""
SELECT count(*) cnt FROM pg_attribute
WHERE
attrelid = '{vocab_table}'::regclass AND
((atttypid = 'INT4'::regtype AND attname = 'wordid') OR
(atttypid = 'INT4'::regtype AND attname = 'old_wordid') OR
(atttypid = 'text'::regtype AND attname = 'word'))
""".format(vocab_table=vocab_table))
_assert(
rv[0]['cnt'] == 3,
"""the %s should have wordid, old_wordid, and word columns""" %
(vocab_table))
except:
_assert(0, "Table %s must exist and should have wordid, old_wordid, "
"and word columns" % (vocab_table))
def _validate_output_data_table(output_data_table, topic_num):
"""
@brief Check the validity of the output data table
@param model_table Output data table name
@param topic_num Topic number
"""
try:
rv = plpy.execute("""
SELECT count(*) cnt
FROM pg_attribute
WHERE
attrelid = '{output_data_table}'::regclass AND
((atttypid = 'INT4[]'::regtype AND attname = 'words') OR
(atttypid = 'INT4[]'::regtype AND attname = 'counts') OR
(atttypid = 'INT4[]'::regtype AND attname = 'topic_count'))
""".format(output_data_table=output_data_table))
_assert(
rv[0]['cnt'] == 3,
"Table %s must have words::INT4[], counts::INT4[], and"
" topic_count::INT4[] columns" % (output_data_table))
except:
_assert(0,
"Table %s must exist and should have words, counts, and "
" topic_count columns" % (output_data_table))
rv = plpy.execute("""
SELECT min(dim) min_dim, max(dim) max_dim
FROM
(
SELECT array_upper(topic_count, 1) dim
FROM {output_data_table}
) t1
""".format(output_data_table=output_data_table))
_assert(rv[0]['min_dim'] == rv[0]['max_dim'] and
rv[0]['min_dim'] == topic_num,
"Dimension mismatch - array_upper(topic_count, 1) <> topic_num")
def _validate_model_table(model_table):
"""
@brief Check the validity of the model table
@param model_table Model table name
"""
# plpy.notice('checking the model table ...')
try:
rv = plpy.execute("""
SELECT count(*) cnt
FROM pg_attribute
WHERE
attrelid = '{model_table}'::regclass AND
((atttypid = 'INT4'::regtype AND attname = 'voc_size') OR
(atttypid = 'INT4'::regtype AND attname = 'topic_num') OR
(atttypid = 'FLOAT8'::regtype AND attname = 'alpha') OR
(atttypid = 'FLOAT8'::regtype AND attname = 'beta') OR
(atttypid = 'INT8[]'::regtype AND attname = 'model'))
""".format(model_table=model_table))
_assert(
rv[0]['cnt'] == 5,
"""Table %s must have voc_size::INT4, topic_num::INT4, alpha::FLOAT8,
beta::FLOAT8, and model::INT8[] columns""" % (model_table))
except:
_assert(0,
"""Table %s must exist and should have voc_size, topic_num, alpha,
beta, word_topic, and corpus_topic columns""" % (model_table))
rv = plpy.execute("""
SELECT voc_size, topic_num, alpha, beta,
array_upper(model, 1) model_size
FROM %s
""" % (model_table))
_assert(len(rv) > 0, '%s is empty' % (model_table))
_assert(len(rv) == 1, '%s should have only 1 row' % (model_table))
_assert(rv[0]['voc_size'] > 0,
'voc_size in %s should be a positive integer' % (model_table))
_assert(rv[0]['topic_num'] > 0,
'topic_num in %s should be a positive integer' % (model_table))
_assert(rv[0]['alpha'] > 0,
'alpha in %s should be a positive real number' % (model_table))
_assert(rv[0]['beta'] > 0,
'beta in %s should be a positive real number' % (model_table))
_assert(rv[0]['model_size'] == ((rv[0]['voc_size']) * (rv[0]['topic_num'] + 1) + 1) / 2,
"model_size mismatches with voc_size and topic_num in %s" % (model_table))