#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

#  
# THIS SCRIPT ANALIZES SURVIVAL DATA USING KAPLAN-MEIER ESTIMATES 
#
# INPUT   PARAMETERS:
# ---------------------------------------------------------------------------------------------
# NAME    TYPE     DEFAULT      MEANING
# ---------------------------------------------------------------------------------------------
# X       String   ---          Location to read the input matrix X containing the survival data: 
#								timestamps, whether event occurred (1) or data is censored (0), and a number of factors (categorical features) 
#								for grouping and/or stratifying 
# TE	  String   ---          Column indices of X which contain timestamps (first entry) and event information (second entry) 
# GI	  String   ---          Column indices of X corresponding to the factors to be used for grouping
# SI	  String   ---          Column indices of X corresponding to the factors to be used for stratifying				
# O       String   ---          Location to write the matrix containing the results of the Kaplan-Meier analysis; see below for the description
# M       String   ---          Location to write Matrix M containing the following statistic: total number of events, median and its confidence intervals; 
#								if survival data for multiple groups and strata are provided each row of M contains the above statistics per group and stratum
# T 	  String   " "			If survival data from multiple groups available and ttype=log-rank or wilcoxon, 
#								location to write the matrix containing result of the (stratified) test for comparing multiple groups
# alpha   Double   0.05         Parameter to compute 100*(1-alpha)% confidence intervals for the survivor function and its median 
# etype   String   "greenwood"  Parameter to specify the error type according to "greenwood" (the default) or "peto"
# ctype   String   "log"        Parameter to modify the confidence interval; "plain" keeps the lower and upper bound of 
#								the confidence interval unmodified,	"log" (the default) corresponds to logistic transformation and 
#								"log-log" corresponds to the complementary log-log transformation 
# ttype   String   "none"   	If survival data for multiple groups is available specifies which test to perform for comparing 
#								survival data across multiple groups: "none" (the default) "log-rank" or "wilcoxon" test   
# fmt     String   "text"       The output format of results of the Kaplan-Meier analysis, such as "text" or "csv"
# ---------------------------------------------------------------------------------------------
# OUTPUT: 
# 1- Matrix KM whose dimension depends on the number of groups (denoted by g) and strata (denoted by s) in the data: 
#	each collection of 7 consecutive columns in KM corresponds to a unique combination of groups and strata in the data with the following schema
# 	1. col: timestamp
# 	2. col: no. at risk
# 	3. col: no. of events
# 	4. col: Kaplan-Meier estimate of survivor function surv
# 	5. col: standard error of surv
# 	6. col: lower 100*(1-alpha)% confidence interval for surv
# 	7. col: upper 100*(1-alpha)% confidence interval for surv
# 2- Matrix M whose dimension depends on the number of groups (g) and strata (s) in the data (k denotes the number of factors used for grouping 
#	,i.e., ncol(GI) and l denotes the number of factors used for stratifying, i.e., ncol(SI))
#	M[,1:k]: unique combination of values in the k factors used for grouping 
#	M[,(k+1):(k+l)]: unique combination of values in the l factors used for stratifying
#	M[,k+l+1]: total number of records
#	M[,k+l+2]: total number of events
#	M[,k+l+3]: median of surv
#	M[,k+l+4]: lower 100*(1-alpha)% confidence interval of the median of surv 
#	M[,k+l+5]: upper 100*(1-alpha)% confidence interval of the median of surv
#	If the number of groups and strata is equal to 1, M will have 4 columns with 
#	M[,1]: total number of events
#	M[,2]: median of surv
#	M[,3]: lower 100*(1-alpha)% confidence interval of the median of surv 
#	M[,4]: upper 100*(1-alpha)% confidence interval of the median of surv
# 3- If survival data from multiple groups available and ttype=log-rank or wilcoxon, a 1 x 4 matrix T and an g x 5 matrix T_GROUPS_OE with
#	T_GROUPS_OE[,1] = no. of events	
#	T_GROUPS_OE[,2] = observed value (O)
#	T_GROUPS_OE[,3] = expected value (E)
#	T_GROUPS_OE[,4] = (O-E)^2/E
#	T_GROUPS_OE[,5] = (O-E)^2/V 	 
#	T[1,1] = no. of groups
#	T[1,2] = degree of freedom for Chi-squared distributed test statistic
#	T[1,3] = test statistic 
#	T[1,4] = P-value
# -------------------------------------------------------------------------------------------
# HOW TO INVOKE THIS SCRIPT - EXAMPLE:
# hadoop jar SystemDS.jar -f KM.dml -nvargs X=INPUT_DIR/X TE=INPUT_DIR/TE GI=INPUT_DIR/GI SI=INPUT_DIR/SI O=OUTPUT_DIR/O
#											M=OUTPUT_DIR/M T=OUTPUT_DIR/T alpha=0.05 etype=greenwood ctype=log fmt=csv

