| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| |
| X = read($1, data_type = "frame", format = "csv", header=TRUE) |
| X = X[ , 2:ncol(X) - 1] |
| X = as.matrix(X) |
| |
| # divide in train and test set |
| train = X[1:45,] |
| train = rbind(train, X[52:95,]) |
| train = rbind(train, X[102:145,]) |
| |
| test = X[46:51,] |
| test = rbind(test, X[96:101,]) |
| test = rbind(test, X[146:150,]) |
| |
| # train GMM |
| [labels, prob, df, bic, mu, prec_chol, w] = gmm(X=train, n_components = $2, |
| model = $3, init_params = $4, iter = $5, reg_covar = $6, tol = $7, seed=$8, verbose=TRUE) |
| |
| # predict labels |
| [pred, pp] = gmmPredict(test, w, mu, prec_chol, $3) |
| |
| # expected clusters/predictions |
| expected = matrix("6 6 5", 3, 1) |
| |
| resp = matrix(1, 17, 3) * t(seq(1,3)) |
| resp = resp == pred |
| cluster = t(colSums(resp)) |
| |
| cluster = order(target = cluster, by = 1, decreasing = FALSE, index.return=FALSE) |
| correct_Predictions = order(target = expected, by = 1, decreasing = FALSE, index.return=FALSE) |
| |
| error = mean(abs(correct_Predictions - cluster)) |
| write(error, $9, format = "text") |