blob: 17f809fcd7b59bd54714faa16a30518bddf63752 [file] [log] [blame]
# ----------------------------------------------------------------------
# Clustered variance for CoxPH model
# ----------------------------------------------------------------------
import plpy
from utilities.utilities import unique_string
from utilities.utilities import _assert
from utilities.validate_args import table_exists
from utilities.validate_args import columns_exist_in_table
from utilities.utilities import _string_to_array
# ----------------------------------------------------------------------
def clustered_coxph(schema_madlib, model_table, output_table,
clustervar, **kwargs):
""" CoxPH regression clustered standard errors
"""
old_msg_level = plpy.execute("""
select setting from pg_settings
where name='client_min_messages'
""")[0]['setting']
plpy.execute("set client_min_messages to error")
_validate_params(schema_madlib, model_table, output_table, clustervar)
# info is a dict that contains source_table, ind_var, dep_var,
# right_censoring_status, strata
info = plpy.execute("select * from {model_table}_summary".
format(model_table=model_table))[0]
size = plpy.execute(
"""
select array_upper({independent_varname},1) -
array_lower({independent_varname},1) + 1 as size
from {source_table} limit 1
""".format(**info))[0]['size']
# table name of the result of H and S
temp_H_S = unique_string()
temp_A_B = unique_string()
x = unique_string()
y = unique_string()
status = unique_string()
h_s = unique_string()
h = unique_string()
s = unique_string()
if info['strata'] is None:
# Create H and S table, use window function ordered desc
plpy.execute(
"""
create temp table {temp_H_S} as
select
{clustervar},
({independent_varname})::float8[] as {x},
({dependent_varname})::float8 as {y},
({right_censoring_status})::boolean as {status},
{schema_madlib}.coxph_h_s(
{independent_varname},
(select coef from {model_table})
) over (order by {dependent_varname} desc) as {h_s}
from
{source_table} s
where
{independent_varname} is not NULL and
{schema_madlib}.array_contains_null({independent_varname}) is False and
{dependent_varname} is not NULL
""".format(temp_H_S=temp_H_S, model_table=model_table,
clustervar=clustervar, x=x, y=y, status=status,
h_s=h_s, schema_madlib=schema_madlib, **info))
plpy.execute(
"""
create temp table {temp_A_B} as
select {clustervar}, {status},
({h_s}).h as {h}, ({h_s}).s as {s}, {x},
{schema_madlib}.coxph_a_b(
{size},
{status},
({h_s}).h, ({h_s}).s)
over (order by {y}) as f
from
{temp_H_S}
where
{x} is not NULL and
{schema_madlib}.array_contains_null({x}) is False and
{y} is not NULL
""".format(temp_H_S=temp_H_S, temp_A_B=temp_A_B, size=size,
x=x, y=y, status=status, h_s=h_s, h=h, s=s,
schema_madlib=schema_madlib, clustervar=clustervar))
else:
# Create H and S table, use window function ordered desc
plpy.execute(
"""
create temp table {temp_H_S} as
select
{clustervar},
{strata},
({independent_varname})::float8[] as {x},
({dependent_varname})::float8 as {y},
({right_censoring_status})::boolean as {status},
{schema_madlib}.coxph_h_s(
{independent_varname},
(select coef from {model_table})
) over (partition by {strata}
order by {dependent_varname} desc) as {h_s}
from
{source_table} s
where
{independent_varname} is not NULL and
{schema_madlib}.array_contains_null({independent_varname}) is False and
{dependent_varname} is not NULL
""".format(temp_H_S=temp_H_S, model_table=model_table,
clustervar=clustervar, x=x, y=y, status=status,
h_s=h_s, schema_madlib=schema_madlib, **info))
plpy.execute(
"""
create table {temp_A_B} as
select {clustervar}, {status},
({h_s}).h as {h}, ({h_s}).s as {s}, {x},
{schema_madlib}.coxph_a_b(
{size},
{status},
({h_s}).h, ({h_s}).s)
over (partition by {strata} order by {y}) as f
from
{temp_H_S}
where
{x} is not NULL and
{schema_madlib}.array_contains_null({x}) is False and
{y} is not NULL
""".format(temp_H_S=temp_H_S, temp_A_B=temp_A_B, size=size,
x=x, y=y, status=status, h_s=h_s, h=h, s=s,
schema_madlib=schema_madlib, clustervar=clustervar,
**info))
plpy.execute(
"""
create temp table {output_table} as
select
u.coef, u.loglikelihood, u.std_err,
'{clustervar}'::TEXT as clustervar,
(v.f).std_err as clustered_se,
(v.f).z_stats as clustered_z,
(v.f).p_values as clustered_p,
u.hessian
from (
select {schema_madlib}.coxph_compute_clustered_stats(
(select coef from {model_table}),
(select hessian from {model_table}),
s3.a) as f
from (
select
{schema_madlib}.matrix_agg(s2.w) as a
from (
select
m4_ifdef(`__POSTGRESQL__', `{schema_madlib}.__array_')sum(s1.w) as w
from (
select
{clustervar},
{schema_madlib}.coxph_compute_w(
{x}, {status},
(select coef from {model_table}),
{h}, {s}, (f).a, (f).b) as w
from {temp_A_B}
) s1
group by {clustervar}
) s2
) s3
) v, {model_table} u
""".format(model_table=model_table, output_table=output_table,
schema_madlib=schema_madlib, clustervar=clustervar,
temp_A_B=temp_A_B, x=x, status=status, h=h, s=s, **info))
plpy.execute("drop table if exists {temp_H_S}".format(temp_H_S=temp_H_S))
plpy.execute("drop table if exists {temp_A_B}".format(temp_A_B=temp_A_B))
plpy.execute("set client_min_messages to " + old_msg_level)
return None
# ----------------------------------------------------------------------
def _validate_params(schema_madlib, model_table, output_table, clustervar):
""" Validate the input parameters for coxph
Args:
@param schema_madlib - MADlib schema name
@param model_table - A string, the model table name
@param output_table - A string, the result table name
@param clustervar - A string, the cluster variables
Throws:
"Clustered Variance Cox error" if any argument is invalid
"""
if clustervar is None or clustervar.lower().strip() in ("", "null"):
plpy.error("CoxPH clustered variance error: clustervar must be specified!")
_assert(model_table is not None and table_exists(model_table),
"Clustered Variance Cox error: Model data table does not exist")
_assert(model_table is not None and table_exists(model_table + '_summary'),
"Clustered Variance Cox error: Model summary table does not exist")
_assert(columns_exist_in_table(model_table,
['coef', 'loglikelihood', 'std_err',
'z_stats', 'p_values', 'hessian']),
"Clustered Variance Cox error: Invalid model data table"
" - some required columns missing")
_assert(columns_exist_in_table(model_table + '_summary',
['source_table', 'dependent_varname', 'independent_varname',
'right_censoring_status', 'strata']),
"Clustered Variance Cox error: Invalid model summary table"
" - some required columns missing")
_assert(output_table is not None and output_table.strip() != '',
"Clustered Variance Cox error: Invalid output_table is given")
_assert(not table_exists(output_table, only_first_schema=True),
"Clustered Variance Cox error: Output table {0}"
" already exists!".format(str(output_table)))
not_null_coef = plpy.execute("""
SELECT count(*) AS c FROM {model_table}
WHERE coef IS NOT NULL AND hessian IS NOT NULL
""".format(model_table=model_table))[0]["c"]
_assert(not_null_coef != 0,
"Clustered Variance Cox error: No not-null coef and hessian found in "
"model_table {model_table}".format(model_table=model_table))
# What's the purpose of the following test?
# To warn users who have multiple rows in the model table.
if not_null_coef > 1:
plpy.warning("Clustered Variance Cox Warning: multiple rows in "
"model_table {model_table}".format(model_table=model_table))
coef_contains_null = plpy.execute("""
SELECT {schema_madlib}.array_contains_null(coef) OR
{schema_madlib}.array_contains_null(hessian)
AS contains_null
FROM {model_table}
""".format(schema_madlib=schema_madlib,
model_table=model_table))[0]["contains_null"]
_assert(not coef_contains_null,
"Clustered Variance Cox error: coef or hessian array in {0} contains "
"NULL values. (If the input table contains at least one row without "
" NULLS then rerunning coxph should get correct values)".
format(model_table))
if clustervar is not None:
source_table = plpy.execute("select source_table from "
"{model_table}_summary".format(
model_table=model_table))[0]['source_table']
if not columns_exist_in_table(source_table,
_string_to_array(clustervar), schema_madlib):
plpy.error("Clustered Variance Cox error: Cluster column does not exist!")
where_str = " or ".join("\"" + cl + "\" is NULL" for cl in _string_to_array(clustervar))
cluster_null_count = plpy.execute(
"""
select count(*) from {source_table}
where {where_str}
""".format(source_table=source_table, where_str=where_str))[0]['count']
if cluster_null_count != 0:
plpy.info(
"""
MADlib WARNING: There are NULL values in the columns for clustering!
The result might be different from the default results of R or Stata,
because MADlib treats the NULL as a valid value for clustering.
In order to obtain the same result as the default result of R's coxph
function or Stata's stcox function, all rows that contain NULL values
in the clustering columns need to be filtered out before using
MADlib's coxph_train function.
Or, one can use the R code that is similar to the following example
to filter the NULL values only in the independent variables and
produce the same result as MADlib:
coxph(Surv(time) ~ age + ph.ecog + cluster(ph.karno), data = rossi,
ties = "breslow",
na.action = function(object, ...) {
object[with(object,
!(is.na(age) | is.na(ph.ecog))),]
}
)
""")
# ----------------------------------------------------------------------
def cl_coxph_help_message(schema_madlib, message=None, **kwargs):
""" The help message for clustered variance of CoxPH
Returns: A string, which contains a short help message.
"""
if message is not None and message.lower() in ("usage", "help", "?"):
return """
This function calculates the clustered robust sandwich estimator for
Cox Proportional Hazards Regression. The cluster variable(s) identifies
correlated groups of observations.
-----------------------------------------------------------------------
USAGE
-----------------------------------------------------------------------
SELECT madlib.clustered_variance_coxph
(
model_table TEXT, -- Model table name, which is generated FROM coxph_train
output_table TEXT, -- Output table name
clustervar TEXT -- The cluster column names, separated by comma
)
-----------------------------------------------------------------------
OUTPUT
-----------------------------------------------------------------------
The output table contains the following columns:
- coef: Coefficients of regression
- loglikelihood: Log-likelihood value
- std_err: Standard errors
- clustervar: A string, the clustering variables separated by comma
- clustered_se: Standard errors of the clustered robust variance estimators
- clustered_z: z-stats of the clustered robust variance estimators
- clustered_p: p-values of the clustered robust variance estimators
- hessian: Hessian matrix
"""
else:
return """
This function calculates the clustered robust statistics for
Cox Proportional Hazards Regression.
For more details on function usage:
SELECT madlib.clustered_variance_coxph('usage');
"""