blob: 63dfa5ef61abd3cd95c3dc83f01db4649eae0303 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# Implements builtin multiclass SVM with squared slack variables,
# learns one-against-the-rest binary-class classifiers by making a function call to l2SVM
# INPUT PARAMETERS:
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# X Double --- matrix X of feature vectors
# Y Double --- matrix Y of class labels
# intercept Boolean False No Intercept ( If set to TRUE then a constant bias column is added to X)
# num_classes integer 10 Number of classes
# epsilon Double 0.001 Procedure terminates early if the reduction in objective function
# value is less than epsilon (tolerance) times the initial objective function value.
# lambda Double 1.0 Regularization parameter (lambda) for L2 regularization
# maxIterations Int 100 Maximum number of conjugate gradient iterations
# verbose Boolean False Set to true to print while training.
# ---------------------------------------------------------------------------------------------
#Output(s)
# ---------------------------------------------------------------------------------------------
# NAME TYPE DEFAULT MEANING
# ---------------------------------------------------------------------------------------------
# model Double --- model matrix
m_msvm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = FALSE,
Double epsilon = 0.001, Double lambda = 1.0, Integer maxIterations = 100, Boolean verbose = FALSE)
return(Matrix[Double] model)
{
if(min(Y) < 0)
stop("MSVM: Invalid Y input, containing negative values")
if(verbose)
print("Running Multiclass-SVM")
num_rows_in_w = ncol(X)
if(intercept) {
num_rows_in_w = num_rows_in_w + 1
}
if(ncol(Y) > 1)
Y = rowMaxs(Y * t(seq(1,ncol(Y))))
# Assuming number of classes to be max contained in Y
w = matrix(0, rows=num_rows_in_w, cols=max(Y))
parfor(class in 1:max(Y)) {
Y_local = 2 * (Y == class) - 1
w[,class] = l2svm(X=X, Y=Y_local, intercept=intercept,
epsilon=epsilon, lambda=lambda, maxIterations=maxIterations,
verbose= verbose, columnId=class)
}
model = w
}