| |
| #include "dbconnector/dbconnector.hpp" |
| #include "elastic_net_binomial_fista.hpp" |
| #include "state/fista.hpp" |
| #include "elastic_net_optimizer_fista.hpp" |
| #include "share/shared_utils.hpp" |
| |
| namespace madlib { |
| namespace modules { |
| namespace elastic_net { |
| |
| /* |
| This class contains specific methods needed by Gaussian model using FISTA |
| */ |
| class BinomialFista |
| { |
| public: |
| static void initialize(FistaState<MutableArrayHandle<double> >& state); |
| static void get_y(double& y, AnyType& args); |
| static void normal_transition(FistaState<MutableArrayHandle<double> >& state, |
| MappedColumnVector& x, double y); |
| static void active_transition(FistaState<MutableArrayHandle<double> >& state, |
| MappedColumnVector& x, double y); |
| |
| static void update_b_intercept(FistaState<MutableArrayHandle<double> >& state); |
| static void update_loglikelihood(FistaState<MutableArrayHandle<double> >& state, |
| MappedColumnVector& x, double y); |
| |
| static void update_y_intercept(FistaState<MutableArrayHandle<double> >& state, |
| double old_tk); |
| static void update_y_intercept_final(FistaState<MutableArrayHandle<double> >& state); |
| |
| static void merge_intercept(FistaState<MutableArrayHandle<double> >& state1, |
| FistaState<ArrayHandle<double> >& state2); |
| private: |
| static void backtracking_transition(FistaState<MutableArrayHandle<double> >& state, |
| MappedColumnVector& x, double y); |
| }; |
| |
| // ------------------------------------------------------------------------ |
| |
| inline void BinomialFista::update_y_intercept_final(FistaState<MutableArrayHandle<double> >& state) |
| { |
| state.gradient_intercept = state.gradient_intercept / static_cast<double>(state.totalRows); |
| } |
| // ----------------------------------------------------------------------------- |
| |
| /** |
| @brief Compute log-likelihood for one data point in binomial models |
| */ |
| inline void BinomialFista::update_loglikelihood( |
| FistaState<MutableArrayHandle<double> >& state, |
| MappedColumnVector& x, double y) { |
| |
| double r = state.intercept + sparse_dot(state.coef, x); |
| if (y > 0) |
| state.loglikelihood += std::log(1 + std::exp(-r)); |
| else |
| state.loglikelihood += std::log(1 + std::exp(r)); |
| } |
| |
| // ------------------------------------------------------------------------ |
| |
| inline void BinomialFista::merge_intercept( |
| FistaState<MutableArrayHandle<double> >& state1, |
| FistaState<ArrayHandle<double> >& state2) { |
| state1.gradient_intercept += state2.gradient_intercept; |
| } |
| |
| // ------------------------------------------------------------------------ |
| |
| inline void BinomialFista::initialize(FistaState<MutableArrayHandle<double> >& state) |
| { |
| state.coef.setZero(); |
| state.coef_y.setZero(); |
| state.intercept = 0; |
| state.intercept_y = 0; |
| } |
| |
| // ------------------------------------------------------------------------ |
| // extract dependent variable from args |
| inline void BinomialFista::get_y(double& y, AnyType& args) |
| { |
| y = args[2].getAs<bool>() ? 1. : -1.; |
| } |
| |
| // ------------------------------------------------------------------------ |
| |
| inline void BinomialFista::normal_transition(FistaState<MutableArrayHandle<double> >& state, |
| MappedColumnVector& x, double y) |
| { |
| if (state.backtracking == 0) |
| { |
| double r = state.intercept_y + sparse_dot(state.coef_y, x); |
| double u; |
| |
| if (y > 0) |
| u = - 1. / (1. + std::exp(r)); |
| else |
| u = 1. / (1. + std::exp(-r)); |
| |
| for (uint32_t i = 0; i < state.dimension; i++) |
| state.gradient(i) += x(i) * u; |
| |
| // update gradient |
| state.gradient_intercept += u; |
| } |
| else |
| backtracking_transition(state, x, y); |
| } |
| |
| // ------------------------------------------------------------------------ |
| |
| inline void BinomialFista::active_transition(FistaState<MutableArrayHandle<double> >& state, |
| MappedColumnVector& x, double y) |
| { |
| if (state.backtracking == 0) // Compute gradient for active set |
| { |
| double r = state.intercept_y + sparse_dot(state.coef_y, x); |
| double u; |
| |
| if (y > 0) |
| u = - 1. / (1. + std::exp(r)); |
| else |
| u = 1. / (1. + std::exp(-r)); |
| |
| for (uint32_t i = 0; i < state.dimension; i++) |
| if (state.coef_y(i) != 0) |
| state.gradient(i) += x(i) * u; |
| |
| // always update intercept |
| state.gradient_intercept += u; |
| } |
| else |
| backtracking_transition(state, x, y); |
| } |
| |
| // ------------------------------------------------------------------------ |
| |
| inline void BinomialFista::backtracking_transition(FistaState<MutableArrayHandle<double> >& state, |
| MappedColumnVector& x, double y) |
| { |
| // during backtracking, always use b_coef and b_intercept |
| double r = state.b_intercept + sparse_dot(state.b_coef, x); |
| |
| if (y > 0) |
| state.fn += std::log(1 + std::exp(-r)); |
| else |
| state.fn += std::log(1 + std::exp(r)); |
| |
| // Qfn only need to be calculated once in each backtracking |
| if (state.backtracking == 1) |
| { |
| r = state.intercept_y + sparse_dot(state.coef_y, x); |
| if (y > 0) |
| state.Qfn += std::log(1 + std::exp(-r)); |
| else |
| state.Qfn += std::log(1 + std::exp(r)); |
| } |
| } |
| |
| // ------------------------------------------------------------------------ |
| |
| inline void BinomialFista::update_b_intercept (FistaState<MutableArrayHandle<double> >& state) |
| { |
| state.b_intercept = state.intercept_y - state.stepsize * state.gradient_intercept; |
| } |
| |
| // ------------------------------------------------------------------------ |
| |
| inline void BinomialFista::update_y_intercept (FistaState<MutableArrayHandle<double> >& state, |
| double old_tk) |
| { |
| state.intercept_y = state.b_intercept + (old_tk - 1) * (state.b_intercept - state.intercept) |
| / state.tk; |
| } |
| |
| // ------------------------------------------------------------------------ |
| // ------------------------------------------------------------------------ |
| // ------------------------------------------------------------------------ |
| |
| /* |
| The following are the functions that are actually called by SQL |
| */ |
| |
| /** |
| @brief Perform FISTA transition step |
| |
| It is called for each tuple of (x, y) |
| */ |
| AnyType binomial_fista_transition::run (AnyType& args) |
| { |
| return Fista<BinomialFista>::fista_transition(args, *this); |
| } |
| |
| // ------------------------------------------------------------------------ |
| /** |
| @brief Perform Merge transition steps |
| */ |
| AnyType binomial_fista_merge::run (AnyType& args) |
| { |
| return Fista<BinomialFista>::fista_merge(args); |
| } |
| |
| // ------------------------------------------------------------------------ |
| /** |
| @brief Perform the final computation |
| */ |
| AnyType binomial_fista_final::run (AnyType& args) |
| { |
| return Fista<BinomialFista>::fista_final(args); |
| } |
| |
| // ------------------------------------------------------------------------ |
| |
| /** |
| * @brief Return the difference in RMSE between two states |
| */ |
| AnyType __binomial_fista_state_diff::run (AnyType& args) |
| { |
| return Fista<BinomialFista>::fista_state_diff(args); |
| } |
| |
| // ------------------------------------------------------------------------ |
| |
| /** |
| * @brief Return the coefficients and diagnostic statistics of the state |
| */ |
| AnyType __binomial_fista_result::run (AnyType& args) |
| { |
| return Fista<BinomialFista>::fista_result(args); |
| } |
| |
| } |
| } |
| } |