blob: 1a91ae4d0a9af4e44ff5108726aede12584b1e6d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
use strict;
use warnings;
use Test::More tests => 14;
use AI::MXNet qw(mx);
use Storable;
sub contains
{
my ($x, $y) = @_;
while(my ($k, $v) = each %$x)
{
return 0 unless exists $y->{$k};
if(ref $y->{$k} and ref $y->{$k} eq 'HASH')
{
return 0 unless (ref $v and ref $v eq 'HASH');
return 0 unless contains($v, $y->{$k});
}
elsif($y->{$k} ne $v)
{
return 0;
}
}
return 1;
}
sub test_attr_basic
{
my ($data, $gdata);
{
local($mx::AttrScope) = mx->AttrScope(group=>'4', data=>'great');
$data = mx->symbol->Variable(
'data',
attr => {
qw/ dtype data
group 1
force_mirroring 1/
},
lr_mult => 1);
$gdata = mx->symbol->Variable('data2');
}
ok($gdata->attr('group') == 4);
ok($data->attr('group') == 1);
ok($data->attr('lr_mult') == 1);
ok($data->attr('__lr_mult__') == 1);
ok($data->attr('force_mirroring') == 1);
ok($data->attr('__force_mirroring__') == 1);
my $data2 = Storable::thaw(Storable::freeze($data));
ok($data->attr('dtype') eq $data2->attr('dtype'));
}
sub test_operator
{
my $data = mx->symbol->Variable('data');
my ($fc1, $fc2);
{
local($mx::AttrScope) = mx->AttrScope(__group__=>'4', __data__=>'great');
$fc1 = mx->symbol->Activation($data, act_type=>'relu');
{
local($mx::AttrScope) = mx->AttrScope(__init_bias__ => 0,
__group__=>'4', __data__=>'great');
$fc2 = mx->symbol->FullyConnected($fc1, num_hidden=>10, name=>'fc2');
}
}
ok($fc1->attr('__data__') eq 'great');
ok($fc2->attr('__data__') eq 'great');
ok($fc2->attr('__init_bias__') == 0);
my $fc2copy = Storable::thaw(Storable::freeze($fc2));
ok($fc2copy->tojson() eq $fc2->tojson());
ok($fc2->get_internals()->slice('fc2_weight'));
}
sub test_list_attr
{
my $data = mx->sym->Variable('data', attr=>{'mood', 'angry'});
my $op = mx->sym->Convolution(
data=>$data, name=>'conv', kernel=>[1, 1],
num_filter=>1, attr => {'__mood__'=> 'so so', 'wd_mult'=> 'x'}
);
ok(contains({'__mood__'=> 'so so', 'wd_mult'=> 'x', '__wd_mult__'=> 'x'}, $op->list_attr()));
}
sub test_attr_dict
{
my $data = mx->sym->Variable('data', attr=>{'mood'=> 'angry'});
my $op = mx->sym->Convolution(
data=>$data, name=>'conv', kernel=>[1, 1],
num_filter=>1, attr=>{'__mood__'=> 'so so'}, lr_mult=>1
);
ok(
contains(
{
'data'=> {'mood'=> 'angry'},
'conv_weight'=> {'__mood__'=> 'so so'},
'conv'=> {
'kernel'=> '(1, 1)', '__mood__'=> 'so so',
'num_filter'=> '1', 'lr_mult'=> '1', '__lr_mult__'=> '1'
},
'conv_bias'=> {'__mood__'=> 'so so'}
},
$op->attr_dict()
)
);
}
test_attr_basic();
test_operator();
test_list_attr();
test_attr_dict();