fileX = $X;
fileTE = $TE;
fileGI = ifdef ($GI, " ");
fileSI = ifdef ($SI, " ");
fileO = $O;
fileM = $M;

# Default values of some parameters
fileT = ifdef ($T, " ");                  # $T=" "

fileG = ifdef ($G, " ");                 # $G=" "
fileS = ifdef ($S, " ");                 # $S=" "
fmtO = ifdef ($fmt, "text");             # $fmt="text"
alpha = ifdef ($alpha, 0.05);            # $alpha=0.05
err_type  = ifdef ($etype, "greenwood"); # $etype="greenwood"
conf_type = ifdef ($ctype, "log");       # $ctype="log"
test_type = ifdef ($ttype, "none");      # $ttype="none"

X = read (fileX);
TE = read (fileTE);
if (fileGI != " ") {
	GI = read (fileGI);
} else {
    GI = matrix (0, rows = 1, cols = 1);
}

if (fileSI != " "){
	SI = read (fileSI);
} else {
    SI = matrix (0, rows = 1, cols = 1);
}

TE = t(TE);
GI = t(GI);
SI = t(SI);
		
# check arguments for validity
if (err_type != "greenwood" & err_type != "peto") { 
	stop (err_type + " is not a valid error type!");
}

if (conf_type != "plain" & conf_type != "log" & conf_type != "log-log") { 
	stop (conf_type + " is not a valid confidence type!");
}

if (test_type != "log-rank" & test_type != "wilcoxon" & test_type != "none") {
	stop (test_type + " is not a valid test type!");
}

n_group_cols = ncol (GI);
n_stratum_cols = ncol (SI);

# check GI and SI for validity
GI_1_1 = as.scalar (GI[1,1]);
SI_1_1 = as.scalar (SI[1,1]);	
if (n_group_cols == 1) {
	if (GI_1_1 == 0) { # no factors for grouping
		n_group_cols = 0;
	}
} else if (GI_1_1 == 0) {
	stop ("Matrix GI contains zero entries!");
}
if (n_stratum_cols == 1) {
	if (SI_1_1 == 0) { # no factors for stratifying
		n_stratum_cols = 0;
	}
} else if (SI_1_1 == 0) {
	stop ("Matrix SI contains zero entries!");
}

if (2 + n_group_cols + n_stratum_cols > ncol (X)) {
	stop ("X has an incorrect number of columns!");
}

# reorder cols of X 
if (GI_1_1 == 0 & SI_1_1 == 0) {
	Is = TE;
} else if (GI_1_1 == 0) {
	Is = cbind (TE, SI);
} else if (SI_1_1 == 0) {
	Is = cbind (TE, GI);
} else {
	Is = cbind (TE, GI, SI);
}
X = X %*% table (Is, seq (1, 2 + n_group_cols + n_stratum_cols), ncol (X), 2 + n_group_cols + n_stratum_cols);	

num_records = nrow (X);
num_groups = 1;
num_strata = 1;

