| """ |
| @file crf_feature_gen.py_in |
| |
| @brief Conditional Random Field: Feature Extraction for Tranining and Testing. |
| |
| @namespace crf |
| |
| Conditional Random Field: Feature Extraction for Training and Testing. |
| """ |
| |
| import plpy |
| from utilities.validate_args import table_exists |
| from utilities.validate_args import columns_exist_in_table |
| from utilities.utilities import _assert |
| from utilities.utilities import add_postfix |
| |
| def generate_train_features(schema_madlib, train_segment_tbl, regex_tbl, label_tbl, |
| dictionary_tbl, train_feature_tbl, train_featureset_tbl, **kwargs): |
| |
| _validate_train_args(train_segment_tbl, regex_tbl, label_tbl, |
| dictionary_tbl, train_feature_tbl, train_featureset_tbl) |
| |
| origClientMinMessages = plpy.execute("SELECT setting AS setting FROM pg_settings WHERE name = \'client_min_messages\';") |
| plpy.execute("SET client_min_messages TO warning;") |
| |
| plpy.execute("SELECT {schema_madlib}.create_schema_pg_temp();".format(schema_madlib=schema_madlib)); |
| tmp1_feature = "pg_temp._madlib_tmp1_feature" |
| tmp_rtbl = "pg_temp._madlib_tmp_rtbl" |
| tmp_dense_mtbl = "pg_temp._madlib_tmp_dense_mtbl" |
| dense_mtbl = "pg_temp._madlib_dense_mtbl" |
| sparse_rtbl = "pg_temp._madlib_sparse_rtbl" |
| sparse_mtbl = "pg_temp._madlib_sparse_mtbl" |
| tmp_featureset = "pg_temp._madlib_tmp_featureset" |
| tmp_segmenttbl = "pg_temp._madlib_tmp_segmenttbl" |
| tmp_segcount_tbl = "pg_temp._madlib_tmp_segcount_tbl" |
| |
| plpy.execute("""DROP TABLE IF EXISTS {tmp1_feature}, |
| {tmp_rtbl}, |
| {tmp_dense_mtbl}, |
| {dense_mtbl}, |
| {sparse_rtbl}, |
| {sparse_mtbl}, |
| {tmp_featureset}, |
| {tmp_segmenttbl} |
| """.format(tmp1_feature = tmp1_feature, |
| tmp_rtbl = tmp_rtbl, |
| tmp_dense_mtbl = tmp_dense_mtbl, |
| dense_mtbl = dense_mtbl, |
| sparse_rtbl = sparse_rtbl, |
| sparse_mtbl = sparse_mtbl, |
| tmp_featureset = tmp_featureset, |
| tmp_segmenttbl = tmp_segmenttbl)) |
| |
| plpy.execute("""CREATE TABLE """ + tmp1_feature + """(start_pos integer,doc_id integer, f_name text, feature integer[]) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""") |
| plpy.execute("""CREATE TABLE """ + tmp_rtbl + """(start_pos integer,doc_id integer, feature integer[]) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""") |
| plpy.execute("""CREATE TABLE """ + tmp_dense_mtbl + """(start_pos integer,doc_id integer, feature integer[]) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""") |
| plpy.execute("""CREATE TABLE """ + dense_mtbl + """(doc_id integer, dense_m integer[]) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""") |
| plpy.execute("""CREATE TABLE """ + sparse_rtbl + """(doc_id integer,f_size integer, sparse_r integer[]) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""") |
| plpy.execute("""CREATE TABLE """ + sparse_mtbl + """(sparse_m integer[]) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (sparse_m)')""") |
| plpy.execute("""CREATE TABLE """ + tmp_featureset + """(f_name text, feature integer[]) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (f_name)')""") |
| plpy.execute("""CREATE TABLE """ + tmp_segmenttbl + """(start_pos int,doc_id int,seg_text text,label int) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""") |
| plpy.execute("""CREATE TABLE """ + tmp_segcount_tbl + """(doc_id int, doc_len int) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""") |
| |
| |
| # replace digits with "DIGIT" keyword |
| plpy.execute("""INSERT INTO """ + tmp_segmenttbl + """ SELECT start_pos,doc_id,seg_text,label FROM """ + train_segment_tbl + """ WHERE |
| NOT (seg_text ~ E'^[-+]?([0-9]{1,3}[,]?)*[0-9]{1,3}$' OR seg_text ~ E'^[-+]?[0-9]*[.][0-9]+$');""") |
| plpy.execute("""INSERT INTO """ + tmp_segmenttbl + """ SELECT start_pos,doc_id,'DIGIT',label FROM """ + train_segment_tbl + """ WHERE |
| seg_text ~ E'^[-+]?([0-9]{1,3}[,]?)*[0-9]{1,3}$' OR seg_text ~E'^[-+]?[0-9]*[.][0-9]+$';""") |
| |
| # Create the dictionary_tbl table containing distinct tokens |
| plpy.execute(""" |
| CREATE TABLE {dictionary_tbl} AS |
| SELECT seg_text token, count(*) total FROM {tmp_segmenttbl} |
| GROUP BY seg_text |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (token)') |
| """.format(dictionary_tbl = dictionary_tbl, |
| tmp_segmenttbl = tmp_segmenttbl)) |
| |
| plpy.execute(""" |
| DROP TABLE IF EXISTS {tmp_segcount_tbl}; |
| CREATE TABLE {tmp_segcount_tbl} AS |
| SELECT doc_id, count(*) - 1 doc_len |
| FROM {train_segment_tbl} |
| GROUP BY doc_id |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)') |
| """.format(tmp_segcount_tbl = tmp_segcount_tbl, |
| train_segment_tbl = train_segment_tbl)) |
| |
| # create a temporary table to store all the features |
| |
| # extract all the edge features |
| plpy.execute("""INSERT INTO """ + tmp1_feature + """(start_pos, doc_id, f_name, feature) |
| SELECT doc2.start_pos, doc2.doc_id, 'E.', ARRAY[doc1.label, doc2.label] |
| FROM """ + tmp_segmenttbl + """ doc1, """ + tmp_segmenttbl + """ doc2 |
| WHERE doc1.doc_id = doc2.doc_id AND doc1.start_pos+1 = doc2.start_pos;""") |
| |
| #extract all the regex features |
| plpy.execute("""INSERT INTO """ + tmp1_feature + """(start_pos, doc_id, f_name, feature) |
| SELECT start_pos, doc_id, 'R_' || name, ARRAY[-1, label] |
| FROM """ + regex_tbl + """, """ + tmp_segmenttbl + """ |
| WHERE seg_text ~ pattern;""") |
| |
| #extract all the start feature |
| plpy.execute("""INSERT INTO """ + tmp1_feature + """(start_pos, doc_id, f_name, feature) |
| SELECT start_pos, doc_id, 'S.', ARRAY[-1, label] |
| FROM """ + tmp_segmenttbl + """ |
| WHERE start_pos = 0;""") |
| |
| #extract all the end featue |
| plpy.execute("""INSERT INTO """ + tmp1_feature + """(start_pos, doc_id, f_name, feature) |
| SELECT start_pos, t.doc_id, 'End.', ARRAY[-1, label] |
| FROM """ + tmp_segmenttbl + """ t, """ + tmp_segcount_tbl + """ q |
| WHERE t.doc_id = q.doc_id AND t.start_pos = q.doc_len""") |
| |
| #word feature |
| plpy.execute("""INSERT INTO """ + tmp1_feature + """(start_pos, doc_id, f_name, feature) |
| SELECT start_pos, doc_id, 'W_' || seg_text, ARRAY[-1, label] |
| FROM """ + tmp_segmenttbl + """;""") |
| |
| #unknown feature |
| plpy.execute("""INSERT INTO """ + tmp1_feature + """(start_pos, doc_id, f_name, feature) |
| SELECT start_pos, doc_id, 'U', ARRAY[-1, label] |
| FROM """ + tmp_segmenttbl + """ seg, """ + dictionary_tbl + """ dic |
| WHERE seg.seg_text = dic.token AND dic.total <= 1;""") |
| |
| plpy.execute("""INSERT INTO """ + tmp_featureset + """(f_name, feature) |
| SELECT DISTINCT f_name, feature |
| FROM """ + tmp1_feature + """;""") |
| |
| # Enforce ANALYZE to gather proper table statistics required to generate optimized query plans |
| plpy.execute("""ANALYZE {tmp1_feature} """.format(tmp1_feature = tmp1_feature)) |
| |
| plpy.execute("""DROP SEQUENCE IF EXISTS seq; |
| CREATE SEQUENCE seq START 1 INCREMENT 1;""") |
| |
| #get all distcint features |
| plpy.execute(""" |
| CREATE table {train_featureset_tbl} AS |
| SELECT CAST(nextval('seq')-1 AS INTEGER) f_index, f_name, feature FROM {tmp_featureset} |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (f_index)') |
| """.format(train_featureset_tbl = train_featureset_tbl, |
| tmp_featureset = tmp_featureset)) |
| |
| # Enforce ANALYZE to gather proper table statistics required to generate optimized query plans |
| plpy.execute(""" ANALYZE {train_featureset_tbl} """.format(train_featureset_tbl = train_featureset_tbl)) |
| |
| rv = plpy.execute("""SELECT COUNT(*) AS total_feature FROM """ + train_featureset_tbl + """;""") |
| |
| plpy.execute("""INSERT INTO """ + tmp_rtbl + """(start_pos,doc_id,feature) |
| SELECT start_pos, doc_id, array_cat(fset.feature, |
| ARRAY[f_index,start_pos, |
| CASE WHEN """ + tmp1_feature + """.feature = fset.feature THEN 1 |
| ELSE 0 |
| END] ) |
| FROM """ + tmp1_feature + """, """ + train_featureset_tbl + """ fset |
| WHERE """ + tmp1_feature + """.f_name = fset.f_name AND fset.f_name <> 'E.';""") |
| |
| plpy.execute("""INSERT INTO {sparse_rtbl} (doc_id, f_size, sparse_r) |
| SELECT doc_id, {f_size}, {schema_madlib}.array_union(feature::integer[] order by start_pos) |
| FROM {tmp_rtbl} |
| GROUP BY doc_id;""".format(schema_madlib = schema_madlib, |
| sparse_rtbl = sparse_rtbl, |
| f_size = rv[0]['total_feature'], |
| tmp_rtbl = tmp_rtbl)) |
| |
| plpy.execute("""INSERT INTO """ + tmp_dense_mtbl + """(start_pos,doc_id,feature) |
| SELECT start_pos, doc_id, |
| array_cat(fset.feature, ARRAY[f_index,start_pos,1]) |
| FROM """ + tmp1_feature + """, """ + train_featureset_tbl + """ fset |
| WHERE start_pos > 0 AND """ + tmp1_feature + """.f_name = fset.f_name AND """ + tmp1_feature + """.feature = fset.feature AND fset.f_name = 'E.';""") |
| |
| plpy.execute("""INSERT INTO {dense_mtbl} (doc_id, dense_m) |
| SELECT doc_id, {schema_madlib}.array_union(feature::integer[] order by start_pos) |
| FROM {tmp_dense_mtbl} |
| GROUP BY doc_id;""".format(schema_madlib = schema_madlib, |
| dense_mtbl = dense_mtbl, |
| tmp_dense_mtbl = tmp_dense_mtbl)) |
| |
| plpy.execute("""INSERT INTO {sparse_mtbl} (sparse_m) |
| SELECT {schema_madlib}.array_union(array_cat(ARRAY[f_index],feature)) |
| FROM {train_featureset_tbl} fset |
| WHERE f_name = 'E.';""".format(schema_madlib = schema_madlib, |
| sparse_mtbl = sparse_mtbl, |
| train_featureset_tbl = train_featureset_tbl)) |
| |
| plpy.execute(""" |
| CREATE TABLE {train_feature_tbl} AS |
| SELECT {sparse_rtbl}.doc_id, f_size, sparse_r, dense_m, sparse_m |
| FROM {sparse_rtbl}, {dense_mtbl}, {sparse_mtbl} |
| WHERE {sparse_rtbl}.doc_id = {dense_mtbl}.doc_id |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)') |
| """.format(train_feature_tbl = train_feature_tbl, |
| sparse_rtbl = sparse_rtbl, |
| sparse_mtbl = sparse_mtbl, |
| dense_mtbl = dense_mtbl)) |
| |
| # Enforce ANALYZE to gather proper table statistics required to generate optimized query plans |
| plpy.execute(""" ANALYZE {train_feature_tbl} """.format(train_feature_tbl = train_feature_tbl)) |
| |
| plpy.execute("""SET client_min_messages TO """ + str(origClientMinMessages[0]['setting']) + """;""") |
| |
| |
| def generate_test_features(schema_madlib, test_segment_tbl, |
| dictionary_tbl, label_tbl, regex_tbl, crf_weights_tbl, |
| viterbi_mtbl, viterbi_rtbl, **kwargs): |
| |
| _validate_test_args(test_segment_tbl, dictionary_tbl, label_tbl, |
| regex_tbl, crf_weights_tbl, viterbi_mtbl, viterbi_rtbl) |
| |
| # Create m&r factor table |
| plpy.execute(""" |
| CREATE TABLE {viterbi_mtbl} (score DOUBLE PRECISION[]) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (score)') |
| """.format(viterbi_mtbl = viterbi_mtbl)); |
| |
| plpy.execute(""" |
| CREATE TABLE {viterbi_rtbl} |
| (seg_text text, label integer, score DOUBLE PRECISION) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (label)') |
| """.format(viterbi_rtbl = viterbi_rtbl)) |
| |
| # Create index for performance |
| _tablename = viterbi_rtbl.split('.') |
| if len(_tablename) == 1: |
| rtbl_name = _tablename[0] |
| else: |
| rtbl_name = _tablename[-1] |
| |
| rtbl_name_idx = add_postfix(rtbl_name, "_idx") |
| |
| plpy.execute(""" |
| CREATE INDEX {rtbl_name_idx} ON {viterbi_rtbl} (seg_text) |
| """.format(rtbl_name_idx = rtbl_name_idx, |
| viterbi_rtbl = viterbi_rtbl)) |
| |
| origClientMinMessages = plpy.execute("""SELECT setting AS setting |
| FROM pg_settings WHERE name = \'client_min_messages\';""") |
| plpy.execute("SET client_min_messages TO warning;") |
| |
| plpy.execute("SELECT {schema_madlib}.create_schema_pg_temp()".format(schema_madlib=schema_madlib)) |
| |
| prev_labeltbl = "pg_temp._madlib_prev_labeltbl" |
| segment_hashtbl = "pg_temp._madlib_segment_hashtbl" |
| unknown_segment_hashtbl = "pg_temp._madlib_unknown_segment_hashtbl" |
| rtbl = "pg_temp._madlib_rtbl" |
| mtbl = "pg_temp._madlib_mtbl" |
| tmp_segment_tbl = "pg_temp._madlib_tmp_segment_tbl" |
| tmp_dict = "pg_temp._madlib_tmp_dict" |
| |
| plpy.execute("""DROP TABLE IF EXISTS {prev_labeltbl}, |
| {segment_hashtbl}, |
| {unknown_segment_hashtbl}, |
| {rtbl}, |
| {mtbl}, |
| {tmp_segment_tbl}, |
| {tmp_dict} |
| """.format(prev_labeltbl = prev_labeltbl, |
| segment_hashtbl = segment_hashtbl, |
| unknown_segment_hashtbl = unknown_segment_hashtbl, |
| rtbl = rtbl, |
| mtbl = mtbl, |
| tmp_segment_tbl = tmp_segment_tbl, |
| tmp_dict = tmp_dict)) |
| |
| plpy.execute("""CREATE TABLE """ + prev_labeltbl + """(id int) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (id)')""") |
| |
| # Insert unique tokens into the """ + segment_hashtbl + """ |
| plpy.execute("CREATE TABLE """ + segment_hashtbl + """(seg_text text) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (seg_text)')""") |
| |
| # create a temp partial dictionary_tbl table which stores the words whose occurance |
| # is below certain threshold, refer to the CRF Package |
| plpy.execute("""CREATE TABLE """ + unknown_segment_hashtbl + """(seg_text text) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (seg_text)')""") |
| |
| # Generate a sparse matrix to store the r factors |
| plpy.execute("""CREATE TABLE """ + rtbl + """ (seg_text text NOT NULL, label integer, value double precision) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (label)')""") |
| |
| # Generate M factor table |
| plpy.execute("""CREATE TABLE """ + mtbl + """(prev_label integer, label integer, value double precision) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (label)')""") |
| |
| |
| # temp tables to keep segments and dictionary_tbl with all digits replaced by the word 'DIGIT' |
| plpy.execute("""CREATE TABLE """ + tmp_segment_tbl + """(start_pos int,doc_id int,seg_text text) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (doc_id)')""") |
| plpy.execute("""CREATE TABLE """+ tmp_dict + """(token text, total int) |
| m4_ifdef(`__POSTGRESQL__', `', `DISTRIBUTED BY (token)')""") |
| |
| plpy.execute("""SET client_min_messages TO """ + str(origClientMinMessages[0]['setting']) + """;""") |
| |
| # Calculate the number of labels in the label space |
| rv = plpy.execute("""SELECT COUNT(*) AS total_label FROM """ + label_tbl + """;""") |
| nlabel = rv[0]['total_label'] |
| |
| # replace digits with "DIGIT" keyword |
| plpy.execute("""INSERT INTO """ + tmp_segment_tbl + """ SELECT start_pos,doc_id,seg_text FROM """ + test_segment_tbl + """ WHERE |
| NOT (seg_text ~ E'^[-+]?([0-9]{1,3}[,]?)*[0-9]{1,3}$' OR seg_text ~ E'^[-+]?[0-9]*[.][0-9]+$')""") |
| plpy.execute("""INSERT INTO """ + tmp_segment_tbl + """ SELECT start_pos,doc_id,'DIGIT' FROM """ + test_segment_tbl + """ WHERE |
| seg_text ~ E'^[-+]?([0-9]{1,3}[,]?)*[0-9]{1,3}$' OR seg_text ~E'^[-+]?[0-9]*[.][0-9]+$';""") |
| |
| plpy.execute("""INSERT INTO """+ tmp_dict + """ SELECT token,sum(total) FROM """ + dictionary_tbl + """ GROUP BY token |
| HAVING (token NOT LIKE E'^[-+]?([0-9]{1,3}[,]?)*[0-9]{1,3}$' AND token NOT LIKE E'^[-+]?[0-9]*[.][0-9]+$')""") |
| plpy.execute("""INSERT INTO """+ tmp_dict + """ SELECT 'DIGIT',sum(total) FROM """ + dictionary_tbl + """ WHERE |
| (token ~ E'^[-+]?([0-9]{1,3}[,]?)*[0-9]{1,3}$' OR token ~ E'^[-+]?[0-9]*[.][0-9]+$') GROUP BY token;""") |
| |
| plpy.execute("""INSERT INTO """ + segment_hashtbl + """(seg_text) |
| SELECT DISTINCT seg_text |
| FROM """ + tmp_segment_tbl + """;""") |
| |
| plpy.execute("""INSERT INTO """ + unknown_segment_hashtbl + """(seg_text) |
| ((SELECT DISTINCT seg_text |
| FROM """ + segment_hashtbl + """) |
| EXCEPT |
| (SELECT DISTINCT token |
| FROM """+ tmp_dict + """ |
| WHERE total>1));""") |
| |
| plpy.execute("""INSERT INTO """ + prev_labeltbl + """ |
| SELECT id |
| FROM """ + label_tbl + """; |
| INSERT INTO """ + prev_labeltbl + """ VALUES(-1); |
| INSERT INTO """ + prev_labeltbl + """ VALUES( """ + str(nlabel) + """);""") |
| |
| # Generate sparse M factor table |
| plpy.execute("""INSERT INTO """ + mtbl + """(prev_label, label, value) |
| SELECT prev_label.id, label.id, 0 |
| FROM """ + label_tbl + """ AS label, |
| """ + prev_labeltbl + """ as prev_label;""") |
| |
| # EdgeFeature and startFeature, startFeature can be considered as a special edgeFeature |
| plpy.execute("""INSERT INTO """ + mtbl + """(prev_label, label, value) |
| SELECT prev_label_id,label_id,weight |
| FROM """ + crf_weights_tbl + """ AS features |
| WHERE features.prev_label_id<>-1 OR features.name = 'S.';""") |
| |
| # EndFeature, endFeature can be considered as a special edgeFeature |
| plpy.execute("""INSERT INTO """ + mtbl + """(prev_label, label, value) |
| SELECT """ + str(nlabel) + """, label_id, weight |
| FROM """ + crf_weights_tbl + """ AS features |
| WHERE features.name = 'End.';""") |
| |
| m4_ifdef(`__HAS_ORDERED_AGGREGATES__', ` |
| plpy.execute("""INSERT INTO """ + viterbi_mtbl + """ |
| SELECT array_agg(weight ORDER BY prev_label,label) |
| FROM (SELECT prev_label, label, (SUM(value)*1000)::FLOAT8 AS weight |
| FROM """ + mtbl + """ |
| GROUP BY prev_label,label |
| ORDER BY prev_label,label) as TEMP_MTBL;""".format( |
| viterbi_mtbl = viterbi_mtbl |
| )) |
| ', ` |
| plpy.execute("""INSERT INTO """ + viterbi_mtbl + """ |
| SELECT ARRAY( |
| SELECT |
| (SUM(value) * 1000)::FLOAT8 |
| FROM |
| """ + mtbl + """ |
| GROUP BY |
| prev_label, label |
| ORDER BY |
| prev_label, label |
| );""".format( |
| viterbi_mtbl = viterbi_mtbl |
| )) |
| ') |
| |
| plpy.execute("""INSERT INTO """ + rtbl + """(seg_text, label, value) |
| SELECT segment_hashtbl.seg_text, labels.id, 0 |
| FROM """ + segment_hashtbl + """ segment_hashtbl, |
| """ + label_tbl + """ AS labels;""") |
| |
| # RegExFeature |
| plpy.execute("""INSERT INTO """ + rtbl + """(seg_text, label, value) |
| SELECT segment_hashtbl.seg_text, features.label_id, features.weight |
| FROM """ + segment_hashtbl + """ AS segment_hashtbl, |
| """ + crf_weights_tbl + """ AS features, |
| """ + regex_tbl + """ AS regex |
| WHERE segment_hashtbl.seg_text ~ regex.pattern |
| AND features.name ='R_' || regex.name;""") |
| |
| # UnknownFeature |
| plpy.execute("""INSERT INTO """ + rtbl + """(seg_text, label, value) |
| SELECT segment_hashtbl.seg_text, features.label_id, features.weight |
| FROM """ + unknown_segment_hashtbl + """ AS segment_hashtbl, |
| """ + crf_weights_tbl + """ AS features |
| WHERE features.name = 'U';""") |
| |
| # Wordfeature |
| plpy.execute("""INSERT INTO """ + rtbl + """(seg_text, label, value) |
| SELECT seg_text, label_id, weight |
| FROM """ + segment_hashtbl + """, |
| """ + crf_weights_tbl + """ |
| WHERE name = 'W_' || seg_text;""") |
| |
| # Factor table |
| plpy.execute("""INSERT INTO """ + viterbi_rtbl + """(seg_text, label, score) |
| SELECT seg_text,label,(SUM(value)*1000)::FLOAT8 AS score |
| FROM """ + rtbl + """ |
| GROUP BY seg_text,label;""") |
| |
| # Enforce ANALYZE to gather proper table statistics required to generate optimized query plans |
| plpy.execute(""" ANALYZE {viterbi_mtbl} """.format(viterbi_mtbl = viterbi_mtbl)) |
| plpy.execute(""" ANALYZE {viterbi_rtbl} """.format(viterbi_rtbl = viterbi_rtbl)) |
| |
| |
| def _validate_label_tbl(label_tbl): |
| |
| rv = plpy.execute("""SELECT count(*), max(id), min(id) FROM {label_tbl} |
| """.format(label_tbl = label_tbl)) |
| |
| count = rv[0]['count'] |
| max_id = rv[0]['max'] |
| min_id = rv[0]['min'] |
| |
| _assert(min_id >= 0 and max_id <= count - 1, |
| "CRF error: Bound check failed for label table." |
| " Expected id values between 0 to total number of elements in the table - 1") |
| |
| def _validate_columns(cols, table_name, err_msg_tbl): |
| """ |
| @brief Validate if cols exists in the table |
| """ |
| |
| _assert(columns_exist_in_table(table_name, cols), |
| "CRF error: Missing required columns from %s table: %s" % (err_msg_tbl, ', '.join(cols))) |
| |
| def _validate_train_args(train_segment_tbl, regex_tbl, label_tbl, |
| dictionary_tbl, train_feature_tbl, train_featureset_tbl): |
| |
| """ |
| @brief Validate the arguments: Feature extraction for training. |
| """ |
| |
| # Validate existence of input tables. |
| _assert(table_exists(train_segment_tbl), |
| "CRF error: Train segment table does not exist!") |
| _assert(table_exists(regex_tbl), |
| "CRF error: Regex table does not exist!") |
| _assert(table_exists(label_tbl), |
| "CRF error: Label table does not exist!") |
| |
| # Validate required column names existence in respective tables. |
| _validate_columns(['doc_id', 'start_pos', 'seg_text', 'label'], |
| train_segment_tbl, "segment") |
| _validate_columns(['pattern', 'name'], regex_tbl, "regex") |
| _validate_columns(['id', 'label'], label_tbl, "label") |
| |
| _validate_label_tbl(label_tbl) |
| |
| # Validate output tables for valid names. |
| _assert(dictionary_tbl is not None and |
| dictionary_tbl.lower().strip() not in ('null', ''), |
| "CRF error: Invalid dictionary table name") |
| _assert(train_feature_tbl is not None and |
| train_feature_tbl.lower().strip() not in ('null', ''), |
| "CRF error: Invalid train feature table name") |
| _assert(train_featureset_tbl is not None and |
| train_featureset_tbl.lower().strip() not in ('null', ''), |
| "CRF error: Invalid train fatureset table name") |
| |
| _assert(not table_exists(dictionary_tbl), |
| "CRF error: Dictionary table name already exist!" |
| " Please provide a different table name.") |
| _assert(not table_exists(train_feature_tbl), |
| "CRF error: Train feature table name already exist!" |
| " Please provide a different table name.") |
| _assert(not table_exists(train_featureset_tbl), |
| "CRF error: Train featureset table name already exist!" |
| " Please provide a different table name.") |
| |
| def _validate_test_args(test_segment_tbl, dictionary_tbl, label_tbl, |
| regex_tbl, crf_weights_tbl, viterbi_mtbl, viterbi_rtbl): |
| |
| """ |
| @brief Validate the arguments: Feature extraction for testing. |
| """ |
| |
| # Check existence of input tables. |
| _assert(table_exists(test_segment_tbl), |
| "CRF error: Test segment table does not exist!") |
| _assert(table_exists(dictionary_tbl), |
| "CRF error: Dictionary table does not exist!") |
| _assert(table_exists(label_tbl), |
| "CRF error: Label table does not exist!") |
| _assert(table_exists(regex_tbl), |
| "CRF error: Regex table does not exist!") |
| _assert(table_exists(crf_weights_tbl), |
| "CRF error: CRF weights table does not exist!") |
| |
| # Validate required column names existence in respective tables. |
| _validate_columns(['doc_id', 'start_pos', 'seg_text'], test_segment_tbl, "segment") |
| _validate_columns(['token', 'total'], dictionary_tbl, "dictionary") |
| _validate_columns(['id', 'label'], label_tbl, "label") |
| _validate_columns(['pattern', 'name'], regex_tbl, "regex") |
| _validate_columns(['id', 'name', 'prev_label_id', 'label_id', 'weight'], |
| crf_weights_tbl, "crf weights") |
| |
| _validate_label_tbl(label_tbl) |
| |
| # Validate output tables for valid names. |
| _assert(viterbi_mtbl is not None and |
| viterbi_mtbl.lower().strip() not in ('null', ''), |
| "CRF error: Invalid viterbi mtable name") |
| _assert(viterbi_rtbl is not None and |
| viterbi_rtbl.lower().strip() not in ('null', ''), |
| "CRF error: Invalid viterbi rtable name") |
| |
| _assert(not table_exists(viterbi_mtbl), |
| "CRF error: Viterbi M table name already exist!" |
| " Please provide a different table name.") |
| _assert(not table_exists(viterbi_rtbl), |
| "CRF error: Viterbi R table name already exist!" |
| " Please provide a different table name.") |