blob: 4a26c4a599063917e162548d1a01ca811a19275e [file] [log] [blame]
# coding=utf-8
"""
@file cox_prop_hazards.py_in
@brief Cox prop Hazards: Driver functions
@namespace coxprophazards
Cox prop Hazards: Driver functions
//"""
import plpy
from utilities.validate_args import table_exists
from utilities.validate_args import table_is_empty
from utilities.validate_args import get_cols
from utilities.validate_args import is_var_valid
from utilities.utilities import __unique_string
from utilities.utilities import preprocess_optimizer_params
from utilities.utilities import _assert
from utilities.validate_args import columns_exist_in_table
from utilities.utilities import __mad_version
# ----------------------------------------------------------------------
version_wrapper = __mad_version()
madvec = version_wrapper.select_vecfunc()
def coxph_help_message(schema_madlib, message, **kwargs):
""" Help message for Cox Proportional Hazards
@brief
Args:
@param schema_madlib string, Name of the schema madlib
@param message string, Help message indicator
Returns:
String. Contain the help message string
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Functionality: Cox proprtional hazards regression (Breslow method)
Proportional-Hazard models enable the comparison of various survival models.
These survival models are functions describing the probability of a one-item
event (prototypically, this event is death) with respect to time.
The interval of time before death occurs is the survival time.
Let T be a random variable representing the survival time,
with a cumulative probability function P(t). Informally, P(t) is
the probability that death has happened before time t.
For more details on function usage:
SELECT {schema_madlib}.coxph_train('usage')
For an example on using the function:
SELECT {schema_madlib}.coxph_train('example')
"""
elif message in ['usage', 'help', '?']:
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT {schema_madlib}.coxph_train(
'source_table', -- Name of data table
'output_table', -- Name of result table (overwrites if exists)
'dependent_variable', -- Name of column for dependent variables
'independent_variable', -- Name of column for independent variables
(can be any SQL expression Eg: '*')
'right_censoring_status', -- Name of the column containing censoring status
0/false : If the observation is censored
1/true : otherwise
Can also be an SQL expression: 'dependent_variable < 10')
(Optional, DEFAULT = TRUE)
'strata', -- The stratification column names. (Optional, DEFAULT = NULL)
'optimizer_params' -- The optimizer parameters as a comma-separated string
);
-----------------------------------------------------------------------
OUTUPT
-----------------------------------------------------------------------
The output table ('output_table' above) has the following columns
'coef' DOUBLE PRECISION[], -- Coefficients of regression
'loglikelihood' DOUBLE PRECISION, -- Log-likelihood value
'std_err' DOUBLE PRECISION[], -- Standard errors
'z_stats' DOUBLE PRECISION[], -- z-stats of the standard errors
'p_values' DOUBLE PRECISION[], -- p-values of the standard errors
'num_iterations' INTEGER -- Number of iterations performed by the optimizer
The output summary table is named as <output_table>_summary has the following columns
'source_table' VARCHAR, Source table name
'dep_var' VARCHAR, Dependent variable name
'ind_var' VARCHAR, Independent variable name
'right_censoring_status' VARCHAR, Right censoring status
'strata' VARCHAR, Stratification columns
num_rows_processed INTEGER, Number of rows processed during training
num_missing_rows_skipped INTEGER, Number of rows skipped during training
due to missing values
"""
elif message in ['example', 'examples']:
help_string = """
DROP TABLE IF EXISTS sample_data;
CREATE TABLE sample_data (
id INTEGER NOT NULL,
grp DOUBLE PRECISION,
wbc DOUBLE PRECISION,
timedeath INTEGER,
status BOOLEAN
);
COPY sample_data FROM STDIN DELIMITER '|';
0 | 0 | 1.45 | 35 | t
1 | 0 | 1.47 | 34 | t
3 | 0 | 2.2 | 32 | t
4 | 0 | 1.78 | 25 | t
5 | 0 | 2.57 | 23 | t
6 | 0 | 2.32 | 22 | t
7 | 0 | 2.01 | 20 | t
8 | 0 | 2.05 | 19 | t
9 | 0 | 2.16 | 17 | t
10 | 0 | 3.6 | 16 | t
11 | 1 | 2.3 | 15 | t
12 | 0 | 2.88 | 13 | t
13 | 1 | 1.5 | 12 | t
14 | 0 | 2.6 | 11 | t
15 | 0 | 2.7 | 10 | t
16 | 0 | 2.8 | 9 | t
17 | 1 | 2.32 | 8 | t
18 | 0 | 4.43 | 7 | t
19 | 0 | 2.31 | 6 | t
20 | 1 | 3.49 | 5 | t
21 | 1 | 2.42 | 4 | t
22 | 1 | 4.01 | 3 | t
23 | 1 | 4.91 | 2 | t
24 | 1 | 5 | 1 | t
\.
SELECT {schema_madlib}.coxph_train(
'sample_data',
'sample_cox',
'timedeath',
'ARRAY[grp,wbc]',
'status');
SELECT * FROM sample_cox;
"""
else:
help_string = "No such option. Use {schema_madlib}.coxph_train()"
return help_string.format(schema_madlib=schema_madlib)
# ---------------------------------------------------------------------------
def coxph(schema_madlib, source_table, output_table, dependent_varname,
independent_varname, right_censoring_status, strata,
optimizer_params, *args, **kwargs):
""" Cox proportional hazards regression training function
@brief Cox proportional hazards regression, with stratification
support.
Args:
@param schema_madlib - MADlib schema name
@param source_table - A string, the data table name
@param output_table - A string, the result table name
@param dependent_varname - A string, the survival time column
name or a valid expression
@param independent_varname - A string, the covariates in array
formats. It is a valid expression.
@param right_censoring_status - A string, a column name or a
valid expression that has boolean values. Whether the row
of data is censored. Default is 'TRUE'.
@param strata - A string, column names seprated by commas. The
columns used for stratification. Default is None.
@param optimizer_params - A string, which contains key=value
pairs separated by commas. Default values: max_iter=20,
optimizer='newton', tolerance=1e-4.
Returns:
A table named by output_table, which contains the following
columns:
* coef - An array of double precision values, fitting coefs
* std_err - An array of double precision values, standard erros
of coef
* z - An array of double precision values, z statistics
* p - An array of double precision values, p value
"""
old_msg_level = plpy.execute("""
SELECT setting FROM pg_settings
WHERE name='client_min_messages'
""")[0]['setting']
plpy.execute("set client_min_messages to error")
all_arguments = {'schema_madlib': schema_madlib,
'source_table': source_table,
'output_table': output_table,
'dependent_varname': dependent_varname,
'independent_varname': independent_varname,
'right_censoring_status': right_censoring_status,
'strata': strata
}
_validate_params(**all_arguments)
max_iter, optimizer, tolerance = _extract_params(schema_madlib,
optimizer_params)
if strata is None:
theIteration, madlib_iterative_alg = compute_coxph(
schema_madlib, source_table, independent_varname,
dependent_varname, right_censoring_status, optimizer,
max_iter, tolerance)
else:
theIteration, madlib_iterative_alg = compute_coxph_strata(
schema_madlib, source_table, independent_varname,
dependent_varname, right_censoring_status, strata,
optimizer, max_iter, tolerance)
total_rows = plpy.execute(""" SELECT count(*) as cnt FROM {0}""".format(
source_table))[0]["cnt"]
# plpy.info("Creating output table")
# we have to put num_processed in output_table first since result can only
# be extracted via SQL. Fetching it in Python fails since fetching two-dim
# array (in this case 'hessian') is not supported.
plpy.execute(""" CREATE TABLE {output_table} AS
SELECT (result).*,
({total_rows} - (result).num_processed)
as num_missing_rows_skipped,
{theIteration} AS num_iterations
FROM (
SELECT
{schema_madlib}.internal_coxph_result(_madlib_state)
AS result
FROM {madlib_iterative_alg}
WHERE _madlib_iteration = {theIteration}
) subq
""".format(schema_madlib=schema_madlib,
output_table=output_table,
theIteration=theIteration,
madlib_iterative_alg=madlib_iterative_alg,
total_rows=total_rows))
# plpy.info("Fetching result")
result = plpy.execute("""SELECT num_processed, num_missing_rows_skipped
FROM {output_table}
""".format(output_table=output_table))[0]
if not result["num_processed"]:
## when no rows have been processed by the aggregate, then a NULL
## result is returned. We need to capture that to ensure correct value
## for num_processed
result["num_processed"] = 0
result["num_missing_rows_skipped"] = total_rows
# plpy.info("Removing rows processed column out of output table since we "
# "place it in summary table")
plpy.execute("""ALTER TABLE {output_table}
DROP num_processed,
DROP num_missing_rows_skipped
""".format(output_table=output_table))
if not all_arguments['right_censoring_status']:
all_arguments['right_censoring_status'] = 'NULL'
else:
all_arguments['right_censoring_status'] = "'" + all_arguments['right_censoring_status'] + "'"
all_arguments['strata'] = "'" + all_arguments['strata'] + "'" if all_arguments['strata'] else 'NULL'
# plpy.info("Creating summary table")
plpy.execute(
"""
CREATE TABLE {output_table}_summary AS
SELECT
'{source_table}'::VARCHAR as source_table,
'{dependent_varname}'::VARCHAR as dependent_variable,
'{independent_varname}'::VARCHAR as independent_variable,
{right_censoring_status}::VARCHAR as right_censoring_status,
{strata}::VARCHAR as strata,
{num_processed}::INTEGER as num_processed,
{num_missing_rows_skipped}::INTEGER as num_missing_rows_skipped
""".format(num_processed=result["num_processed"],
num_missing_rows_skipped=result["num_missing_rows_skipped"],
**all_arguments))
plpy.execute("set client_min_messages to " + old_msg_level)
return None
# ----------------------------------------------------------------------
def _validate_params(schema_madlib, source_table, output_table,
dependent_varname, independent_varname,
right_censoring_status, strata, *args, **kwargs):
""" Validate the input parameters for coxph
Args:
@param schema_madlib - MADlib schema name
@param source_table - A string, the data table name
@param output_table - A string, the result table name
@param dependent_varname - A string, the survival time column
name or a valid expression
@param independent_varname - A string, the covariates in array
formats. It is a valid expression.
@param right_censoring_status - A string, a column name or a
valid expression that has boolean values. Whether the row
of data is censored. Default is 'TRUE'.
@param strata - A string, column name seprated by commas. The
columns used for stratification. Default is None.
Throws:
"Cox error" if any argument is invalid
"""
_assert(source_table is not None and table_exists(source_table),
"Cox error: Source data table does not exist!")
_assert(not table_exists(output_table),
"Cox error: Output table {0}"
" already exists!".format(str(output_table)))
_assert(not table_exists(output_table + "_summary"),
"Cox error: Output table {0}_summary"
" already exists!".format(str(output_table)))
if strata is not None:
strata_cols = [a.strip() for a in strata.split(",")]
_assert(columns_exist_in_table(source_table, strata_cols,
schema_madlib),
"ARIMA error: {1} columns do not exist in {0}!"
.format(source_table, strata_cols))
return None
# ----------------------------------------------------------------------
def _extract_params(schema_madlib, optimizer_params):
""" Extract optimizer control parameter or set the default values
@brief optimizer_params is a string with the format of
'max_iter=..., optimizer=..., tolerance=...'. The order
does not matter. If a parameter is missing, then the default
value for it is used. If optimizer_params is None or '',
then all default values are used. If the parameter specified
is none of 'max_iter', 'optimizer', or 'tolerance' then an
error is raised. This function also validates the values of
these parameters.
Throws:
"Cox error" - If the parameter is unsupported or the value is
not valid.
"""
allowed_params = set(["max_iter", "optimizer", "tolerance"])
name_value = dict(max_iter=100, optimizer="newton", tolerance=1e-8)
if optimizer_params is None or len(optimizer_params) == 0:
return (name_value['max_iter'], name_value['optimizer'],
name_value['tolerance'])
for s in preprocess_optimizer_params(optimizer_params):
items = s.split("=")
if (len(items) != 2):
plpy.error("Cox error: Optimizer parameter list has incorrect format!")
param_name = items[0].strip(" \"").lower()
param_value = items[1].strip(" \"").lower()
if param_name not in allowed_params:
plpy.error(
"""
Cox error: {param_name} is not a valid parameter name.
Run:
SELECT {schema_madlib}.coxph('usage');
to see the allowed parameters.
""".format(param_name=param_name,
schema_madlib=schema_madlib))
if param_name == "max_iter":
try:
name_value["max_iter"] = int(param_value)
except:
plpy.error("Cox error: max_iter must be an integer number!")
if param_name == "optimizer":
name_value["optimizer"] = param_value
if param_name == "tolerance":
try:
name_value["tolerance"] = float(param_value)
except:
plpy.error("Cox error: tolerance must be a double precision value!")
if name_value["max_iter"] <= 0:
plpy.error("Cox error: max_iter must be positive!")
if name_value["optimizer"] != "newton":
plpy.error("Cox error: this optimization method is not supported yet!")
if name_value["tolerance"] < 0:
plpy.error("Cox error: tolerance cannot be smaller than 0!")
return (name_value['max_iter'], name_value['optimizer'],
name_value['tolerance'])
# ----------------------------------------------------------------------
def __check_args(schema_madlib, tbl_source, col_ind_var, col_dep_var, col_status):
_assert(tbl_source is not None,
"Cox Proportional Hazards Error: Source table should not be NULL!")
_assert(col_ind_var is not None,
"Cox Proportional Hazards Error: Independent variable should not be NULL!")
_assert(col_dep_var is not None,
"Cox Proportional Hazards Error: Dependent variable should not be NULL!")
_assert(table_exists(tbl_source),
"Cox Proportional Hazards Error: Source table " + tbl_source + " does not exist!")
_assert(not table_is_empty(tbl_source),
"Cox Proportional Hazards Error: Source table " + tbl_source + " is empty!")
_assert(columns_exist_in_table(tbl_source, [col_dep_var]),
"Cox Proportional Hazards Error: Dependent variable does not exist!")
_assert(is_var_valid(tbl_source, col_ind_var),
"Cox Proportional Hazards Error: The independent variable does not exist!")
_assert(is_var_valid(tbl_source, col_status),
"Cox Proportional Hazards Error: Not a valid boolean expression for status!")
col_ind_var_new = col_ind_var
cols = get_cols(tbl_source)
# Select al columns except status and dependent variable
if col_ind_var == "*":
cols = get_cols(tbl_source)
outstr_array = []
for each_col in cols:
if each_col != col_dep_var.lower() and each_col not in col_status.lower():
outstr_array.append(each_col)
col_ind_var_new = 'array[%s]' % (','.join(outstr_array))
return col_ind_var_new
# ----------------------------------------------------------------------
def __runIterativeAlg(stateType, initialState, source, updateExpr,
terminateExpr, resultExpr, maxNumIterations,
cyclesPerIteration=1, updateExprOuter='',
updateExprInner='', strata='', *args, **kwargs):
"""
Driver for an iterative algorithm
A general driver function for most iterative algorithms: The state between
iterations is kept in a variable of type <tt>stateType</tt>, which is
initialized with <tt><em>initialState</em></tt>. During each iteration, the
SQL statement <tt>updateSQL</tt> is executed in the database. Afterwards,
the SQL query <tt>updateSQL</tt> decides whether the algorithm terminates.
@param stateType SQL type of the state between iterations
@param initialState The initial value of the SQL state variable
@param source The source relation
@param updateExpr SQL expression that returns the new state of type
<tt>stateType</tt>. The expression may use the replacement fields
<tt>"{state}"</tt>, <tt>"{iteration}"</tt>, and
<tt>"{sourceAlias}"</tt>. Source alias is an alias for the source
relation <tt><em>source</em></tt>.
@param terminateExpr SQL expression that returns whether the algorithm should
terminate. The expression may use the replacement fields
<tt>"{oldState}"</tt>, <tt>"{newState}"</tt>, and
<tt>"{iteration}"</tt>. It must return a BOOLEAN value.
@param resultExpr The SQL query to extract the result
@param maxNumIterations Maximum number of iterations. Algorithm will then
terminate even when <tt>terminateExpr</tt> does not evaluate to \c true
@param cyclesPerIteration Number of aggregate function calls per iteration.
"""
madlib_iterative_alg = __unique_string()
# updateExpr is used to branch b/w stratified cox and normal cox
if updateExpr.strip() != '':
updateSQL = """
INSERT INTO {madlib_iterative_alg}
SELECT
{{iteration}},
{updateExpr}
FROM (
SELECT {resultExpr} AS result
FROM {madlib_iterative_alg}
WHERE _madlib_iteration = {{iteration}} - 1
) st,
{{source}} as src
""".format(updateExpr=updateExpr, resultExpr=resultExpr,
madlib_iterative_alg=madlib_iterative_alg)
else:
updateSQL = """
INSERT INTO {madlib_iterative_alg}
SELECT
{{iteration}},
{updateExprOuter}
FROM (
SELECT {updateExprInner} AS inner_state
FROM (
SELECT {resultExpr} AS result
FROM {madlib_iterative_alg}
WHERE _madlib_iteration = {{iteration}} - 1
) s1,
{{source}} as src
GROUP BY {strata}
) s2
""".format(updateExprOuter=updateExprOuter, strata=strata,
updateExprInner=updateExprInner, resultExpr=resultExpr,
madlib_iterative_alg=madlib_iterative_alg)
terminateSQL = """
SELECT
{terminateExpr} AS should_terminate
FROM (
SELECT _madlib_state
FROM {madlib_iterative_alg}
WHERE _madlib_iteration = {{iteration}} - {{cyclesPerIteration}}
) AS older,
(
SELECT _madlib_state
FROM {madlib_iterative_alg}
WHERE _madlib_iteration = {{iteration}}
) AS newer
""".format(terminateExpr=terminateExpr,
madlib_iterative_alg=madlib_iterative_alg)
checkForNullStateSQL = """
SELECT _madlib_state IS NULL AS should_terminate
FROM {madlib_iterative_alg}
WHERE _madlib_iteration = {iteration}
"""
oldMsgLevel = plpy.execute("SELECT setting FROM pg_settings "
"WHERE name='client_min_messages'")[0]['setting']
plpy.execute("""
SET client_min_messages = error;
CREATE TEMPORARY TABLE {madlib_iterative_alg} (
_madlib_iteration INTEGER PRIMARY KEY,
_madlib_state {stateType}
);
SET client_min_messages = {oldMsgLevel};
""".format(stateType=stateType,
oldMsgLevel=oldMsgLevel,
madlib_iterative_alg=madlib_iterative_alg))
iteration = 0
plpy.execute("""
INSERT INTO {madlib_iterative_alg} VALUES ({iteration}, {initialState})
""".format(iteration=iteration, initialState=initialState,
madlib_iterative_alg=madlib_iterative_alg))
while True:
iteration = iteration + 1
plpy.execute(updateSQL.format(
source=source,
state="(_madlib_state)",
oldCoef="(result).coef",
iteration=iteration,
sourceAlias="src",
madlib_iterative_alg=madlib_iterative_alg))
if (plpy.execute(checkForNullStateSQL.format(iteration=iteration,
madlib_iterative_alg=madlib_iterative_alg)
)[0]['should_terminate'] or
(iteration > cyclesPerIteration and
(iteration >= cyclesPerIteration * maxNumIterations or
plpy.execute(terminateSQL.format(
iteration=iteration,
cyclesPerIteration=cyclesPerIteration,
oldState="(older._madlib_state)",
newState="(newer._madlib_state)")
)[0]['should_terminate']))):
break
# Note: We do not drop the temporary table
return (iteration, madlib_iterative_alg)
# ----------------------------------------------------------------------
def compute_coxph(schema_madlib, source, indepColumn,
depColumn, status, optimizer,
maxNumIterations, precision,
*args, **kwargs):
"""
Compute cox survival regression coefficients
This method serves as an interface to different optimization algorithms.
By default, iteratively reweighted least squares is used, but for data with
a lot of columns the conjugate-gradient method might perform better.
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@param source Name of relation containing the training data
@param indepColumn Name of independent column in training data
@param depColumn Name of dependant column which captures time of death
@param status Right censoring support for cox
@param optimizer Name of the optimizer. 'newton': newton method
@param maxNumIterations Maximum number of iterations
@param precision Terminate if two consecutive iterations have a difference
in the log-likelihood of less than <tt>precision</tt>. In other
words, we terminate if the objective function value has converged.
This convergence criterion can be disabled by specifying a negative
value.
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
@return array with coefficients in case of convergence, otherwise None
"""
indepColumn = __check_args(schema_madlib, source, indepColumn, depColumn, status)
if maxNumIterations < 1:
plpy.error("Number of iterations must be positive")
if optimizer not in ['newton']:
plpy.error("Unknown optimizer requested. Must be 'newton'")
return __runIterativeAlg(
stateType="double precision[]",
initialState="NULL",
source=source,
updateExpr="""
{schema_madlib}.coxph_step(
({indepColumn})::double precision[],
({depColumn})::double precision,
({status})::boolean,
{{oldCoef}}::double precision[]
ORDER BY {depColumn} DESC
)
""".format(schema_madlib=schema_madlib,
indepColumn=indepColumn,
depColumn=depColumn,
status=status),
terminateExpr="""
{schema_madlib}.internal_coxph_step_distance(
{{newState}}, {{oldState}}
) < {precision}
""".format(schema_madlib=schema_madlib, precision=precision),
resultExpr="""
{schema_madlib}.internal_coxph_result({{state}})
""".format(schema_madlib=schema_madlib),
maxNumIterations=maxNumIterations)
# ----------------------------------------------------------------------
def compute_coxph_strata(schema_madlib, source, indepColumn,
depColumn, status, strata, optimizer,
maxNumIterations, precision,
*args, **kwargs):
"""
Compute cox survival regression coefficients
This method serves as an interface to different optimization algorithms.
By default, iteratively reweighted least squares is used, but for data with
a lot of columns the conjugate-gradient method might perform better.
@param schema_madlib Name of the MADlib schema, properly escaped/quoted
@param source Name of relation containing the training data
@param indepColumn Name of independent column in training data
@param depColumn Name of dependant column which captures time of death
@param status Right censoring support for cox
@param strata - A string, column name seprated by commas. The
columns used for stratification. Default is None.
@param optimizer Name of the optimizer. 'newton': newton method
@param maxNumIterations Maximum number of iterations
@param precision Terminate if two consecutive iterations have a difference
in the log-likelihood of less than <tt>precision</tt>. In other
words, we terminate if the objective function value has converged.
This convergence criterion can be disabled by specifying a negative
value.
@param kwargs We allow the caller to specify additional arguments (all of
which will be ignored though). The purpose of this is to allow the
caller to unpack a dictionary whose element set is a superset of
the required arguments by this function.
@return array with coefficients in case of convergence, otherwise None
"""
if strata is None:
return compute_coxph(schema_madlib, source, indepColumn,
depColumn, status, optimizer,
maxNumIterations, precision)
indepColumn = __check_args(schema_madlib, source, indepColumn, depColumn, status)
if maxNumIterations < 1:
plpy.error("Number of iterations must be positive")
if optimizer not in ['newton']:
plpy.error("Unknown optimizer requested. Must be 'newton'")
terminateExpr = """
{schema_madlib}.internal_coxph_step_distance(
{{newState}}, {{oldState}}
) < {precision}
""".format(schema_madlib=schema_madlib, precision=precision)
resultExpr = "{schema_madlib}.internal_coxph_result({{state}})".format(
schema_madlib=schema_madlib)
updateExprOuter = """{schema_madlib}.coxph_strata_step_outer(inner_state)
""".format(schema_madlib=schema_madlib)
updateExprInner = """
{schema_madlib}.coxph_strata_step_inner(
({indepColumn})::double precision[],
({depColumn})::double precision,
({status})::boolean,
{{oldCoef}}::double precision[]
ORDER BY {depColumn} DESC
)
""".format(schema_madlib=schema_madlib,
indepColumn=indepColumn, depColumn=depColumn,
status=status)
return __runIterativeAlg(
stateType="double precision[]", initialState="NULL",
source=source, updateExpr='', terminateExpr=terminateExpr,
resultExpr=resultExpr, maxNumIterations=maxNumIterations,
updateExprOuter=updateExprOuter, updateExprInner=updateExprInner,
strata=strata)
# -----------------------------------------------------------------------
# ZPH functionality
# -----------------------------------------------------------------------
def zph_help_message(schema_madlib, message, **kwargs):
""" Help message for function to test the proportional hazards assumption
for a Cox regression model fit
@brief
Args:
@param schema_madlib string, Name of the schema madlib
@param message string, Help message indicator
Returns:
String. Contain the help message string
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Functionality: Test of proportional hazards assumption
Proportional-Hazard models enable the comparison of various survival models.
See {schema_madlib}.coxph_train() for details to create a Cox PH model.
These PH models, however, assume that the hazard for a given individual
is a fixed proportion of the hazard for any other individual, and the
ratio of the hazards is constant across time.
The cox_zph() function is used to test this assumption by computing the
correlation of the residual of the Cox PH model with time.
For more details on function usage:
SELECT {schema_madlib}.cox_zph('usage')
For an example on using the function:
SELECT {schema_madlib}.cox_zph('example')
"""
elif message in ['usage', 'help', '?']:
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT {schema_madlib}.cox_zph(
'cox_model_table', -- TEXT. The name of the table containing the Cox Proportional-Hazards model
'output_table', -- TEXT. The name of the table where the test statistics are saved
);
-----------------------------------------------------------------------
OUTUPT
-----------------------------------------------------------------------
The <output table> ('output_table' above) has the following columns
- covariate TEXT. The names of independent variables
- rho FLOAT8[]. Vector of the correlation coefficients between
survival time and the scaled Schoenfeld residuals
- chi_square FLOAT8[]. Chi-square test statistic for the correlation analysis
- p_value FLOAT8[]. Two-side p-value for the chi-square statistic
The output residual table is named as <output_table>_residual has the following columns
- <dep_column_name> FLOAT8. Time values (dependent variable) present in the original source table.
- residual FLOAT8[]. Difference between the original covariate value and the
expectation of the covariate obtained from the coxph model.
- scaled_reisdual FLOAT8[]. Residual values scaled by the variance of the coefficients
"""
elif message in ['example', 'examples']:
help_string = """
DROP TABLE IF EXISTS sample_data;
CREATE TABLE sample_data (
id INTEGER NOT NULL,
grp DOUBLE PRECISION,
wbc DOUBLE PRECISION,
timedeath INTEGER,
status BOOLEAN
);
-- Insert sample data
COPY sample_data FROM STDIN DELIMITER '|';
0 | 0 | 1.45 | 35 | t
1 | 0 | 1.47 | 34 | t
3 | 0 | 2.2 | 32 | t
4 | 0 | 1.78 | 25 | t
5 | 0 | 2.57 | 23 | t
6 | 0 | 2.32 | 22 | t
7 | 0 | 2.01 | 20 | t
8 | 0 | 2.05 | 19 | t
9 | 0 | 2.16 | 17 | t
10 | 0 | 3.6 | 16 | t
11 | 1 | 2.3 | 15 | t
12 | 0 | 2.88 | 13 | t
13 | 1 | 1.5 | 12 | t
14 | 0 | 2.6 | 11 | t
15 | 0 | 2.7 | 10 | t
16 | 0 | 2.8 | 9 | t
17 | 1 | 2.32 | 8 | t
18 | 0 | 4.43 | 7 | t
19 | 0 | 2.31 | 6 | t
20 | 1 | 3.49 | 5 | t
21 | 1 | 2.42 | 4 | t
22 | 1 | 4.01 | 3 | t
23 | 1 | 4.91 | 2 | t
24 | 1 | 5 | 1 | t
\.
-- Run coxph function
SELECT {schema_madlib}.coxph_train(
'sample_data',
'sample_cox',
'timedeath',
'ARRAY[grp,wbc]',
'status');
-- Get the Cox PH model
SELECT * FROM sample_cox;
-- Run the PH assumption test and obtain the results
SELECT {schema_madlib}.cox_zph('sample_cox', 'sample_zph_output');
SELECT * FROM sample_zph_output;
"""
else:
help_string = "No such option. Use {schema_madlib}.cox_zph()"
return help_string.format(schema_madlib=schema_madlib)
def zph(schema_madlib, cox_output_table, output_table):
""" Compute the Schoenfeld residuals for a Hazards data table
@brief Compute the Schoenfeld residuals for a Hazards data table
by using an aggregate-defined window function
Args:
@param schema_madlib: string, Name of the MADlib schema
@param cox_output_table: string, Name of the coxph output_table
Returns:
None
"""
_validate_zph_params(schema_madlib, cox_output_table, output_table)
rv = plpy.execute("""
SELECT
source_table,
dependent_variable,
independent_variable,
right_censoring_status,
strata
FROM {cox_output_table}_summary
""".format(cox_output_table=cox_output_table))
source_table = rv[0]['source_table']
dependent_variable = rv[0]['dependent_variable']
independent_variable = rv[0]['independent_variable']
right_censoring_status = rv[0]['right_censoring_status']
strata = rv[0]['strata']
_compute_residual(schema_madlib, source_table, output_table,
dependent_variable, independent_variable,
cox_output_table, right_censoring_status,
strata)
# ----------------------------------------------------------------------
def _validate_zph_params(schema_madlib, cox_model_table, output_table):
"""
Args:
@param schema_madlib: string, Name of the MADlib schema
@param cox_model_table: string, Table name for Cox Prop Hazards model
@param output_table: string, Output data table name
Returns:
None
Throws:
Error on any invalid parameter
"""
_assert(cox_model_table is not None and table_exists(cox_model_table)
and table_exists(cox_model_table + "_summary"),
"Cox error: Model table {0} or summary table {0}_summary "
"does not exist!".format(cox_model_table))
_assert((not table_exists(output_table)) and
(not table_exists(output_table + "_residual")),
"Cox error: Output table {0} or residual table {0}_residual "
"already exists!".format(output_table))
summary_columns = ["source_table", "dependent_variable",
"independent_variable", "right_censoring_status",
"strata"]
_assert(columns_exist_in_table(cox_model_table + "_summary", summary_columns),
"Cox error: At least one column from {0} missing in "
"model table {1}". format(str(summary_columns), cox_model_table))
return None
# ----------------------------------------------------------------------
def _compute_residual(schema_madlib, source_table, output_table,
dependent_variable, independent_variable,
cox_output_table,
right_censoring_status=None,
strata=None, **kwargs):
""" Compute the Schoenfeld residuals for a Hazards model
@brief Computes the Schoenfeld residuals for a Hazards data table
by using an aggregate-defined window function and outputs to a table
Args:
@param schema_madlib: string, Name of the MADlib schema
@param source_table: string, Input data table name
@param output_table: string, Output data table name
@param dependent_variable: string, Dependent variable name
@param independent_variable: string, Independent variable name (could also be an expression)
@param right_censoring_status: string, Column name with right censoring status
@param cox_output_table: string, Output table of coxph
@param strata: string, Comma-separated list of columns to stratify with
Returns:
None
"""
if not right_censoring_status:
right_censoring_status = 'TRUE'
if strata:
partition_str = "PARTITION BY {0}".format(strata)
else:
partition_str = ''
coef = madvec(plpy.execute("SELECT coef FROM {table} ".
format(table=cox_output_table))[0]["coef"],
text=False)
coef_str = "ARRAY" + str(coef)
# We don't extract a copy of the Hessian 2D array, since Postgres/GPDB still
# don't support getting a 2d array into plpython
residual_table = __unique_string()
format_args = {'schema_madlib': schema_madlib,
'output': output_table,
'indep_column': independent_variable,
'dep_column': dependent_variable,
'status': right_censoring_status,
'cox_output_table': cox_output_table,
'source_table': source_table,
'residual_table': residual_table,
'coef_str': coef_str,
'partition_str': partition_str}
# plpy.info("--------- Computing residuals --------- ")
plpy.execute("""
CREATE TEMP TABLE {residual_table} AS
SELECT
{dep_column},
{schema_madlib}.array_sub(
x::DOUBLE PRECISION[],
expectation_x::DOUBLE PRECISION[]
) AS residual
FROM
(
SELECT
{dep_column},
({indep_column})::DOUBLE PRECISION[] AS x,
({status})::BOOLEAN as status,
{schema_madlib}.zph_agg(
({indep_column})::DOUBLE PRECISION[],
{coef_str}
) OVER ({partition_str} ORDER BY {dep_column} DESC) AS expectation_x
FROM {source_table}
WHERE {dep_column} IS NOT NULL AND
NOT {schema_madlib}.array_contains_null(
{indep_column}::DOUBLE PRECISION[])
) AS q1
WHERE status is TRUE
ORDER BY {dep_column} ASC
m4_ifdef(`__GREENPLUM__', `DISTRIBUTED BY ({dep_column})')
""".format(**format_args))
n_uncensored = plpy.execute("""SELECT count(*)::INTEGER as m
FROM {table}
""".format(table=residual_table))[0]["m"]
format_args['m'] = n_uncensored
# plpy.info("--------- Computing scaled residuals ---------")
plpy.execute("""
CREATE TABLE {output}_residual AS
SELECT
{dep_column},
residual as residual,
{schema_madlib}.__coxph_scale_resid(
{m}::INTEGER,
(SELECT hessian FROM {cox_output_table}),
residual
) AS scaled_residual
FROM
{residual_table}
m4_ifdef(`__GREENPLUM__', `DISTRIBUTED BY ({dep_column})')
""".format(**format_args))
# plpy.info("--------- Computing metrics ---------")
mean = plpy.execute("""
SELECT avg({dep_column}) AS w FROM {residual_table}
""".format(**format_args))[0]['w']
plpy.execute("""
CREATE TABLE {output} AS
SELECT
('{indep_column}')::TEXT as covariate, rho,
(f).chi_square_stat as chi_square, (f).p_value as p_value
FROM (
SELECT
{schema_madlib}.array_elem_corr_agg(
scaled_residual,
({dep_column} - {mean})::DOUBLE PRECISION)
AS rho,
{schema_madlib}.__coxph_resid_stat_agg(
({dep_column} - {mean})::DOUBLE PRECISION,
residual,
(SELECT hessian FROM {cox_output_table}),
{m}::INTEGER)
AS f
FROM
{output}_residual
) AS q1
m4_ifdef(`__GREENPLUM__', `DISTRIBUTED RANDOMLY')
""".format(mean=mean, **format_args))
# Cleanup
plpy.execute('DROP TABLE IF EXISTS ' + residual_table)