### compute group id for each record
print ("Perform grouping...");
if (n_group_cols > 0) {
	for (g in 1:n_group_cols) { # sort columns corresponding to groups sequentially
		X = order (target = X, by = 2 + g);
	}
	XG = X[,3:(3 + n_group_cols - 1)];
	Idx = matrix (1, rows = num_records, cols = 1);
	Idx[2:num_records,] = rowMaxs (X[1:(num_records - 1),3:(2 + n_group_cols)] != X[2:num_records,3:(2 + n_group_cols)]);
	num_groups = sum (Idx);

	XG = replace (target = XG, pattern = 0, replacement = Inf);
	XG = XG * Idx;
	XG = replace (target = XG, pattern = NaN, replacement = 0);	
	G_cols = removeEmpty (target = XG, margin = "rows"); 
	G_cols = replace (target = G_cols, pattern = Inf, replacement = 0);	

	A = removeEmpty (target = diag (Idx), margin = "cols");
	if (ncol (A) > 1) {
		A[,1:(ncol (A) - 1)] = A[,1:(ncol (A) - 1)] - A[,2:ncol (A)];
		B = cumsum (A);
		Gi = B %*% seq(1, ncol(B)); # group ids
	} else { # there is only one group
		Gi = matrix (1, rows = num_records, cols = 1);
	}
	if (n_stratum_cols > 0) {
		X = cbind (X[,1:2], Gi, X[,(3 + g):ncol (X)]);
	} else { # no strata
		X = cbind (X[,1:2],Gi);
	}
}

### compute stratum id for each record
print ("Perform stratifying...");
if (n_stratum_cols > 0) {
	s_offset = 2;
	if (n_group_cols > 0) {
		s_offset = 3;
	}
	for (s in 1:n_stratum_cols) { # sort columns corresponding to strata sequentially
		X = order (target = X, by = s_offset + s);		
	}
	XS = X[,(s_offset + 1):(s_offset + n_stratum_cols)];
	Idx = matrix (1, rows = num_records, cols = 1);
	Idx[2:num_records,] = rowMaxs (X[1:(num_records - 1),(s_offset + 1):(s_offset + n_stratum_cols)] != X[2:num_records,(s_offset + 1):(s_offset + n_stratum_cols)]);
	num_strata = sum (Idx);

	XS = replace (target = XS, pattern = 0, replacement = "Infinity");
	XS = XS * Idx;
	XS = replace (target = XS, pattern = NaN, replacement = 0);	
	S_cols = removeEmpty (target = XS, margin = "rows"); 
	S_cols = replace (target = S_cols, pattern = Inf, replacement = 0);	

	SB = removeEmpty (target = seq (1,num_records), margin = "rows", select = Idx); # indices of stratum boundaries 
	A = removeEmpty (target = diag (Idx), margin = "cols");
	if (ncol (A) > 1) {
		A[,1:(ncol (A) - 1)] = A[,1:(ncol (A) - 1)] - A[,2:ncol (A)];
		B = cumsum (A);
		Si = B %*% seq(1, ncol(B)); # stratum ids
	} else { # there is only one stratum
		Si = matrix (1, rows = num_records, cols = 1);
	}
	X = cbind (X[,1:3],Si);
}

if (n_group_cols == 0 & n_stratum_cols == 0) {
	X = cbind (X, matrix (1, rows = num_records, cols = 2));
	SB = matrix (1, rows = 1, cols = 1);	
} else if (n_group_cols == 0) {	
	X = cbind (X[,1:2], matrix (1, rows = num_records, cols = 1), X[,3]);
} else if (n_stratum_cols == 0) {
	X = cbind (X, matrix (1, rows = num_records, cols = 1));
	SB = matrix (1, rows = 1, cols = 1);
}

######## BEGIN KAPLAN-MEIER ANALYSIS
print ("BEGIN KAPLAN-MEIER SURVIVAL FIT SCRIPT");

KM = matrix (0, rows = num_records, cols = num_groups * num_strata * 7);
KM_cols_select = matrix (1, rows = num_groups * num_strata * 7, cols = 1);
GSI = matrix (0, rows = num_groups * num_strata, cols = 2);
a = 1/0;
M = matrix (a, rows = num_groups * num_strata, cols = 5);
M_cols = seq (1, num_groups * num_strata);
z_alpha_2 = icdf (target = 1 - alpha / 2, dist = "normal");

if (num_groups > 1 & test_type != "none") { 
	str = "";
	TEST = matrix (0, rows = num_groups, cols = 5);
	TEST_GROUPS_OE = matrix (0, rows = 1, cols = 4);
	U = matrix (0, rows = num_groups, cols = num_strata);
	U_OE = matrix (0, rows = num_groups, cols = num_strata);
	OBS = matrix (0, rows = num_groups, cols = num_strata);
	EXP = matrix (0, rows = num_groups, cols = num_strata);
	V_sum_total = matrix (0, rows = num_groups-1, cols = (num_groups-1) * num_strata);
	n_event_all_global = matrix(1, rows=num_groups, cols=num_strata);
} else if (num_groups == 1 & test_type != "none") {
	stop ("Data contains only one group or no groups, at least two groups are required for test!");
}

