blob: 6f1a0681fb66aa28f7a8f5d0397d5bcac4c3a1c2 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
#
# THIS SCRIPT APPLIES THE ESTIMATED PARAMETERS OF A COX PROPORTIONAL HAZARD REGRESSION MODEL TO A NEW (TEST) DATASET
#
# INPUT PARAMETERS:
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# X String --- Location to read the input matrix containing the survival data with the following schema:
# - X[,1]: timestamps
# - X[,2]: whether an event occurred (1) or data is censored (0)
# - X[,3:]: feature vectors (excluding the baseline columns) used for model fitting
# RT String --- Location to read column matrix RT containing the (order preserving) recoded timestamps from X
# M String --- Location to read matrix M containing the fitted Cox model with the following schema:
# - M[,1]: betas
# - M[,2]: exp(betas)
# - M[,3]: standard error of betas
# - M[,4]: Z
# - M[,5]: p-value
# - M[,6]: lower 100*(1-alpha)% confidence interval of betas
# - M[,7]: upper 100*(1-alpha)% confidence interval of betas
# Y String --- Location to read matrix Y used for prediction
# COV String --- Location to read the variance-covariance matrix of the betas
# MF String --- Location to read column indices of X excluding the baseline factors if available
# P String --- Location to store matrix P containing the results of prediction
# fmt String "text" Matrix output format, usually "text" or "csv" (for matrices only)
# ---------------------------------------------------------------------------------------------
# OUTPUT:
# 1- A matrix P with the following schema:
# P[,1]: linear predictors relative to a baseline which contains the mean values for each feature
# i.e., (Y[3:] - colMeans (X[3:])) %*% b
# P[,2]: standard error of linear predictors
# P[,3]: risk relative to a baseline which contains the mean values for each feature
# i.e., exp ((Y[3:] - colMeans (X[3:])) %*% b)
# P[,4]: standard error of risk
# P[,5]: estimates of cumulative hazard
# P[,6]: standard error of the estimates of cumulative hazard
# -------------------------------------------------------------------------------------------
# HOW TO INVOKE THIS SCRIPT - EXAMPLE:
# hadoop jar SystemDS.jar -f cox-predict.dml -nvargs X=INPUT_DIR/X RT=INPUT_DIR/RT M=INPUT_DIR/M Y=INPUT_DIR/Y
# COV=INTPUT_DIR/COV MF=INPUT_DIR/MF P=OUTPUT_DIR/P fmt=csv
fileX = $X;
fileRT = $RT;
fileMF = $MF;
fileY = $Y;
fileM = $M;
fileCOV = $COV;
fileP = $P;
# Default values of some parameters
fmtO = ifdef ($fmt, "text"); # $fmt="text"
X_orig = read (fileX);
RT_X = read (fileRT);
Y_orig = read (fileY);
M = read (fileM);
b = M[,1];
COV = read (fileCOV);
col_ind = read (fileMF);
tab = table (col_ind, seq (1, nrow (col_ind)), ncol (Y_orig), nrow (col_ind));
Y_orig = Y_orig %*% tab;
# Y and X have the same dimensions and schema
if (ncol (Y_orig) != ncol (X_orig)) {
stop ("Y has a wrong number of columns!");
}
X = X_orig[,3:ncol (X_orig)];
T_X = X_orig[,1];
E_X = X_orig[,2];
D = ncol (X);
N = nrow (X);
Y_orig = order (target = Y_orig, by = 1);
Y = Y_orig[,3:ncol (X_orig)];
T_Y = Y_orig[,1];
col_means = colMeans (X);
ones = matrix (1, rows = nrow (Y), cols = 1);
Y_rel = Y - (ones %*% col_means);
##### compute linear predictors
LP = Y_rel %*% b;
# compute standard error of linear predictors using the Delta method
se_LP = diag(sqrt (Y_rel %*% COV %*% t(Y_rel)));
##### compute risk
R = exp (Y_rel %*% b);
# compute standard error of risk using the Delta method
se_R = diag(sqrt ((Y_rel * R) %*% COV %*% t(Y_rel * R))) / sqrt (exp (LP));
##### compute estimates of cumulative hazard together with their standard errors:
# 1. col contains cumulative hazard estimates
# 2. col contains standard errors for cumulative hazard estimates
d_r = aggregate (target = E_X, groups = RT_X, fn = "sum");
e_r = aggregate (target = RT_X, groups = RT_X, fn = "count");
Idx = cumsum (e_r);
all_times = table (seq (1, nrow (Idx), 1), Idx) %*% T_X; # distinct event times
event_times = removeEmpty (target = (d_r > 0) * all_times, margin = "rows");
num_distinct_event = nrow (event_times);
num_distinct = nrow (all_times); # no. of distinct timestamps censored or uncensored
I_rev = table (seq (1, num_distinct, 1), seq (num_distinct, 1, -1));
e_r_rev_agg = cumsum (I_rev %*% e_r);
select = t (colSums (table (seq (1, num_distinct), e_r_rev_agg)));
min_event_time = min (event_times);
max_event_time = max (event_times);
T_Y = T_Y + (min_event_time * (T_Y < min_event_time));
T_Y = T_Y + (max_event_time * (T_Y > max_event_time));
Ind = outer (T_Y, t (event_times), ">=");
Ind = table (seq (1, nrow (T_Y)), rowIndexMax (Ind), nrow (T_Y), num_distinct_event);
exp_Xb = exp (X %*% b);
exp_Xb_agg = aggregate (target = exp_Xb, groups = RT_X, fn = "sum");
exp_Xb_cum = I_rev %*% cumsum (I_rev %*% exp_Xb_agg);
H0 = cumsum (removeEmpty (target = d_r / exp_Xb_cum, margin = "rows"));
P1 = cumsum (removeEmpty (target = d_r / exp_Xb_cum ^ 2, margin = "rows"));
X_exp_Xb = X * exp (X %*% b);
I_rev_all = table (seq (1, N, 1), seq (N, 1, -1));
X_exp_Xb_rev_agg = cumsum (I_rev_all %*% X_exp_Xb);
X_exp_Xb_rev_agg = removeEmpty (target = X_exp_Xb_rev_agg * select, margin = "rows");
X_exp_Xb_cum = I_rev %*% X_exp_Xb_rev_agg;
P2 = cumsum (removeEmpty (target = (X_exp_Xb_cum * d_r) / exp_Xb_cum ^ 2, margin = "rows"));
exp_Yb = exp (Y %*% b);
exp_Yb_2 = exp_Yb ^ 2;
Y_exp_Yb = Y * exp (Y %*% b);
# estimates of cumulative hazard
H = exp_Yb * (Ind %*% H0);
# term1
term1 = exp_Yb_2 * (Ind %*% P1);
# term2
P3 = cumsum (removeEmpty (target = (exp_Xb_cum * d_r) / exp_Xb_cum ^ 2, margin = "rows"));
P4 = (Ind %*% P2) * exp_Yb;
P5 = Y_exp_Yb * (Ind %*% P3);
term2 = P4 - P5;
# standard error of the estimates of cumulative hazard
se_H = sqrt (term1 + rowSums((term2 %*% COV) * term2));
# prepare output matrix
P = matrix (0, rows = nrow (Y), cols = 6);
P[,1] = LP;
P[,2] = se_LP;
P[,3] = R;
P[,4] = se_R;
P[,5] = H;
P[,6] = se_H;
write (P, fileP, format=fmtO);