blob: c4f5113b796f045081fcc9c0f6af4ca7558eb003 [file] [log] [blame]
#include "dbconnector/dbconnector.hpp"
#include "elastic_net_binomial_fista.hpp"
#include "state/fista.hpp"
#include "elastic_net_optimizer_fista.hpp"
#include "share/shared_utils.hpp"
namespace madlib {
namespace modules {
namespace elastic_net {
/*
This class contains specific methods needed by Gaussian model using FISTA
*/
class BinomialFista
{
public:
static void initialize(FistaState<MutableArrayHandle<double> >& state);
static void get_y(double& y, AnyType& args);
static void normal_transition(FistaState<MutableArrayHandle<double> >& state,
MappedColumnVector& x, double y);
static void active_transition(FistaState<MutableArrayHandle<double> >& state,
MappedColumnVector& x, double y);
static void update_b_intercept(FistaState<MutableArrayHandle<double> >& state);
static void update_loglikelihood(FistaState<MutableArrayHandle<double> >& state,
MappedColumnVector& x, double y);
static void update_y_intercept(FistaState<MutableArrayHandle<double> >& state,
double old_tk);
static void update_y_intercept_final(FistaState<MutableArrayHandle<double> >& state);
static void merge_intercept(FistaState<MutableArrayHandle<double> >& state1,
FistaState<ArrayHandle<double> >& state2);
private:
static void backtracking_transition(FistaState<MutableArrayHandle<double> >& state,
MappedColumnVector& x, double y);
};
// ------------------------------------------------------------------------
inline void BinomialFista::update_y_intercept_final(FistaState<MutableArrayHandle<double> >& state)
{
state.gradient_intercept = state.gradient_intercept / static_cast<double>(state.totalRows);
}
// -----------------------------------------------------------------------------
/**
@brief Compute log-likelihood for one data point in binomial models
*/
inline void BinomialFista::update_loglikelihood(
FistaState<MutableArrayHandle<double> >& state,
MappedColumnVector& x, double y) {
double r = state.intercept + sparse_dot(state.coef, x);
if (y > 0)
state.loglikelihood += std::log(1 + std::exp(-r));
else
state.loglikelihood += std::log(1 + std::exp(r));
}
// ------------------------------------------------------------------------
inline void BinomialFista::merge_intercept(
FistaState<MutableArrayHandle<double> >& state1,
FistaState<ArrayHandle<double> >& state2) {
state1.gradient_intercept += state2.gradient_intercept;
}
// ------------------------------------------------------------------------
inline void BinomialFista::initialize(FistaState<MutableArrayHandle<double> >& state)
{
state.coef.setZero();
state.coef_y.setZero();
state.intercept = 0;
state.intercept_y = 0;
}
// ------------------------------------------------------------------------
// extract dependent variable from args
inline void BinomialFista::get_y(double& y, AnyType& args)
{
y = args[2].getAs<bool>() ? 1. : -1.;
}
// ------------------------------------------------------------------------
inline void BinomialFista::normal_transition(FistaState<MutableArrayHandle<double> >& state,
MappedColumnVector& x, double y)
{
if (state.backtracking == 0)
{
double r = state.intercept_y + sparse_dot(state.coef_y, x);
double u;
if (y > 0)
u = - 1. / (1. + std::exp(r));
else
u = 1. / (1. + std::exp(-r));
for (uint32_t i = 0; i < state.dimension; i++)
state.gradient(i) += x(i) * u;
// update gradient
state.gradient_intercept += u;
}
else
backtracking_transition(state, x, y);
}
// ------------------------------------------------------------------------
inline void BinomialFista::active_transition(FistaState<MutableArrayHandle<double> >& state,
MappedColumnVector& x, double y)
{
if (state.backtracking == 0) // Compute gradient for active set
{
double r = state.intercept_y + sparse_dot(state.coef_y, x);
double u;
if (y > 0)
u = - 1. / (1. + std::exp(r));
else
u = 1. / (1. + std::exp(-r));
for (uint32_t i = 0; i < state.dimension; i++)
if (state.coef_y(i) != 0)
state.gradient(i) += x(i) * u;
// always update intercept
state.gradient_intercept += u;
}
else
backtracking_transition(state, x, y);
}
// ------------------------------------------------------------------------
inline void BinomialFista::backtracking_transition(FistaState<MutableArrayHandle<double> >& state,
MappedColumnVector& x, double y)
{
// during backtracking, always use b_coef and b_intercept
double r = state.b_intercept + sparse_dot(state.b_coef, x);
if (y > 0)
state.fn += std::log(1 + std::exp(-r));
else
state.fn += std::log(1 + std::exp(r));
// Qfn only need to be calculated once in each backtracking
if (state.backtracking == 1)
{
r = state.intercept_y + sparse_dot(state.coef_y, x);
if (y > 0)
state.Qfn += std::log(1 + std::exp(-r));
else
state.Qfn += std::log(1 + std::exp(r));
}
}
// ------------------------------------------------------------------------
inline void BinomialFista::update_b_intercept (FistaState<MutableArrayHandle<double> >& state)
{
state.b_intercept = state.intercept_y - state.stepsize * state.gradient_intercept;
}
// ------------------------------------------------------------------------
inline void BinomialFista::update_y_intercept (FistaState<MutableArrayHandle<double> >& state,
double old_tk)
{
state.intercept_y = state.b_intercept + (old_tk - 1) * (state.b_intercept - state.intercept)
/ state.tk;
}
// ------------------------------------------------------------------------
// ------------------------------------------------------------------------
// ------------------------------------------------------------------------
/*
The following are the functions that are actually called by SQL
*/
/**
@brief Perform FISTA transition step
It is called for each tuple of (x, y)
*/
AnyType binomial_fista_transition::run (AnyType& args)
{
return Fista<BinomialFista>::fista_transition(args, *this);
}
// ------------------------------------------------------------------------
/**
@brief Perform Merge transition steps
*/
AnyType binomial_fista_merge::run (AnyType& args)
{
return Fista<BinomialFista>::fista_merge(args);
}
// ------------------------------------------------------------------------
/**
@brief Perform the final computation
*/
AnyType binomial_fista_final::run (AnyType& args)
{
return Fista<BinomialFista>::fista_final(args);
}
// ------------------------------------------------------------------------
/**
* @brief Return the difference in RMSE between two states
*/
AnyType __binomial_fista_state_diff::run (AnyType& args)
{
return Fista<BinomialFista>::fista_state_diff(args);
}
// ------------------------------------------------------------------------
/**
* @brief Return the coefficients and diagnostic statistics of the state
*/
AnyType __binomial_fista_result::run (AnyType& args)
{
return Fista<BinomialFista>::fista_result(args);
}
}
}
}