blob: e8b6c902ab74a7d7430663c363324713295c5950 [file] [log] [blame]
/* ----------------------------------------------------------------------- *//**
*
* @file viterbi.sql_in
* @brief concatenate a set of input values into arrays to feed into viterbi c
* function and create a human readable view of the output
* @date February 2012
*
*
*//* ----------------------------------------------------------------------- */
m4_include(`SQLCommon.m4')
/**
* @brief This function creates a human readable view of the results of Viterbi function
* @param segtbl Name of table containing all the testing sentences.
* @param labeltbl Name of table containing all the labels in the label space.
* @param result_tbl Name of table storing the best label sequence and the conditional probability.
* @param vw Name of the human readable view of output.
*/
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.vcrf_top1_view (segtbl text, labeltbl text, result_tbl text, vw text) returns text AS
$$
rv = plpy.execute('SELECT COUNT(*) AS total FROM ' + labeltbl);
nlabel = rv[0]['total']
query = """create view """ + vw + """ AS
select segs.doc_id, start_pos, seg_text, L.label, (L.id+1) as id, (result.label[max_pos+2]::float/1000000) as prob
from """ + segtbl + """ segs, """ + labeltbl + """ L, """ + result_tbl + """ result
where result.label[segs.start_pos+1]=L.id and segs.doc_id=result.doc_id
order by doc_id, start_pos;"""
plpy.execute(query)
return vw
$$ language plpythonu strict;
/**
* @brief This function implements the Viterbi algorithm which takes the sentence to be label as input and return the top1 labeling for that sentence
* @param marray Name of arrays containing m factors
* @param rarray Name of arrays containing r factors
* @param nlabel Total number of labels in the label space
* @returns the top1 label sequence, the last two elements in the array is used to calculate the top1 probability
*/
CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.vcrf_top1_label(mArray int[], rArray int[], nlabel int)
returns int[] as 'MODULE_PATHNAME' language c strict;
/**
* @brief This function prepares the inputs for the c function 'vcrf_top1_label' and invoke the c function.
* @param segtbl Name of table containing all the testing sentences.
* @param factor_mtbl Name of table containing all the m factors.
* @param factor_rtbl Name of table containing all the r factors.
* @param labeltbl Name of table containing all the labels in the label space.
* @param resulttbl Name of table to store the output
* @returns the top1 label sequence, the last two elements in the array is used to calculate the top1 probability
*/
CREATE OR REPLACE FUNCTION
MADLIB_SCHEMA.vcrf_label(segtbl text, factor_mtbl text, factor_rtbl text, labeltbl text, resulttbl text) RETURNS text AS
$$
m_factors = "_m_factors"
r_factors = "_r_factors"
resulttbl_raw = resulttbl + "_raw"
plpy.execute("DROP TABLE IF EXISTS " + m_factors +","+ r_factors + "," + resulttbl_raw +" CASCADE;")
query = """
-- for each sentence, store array representation of r_factors
m4_ifdef(`__HAS_ORDERED_AGGREGATES__', `
select doc_id, array_agg(score order by start_pos, label) as score
', `
select doc_id, array(
select score
from """ + factor_rtbl + """ factors,
""" + segtbl + """ seg
where factors.seg_text = seg.seg_text
and doc_id = ss.doc_id
order by start_pos, label
) as score
')
into """ + r_factors + """
from (select doc_id, start_pos, label, score
from """ + factor_rtbl + """ factors,
""" + segtbl + """ seg
where factors.seg_text=seg.seg_text) as ss
group by doc_id
order by doc_id;"""
plpy.execute(query)
plpy.execute("analyze " + r_factors + ";")
query = """
-- array representation of m_factor
select score
into """ + m_factors + """
from (select score
from """ + factor_mtbl + """ factors) as ss; """
plpy.execute(query)
query = "CREATE TABLE " + resulttbl_raw + " (doc_id integer, label integer[]);"
plpy.execute(query);
rv = plpy.execute('SELECT COUNT(*) AS total FROM ' + labeltbl);
nlabel = rv[0]['total']
query = (""" INSERT INTO """ + resulttbl_raw + """
SELECT doc_id, MADLIB_SCHEMA.vcrf_top1_label(mfactors.score, rfactors.score, """ + str(nlabel) + """ )
FROM """ + m_factors + """ mfactors, """ + r_factors + """ rfactors;""")
plpy.execute(query);
query = "SELECT * FROM MADLIB_SCHEMA.vcrf_top1_view(\'" + segtbl + "\', \'" + labeltbl + "\', \'" + resulttbl_raw + "\', \'" + resulttbl + "\');"
plpy.execute(query);
$$ LANGUAGE plpythonu STRICT;