blob: 64c532da103011f4d71bce32c9117241556f23db [file] [log] [blame]
/* -----------------------------------------------------------------------------
*
* matrix.c
*
* Fundamental matrix operations.
*
* Copyright (c) 2010, EMC
*
* -----------------------------------------------------------------------------
*/
#include "postgres.h"
#include "funcapi.h"
#include "catalog/pg_type.h"
#include "utils/array.h"
#include "utils/lsyscache.h"
#include "matrix.h"
/* compatibility macros */
#ifndef GP_VERSION
#define INT2ARRAYOID 1005
#define INT8ARRAYOID 1016
#define FLOAT8ARRAYOID 1022
#endif
/* Indicate "version 1" calling conventions for all exported functions. */
PG_FUNCTION_INFO_V1(matrix_transpose);
PG_FUNCTION_INFO_V1(matrix_multiply);
PG_FUNCTION_INFO_V1(matrix_add);
PG_FUNCTION_INFO_V1(int8_matrix_smultiply);
PG_FUNCTION_INFO_V1(float8_matrix_smultiply);
/*
* matrix_transpose - transpose an array[x][y] into an array[y][x].
*/
Datum
matrix_transpose(PG_FUNCTION_ARGS)
{
ArrayType *m, *result;
Oid eltype;
int elsize, size, ndim, databytes;
int i,j;
int16 typlen;
bool typbyval;
char typalign;
char *data_m, *data_r;
m = PG_GETARG_ARRAYTYPE_P(0);
ndim = ARR_NDIM(m);
/* Sanity check, transpose only valid for two dimensional arrays. */
if (ndim != 2)
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("argument must be a two dimensional array")));
if (m->dataoffset || ARR_NULLBITMAP(m))
ereport(ERROR,
(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
errmsg("null array element not allowed in this context")));
if (ARR_ELEMTYPE(m) == InvalidOid)
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("could not determine anyarray/anyelement type because "
"input has type \"unknown\"")));
/* datatype and size of input matrix */
eltype = ARR_ELEMTYPE(m);
size = ARR_SIZE(m);
databytes = size - ARR_DATA_OFFSET(m);
/* Allocate the result matrix */
result = (ArrayType *) palloc(size);
SET_VARSIZE(result, size);
result->ndim = ndim;
result->dataoffset = 0; /* We disallowed arrays with NULLS */
result->elemtype = eltype;
ARR_DIMS(result)[1] = ARR_DIMS(m)[0];
ARR_DIMS(result)[0] = ARR_DIMS(m)[1];
ARR_LBOUND(result)[1] = ARR_LBOUND(m)[0];
ARR_LBOUND(result)[0] = ARR_LBOUND(m)[1];
/*
* If one of the dimensions was length 1 then it is a simple transpose
* and we can directly copy the data
*/
if (ARR_DIMS(m)[0] == 1 || ARR_DIMS(m)[1] == 1)
{
memcpy(ARR_DATA_PTR(result), ARR_DATA_PTR(m), databytes);
PG_RETURN_ARRAYTYPE_P(result);
}
/* Otherwise we have a more complicated data reshuffle */
get_typlenbyvalalign(eltype, &typlen, &typbyval, &typalign);
elsize = att_align_nominal(typlen, typalign);
data_m = ARR_DATA_PTR(m);
data_r = ARR_DATA_PTR(result);
for (i = 0; i < ARR_DIMS(m)[1]; i++)
{
for (j = 0; j < ARR_DIMS(m)[0]; j++)
{
int index_m = i + j * ARR_DIMS(m)[1];
if (typlen > 0)
{
char *seek = data_m + (elsize * index_m);
memcpy(data_r, seek, elsize);
data_r = (char *) att_align_nominal(data_r+elsize, typalign);
}
else
{
char *seek = data_m;
int k;
for (k = 0; k < index_m; k++)
{
seek += VARSIZE(seek);
seek = (char*) att_align_nominal(seek, typalign);
}
memcpy(data_r, seek, VARSIZE(seek));
data_r = (char *) att_align_nominal(data_r+VARSIZE(seek), typalign);
}
}
}
Assert(data_r - ARR_DATA_PTR(result) == databytes);
PG_RETURN_ARRAYTYPE_P(result);
}
/*
* matrix_multipy - array multiplication of two input arrays
*/
Datum matrix_multiply(PG_FUNCTION_ARGS)
{
ArrayType *m, *n, *result;
Oid eltype;
int elsize, size;
int i,j,k;
int16 typlen;
bool typbyval;
char typalign;
char *data_m, *data_n, *data_r;
m = PG_GETARG_ARRAYTYPE_P(0);
n = PG_GETARG_ARRAYTYPE_P(1);
eltype = ARR_ELEMTYPE(m);
/* Do all error checking */
if (ARR_NDIM(m) != 2 || ARR_NDIM(n) != 2)
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("argument must be a two dimensional array")));
if (ARR_DIMS(m)[1] != ARR_DIMS(n)[0])
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("non-conformable arrays")));
if (m->dataoffset || ARR_NULLBITMAP(m) ||
n->dataoffset || ARR_NULLBITMAP(n))
ereport(ERROR,
(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
errmsg("null array element not allowed in this context")));
if (!OidIsValid(ARR_ELEMTYPE(m)) || !OidIsValid(ARR_ELEMTYPE(n)))
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("could not determine anyarray/anyelement type because "
"input has type \"unknown\"")));
if (ARR_ELEMTYPE(m) != ARR_ELEMTYPE(n))
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("cannot multiply arrays of different element types")));
if (eltype != INT2OID && eltype != INT4OID && eltype != INT8OID &&
eltype != FLOAT4OID && eltype != FLOAT8OID /* && eltype != NUMERICOID */)
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("datatype not supported for array multiplication")));
/* datatype and size of input matrix */
get_typlenbyvalalign(eltype, &typlen, &typbyval, &typalign);
elsize = att_align_nominal(typlen, typalign);
/* Output array is always either int8[] or float8[] */
size = ARR_OVERHEAD_NONULLS(2); /* overhead for two dimensional array */
size += ARR_DIMS(m)[0] * ARR_DIMS(n)[1] * sizeof(int64);
result = (ArrayType *) palloc(size);
SET_VARSIZE(result, size);
result->ndim = 2;
result->dataoffset = 0; /* We dissallowed arrays with NULLS */
ARR_DIMS(result)[0] = ARR_DIMS(m)[0];
ARR_DIMS(result)[1] = ARR_DIMS(n)[1];
ARR_LBOUND(result)[0] = 1; /* may need better lbound handling? */
ARR_LBOUND(result)[1] = 1;
switch (eltype) {
case INT2OID:
case INT4OID:
case INT8OID:
result->elemtype = INT8OID;
break;
case FLOAT4OID:
case FLOAT8OID:
result->elemtype = FLOAT8OID;
break;
default:
Assert(false);
}
data_r = ARR_DATA_PTR(result);
memset(data_r, 0, ARR_DIMS(m)[0] * ARR_DIMS(n)[1] * sizeof(int64));
/*
* Three dimensional loop over result arrays two dimensional space and
* over the intersection dimension of the two input matrices.
* Locality is preserved for the m matrix.
*
* This will need to be reworked for variable length numeric elements.
*/
for (i = 0; i < ARR_DIMS(result)[0]; i++)
{
for (j = 0; j < ARR_DIMS(result)[1]; j++)
{
data_m = ARR_DATA_PTR(m) + i * ARR_DIMS(m)[1] * elsize;
for (k = 0; k < ARR_DIMS(m)[1]; k++)
{
data_n = ARR_DATA_PTR(n) + k * ARR_DIMS(n)[1] * elsize;
switch (eltype) {
case INT2OID:
*(int64*)data_r +=
((int16*)data_n)[j] * ((int16*)data_m)[k];
break;
case INT4OID:
*(int64*)data_r +=
((int32*)data_n)[j] * ((int32*)data_m)[k];
break;
case INT8OID:
*(int64*)data_r +=
((int64*)data_n)[j] * ((int64*)data_m)[k];
break;
case FLOAT4OID:
*(float8*)data_r +=
((float*)data_n)[j] * ((float*)data_m)[k];
break;
case FLOAT8OID:
*(float8*)data_r +=
((float8*)data_n)[j] * ((float8*)data_m)[k];
break;
default:
Assert(false);
}
}
data_r += 8;
}
}
PG_RETURN_ARRAYTYPE_P(result);
}
/*
* matrix_add - array summation over two input arrays
*/
Datum matrix_add(PG_FUNCTION_ARGS)
{
ArrayType *m, *n;
Oid mtype, ntype;
int i, ndim, len;
bool transition_function;
/* If we're in a transition function we can be smarter */
transition_function = fcinfo->context && IsA(fcinfo->context, AggState);
/* Validate arguments */
if (PG_NARGS() != 2)
{
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("matrix_add called with %d arguments", PG_NARGS())));
}
/*
* This function is sometimes strict, and sometimes not in order to deal
* with needing to upconvert datatypes in an aggregate function.
*/
if (fcinfo->flinfo != NULL && fcinfo->flinfo->fn_strict &&
(PG_ARGISNULL(0) || PG_ARGISNULL(1)))
PG_RETURN_NULL();
/*
* When we are upconverting we always upconvert to the datatype of the
* first argument, so the first argument is a safe return value
*/
if (PG_ARGISNULL(1))
{
if (PG_ARGISNULL(0))
PG_RETURN_NULL();
else
PG_RETURN_ARRAYTYPE_P(PG_GETARG_ARRAYTYPE_P(0));
}
n = PG_GETARG_ARRAYTYPE_P(1);
ndim = ARR_NDIM(n);
ntype = ARR_ELEMTYPE(n);
/* Typecheck the input arrays, we only handle fixed length numeric data. */
if (ntype != INT2OID && ntype != INT4OID && ntype != INT8OID &&
ntype != FLOAT4OID && ntype != FLOAT8OID)
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("matrix_add: unsupported datatype")));
/* count total number of elements */
for (len = 1, i = 0; i < ndim; i++)
len *= ARR_DIMS(n)[i];
if (PG_ARGISNULL(0))
{
int size;
int elsize;
Oid returntype;
TupleDesc tupdesc;
/* Determine what our return type should be */
get_call_result_type(fcinfo, &returntype, &tupdesc);
switch (returntype)
{
case INT2ARRAYOID:
mtype = INT2OID;
elsize = sizeof(int16);
break;
case INT4ARRAYOID:
mtype = INT4OID;
elsize = sizeof(int32);
break;
case INT8ARRAYOID:
mtype = INT8OID;
elsize = sizeof(int64);
break;
case FLOAT4ARRAYOID:
mtype = FLOAT4OID;
elsize = sizeof(float4);
break;
case FLOAT8ARRAYOID:
mtype = FLOAT8OID;
elsize = sizeof(float8);
break;
default:
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("matrix_add: return datatype lookup failure")));
/* Completely useless code that fixes compiler warnings */
mtype = INT2OID;
elsize = sizeof(int16);
}
/* Allocate the state matrix */
size = ARR_OVERHEAD_NONULLS(ndim) + len * elsize;
m = (ArrayType *) palloc(size);
SET_VARSIZE(m, size);
m->ndim = ndim;
m->dataoffset = 0;
m->elemtype = mtype;
for (i = 0; i < ndim; i++)
{
ARR_DIMS(m)[i] = ARR_DIMS(n)[i];
ARR_LBOUND(m)[i] = 1;
}
memset(ARR_DATA_PTR(m), 0, len * elsize);
}
else
{
m = PG_GETARG_ARRAYTYPE_P(0);
mtype = ARR_ELEMTYPE(m);
if (ndim != ARR_NDIM(m))
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("matrix_add: Dimensionality of both arrays must match")));
for (i = 0; i < ndim; i++)
{
if (ARR_DIMS(m)[i] != ARR_DIMS(n)[i])
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("matrix_add: non-conformable arrays")));
}
if (ARR_NULLBITMAP(m) || ARR_NULLBITMAP(n))
ereport(ERROR,
(errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED),
errmsg("matrix_add: null array element not allowed in this context")));
/* Typecheck the input arrays, we only handle fixed length numeric data. */
if (ntype != INT2OID && ntype != INT4OID && ntype != INT8OID &&
ntype != FLOAT4OID && ntype != FLOAT8OID)
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("matrix_add: unsupported datatype")));
}
/* Transition function updates in place, otherwise allocate result */
if (transition_function)
{
switch (mtype)
{
case INT2OID:
{
int16 *data_m = (int16*) ARR_DATA_PTR(m);
switch (ntype)
{
case INT2OID:
{
int16 *data_n = (int16*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
default:
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("matrix_add: can not downconvert state")));
}
break;
}
case INT4OID:
{
int32 *data_m = (int32*) ARR_DATA_PTR(m);
switch (ntype)
{
case INT2OID:
{
int16 *data_n = (int16*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
case INT4OID:
{
int32 *data_n = (int32*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
default:
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("matrix_add: can not downconvert state")));
}
break;
}
case INT8OID:
{
int64 *data_m = (int64*) ARR_DATA_PTR(m);
switch (ntype)
{
case INT2OID:
{
int16 *data_n = (int16*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
case INT4OID:
{
int32 *data_n = (int32*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
case INT8OID:
{
int64 *data_n = (int64*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
default:
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("matrix_add: can not downconvert state")));
}
break;
}
case FLOAT4OID:
{
float4 *data_m = (float4*) ARR_DATA_PTR(m);
switch (ntype)
{
case INT2OID:
{
int16 *data_n = (int16*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
case INT4OID:
{
int32 *data_n = (int32*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
case INT8OID:
{
int64 *data_n = (int64*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
case FLOAT4OID:
{
float4 *data_n = (float4*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
default:
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("matrix_add: can not downconvert state")));
}
break;
}
case FLOAT8OID:
{
float8 *data_m = (float8*) ARR_DATA_PTR(m);
switch (ntype)
{
case INT2OID:
{
int16 *data_n = (int16*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
case INT4OID:
{
int32 *data_n = (int32*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
case INT8OID:
{
int64 *data_n = (int64*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
case FLOAT4OID:
{
float4 *data_n = (float4*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
case FLOAT8OID:
{
float8 *data_n = (float8*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_m[i] += data_n[i];
break;
}
default:
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("matrix_add: can not downconvert state")));
}
break;
}
default:
Assert(false);
}
PG_RETURN_ARRAYTYPE_P(m);
}
else
{
ArrayType *result;
Oid rtype = InvalidOid;
int elsize = 0;
int size;
/* Result type for non-transition function is the higher of the two input types */
if (ntype == FLOAT8OID || mtype == FLOAT8OID)
{
rtype = FLOAT8OID;
elsize = sizeof(float8);
}
else if (ntype == FLOAT4OID || mtype == FLOAT4OID)
{
rtype = FLOAT4OID;
elsize = sizeof(float4);
}
else if (ntype == INT8OID || mtype == INT8OID)
{
rtype = INT8OID;
elsize = sizeof(int64);
}
else if (ntype == INT4OID || mtype == INT4OID)
{
rtype = INT4OID;
elsize = sizeof(int32);
}
else if (ntype == INT2OID || mtype == INT2OID)
{
rtype = INT2OID;
elsize = sizeof(int16);
}
Assert(rtype != InvalidOid && elsize > 0);
size = ARR_OVERHEAD_NONULLS(ndim) + len * elsize;
result = (ArrayType *) palloc(size);
SET_VARSIZE(result, size);
result->ndim = ndim;
result->dataoffset = 0; /* We dissallowed arrays with NULLS */
result->elemtype = rtype;
for (i = 0; i < ndim; i++)
{
ARR_DIMS(result)[i] = ARR_DIMS(n)[i];
ARR_LBOUND(result)[i] = 1;
}
switch (mtype)
{
case INT2OID:
{
int16 *data_m = (int16*) ARR_DATA_PTR(m);
switch (ntype)
{
case INT2OID:
{
int16 *data_n = (int16*) ARR_DATA_PTR(n);
int16 *data_r = (int16*) ARR_DATA_PTR(result);
Assert(rtype == INT2OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case INT4OID:
{
int32 *data_n = (int32*) ARR_DATA_PTR(n);
int32 *data_r = (int32*) ARR_DATA_PTR(result);
Assert(rtype == INT4OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case INT8OID:
{
int64 *data_n = (int64*) ARR_DATA_PTR(n);
int64 *data_r = (int64*) ARR_DATA_PTR(result);
Assert(rtype == INT8OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case FLOAT4OID:
{
float4 *data_n = (float4*) ARR_DATA_PTR(n);
float4 *data_r = (float4*) ARR_DATA_PTR(result);
Assert(rtype == FLOAT4OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case FLOAT8OID:
{
float8 *data_n = (float8*) ARR_DATA_PTR(n);
float8 *data_r = (float8*) ARR_DATA_PTR(result);
Assert(rtype == FLOAT8OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
default:
Assert(false);
}
break;
}
case INT4OID:
{
int32 *data_m = (int32*) ARR_DATA_PTR(m);
switch (ntype)
{
case INT2OID:
{
int16 *data_n = (int16*) ARR_DATA_PTR(n);
int32 *data_r = (int32*) ARR_DATA_PTR(result);
Assert(rtype == INT4OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case INT4OID:
{
int32 *data_n = (int32*) ARR_DATA_PTR(n);
int32 *data_r = (int32*) ARR_DATA_PTR(result);
Assert(rtype == INT4OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case INT8OID:
{
int64 *data_n = (int64*) ARR_DATA_PTR(n);
int64 *data_r = (int64*) ARR_DATA_PTR(result);
Assert(rtype == INT8OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case FLOAT4OID:
{
float4 *data_n = (float4*) ARR_DATA_PTR(n);
float4 *data_r = (float4*) ARR_DATA_PTR(result);
Assert(rtype == FLOAT4OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case FLOAT8OID:
{
float8 *data_n = (float8*) ARR_DATA_PTR(n);
float8 *data_r = (float8*) ARR_DATA_PTR(result);
Assert(rtype == FLOAT8OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
default:
Assert(false);
}
break;
}
case INT8OID:
{
int64 *data_m = (int64*) ARR_DATA_PTR(m);
switch (ntype)
{
case INT2OID:
{
int16 *data_n = (int16*) ARR_DATA_PTR(n);
int64 *data_r = (int64*) ARR_DATA_PTR(result);
Assert(rtype == INT8OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case INT4OID:
{
int32 *data_n = (int32*) ARR_DATA_PTR(n);
int64 *data_r = (int64*) ARR_DATA_PTR(result);
Assert(rtype == INT8OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case INT8OID:
{
int64 *data_n = (int64*) ARR_DATA_PTR(n);
int64 *data_r = (int64*) ARR_DATA_PTR(result);
Assert(rtype == INT8OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case FLOAT4OID:
{
float4 *data_n = (float4*) ARR_DATA_PTR(n);
float4 *data_r = (float4*) ARR_DATA_PTR(result);
Assert(rtype == FLOAT4OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case FLOAT8OID:
{
float8 *data_n = (float8*) ARR_DATA_PTR(n);
float8 *data_r = (float8*) ARR_DATA_PTR(result);
Assert(rtype == FLOAT8OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
default:
Assert(false);
}
break;
}
case FLOAT4OID:
{
float4 *data_m = (float4*) ARR_DATA_PTR(m);
switch (ntype)
{
case INT2OID:
{
int16 *data_n = (int16*) ARR_DATA_PTR(n);
float4 *data_r = (float4*) ARR_DATA_PTR(result);
Assert(rtype == FLOAT4OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case INT4OID:
{
int32 *data_n = (int32*) ARR_DATA_PTR(n);
float4 *data_r = (float4*) ARR_DATA_PTR(result);
Assert(rtype == FLOAT4OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case INT8OID:
{
int64 *data_n = (int64*) ARR_DATA_PTR(n);
float4 *data_r = (float4*) ARR_DATA_PTR(result);
Assert(rtype == FLOAT4OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case FLOAT4OID:
{
float4 *data_n = (float4*) ARR_DATA_PTR(n);
float4 *data_r = (float4*) ARR_DATA_PTR(result);
Assert(rtype == FLOAT4OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case FLOAT8OID:
{
float8 *data_n = (float8*) ARR_DATA_PTR(n);
float8 *data_r = (float8*) ARR_DATA_PTR(result);
Assert(rtype == FLOAT8OID);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
default:
Assert(false);
}
break;
}
case FLOAT8OID:
{
float8 *data_m = (float8*) ARR_DATA_PTR(m);
float8 *data_r = (float8*) ARR_DATA_PTR(result);
Assert(rtype == FLOAT8OID);
switch (ntype)
{
case INT2OID:
{
int16 *data_n = (int16*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case INT4OID:
{
int32 *data_n = (int32*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case INT8OID:
{
int64 *data_n = (int64*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case FLOAT4OID:
{
float4 *data_n = (float4*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
case FLOAT8OID:
{
float8 *data_n = (float8*) ARR_DATA_PTR(n);
for (i = 0; i < len; i++)
data_r[i] = data_m[i] + data_n[i];
break;
}
default:
Assert(false);
}
break;
}
default:
Assert(false);
}
PG_RETURN_ARRAYTYPE_P(result);
}
}
/*
* int8_matrix_smultiply - scalar multiple of input array by an int8
*/
Datum int8_matrix_smultiply(PG_FUNCTION_ARGS)
{
ArrayType *m, *result;
int64 *data_r;
int64 scalar;
int size, nelem, ndim;
int i;
if (PG_NARGS() != 2)
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("matrix_smultiply_int8 called with %d arguments",
PG_NARGS())));
if (PG_ARGISNULL(0) || PG_ARGISNULL(1))
PG_RETURN_NULL();
scalar = PG_GETARG_INT64(1);
m = PG_GETARG_ARRAYTYPE_P(0);
ndim = ARR_NDIM(m);
/* Do all error checking */
if (ARR_ELEMTYPE(m) != INT8OID)
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("matrix_smultiply_int8 datatype mismatch")));
/* count number of elements */
for (nelem = 1, i = 0; i < ndim; i++)
nelem *= ARR_DIMS(m)[i];
/* Make a copy of the input matrix */
size = VARSIZE(m);
result = (ArrayType *) palloc(size);
SET_VARSIZE(result, size);
memcpy(result, m, size);
/* And multiply by the scalar in place */
data_r = (int64*) ARR_DATA_PTR(result);
for (i = 0; i < nelem; i++)
data_r[i] *= scalar;
PG_RETURN_ARRAYTYPE_P(result);
}
/*
* float8_matrix_smultiply - scalar multiple of input array by a float8
*/
Datum float8_matrix_smultiply(PG_FUNCTION_ARGS)
{
ArrayType *m, *result;
float8 *data_r;
float8 scalar;
int size, nelem, ndim;
int i;
if (PG_NARGS() != 2)
ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("matrix_smultiply_float8 called with %d arguments",
PG_NARGS())));
if (PG_ARGISNULL(0) || PG_ARGISNULL(1))
PG_RETURN_NULL();
scalar = PG_GETARG_FLOAT8(1);
m = PG_GETARG_ARRAYTYPE_P(0);
ndim = ARR_NDIM(m);
/* Do all error checking */
if (ARR_ELEMTYPE(m) != FLOAT8OID)
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("matrix_smultiply_float8 datatype mismatch")));
/* count number of elements */
for (nelem = 1, i = 0; i < ndim; i++)
nelem *= ARR_DIMS(m)[i];
/* Make a copy of the input matrix */
size = VARSIZE(m);
result = (ArrayType *) palloc(size);
SET_VARSIZE(result, size);
memcpy(result, m, size);
/* And multiply by the scalar in place */
data_r = (float8*) ARR_DATA_PTR(result);
for (i = 0; i < nelem; i++)
data_r[i] *= scalar;
PG_RETURN_ARRAYTYPE_P(result);
}