parfor (s in 1:num_strata, check = 0) {
	
	start_ind = as.scalar (SB[s,]);
	end_ind = num_records;
	if (s != num_strata) {
		end_ind = as.scalar (SB[s + 1,]) - 1;
	} 
	
	######## RECODING TIMESTAMPS PRESERVING THE ORDER
	
	X_cur = X[start_ind:end_ind,];
	range = end_ind - start_ind + 1;
	X_cur = order (target = X_cur, by = 1);
	Idx1 = matrix (1, rows = range, cols = 1);
	
	num_timestamps = 1;
	if (range == 1) {
		RT = matrix (1, rows = 1, cols = 1);
	} else {
		Idx1[2:range,1] = (X_cur[1:(range - 1),1] != X_cur[2:range,1]);
		num_timestamps = sum (Idx1);
		A1 = removeEmpty (target = diag (Idx1), margin = "cols");
		if (ncol (A1) > 1) {
			A1[,1:(ncol (A1) - 1)] = A1[,1:(ncol (A1) - 1)] - A1[,2:ncol (A1)];
			B1 = cumsum (A1);
			RT = B1 %*% seq(1, ncol(B1)); 	
		} else { # there is only one group
			RT = matrix (1, rows = range, cols = 1);
		}
	}
	
	T = X_cur[,1];
	E = X_cur[,2];
	G = X_cur[,3];
	S = X_cur[,4];
	
	n_event_stratum = aggregate (target = E, groups = RT, fn = "sum"); # no. of uncensored events per stratum 
	n_event_all_stratum = aggregate (target = E, groups = RT, fn = "count"); # no. both censored and uncensored of events per stratum 
	Idx1 = cumsum (n_event_all_stratum); 
	time_stratum = table (seq (1, nrow (Idx1), 1), Idx1) %*% T; # distinct timestamps both censored and uncensored per stratum 
	time_stratum_has_zero = sum (time_stratum == 0) > 0;
	if (time_stratum_has_zero) {
		time_stratum = 	replace (target = time_stratum, pattern = 0, replacement = Inf);
	}
	n_time_all1 = nrow (n_event_stratum);  # no. of distinct timestamps both censored and uncensored per stratum
	n_event_all_stratum_agg = matrix (0, rows = n_time_all1, cols = 1); 
	if (n_time_all1 > 1) {
		n_event_all_stratum_agg[2:n_time_all1,] = Idx1[1:(n_time_all1 - 1),]; 
	}
	n_risk_stratum = range - n_event_all_stratum_agg; # no. at risk per stratum

	if (num_groups > 1 & test_type != "none") {	# needed for log-rank or wilcoxon test	
		n_risk_n_event_stratum = matrix (0, rows = n_time_all1, cols = num_groups * 2);
	}

	parfor (g in 1:num_groups, check = 0) {
	
		group_ind = (G == g);
		KM_offset = (s - 1) * num_groups * 7 + (g - 1) * 7;
		M_offset = (s - 1) * num_groups + g;
		if (sum (group_ind) != 0) { # group g is present in the stratum s

			GSI_offset = (s - 1) * num_groups + g; 
			GSI[GSI_offset,1] = g;
			GSI[GSI_offset,2] = s;		
			E_cur = E * group_ind;

			######## COMPUTE NO. AT RISK AND NO.OF EVENTS FOR EACH TIMESTAMP
			
			n_event = aggregate (target = E_cur, groups = RT, fn = "sum"); # no. of uncensored events per stratum per group
			n_event_all = aggregate (target = group_ind, groups = RT, fn = "sum"); # no. of both censored and uncensored events per stratum per group
			Idx1 = cumsum (n_event_all); 
			event_occurred = (n_event > 0);
			if (time_stratum_has_zero) {
				time = replace (target = time_stratum * event_occurred, pattern = NaN, replacement = 0);
				time = removeEmpty (target = time, margin = "rows");
				time = replace (target = time, pattern = Inf, replacement = 0);
			} else {
				time = removeEmpty (target = time_stratum * event_occurred, margin = "rows");
			}
			n_time_all2 = nrow (n_event);  # no. of distinct timestamps both censored and uncensored per stratum per group
			n_event_all_agg = matrix (0, rows = n_time_all2, cols = 1); 
			if (n_time_all2 > 1) {
				n_event_all_agg[2:n_time_all2,] = Idx1[1:(n_time_all2 - 1),]; 
			}
				
			n_risk = sum (group_ind) - n_event_all_agg; # no. at risk per stratum per group
			
			if (num_groups > 1 & test_type != "none") {
				n_risk_n_event_stratum[,(g - 1) * 2 + 1] = n_risk;
				n_risk_n_event_stratum[,(g - 1) * 2 + 2] = n_event;					
			}
			
			# Extract only rows corresponding to events, i.e., for which n_event is nonzero 
			Idx1 = (n_event != 0);
			KM_1 = matrix (0, rows = n_time_all2, cols = 2);
			KM_1[,1] = n_risk;
			KM_1[,2] = n_event;
			KM_1 = removeEmpty (target = KM_1, margin = "rows", select = Idx1);
			n_risk = KM_1[,1];
			n_event = KM_1[,2];
			n_time = nrow (time);
			
			######## ESTIMATE SERVIVOR FUNCTION SURV, ITS STANDARD ERROR SE_SURV, AND ITS 100(1-ALPHA)% CONFIDENCE INTERVAL	
			surv = cumprod ((n_risk - n_event) / n_risk);
			tmp = n_event / (n_risk * (n_risk - n_event));
			se_surv = sqrt (cumsum (tmp)) * surv; 
			if (err_type == "peto") {
				se_surv = (surv * sqrt(1 - surv) / sqrt(n_risk));		
			}
		
			if (conf_type == "plain") { 
				# True survivor function is in [surv +- z_alpha_2 * se_surv], 
				# values less than 0 are replaced by 0, values larger than 1are replaced by 1!
				CI_l = max (surv - (z_alpha_2 * se_surv), 0);  
				CI_r = min (surv + (z_alpha_2 * se_surv), 1); 
			} else if (conf_type == "log") {
				# True survivor function is in [surv * exp(+- z_alpha_2 * se_surv / surv)]
				CI_l = max (surv * exp (- z_alpha_2 * se_surv / surv), 0); 
				CI_r = min (surv * exp ( z_alpha_2 * se_surv / surv), 1); 
			} else { # conf_type == "log-log"
				# True survivor function is in [surv ^ exp(+- z_alpha_2 * se(log(-log(surv))))]
				CI_l = max (surv ^ exp (- z_alpha_2 * se_surv / log(surv)), 0); 
				CI_r = min (surv ^ exp ( z_alpha_2 * se_surv / log(surv)), 1);  
			}	 
			#
			if (as.scalar (n_risk[n_time,]) == as.scalar (n_event[n_time,])) {
				CI_l[n_time,] = 0/0;
				CI_r[n_time,] = 0/0;
			}	
		
			n_event_sum = sum (n_event);
			n_event_sum_all = sum(n_event_all);
			if (n_event_sum > 0) {
				# KM_offset = (s - 1) * num_groups * 7 + (g - 1) * 7;
				KM[1:n_time,KM_offset + 1] = time;
				KM[1:n_time,KM_offset + 2] = n_risk;			
				KM[1:n_time,KM_offset + 3] = n_event;
				KM[1:n_time,KM_offset + 4] = surv;
				KM[1:n_time,KM_offset + 5] = se_surv;
				KM[1:n_time,KM_offset + 6] = CI_l;
				KM[1:n_time,KM_offset + 7] = CI_r;				
			}			
						
			######## ESTIMATE MEDIAN OF SERVIVAL TIMES AND ITS 100(1-ALPHA)% CONFIDENCE INTERVAL
		
			p_5 = (surv <= 0.5); 
			pn_5 = sum (p_5);
			#M_offset = (s - 1) * num_groups + g;
			# if the estimated survivor function is larger than 0.5 for all timestamps median does not exist! 
			p_5_exists = (pn_5 != 0);
			M[M_offset,2] = n_event_sum;
			M[M_offset,1] = n_event_sum_all; 
			if (p_5_exists) {
				if ( as.scalar (surv[n_time - pn_5 + 1,1]) == 0.5 ) { # if the estimated survivor function is exactly equal to 0.5
					if (pn_5 > 1) {
						t_5 = as.scalar ((time[n_time - pn_5 + 1,1] + time[n_time - pn_5 + 2,1])/2);
					} else {
						t_5 = as.scalar (time[n_time - pn_5 + 1,1]);
					}
				} else {
					t_5 = as.scalar (time[n_time - pn_5 + 1,1]);
				}
		
				l_ind = (CI_l <= 0.5);
				r_ind = (CI_r <= 0.5);
				l_ind_sum = sum (l_ind);
				r_ind_sum = sum (r_ind);
				l_min_ind = as.scalar (rowIndexMin (t(l_ind)));
				r_min_ind = as.scalar (rowIndexMin (t(r_ind)));		
				if (l_min_ind == n_time) {
					if (l_ind_sum > 0) {
						if (as.scalar (l_ind[n_time,1]) == 0) { # NA at last position
							M[M_offset,4] = time[n_time - l_ind_sum,1];
						} else {
							M[M_offset,4] = time[1,1];
						}
					}
				} else {
					M[M_offset,4] = time[l_min_ind + 1,1];
				}
				#
				if (r_min_ind == n_time) {
					if (r_ind_sum > 0) {
						if (as.scalar (r_ind[n_time,1]) == 0) { # NA at last position
							M[M_offset,5] = time[n_time - r_ind_sum,1];
						} else {
							M[M_offset,5] = time[1,1];
						}
					}
				} else {
					M[M_offset,5] = time[r_min_ind + 1,1];
				}
				M[M_offset,3] = t_5;
				if (test_type != "none"){
					n_event_all_global[g,s] = n_event_sum_all; 
				}
			}
		} else {
			print ("group " + g + " is not present in the stratum " + s);
			KM_cols_select[(KM_offset + 1):(KM_offset + 7),1] = matrix (0, rows = 7, cols = 1);
			M_cols[M_offset,1] = 0;
		}		
	}
	
			
	######## COMPARISON BETWEEN DIFFERENT GROUPS USING LOG-RANK OR WILCOXON TEST
		
	if (num_groups > 1 & test_type != "none") {

		V = matrix (0, rows = num_groups-1, cols = num_groups-1);
		parfor (g in 0:(num_groups-1), check = 0) {
		
			n_risk = n_risk_n_event_stratum[,g * 2 + 1];			
			n_event = n_risk_n_event_stratum[,g * 2 + 2];
		
			if (test_type == "log-rank") {
				O = n_event;
				E = n_risk * n_event_stratum / n_risk_stratum;		
			} else { ### test_type == "wilcoxon"
				O = n_risk_stratum * n_event / range;
				E = n_risk * n_event_stratum / range;
			}			
			U[(g + 1),s] = sum (O - E);
			U_OE[g + 1, s] = (sum (O - E)*sum (O - E))/sum(E);
			OBS[g + 1, s] = sum(O);
			EXP[g + 1, s] = sum(E);
		}
		
		# parfor (i1 in 0:(num_groups - 2), check = 0) {
		for (i1 in 0:(num_groups - 2), check = 0) {
		
			n_risk = n_risk_n_event_stratum[,1 + i1 * 2]; 
			n_event = n_risk_n_event_stratum[,2 + i1 * 2]; 
			for (i2 in 0:(num_groups - 2)) {

				n_risk_i2j = n_risk_n_event_stratum[,1 + i2 * 2]; 
				I_i1i2 = 0;
				if (i1 == i2) { 
					I_i1i2 = 1;
				}
				if (test_type == "log-rank") {
					V1 = n_risk * n_event_stratum * (n_risk_stratum - n_event_stratum) / (n_risk_stratum * (n_risk_stratum - 1));
					V1 = replace (target = V1, pattern = NaN, replacement = 0);
					V2 = I_i1i2 - (n_risk_i2j / n_risk_stratum);
					V[(i1 + 1),(i2 + 1)] = sum (V1 * V2);
				} else { ### test_type == "wilcoxon"
					V1 = (n_risk_stratum ^ 2) * (n_risk * n_event_stratum) * (n_risk_stratum - n_event_stratum) / (n_risk_stratum * (n_risk_stratum - 1));
					V1 = replace (target = V1, pattern = NaN, replacement = 0);
					V2 = I_i1i2 - (n_risk_i2j / n_risk_stratum);
					V[(i1 + 1),(i2 + 1)] = sum (V1 * V2) / (range ^ 2);
				}
			}
		}
		V_start_ind = (s - 1) * (num_groups - 1) + 1;
		V_sum_total[,V_start_ind:(V_start_ind + num_groups - 2)] = V;
	}
}

