blob: cd18c61c6f448b73830dd695f080cffbdb7323e1 [file] [log] [blame]
#!/usr/bin/env perl
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
use strict;
use warnings;
use AI::MXNet qw(mx);
use AI::MXNet::Base qw(pdl enumerate digitize hash array_index range);
use Getopt::Long qw(HelpMessage);
GetOptions(
'log-interval=i' => \(my $log_interval = 100),
'optimizer=s' => \(my $optimizer = 'adam'),
'cuda' => \(my $cuda = 0 ),
'num-epoch=i' => \(my $num_epoch = 10 ),
'batch-size=i' => \(my $batch_size = 100),
'lr=f' => \(my $lr = 0.001 ),
'help' => sub { HelpMessage(0) },
) or HelpMessage(1);
=head1 NAME
train.pl - Run sparse wide and deep classification
=head1 SYNOPSIS
--log-interval number of batches to wait before logging training status, 100
--optimizer what optimizer to use, 'adam'
--cuda train on gpu with cuda, 0
--num-epoch number of epochs to train, 10
--batch-size number of examples per batch, 100
--lr learning rate, 0.001
=cut
my %allowed_optimizers = qw(adam 1 sgd 1 ftrl 1);
Carp::confess("optimizer can only be one of 'adam', 'sgd', 'ftrl'")
unless exists $allowed_optimizers{ $optimizer };
sub wide_deep_model
{
my ($num_linear_features, $num_embed_features, $num_cont_features,
$input_dims, $hidden_units) = @_;
# wide model
my $csr_data = mx->symbol->Variable("csr_data", stype=>'csr');
my $label = mx->symbol->Variable("softmax_label");
my $norm_init = mx->initializer->Normal(sigma=>0.01);
# weight with row_sparse storage type to enable sparse gradient updates
my $weight = mx->symbol->Variable("linear_weight", shape=>[$num_linear_features, 2],
init=>$norm_init, stype=>'row_sparse');
my $bias = mx->symbol->Variable("linear_bias", shape=>[2]);
my $dot = mx->symbol->sparse->dot($csr_data, $weight);
my $linear_out = mx->symbol->broadcast_add($dot, $bias);
# deep model
my $dns_data = mx->symbol->Variable("dns_data");
# embedding features
my $x = mx->symbol->slice(data=>$dns_data, begin=>[0, 0],
end=>[undef, $num_embed_features]);
my $embeds = mx->symbol->split(data=>$x, num_outputs=>$num_embed_features, squeeze_axis=>1);
# continuous features
$x = mx->symbol->slice(data=>$dns_data, begin=>[0, $num_embed_features],
end=>[undef, $num_embed_features + $num_cont_features]);
my @features = ($x);
enumerate(sub {
my ($i, $embed) = @_;
my $embed_weight = mx->symbol->Variable("embed_${i}_weight", stype=>'row_sparse');
push @features, mx->symbol->contrib->SparseEmbedding(data=>$embed, weight=>$embed_weight,
input_dim=>$input_dims->[$i], output_dim=>$hidden_units->[0]);
}, $embeds);
my $hidden = mx->symbol->concat(@features, dim=>1);
$hidden = mx->symbol->FullyConnected(data=>$hidden, num_hidden=>$hidden_units->[1]);
$hidden = mx->symbol->Activation(data=>$hidden, act_type=>'relu');
$hidden = mx->symbol->FullyConnected(data=>$hidden, num_hidden=>$hidden_units->[2]);
$hidden = mx->symbol->Activation(data=>$hidden, act_type=>'relu');
my $deep_out = mx->symbol->FullyConnected(data=>$hidden, num_hidden=>2);
my $out = mx->symbol->SoftmaxOutput($linear_out + $deep_out, $label, name=>'model');
return $out;
}
sub preprocess_uci_adult
{
my ($data_name) = @_;
# Some tricks of feature engineering are adapted
# from tensorflow's wide and deep tutorial.
my @csv_columns = (
"age", "workclass", "fnlwgt", "education", "education_num",
"marital_status", "occupation", "relationship", "race", "gender",
"capital_gain", "capital_loss", "hours_per_week", "native_country",
"income_bracket"
);
my %vocabulary_dict = (
"gender" => [
"Female", "Male"
],
"education" => [
"Bachelors", "HS-grad", "11th", "Masters", "9th",
"Some-college", "Assoc-acdm", "Assoc-voc", "7th-8th",
"Doctorate", "Prof-school", "5th-6th", "10th", "1st-4th",
"Preschool", "12th"
],
"marital_status" => [
"Married-civ-spouse", "Divorced", "Married-spouse-absent",
"Never-married", "Separated", "Married-AF-spouse", "Widowed"
],
"relationship" => [
"Husband", "Not-in-family", "Wife", "Own-child", "Unmarried",
"Other-relative"
],
"workclass" => [
"Self-emp-not-inc", "Private", "State-gov", "Federal-gov",
"Local-gov", "?", "Self-emp-inc", "Without-pay", "Never-worked"
]
);
# wide columns
my @crossed_columns = (
["education", "occupation"],
["native_country", "occupation"],
["age_buckets", "education", "occupation"],
);
my @age_boundaries = (18, 25, 30, 35, 40, 45, 50, 55, 60, 65);
# deep columns
my @indicator_columns = ('workclass', 'education', 'gender', 'relationship');
my @embedding_columns = ('native_country', 'occupation');
my @continuous_columns = ('age', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week');
# income_bracket column is the label
my %labels = ("<=50K" => 0, ">50K" => 1);
my $hash_bucket_size = 1000;
my $csr_ncols = @crossed_columns * $hash_bucket_size;
my $dns_ncols = @continuous_columns + @embedding_columns;
for my $col (@indicator_columns)
{
$dns_ncols += @{ $vocabulary_dict{$col} };
}
my @label_list;
my @csr_list;
my @dns_list;
open(F, $data_name) or die $!;
while(<F>)
{
chomp;
my %row;
@row{ @csv_columns } = split(/\s*,\s*/);
next if not defined $row{income_bracket};
push @label_list, $labels{$row{income_bracket}};
enumerate(sub {
my ($i, $cols) = @_;
if($cols->[0] eq "age_buckets")
{
my $age_bucket = digitize($row{age}, \@age_boundaries);
my $s = join('_', map { $row{$_} } @{ $cols }[1..@{$cols}-1]);
$s .= '_' . $age_bucket;
push @csr_list, [$i * $hash_bucket_size + hash($s) % $hash_bucket_size, 1];
}
else
{
my $s = join('_', map { $row{$_} } @{ $cols });
push @csr_list, [$i * $hash_bucket_size + hash($s) % $hash_bucket_size, 1];
}
}, \@crossed_columns);
my @dns_row = (0) x $dns_ncols;
my $dns_dim = 0;
for my $col (@embedding_columns)
{
$dns_row[$dns_dim] = hash($row{$col}) % $hash_bucket_size;
$dns_dim += 1;
}
for my $col (@indicator_columns)
{
$dns_row[$dns_dim + array_index($row{$col}, $vocabulary_dict{$col})] = 1;
$dns_dim += @{$vocabulary_dict{$col}};
}
for my $col (@continuous_columns)
{
$dns_row[$dns_dim] = $row{col};
$dns_dim += 1;
}
push @dns_list, \@dns_row;
}
my @data_list = map { $_->[1] } @csr_list;
my @indices_list = map { $_->[0] } @csr_list;
my @indptr_list = range(0, @indices_list + 1, scalar @crossed_columns);
# convert to ndarrays
my $csr = mx->nd->sparse->csr_matrix([\@data_list, \@indices_list, \@indptr_list],
shape=>[scalar @label_list, $hash_bucket_size * @crossed_columns]);
my $dns = pdl(\@dns_list);
my $label = pdl(\@label_list);
return ($csr, $dns, $label);
}
# Related to feature engineering, please see preprocess in data.py
my %ADULT = (
train => './data/adult.data',
test => './data/adult.test',
num_linear_features => 3000,
num_embed_features => 2,
num_cont_features => 38,
embed_input_dims => [1000, 1000],
hidden_units => [8, 50, 100],
);
my $ctx = $cuda ? mx->gpu : mx->cpu;
# dataset
my ($train_csr, $train_dns, $train_label) = preprocess_uci_adult($ADULT{train});
my ($val_csr, $val_dns, $val_label) = preprocess_uci_adult($ADULT{test});
my $model = wide_deep_model(
$ADULT{num_linear_features}, $ADULT{num_embed_features},
$ADULT{num_cont_features}, $ADULT{embed_input_dims},
$ADULT{hidden_units}
);
# data iterator
my $train_data = mx->io->NDArrayIter(
data => Hash::Ordered->new(csr_data => $train_csr, dns_data => $train_dns),
label => Hash::Ordered->new(softmax_label => $train_label),
batch_size => $batch_size,
shuffle => 1,
last_batch_handle => 'discard'
);
my $eval_data = mx->io->NDArrayIter(
data => Hash::Ordered->new(csr_data => $val_csr, val_data => $val_dns),
label => Hash::Ordered->new(softmax_label => $val_label),
batch_size => $batch_size,
shuffle => 0,
last_batch_handle => 'discard'
);
# module
my $mod = mx->mod->Module(
symbol => $model, context => $ctx, data_names=>['csr_data', 'dns_data'],
label_names => ['softmax_label']
);
$mod->bind(data_shapes => $train_data->provide_data, label_shapes => $train_data->provide_label);
$mod->init_params();
my $optim = mx->optimizer->create($optimizer, learning_rate=>$lr, rescale_grad=>1/$batch_size);
$mod->init_optimizer(optimizer=>$optim);
# use accuracy as the metric
my $metric = mx->metric->create('acc');
# get the sparse weight parameter
my $speedometer = mx->callback->Speedometer($batch_size, $log_interval);
print "Training started ...\n";
for my $epoch (0..$num_epoch-1)
{
my $nbatch = 0;
$metric->reset;
while(my $batch = <$train_data>)
{
$nbatch++;
$mod->forward_backward($batch);
# update all parameters (including the weight parameter)
$mod->update;
# update training metric
$mod->update_metric($metric, $batch->label);
my $speedometer_param = AI::MXNet::BatchEndParam->new(
epoch=>$epoch, nbatch=>$nbatch,
eval_metric=>$metric
);
$speedometer->($speedometer_param);
}
# evaluate metric on validation dataset
my $score = $mod->score($eval_data, 'acc');
printf("epoch %d, validation accuracy = %.4f\n", $epoch, $score->{accuracy});
$mod->save_checkpoint("checkpoint", $epoch, 1);
# reset the iterator for next pass of data
$train_data->reset;
}
print "Training completed.\n";