blob: 23f39b08145a7d3597267269bd76bed614b2de07 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
msvm2 = function(Matrix[Double] X, Matrix[Double] Y, Boolean intercept = FALSE,
Double epsilon = 0.001, Double lambda = 1.0, Integer maxIterations = 100, Boolean verbose = FALSE)
return(Matrix[Double] model)
{
if(min(Y) < 0)
stop("MSVM: Invalid Y input, containing negative values")
if(verbose)
print("Running Multiclass-SVM")
num_rows_in_w = ncol(X)
if(intercept) {
num_rows_in_w = num_rows_in_w + 1
}
if(ncol(Y) > 1)
Y = rowMaxs(Y * t(seq(1,ncol(Y))))
# Assuming number of classes to be max contained in Y
w = matrix(0, rows=num_rows_in_w, cols=max(Y))
parfor(class in 1:max(Y), opt=CONSTRAINED, par=4, mode=REMOTE_SPARK) {
Y_local = 2 * (Y == class) - 1
w[,class] = l2svm(X=X, Y=Y_local, intercept=intercept,
epsilon=epsilon, lambda=lambda, maxIterations=maxIterations,
verbose= verbose, columnId=class)
}
model = w
}
nclass = 10;
X = rand(rows=$2, cols=$3, seed=1);
y = rand(rows=$2, cols=1, min=0, max=nclass, seed=2);
y = ceil(y);
model = msvm2(X=X, Y=y, intercept=FALSE);
write(model, $1);