blob: 8ff342079b4fa2c2335eb8f50e8643c9b1521b80 [file] [log] [blame]
#include <dbconnector/dbconnector.hpp>
#include "crossprod.hpp"
namespace madlib {
namespace modules {
namespace linalg {
using namespace dbal::eigen_integration;
// ----------------------------------------------------------------------
AnyType __pivotalr_crossprod_transition::run (AnyType& args)
{
ArrayHandle<double> left = args[1].getAs<ArrayHandle<double> >();
ArrayHandle<double> right = args[2].getAs<ArrayHandle<double> >();
size_t m = left.size();
size_t n = right.size();
MutableArrayHandle<double> state(NULL);
if (args[0].isNull()) {
state = this->allocateArray<double, dbal::AggregateContext, dbal::DoZero,
dbal::ThrowBadAlloc>(m * n);
for (size_t i = 0; i < state.size(); i++) state[i] = 0;
} else
state = args[0].getAs<MutableArrayHandle<double> >();
int count = 0;
for (size_t i = 0; i < m; i++)
for (size_t j = 0; j < n; j++)
state[count++] += left[i] * right[j];
return state;
}
// ----------------------------------------------------------------------
AnyType __pivotalr_crossprod_merge::run (AnyType& args)
{
if (args[0].isNull() && args[1].isNull()) return args[0];
if (args[0].isNull()) return args[1];
if (args[1].isNull()) return args[0];
MutableArrayHandle<double> state1 = args[0].getAs<MutableArrayHandle<double> >();
ArrayHandle<double> state2 = args[1].getAs<ArrayHandle<double> >();
for (size_t i = 0; i < state1.size(); i++) state1[i] += state2[i];
return state1;
}
// ----------------------------------------------------------------------
AnyType __pivotalr_crossprod_sym_transition::run (AnyType& args)
{
ArrayHandle<double> arr = args[1].getAs<ArrayHandle<double> >();
size_t n = arr.size();
MutableArrayHandle<double> state(NULL);
if (args[0].isNull()) {
state = this->allocateArray<double, dbal::AggregateContext, dbal::DoZero,
dbal::ThrowBadAlloc>(n * (n + 1) / 2);
for (size_t i = 0; i < state.size(); i++) state[i] = 0;
} else
state = args[0].getAs<MutableArrayHandle<double> >();
int count = 0;
for (size_t i = 0; i < n; i++)
for (size_t j = 0; j <= i; j++)
state[count++] += arr[i] * arr[j];
return state;
}
}
}
}