blob: 076b9eb5971ece6a65dafa3855121a4fea627131 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# This builtin function implements a multi-class Support Vector Machine (SVM)
# with squared slack variables. The trained model comprises #classes
# one-against-the-rest binary-class l2svm classification models.
#
# INPUT:
#-------------------------------------------------------------------------------
# X Feature matrix X (shape: m x n)
# Y Label vector y of class labels (shape: m x 1),
# where max(Y) is assumed to be the number of classes
# intercept Indicator if a bias column should be added to X and the model
# epsilon Tolerance for early termination if the reduction of objective
# function is less than epsilon times the initial objective
# reg Regularization parameter (lambda) for L2 regularization
# maxIterations Maximum number of conjugate gradient (outer l2svm) iterations
# verbose Indicator if training details should be printed
# ------------------------------------------------------------------------------
#
# OUTPUT:
#-------------------------------------------------------------------------------
# model Trained model/weights (shape: n x max(Y), w/ intercept: n+1)
#-------------------------------------------------------------------------------
m_msvm = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = FALSE,
Double epsilon = 0.001, Double reg = 1.0, Integer maxIterations = 100,
Boolean verbose = FALSE)
return(Matrix[Double] model)
{
if(min(Y) < 0)
stop("MSVM: Invalid Y input, containing negative values")
if(verbose)
print("Running Multiclass-SVM")
# Robustness for datasets with missing values (causing NaN gradients)
numNaNs = sum(isNaN(X))
if( numNaNs > 0 ) {
print("msvm: matrix X contains "+numNaNs+" missing values, replacing with 0.")
X = replace(target=X, pattern=NaN, replacement=0);
}
# append once, and call l2svm always with intercept=FALSE
if(intercept) {
ones = matrix(1, rows=nrow(X), cols=1)
X = cbind(X, ones);
}
if(ncol(Y) > 1)
Y = rowIndexMax(Y)
# Assuming number of classes to be max contained in Y
w = matrix(0, rows=ncol(X), cols=max(Y))
parfor(class in 1:max(Y)) {
# extract the class' binary labels and convert to -1/+1
Y_local = 2 * (Y == class) - 1
# train l2svm model with robustness for non-existing classes
nnzY = sum(Y == class);
if( nnzY > 0 ) {
w[,class] = l2svm(X=X, Y=Y_local, intercept=FALSE,
epsilon=epsilon, reg=reg, maxIterations=maxIterations,
verbose=verbose, columnId=class)
}
else {
w[,class] = matrix(-Inf, ncol(X), 1);
}
}
model = w
}