blob: 6e0cb516343d6397a58d3f8e14308213fd2adfd8 [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
# This function calls the multiLogReg-function in which solves Multinomial
# Logistic Regression using Trust Region method
#
# INPUT:
# -------------------------------------------------------------------------------------
# X matrix of feature vectors
# Y matrix with category labels
# threshold threshold to clear otherwise return X and Y unmodified
# verbose flag specifying if logging information should be printed
# -------------------------------------------------------------------------------------
#
# OUTPUT:
# -------------------------------------------------------------------------------------
# Xout abstained output X
# Yout abstained output Y
# -------------------------------------------------------------------------------------
m_abstain = function(Matrix[Double] X, Matrix[Double] Y, Double threshold, Boolean verbose = FALSE)
return (Matrix[Double] Xout, Matrix[Double] Yout)
{
Xout = X
Yout = Y
if(min(Y) != max(Y) & max(Y) <= 2)
{
betas = multiLogReg(X=X, Y=Y, icpt=1, reg=1e-4, maxi=100, maxii=0, verbose=verbose)
[prob, yhat, accuracy] = multiLogRegPredict(X, betas, Y, FALSE)
inc = ((yhat != Y) & (rowMaxs(prob) > threshold))
if(sum(inc) > 0)
{
Xout = removeEmpty(target = X, margin = "rows", select = (inc == 0) )
Yout = removeEmpty(target = Y, margin = "rows", select = (inc == 0) )
}
}
}