| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| M = 10000; |
| N = 200; |
| sp = 1.0; #1.0 |
| nweights = 10; #3000 |
| |
| X = rand(rows=M, cols=N, sparsity=sp, seed=42); |
| y = rand(rows=M, cols=1, min=0, max=2, seed=42); |
| y = ceil(y); |
| |
| model_svm = l2svm(X=X, Y=y, intercept=TRUE, epsilon=1e-12, |
| reg=0.001, maxIterations=20, verbose=FALSE); |
| model_mlr = multiLogReg(X=X, Y=y, icpt=2, tol=1e-6, reg=0.001, maxi=20, maxii=20, verbose=FALSE); |
| |
| # Assign random weights and grid search top-k models |
| bestAcc = 0; |
| weights = rand(rows=2, cols=nweights, min=0, max=1, seed=42); |
| nclass = 2; |
| k = 2; |
| for (wi in 1:nweights) { |
| weightedClassProb = matrix(0, M, 2); |
| for (i in 1:k) { |
| [yRaw, yPred] = l2svmPredict(X=X, W=model_svm, verbose=FALSE); |
| probs_svm = yRaw / rowSums(yRaw); |
| [prob_mlr, Y_mlr, acc] = multiLogRegPredict(X=X, B=model_mlr, Y=y, verbose=FALSE); |
| weightedClassProb = weightedClassProb + as.scalar(weights[1,wi])*probs_svm + as.scalar(weights[2,wi])*prob_mlr; |
| y_voted = rowIndexMax(weightedClassProb); |
| acc = sum(y_voted == y) / M * 100; |
| if (acc > bestAcc) { |
| bestWeights = weights; |
| bestAcc = acc; |
| } |
| } |
| } |
| R = bestAcc; |
| write(R, $1, format="text"); |
| |