blob: c5c1066d6565c8cd51a24d013119bf3f5a108801 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# Split input data X and Y into contiguous or samples train/test sets
# ------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ------------------------------------------------------------------------------
# X Matrix --- Input feature matrix
# Y Matrix --- Input Labels
# f Double 0.7 Train set fraction [0,1]
# cont Boolean TRUE contiuous splits, otherwise sampled
# seed Integer -1 The seed to reandomly select rows in sampled mode
# ------------------------------------------------------------------------------
# Xtrain Matrix --- Train split of feature matrix
# Xtest Matrix --- Test split of feature matrix
# ytrain Matrix --- Train split of label matrix
# ytest Matrix --- Test split of label matrix
# ------------------------------------------------------------------------------
m_split = function(Matrix[Double] X, Matrix[Double] Y, Double f=0.7, Boolean cont=TRUE, Integer seed=-1)
return (Matrix[Double] Xtrain, Matrix[Double] Xtest, Matrix[Double] Ytrain, Matrix[Double] Ytest)
{
# basic sanity checks
if( f <= 0 | f >= 1 )
stop("Invalid train/test split configuration: f="+f);
if( nrow(X) != nrow(Y) )
stop("Mismatching number of rows X and Y: "+nrow(X)+" "+nrow(Y) )
# contiguous train/test splits
if( cont ) {
Xtrain = X[1:f*nrow(X),];
Ytrain = Y[1:f*nrow(X),];
Xtest = X[(nrow(Xtrain)+1):nrow(X),];
Ytest = Y[(nrow(Xtrain)+1):nrow(X),];
}
# sampled train/test splits
else {
# create random select vector according to f and then
# extract tuples via permutation (selection) matrix multiply
# or directly via removeEmpty by selection vector
I = rand(rows=nrow(X), cols=1, seed=seed) <= f;
Xtrain = removeEmpty(target=X, margin="rows", select=I);
Ytrain = removeEmpty(target=Y, margin="rows", select=I);
Xtest = removeEmpty(target=X, margin="rows", select=(I==0));
Ytest = removeEmpty(target=Y, margin="rows", select=(I==0));
}
}