blob: 52e56ef1ee01bc146c5cbac109cb80c72611bcf0 [file] [log] [blame]
---------------------------------------------------------------------------
-- Rules:
-- ------
-- 1) Any DB objects should be created w/o schema prefix,
-- since this file is executed in a separate schema context.
-- 2) There should be no DROP statements in this script, since
-- all objects created in the default schema will be cleaned-up outside.
---------------------------------------------------------------------------
---------------------------------------------------------------------------
-- Setup:
---------------------------------------------------------------------------
CREATE FUNCTION install_test() RETURNS VOID AS $$
declare
result float;
x FLOAT[];
begin
CREATE TABLE A(
row INT,
val FLOAT[]
);
INSERT INTO A VALUES(1, ARRAY[1.0175001, 0.45604107, 0.32282152, 0.25168270, 0.20694042, 0.17602822, 0.15331823, 0.13589278, 0.12208089, 0.11085359]);
INSERT INTO A VALUES(2, ARRAY[0.4560411, 0.48234387, 0.21190993, 0.17177051, 0.14528005, 0.12627935, 0.11189286, 0.10057578, 0.09141562, 0.08383503]);
INSERT INTO A VALUES(3, ARRAY[0.3228215, 0.21190993, 0.36372483, 0.13511499, 0.11570299, 0.10149569, 0.09057135, 0.08187212, 0.07476054, 0.06882649]);
INSERT INTO A VALUES(4, ARRAY[0.2516827, 0.17177051, 0.13511499, 0.31270305, 0.09720943, 0.08572070, 0.07680054, 0.06964365, 0.06375766, 0.05882206]);
INSERT INTO A VALUES(5, ARRAY[0.2069404, 0.14528005, 0.11570299, 0.09720943, 0.28424848, 0.07454783, 0.06696472, 0.06084900, 0.05579869, 0.05154975]);
INSERT INTO A VALUES(6, ARRAY[0.1760282, 0.12627935, 0.10149569, 0.08572070, 0.07454783, 0.26612646, 0.05951006, 0.05415360, 0.04971703, 0.04597544]);
INSERT INTO A VALUES(7, ARRAY[0.1533182, 0.11189286, 0.09057135, 0.07680054, 0.06696472, 0.05951006, 0.25363014, 0.04885589, 0.04489247, 0.04154376]);
INSERT INTO A VALUES(8, ARRAY[0.1358928, 0.10057578, 0.08187212, 0.06964365, 0.06084900, 0.05415360, 0.04885589, 0.24454426, 0.04095839, 0.03792426]);
INSERT INTO A VALUES(9, ARRAY[0.1220809, 0.09141562, 0.07476054, 0.06375766, 0.05579869, 0.04971703, 0.04489247, 0.04095839, 0.23768165, 0.03490583]);
INSERT INTO A VALUES(10,ARRAY[0.1108536, 0.08383503, 0.06882649, 0.05882206, 0.05154975, 0.04597544, 0.04154376, 0.03792426, 0.03490583, 0.23234632]);
SELECT INTO x MADLIB_SCHEMA.conjugate_gradient('A', 'val', 'row', ARRAY(SELECT random() FROM generate_series(1,10)), .000001,2);
result = x[1]*x[1];
IF (result > .00000001) THEN
RAISE EXCEPTION 'Incorrect results, got %',result;
END IF;
-- simple symmetric +ve def case
CREATE TABLE data(row_num INT, row_val FLOAT[]);
INSERT INTO data VALUES (1,'{2,1}');
INSERT INTO data VALUES (2,'{1,4}');
SELECT INTO x MADLIB_SCHEMA.conjugate_gradient('data','row_val','row_num','{2,1}',1E-6,1);
IF (round(x[1]) != 1) OR (round(x[2]) != 0) THEN
RAISE EXCEPTION 'Incorrect multivariate results, got %',x;
END IF;
RAISE INFO 'Conjugate gradient install checks passed';
RETURN;
end
$$ language plpgsql;
---------------------------------------------------------------------------
-- Test
---------------------------------------------------------------------------
SELECT install_test();