blob: f74d13f13331c1f3001d1c5cdf27f6e419b6a448 [file] [log] [blame]
# ----------------------------------------------------------------------
# Robust variance for Cox Proportional Hazards model
# ----------------------------------------------------------------------
import plpy
from utilities.utilities import unique_string
from utilities.utilities import _assert
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
# ----------------------------------------------------------------------
def rb_help_message(schema_madlib, message=None, **kwargs):
"""
Given a help string, provide usage information
Args:
schema_madlib: String. Schema where MADlib is installed
message: String. Which help message to display
Returns:
String. A detailed help message
"""
if message is not None and \
message.lower() in ("usage", "help", "?"):
return """
This function calculates the Robust statistics for
Cox Proportional Hazards Regression.
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT madlib.robust_variance_coxph
(
model_table TEXT, -- Model table name, which is generated FROM coxph_train
output_table TEXT, -- Output table name
)
-----------------------------------------------------------------------
OUTPUT
-----------------------------------------------------------------------
The output table contains the following columns:
- coef: Coefficients of regression
- loglikelihood: Log-likelihood value
- std_err: Standard errors
- robust_se: Standard errors of the robust variance estimators
- robust_z: z-stats of the robust variance estimators
- robust_p: p-values of the robust variance estimators
- hessian: Hessian matrix
"""
else:
return """
This function calculates the Huber-White robust statistics for
Cox Proportional Hazards Regression.
For more details on function usage:
SELECT madlib.robust_variance_coxph('usage');
"""
# ----------------------------------------------------------------------
def _validate_params(schema_madlib, model_table, output_table):
""" Validate the input parameters for coxph
Args:
@param schema_madlib - MADlib schema name
@param model_table - A string, the model table name
@param output_table - A string, the result table name
Throws:
"Robust Variance Cox error" if any argument is invalid
"""
_assert(model_table is not None and table_exists(model_table),
"Robust Variance Cox error: Model data table does not exist")
_assert(model_table is not None and table_exists(model_table + '_summary'),
"Robust Variance Cox error: Model summary table does not exist")
_assert(columns_exist_in_table(model_table,
['coef', 'loglikelihood', 'std_err',
'z_stats', 'p_values', 'hessian']),
"Robust Variance Cox error: Invalid model data table"
" - some required columns missing")
_assert(columns_exist_in_table(model_table + '_summary',
['source_table', 'dependent_varname', 'independent_varname',
'right_censoring_status', 'strata']),
"Robust Variance Cox error: Invalid model summary table"
" - some required columns missing")
_assert(output_table is not None and output_table.strip() != '',
"Robust Variance Cox error: Invalid output_table is given")
_assert(not table_exists(output_table, only_first_schema=True),
"Robust Variance Cox error: Output table {0}"
" already exists!".format(str(output_table)))
not_null_coef = plpy.execute("""
SELECT count(*) AS c FROM {model_table}
WHERE coef IS NOT NULL AND hessian IS NOT NULL
""".format(model_table=model_table))[0]["c"]
_assert(not_null_coef != 0,
"Robust Variance Cox error: No not-null coef and hessian found in "
"model_table {model_table}".format(model_table=model_table))
# To warn users who have multiple rows in the model table.
if not_null_coef > 1:
plpy.warning("Robust Variance Cox Warning: multiple rows in "
"model_table {model_table}".format(model_table=model_table))
coef_contains_null = plpy.execute("""
SELECT {schema_madlib}.array_contains_null(coef) OR
{schema_madlib}.array_contains_null(hessian)
AS contains_null
FROM {model_table}
""".format(schema_madlib=schema_madlib,
model_table=model_table))[0]["contains_null"]
_assert(not coef_contains_null,
"Robust Variance Cox error: coef or hessian array in {0} contains "
"NULL values. (If the input table contains at least one row without "
" NULLS then rerunning coxph should get correct values)".
format(model_table))
# ----------------------------------------------------------------------
def rb_coxph(schema_madlib, model_table, output_table, **kwargs):
""" Compute the Huber-White robust sandwich estimator for CoxPH model
@brief Huber-White robust sandwich estimator for Cox proportional
hazards regression.
Args:
@param schema_madlib - MADlib schema name
@param model_table - A string, the table name of a trained cox model
@param output_table - A string, the result table name
Returns:
None
Side effect:
A table named by output_table, which contains the following
columns:
* coef - the input coef
* loglikelihood - the input log-likelihood
* std_err - the input standard errors
* robust_se - standard erros of robust variance estimators
* robust_z - z statistics of robust variance estimators
* robust_p - p value of robust variance estimators
* hessian - the input hessian
"""
old_msg_level = plpy.execute("""
SELECT setting FROM pg_settings
WHERE name='client_min_messages'
""")[0]['setting']
plpy.execute("set client_min_messages to error")
_validate_params(schema_madlib, model_table, output_table)
# info is a dict that contains source_table, ind_var, dep_var,
# right_censoring_status, strata
info = plpy.execute("SELECT * FROM {model_table}_summary".
format(model_table=model_table))[0]
# table name of the result of H and S
temp_H_S = unique_string()
if info['strata'] is None:
# Create H and S table, use window function ordered desc
plpy.execute(
"""
CREATE TEMP TABLE {temp_H_S} as
SELECT
({independent_varname})::FLOAT8[] AS x,
({dependent_varname})::FLOAT8 AS y,
({right_censoring_status})::BOOLEAN AS status,
{schema_madlib}.coxph_h_s(
{independent_varname},
(SELECT coef FROM {model_table})
) OVER (ORDER BY {dependent_varname} desc) AS h_s
FROM
{source_table} s
""".format(temp_H_S=temp_H_S, model_table=model_table,
schema_madlib=schema_madlib, **info))
# W is computed in ascending order
plpy.execute(
"""
CREATE TABLE {output_table} AS
SELECT
u.coef, u.loglikelihood, u.std_err,
(v.f).std_err AS robust_se, (v.f).stats AS robust_z,
(v.f).p_values AS robust_p, u.hessian
FROM
{model_table} u,
( SELECT {schema_madlib}.rb_coxph_step(
x, y, status,
(SELECT coef FROM {model_table}),
(SELECT hessian FROM {model_table}),
(h_s).h, (h_s).s
ORDER BY y) AS f
FROM
{temp_H_S}
) v
""".format(output_table=output_table, model_table=model_table,
schema_madlib=schema_madlib, temp_H_S=temp_H_S,
**info))
else:
# strata can have multiple columns, so cannot use '{strata} AS strata'
# To avoid name conflicts, have to use unique strings
x = unique_string()
y = unique_string()
status = unique_string()
h_s = unique_string()
# Create H and S table, use window function ordered desc
# But for each strata, we need to do this
plpy.execute(
"""
CREATE TEMP TABLE {temp_H_S} AS
SELECT
({independent_varname})::FLOAT8[] AS {x},
({dependent_varname})::FLOAT8 AS {y},
({right_censoring_status})::BOOLEAN AS {status},
{strata},
{schema_madlib}.coxph_h_s(
{independent_varname},
(SELECT coef FROM {model_table})
) OVER (partition by {strata}
ORDER BY {dependent_varname} desc) AS {h_s}
FROM
{source_table} s
""".format(temp_H_S=temp_H_S, model_table=model_table,
x=x, y=y, status=status, h_s=h_s,
schema_madlib=schema_madlib,
**info))
# W is computed with ascending order
plpy.execute(
"""
CREATE TABLE {output_table} AS
SELECT
u.coef, u.loglikelihood, u.std_err,
(v.f).std_err AS robust_se, (v.f).stats AS robust_z,
(v.f).p_values AS robust_p, u.hessian
FROM
{model_table} u,
(
SELECT {schema_madlib}.rb_sum_strata(in_state) AS f
FROM (
SELECT {schema_madlib}.rb_coxph_strata_step(
{x}, {y}, {status},
(SELECT coef FROM {model_table}),
(SELECT hessian FROM {model_table}),
({h_s}).h, ({h_s}).s
ORDER BY {y}) AS in_state
FROM
{temp_H_S}
group by {strata}
) w
) v
""".format(output_table=output_table, model_table=model_table,
x=x, y=y, status=status, h_s=h_s,
schema_madlib=schema_madlib,
temp_H_S=temp_H_S, **info))
# plpy.info(temp_H_S)
plpy.execute("DROP TABLE IF EXISTS {temp_H_S}".format(temp_H_S=temp_H_S))
plpy.execute("SET client_min_messages TO " + old_msg_level)
return None