blob: 0d8a116d3b59a11c8567ad1f7bcfc2b26e37ac3a [file] [log] [blame]
/* ----------------------------------------------------------------------- *//**
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*//* ----------------------------------------------------------------------- */
-- NOTE that the batch specific tables were created using:
-- madlib.minibatch_preprocessor(), with the regular source tables used in
-- this file.
-- Create preprocessed data that can be used with minibatch MLP:
DROP TABLE IF EXISTS iris_data_batch, iris_data_batch_summary, iris_data_batch_standardization;
CREATE TABLE iris_data_batch(
__id__ integer,
dependent_varname double precision[],
independent_varname double precision[]
);
COPY iris_data_batch (__id__, dependent_varname, independent_varname) FROM STDIN NULL '?' DELIMITER '|';
0 | {{0,1,0},{0,1,0},{0,0,1},{1,0,0},{0,1,0},{0,1,0},{0,0,1},{1,0,0},{1,0,0},{0,1,0},{1,0,0},{0,0,1},{0,0,1},{0,0,1},{1,0,0},{0,0,1},{0,0,1},{1,0,0},{1,0,0},{0,0,1},{0,1,0},{0,0,1},{0,0,1},{0,0,1},{0,0,1},{1,0,0},{0,1,0},{0,0,1},{0,0,1},{1,0,0}} | {{0.828881825720994,-0.314980522532101,0.363710790466334,0.159758615207397},{-1.08079689039279,-1.57669227467446,-0.229158821743702,-0.240110581430527},{-1.08079689039279,-1.32434992424599,0.482284712908341,0.692917544057962},{-1.46273263361555,0.442046528753317,-1.35561108494277,-1.30642843913166},{-0.0623015751321059,-0.567322872960574,0.245136868024327,0.159758615207397},{-0.189613489539692,-0.819665223389045,0.304423829245331,0.159758615207397},{0.701569911313408,-1.32434992424599,0.778719519013359,0.959497008483245},{-1.20810880480038,-0.0626381721036282,-1.35561108494277,-1.4397181713443},{-0.698861147170034,0.946731229610261,-1.35561108494277,-1.30642843913166},{-0.82617306157762,-1.32434992424599,-0.407019705406713,-0.106820849217886},{-0.698861147170034,2.71312768260957,-1.29632412372177,-1.4397181713443},{1.33812948335134,0.442046528753317,1.31230217000239,1.49265593733381},{0.319634168090651,-0.0626381721036282,0.660145596571352,0.826207276270604},{0.701569911313408,-1.32434992424599,0.778719519013359,0.959497008483245},{-0.698861147170034,1.19907358003873,-1.29632412372177,-1.30642843913166},{1.46544139775892,0.189704178324845,0.838006480234363,1.49265593733381},{1.21081756894375,-0.0626381721036282,0.897293441455367,1.49265593733381},{-0.444237318354863,1.70375828089568,-1.29632412372177,-1.30642843913166},{-0.82617306157762,1.95610063132415,-1.05917627883775,-1.03984897470638},{0.828881825720994,-0.819665223389045,0.95658040267637,0.959497008483245},{0.956193740128579,-0.567322872960574,0.541571674129345,0.42633807963268},{1.33812948335134,0.442046528753317,1.31230217000239,1.49265593733381},{0.574257996905822,0.946731229610261,1.01586736389737,1.49265593733381},{0.0650103392754793,-0.819665223389045,0.838006480234363,0.959497008483245},{0.0650103392754793,-0.819665223389045,0.838006480234363,0.959497008483245},{-1.46273263361555,0.442046528753317,-1.35561108494277,-1.30642843913166},{0.574257996905822,-2.08137697553141,0.482284712908341,0.42633807963268},{1.21081756894375,0.189704178324845,1.13444128633938,1.62594566954645},{1.97468905538926,-0.314980522532101,1.54945001488641,0.826207276270604},{-1.08079689039279,0.189704178324845,-1.29632412372177,-1.4397181713443}}
1 | {{0,1,0},{1,0,0},{0,1,0},{1,0,0},{1,0,0},{1,0,0},{1,0,0},{0,1,0},{0,0,1},{0,0,1},{1,0,0},{0,0,1},{1,0,0},{0,0,1},{0,1,0},{0,1,0},{0,1,0},{1,0,0},{1,0,0},{0,0,1},{0,1,0},{0,1,0},{0,0,1},{1,0,0},{1,0,0},{0,1,0},{1,0,0},{0,0,1},{0,1,0},{0,1,0}} | {{-0.0623015751321059,-0.0626381721036282,0.304423829245331,0.0264688829947554},{-0.316925403947277,2.96547003303804,-1.35561108494277,-1.30642843913166},{0.319634168090651,-0.819665223389045,0.838006480234363,0.559627811845321},{-0.953484975985206,1.19907358003873,-1.41489804616377,-1.17313870691902},{-0.953484975985206,0.442046528753317,-1.47418500738478,-1.30642843913166},{-1.33542071920796,0.442046528753317,-1.41489804616377,-1.30642843913166},{-1.71735646243072,-0.0626381721036282,-1.41489804616377,-1.30642843913166},{0.446946082498236,-0.0626381721036282,0.541571674129345,0.293048347420038},{1.21081756894375,-1.32434992424599,1.25301520878139,0.826207276270604},{0.701569911313408,0.694388879181789,1.3715891312234,1.75923540175909},{-1.84466837683831,-0.0626381721036282,-1.53347196860578,-1.4397181713443},{1.84737714098168,1.45141593046721,1.4308760924444,1.75923540175909},{-0.82617306157762,1.19907358003873,-1.35561108494277,-1.30642843913166},{0.701569911313408,-0.314980522532101,1.13444128633938,0.826207276270604},{1.33812948335134,-0.567322872960574,0.660145596571352,0.293048347420038},{0.192322253683066,-0.0626381721036282,0.304423829245331,0.42633807963268},{-0.189613489539692,-0.819665223389045,0.304423829245331,0.159758615207397},{-1.46273263361555,0.189704178324845,-1.29632412372177,-1.30642843913166},{-1.71735646243072,0.442046528753317,-1.41489804616377,-1.30642843913166},{0.828881825720994,0.189704178324845,1.07515432511838,0.826207276270604},{0.0650103392754793,-1.07200757381752,0.185849906803323,0.0264688829947554},{-0.953484975985206,-2.58606167638835,-0.110584899301695,-0.240110581430527},{0.192322253683066,-0.0626381721036282,0.838006480234363,0.826207276270604},{-0.953484975985206,1.19907358003873,-1.23703716250076,-0.773269510281093},{-0.82617306157762,0.946731229610261,-1.29632412372177,-1.30642843913166},{0.319634168090651,0.946731229610261,0.482284712908341,0.559627811845321},{-0.953484975985206,0.694388879181789,-1.35561108494277,-1.30642843913166},{0.192322253683066,-0.0626381721036282,0.838006480234363,0.826207276270604},{0.446946082498236,-0.314980522532101,0.600858635350349,0.293048347420038},{-0.0623015751321059,-0.567322872960574,0.482284712908341,0.159758615207397}}
2 | {{1,0,0},{1,0,0},{0,0,1},{1,0,0},{0,1,0},{0,1,0},{1,0,0},{1,0,0},{0,0,1},{0,0,1},{0,1,0},{0,1,0},{0,1,0},{0,1,0},{1,0,0},{0,0,1},{0,1,0},{1,0,0},{1,0,0},{0,0,1},{1,0,0},{0,0,1},{1,0,0},{1,0,0},{1,0,0},{1,0,0},{0,1,0},{1,0,0},{0,1,0},{0,0,1}} | {{-0.953484975985206,0.946731229610261,-1.23703716250076,-1.03984897470638},{-0.953484975985206,0.694388879181789,-1.35561108494277,-1.30642843913166},{1.21081756894375,0.694388879181789,1.19372824756038,1.75923540175909},{-1.20810880480038,0.946731229610261,-1.23703716250076,-1.30642843913166},{1.08350565453616,-0.314980522532101,0.541571674129345,0.159758615207397},{-0.189613489539692,-0.314980522532101,-0.0512979380806911,0.159758615207397},{-1.20810880480038,-0.0626381721036282,-1.35561108494277,-1.17313870691902},{-1.08079689039279,0.189704178324845,-1.29632412372177,-1.4397181713443},{0.956193740128579,-0.0626381721036282,0.897293441455367,1.09278674069589},{0.956193740128579,-0.0626381721036282,0.897293441455367,1.09278674069589},{1.46544139775892,0.189704178324845,0.719432557792356,0.42633807963268},{0.0650103392754793,-1.07200757381752,0.185849906803323,0.0264688829947554},{1.08350565453616,-0.0626381721036282,0.422997751687338,0.293048347420038},{0.319634168090651,-0.314980522532101,0.482284712908341,0.42633807963268},{-0.82617306157762,1.95610063132415,-1.23703716250076,-1.30642843913166},{0.956193740128579,-0.0626381721036282,1.25301520878139,1.35936620512117},{-0.0623015751321059,-1.07200757381752,-0.110584899301695,-0.240110581430527},{-0.571549232762449,1.70375828089568,-1.29632412372177,-1.30642843913166},{-0.571549232762449,1.70375828089568,-1.29632412372177,-1.30642843913166},{2.35662479861202,-0.0626381721036282,1.72731089854942,1.22607647290853},{-1.71735646243072,0.442046528753317,-1.41489804616377,-1.30642843913166},{1.72006522657409,-0.0626381721036282,1.31230217000239,1.22607647290853},{-0.953484975985206,0.946731229610261,-1.29632412372177,-1.30642843913166},{-1.46273263361555,0.946731229610261,-1.35561108494277,-1.17313870691902},{-1.08079689039279,-0.0626381721036282,-1.35561108494277,-1.30642843913166},{-0.953484975985206,1.45141593046721,-1.35561108494277,-1.30642843913166},{0.701569911313408,-1.82903462510294,0.422997751687338,0.159758615207397},{-0.444237318354863,2.20844298175262,-1.17775020127976,-1.03984897470638},{-0.0623015751321059,-0.314980522532101,0.304423829245331,0.159758615207397},{1.33812948335134,-0.0626381721036282,1.07515432511838,1.22607647290853}}
3 | {{0,1,0},{0,1,0},{0,1,0},{0,1,0},{1,0,0},{1,0,0},{0,0,1},{0,1,0},{1,0,0},{0,1,0},{1,0,0},{0,1,0},{0,1,0},{0,0,1},{1,0,0},{0,1,0},{0,1,0},{0,0,1},{1,0,0},{1,0,0},{0,1,0},{0,0,1},{1,0,0},{0,1,0},{0,0,1},{0,1,0},{0,0,1},{0,1,0},{0,1,0},{0,1,0}} | {{-0.953484975985206,-1.82903462510294,-0.229158821743702,-0.240110581430527},{0.319634168090651,-2.08137697553141,0.185849906803323,-0.240110581430527},{-0.189613489539692,-0.0626381721036282,0.482284712908341,0.42633807963268},{-0.316925403947277,-1.07200757381752,0.422997751687338,0.0264688829947554},{-0.953484975985206,1.19907358003873,-1.23703716250076,-0.773269510281093},{-0.316925403947277,1.19907358003873,-1.41489804616377,-1.30642843913166},{0.0650103392754793,-0.819665223389045,0.838006480234363,0.959497008483245},{0.446946082498236,-0.0626381721036282,0.541571674129345,0.293048347420038},{-0.444237318354863,0.946731229610261,-1.29632412372177,-1.03984897470638},{1.21081756894375,0.189704178324845,0.600858635350349,0.42633807963268},{-0.82617306157762,1.95610063132415,-1.23703716250076,-1.30642843913166},{-0.0623015751321059,-0.567322872960574,0.245136868024327,0.159758615207397},{-0.316925403947277,-1.82903462510294,0.185849906803323,0.159758615207397},{1.21081756894375,-0.0626381721036282,0.897293441455367,1.49265593733381},{-1.59004454802313,-1.82903462510294,-1.41489804616377,-1.17313870691902},{0.701569911313408,0.694388879181789,0.600858635350349,0.559627811845321},{-0.316925403947277,-1.57669227467446,0.00798902314031256,-0.240110581430527},{1.46544139775892,0.189704178324845,1.01586736389737,1.22607647290853},{-1.08079689039279,0.189704178324845,-1.29632412372177,-1.4397181713443},{-1.71735646243072,-0.314980522532101,-1.35561108494277,-1.30642843913166},{-0.444237318354863,-0.0626381721036282,0.482284712908341,0.42633807963268},{1.72006522657409,-0.0626381721036282,1.31230217000239,1.22607647290853},{-0.82617306157762,1.95610063132415,-1.05917627883775,-1.03984897470638},{1.21081756894375,-0.0626381721036282,0.778719519013359,0.692917544057962},{2.35662479861202,-0.0626381721036282,1.72731089854942,1.22607647290853},{-0.953484975985206,-1.82903462510294,-0.229158821743702,-0.240110581430527},{0.701569911313408,-0.314980522532101,1.13444128633938,0.826207276270604},{-0.698861147170034,-0.819665223389045,0.12656294558232,0.293048347420038},{-0.0623015751321059,-0.314980522532101,0.304423829245331,0.159758615207397},{0.574257996905822,-0.314980522532101,0.363710790466334,0.159758615207397}}
4 | {{0,0,1},{0,1,0},{0,0,1},{0,1,0},{1,0,0},{0,1,0},{0,1,0},{0,0,1},{0,0,1},{0,1,0},{0,0,1},{0,0,1},{1,0,0},{0,1,0},{0,1,0},{0,1,0},{0,1,0},{0,1,0},{0,0,1},{1,0,0},{0,1,0}} | {{1.21081756894375,0.694388879181789,1.19372824756038,1.75923540175909},{-0.82617306157762,-1.32434992424599,-0.407019705406713,-0.106820849217886},{0.701569911313408,0.694388879181789,1.3715891312234,1.75923540175909},{0.0650103392754793,-0.819665223389045,0.245136868024327,-0.240110581430527},{-1.20810880480038,0.189704178324845,-1.23703716250076,-1.30642843913166},{0.574257996905822,-0.314980522532101,0.363710790466334,0.159758615207397},{1.21081756894375,0.189704178324845,0.422997751687338,0.293048347420038},{1.97468905538926,-0.314980522532101,1.54945001488641,0.826207276270604},{-1.08079689039279,-1.32434992424599,0.482284712908341,0.692917544057962},{0.0650103392754793,-0.819665223389045,0.12656294558232,0.0264688829947554},{0.574257996905822,0.946731229610261,1.01586736389737,1.49265593733381},{0.956193740128579,0.442046528753317,0.838006480234363,1.09278674069589},{-1.20810880480038,-0.0626381721036282,-1.35561108494277,-1.17313870691902},{0.828881825720994,0.442046528753317,0.482284712908341,0.42633807963268},{-0.0623015751321059,-0.0626381721036282,0.304423829245331,0.0264688829947554},{-0.316925403947277,-1.57669227467446,0.0672759843613159,-0.106820849217886},{-0.189613489539692,-0.0626381721036282,0.245136868024327,0.159758615207397},{1.59275331216651,0.442046528753317,0.600858635350349,0.293048347420038},{0.956193740128579,-0.0626381721036282,1.25301520878139,1.35936620512117},{-1.33542071920796,0.442046528753317,-1.23703716250076,-1.30642843913166},{-0.316925403947277,-1.32434992424599,0.185849906803323,0.159758615207397}}
\.
-- Create the corresponding summary table for preprocessed data
CREATE TABLE iris_data_batch_summary(
source_table text,
output_table text,
dependent_varname text,
independent_varname text,
dependent_vartype text,
buffer_size integer,
class_values text[],
num_rows_processed integer,
num_rows_skipped integer,
grouping_cols text
);
-- The availability of the original source table should not be a condition for
-- MLP to work correctly. It should work fine even the original source table is
-- deleted (this basically ensures that all the necessary info is captured in
-- the summary table). So name the original source table as
-- 'iris_data_does_not_exist' instead of the original 'iris_data', to mimic the
-- scenario where the original source table is deleted and MLP is trained with
-- the preprocessed table.
INSERT INTO iris_data_batch_summary VALUES
('iris_data_does_not_exist', 'iris_data_batch', 'class::TEXT', 'attributes',
'text', 30, ARRAY[1,2,3], 141, 0, '');
-- Create the corresponding standardization table for preprocessed data
CREATE TABLE iris_data_batch_standardization(
mean double precision[],
std double precision[]
);
INSERT INTO iris_data_batch_standardization VALUES
-- -- TODO get real numbers by running preprocessor
(ARRAY[5.74893617021,3.02482269504,3.6865248227,1.18014184397],
ARRAY[0.785472439601,0.396287027644,1.68671151195,0.750245336531]);
-- -- minibatch without grouping and without warm_start
DROP TABLE IF EXISTS mlp_class_batch, mlp_class_batch_summary, mlp_class_batch_standardization;
SELECT mlp_classification(
'iris_data_batch', -- Source table
'mlp_class_batch', -- Desination table
'independent_varname', -- Input features
'dependent_varname', -- Label
ARRAY[5], -- Number of units per layer
'learning_rate_init=0.1,
learning_rate_policy=constant,
n_iterations=5,
n_tries=3,
tolerance=0,
n_epochs=20',
'sigmoid',
'',
False,
False
);