blob: 283db7c0761664779a3cc6d1dc0ab78ebc8a5145 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
X = read($1, data_type = "frame", format = "csv", header=TRUE)
X = X[ , 2:ncol(X) - 1]
X = as.matrix(X)
# divide in train and test set
train = X[1:45,]
train = rbind(train, X[52:95,])
train = rbind(train, X[102:145,])
test = X[46:51,]
test = rbind(test, X[96:101,])
test = rbind(test, X[146:150,])
# train GMM
[labels, prob, df, bic, mu, prec_chol, w] = gmm(X=train, n_components = $2,
model = $3, init_params = $4, iter = $5, reg_covar = $6, tol = $7, seed=$8, verbose=TRUE)
# predict labels
[pred, pp] = gmmPredict(test, w, mu, prec_chol, $3)
# expected clusters/predictions
expected = matrix("6 6 5", 3, 1)
resp = matrix(1, 17, 3) * t(seq(1,3))
resp = resp == pred
cluster = t(colSums(resp))
cluster = order(target = cluster, by = 1, decreasing = FALSE, index.return=FALSE)
correct_Predictions = order(target = expected, by = 1, decreasing = FALSE, index.return=FALSE)
error = mean(abs(correct_Predictions - cluster))
write(error, $9, format = "text")