if (num_groups > 1 & test_type != "none") {
	V_sum = matrix (0, rows = num_groups-1, cols = num_groups-1);
	for (s in 1:num_strata) {
		V_start_ind = (s - 1) * (num_groups - 1) + 1;
	    V_sum_total_part = V_sum_total[,V_start_ind:(V_start_ind + num_groups - 2)];
		V_sum = V_sum + V_sum_total_part;
	}
		
	U_sum = rowSums (U);

	test_st = as.scalar (t(U_sum[1:(num_groups-1),1]) %*% inv(V_sum) %*% U_sum[1:(num_groups-1),1]);
	p_val = 1 - cdf (target = test_st, dist = "chisq", df = num_groups-1 );
	if (test_type != "none") {
		U_OE_sum = rowSums(U_OE);
		V_OE =rowSums((U*U) /sum(V_sum));
		TEST_GROUPS_OE[1,1] = num_groups;
		TEST_GROUPS_OE[1,2] = num_groups - 1;
		TEST_GROUPS_OE[1,3] = test_st;
		TEST_GROUPS_OE[1,4] = p_val;
		TEST[,1] = rowSums(n_event_all_global);
		TEST[,2] = rowSums(OBS);
		TEST[,3] = rowSums(EXP);
		TEST[,4] = rowSums(U_OE_sum);
		TEST[,5] = rowSums(V_OE);
		str = append (str, test_type + " test for " + num_groups + " groups: Chi-squared = " + test_st + " on " + (num_groups - 1) + " df, p = " + p_val + " ");	
	} 
}

