blob: 8e39a137b02e30c368bb08054fab4bd6a33dc79f [file] [log] [blame]
import plpy
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.utilities import _assert
def vcrf_top1_table(schema_madlib, segment_tbl, label_tbl, resulttbl_raw, result_tbl, **kwargs):
plpy.execute("SELECT {schema_madlib}.create_schema_pg_temp();".format(schema_madlib = schema_madlib));
seg_count_tbl = "pg_temp._madlib_segcount_tbl"
plpy.execute("""
DROP TABLE IF EXISTS {seg_count_tbl};
CREATE TABLE {seg_count_tbl} AS
SELECT doc_id, count(*) -1 max_pos
FROM {segment_tbl}
GROUP BY doc_id
m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')
""".format(seg_count_tbl = seg_count_tbl,
segment_tbl = segment_tbl))
rv = plpy.execute('SELECT COUNT(*) AS total FROM ' + label_tbl);
nlabel = rv[0]['total']
query = """create table """ + result_tbl + """ AS
select segs.doc_id, start_pos, seg_text, L.label, L.id as id, count_tbl.max_pos, (result.label[count_tbl.max_pos+2]::float/1000000) as prob
from """ + segment_tbl + """ segs, """ + label_tbl + """ L, """ + resulttbl_raw + """ result, """ + seg_count_tbl + """ count_tbl
where result.label[segs.start_pos+1]=L.id and segs.doc_id=result.doc_id and segs.doc_id = count_tbl.doc_id
order by doc_id, start_pos;"""
plpy.execute(query)
def vcrf_label(schema_madlib, segment_tbl, factor_mtbl, factor_rtbl, label_tbl, result_tbl, **kwargs):
_validate_args(segment_tbl, factor_mtbl, factor_rtbl, label_tbl, result_tbl)
origClientMinMessages = plpy.execute("SELECT setting AS setting FROM pg_settings WHERE name = \'client_min_messages\';")
plpy.execute("SET client_min_messages TO warning;")
plpy.execute("SELECT {schema_madlib}.create_schema_pg_temp();".format(schema_madlib = schema_madlib));
m_factors = "pg_temp._madlib_m_factors"
r_factors = "pg_temp._madlib_r_factors"
segment_tbl_digits = "pg_temp._madlib_segment_tbl_digits"
resulttbl_raw = "pg_temp._madlib_resulttbl_raw"
plpy.execute("""DROP TABLE IF EXISTS """ + m_factors + """,""" + r_factors + """,""" + segment_tbl_digits + """,""" + resulttbl_raw + """;""")
plpy.execute("""CREATE TABLE """ + resulttbl_raw + """(doc_id integer, label integer[])
m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""")
# replace digits with "DIGIT" keyword
plpy.execute("""CREATE TABLE """ + segment_tbl_digits + """ AS SELECT start_pos, doc_id, seg_text FROM """ + segment_tbl + """ WHERE
NOT (seg_text ~ E'^[-+]?([0-9]{1,3}[,]?)*[0-9]{1,3}$' OR seg_text ~ E'^[-+]?[0-9]*[.][0-9]+$')
m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""")
plpy.execute("""INSERT INTO """ + segment_tbl_digits + """ SELECT start_pos,doc_id,'DIGIT' FROM """ + segment_tbl + """ WHERE
seg_text ~ E'^[-+]?([0-9]{1,3}[,]?)*[0-9]{1,3}$' OR seg_text ~E'^[-+]?[0-9]*[.][0-9]+$';""")
query = """
-- for each sentence, store array representation of r_factors
m4_ifdef(`__HAS_ORDERED_AGGREGATES__', `
select doc_id, array_agg(score order by start_pos, label) as score
', `
select doc_id, array(
select score
from """ + factor_rtbl + """ factors,
""" + segment_tbl_digits + """ seg
where factors.seg_text = seg.seg_text
and doc_id = ss.doc_id
order by start_pos, label
) as score
')
into """ + r_factors + """
from (select doc_id, start_pos, label, score
from """ + factor_rtbl + """ factors,
""" + segment_tbl_digits + """ seg
where factors.seg_text=seg.seg_text) as ss
group by doc_id
order by doc_id
"""
plpy.execute(query)
plpy.execute("analyze " + r_factors + ";")
query = """
-- array representation of m_factor
select score
into """ + m_factors + """
from (select score
from """ + factor_mtbl + """ factors) as ss
"""
plpy.execute(query)
rv = plpy.execute('SELECT COUNT(*) AS total FROM ' + label_tbl);
nlabel = rv[0]['total']
query = ("""
INSERT INTO {resulttbl_raw}
SELECT doc_id, {schema_madlib}.vcrf_top1_label(mfactors.score, rfactors.score, {nlabel})
FROM {m_factors} mfactors, {r_factors} rfactors
""".format(schema_madlib = schema_madlib,
resulttbl_raw = resulttbl_raw,
m_factors = m_factors,
r_factors = r_factors,
nlabel = str(nlabel)))
plpy.execute(query);
query = "SELECT * FROM {schema_madlib}.vcrf_top1_table('{segment_tbl}', '{label_tbl}', '{resulttbl_raw}', '{result_tbl}')"
plpy.execute(query.format(schema_madlib = schema_madlib,
segment_tbl = segment_tbl,
label_tbl = label_tbl,
resulttbl_raw = resulttbl_raw,
result_tbl = result_tbl))
plpy.execute("""SET client_min_messages TO """ + str(origClientMinMessages[0]['setting']) + """;""")
def _validate_columns(cols, table_name, err_msg_tbl):
"""
@brief Validate if cols exists in the table
"""
_assert(columns_exist_in_table(table_name, cols),
"CRF error: Missing required columns from %s table: %s" % (err_msg_tbl, ', '.join(cols)))
def _validate_args(segment_tbl, factor_mtbl, factor_rtbl, label_tbl, result_tbl):
# Check existence of input tables.
_assert(table_exists(segment_tbl),
"CRF error: Segment table does not exist!")
_assert(table_exists(factor_mtbl),
"CRF error: M Factor table does not exist!")
_assert(table_exists(factor_rtbl),
"CRF error: R Factor table does not exist!")
_assert(table_exists(label_tbl),
"CRF error: Label table does not exist!")
# validate required column names for existence
_validate_columns(['doc_id', 'start_pos', 'seg_text'], segment_tbl, "segment")
_validate_columns(['seg_text', 'label', 'score'], factor_rtbl, "R factor")
_validate_columns(['score'], factor_mtbl, "M factor")
_validate_columns(['id', 'label'], label_tbl, "label")
_assert(result_tbl is not None and
result_tbl.lower().strip() not in ('null', ''),
"CRF error: Invalid result table name")
_assert(not table_exists(result_tbl, only_first_schema=True),
"CRF error: Result table name already exist!"
" Please provide a different table name.")