| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # |
| # THIS SCRIPT APPLIES THE ESTIMATED PARAMETERS OF A COX PROPORTIONAL HAZARD REGRESSION MODEL TO A NEW (TEST) DATASET |
| # |
| # INPUT PARAMETERS: |
| # --------------------------------------------------------------------------------------------- |
| # NAME TYPE DEFAULT MEANING |
| # --------------------------------------------------------------------------------------------- |
| # X String --- Location to read the input matrix containing the survival data with the following schema: |
| # - X[,1]: timestamps |
| # - X[,2]: whether an event occurred (1) or data is censored (0) |
| # - X[,3:]: feature vectors (excluding the baseline columns) used for model fitting |
| # RT String --- Location to read column matrix RT containing the (order preserving) recoded timestamps from X |
| # M String --- Location to read matrix M containing the fitted Cox model with the following schema: |
| # - M[,1]: betas |
| # - M[,2]: exp(betas) |
| # - M[,3]: standard error of betas |
| # - M[,4]: Z |
| # - M[,5]: p-value |
| # - M[,6]: lower 100*(1-alpha)% confidence interval of betas |
| # - M[,7]: upper 100*(1-alpha)% confidence interval of betas |
| # Y String --- Location to read matrix Y used for prediction |
| # COV String --- Location to read the variance-covariance matrix of the betas |
| # MF String --- Location to read column indices of X excluding the baseline factors if available |
| # P String --- Location to store matrix P containing the results of prediction |
| # fmt String "text" Matrix output format, usually "text" or "csv" (for matrices only) |
| # --------------------------------------------------------------------------------------------- |
| # OUTPUT: |
| # 1- A matrix P with the following schema: |
| # P[,1]: linear predictors relative to a baseline which contains the mean values for each feature |
| # i.e., (Y[3:] - colMeans (X[3:])) %*% b |
| # P[,2]: standard error of linear predictors |
| # P[,3]: risk relative to a baseline which contains the mean values for each feature |
| # i.e., exp ((Y[3:] - colMeans (X[3:])) %*% b) |
| # P[,4]: standard error of risk |
| # P[,5]: estimates of cumulative hazard |
| # P[,6]: standard error of the estimates of cumulative hazard |
| # ------------------------------------------------------------------------------------------- |
| # HOW TO INVOKE THIS SCRIPT - EXAMPLE: |
| # hadoop jar SystemDS.jar -f cox-predict.dml -nvargs X=INPUT_DIR/X RT=INPUT_DIR/RT M=INPUT_DIR/M Y=INPUT_DIR/Y |
| # COV=INTPUT_DIR/COV MF=INPUT_DIR/MF P=OUTPUT_DIR/P fmt=csv |
| |
| fileX = $X; |
| fileRT = $RT; |
| fileMF = $MF; |
| fileY = $Y; |
| fileM = $M; |
| fileCOV = $COV; |
| fileP = $P; |
| |
| # Default values of some parameters |
| fmtO = ifdef ($fmt, "text"); # $fmt="text" |
| |
| X_orig = read (fileX); |
| RT_X = read (fileRT); |
| Y_orig = read (fileY); |
| M = read (fileM); |
| b = M[,1]; |
| COV = read (fileCOV); |
| |
| col_ind = read (fileMF); |
| tab = table (col_ind, seq (1, nrow (col_ind)), ncol (Y_orig), nrow (col_ind)); |
| Y_orig = Y_orig %*% tab; |
| |
| |
| # Y and X have the same dimensions and schema |
| if (ncol (Y_orig) != ncol (X_orig)) { |
| stop ("Y has a wrong number of columns!"); |
| } |
| |
| X = X_orig[,3:ncol (X_orig)]; |
| T_X = X_orig[,1]; |
| E_X = X_orig[,2]; |
| D = ncol (X); |
| N = nrow (X); |
| Y_orig = order (target = Y_orig, by = 1); |
| Y = Y_orig[,3:ncol (X_orig)]; |
| T_Y = Y_orig[,1]; |
| |
| col_means = colMeans (X); |
| ones = matrix (1, rows = nrow (Y), cols = 1); |
| Y_rel = Y - (ones %*% col_means); |
| |
| ##### compute linear predictors |
| LP = Y_rel %*% b; |
| # compute standard error of linear predictors using the Delta method |
| se_LP = diag(sqrt (Y_rel %*% COV %*% t(Y_rel))); |
| |
| ##### compute risk |
| R = exp (Y_rel %*% b); |
| # compute standard error of risk using the Delta method |
| se_R = diag(sqrt ((Y_rel * R) %*% COV %*% t(Y_rel * R))) / sqrt (exp (LP)); |
| |
| ##### compute estimates of cumulative hazard together with their standard errors: |
| # 1. col contains cumulative hazard estimates |
| # 2. col contains standard errors for cumulative hazard estimates |
| |
| d_r = aggregate (target = E_X, groups = RT_X, fn = "sum"); |
| e_r = aggregate (target = RT_X, groups = RT_X, fn = "count"); |
| Idx = cumsum (e_r); |
| all_times = table (seq (1, nrow (Idx), 1), Idx) %*% T_X; # distinct event times |
| |
| event_times = removeEmpty (target = (d_r > 0) * all_times, margin = "rows"); |
| num_distinct_event = nrow (event_times); |
| |
| num_distinct = nrow (all_times); # no. of distinct timestamps censored or uncensored |
| I_rev = table (seq (1, num_distinct, 1), seq (num_distinct, 1, -1)); |
| e_r_rev_agg = cumsum (I_rev %*% e_r); |
| select = t (colSums (table (seq (1, num_distinct), e_r_rev_agg))); |
| |
| min_event_time = min (event_times); |
| max_event_time = max (event_times); |
| T_Y = T_Y + (min_event_time * (T_Y < min_event_time)); |
| T_Y = T_Y + (max_event_time * (T_Y > max_event_time)); |
| |
| Ind = outer (T_Y, t (event_times), ">="); |
| Ind = table (seq (1, nrow (T_Y)), rowIndexMax (Ind), nrow (T_Y), num_distinct_event); |
| |
| exp_Xb = exp (X %*% b); |
| exp_Xb_agg = aggregate (target = exp_Xb, groups = RT_X, fn = "sum"); |
| exp_Xb_cum = I_rev %*% cumsum (I_rev %*% exp_Xb_agg); |
| |
| H0 = cumsum (removeEmpty (target = d_r / exp_Xb_cum, margin = "rows")); |
| P1 = cumsum (removeEmpty (target = d_r / exp_Xb_cum ^ 2, margin = "rows")); |
| X_exp_Xb = X * exp (X %*% b); |
| |
| I_rev_all = table (seq (1, N, 1), seq (N, 1, -1)); |
| X_exp_Xb_rev_agg = cumsum (I_rev_all %*% X_exp_Xb); |
| X_exp_Xb_rev_agg = removeEmpty (target = X_exp_Xb_rev_agg * select, margin = "rows"); |
| X_exp_Xb_cum = I_rev %*% X_exp_Xb_rev_agg; |
| P2 = cumsum (removeEmpty (target = (X_exp_Xb_cum * d_r) / exp_Xb_cum ^ 2, margin = "rows")); |
| |
| exp_Yb = exp (Y %*% b); |
| exp_Yb_2 = exp_Yb ^ 2; |
| Y_exp_Yb = Y * exp (Y %*% b); |
| |
| # estimates of cumulative hazard |
| H = exp_Yb * (Ind %*% H0); |
| |
| # term1 |
| term1 = exp_Yb_2 * (Ind %*% P1); |
| |
| # term2 |
| P3 = cumsum (removeEmpty (target = (exp_Xb_cum * d_r) / exp_Xb_cum ^ 2, margin = "rows")); |
| P4 = (Ind %*% P2) * exp_Yb; |
| P5 = Y_exp_Yb * (Ind %*% P3); |
| term2 = P4 - P5; |
| |
| # standard error of the estimates of cumulative hazard |
| se_H = sqrt (term1 + rowSums((term2 %*% COV) * term2)); |
| |
| # prepare output matrix |
| P = matrix (0, rows = nrow (Y), cols = 6); |
| P[,1] = LP; |
| P[,2] = se_LP; |
| P[,3] = R; |
| P[,4] = se_R; |
| P[,5] = H; |
| P[,6] = se_H; |
| write (P, fileP, format=fmtO); |
| |