GSI = removeEmpty (target = GSI, margin = "rows");
if (n_group_cols > 0) {
	# making a copy of unique groups before adding new rows depending on strata
	G_cols_original = G_cols;

	GSI_1 = GSI[,1];
	tab_G = table (seq (1, nrow (GSI_1)), GSI_1, nrow (GSI_1), nrow (G_cols));
	G_cols = tab_G %*% G_cols;
}

if (n_stratum_cols > 0) {
	GSI_2 = GSI[,2];
	tab_S = table (seq (1, nrow (GSI_2)), GSI_2, nrow (GSI_2), nrow (S_cols));
	S_cols = tab_S %*% S_cols;
}

# pull out non-empty rows from M 
M_cols = removeEmpty (target = M_cols, margin = "rows");
tab_M = table (seq (1, nrow (M_cols)), M_cols, nrow (M_cols), nrow (M));
M = tab_M %*% M;
M = replace (target = M, pattern = Inf, replacement = NaN);

# pull out non-empty rows from TEST
if (n_group_cols > 0 & n_stratum_cols > 0) {
	M = cbind (G_cols, S_cols, M);
	if (test_type != "none") {
		TEST = cbind (G_cols_original, TEST);
	}
} else if (n_group_cols > 0) {
	M = cbind (G_cols, M);
	if (test_type != "none") {	
		TEST = cbind (G_cols_original, TEST);
	}
} else if (n_stratum_cols > 0) {
	M = cbind (S_cols, M);
}

# pull out non-empty columns from KM
KM = t (cbind (t (KM), KM_cols_select) * KM_cols_select);
KM = removeEmpty (target = KM, margin = "cols");
KM = removeEmpty (target = KM, margin = "rows");
KM = KM[1:(nrow (KM) - 1),];

# write output matrices
write (M, fileM, format=fmtO);
write (KM, fileO, format=fmtO);

if (test_type != "none") {
	if (num_groups > 1 & fileT != " ") { 
		write (TEST, fileT, format=fmtO);
		write (TEST_GROUPS_OE, fileT+".groups.oe", format=fmtO);
	} else {
		print (str);
	}	
}