blob: b6b2996f3348a5391f60cd09ff1b974c492643a5 [file] [log] [blame]
/* ----------------------------------------------------------------------- *//**
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*//* ----------------------------------------------------------------------- */
-- The following function is used to assert if the minibatch preprocessor packs the
-- data in expected way so that dependent variable and independent variable have the
-- correct dimension.
-- column_name : allows 'dependent_varname' or 'independent_varname' as
-- input to indicate which variable to test against
-- dimension : '1' is used to test how many rows(records) get packed into one super row,
-- '2' is used to test how many columns in one records get packed into an array
-- expected_result: a string of array with the expected number for row/col number in
-- ascending order, e.g. '{2,4,4}'
-- grp : grouping columns. "NULL" if there is no grouping.
-- See more examples in this file to know how it works.
CREATE OR REPLACE FUNCTION assert_col_dimension(column_name VARCHAR, dimension int, expected_result text, grp text)
RETURNS void AS
$$
DECLARE
qry text;
result text;
BEGIN
IF grp is NULL THEN
qry := 'select array_agg(row_count order by row_count asc) from (select array_upper(' || column_name ||' ,' || dimension ||') as row_count from minibatch_preprocessing_out order by row_count asc) s';
ELSE
qry := 'select array_agg(row_count order by ' || grp || ', row_count asc) from (select array_upper(' || column_name ||' ,' || dimension ||') as row_count,' || grp || ' from minibatch_preprocessing_out order by ' || grp || ' , row_count asc) s';
END IF;
EXECUTE qry into result;
IF result != expected_result THEN
raise exception 'Dependent/Independent Varaiable dimension check failed. Actual: % Expected %', result, expected_result;
END IF;
END;
$$ LANGUAGE plpgsql;
DROP TABLE IF EXISTS minibatch_preprocessing_input;
CREATE TABLE minibatch_preprocessing_input(
sex TEXT,
id SERIAL NOT NULL,
length DOUBLE PRECISION,
diameter DOUBLE PRECISION,
height DOUBLE PRECISION,
whole DOUBLE PRECISION,
shucked DOUBLE PRECISION,
viscera DOUBLE PRECISION,
shell DOUBLE PRECISION,
rings INTEGER);
INSERT INTO minibatch_preprocessing_input(id,sex,length,diameter,height,whole,shucked,viscera,shell,rings) VALUES
(1040,'F',0.66,0.475,0.18,1.3695,0.641,0.294,0.335,6),
(3160,'F',0.34,0.255,0.085,0.204,0.097,0.021,0.05,6),
(3984,'F',0.585,0.45,0.125,0.874,0.3545,0.2075,0.225,6),
(2551,'I',0.28,0.22,0.08,0.1315,0.066,0.024,0.03,5),
(1246,'I',0.385,0.28,0.09,0.228,0.1025,0.042,0.0655,5),
(519,'M',0.325,0.23,0.09,0.147,0.06,0.034,0.045,4),
(2382,'M',0.155,0.115,0.025,0.024,0.009,0.005,0.0075,5),
(698,'M',0.28,0.205,0.1,0.1165,0.0545,0.0285,0.03,5),
(2381,'M',0.175,0.135,0.04,0.0305,0.011,0.0075,0.01,5),
(516,'M',0.27,0.195,0.08,0.1,0.0385,0.0195,0.03,6);
-- check if an integer dependent var is being one-hot-encoded
DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
SELECT minibatch_preprocessor('minibatch_preprocessing_input',
'minibatch_preprocessing_out',
'rings',
'ARRAY[diameter,height,whole,shucked,viscera,shell]',
NULL,
4,
TRUE);
SELECT assert(array_upper(dependent_varname, 2) > 1,
'One hot encoding with one_hot_encode_int_dep_var=TRUE is not working')
FROM minibatch_preprocessing_out;
-- check that double precision dependent var is not being one-hot-encoded
DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
SELECT minibatch_preprocessor('minibatch_preprocessing_input',
'minibatch_preprocessing_out',
'rings::double precision',
'ARRAY[diameter,height,whole,shucked,viscera,shell]',
NULL,
4,
TRUE);
SELECT assert(1 < all(dependent_varname),
'Double precision values are being incorrectly encoded')
FROM minibatch_preprocessing_out;
-- no of rows = 10, buffer_size = 4, so assert that count = 10/4 = 3
\set expected_row_count 3
DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
SELECT minibatch_preprocessor('minibatch_preprocessing_input',
'minibatch_preprocessing_out',
'length>0.2',
'ARRAY[diameter,height,whole,shucked,viscera,shell]',
NULL,
4);
SELECT assert
(
row_count = :expected_row_count, 'Row count validation failed.
Expected:' || :expected_row_count || ' Actual: ' || row_count
) from (select count(*) as row_count from minibatch_preprocessing_out) s;
-- assert dimensions for both dependent and independent variable
SELECT assert_col_dimension('dependent_varname', 1 , '{2,4,4}', NULL);
SELECT assert_col_dimension('dependent_varname', 2 , '{2,2,2}', NULL);
SELECT assert_col_dimension('independent_varname', 1 , '{2,4,4}', NULL);
SELECT assert_col_dimension('independent_varname', 2 , '{6,6,6}', NULL);
SELECT assert
(
source_table = 'minibatch_preprocessing_input' AND
output_table = 'minibatch_preprocessing_out' AND
dependent_varname = 'length>0.2' AND
independent_varname = 'ARRAY[diameter,height,whole,shucked,viscera,shell]' AND
buffer_size = 4 AND
class_values = '{f,t}' AND -- we sort the class values in python
num_rows_processed = 10 AND
num_missing_rows_skipped = 0 AND
grouping_cols is NULL,
'Summary Validation failed. Actual:' || __to_char(summary)
) from (select * from minibatch_preprocessing_out_summary) summary;
-- grouping with the same dataset
\set expected_grouping_row_count '\'' 1,1,2 '\''
DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'length>0.2', 'ARRAY[diameter,height,whole,shucked,viscera,shell]', 'rings', 4);
SELECT assert
(
str_row_count = :expected_grouping_row_count, 'Row count validation failed for minibatch_preprocessing grouping.
Expected:' || :expected_grouping_row_count || ' Actual: ' || str_row_count
) from
(
select array_to_string(array_agg(row_count order by row_count asc),',') as str_row_count from
(
select count(*) as row_count
from minibatch_preprocessing_out group by rings order by rings
) s
) s1;
-- assert dimensions for both dependent and independent variable, notice that in
-- each query, the result is order by grouping_cols and rowcount so we end up
-- having one str of array
SELECT assert_col_dimension('dependent_varname', 1 , '{1,1,4,4}', 'rings');
SELECT assert_col_dimension('dependent_varname', 2 , '{2,2,2,2}', 'rings');
SELECT assert_col_dimension('independent_varname', 1 , '{1,1,4,4}', 'rings');
SELECT assert_col_dimension('independent_varname', 2 , '{6,6,6,6}', 'rings');
SELECT assert
(
source_table = 'minibatch_preprocessing_input' AND
output_table = 'minibatch_preprocessing_out' AND
dependent_varname = 'length>0.2' AND
independent_varname = 'ARRAY[diameter,height,whole,shucked,viscera,shell]' AND
dependent_vartype = 'boolean' AND
buffer_size = 4 AND
class_values = '{f,t}' AND -- we sort the class values in python
num_rows_processed = 10 AND
num_missing_rows_skipped = 0 AND
grouping_cols = 'rings',
'Summary Validation failed for grouping col. Actual:' || __to_char(summary)
) from (select * from minibatch_preprocessing_out_summary) summary;
-- Test that the standardization table gets created.
select count(*) from minibatch_preprocessing_out_standardization;
-- Test that the summary table gets created.
select count(*) from minibatch_preprocessing_out_summary;
-- Test null values in x and y both with and without grouping
DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
TRUNCATE TABLE minibatch_preprocessing_input;
INSERT INTO minibatch_preprocessing_input(id,sex,length,diameter,height,whole,shucked,viscera,shell,rings) VALUES
(1040,'F',0.66,0.475,0.18,NULL,0.641,0.294,0.335,5),
(3984,NULL,0.585,0.45,0.25,0.874,0.3545,0.2075,0.225,5),
(861,'M',0.595,0.475,NULL,1.1405,0.547,0.231,0.271,5),
(932,NULL,0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195,5),
(698,'F',0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195,6),
(698,'F',0.445,0.335,0.11,0.4355,NULL,0.1095,0.1195,6),
(698,'F',0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195,6),
(922,NULL,0.445,0.335,0.11,NULL,0.2025,0.1095,0.1195,5),
(942,'F',0.445,0.335,0.11,0.25,0.2025,0.1095,0.1195,5);
SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'sex', 'ARRAY[length,diameter,height,whole,shucked,viscera,shell]', NULL, 2);
SELECT assert
(
row_count = 2, 'Row count validation failed with null values.
Expected:' || 2 || ' Actual: ' || row_count
) from (select count(*) as row_count from minibatch_preprocessing_out) s;
SELECT assert
(num_rows_processed = 3 AND num_missing_rows_skipped = 6,
'Rows processed/skipped validation failed_summary.
Actual num_rows_processed:' || num_rows_processed || ', Actual num_missing_rows_skipped: ' || num_missing_rows_skipped
) from (select * from minibatch_preprocessing_out_summary) s;
DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'sex', 'ARRAY[length,diameter,height,whole,shucked,viscera,shell]', 'rings', 2);
SELECT assert
(
row_count = 2, 'Row count validation failed with null values and grouping.
Expected:' || 2 || ' Actual: ' || row_count
) from (select count(*) as row_count from minibatch_preprocessing_out) s;
SELECT assert
(num_rows_processed = 3 AND num_missing_rows_skipped = 6,
'Rows processed/skipped validation failed_summary with grouping.
Actual num_rows_processed:' || num_rows_processed || ', Actual num_missing_rows_skipped: ' || num_missing_rows_skipped
) from (select * from minibatch_preprocessing_out_summary) s;
-- Test standardization
DROP TABLE IF EXISTS minibatch_preprocessing_input;
DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
CREATE TABLE minibatch_preprocessing_input(x1 INTEGER ,x2 INTEGER ,y TEXT);
INSERT INTO minibatch_preprocessing_input(x1,x2,y) VALUES
(2,10,'y1'),
(4,30,'y2');
SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'y', 'ARRAY[x1,x2]', NULL, 2);
-- since the order is not deterministic, we assert for all possible orders
\set expected_normalized_independent_var1 '\'' {{-1, -1},{1, 1}} '\''
\set expected_normalized_independent_var2 '\'' {{1, 1},{-1, -1}} '\''
SELECT assert
(
independent_varname = :expected_normalized_independent_var1 OR
independent_varname = :expected_normalized_independent_var2,
'Standardization check failed. Actual: ' || independent_varname
) from
(
select __to_char(independent_varname) as independent_varname from minibatch_preprocessing_out
) s;
-- Test that the standardization table gets created.
select count(*) from minibatch_preprocessing_out_standardization;
-- Test that the summary table gets created.
select count(*) from minibatch_preprocessing_out_summary;
-- Test for array values in indep column
DROP TABLE IF EXISTS minibatch_preprocessing_input;
DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
CREATE TABLE minibatch_preprocessing_input(
id INTEGER,
sex TEXT,
attributes double precision[],
rings INTEGER);
TRUNCATE TABLE minibatch_preprocessing_input;
INSERT INTO minibatch_preprocessing_input(id,sex,attributes) VALUES
(1040,'F',ARRAY[0.66,0.475,0.18,NULL,0.641,0.294,0.335]),
(3160,'F',ARRAY[0.34,0.35,0.085,0.204,0.097,0.021,0.05]),
(3984,NULL,ARRAY[0.585,0.45,0.25,0.874,0.3545,0.2075,0.225]),
(861,'M',ARRAY[0.595,0.475,NULL,1.1405,0.547,0.231,0.271]),
(932,NULL,ARRAY[0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195]),
(NULL,'F',ARRAY[0.445,0.335,0.11,0.4355,0.2025,0.1095,0.1195]),
(922,NULL,ARRAY[0.445,0.335,0.11,NULL,0.2025,0.1095,0.1195]);
SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'sex', 'attributes', NULL, 1);
SELECT assert
(
row_count = 2, 'Row count validation failed with array values in independent variable.
Expected:' || 2 || ' Actual: ' || row_count
) from (select count(*) as row_count from minibatch_preprocessing_out) s;
-- Test for array values in dep column
DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'attributes', 'ARRAY[id]', NULL, 1);
SELECT assert
(
row_count = 3, 'Row count validation failed with array values in dependent variable.
Expected:' || 3 || ' Actual: ' || row_count
) from (select count(*) as row_count from minibatch_preprocessing_out) s;
-- Test for null buffer size
DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'attributes', 'ARRAY[id]');
SELECT assert
(
ind_var_rows = dep_var_rows AND ind_var_rows = buffer_size, 'Row count validation failed for null buffer size.
Buffer size from summary table: ' || buffer_size || ' does not match the output table:'
|| ind_var_rows
) from (select max(array_upper(o.dependent_varname, 1)) as dep_var_rows, max(array_upper(o.independent_varname, 1)) as ind_var_rows , s1.buffer_size from minibatch_preprocessing_out o, minibatch_preprocessing_out_summary s1 group by buffer_size) s;
-- Test default buffer size calculator
DROP TABLE IF EXISTS minibatch_preprocessing_input;
CREATE TABLE minibatch_preprocessing_input(
sex TEXT,
length DOUBLE PRECISION[],
rings INTEGER);
INSERT INTO minibatch_preprocessing_input VALUES
('F',ARRAY[0.66, 0.5],6);
DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', 'sex', 'length', 'rings');
SELECT assert
(
source_table = 'minibatch_preprocessing_input' AND
output_table = 'minibatch_preprocessing_out' AND
dependent_varname = 'sex' AND
independent_varname = 'length' AND
buffer_size = 1 AND
class_values = '{F}' AND -- we sort the class values in python
num_rows_processed = 1 AND
num_missing_rows_skipped = 0 AND
grouping_cols = 'rings',
'Summary Validation failed for special chars. Actual:' || __to_char(summary)
) from (select * from minibatch_preprocessing_out_summary) summary;
-- Test special characters and unicode characters in independent_var, dependent_var and grouping_cols, both for column name and column value
DROP TABLE IF EXISTS minibatch_preprocessing_input;
CREATE TABLE minibatch_preprocessing_input(
"se$$''x" TEXT,
"len$$'%*()gth" DOUBLE PRECISION[],
"rin$$Ж!#'gs" INTEGER);
INSERT INTO minibatch_preprocessing_input VALUES
('M''M',ARRAY[0.66, 0.5],6),
('''M''M''',ARRAY[0.66, 0.5],6),
('M|$$M',ARRAY[0.66, 0.5],6),
('M,M',ARRAY[0.66, 0.5],6),
('M@[}(:*;M',ARRAY[0.66, 0.5],6),
('M"M',ARRAY[0.66, 0.5],6),
('MЖM',ARRAY[0.66, 0.5],6);
DROP TABLE IF EXISTS minibatch_preprocessing_out, minibatch_preprocessing_out_standardization, minibatch_preprocessing_out_summary;
SELECT minibatch_preprocessor('minibatch_preprocessing_input', 'minibatch_preprocessing_out', '"se$$''''x"', '"len$$''%*()gth"', '"rin$$Ж!#''gs"', 4);
SELECT assert(
source_table = 'minibatch_preprocessing_input' AND
output_table = 'minibatch_preprocessing_out' AND
dependent_varname = '"se$$''''x"' AND
independent_varname = '"len$$''%*()gth"' AND
buffer_size = 4 AND
class_values = $__madlib__${ 'M'M', M\"M, M'M, "M,M", "M@[}(:*;M", "M|$$M", MЖM }$__madlib__$ AND
-- we sort the class values in python
num_rows_processed = 7 AND
num_missing_rows_skipped = 0 AND
grouping_cols = '"rin$$Ж!#''gs"',
'Summary Validation failed for special chars. Actual:' || __to_char(summary)
) from (select * from minibatch_preprocessing_out_summary order by class_values) summary;