blob: f73c14fcf72b32851c1b10ba2e578efd686f86d2 [file] [log] [blame]
m4_changequote(`<!', `!>')
# coding=utf-8
"""
@file cox_prop_hazards.py_in
@brief Cox prop Hazards: Driver functions
@namespace coxprophazards
Cox prop Hazards: Driver functions
//"""
import plpy
import re
from utilities.control import MinWarning
from utilities.validate_args import table_exists
from utilities.validate_args import table_is_empty
from utilities.validate_args import get_cols
from utilities.validate_args import is_var_valid
from utilities.utilities import unique_string
from utilities.utilities import preprocess_keyvalue_params
from utilities.utilities import _assert
from utilities.utilities import get_seg_number
from utilities.utilities import add_postfix
from utilities.utilities import py_list_to_sql_string
from utilities.validate_args import columns_exist_in_table
from utilities.utilities import __mad_version
from utilities.control import IterationController2S
from convex.utils_regularization import utils_ind_var_scales
import random
# ----------------------------------------------------------------------
version_wrapper = __mad_version()
madvec = version_wrapper.select_vecfunc()
def coxph_help_message(schema_madlib, message, **kwargs):
""" Help message for Cox Proportional Hazards
@brief
Args:
@param schema_madlib string, Name of the schema madlib
@param message string, Help message indicator
Returns:
String. Contain the help message string
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Functionality: Cox proprtional hazards regression (Breslow method)
Proportional-Hazard models enable the comparison of various survival models.
These survival models are functions describing the probability of a one-item
event (prototypically, this event is death) with respect to time.
The interval of time before death occurs is the survival time.
Let T be a random variable representing the survival time,
with a cumulative probability function P(t). Informally, P(t) is
the probability that death has happened before time t.
For more details on function usage:
SELECT {schema_madlib}.coxph_train('usage')
"""
elif message in ['usage', 'help', '?']:
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT {schema_madlib}.coxph_train(
'source_table', -- Name of data table
'output_table', -- Name of result table (overwrites if exists)
'dependent_variable', -- Name of column for dependent variables
'independent_variable', -- Name of column for independent variables
(can be any SQL expression Eg: '*')
'right_censoring_status', -- Name of the column containing censoring status
0/false : If the observation is censored
1/true : otherwise
Can also be an SQL expression: 'dependent_variable < 10')
(Optional, DEFAULT = TRUE)
'strata', -- The stratification column names. (Optional, DEFAULT = NULL)
'optimizer_params' -- The optimizer parameters as a comma-separated string
);
-----------------------------------------------------------------------
OUTUPT
-----------------------------------------------------------------------
The output table ('output_table' above) has the following columns
'coef' DOUBLE PRECISION[], -- Coefficients of regression
'loglikelihood' DOUBLE PRECISION, -- Log-likelihood value
'std_err' DOUBLE PRECISION[], -- Standard errors
'z_stats' DOUBLE PRECISION[], -- z-stats of the standard errors
'p_values' DOUBLE PRECISION[], -- p-values of the standard errors
'num_iterations' INTEGER -- Number of iterations performed by the optimizer
The output summary table is named as <output_table>_summary has the following columns
'source_table' VARCHAR, Source table name
'dep_var' VARCHAR, Dependent variable name
'ind_var' VARCHAR, Independent variable name
'right_censoring_status' VARCHAR, Right censoring status
'strata' VARCHAR, Stratification columns
num_rows_processed INTEGER, Number of rows processed during training
num_missing_rows_skipped INTEGER, Number of rows skipped during training
due to missing values
"""
else:
help_string = "No such option. Use {schema_madlib}.coxph_train()"
return help_string.format(schema_madlib=schema_madlib)
# ------------------------------------------------------------
def coxph(
schema_madlib,
source_table,
output_table,
dependent_varname,
independent_varname,
right_censoring_status,
strata,
optimizer_params,
*args,
**kwargs):
""" Cox proportional hazards regression training function
@brief Cox proportional hazards regression, with
stratification support.
Args:
@param schema_madlib - MADlib schema name
@param source_table - A string, the data table name
@param output_table - A string, the result table name
@param dependent_varname - A string, the survival time
column name or a valid expression
@param independent_varname - A string, the covariates in
array formats. It is a valid expression.
@param right_censoring_status - A string, a column name
or a valid expression that has boolean values. Whether
the row of data is censored. Default is 'TRUE'.
@param strata - A string, column names seprated by
commas. The columns used for stratification. Default is
None.
@param optimizer_params - A string, which contains
key=value pairs separated by commas. Default values:
max_iter=20, optimizer='newton', tolerance=1e-4.
Returns:
A table named by output_table, which contains the
following columns:
* coef - An array of double precision values, fitting
coefs
* std_err - An array of double precision values,
standard erros of coef
* z - An array of double precision values, z statistics
* p - An array of double precision values, p value
"""
old_msg_level = plpy.execute(
"""
select setting from pg_settings
where name='client_min_messages'
""")[0]['setting']
plpy.execute("set client_min_messages to warning")
all_arguments = {
'schema_madlib': schema_madlib,
'source_table': source_table,
'output_table': output_table,
'dependent_varname': dependent_varname,
'independent_varname': independent_varname,
'right_censoring_status': right_censoring_status,
'strata': strata
}
_validate_params(**all_arguments)
indepColumn = __check_args(schema_madlib, source_table,
independent_varname, dependent_varname,
right_censoring_status)
(max_iter, optimizer, tolerance, array_agg_size,
sample_size) = _extract_params(schema_madlib, optimizer_params)
# Number of features
n_features = plpy.execute(
"""
select array_upper({independent_varname}, 1)
as n_features from {source_table} limit 1
""".format(
independent_varname=independent_varname,
source_table=source_table))[0]['n_features']
# Use the quick split to re-distribute the data
# construct a list of other columns
other_cols = [right_censoring_status]
new_source_table, index, dep, indep, status, n_processed, n_skipped, real_distid, std_str = quick_split(
schema_madlib, source_table, dependent_varname, n_features,
independent_varname, right_censoring_status, strata,
array_agg_size, sample_size, True)
if n_processed > 0:
compute_coxph(schema_madlib, new_source_table, output_table,
index, dep, indep, n_features, status, strata, optimizer,
max_iter, tolerance, real_distid, std_str)
plpy.execute('drop table if exists ' + new_source_table)
else:
plpy.execute(
"""
create table {output_table} (
coef DOUBLE PRECISION[],
loglikelihood DOUBLE PRECISION,
std_err DOUBLE PRECISION[],
z_stats DOUBLE PRECISION[],
p_values DOUBLE PRECISION[],
hessian DOUBLE PRECISION[],
num_iterations INTEGER
);
""".format(output_table=output_table))
plpy.execute(
"""
insert into {output_table} values
(NULL, NULL, NULL, NULL, NULL, NULL, 1);
""".format(output_table=output_table))
# the summary table
output_table_summary = add_postfix(output_table, "_summary")
plpy.execute(
"""
create table {output_table_summary} as
select
'coxph'::varchar as method,
'{source_table}'::varchar as source_table,
'{output_table}'::varchar as out_table,
'{dependent_varname}'::varchar as dependent_varname,
'{independent_varname}'::varchar as independent_varname,
{right_censoring_status}::varchar as right_censoring_status,
{strata}::varchar as strata,
{n_processed}::integer as num_processed,
{n_skipped}::integer as num_missing_rows_skipped;
""".format(
source_table=source_table,
output_table_summary=output_table_summary,
output_table=output_table,
dependent_varname=dependent_varname,
independent_varname=independent_varname,
right_censoring_status=("NULL::text" if right_censoring_status is None
else "'" + right_censoring_status + "'"),
strata=("NULL::text" if strata is None
else "'" + strata + "'"),
n_processed=n_processed,
n_skipped=n_skipped))
plpy.execute("set client_min_messages to " + old_msg_level)
return None
# ------------------------------------------------------------
def compute_coxph(schema_madlib, source_table, output_table, index, dep,
indep, n_features, status, strata, optimizer, max_iter,
precision, real_distid, std_str):
""" Use the old sequential algorithm to solve coxph
@brief Run ordered aggregate on the re-distributed data.
Each row of the table contains many original rows of the
original data table in inverse order.
"""
if max_iter < 1:
plpy.error("Number of iterations must be positive")
if optimizer not in ['newton']:
plpy.error("Unknown optimizer requested. Must be 'newton'")
# FIXME starting from random values or 0 ?
coef = [0] * n_features
L = float('-inf')
# $1 - previous coef
# $2 - coef limit, when it an array of all 0 (during 1-th iteration),
# we compute the limit, and return it to Python. And then in the following
# iterations, we use the limit.
# the data is already sorted desc
if strata is None:
sql = """
select
(f).*
from (
select {schema_madlib}.coxph_improved_step(
{indep}, {dep}, {status}, $1, $2
order by {index}
) as f
from {source_table}
group by {real_distid} -- to avoid gathering data before aggregating
) s
""".format(**locals())
else:
sql = """
select
(f).*
from
(
select
{schema_madlib}.coxph_improved_strata_step_outer(inner_state) as f
from (
select
{schema_madlib}.coxph_improved_strata_step_inner(
{indep}, {dep}, {status}, $1, $2
order by {index}
) as inner_state
from {source_table}
group by {strata}
) t1
) t2
""".format(**locals())
old_coef = coef
n_iter = 0
max_coef = [-1] * n_features # limit of coefficients
sql_plan = plpy.prepare(sql, ["double precision[]", "double precision[]"])
while True:
n_iter += 1
# The result contains coef, L, d2L
result = plpy.execute(sql_plan, [coef, max_coef])[0]
prev_L, L = L, result['l']
new_coef = result['coef']
if n_iter == 1:
max_coef = result['max_coef']
if L < prev_L:
# Newton Raphson step
coef = [(i + j)/2 for i, j in zip(coef, old_coef)]
else:
old_coef, coef = coef, new_coef
if n_iter > max_iter or (L > prev_L and abs(1 - L / prev_L) <= precision):
# exiting since max_iter iterations explored or
# very small increase in log_likelihood
# Compute std_err, t_stats, p_value
plpy.execute("""
CREATE TABLE {output_table} as
SELECT (f).*
FROM (
SELECT {schema_madlib}.compute_coxph_result(
{coef},
{L},
{d2L}, {n_iter}, {std_str}
) as f
) s;
""".format(output_table=output_table,
schema_madlib=schema_madlib,
coef=py_list_to_sql_string(coef),
L=L,
d2L=py_list_to_sql_string(result["d2l"]),
std_str=std_str,
n_iter=n_iter))
return None
# ------------------------------------------------------------
# The new quick method. Used by PLANET algorithm for
# decision tree. Since it has been successfully applied
# onto an algorithm, no reason that we could not use it.
def quick_split(
schema_madlib,
source_table,
split_col,
num_features,
indep,
status,
strata,
array_agg_size,
sample_size,
reverse=True,
split_col_alias=''):
""" Quickly find the splits of the split_col
so that we can evenly cut the data without sorting
@brief We want to cut the table to n even pieces, and at
the same time keep split_col sorted. For example, in ARIMA,
we need to partition the data into multiple chunks while
each chunk contains consecutive time series. In CoxPH, we
need to do the same thing. This step proves to be
time-consuming becuase sorting (row_number over (order by
...)) is really slow.
The solution is to work on a random sample of the original
data instead of the whole data set.
@param source_table The data table
@param split_col Find the cuts of this column
@param num_features The number of features
@param indep The indepent variables
"""
# Need the total number of records
n_rows = plpy.execute(
"select count(*) from {source_table}".format(
source_table=source_table))[0]['count']
# It might be cubersome to deal with NULL values in C++,
# since a matrix of independent variables is passed into C++.
filter_null = """
{split_col} is not NULL and
{schema_madlib}.array_contains_null({indep}) is False and
{indep} is not NULL and
({status}) is not NULL
""".format(schema_madlib=schema_madlib,
split_col=split_col,
indep=indep, status=status)
# Number of rows to be processed
n_processed = plpy.execute(
"select count(*) from {source_table} where {filter_null}".format(
source_table=source_table, filter_null=filter_null))[0]['count']
n_skipped = n_rows - n_processed
if n_processed == 0:
return ('', 1, '', '',
'', n_processed, n_skipped, '', '')
# Rewrite the num_splits
n_rows_in_chunk = array_agg_size / num_features
num_splits = (1 if n_rows < n_rows_in_chunk else n_rows / n_rows_in_chunk)
# So that we could compute the percentage of the sample
# We sample a few more values to make sure we can get enough
# samples, otherwise the number of samples might be smaller
# than sample_size.
percentage = sample_size / n_rows + 0.01
n_per_seg = int(sample_size / get_seg_number()) + 1
output_table = unique_string() + '_redist'
# nomalize the data to avoid possible overflow
x_mean = plpy.execute(
"""
select {schema_madlib}.array_avg({indep}, false) as xmean from {source_table}
where {filter_null}
""".format(
schema_madlib=schema_madlib,
indep=indep,
source_table=source_table,
filter_null=filter_null))[0]['xmean']
mean_str = "array[" + ",".join(str(v) for v in x_mean) + "]"
scales = plpy.execute(
"""
select {schema_madlib}.array_avg(
{schema_madlib}.array_sub({indep}::float8[], {mean}::float8[]), true) as scale
from {source_table} where {filter_null}
""".format(
schema_madlib=schema_madlib,
indep=indep,
mean=mean_str,
source_table=source_table,
filter_null=filter_null))[0]['scale']
scale_str = "array[" + ",".join(str(v) if v != 0 else "1" for v in scales) + "]"
# compute the break points
# Use a fixed number of rows of sample to approximate the
# break points.
# Each bin will have approximate the same number of records.
splits = plpy.execute(
"""
select
{schema_madlib}._compute_splits(
{split_col}, {n_per_seg}, {num_splits})
as splits
from {source_table}
where random() <= {percentage} and {filter_null}
""".format(
schema_madlib=schema_madlib,
split_col=split_col,
n_per_seg=n_per_seg,
num_splits=num_splits,
source_table=source_table,
filter_null=filter_null,
percentage=percentage))[0]['splits']
# Since the array of break points is not big
# we load it into memory.
splits_str = ('NULL::DOUBLE PRECISION[]' if splits is None
else "array[" + ",".join(str(split) for split in splits) + "]")
# Use the same convention as _redistribute_data
# ie. hard code dep, and indep as column names
# NOTE: CoxPH will process the data in a reversed order. Thus
# the latest data points will be grouped into the group 0.
indep_name = unique_string() + '_indep' # avoid name conflicts
distid = unique_string() + '_distid'
if split_col_alias == "":
split_col_alias = unique_string() + '_split_alias'
status_name = unique_string() + '_status'
order_str = 'desc' if reverse else ''
# Each big rows may contain different number of original small rows
# since the quick split is approximate. But this is perfectly fine.
if strata is None:
real_distid = unique_string() + '_real_distid'
strata_sql = """
CREATE TEMP TABLE {output_table} AS
select
0 as {real_distid}, -- to ensure that all chunks are stored in the same segment
{distid},
array_agg({split_col_alias} ORDER BY {split_col_alias} {order_str},
{status_name}) as {split_col_alias},
array_agg({status_name} ORDER BY {split_col_alias} {order_str},
{status_name}) as {status_name},
{schema_madlib}.matrix_agg(
{schema_madlib}.utils_normalize_data(
{indep_name}, {mean_str}, {scale_str}
)
ORDER BY {split_col_alias} {order_str}, {status_name}) as {indep_name}
from
(
select
{schema_madlib}._compute_grpid(
{splits}, {split_col}, {reverse})
as {distid},
{split_col} as {split_col_alias},
{indep} as {indep_name},
({status})::INTEGER as {status_name}
from
{source_table}
where {filter_null}
) table_split
group by {distid}
m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!DISTRIBUTED BY ({real_distid})!>);
""".format(
output_table=output_table,
schema_madlib=schema_madlib,
distid=distid,
splits=splits_str,
reverse='true' if reverse else 'false',
split_col_alias=split_col_alias,
split_col=split_col,
indep=indep,
indep_name=indep_name,
status_name=status_name,
status=status,
filter_null=filter_null,
source_table=source_table,
order_str=order_str,
mean_str=mean_str,
scale_str=scale_str,
real_distid=real_distid)
plpy.execute(strata_sql)
return (output_table, distid, split_col_alias, indep_name, status_name,
n_processed, n_skipped, real_distid, scale_str)
else:
plpy.execute("""
create temp table {output_table} as
select
{strata},
{distid},
array_agg({split_col_alias} order by {split_col_alias} {order_str},
{status_name}) as {split_col_alias},
array_agg({status_name} order by {split_col_alias} {order_str},
{status_name}) as {status_name},
{schema_madlib}.matrix_agg(
{schema_madlib}.utils_normalize_data(
{indep_name}, {mean_str}, {scale_str}
)
order by {split_col_alias} {order_str}, {status_name}) as {indep_name}
from
(
select
{strata},
{schema_madlib}._compute_grpid(
{splits}, {split_col}, {reverse})
as {distid},
{split_col} as {split_col_alias},
{indep} as {indep_name},
({status})::INTEGER as {status_name}
from
{source_table}
where {filter_null}
) table_split
group by {distid}, {strata}
-- ensure that all chunks belonging to the same strata go to the same segment
m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!DISTRIBUTED BY ({strata})!>)
""".format(
output_table=output_table,
schema_madlib=schema_madlib,
distid=distid,
splits=splits_str,
reverse='true' if reverse else 'false',
split_col_alias=split_col_alias,
split_col=split_col,
indep=indep,
indep_name=indep_name,
status_name=status_name,
status=status,
strata=strata,
filter_null=filter_null,
source_table=source_table,
mean_str=mean_str,
scale_str=scale_str,
order_str=order_str))
return (output_table, distid, split_col_alias, indep_name, status_name,
n_processed, n_skipped, '', scale_str)
# ----------------------------------------------------------------------
def _validate_params(schema_madlib, source_table, output_table,
dependent_varname, independent_varname,
right_censoring_status, strata, *args, **kwargs):
""" Validate the input parameters for coxph
Args:
@param schema_madlib - MADlib schema name
@param source_table - A string, the data table name
@param output_table - A string, the result table name
@param dependent_varname - A string, the survival time column
name or a valid expression
@param independent_varname - A string, the covariates in array
formats. It is a valid expression.
@param right_censoring_status - A string, a column name or a
valid expression that has boolean values. Whether the row
of data is censored. Default is 'TRUE'.
@param strata - A string, column name seprated by commas. The
columns used for stratification. Default is None.
Throws:
"Cox error" if any argument is invalid
"""
# _assert(source_table is not None and table_exists(source_table),
# "Cox error: Source data table does not exist!")
_assert(not table_exists(output_table, only_first_schema=True),
"Cox error: Output table {0}"
" already exists!".format(str(output_table)))
output_table_summary = add_postfix(output_table, "_summary")
_assert(not table_exists(output_table_summary, only_first_schema=True),
"Cox error: Output table {0}"
" already exists!".format(str(output_table_summary)))
if strata is not None:
strata_cols = [a.strip() for a in strata.split(",")]
_assert(columns_exist_in_table(source_table, strata_cols,
schema_madlib),
"ARIMA error: {1} columns do not exist in {0}!"
.format(source_table, strata_cols))
return None
# ----------------------------------------------------------------------
def _extract_params(schema_madlib, optimizer_params):
""" Extract optimizer control parameter or set the default values
@brief optimizer_params is a string with the format of
'max_iter=..., optimizer=..., tolerance=...'. The order
does not matter. If a parameter is missing, then the default
value for it is used. If optimizer_params is None or '',
then all default values are used. If the parameter specified
is none of 'max_iter', 'optimizer', or 'tolerance' then an
error is raised. This function also validates the values of
these parameters.
Throws:
"Cox error" - If the parameter is unsupported or the value is
not valid.
"""
allowed_params = set(["max_iter", "optimizer", "tolerance",
"array_agg_size", "sample_size"])
name_value = dict(max_iter=100, optimizer="newton", tolerance=1e-8,
array_agg_size=10000000, sample_size=1000000)
if optimizer_params is None or len(optimizer_params) == 0:
return (name_value['max_iter'], name_value['optimizer'],
name_value['tolerance'], name_value['array_agg_size'],
name_value['sample_size'])
for s in preprocess_keyvalue_params(optimizer_params):
items = s.split("=")
if (len(items) != 2):
plpy.error("Cox error: Optimizer parameter list has incorrect format!")
param_name = items[0].strip(" \"").lower()
param_value = items[1].strip(" \"").lower()
if param_name not in allowed_params:
plpy.error(
"""
Cox error: {param_name} is not a valid parameter name.
Run:
SELECT {schema_madlib}.coxph('usage');
to see the allowed parameters.
""".format(param_name=param_name,
schema_madlib=schema_madlib))
if param_name == "array_agg_size":
try:
name_value["array_agg_size"] = int(param_value)
except:
plpy.error("Cox error: array_agg_size must be an integer value!")
if param_name == "sample_size":
try:
name_value["sample_size"] = int(param_value)
except:
plpy.error("Cox error: sample_size must be an integer value!")
if param_name == "max_iter":
try:
name_value["max_iter"] = int(param_value)
except:
plpy.error("Cox error: max_iter must be an integer number!")
if param_name == "optimizer":
name_value["optimizer"] = param_value
if param_name == "tolerance":
try:
name_value["tolerance"] = float(param_value)
except:
plpy.error("Cox error: tolerance must be a double precision value!")
if name_value["max_iter"] <= 0:
plpy.error("Cox error: max_iter must be positive!")
if name_value["optimizer"] != "newton":
plpy.error("Cox error: this optimization method is not supported yet!")
if name_value["tolerance"] < 0:
plpy.error("Cox error: tolerance cannot be smaller than 0!")
if name_value["array_agg_size"] <= 0:
plpy.error("Cox error: array_agg_size must be positive!")
if name_value["sample_size"] <= 0:
plpy.error("Cox error: sample_size must be positive!")
return (name_value['max_iter'], name_value['optimizer'],
name_value['tolerance'], name_value['array_agg_size'],
name_value['sample_size'])
# ----------------------------------------------------------------------
def __check_args(schema_madlib, tbl_source, col_ind_var, col_dep_var, col_status):
_assert(tbl_source is not None,
"Cox Proportional Hazards Error: Source table should not be NULL!")
_assert(col_ind_var is not None,
"Cox Proportional Hazards Error: Independent variable should not be NULL!")
_assert(col_dep_var is not None,
"Cox Proportional Hazards Error: Dependent variable should not be NULL!")
_assert(table_exists(tbl_source),
"Cox Proportional Hazards Error: Source table " + tbl_source + " does not exist!")
_assert(not table_is_empty(tbl_source),
"Cox Proportional Hazards Error: Source table " + tbl_source + " is empty!")
_assert(columns_exist_in_table(tbl_source, [col_dep_var]),
"Cox Proportional Hazards Error: Dependent variable does not exist!")
_assert(is_var_valid(tbl_source, col_ind_var),
"Cox Proportional Hazards Error: The independent variable does not exist!")
_assert(is_var_valid(tbl_source, col_status),
"Cox Proportional Hazards Error: Not a valid boolean expression for status!")
col_ind_var_new = col_ind_var
cols = get_cols(tbl_source)
# Select al columns except status and dependent variable
if col_ind_var == "*":
cols = get_cols(tbl_source)
outstr_array = []
for each_col in cols:
if each_col != col_dep_var.lower() and each_col not in col_status.lower():
outstr_array.append(each_col)
col_ind_var_new = 'array[%s]' % (','.join(outstr_array))
return col_ind_var_new
# -----------------------------------------------------------------------
# ZPH functionality
# -----------------------------------------------------------------------
def zph_help_message(schema_madlib, message, **kwargs):
""" Help message for function to test the proportional hazards assumption
for a Cox regression model fit
@brief
Args:
@param schema_madlib string, Name of the schema madlib
@param message string, Help message indicator
Returns:
String. Contain the help message string
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Functionality: Test of proportional hazards assumption
Proportional-Hazard models enable the comparison of various survival models.
See {schema_madlib}.coxph_train() for details to create a Cox PH model.
These PH models, however, assume that the hazard for a given individual
is a fixed proportion of the hazard for any other individual, and the
ratio of the hazards is constant across time.
The cox_zph() function is used to test this assumption by computing the
correlation of the residual of the Cox PH model with time.
For more details on function usage:
SELECT {schema_madlib}.cox_zph('usage')
For an example on using the function:
SELECT {schema_madlib}.cox_zph('example')
"""
elif message in ['usage', 'help', '?']:
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT {schema_madlib}.cox_zph(
'cox_model_table', -- TEXT. The name of the table containing the Cox Proportional-Hazards model
'output_table', -- TEXT. The name of the table where the test statistics are saved
);
-----------------------------------------------------------------------
OUTUPT
-----------------------------------------------------------------------
The <output table> ('output_table' above) has the following columns
- covariate TEXT. The names of independent variables
- rho FLOAT8[]. Vector of the correlation coefficients between
survival time and the scaled Schoenfeld residuals
- chi_square FLOAT8[]. Chi-square test statistic for the correlation analysis
- p_value FLOAT8[]. Two-side p-value for the chi-square statistic
The output residual table is named as <output_table>_residual has the following columns
- <dep_column_name> FLOAT8. Time values (dependent variable) present in the original source table.
- residual FLOAT8[]. Difference between the original covariate value and the
expectation of the covariate obtained from the coxph model.
- scaled_reisdual FLOAT8[]. Residual values scaled by the variance of the coefficients
"""
elif message in ['example', 'examples']:
help_string = """
DROP TABLE IF EXISTS sample_data;
CREATE TABLE sample_data (
id INTEGER NOT NULL,
grp DOUBLE PRECISION,
wbc DOUBLE PRECISION,
timedeath INTEGER,
status BOOLEAN
);
-- Insert sample data
COPY sample_data FROM STDIN DELIMITER '|';
0 | 0 | 1.45 | 35 | t
1 | 0 | 1.47 | 34 | t
3 | 0 | 2.2 | 32 | t
4 | 0 | 1.78 | 25 | t
5 | 0 | 2.57 | 23 | t
6 | 0 | 2.32 | 22 | t
7 | 0 | 2.01 | 20 | t
8 | 0 | 2.05 | 19 | t
9 | 0 | 2.16 | 17 | t
10 | 0 | 3.6 | 16 | t
11 | 1 | 2.3 | 15 | t
12 | 0 | 2.88 | 13 | t
13 | 1 | 1.5 | 12 | t
14 | 0 | 2.6 | 11 | t
15 | 0 | 2.7 | 10 | t
16 | 0 | 2.8 | 9 | t
17 | 1 | 2.32 | 8 | t
18 | 0 | 4.43 | 7 | t
19 | 0 | 2.31 | 6 | t
20 | 1 | 3.49 | 5 | t
21 | 1 | 2.42 | 4 | t
22 | 1 | 4.01 | 3 | t
23 | 1 | 4.91 | 2 | t
24 | 1 | 5 | 1 | t
\.
-- Run coxph function
SELECT {schema_madlib}.coxph_train(
'sample_data',
'sample_cox',
'timedeath',
'ARRAY[grp,wbc]',
'status');
-- Get the Cox PH model
SELECT * FROM sample_cox;
-- Run the PH assumption test and obtain the results
SELECT {schema_madlib}.cox_zph('sample_cox', 'sample_zph_output');
SELECT * FROM sample_zph_output;
"""
else:
help_string = "No such option. Use {schema_madlib}.cox_zph()"
return help_string.format(schema_madlib=schema_madlib)
def zph(schema_madlib, cox_output_table, output_table):
""" Compute the Schoenfeld residuals for a Hazards data table
@brief Compute the Schoenfeld residuals for a Hazards data table
by using an aggregate-defined window function
Args:
@param schema_madlib: string, Name of the MADlib schema
@param cox_output_table: string, Name of the coxph output_table
Returns:
None
"""
_validate_zph_params(schema_madlib, cox_output_table, output_table)
cox_output_table_summary = add_postfix(cox_output_table, "_summary")
rv = plpy.execute("""
SELECT
source_table,
dependent_varname,
independent_varname,
right_censoring_status,
strata
FROM {cox_output_table_summary}
""".format(cox_output_table_summary=cox_output_table_summary))
source_table = rv[0]['source_table']
dependent_variable = rv[0]['dependent_varname']
independent_variable = rv[0]['independent_varname']
right_censoring_status = rv[0]['right_censoring_status']
strata = rv[0]['strata']
_compute_residual(schema_madlib, source_table, output_table,
dependent_variable, independent_variable,
cox_output_table, right_censoring_status,
strata)
# ----------------------------------------------------------------------
def _validate_zph_params(schema_madlib, cox_model_table, output_table):
"""
Args:
@param schema_madlib: string, Name of the MADlib schema
@param cox_model_table: string, Table name for Cox Prop Hazards model
@param output_table: string, Output data table name
Returns:
None
Throws:
Error on any invalid parameter
"""
if cox_model_table is None or cox_model_table.strip() == '':
plpy.error("Cox error: NULL/Empty model table is given!")
cox_model_table_summary = add_postfix(cox_model_table, "_summary")
_assert(cox_model_table is not None and table_exists(cox_model_table)
and table_exists(cox_model_table_summary),
"Cox error: Model table {0} or summary table {1} "
"does not exist!".format(cox_model_table, cox_model_table_summary))
output_table_residual= add_postfix(output_table, "_residual")
_assert((not table_exists(output_table, only_first_schema=True)) and
(not table_exists(output_table_residual, only_first_schema=True)),
"Cox error: Output table {0} or residual table {1} "
"already exists!".format(output_table, output_table_residual))
summary_columns = ["source_table", "dependent_varname",
"independent_varname", "right_censoring_status",
"strata"]
_assert(columns_exist_in_table(cox_model_table + "_summary", summary_columns),
"Cox error: At least one column from {0} missing in "
"model table {1}". format(str(summary_columns), cox_model_table))
return None
# ----------------------------------------------------------------------
def _compute_residual(schema_madlib, source_table, output_table,
dependent_variable, independent_variable,
cox_output_table,
right_censoring_status=None,
strata=None, **kwargs):
""" Compute the Schoenfeld residuals for a Hazards model
@brief Computes the Schoenfeld residuals for a Hazards data table
by using an aggregate-defined window function and outputs to a table
Args:
@param schema_madlib: string, Name of the MADlib schema
@param source_table: string, Input data table name
@param output_table: string, Output data table name
@param dependent_variable: string, Dependent variable name
@param independent_variable: string, Independent variable name (could also be an expression)
@param right_censoring_status: string, Column name with right censoring status
@param cox_output_table: string, Output table of coxph
@param strata: string, Comma-separated list of columns to stratify with
Returns:
None
"""
if not right_censoring_status:
right_censoring_status = 'TRUE'
if strata:
partition_str = "PARTITION BY {0}".format(strata)
else:
partition_str = ''
coef = madvec(plpy.execute("SELECT coef FROM {table} ".
format(table=cox_output_table))[0]["coef"],
text=False)
coef_str = "ARRAY" + str(coef)
# We don't extract a copy of the Hessian 2D array, since Postgres/GPDB still
# don't support getting a 2d array into plpython
residual_table = unique_string()
format_args = {'schema_madlib': schema_madlib,
'output': output_table,
'indep_column': independent_variable,
'dep_column': dependent_variable,
'status': right_censoring_status,
'cox_output_table': cox_output_table,
'source_table': source_table,
'residual_table': residual_table,
'coef_str': coef_str,
'partition_str': partition_str}
# plpy.info("--------- Computing residuals --------- ")
plpy.execute("""
CREATE TEMP TABLE {residual_table} AS
SELECT
{dep_column},
{schema_madlib}.array_sub(
x::DOUBLE PRECISION[],
expectation_x::DOUBLE PRECISION[]
) AS residual
FROM
(
SELECT
{dep_column},
({indep_column})::DOUBLE PRECISION[] AS x,
({status})::BOOLEAN as status,
{schema_madlib}.zph_agg(
({indep_column})::DOUBLE PRECISION[],
{coef_str}
) OVER ({partition_str} ORDER BY {dep_column} DESC) AS expectation_x
FROM {source_table}
WHERE {dep_column} IS NOT NULL AND
NOT {schema_madlib}.array_contains_null(
{indep_column}::DOUBLE PRECISION[])
) AS q1
WHERE status is TRUE
ORDER BY {dep_column} ASC
m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!DISTRIBUTED BY ({dep_column})!>)
""".format(**format_args))
n_uncensored = plpy.execute("""SELECT count(*)::INTEGER as n_uncensored
FROM {table}
""".format(table=residual_table))[0]["n_uncensored"]
format_args['n_uncensored'] = n_uncensored
# plpy.info("--------- Computing scaled residuals ---------")
output_residual = add_postfix(output_table, "_residual")
plpy.execute("""
CREATE TABLE {output_residual} AS
SELECT
{dep_column},
residual as residual,
{schema_madlib}.__coxph_scale_resid(
{n_uncensored}::INTEGER,
(SELECT hessian FROM {cox_output_table}),
residual
) AS scaled_residual
FROM
{residual_table}
m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!DISTRIBUTED BY ({dep_column})!>)
""".format(output_residual=output_residual, **format_args))
# plpy.info("--------- Computing metrics ---------")
mean = plpy.execute("""
SELECT avg({dep_column}) AS w FROM {residual_table}
""".format(**format_args))[0]['w']
plpy.execute("""
CREATE TABLE {output} AS
SELECT
('{indep_column}')::TEXT as covariate, rho,
(f).chi_square_stat as chi_square, (f).p_value as p_value
FROM (
SELECT
{schema_madlib}.array_elem_corr_agg(
scaled_residual,
({dep_column} - {mean})::DOUBLE PRECISION)
AS rho,
{schema_madlib}.__coxph_resid_stat_agg(
({dep_column} - {mean})::DOUBLE PRECISION,
residual,
(SELECT hessian FROM {cox_output_table}),
{n_uncensored}::INTEGER)
AS f
FROM
{output_residual}
) AS q1
m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!DISTRIBUTED RANDOMLY!>)
""".format(mean=mean, output_residual=output_residual, **format_args))
# Cleanup
plpy.execute('DROP TABLE IF EXISTS ' + residual_table)
def cox_prop_hazards(schema_madlib, usage_string, **kwargs):
plpy.warning("This function has been deprecated. Please use 'coxph_train' instead.")
if usage_string == '':
message = """
Summary
------------------------------------------------------------------------------------
Functionality: Cox proprtional hazards regression (Breslow method)
SELECT {schema_madlib}.cox_prop_hazards(
'source_table',
'output_table',
'dependent_variable',
'independent_variable',
'right_censoring_status'
);
For more details on function usage:
SELECT {schema_madlib}.cox_prop_hazards('usage');
"""
elif usage_string in ('usage', 'help', '?'):
message = """
Usage
------------------------------------------------------------------------------------
SELECT {schema_madlib}.cox_prop_hazards(
'source_table', -- Name of data table
'output_table', -- Name of result table (overwrites if exists)
'dependent_variable', -- Name of column for dependent variables
'independent_variable', -- Name of column for independent variables
(can be any SQL expression Eg: ''*'')
['right_censoring_status', -- Name of the column containing censoring status
-- 0/false : If the observation is censored
-- 1/true : otherwise
-- Default is 1/true for all observations
-- Can also be an SQL expression: 'dependent_variable < 10'
);
Output:
------------------------------------------------------------------------------------
The output table (''output_table'' above) has the following columns:
'coef' DOUBLE PRECISION[], -- Coefficients of regression
'std_err' DOUBLE PRECISION[], -- Standard errors
'z_stats' DOUBLE PRECISION[], -- z-stats of the standard errors
'p_values' DOUBLE PRECISION[], -- p-values of the standard errors
"""
else:
message = "No such option. Run SELECT {schema_madlib}.cox_prop_hazards()"
return message.format(schema_madlib=schema_madlib)
def _validate_predict(schema_madlib, model_table, source_table, id_col_name,
output_table, pred_type, reference):
""" Validate the arguments
"""
_assert(pred_type in ('linear_predictors', 'risk', 'terms'),
"Cox predict error: Invalid prediction type.")
_assert(reference in ('overall', 'strata'),
"Cox predict error: Invalid reference type.")
_assert(model_table and model_table.strip().lower() not in ('null', ''),
"Cox predict error: Invalid model table.")
_assert(table_exists(model_table),
"Cox predict error: Model table is missing.")
_assert(source_table and source_table.strip().lower() not in ('null', ''),
"Cox predict error: Invalid source table.")
_assert(table_exists(source_table),
"Cox predict error: source table is missing.")
_assert(not table_exists(output_table, only_first_schema=True),
"Cox predict error: output table already exists.")
_assert(columns_exist_in_table(model_table, ["coef"], schema_madlib),
"Cox predict error: Invalid model table ({0})".format(model_table))
model_summary = model_table + "_summary"
_assert(table_exists(model_summary),
"Cox predict error: Model summary table is missing.")
_assert(columns_exist_in_table(model_summary,
["independent_varname", "strata"],
schema_madlib),
"Cox predict error: Invalid summary table ({0})".format(model_summary))
# ------------------------------------------------------------------------------
def coxph_predict(schema_madlib, model_table, source_table, id_col_name, output_table,
pred_type, reference, **kwargs):
"""
Cox prediction function
"""
if pred_type is None:
pred_type = "linear_predictors"
if reference is None:
reference = "strata"
_validate_predict(schema_madlib, model_table, source_table,
id_col_name, output_table, pred_type, reference)
model_table_summary = model_table + "_summary"
summary_elements = plpy.execute("SELECT strata, independent_varname FROM {0}"
.format(model_table_summary))[0]
strata = summary_elements['strata']
independent_varname = summary_elements['independent_varname']
term_avg = unique_string()
select_strata = ""
group_by_strata = ""
where_strata = ""
if pred_type == "terms":
strata = None
if strata is not None and reference == "strata":
select_strata = "{strata}".format(**locals()) + ","
group_by_strata = "group by {strata}".format(**locals())
strat_cols_condition = ""
els = []
for s in strata.split(","):
els.append("t." + s + " = s." + s)
strat_cols_condition = " and ".join(els)
where_strata = "where {strat_cols_condition}".format(**locals())
if columns_exist_in_table(source_table, [id_col_name], schema_madlib):
coxph_predict_id = id_col_name
else:
coxph_predict_id = 'coxph_predict_id'
# resolve name conflicts in output table
output_name = pred_type + "_output" if id_col_name == pred_type else pred_type
if pred_type != "terms":
sql_predict = """
CREATE TABLE {output_table} AS
SELECT {id_col_name} AS {coxph_predict_id},
{schema_madlib}._coxph_predict_resp(
coef,
{independent_varname},
{term_avg},
'{pred_type}'::TEXT) AS {output_name}
FROM (
SELECT
{select_strata}
{schema_madlib}.avg({independent_varname}) as {term_avg}
FROM {source_table}
{group_by_strata}
) t,
{model_table} m,
{source_table} s
{where_strata}
""".format(**locals())
else:
sql_predict = """
CREATE TABLE {output_table} AS
SELECT
{id_col_name} AS {coxph_predict_id},
{schema_madlib}._coxph_predict_terms(
coef,
{independent_varname},
{term_avg}
) AS {output_name}
FROM (
SELECT
{select_strata}
{schema_madlib}.avg({independent_varname}) as {term_avg}
FROM {source_table}
{group_by_strata}
) t,
{model_table} m,
{source_table} s
{where_strata}
""".format(**locals())
with MinWarning('warning'):
plpy.notice("sql_predict:\n" + sql_predict)
plpy.execute(sql_predict)
return None
def coxph_predict_help_message(schema_madlib, message, **kwargs):
""" Help message for prediction using a CoxPH model
@brief
Args:
@param schema_madlib string, Name of the schema madlib
@param message string, Help message indicator
Returns:
String. Help message string
"""
if not message:
help_string = """
-----------------------------------------------------------------------
SUMMARY
-----------------------------------------------------------------------
Functionality: Prediction using a CoxPH model
The prediction function is provided to calculate the linear
predictors, risk or the linear terms for the given prediction data.
For more details on function usage:
SELECT {schema_madlib}.coxph_predict('usage')
For an example on using the function:
SELECT {schema_madlib}.coxph_predict('example')
"""
elif message in ['usage', 'help', '?']:
help_string = """
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT {schema_madlib}.coxph_predict(
'model_table', -- TEXT. Name of the table containing the cox model.
'source_table', -- TEXT. Name of the table containing the prediction data.
'id_col_name', -- TEXT. Name of the id column in the source table.
'output_table', -- TEXT. Name of the table to store the prediction results in.
'pred_type', -- TEXT. Type of prediction. Can be one of 'linear_predictors',
'risk' or 'terms'. Default = 'linear_predictors'.
'linear_predictors' calculates the dot product of the
independent variables and the coefficients.
'risk' is the exponentiated value of the linear prediction.
'terms' correspond to the linear terms obtained by multiplying
the independent variables with their corresponding coefficients
values (without further calculating the sum of these terms)
The resulting predictions, in all of the above cases, are then
centered around a reference level.
'reference' -- TEXT. Reference level to use for centering the predictions.
Can be one of 'strata' or 'overall'. Default = 'strata'.
Cox model is a relative risk model wherein the predictions are
relative to the sample that they are taken from. Therefore, all
predictions are centered around the mean of the covariates within
each stratum by default. If it is instead desired to use the
mean over all samples, reference='overall' can be
specified. If there was no stratification involved, the reference
parameter does not have any effect.
Note 1: For pred_type = 'terms', the predictions are always
centered around the overall mean values of the covariates
independent of stratification.
Note 2: R uses 'sample' instead of 'overall' when referring to
the overall mean value of the covariates as being the reference
level.
)
-----------------------------------------------------------------------
OUTUPT
-----------------------------------------------------------------------
The <output table> ('output_table' above) has the following columns
- id TEXT. The id column name from the source table
- predicted_result DOUBLE PRECISION. Result of prediction based of the value of
the pred_type parameter
"""
elif message in ['example', 'examples']:
help_string = """
DROP TABLE IF EXISTS sample_data;
CREATE TABLE sample_data (
id INTEGER NOT NULL,
grp DOUBLE PRECISION,
wbc DOUBLE PRECISION,
timedeath INTEGER,
status BOOLEAN
);
-- Insert sample data
COPY sample_data FROM STDIN DELIMITER '|';
0 | 0 | 1.45 | 35 | t
1 | 0 | 1.47 | 34 | t
3 | 0 | 2.2 | 32 | t
4 | 0 | 1.78 | 25 | t
5 | 0 | 2.57 | 23 | t
6 | 0 | 2.32 | 22 | t
7 | 0 | 2.01 | 20 | t
8 | 0 | 2.05 | 19 | t
9 | 0 | 2.16 | 17 | t
10 | 0 | 3.6 | 16 | t
11 | 1 | 2.3 | 15 | t
12 | 0 | 2.88 | 13 | t
13 | 1 | 1.5 | 12 | t
14 | 0 | 2.6 | 11 | t
15 | 0 | 2.7 | 10 | t
16 | 0 | 2.8 | 9 | t
17 | 1 | 2.32 | 8 | t
18 | 0 | 4.43 | 7 | t
19 | 0 | 2.31 | 6 | t
20 | 1 | 3.49 | 5 | t
21 | 1 | 2.42 | 4 | t
22 | 1 | 4.01 | 3 | t
23 | 1 | 4.91 | 2 | t
24 | 1 | 5 | 1 | t
\.
-- Run coxph function
SELECT {schema_madlib}.coxph_train(
'sample_data',
'sample_cox',
'timedeath',
'ARRAY[grp,wbc]',
'status');
-- View the Cox PH model
SELECT * FROM sample_cox;
-- Predict back on the original dataset
SELECT madlib.coxph_predict('sample_cox',
'sample_data',
'id',
'sample_pred',
'risk');
"""
else:
help_string = "No such option. Use {schema_madlib}.coxph_predict()"
return help_string.format(schema_madlib=schema_madlib)