blob: bd60143eb29ef45a1dd823eb334ddc27f0da5f3c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/* ----------------------------------------------------------------------- *//**
*
* @file avg_var.cpp
*
* @brief average population variance functions
*
*//* ----------------------------------------------------------------------- */
#include <dbconnector/dbconnector.hpp>
#include <modules/shared/HandleTraits.hpp>
#include "avg_var.hpp"
namespace madlib {
namespace modules {
namespace hello_world {
template <class Handle>
class AvgVarTransitionState {
template <class OtherHandle>
friend class AvgVarTransitionState;
public:
AvgVarTransitionState(const AnyType &inArray)
: mStorage(inArray.getAs<Handle>()) {
rebind();
}
/**
* @brief Convert to backend representation
*
* We define this function so that we can use State in the
* argument list and as a return type.
*/
inline operator AnyType() const {
return mStorage;
}
/**
* @brief Update state with a new data point
*/
AvgVarTransitionState & operator+=(const double x){
double diff = (x - avg);
double normalizer = static_cast<double>(numRows + 1);
// online update mean
this->avg += diff / normalizer;
// online update variance
double new_diff = (x - avg);
double a = static_cast<double>(this->numRows) / normalizer;
this->var = (var * a) + (diff * new_diff) / normalizer;
return *this;
}
/**
* @brief Merge with another state object
*
* We update mean and variance in a online fashion
* to avoid intermediate large sum.
*/
template <class OtherHandle>
AvgVarTransitionState &operator+=(
const AvgVarTransitionState<OtherHandle> &inOtherState) {
if (mStorage.size() != inOtherState.mStorage.size())
throw std::logic_error("Internal error: Incompatible transition "
"states");
double avg_ = inOtherState.avg;
double var_ = inOtherState.var;
uint16_t numRows_ = static_cast<uint16_t>(inOtherState.numRows);
double totalNumRows = static_cast<double>(numRows + numRows_);
// we perform a weighted average between states
double w = static_cast<double>(numRows) / totalNumRows;
double w_ = static_cast<double>(numRows_) / totalNumRows;
double totalAvg = avg * w + avg_ * w_;
double a = avg - totalAvg;
double a_ = avg_ - totalAvg;
numRows += numRows_;
this->var = (w * var) + (w_ * var_) + (w * a * a) + (w_ * a_ * a_);
this->avg = totalAvg;
return *this;
}
private:
void rebind() {
avg.rebind(&mStorage[0]);
var.rebind(&mStorage[1]);
numRows.rebind(&mStorage[2]);
}
Handle mStorage;
public:
typename HandleTraits<Handle>::ReferenceToDouble avg;
typename HandleTraits<Handle>::ReferenceToDouble var;
typename HandleTraits<Handle>::ReferenceToUInt64 numRows;
};
AnyType
avg_var_transition::run(AnyType& args) {
// get current state value
AvgVarTransitionState<MutableArrayHandle<double> > state = args[0];
// update state with current row value
double x = args[1].getAs<double>();
state += x;
state.numRows ++;
return state;
}
AnyType
avg_var_merge_states::run(AnyType& args) {
AvgVarTransitionState<MutableArrayHandle<double> > stateLeft = args[0];
AvgVarTransitionState<ArrayHandle<double> > stateRight = args[1];
// Merge states together and return
stateLeft += stateRight;
return stateLeft;
}
AnyType
avg_var_final::run(AnyType& args) {
AvgVarTransitionState<MutableArrayHandle<double> > state = args[0];
// If we haven't seen any data, just return Null. This is the standard
// behavior of aggregate function on empty data sets (compare, e.g.,
// how PostgreSQL handles sum or avg on empty inputs)
if (state.numRows == 0)
return Null();
return state;
}
// -----------------------------------------------------------------------
} // namespace hello_world
} // namespace modules
} // namespace madlib