| #------------------------------------------------------------- |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| #------------------------------------------------------------- |
| |
| # This Scripts helps in applying an trained MSVM |
| # |
| # INPUT: |
| #------------------------------------------------------------------------------ |
| # X matrix X of feature vectors to classify |
| # W matrix of the trained variables |
| #------------------------------------------------------------------------------ |
| # |
| # OUTPUT: |
| #------------------------------------------------------------------------------- |
| # YRaw Classification Labels Raw, meaning not modified to clean |
| # Labeles of 1's and -1's |
| # Y Classification Labels Maxed to ones and zeros. |
| #------------------------------------------------------------------------------- |
| |
| m_msvmPredict = function(Matrix[Double] X, Matrix[Double] W) |
| return(Matrix[Double] YRaw, Matrix[Double] Y) |
| { |
| # Robustness for datasets with missing values |
| numNaNs = sum(isNaN(X)) |
| if( numNaNs > 0 ) { |
| print("msvm: matrix X contains "+numNaNs+" missing values, replacing with 0.") |
| X = replace(target=X, pattern=NaN, replacement=0); |
| } |
| if(ncol(X) != nrow(W)){ |
| if(ncol(X) + 1 != nrow(W)){ |
| stop("MSVM Predict: Invalid shape of W ["+ncol(W)+","+nrow(W)+"] or X ["+ncol(X)+","+nrow(X)+"]") |
| } |
| YRaw = X %*% W[1:ncol(X),] + W[ncol(X)+1,] |
| Y = rowIndexMax(YRaw) |
| } |
| else{ |
| YRaw = X %*% W |
| Y = rowIndexMax(YRaw) |
| } |
| } |