blob: f875b3de92771749f354e9a0475b625728ac21a2 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
F = read($1, data_type="frame", format="csv", header=FALSE);
tfspec = read($2, data_type="scalar", value_type="string");
R = matrix("1 1 1 1 1 1 1 1 1 1 1 1 2", rows=1, cols=13)
[X, meta] = transformencode(target=F, spec=tfspec);
Y = X[,ncol(X)];
X = X[,1:ncol(X)-1];
X = replace(target=X, pattern=NaN, replacement=5); # 1 val
if( $3==1 ) {
M = decisionTree(X=X, y=Y, ctypes=R, max_features=1, max_values=$4,
min_split=10, min_leaf=4, seed=7, verbose=TRUE);
yhat = decisionTreePredict(X=X, y=Y, ctypes=R, M=M)
}
else {
sf = 1.0/($3-1);
M = randomForest(X=X, y=Y, ctypes=R, sample_frac=sf, num_trees=$3-1,
max_features=1, max_values=$4,
min_split=10, min_leaf=4, seed=7, verbose=TRUE);
yhat = randomForestPredict(X=X, y=Y, ctypes=R, M=M)
}
acc = as.matrix(mean(yhat == Y))
err = 1-(acc);
print("accuracy: "+as.scalar(acc))
write(acc, $5);