blob: e6e1b799ffe963a94136682ccb1caecf8754fc21 [file] [log] [blame]
use strict;
use warnings;
use Test::More tests => 38;
use AI::MXNet qw(mx);
my $shape = [4, 4];
my $keys = [5,7,9];
sub init_kv
{
# init kv
my $kv = mx->kv->create();
# single
$kv->init(3, mx->nd->zeros($shape));
# list
$kv->init($keys, [map { mx->nd->zeros($shape) } 0..@$keys-1]);
return $kv;
}
sub check_diff_to_scalar
{
# assert A == x
my ($A, $x) = @_;
ok(($A - $x)->aspdl->abs->sum == 0);
}
sub test_single_kv_pair
{
# single key-value pair push & pull
my $kv = init_kv();
$kv->push(3, mx->nd->ones($shape));
my $val = mx->nd->empty($shape);
$kv->pull(3, out => $val);
check_diff_to_scalar($val, 1);
}
sub test_init
{
my $kv = mx->kv->create();
$kv->init(3, mx->nd->ones($shape)*4);
my $a = mx->nd->zeros($shape);
$kv->pull(3, out=>$a);
check_diff_to_scalar($a, 4);
}
sub test_list_kv_pair
{
# list key-value pair push & pull
my $kv = init_kv();
$kv->push($keys, [map {mx->nd->ones($shape)*4} 0..@$keys-1]);
my $val = [map { mx->nd->empty($shape) } 0..@$keys-1];
$kv->pull($keys, out => $val);
for my $v (@$val)
{
check_diff_to_scalar($v, 4);
}
}
sub test_aggregator
{
# aggregate value on muliple devices
my $kv = init_kv();
# devices
my $num_devs = 4;
my $devs = [map { mx->cpu($_) } 0..$num_devs-1];
# single
my $vals = [map { mx->nd->ones($shape, ctx => $_) } @$devs];
$kv->push(3, $vals);
$kv->pull(3, out => $vals);
for my $v (@$vals)
{
check_diff_to_scalar($v, $num_devs);
}
# list
$vals = [map { [map { mx->nd->ones($shape, ctx => $_)*2 } @$devs] } 0..@$keys-1];
$kv->push($keys, $vals);
$kv->pull($keys, out => $vals);
for my $vv (@{ $vals })
{
for my $v (@{ $vv })
{
check_diff_to_scalar($v, $num_devs * 2);
}
}
}
sub updater
{
my ($key, $recv, $local) = @_;
$local += $recv;
}
sub test_updater
{
my ($dev) = @_;
$dev //= 'cpu';
my $kv = init_kv();
$kv->_set_updater(\&updater);
# devices
my $num_devs = 4;
my $devs = [map { mx->$dev($_) } 0..$num_devs-1];
# single
my $vals = [map { mx->nd->ones($shape, ctx => $_) } @$devs];
$kv->push(3, $vals);
$kv->pull(3, out => $vals);
for my $v (@$vals)
{
check_diff_to_scalar($v, $num_devs);
}
# list
$vals = [map { [map { mx->nd->ones($shape, ctx => $_) } @$devs] } 0..@$keys-1];
my $num_push = 10;
for my $i (0..$num_push-1)
{
$kv->push($keys, $vals);
}
$kv->pull($keys, out => $vals);
for my $vv (@{ $vals })
{
for my $v (@{ $vv })
{
check_diff_to_scalar($v, $num_devs * $num_push);
}
}
}
sub test_get_type
{
my $kvtype = 'local_allreduce_cpu';
my $kv = mx->kv->create($kvtype);
is($kv->type, $kvtype);
}
test_init();
test_get_type();
test_single_kv_pair();
test_list_kv_pair();
test_aggregator();
test_updater();