blob: 8e7f0980dfe45449440eabd679352ef4e0c2654e [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# This function split input data X and Y into contiguous or samples train/test sets
#
# INPUT:
# ------------------------------------------------------------------------------------------
# X Input feature matrix
# Y Input Labels
# f Train set fraction [0,1]
# cont contiguous splits, otherwise sampled
# seed The seed to randomly select rows in sampled mode
# ------------------------------------------------------------------------------------------
#
# OUTPUT:
# --------------------------------------------------------------------------------------------
# X_Train Train split of feature matrix
# X_Test Test split of feature matrix
# y_Train Train split of label matrix
# y_Test Test split of label matrix
# --------------------------------------------------------------------------------------------
m_split = function(Matrix[Double] X, Matrix[Double] Y, Double f=0.7, Boolean cont=TRUE, Integer seed=-1)
return (Matrix[Double] X_Train, Matrix[Double] X_Test, Matrix[Double] Y_Train, Matrix[Double] Y_Test)
{
# basic sanity checks
if( f <= 0 | f >= 1 )
stop("Invalid train/test split configuration: f="+f);
if( nrow(X) != nrow(Y) )
stop("Mismatching number of rows X and Y: "+nrow(X)+" "+nrow(Y) )
# contiguous train/test splits
if( cont ) {
X_Train = X[1:f*nrow(X),];
Y_Train = Y[1:f*nrow(X),];
X_Test = X[(nrow(X_Train)+1):nrow(X),];
Y_Test = Y[(nrow(X_Train)+1):nrow(X),];
}
# sampled train/test splits
else {
# create random select vector according to f and then
# extract tuples via permutation (selection) matrix multiply
# or directly via removeEmpty by selection vector
I = rand(rows=nrow(X), cols=1, seed=seed) <= f;
X_Train = removeEmpty(target=X, margin="rows", select=I);
Y_Train = removeEmpty(target=Y, margin="rows", select=I);
X_Test = removeEmpty(target=X, margin="rows", select=(I==0));
Y_Test = removeEmpty(target=Y, margin="rows", select=(I==0));
}
}