blob: b7560f133ec53ecc5ada54b77233c2bd7776d565 [file] [log] [blame]
"""
@file lda.py_in
@brief Latent Dirichlet Allocation inference using collapsed Gibbs
sampling algorithm
@namespace lda
LDA: Driver and auxiliary functions
"""
import plpy
import time
# use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0
from utilities.control import EnableOptimizer
from utilities.control import EnableHashagg
from utilities.utilities import __mad_version, _assert, warn
from utilities.validate_args import output_tbl_valid
from utilities.validate_args import input_tbl_valid
# use mad_vec to process arrays passed as strings in GPDB < 4.1 and PG < 9.0
version_wrapper = __mad_version()
string_to_array = version_wrapper.select_vecfunc()
array_to_string = version_wrapper.select_vec_return()
class LDATrainer:
"""
@brief This class defines a LDATrainer.
"""
def __init__(self, schema_madlib, data_table, model_table,
output_data_table, voc_size, topic_num,
iter_num, alpha, beta):
self.schema_madlib = schema_madlib
self.data_table = data_table
self.voc_size = voc_size
self.topic_num = topic_num
self.iter_num = iter_num
self.alpha = alpha
self.beta = beta
self.model_table = model_table
self.output_data_table = output_data_table
self.work_table_0 = '__work_table_train_0__'
self.work_table_1 = '__work_table_train_1__'
plpy.execute("DROP TABLE IF EXISTS " + self.work_table_0)
plpy.execute("""
CREATE TEMP TABLE {work_table_0}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[],
doc_topic INT4[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE,COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (docid)')
""".format(work_table_0=self.work_table_0))
plpy.execute("DROP TABLE IF EXISTS " + self.work_table_1)
plpy.execute("""
CREATE TEMP TABLE {work_table_1}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[],
doc_topic INT4[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE,COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (docid)')
""".format(work_table_1=self.work_table_1))
plpy.execute("""
CREATE TABLE {model_table}(
voc_size INT4,
topic_num INT4,
alpha FLOAT8,
beta FLOAT8,
model INT8[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE,COMPRESSTYPE=QUICKLZ)
DISTRIBUTED RANDOMLY')
""".format(model_table=self.model_table))
plpy.execute("""
CREATE TABLE {output_data_table}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[],
topic_count INT4[],
topic_assignment INT4[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE,COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (docid)')
""".format(output_data_table=self.output_data_table))
def init_random(self):
stime = time.time()
# plpy.notice('initializing topics randomly ...')
plpy.execute('TRUNCATE TABLE ' + self.work_table_0)
plpy.execute("""
INSERT INTO {work_table}
SELECT
docid, wordcount, words, counts,
{schema_madlib}.__lda_random_assign(wordcount, {topic_num}) AS topics
FROM {data_table}
""".format(work_table=self.work_table_0,
schema_madlib=self.schema_madlib,
topic_num=self.topic_num,
data_table=self.data_table))
etime = time.time()
# plpy.notice('\t\ttime elapsed: %.2f seconds' % (etime - stime))
def gen_output_data_table(self):
stime = time.time()
# plpy.notice('\t\tgenerating output data table ...')
work_table_final = self.work_table_1
if self.iter_num % 2 == 0:
work_table_final = self.work_table_0
plpy.execute("TRUNCATE TABLE " + self.output_data_table)
plpy.execute("""
INSERT INTO {output_data_table}
SELECT
docid, wordcount, words, counts, doc_topic[1:{topic_num}] topic_count,
doc_topic[{topic_num} + 1:array_upper(doc_topic,1)] topic_assignment
FROM
{work_table_final}
""".format(output_data_table=self.output_data_table,
topic_num=self.topic_num,
work_table_final=work_table_final))
etime = time.time()
# plpy.notice('\t\t\ttime elapsed: %.2f seconds' % (etime - stime))
def iteration(self, it):
stime = time.time()
work_table_in = self.work_table_0
if it % 2 == 0:
work_table_in = self.work_table_1
work_table_out = self.work_table_1
if it % 2 == 0:
work_table_out = self.work_table_0
# plpy.notice('iteration [%d] ...' % (it))
# plpy.notice('\t\tupdating global model...')
plpy.execute('TRUNCATE TABLE ' + self.model_table)
if version_wrapper.is_gp43() or version_wrapper.is_hawq():
with EnableOptimizer(True):
plpy.execute("""
INSERT INTO {model_table}
SELECT
{voc_size},
{topic_num},
{alpha},
{beta},
{schema_madlib}.__lda_count_topic_agg(
words,
counts,
doc_topic[{topic_num} + 1:array_upper(doc_topic, 1)],
{voc_size},
{topic_num}
) AS model
FROM {work_table_in}
""".format(model_table=self.model_table,
topic_num=self.topic_num,
voc_size=self.voc_size,
alpha=self.alpha,
beta=self.beta,
schema_madlib=self.schema_madlib,
work_table_in=work_table_in))
else:
# work around insertion memory error (MPP-25561)
# by taking the model in Python temporarily
model = plpy.execute("""
SELECT
{schema_madlib}.__lda_count_topic_agg(
words,
counts,
doc_topic[{topic_num} + 1:array_upper(doc_topic, 1)],
{voc_size},
{topic_num}
) AS model
FROM {work_table_in}
""".format(schema_madlib=self.schema_madlib,
topic_num=self.topic_num,
voc_size=self.voc_size,
work_table_in=work_table_in))[0]['model']
# insert it back immediately
plan = plpy.prepare("""
INSERT INTO {model_table}
SELECT
{voc_size}, {topic_num}, {alpha}, {beta}, $1
""".format(model_table=self.model_table,
topic_num=self.topic_num,
voc_size=self.voc_size,
alpha=self.alpha,
beta=self.beta,
schema_madlib=self.schema_madlib),
['bigint[]'])
plpy.execute(plan, [model])
# For GPDB 4.3 and higher we disable the optimzer (ORCA) for the query
# planner since currently ORCA doesn't support InitPlan. This would have
# to be fixed when ORCA is the only available query planner.
with EnableOptimizer(False):
# plpy.notice('\t\tsampling ...')
plpy.execute('TRUNCATE TABLE ' + work_table_out)
query = """
INSERT INTO {work_table_out}
SELECT
docid, wordcount, words, counts,
{schema_madlib}.__lda_gibbs_sample(
words, counts, doc_topic,
(SELECT model FROM {model_table}),
{alpha}, {beta}, {voc_size}, {topic_num}, 1)
FROM
{work_table_in}
""".format(work_table_out=work_table_out,
schema_madlib=self.schema_madlib,
model_table=self.model_table,
alpha=self.alpha,
beta=self.beta,
voc_size=self.voc_size,
topic_num=self.topic_num,
work_table_in=work_table_in)
plpy.execute(query)
etime = time.time()
# plpy.notice('\t\ttime elapsed: %.2f seconds' % (etime - stime))
def run(self):
stime = time.time()
# plpy.notice('start training process ...')
self.init_random()
sstime = time.time()
for it in range(1, self.iter_num + 1):
self.iteration(it)
eetime = time.time()
# plpy.notice('\t\titeration done, time elapsed: %.2f seconds' % (eetime - sstime))
self.gen_output_data_table()
etime = time.time()
# plpy.notice('finished, time elapsed: %.2f seconds' % (etime - stime))
# ------------------------------------------------------------------------------
class LDAPredictor:
"""
@brief This class defines a LDAPredictor
"""
def __init__(self, schema_madlib, test_table, model_table, output_table,
iter_num):
self.schema_madlib = schema_madlib
self.test_table = test_table
self.model_table = model_table
self.iter_num = iter_num
rv = plpy.execute("""
SELECT
voc_size, topic_num, alpha, beta
FROM
{model_table}
""".format(model_table=self.model_table))
self.voc_size = rv[0]['voc_size']
self.topic_num = rv[0]['topic_num']
self.alpha = rv[0]['alpha']
self.beta = rv[0]['beta']
self.doc_topic = output_table
self.work_table_0 = '__work_table_pred_0__'
self.work_table_1 = '__word_table_pred_1__'
plpy.execute("DROP TABLE IF EXISTS " + self.work_table_0)
plpy.execute("""
CREATE TEMP TABLE {work_table_0}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[],
doc_topic INT4[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE,COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (docid)')
""".format(work_table_0=self.work_table_0))
plpy.execute("DROP TABLE IF EXISTS " + self.work_table_1)
plpy.execute("""
CREATE TEMP TABLE {work_table_1}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[],
doc_topic INT4[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE,COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (docid)')
""".format(work_table_1=self.work_table_1))
plpy.execute("""
CREATE TABLE {doc_topic}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[],
topic_count INT4[],
topic_assignment INT4[]
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE,COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (docid)')
""".format(doc_topic=self.doc_topic))
def init_random(self):
stime = time.time()
# plpy.notice('initializing topics randomly ...')
plpy.execute('TRUNCATE TABLE ' + self.work_table_0)
plpy.execute("""
INSERT INTO {work_table}
SELECT
docid, wordcount, words, counts,
{schema_madlib}.__lda_random_assign(wordcount, {topic_num}) AS topics
FROM {data_table}
""".format(work_table=self.work_table_0,
schema_madlib=self.schema_madlib,
topic_num=self.topic_num,
data_table=self.test_table))
etime = time.time()
# plpy.notice('\t\ttime elapsed: %.2f seconds' % (etime - stime))
def gen_output_table(self):
plpy.execute("TRUNCATE TABLE " + self.doc_topic)
plpy.execute("""
INSERT INTO {doc_topic}
SELECT
docid, wordcount, words, counts, doc_topic[1:{topic_num}] topic_count,
doc_topic[{topic_num} + 1:array_upper(doc_topic,1)] topic_assignment
FROM {work_table_out}
""".format(doc_topic=self.doc_topic,
topic_num=self.topic_num,
work_table_out=self.work_table_1))
def infer(self):
stime = time.time()
# plpy.notice('infering ...')
# For GPDB 4.3 and higher we disable the optimzer (ORCA) for the query
# planner since currently ORCA doesn't support InitPlan. This would have
# to be fixed when ORCA is the only available query planner.
with EnableOptimizer(False):
query = """
INSERT INTO {work_table_out}
SELECT
docid, wordcount, words, counts,
{schema_madlib}.__lda_gibbs_sample(
words, counts, doc_topic,
(SELECT model FROM {model_table}),
{alpha}, {beta},
{voc_size}, {topic_num}, {iter_num}
)
FROM
{work_table_in}
""".format(work_table_out=self.work_table_1,
work_table_in=self.work_table_0,
model_table=self.model_table,
schema_madlib=self.schema_madlib,
alpha=self.alpha,
beta=self.beta,
voc_size=self.voc_size,
topic_num=self.topic_num,
iter_num=self.iter_num)
plpy.execute(query)
etime = time.time()
# plpy.notice('\t\ttime elapsed: %.2f seconds' % (etime - stime))
def run(self):
stime = time.time()
# plpy.notice('start prediction process ...')
self.init_random()
self.infer()
self.gen_output_table()
etime = time.time()
# plpy.notice('finished, time elapsed: %.2f seconds' % (etime - stime))
# ------------------------------------------------------------------------------
def lda_train(schema_madlib, train_table, model_table, output_data_table, voc_size,
topic_num, iter_num, alpha, beta):
"""
@brief This function provides the entry for the LDA training process.
@param schema_madlib MDALib schema
@param data_table Training data table
@param voc_size Size of vocabulary
@param topic_num Number of topics
@param iter_num Number of iterations
@param alpha Dirichlet parameter for per-document topic multinomial
@param beta Dirichlet parameter for per-topic word multinomial
@param model_table Learned model table
@param output_data_table Output data table
"""
_assert(train_table is not None and train_table.strip() != '',
'invalid argument: train_table is not specified')
_assert(model_table is not None and model_table.strip() != '',
'invalid argument: model_table is not specified')
_assert(output_data_table is not None and output_data_table.strip() != '',
'invalid argument: output_data_table is not specified')
_assert(voc_size is not None and voc_size > 0,
'invalid argument: positive integer expected for voc_size')
_assert(topic_num is not None and topic_num > 0,
'invalid argument: positive integer expected for topic_num')
_assert(iter_num is not None and iter_num > 0,
'invalid argument: positive integer expected for iter_num')
_assert(alpha is not None and alpha > 0,
'invalid argument: positive real expected for alpha')
_assert(beta is not None and beta > 0,
'invalid argument: positive real expected for beta')
output_tbl_valid(model_table, 'LDA')
output_tbl_valid(output_data_table, 'LDA')
warn(voc_size <= 1e5,
"""the voc_size is very large: %d - make sure that the system has
enough memory or reduce the vocabulary size""" % (voc_size))
warn(topic_num <= 1e3,
"""the topic_num is large: %d - make sure that the system has enough
memory or reduce the number of topics """ % (topic_num))
_validate_data_table(train_table, voc_size)
convt_table = _convert_data_table(schema_madlib, train_table)
lt = LDATrainer(schema_madlib, convt_table, model_table,
output_data_table, voc_size, topic_num,
iter_num, alpha, beta)
lt.run()
# __lda_check_count_ceiling returns NULL when the count ceiling is
# not hit, otherwise, a sample of wordid's that hit
examples_hit_ceiling = plpy.execute("""
SELECT
array_to_string(
{schema_madlib}.__lda_check_count_ceiling(
model,
voc_size,
topic_num)::integer[],
','
) AS examples_hit_ceiling
FROM {model_table}
""".format(**locals()))[0]['examples_hit_ceiling']
if examples_hit_ceiling is not None:
plpy.info("""
LDA warning: some words ({examples_hit_ceiling}, etc.) occur too frequently.
More than 2e9 times. It might add noise for the LDA model.
Please remove the very frequent words in the data.
""".format(**locals()))
def lda_predict(schema_madlib, test_table, model_table, output_data_table,
iter_num=20):
"""
@brief This function provides the entry for the LDA prediction process.
@param test_table name of the testing dataset table
@param model_table name of the model table
@param iter_num number of iterations
@param output_table name of output table
"""
input_tbl_valid(test_table, 'LDA')
input_tbl_valid(model_table, 'LDA')
output_tbl_valid(output_data_table, 'LDA')
iter_num = 20 if iter_num is None else iter_num
_assert(
iter_num >= 0,
'invalid argument: positive integer expected for iter_num')
warn(
iter_num <= 20,
"""the iter_num is large: %d - a smaller iter_num (e.g. 20) should be
good enough""" % (iter_num))
_validate_model_table(model_table)
rv = plpy.execute('SELECT voc_size FROM ' + model_table)
voc_size = rv[0]['voc_size']
_validate_data_table(test_table, voc_size)
convt_table = _convert_data_table(schema_madlib, test_table)
lp = LDAPredictor(
schema_madlib, convt_table, model_table, output_data_table, iter_num)
lp.run()
def get_topic_desc(schema_madlib, model_table, vocab_table, desc_table,
top_k=15):
"""
@brief Get the per-topic description by top-k words
@param model_table The model table generated by the training process
@param vocab_table The vocabulary table
@param top_k The top k words for topic description
@param desc_table The output table for storing the per-topic word description
"""
_assert(model_table != '' and vocab_table != '' and desc_table != '',
"invalid argument: at least one of the table names is not specified")
_assert(top_k > 0, "invalid argument: Positive integer expected for top_k")
_validate_model_table(model_table)
_validate_vocab_table(vocab_table)
output_tbl_valid(desc_table, 'LDA')
plpy.execute("DROP TABLE IF EXISTS __lda_topic_word_count__")
plpy.execute("""
CREATE TEMP TABLE __lda_topic_word_count__(
topicid INT4,
word_count INT4[],
beta FLOAT8
)
m4_ifdef(
`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE, COMPRESSTYPE=QUICKLZ) DISTRIBUTED BY
(topicid)')
""")
plpy.execute("""
INSERT INTO __lda_topic_word_count__
SELECT
generate_series(1, topic_num) AS topicid,
{schema_madlib}.__lda_util_unnest_transpose(
model,
voc_size,
topic_num
) AS word_count,
beta
FROM {model_table}
""".format(
schema_madlib=schema_madlib,
model_table=model_table))
plpy.execute("DROP TABLE IF EXISTS __lda_topic_word_dist__")
plpy.execute("""
CREATE TEMP TABLE __lda_topic_word_dist__(
topicid INT4,
word_dist FLOAT[]
)
m4_ifdef(
`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE, COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (topicid)')
""")
plpy.execute("""
INSERT INTO __lda_topic_word_dist__
SELECT
topicid,
{schema_madlib}.__lda_util_norm_with_smoothing(word_count, beta) dist
FROM
__lda_topic_word_count__
""".format(schema_madlib=schema_madlib))
plpy.execute("""
CREATE TABLE {desc_table} (
topicid INT4,
wordid INT4,
prob FLOAT,
word TEXT)
m4_ifdef(
`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE, COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (topicid)')
""".format(desc_table=desc_table))
plpy.execute("""
INSERT INTO {desc_table}
SELECT
topicid, t2.wordid, prob, word
FROM
(
SELECT
topicid, wordid, prob,
rank() OVER(PARTITION BY topicid ORDER BY prob DESC) r
FROM
(
SELECT
topicid,
generate_series(0, array_upper(word_dist, 1) - 1) wordid,
unnest(word_dist) prob
FROM
__lda_topic_word_dist__
) t1
) t2, {vocab_table} AS vocab
WHERE t2.r < {top_k} AND t2.wordid = vocab.wordid
""".format(desc_table=desc_table,
vocab_table=vocab_table, top_k=top_k))
def get_topic_word_count(schema_madlib, model_table, output_table):
"""
@brief Get the per-topic word counts from the model table
@param model_table The model table generated by the training process
@param output_table The output table for storing the per-topic word counts
"""
_assert(
model_table != '',
'invalid argument: model table name is not specified')
_validate_model_table(model_table)
output_tbl_valid(output_table, 'LDA')
plpy.execute("""
CREATE TABLE {output_table} (
topicid INT4,
word_count INT4[])
m4_ifdef(
`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE, COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (topicid)')
""".format(output_table=output_table))
plpy.execute("""
INSERT INTO {output_table}
SELECT
generate_series(1, topic_num) topicid,
{schema_madlib}.__lda_util_unnest_transpose(
model,
voc_size,
topic_num
) AS word_count
FROM
{model_table}
""".format(output_table=output_table, schema_madlib=schema_madlib,
model_table=model_table))
def get_word_topic_count(schema_madlib, model_table, output_table):
"""
@brief Get the per-word topic counts from the model table
@param model_table The model table generated by the training process
@param output_table The output table for storing the per-word topic counts
"""
_assert(model_table != '',
"invalid argument: model table name is not specified")
_validate_model_table(model_table)
output_tbl_valid(output_table, 'LDA')
plpy.execute("""
CREATE TABLE {output_table} (
wordid INT4,
topic_count INT4[])
m4_ifdef(
`__POSTGRESQL__', `',
`WITH (APPENDONLY=TRUE, COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (wordid)')
""".format(output_table=output_table))
plpy.execute("""
INSERT INTO {output_table}
SELECT
generate_series(0, voc_size - 1) wordid,
{schema_madlib}.__lda_util_unnest(model, voc_size, topic_num) word_count
FROM {model_table}
""".format(output_table=output_table, schema_madlib=schema_madlib,
model_table=model_table))
def get_perplexity(schema_madlib, model_table, output_data_table):
"""
@brief Get the perplexity given the prediction and model.
@param model_table The model table generated by lda_train
@param output_data_table The output data table generated by lda_predict
"""
_assert(model_table != '' and output_data_table != '',
'invalid argument: at least one of the table names is not specified')
_validate_model_table(model_table)
params = plpy.execute("""
SELECT topic_num, voc_size, alpha, beta FROM {model_table}
""".format(model_table=model_table))[0]
topic_num = params['topic_num']
voc_size = params['voc_size']
alpha = params['alpha']
beta = params['beta']
_validate_output_data_table(output_data_table, topic_num)
# For GPDB 4.3 and higher we disable the optimzer (ORCA) for the query
# planner since currently ORCA doesn't support InitPlan. This would have
# to be fixed when ORCA is the only available query planner.
with EnableOptimizer(False):
query = """
SELECT exp(-part_perp/total_word) AS perp
FROM
(
SELECT {schema_madlib}.__lda_perplexity_agg(
words, counts, topic_count, (SELECT model FROM {model_table}),
{alpha}, {beta}, {voc_size}, {topic_num}) AS part_perp
FROM
{out_data_table}
) t1,
(
SELECT sum(wordcount) total_word FROM {out_data_table}
) t2
""".format(schema_madlib=schema_madlib,
out_data_table=output_data_table,
model_table=model_table,
alpha=alpha,
beta=beta,
topic_num=topic_num,
voc_size=voc_size)
rv = plpy.execute(query)
return rv[0]['perp']
def norm_vocab(vocab_table, out_table):
"""
@brief Checks the vocabulary and converts non-continous wordids into continuous
integers ranging from 0 to voc_size - 1.
@param vocab_table The vocabulary table in the form of
<wordid::INT4, word::text>
@param out_table The normalized vocabulary table
"""
_validate_vocab_table(vocab_table)
output_tbl_valid(out_table, 'LDA')
plpy.execute("""
CREATE TABLE {out_table}(
wordid INT4,
old_wordid INT4,
word TEXT
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH(APPENDONLY=TRUE,COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (wordid)')
""".format(out_table=out_table))
plpy.execute("""
INSERT INTO {out_table}
SELECT r - 1, wordid, word
FROM
(
SELECT
wordid, word, rank() OVER(ORDER BY wordid) r
FROM
(
SELECT wordid, word
FROM
(
SELECT
wordid,
word,
rank() OVER(PARTITION BY wordid ORDER BY word ASC) r
FROM
{vocab_table}
) t1
WHERE r = 1
) t2
) t3
""".format(out_table=out_table, vocab_table=vocab_table))
def norm_dataset(data_table, vocab_table, output_table):
"""
@brief Normalize the data table according to the normalized vocabulary, rows
with non-positive count values will be removed
@param data_table The data table to be normalized
@param vocab_table The nomralized vocabulary table
@param output_table The normalized data table
"""
_validate_data_table_cols(data_table)
_validate_norm_vocab_table(vocab_table)
output_tbl_valid(output_table, 'LDA')
plpy.execute("""
CREATE TABLE {output_table}(
docid INT4,
wordid INT4,
count INT4
)
m4_ifdef(`__POSTGRESQL__', `',
`WITH(APPENDONLY=TRUE,COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (docid)')
""".format(output_table=output_table))
plpy.execute("""
INSERT INTO {output_table}
SELECT
docid,
vocab.wordid,
count
FROM
{data_table} AS data,
{vocab_table} AS vocab
WHERE
data.wordid = vocab.old_wordid AND data.count > 0
""".format(output_table=output_table,
data_table=data_table,
vocab_table=vocab_table))
def conorm_data(data_table, vocab_table, output_data_table, output_vocab_table):
"""
@brief Co-normalize the data table and the vocabulary table
@param data_table The data table to be normalized
@param vocab_table The vocabulary table to be nomralized
@param output_data_table The normalized data table
@param output_vocab_table The normalized vocabulary table
"""
_validate_data_table_cols(data_table)
_validate_vocab_table(vocab_table)
output_tbl_valid(output_data_table, 'LDA')
output_tbl_valid(output_vocab_table, 'LDA')
plpy.execute("DROP TABLE IF EXISTS __vocab__")
plpy.execute("""
CREATE TEMP TABLE __vocab__(
wordid INT4,
word TEXT
)
m4_ifdef(`__POSTGRESQL__', `', `WITH(APPENDONLY=TRUE,COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (wordid)')
""")
plpy.execute("""
INSERT INTO __vocab__
SELECT voc.wordid, voc.word
FROM
(
SELECT
wordid
FROM
{data_table} AS data
GROUP BY wordid
) tvoc, {vocab_table} voc
WHERE tvoc.wordid = voc.wordid
""".format(data_table=data_table, vocab_table=vocab_table))
norm_vocab('__vocab__', output_vocab_table)
norm_dataset(data_table, output_vocab_table, output_data_table)
def index_sort(vector):
"""
@brief Return the index of elements in a sorted order
@param vector The array to be sorted
@return The index of elements
"""
# process arrays for GPDB < 4.1 and PG < 9.0
vector = string_to_array(vector, False)
dim = len(vector)
idx = range(dim)
idx.sort(key=lambda r: vector[r])
return array_to_string(map(lambda r: r + 1, idx))
def _convert_data_table(schema_madlib, data_table):
"""
@brief Convert the format of data table from <docid, wordid, count> to <docid,
wordcount, words, counts>.
@param data_table The data table to be converted
@param return The converted table name
"""
# plpy.notice('converting the data table ...')
convt_table = '__lda_convt_corpus__'
plpy.execute("DROP TABLE IF EXISTS " + convt_table)
plpy.execute("""
CREATE TEMP TABLE {convt_table}(
docid INT4,
wordcount INT4,
words INT4[],
counts INT4[]
)
m4_ifdef(`__POSTGRESQL__', `', `WITH(APPENDONLY=TRUE,COMPRESSTYPE=QUICKLZ)
DISTRIBUTED BY (docid)')
""".format(convt_table=convt_table))
with EnableOptimizer(False):
with EnableHashagg(False):
plpy.execute("""
INSERT INTO {convt_table}
SELECT
docid,
sum(count) wordcount,
{schema_madlib}.array_agg(wordid) words,
{schema_madlib}.array_agg(count) counts
FROM
{data_table}
WHERE
(docid IS NOT NULL) AND
(wordid IS NOT NULL) AND
(count IS NOT NULL)
GROUP BY docid
""".format(convt_table=convt_table,
schema_madlib=schema_madlib,
data_table=data_table))
return convt_table
def _validate_data_table_cols(data_table):
"""
@brief Check the structure of the data table
@param data_table The data table name
"""
# plpy.notice('checking the data table ... ')
_assert(data_table is not None and data_table.strip() != '',
'no data table is specified')
try:
rv = plpy.execute("""
SELECT count(*) cnt FROM pg_attribute
WHERE
attrelid = '{data_table}'::regclass AND
((atttypid = 'INT4'::regtype AND attname = 'docid') OR
(atttypid = 'INT4'::regtype AND attname = 'wordid') OR
(atttypid = 'INT4'::regtype AND attname = 'count'))
""".format(data_table=data_table))
_assert(rv[0]['cnt'] == 3,
"Table %s should have docid::INT4, wordid::INT4, count::INT4"
"columns" % (data_table))
except:
_assert(False,
"Table %s must exist and should have docid, wordid, and "
"count columns" % (data_table))
def _validate_data_table(data_table, voc_size):
"""
@brief Check the validity of the data table
@param data_table The data table
@param voc_size The size of vocabulary
"""
# plpy.notice('validating the data table ...')
_validate_data_table_cols(data_table)
rv = plpy.execute('SELECT count(*) cnt FROM %s WHERE docid < 0' % (data_table))
warn(
rv[0]['cnt'] == 0,
"""%d rows have negative docid - use continous non-negative
integers for docid for better interpretation""" % (rv[0]['cnt']))
rv = plpy.execute('SELECT count(*) cnt FROM %s WHERE wordid < 0' % (data_table))
_assert(
rv[0]['cnt'] == 0,
"""%d rows have negative wordid - the wordid must range from 0 to
voc_size - 1""" % (rv[0]['cnt']))
rv = plpy.execute("""
SELECT
count(wordid) size,
min(wordid) min_wordid,
max(wordid) max_wordid
FROM
(
SELECT wordid
FROM {data_table}
GROUP BY wordid
) t1
""".format(data_table=data_table))
_assert(0 < rv[0]['size'], 'Table %s is empty' % (data_table))
_assert(
rv[0]['size'] <= voc_size,
"""The actual size of vocabulary %d is larger than the specified %d -
set the correct voc_size and try again""" % (rv[0]['size'], voc_size))
_assert(
rv[0]['min_wordid'] <= voc_size - 1 and
rv[0]['max_wordid'] <= voc_size - 1,
"""The wordid should be in the range of 0 to %d""" % (voc_size - 1))
warn(
rv[0]['size'] == voc_size,
"""Actual size of vocabulary %d is smaller than the specified %d -
the vocabulary and dataset normalization is highly recommended for
memory efficiency""" % (rv[0]['size'], voc_size))
warn(
rv[0]['min_wordid'] == 0,
"""The actual min wordid is large than 0 - the vocabulary and dataset
normalization is highly recommended for memory efficiency""")
warn(rv[0]['max_wordid'] == voc_size - 1,
"""Actual max word_id is smaller than the specified
(voc_size - 1 = %d) - set the voc_size to %d for memory efficiency
""" % (voc_size - 1, rv[0]['max_wordid'] + 1))
rv = plpy.execute('SELECT count(*) cnt FROM %s WHERE count <= 0' % (data_table))
_assert(
rv[0]['cnt'] == 0,
"""%d rows have zero or negative count - the value in the count column
must be positive integers""" % (rv[0]['cnt']))
def _validate_vocab_table(vocab_table):
"""
@brief Check the validity of the vocabulary table
@param vocab_table The vocabulary table name
"""
# plpy.notice('checking the vocabulary table ...')
try:
rv = plpy.execute("""
SELECT count(*) cnt FROM pg_attribute
WHERE
attrelid = '{vocab_table}'::regclass AND
((atttypid = 'INT4'::regtype AND attname = 'wordid') OR
(atttypid = 'text'::regtype AND attname = 'word'))
""".format(vocab_table=vocab_table))
_assert(
rv[0]['cnt'] == 2,
"""the %s should have wordid::INT4, word::TEXT columns""" % (vocab_table))
except:
_assert(0, "Table %s must exist and should have wordid and "
"word columnes" % (vocab_table))
def _validate_norm_vocab_table(vocab_table):
"""
@brief Check the validity of the normalized vocabulary table
@param vocab_table The normalized vocabulary table name
"""
try:
rv = plpy.execute("""
SELECT count(*) cnt FROM pg_attribute
WHERE
attrelid = '{vocab_table}'::regclass AND
((atttypid = 'INT4'::regtype AND attname = 'wordid') OR
(atttypid = 'INT4'::regtype AND attname = 'old_wordid') OR
(atttypid = 'text'::regtype AND attname = 'word'))
""".format(vocab_table=vocab_table))
_assert(
rv[0]['cnt'] == 3,
"""the %s should have wordid, old_wordid, and word columns""" %
(vocab_table))
except:
_assert(0, "Table %s must exist and should have wordid, old_wordid, "
"and word columns" % (vocab_table))
def _validate_output_data_table(output_data_table, topic_num):
"""
@brief Check the validity of the output data table
@param model_table Output data table name
@param topic_num Topic number
"""
try:
rv = plpy.execute("""
SELECT count(*) cnt
FROM pg_attribute
WHERE
attrelid = '{output_data_table}'::regclass AND
((atttypid = 'INT4[]'::regtype AND attname = 'words') OR
(atttypid = 'INT4[]'::regtype AND attname = 'counts') OR
(atttypid = 'INT4[]'::regtype AND attname = 'topic_count'))
""".format(output_data_table=output_data_table))
_assert(
rv[0]['cnt'] == 3,
"Table %s must have words::INT4[], counts::INT4[], and"
" topic_count::INT4[] columns" % (output_data_table))
except:
_assert(0,
"Table %s must exist and should have words, counts, and "
" topic_count columns" % (output_data_table))
rv = plpy.execute("""
SELECT min(dim) min_dim, max(dim) max_dim
FROM
(
SELECT array_upper(topic_count, 1) dim
FROM {output_data_table}
) t1
""".format(output_data_table=output_data_table))
_assert(rv[0]['min_dim'] == rv[0]['max_dim'] and
rv[0]['min_dim'] == topic_num,
"Dimension mismatch - array_upper(topic_count, 1) <> topic_num")
def _validate_model_table(model_table):
"""
@brief Check the validity of the model table
@param model_table Model table name
"""
# plpy.notice('checking the model table ...')
try:
rv = plpy.execute("""
SELECT count(*) cnt
FROM pg_attribute
WHERE
attrelid = '{model_table}'::regclass AND
((atttypid = 'INT4'::regtype AND attname = 'voc_size') OR
(atttypid = 'INT4'::regtype AND attname = 'topic_num') OR
(atttypid = 'FLOAT8'::regtype AND attname = 'alpha') OR
(atttypid = 'FLOAT8'::regtype AND attname = 'beta') OR
(atttypid = 'INT8[]'::regtype AND attname = 'model'))
""".format(model_table=model_table))
_assert(
rv[0]['cnt'] == 5,
"""Table %s must have voc_size::INT4, topic_num::INT4, alpha::FLOAT8,
beta::FLOAT8, and model::INT8[] columns""" % (model_table))
except:
_assert(0,
"""Table %s must exist and should have voc_size, topic_num, alpha,
beta, word_topic, and corpus_topic columns""" % (model_table))
rv = plpy.execute("""
SELECT voc_size, topic_num, alpha, beta,
array_upper(model, 1) model_size
FROM %s
""" % (model_table))
_assert(len(rv) > 0, '%s is empty' % (model_table))
_assert(len(rv) == 1, '%s should have only 1 row' % (model_table))
_assert(rv[0]['voc_size'] > 0,
'voc_size in %s should be a positive integer' % (model_table))
_assert(rv[0]['topic_num'] > 0,
'topic_num in %s should be a positive integer' % (model_table))
_assert(rv[0]['alpha'] > 0,
'alpha in %s should be a positive real number' % (model_table))
_assert(rv[0]['beta'] > 0,
'beta in %s should be a positive real number' % (model_table))
_assert(rv[0]['model_size'] == ((rv[0]['voc_size']) * (rv[0]['topic_num'] + 1) + 1) / 2,
"model_size mismatches with voc_size and topic_num in %s" % (model_table))