| import plpy |
| from utilities.validate_args import table_exists |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.utilities import _assert |
| |
| def vcrf_top1_table(schema_madlib, segment_tbl, label_tbl, resulttbl_raw, result_tbl, **kwargs): |
| |
| plpy.execute("SELECT {schema_madlib}.create_schema_pg_temp();".format(schema_madlib = schema_madlib)); |
| seg_count_tbl = "pg_temp._madlib_segcount_tbl" |
| plpy.execute(""" |
| DROP TABLE IF EXISTS {seg_count_tbl}; |
| CREATE TABLE {seg_count_tbl} AS |
| SELECT doc_id, count(*) -1 max_pos |
| FROM {segment_tbl} |
| GROUP BY doc_id |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)') |
| """.format(seg_count_tbl = seg_count_tbl, |
| segment_tbl = segment_tbl)) |
| |
| rv = plpy.execute('SELECT COUNT(*) AS total FROM ' + label_tbl); |
| nlabel = rv[0]['total'] |
| query = """create table """ + result_tbl + """ AS |
| select segs.doc_id, start_pos, seg_text, L.label, L.id as id, count_tbl.max_pos, (result.label[count_tbl.max_pos+2]::float/1000000) as prob |
| from """ + segment_tbl + """ segs, """ + label_tbl + """ L, """ + resulttbl_raw + """ result, """ + seg_count_tbl + """ count_tbl |
| where result.label[segs.start_pos+1]=L.id and segs.doc_id=result.doc_id and segs.doc_id = count_tbl.doc_id |
| order by doc_id, start_pos;""" |
| plpy.execute(query) |
| |
| def vcrf_label(schema_madlib, segment_tbl, factor_mtbl, factor_rtbl, label_tbl, result_tbl, **kwargs): |
| |
| _validate_args(segment_tbl, factor_mtbl, factor_rtbl, label_tbl, result_tbl) |
| |
| origClientMinMessages = plpy.execute("SELECT setting AS setting FROM pg_settings WHERE name = \'client_min_messages\';") |
| plpy.execute("SET client_min_messages TO warning;") |
| |
| plpy.execute("SELECT {schema_madlib}.create_schema_pg_temp();".format(schema_madlib = schema_madlib)); |
| |
| m_factors = "pg_temp._madlib_m_factors" |
| r_factors = "pg_temp._madlib_r_factors" |
| segment_tbl_digits = "pg_temp._madlib_segment_tbl_digits" |
| resulttbl_raw = "pg_temp._madlib_resulttbl_raw" |
| |
| plpy.execute("""DROP TABLE IF EXISTS """ + m_factors + """,""" + r_factors + """,""" + segment_tbl_digits + """,""" + resulttbl_raw + """;""") |
| plpy.execute("""CREATE TABLE """ + resulttbl_raw + """(doc_id integer, label integer[]) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""") |
| |
| |
| # replace digits with "DIGIT" keyword |
| plpy.execute("""CREATE TABLE """ + segment_tbl_digits + """ AS SELECT start_pos, doc_id, seg_text FROM """ + segment_tbl + """ WHERE |
| NOT (seg_text ~ E'^[-+]?([0-9]{1,3}[,]?)*[0-9]{1,3}$' OR seg_text ~ E'^[-+]?[0-9]*[.][0-9]+$') |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""") |
| plpy.execute("""INSERT INTO """ + segment_tbl_digits + """ SELECT start_pos,doc_id,'DIGIT' FROM """ + segment_tbl + """ WHERE |
| seg_text ~ E'^[-+]?([0-9]{1,3}[,]?)*[0-9]{1,3}$' OR seg_text ~E'^[-+]?[0-9]*[.][0-9]+$';""") |
| |
| query = """ |
| -- for each sentence, store array representation of r_factors |
| m4_ifdef(`__HAS_ORDERED_AGGREGATES__', ` |
| select doc_id, array_agg(score order by start_pos, label) as score |
| ', ` |
| select doc_id, array( |
| select score |
| from """ + factor_rtbl + """ factors, |
| """ + segment_tbl_digits + """ seg |
| where factors.seg_text = seg.seg_text |
| and doc_id = ss.doc_id |
| order by start_pos, label |
| ) as score |
| ') |
| into """ + r_factors + """ |
| from (select doc_id, start_pos, label, score |
| from """ + factor_rtbl + """ factors, |
| """ + segment_tbl_digits + """ seg |
| where factors.seg_text=seg.seg_text) as ss |
| group by doc_id |
| order by doc_id |
| """ |
| plpy.execute(query) |
| plpy.execute("analyze " + r_factors + ";") |
| |
| query = """ |
| -- array representation of m_factor |
| select score |
| into """ + m_factors + """ |
| from (select score |
| from """ + factor_mtbl + """ factors) as ss |
| """ |
| plpy.execute(query) |
| |
| rv = plpy.execute('SELECT COUNT(*) AS total FROM ' + label_tbl); |
| nlabel = rv[0]['total'] |
| |
| query = (""" |
| INSERT INTO {resulttbl_raw} |
| SELECT doc_id, {schema_madlib}.vcrf_top1_label(mfactors.score, rfactors.score, {nlabel}) |
| FROM {m_factors} mfactors, {r_factors} rfactors |
| """.format(schema_madlib = schema_madlib, |
| resulttbl_raw = resulttbl_raw, |
| m_factors = m_factors, |
| r_factors = r_factors, |
| nlabel = str(nlabel))) |
| |
| plpy.execute(query); |
| |
| query = "SELECT * FROM {schema_madlib}.vcrf_top1_table('{segment_tbl}', '{label_tbl}', '{resulttbl_raw}', '{result_tbl}')" |
| plpy.execute(query.format(schema_madlib = schema_madlib, |
| segment_tbl = segment_tbl, |
| label_tbl = label_tbl, |
| resulttbl_raw = resulttbl_raw, |
| result_tbl = result_tbl)) |
| plpy.execute("""SET client_min_messages TO """ + str(origClientMinMessages[0]['setting']) + """;""") |
| |
| |
| def _validate_columns(cols, table_name, err_msg_tbl): |
| """ |
| @brief Validate if cols exists in the table |
| """ |
| |
| _assert(columns_exist_in_table(table_name, cols), |
| "CRF error: Missing required columns from %s table: %s" % (err_msg_tbl, ', '.join(cols))) |
| |
| def _validate_args(segment_tbl, factor_mtbl, factor_rtbl, label_tbl, result_tbl): |
| |
| # Check existence of input tables. |
| _assert(table_exists(segment_tbl), |
| "CRF error: Segment table does not exist!") |
| _assert(table_exists(factor_mtbl), |
| "CRF error: M Factor table does not exist!") |
| _assert(table_exists(factor_rtbl), |
| "CRF error: R Factor table does not exist!") |
| _assert(table_exists(label_tbl), |
| "CRF error: Label table does not exist!") |
| |
| # validate required column names for existence |
| _validate_columns(['doc_id', 'start_pos', 'seg_text'], segment_tbl, "segment") |
| _validate_columns(['seg_text', 'label', 'score'], factor_rtbl, "R factor") |
| _validate_columns(['score'], factor_mtbl, "M factor") |
| _validate_columns(['id', 'label'], label_tbl, "label") |
| |
| _assert(result_tbl is not None and |
| result_tbl.lower().strip() not in ('null', ''), |
| "CRF error: Invalid result table name") |
| _assert(not table_exists(result_tbl, only_first_schema=True), |
| "CRF error: Result table name already exist!" |
| " Please provide a different table name.") |