blob: 0c01bd50dad2f2734481929ca3cfc08e6e94e227 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
use strict;
use warnings;
use Test::More tests => 2285;
use AI::MXNet qw(mx);
use AI::MXNet::TestUtils qw(reldiff pdl_maximum pdl_minimum);
use PDL;
sub check_bind_with_uniform
{
my ($uf, $gf, $dim, $sf, $lshape, $rshape) = @_;
my $shape = (random($dim)*int(1000**(1.0/$dim))+1)->floor->unpdl;
my $lhs = mx->symbol->Variable('lhs');
my $rhs = mx->symbol->Variable('rhs');
my $ret;
if(defined $sf)
{
$ret = &{$sf}($lhs, $rhs);
}
else
{
$ret = &{$uf}($lhs, $rhs);
}
is_deeply($ret->list_arguments(), ['lhs', 'rhs']);
$lshape //= $shape;
$rshape //= $shape;
my $lhs_arr = mx->nd->array(random(reverse (@$lshape)));
my $rhs_arr = mx->nd->array(random(reverse (@$rshape)));
my $lhs_grad = mx->nd->empty($lshape);
my $rhs_grad = mx->nd->empty($rshape);
my $executor = $ret->bind(
ctx => mx->Context('cpu'),
args => [$lhs_arr, $rhs_arr],
args_grad => [$lhs_grad, $rhs_grad]
);
my $exec3 = $ret->bind(
ctx => mx->Context('cpu'),
args => [$lhs_arr, $rhs_arr]
);
my $exec4 = $ret->bind(
ctx => mx->Context('cpu'),
args => {'rhs' => $rhs_arr, 'lhs' => $lhs_arr},
args_grad=>{'lhs' => $lhs_grad, 'rhs' => $rhs_grad}
);
$executor->forward(1);
$exec3->forward(1);
$exec4->forward(1);
my $out2 = $executor->outputs->[0]->aspdl;
my $out1 = &{$uf}($lhs_arr->aspdl, $rhs_arr->aspdl);
my $out3 = $exec3->outputs->[0]->aspdl;
my $out4 = $exec4->outputs->[0]->aspdl;
ok(reldiff($out1, $out2) < 1e-6);
ok(reldiff($out1, $out3) < 1e-6);
ok(reldiff($out1, $out4) < 1e-6);
# test gradient
my $out_grad = mx->nd->ones([reverse @{$out2->shape->unpdl}]);
my ($lhs_grad2, $rhs_grad2) = &{$gf}(
$out_grad->aspdl,
$lhs_arr->aspdl,
$rhs_arr->aspdl
);
$executor->backward([$out_grad]);
ok(reldiff($lhs_grad->aspdl, $lhs_grad2) < 1e-6);
ok(reldiff($rhs_grad->aspdl, $rhs_grad2) < 1e-6);
}
sub test_bind
{
my ($disable_bulk_exec) = @_;
my ($prev_fwd_var, $prev_bwd_var);
if($disable_bulk_exec)
{
$prev_fwd_var = $ENV{MXNET_EXEC_BULK_FWD_THRESHOLD_TRAIN}//1;
$prev_bwd_var = $ENV{MXNET_EXEC_BULK_BWD_TRAIN}//1;
$ENV{MXNET_EXEC_BULK_FWD_THRESHOLD_TRAIN} = 0;
$ENV{MXNET_EXEC_BULK_BWD_TRAIN} = 0;
}
srand(0);
my $nrepeat = 9;
my $maxdim = 3;
for my $repeat (0..$nrepeat)
{
for my $dim (1..$maxdim)
{
check_bind_with_uniform(sub { my ($x, $y) = @_; $x + $y },
sub { my ($g) = @_; ($g, $g) },
$dim);
check_bind_with_uniform(sub { my ($x, $y) = @_; $x - $y },
sub { my ($g) = @_; ($g, -$g) },
$dim);
check_bind_with_uniform(sub { my ($x, $y) = @_; $x * $y },
sub { my ($g, $x, $y) = @_; ($g*$y, $g*$x) },
$dim);
check_bind_with_uniform(sub { my ($x, $y) = @_; $x / $y },
sub { my ($g, $x, $y) = @_; ($g / $y, -$x * $g/ ($y**2)) },
$dim);
check_bind_with_uniform(sub { my ($x, $y) = @_; pdl_maximum($x, $y) },
sub { my ($g, $x, $y) = @_; ($g * ($x>$y), $g * ($y>$x)) },
$dim,
sub { $_[0]->maximum($_[1]) });
check_bind_with_uniform(sub { my ($x, $y) = @_; pdl_minimum($x, $y) },
sub { my ($g, $x, $y) = @_; ($g * ($x<$y), $g * ($y<$x)) },
$dim,
sub { $_[0]->minimum($_[1]) });
}
}
if($disable_bulk_exec)
{
$ENV{MXNET_EXEC_BULK_FWD_THRESHOLD_TRAIN} = $prev_fwd_var;
$ENV{MXNET_EXEC_BULK_BWD_TRAIN} = $prev_bwd_var;
}
}
sub test_dot
{
srand(0);
my $nrepeat = 9;
my $maxdim = 4;
for my $repeat (0..$nrepeat)
{
my $shape = (random(3)*500+1)->floor->unpdl;
check_bind_with_uniform(sub { my ($x, $y) = @_; $x x $y },
sub { my ($g, $x, $y) = @_; ($g x $y->transpose, $x->transpose x $g) },
2,
sub { mx->symbol->dot(@_) },
[@{$shape}[0, 1]],
[@{$shape}[1, 2]],
);
}
for my $repeat (0..$nrepeat)
{
my $shape = (random(1)*500+1)->floor->unpdl;
check_bind_with_uniform(sub { my ($x, $y) = @_; $x x $y->transpose },
sub { my ($g, $x, $y) = @_; ($g * $y, $g * $x) },
2,
sub { mx->symbol->dot(@_) },
[@{$shape}[0]],
[@{$shape}[0]],
);
}
}
sub test_reshape
{
my $x = mx->sym->Variable('x');
my $y = mx->sym->FullyConnected($x, num_hidden=>4);
my $exe = $y->simple_bind(ctx => mx->cpu(), shapes => { x=>[5,4] }, grad_req=>'null');
$exe->arg_arrays->[0] .= 1;
$exe->arg_arrays->[1] .= mx->nd->ones([4,4]);
$exe->arg_arrays->[2] .= 0;
my $new_exe = $exe->reshape({ x=>[3,4] });
$new_exe->forward(0);
# test sub exec forward
ok(($new_exe->outputs->[0]->aspdl == 4)->all);
# test shared memory
ok(($exe->outputs->[0]->aspdl->slice('X', [0,2]) == 4)->all);
# test base exec forward
$exe->forward(0);
ok(($new_exe->outputs->[0]->aspdl == 4)->all);
$new_exe = $exe->reshape({ x=>[6,4] }, allow_up_sizing=>1);
# data ndarray is not shared between exe and new_exe
$new_exe->arg_arrays->[0] .= 0;
ok(($exe->arg_arrays->[0]->aspdl == 1)->all);
# weight ndarray is shared between exe and new_exe
ok(($new_exe->arg_arrays->[1]->aspdl == 1)->all);
}
test_bind(0);
test_bind(1);
test_dot();
test_reshape();