blob: bc18c25007a343440a12e157f722a31675393dac [file] [log] [blame]
/* ----------------------------------------------------------------------- *//**
*
* @file ordinal.cpp
*
* @brief ordinal linear model functions
*
*//* ----------------------------------------------------------------------- */
#include <dbconnector/dbconnector.hpp>
#include <modules/prob/student.hpp>
#include "Ordinal_proto.hpp"
#include "Ordinal_impl.hpp"
#include "ordinal.hpp"
namespace madlib {
namespace modules {
namespace glm {
// types of states
typedef OrdinalAccumulator<RootContainer> OrdinalState;
typedef OrdinalAccumulator<MutableRootContainer> OrdinalMutableState;
// Logit link
// ------------------------------------------------------------------------
typedef OrdinalAccumulator<MutableRootContainer,Multinomial,OrdinalLogit>
MutableOrdinalLogitState;
AnyType
ordinal_logit_transition::run(AnyType& args) {
MutableOrdinalLogitState state = args[0].getAs<MutableByteString>();
if (state.terminated || args[1].isNull() || args[2].isNull()) {
return args[0];
}
double y = args[1].getAs<double>();
MappedColumnVector x;
try {
MappedColumnVector xx = args[2].getAs<MappedColumnVector>();
x.rebind(xx.memoryHandle(), xx.size());
} catch (const ArrayWithNullException &e) {
return args[0];
}
if (state.empty()) {
state.num_features = static_cast<uint16_t>(x.size());
state.num_categories = args[4].getAs<uint16_t>();
state.optimizer.num_coef = static_cast<uint16_t>(state.num_features + (state.num_categories-1));
state.resize();
if (!args[3].isNull()) {
OrdinalState prev_state = args[3].getAs<ByteString>();
state = prev_state;
state.reset();
}
}
state << OrdinalMutableState::tuple_type(x, y);
return state.storage();
}
// ------------------------------------------------------------------------
// Probit link
// ------------------------------------------------------------------------
typedef OrdinalAccumulator<MutableRootContainer,Multinomial,OrdinalProbit>
MutableOrdinalProbitState;
AnyType
ordinal_probit_transition::run(AnyType& args) {
MutableOrdinalProbitState state = args[0].getAs<MutableByteString>();
if (state.terminated || args[1].isNull() || args[2].isNull()) {
return args[0];
}
double y = args[1].getAs<double>();
MappedColumnVector x;
try {
MappedColumnVector xx = args[2].getAs<MappedColumnVector>();
x.rebind(xx.memoryHandle(), xx.size());
} catch (const ArrayWithNullException &e) {
return args[0];
}
if (state.empty()) {
state.num_features = static_cast<uint16_t>(x.size());
state.num_categories = args[4].getAs<uint16_t>();
state.optimizer.num_coef = static_cast<uint16_t>(
state.num_features + (state.num_categories-1));
state.resize();
if (!args[3].isNull()) {
OrdinalState prev_state = args[3].getAs<ByteString>();
state = prev_state;
state.reset();
}
}
state << OrdinalMutableState::tuple_type(x, y);
return state.storage();
}
// ------------------------------------------------------------------------
AnyType
ordinal_merge_states::run(AnyType& args) {
OrdinalMutableState stateLeft = args[0].getAs<MutableByteString>();
OrdinalState stateRight = args[1].getAs<ByteString>();
stateLeft << stateRight;
return stateLeft.storage();
}
AnyType
ordinal_final::run(AnyType& args) {
OrdinalMutableState state = args[0].getAs<MutableByteString>();
// If we haven't seen any valid data, just return Null. This is the standard
// behavior of aggregate function on empty data sets (compare, e.g.,
// how PostgreSQL handles sum or avg on empty inputs)
if (state.empty() || state.terminated) { return Null(); }
state.apply();
if (state.empty() || state.terminated) { return Null(); }
return state.storage();
}
// ------------------------------------------------------------------------
AnyType
ordinal_result::run(AnyType& args) {
if (args[0].isNull()) { return Null(); }
OrdinalState state = args[0].getAs<ByteString>();
OrdinalResult result(state);
AnyType tuple;
tuple << result.coef_alpha
<< result.std_err_alpha
<< result.z_stats_alpha
<< result.p_values_alpha
<< result.loglik
<< result.coef_beta
<< result.std_err_beta
<< result.z_stats_beta
<< result.p_values_beta
<< result.num_rows_processed;
return tuple;
}
// ------------------------------------------------------------------------
AnyType
ordinal_loglik_diff::run(AnyType& args) {
if (args[0].isNull() || args[1].isNull()) {
return std::numeric_limits<double>::infinity();
} else {
OrdinalState state0 = args[0].getAs<ByteString>();
OrdinalState state1 = args[1].getAs<ByteString>();
double a = state0.loglik;
double b = state1.loglik;
if (a >= 0. || b >= 0.) { return 0.; } // probability = 1
return std::abs(a - b) / std::min(std::abs(a), std::abs(b));
}
}
} // namespace glm
} // namespace modules
} // namespace madlib