| package AI::MXNet::Symbol; |
| |
| =head1 NAME |
| |
| AI::MXNet::Symbol - Symbolic interface of MXNet. |
| =cut |
| |
| use strict; |
| use warnings; |
| use AI::MXNet::Base; |
| use AI::MXNet::Symbol::Base; |
| use AI::MXNet::Types; |
| use Mouse; |
| use AI::MXNet::Function::Parameters; |
| use overload |
| '""' => \&stringify, |
| '+' => \&add, |
| '-' => \&subtract, |
| '*' => \&multiply, |
| '/' => \÷, |
| '/=' => \&idivide, |
| '**' => \&power, |
| '%' => \&mod, |
| '==' => \&equal, |
| '!=' => \¬_equal, |
| '>' => \&greater, |
| '>=' => \&greater_equal, |
| '<' => \&lesser, |
| '<=' => \&lesser_equal, |
| '&{}' => sub { my $self = shift; sub { $self->call(@_) } }, |
| '@{}' => sub { my $self = shift; [map { $self->slice($_) } @{ $self->list_outputs }] }; |
| |
| extends 'AI::MXNet::Symbol::Base'; |
| has 'handle' => (is => 'rw', isa => 'SymbolHandle', required => 1); |
| |
| sub DEMOLISH |
| { |
| check_call(AI::NNVMCAPI::SymbolFree(shift->handle)); |
| } |
| |
| method STORABLE_freeze($cloning) |
| { |
| return $self->tojson(); |
| } |
| |
| method STORABLE_thaw($cloning, $json) |
| { |
| my $handle = check_call( |
| AI::MXNetCAPI::SymbolCreateFromJSON( |
| $json |
| ) |
| ); |
| $self->handle($handle); |
| } |
| |
| method stringify($other=, $reverse=) |
| { |
| my $name = $self->name; |
| sprintf("<%s %s>", ref($self), $name ? $name : 'Grouped'); |
| } |
| |
| method add(AI::MXNet::Symbol|Num $other, $reverse=) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_Plus _PlusScalar/ |
| ); |
| } |
| |
| method subtract(AI::MXNet::Symbol|Num $other, $reverse=) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_Minus _MinusScalar _RMinusScalar/, |
| $reverse |
| ); |
| } |
| |
| method multiply(AI::MXNet::Symbol|Num $other, $reverse=) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_Mul _MulScalar/ |
| ); |
| } |
| |
| method divide(AI::MXNet::Symbol|Num $other, $reverse=) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_Div _DivScalar _RDivScalar/, |
| $reverse |
| ); |
| } |
| |
| method power(AI::MXNet::Symbol|Num $other, $reverse=) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_Power _PowerScalar _RPowerScalar/, |
| $reverse |
| ); |
| } |
| |
| method equal(AI::MXNet::Symbol|Num $other, $reverse=) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_equal _equal_scalar/ |
| ); |
| } |
| |
| method not_equal(AI::MXNet::Symbol|Num $other, $reverse=) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_not_equal _not_equal_scalar/ |
| ); |
| } |
| |
| method greater(AI::MXNet::Symbol|Num $other, $reverse=) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_greater _greater_scalar _lesser_scalar/, |
| $reverse |
| ); |
| } |
| |
| method greater_equal(AI::MXNet::Symbol|Num $other, $reverse=) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_greater_equal _greater_equal_scalar _lesser_equal_scalar/, |
| $reverse |
| ); |
| } |
| |
| method lesser(AI::MXNet::Symbol|Num $other, $reverse=) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_lesser _lesser_scalar _greater_scalar/, |
| $reverse |
| ); |
| } |
| |
| method lesser_equal(AI::MXNet::Symbol|Num $other, $reverse=) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_lesser_equal _lesser_equal_scalar _greater_equal_scalar/, |
| $reverse |
| ); |
| } |
| |
| method true_divide(AI::MXNet::Symbol|Num $other, $reverse=) |
| { |
| return $self->divide($other, $reverse); |
| } |
| |
| method mod(AI::MXNet::Symbol|Num $other, $reverse=) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_Mod _ModScalar _RModScalar/, |
| $reverse |
| ); |
| } |
| |
| method maximum(AI::MXNet::Symbol|Num $other) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_Maximum _MaximumScalar/ |
| ); |
| } |
| |
| method minimum(AI::MXNet::Symbol|Num $other) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_Minimum _MinimumScalar/ |
| ); |
| } |
| |
| method hypot(AI::MXNet::Symbol|Num $other) |
| { |
| return _ufunc_helper( |
| $self, |
| $other, |
| qw/_Hypot _HypotScalar/ |
| ); |
| } |
| |
| method deepcopy() |
| { |
| my $handle = check_call(AI::MXNetCAPI::SymbolCopy($self->handle)); |
| return __PACKAGE__->new(handle => $handle); |
| } |
| |
| method call(@args) |
| { |
| my $s = $self->deepcopy(); |
| $s->_compose(@args); |
| return $s; |
| } |
| |
| method slice(Str|Index $index) |
| { |
| ## __getitem__ tie needs to die |
| if(not find_type_constraint('Index')->check($index)) |
| { |
| my $i = 0; |
| my $idx; |
| for my $name (@{ $self->list_outputs() }) |
| { |
| if($name eq $index) |
| { |
| if(defined $idx) |
| { |
| confess(qq/There are multiple outputs with name "$index"/); |
| } |
| $idx = $i; |
| } |
| $i++; |
| } |
| confess(qq/Cannot find output that matches name "$index"/) unless defined $idx; |
| $index = $idx; |
| } |
| elsif($index >= @{ $self->list_outputs() }) |
| { |
| confess("Index: [$index] is outside of the range of the symbol: $self outputs"); |
| } |
| my $handle = check_call(AI::MXNetCAPI::SymbolGetOutput($self->handle, $index)); |
| return __PACKAGE__->new(handle => $handle); |
| } |
| |
| =head2 name |
| |
| Get name string from the symbol, this function only works for non-grouped symbol. |
| |
| Returns |
| ------- |
| value : str |
| The name of this symbol, returns None for grouped symbol. |
| =cut |
| |
| method name() |
| { |
| my ($name, $success) = check_call(AI::MXNetCAPI::SymbolGetName($self->handle)); |
| return $success ? $name : undef; |
| } |
| |
| =head2 attr |
| |
| Get an attribute string from the symbol, this function only works for non-grouped symbol. |
| |
| Parameters |
| ---------- |
| key : str |
| The key to get attribute from. |
| |
| Returns |
| ------- |
| value : str |
| The attribute value of the key, returns None if attribute do not exist. |
| =cut |
| |
| |
| method attr(Str $key) |
| { |
| my ($attr, $success) = check_call( |
| AI::MXNetCAPI::SymbolGetAttr($self->handle, $key) |
| ); |
| return $success ? $attr : undef; |
| } |
| |
| =head2 list_attr |
| |
| Get all attributes from the symbol. |
| |
| Returns |
| ------- |
| ret : hash ref of str to str |
| a dicitonary mapping attribute keys to values |
| =cut |
| |
| method list_attr() |
| { |
| my %ret; |
| my @attrs = @{ check_call(AI::MXNetCAPI::SymbolListAttrShallow($self->handle)) }; |
| while(@attrs) |
| { |
| my $k = shift(@attrs); |
| my $v = shift(@attrs); |
| $ret{ $k } = $v; |
| } |
| return \%ret; |
| } |
| |
| =head2 attr_dict |
| |
| Recursively get all attributes from the symbol and its childrens |
| |
| Returns |
| ------- |
| ret : hash ref of str to hash ref. |
| Returns a dict whose keys are names of the symbol and its children. |
| Values of the returned dict are dictionaries that map attribute keys to values. |
| =cut |
| |
| method attr_dict() |
| { |
| my %ret; |
| my @attrs = @{ check_call(AI::MXNetCAPI::SymbolListAttr($self->handle)) }; |
| my $size = @attrs/2; |
| for (my $i = 0; $i < $size; $i++) |
| { |
| my ($name, $key) = split(/\$/, $attrs[$i*2]); |
| my $val = $attrs[$i*2+1]; |
| $ret{ $name }{ $key } = $val; |
| } |
| return \%ret; |
| } |
| |
| method _set_attr(Str @args) |
| { |
| my %kwargs = @args; |
| while(my ($key, $val) = each(%kwargs)) |
| { |
| check_call( |
| AI::MXNetCAPI::SymbolSetAttr( |
| $self->handle, $key, $val |
| ) |
| ); |
| } |
| } |
| |
| =head2 get_internals |
| |
| Get a new grouped symbol whose output contains all the internal outputs of this symbol. |
| |
| Returns |
| ------- |
| sgroup : AI::MXNet::Symbol |
| The internal symbol of the symbol. |
| =cut |
| |
| method get_internals() |
| { |
| my $handle = check_call(AI::MXNetCAPI::SymbolGetInternals($self->handle)); |
| return __PACKAGE__->new(handle => $handle); |
| } |
| |
| =head2 get_children |
| |
| Get a new grouped symbol whose output contains |
| inputs to output nodes of the original symbol |
| |
| Returns |
| ------- |
| sgroup : Symbol or undef |
| The children of the head node. If the symbol has no |
| inputs undef will be returned. |
| =cut |
| |
| |
| method get_children() |
| { |
| my $handle = check_call(AI::MXNetCAPI::SymbolGetChildren($self->handle)); |
| my $ret = __PACKAGE__->new(handle => $handle); |
| return undef unless @{ $ret->list_outputs }; |
| return $ret; |
| } |
| |
| =head2 list_arguments |
| |
| List all the arguments in the symbol. |
| |
| Returns |
| ------- |
| args : array ref of strings |
| =cut |
| |
| method list_arguments() |
| { |
| return scalar(check_call(AI::MXNetCAPI::SymbolListArguments($self->handle))); |
| } |
| |
| =head2 list_outputs() |
| |
| List all outputs in the symbol. |
| |
| Returns |
| ------- |
| $out : array ref of strings. |
| =cut |
| |
| method list_outputs() |
| { |
| return scalar(check_call(AI::MXNetCAPI::SymbolListOutputs($self->handle))); |
| } |
| |
| |
| =head2 list_auxiliary_states() |
| |
| List all auxiliary states in the symbol. |
| |
| Returns |
| ------- |
| aux_states : array ref of string |
| List the names of the auxiliary states. |
| |
| Notes |
| ----- |
| Auxiliary states are special states of symbols that do not corresponds to an argument, |
| and do not have gradient. But still be useful for the specific operations. |
| A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm. |
| Most operators do not have Auxiliary states. |
| =cut |
| |
| method list_auxiliary_states() |
| { |
| return scalar(check_call(AI::MXNetCAPI::SymbolListAuxiliaryStates($self->handle))); |
| } |
| |
| |
| =head2 list_inputs |
| |
| Lists all arguments and auxiliary states of this Symbol. |
| |
| Returns |
| ------- |
| inputs : array ref of str |
| List of all inputs. |
| |
| Examples |
| -------- |
| >>> my $bn = mx->sym->BatchNorm(name=>'bn'); |
| =cut |
| |
| method list_inputs() |
| { |
| return scalar(check_call(AI::NNVMCAPI::SymbolListInputNames($self->handle, 0))); |
| } |
| |
| =head2 infer_type |
| |
| Infer the type of outputs and arguments of given known types of arguments. |
| |
| User can either pass in the known types in positional way or keyword argument way. |
| Tuple of Nones is returned if there is not enough information passed in. |
| An error will be raised if there is inconsistency found in the known types passed in. |
| |
| Parameters |
| ---------- |
| args : Array |
| Provide type of arguments in a positional way. |
| Unknown type can be marked as None |
| |
| kwargs : Hash ref, must ne ssupplied as as sole argument to the method. |
| Provide keyword arguments of known types. |
| |
| Returns |
| ------- |
| arg_types : array ref of Dtype or undef |
| List of types of arguments. |
| The order is in the same order as list_arguments() |
| out_types : array ref of Dtype or undef |
| List of types of outputs. |
| The order is in the same order as list_outputs() |
| aux_types : array ref of Dtype or undef |
| List of types of outputs. |
| The order is in the same order as list_auxiliary() |
| =cut |
| |
| |
| method infer_type(Str|Undef @args) |
| { |
| my ($positional_arguments, $kwargs, $kwargs_order) = _parse_arguments("Dtype", @args); |
| my $sdata = []; |
| my $keys = []; |
| if(@$positional_arguments) |
| { |
| @{ $sdata } = map { defined($_) ? DTYPE_STR_TO_MX->{ $_ } : -1 } @{ $positional_arguments }; |
| } |
| else |
| { |
| @{ $keys } = @{ $kwargs_order }; |
| @{ $sdata } = map { DTYPE_STR_TO_MX->{ $_ } } @{ $kwargs }{ @{ $kwargs_order } }; |
| } |
| my ($arg_type, $out_type, $aux_type, $complete) = check_call(AI::MXNetCAPI::SymbolInferType( |
| $self->handle, |
| scalar(@{ $sdata }), |
| $keys, |
| $sdata |
| ) |
| ); |
| if($complete) |
| { |
| return ( |
| [ map { DTYPE_MX_TO_STR->{ $_ } } @{ $arg_type }], |
| [ map { DTYPE_MX_TO_STR->{ $_ } } @{ $out_type }], |
| [ map { DTYPE_MX_TO_STR->{ $_ } } @{ $aux_type }] |
| ); |
| } |
| else |
| { |
| return (undef, undef, undef); |
| } |
| } |
| |
| =head2 infer_shape |
| |
| Infer the shape of outputs and arguments of given known shapes of arguments. |
| |
| User can either pass in the known shapes in positional way or keyword argument way. |
| Tuple of Nones is returned if there is not enough information passed in. |
| An error will be raised if there is inconsistency found in the known shapes passed in. |
| |
| Parameters |
| ---------- |
| *args : |
| Provide shape of arguments in a positional way. |
| Unknown shape can be marked as undef |
| |
| **kwargs : |
| Provide keyword arguments of known shapes. |
| |
| Returns |
| ------- |
| arg_shapes : array ref of Shape or undef |
| List of shapes of arguments. |
| The order is in the same order as list_arguments() |
| out_shapes : array ref of Shape or undef |
| List of shapes of outputs. |
| The order is in the same order as list_outputs() |
| aux_shapes : array ref of Shape or undef |
| List of shapes of outputs. |
| The order is in the same order as list_auxiliary() |
| =cut |
| |
| method infer_shape(Maybe[Str|Shape] @args) |
| { |
| my @res = $self->_infer_shape_impl(0, @args); |
| if(not defined $res[1]) |
| { |
| my ($arg_shapes) = $self->_infer_shape_impl(1, @args); |
| my $arg_names = $self->list_arguments; |
| my @unknowns; |
| zip(sub { |
| my ($name, $shape) = @_; |
| if(not ref $shape or not @$shape or not product(@$shape)) |
| { |
| if(@unknowns >= 10) |
| { |
| $unknowns[10] = '...'; |
| } |
| else |
| { |
| my @shape = eval { @$shape }; |
| push @unknowns, "$name @shape"; |
| } |
| } |
| }, $arg_names, $arg_shapes); |
| AI::MXNet::Logging->warning( |
| "Cannot decide shape for the following arguments " |
| ."(0s in shape means unknown dimensions). " |
| ."Consider providing them as input:\n\t" |
| ."\n\t" |
| .join(", ", @unknowns) |
| ); |
| } |
| return @res; |
| } |
| |
| =head2 infer_shape_partial |
| |
| Partially infer the shape. The same as infer_shape, except that the partial |
| results can be returned. |
| =cut |
| |
| method infer_shape_partial(Maybe[Str|Shape] @args) |
| { |
| $self->_infer_shape_impl(1, @args) |
| } |
| |
| # The actual implementation for calling shape inference API. |
| method _infer_shape_impl(Maybe[Str|Shape] @args) |
| { |
| my $partial = shift(@args); |
| my ($positional_arguments, $kwargs, $kwargs_order) = _parse_arguments("Shape", @args); |
| my $sdata = []; |
| my $indptr = [0]; |
| my $keys = []; |
| if(@{ $positional_arguments }) |
| { |
| for my $shape (grep { defined } @{ $positional_arguments }) |
| { |
| push @{ $sdata }, @{ $shape }; |
| push @{ $indptr }, scalar(@{ $sdata }); |
| } |
| } |
| { |
| for my $k (@{ $kwargs_order }) |
| { |
| push @{ $keys }, $k; |
| push @{ $sdata }, @{ $kwargs->{ $k } }; |
| push @{ $indptr }, scalar(@{ $sdata }); |
| } |
| } |
| my $infer_func = $partial ? \&AI::MXNetCAPI::SymbolInferShapePartial : \&AI::MXNetCAPI::SymbolInferShape; |
| my ($arg_shapes, $out_shapes, $aux_shapes, $complete) = check_call( |
| $infer_func->( |
| $self->handle, |
| scalar(@{ $indptr }) - 1, |
| $keys, |
| $indptr, |
| $sdata, |
| ) |
| ); |
| if($complete) |
| { |
| return $arg_shapes, $out_shapes, $aux_shapes; |
| } |
| else |
| { |
| return (undef, undef, undef); |
| } |
| } |
| |
| =head2 debug_str |
| |
| The debug string. |
| |
| Returns |
| ------- |
| debug_str : string |
| Debug string of the symbol. |
| =cut |
| |
| method debug_str() |
| { |
| return scalar(check_call(AI::MXNetCAPI::SymbolPrint($self->handle))); |
| } |
| |
| =head2 save |
| |
| Save the symbol into a file. |
| |
| You can also use Storable to do the job if you only work with Perl. |
| The advantage of load/save is the file is language agnostic. |
| This means the file saved using save can be loaded by other language binding of mxnet. |
| You also get the benefit being able to directly load/save from cloud storage(S3, HDFS) |
| |
| Parameters |
| ---------- |
| fname : str |
| The name of the file |
| - s3://my-bucket/path/my-s3-symbol |
| - hdfs://my-bucket/path/my-hdfs-symbol |
| - /path-to/my-local-symbol |
| |
| See Also |
| -------- |
| load : Used to load symbol from file. |
| =cut |
| |
| method save(Str $fname) |
| { |
| check_call(AI::MXNetCAPI::SymbolSaveToFile($self->handle, $fname)); |
| } |
| |
| =head2 tojson |
| |
| Save the symbol into a JSON string. |
| |
| See Also |
| -------- |
| load_json : Used to load symbol from JSON string. |
| =cut |
| |
| method tojson() |
| { |
| return scalar(check_call(AI::MXNetCAPI::SymbolSaveToJSON($self->handle))); |
| } |
| |
| method _get_ndarray_inputs( |
| Str $arg_key, |
| HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray] $args, |
| ArrayRef[Str] $arg_names, |
| Bool $allow_missing=0 |
| ) |
| { |
| my ($arg_handles, $arg_arrays) = ([], []); |
| if(ref $args eq 'ARRAY') |
| { |
| confess("Length of $arg_key do not match number of arguments") |
| unless @$args == @$arg_names; |
| @{ $arg_handles } = map { $_->handle } @{ $args }; |
| $arg_arrays = $args; |
| } |
| else |
| { |
| my %tmp = ((map { $_ => undef } @$arg_names), %$args); |
| if(not $allow_missing and grep { not defined } values %tmp) |
| { |
| my ($missing) = grep { not defined $tmp{ $_ } } (keys %tmp); |
| confess("key $missing is missing in $arg_key"); |
| } |
| for my $name (@$arg_names) |
| { |
| push @$arg_handles, defined($tmp{ $name }) ? $tmp{ $name }->handle : undef; |
| push @$arg_arrays, defined($tmp{ $name }) ? $tmp{ $name } : undef; |
| } |
| } |
| return ($arg_handles, $arg_arrays); |
| } |
| |
| =head2 simple_bind |
| |
| Bind current symbol to get an executor, allocate all the ndarrays needed. |
| Allows specifying data types. |
| |
| This function will ask user to pass in ndarray of position |
| they like to bind to, and it will automatically allocate the ndarray |
| for arguments and auxiliary states that user did not specify explicitly. |
| |
| Parameters |
| ---------- |
| :$ctx : AI::MXNet::Context |
| The device context the generated executor to run on. |
| |
| :$grad_req: string |
| {'write', 'add', 'null'}, or list of str or dict of str to str, optional |
| Specifies how we should update the gradient to the args_grad. |
| - 'write' means everytime gradient is write to specified args_grad NDArray. |
| - 'add' means everytime gradient is add to the specified NDArray. |
| - 'null' means no action is taken, the gradient may not be calculated. |
| |
| :$type_dict : hash ref of str->Dtype |
| Input type map, name->dtype |
| |
| :$group2ctx : hash ref of string to AI::MXNet::Context |
| The mapping of the ctx_group attribute to the context assignment. |
| |
| :$shapes : hash ref of str->Shape |
| Input shape map, name->shape |
| |
| :$shared_arg_names : Maybe[ArrayRef[Str]] |
| The argument names whose 'NDArray' of shared_exec can be reused for initializing |
| the current executor. |
| |
| :$shared_exec : Maybe[AI::MXNet::Executor] |
| The executor whose arg_arrays, arg_arrays, grad_arrays, and aux_arrays can be |
| reused for initializing the current executor. |
| |
| :$shared_buffer : Maybe[HashRef[AI::MXNet::NDArray]] |
| The dict mapping argument names to the `NDArray` that can be reused for initializing |
| the current executor. This buffer will be checked for reuse if one argument name |
| of the current executor is not found in `shared_arg_names`. |
| |
| Returns |
| ------- |
| $executor : AI::MXNet::Executor |
| The generated Executor |
| =cut |
| |
| method simple_bind( |
| AI::MXNet::Context :$ctx=AI::MXNet::Context->current_ctx, |
| GradReq|ArrayRef[GradReq]|HashRef[GradReq] :$grad_req='write', |
| Maybe[HashRef[Shape]] :$shapes=, |
| Maybe[HashRef[Dtype]] :$type_dict=, |
| Maybe[HashRef[AI::MXNet::Context]] :$group2ctx=, |
| Maybe[ArrayRef[Str]] :$shared_arg_names=, |
| Maybe[AI::MXNet::Executor] :$shared_exec=, |
| Maybe[HashRef[AI::MXNet::NDArray]] :$shared_buffer= |
| ) |
| { |
| my $num_provided_arg_types; |
| my @provided_arg_type_names; |
| my @provided_arg_type_data; |
| if(defined $type_dict) |
| { |
| while(my ($k, $v) = each %{ $type_dict }) |
| { |
| push @provided_arg_type_names, $k; |
| push @provided_arg_type_data, DTYPE_STR_TO_MX->{$v}; |
| } |
| $num_provided_arg_types = @provided_arg_type_names; |
| } |
| my @provided_arg_shape_data; |
| # argument shape index in sdata, |
| # e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg |
| my @provided_arg_shape_idx = (0); |
| my @provided_arg_shape_names; |
| while(my ($k, $v) = each %{ $shapes//{} }) |
| { |
| push @provided_arg_shape_names, $k; |
| push @provided_arg_shape_data, @{ $v }; |
| push @provided_arg_shape_idx, scalar(@provided_arg_shape_data); |
| } |
| $num_provided_arg_types = @provided_arg_type_names; |
| |
| my $provided_req_type_list_len = 0; |
| my @provided_grad_req_types; |
| my @provided_grad_req_names; |
| if(defined $grad_req) |
| { |
| if(not ref $grad_req) |
| { |
| push @provided_grad_req_types, $grad_req; |
| } |
| elsif(ref $grad_req eq 'ARRAY') |
| { |
| assert((@{ $grad_req } != 0), 'grad_req in simple_bind cannot be an empty list'); |
| @provided_grad_req_types = @{ $grad_req }; |
| $provided_req_type_list_len = @provided_grad_req_types; |
| } |
| elsif(ref $grad_req eq 'HASH') |
| { |
| assert((keys %{ $grad_req } != 0), 'grad_req in simple_bind cannot be an empty hash'); |
| while(my ($k, $v) = each %{ $grad_req }) |
| { |
| push @provided_grad_req_names, $k; |
| push @provided_grad_req_types, $v; |
| } |
| $provided_req_type_list_len = @provided_grad_req_types; |
| } |
| } |
| my $num_ctx_map_keys = 0; |
| my @ctx_map_keys; |
| my @ctx_map_dev_types; |
| my @ctx_map_dev_ids; |
| if(defined $group2ctx) |
| { |
| while(my ($k, $v) = each %{ $group2ctx }) |
| { |
| push @ctx_map_keys, $k; |
| push @ctx_map_dev_types, $v->device_type_id; |
| push @ctx_map_dev_ids, $v->device_id; |
| } |
| $num_ctx_map_keys = @ctx_map_keys; |
| } |
| |
| my @shared_arg_name_list; |
| if(defined $shared_arg_names) |
| { |
| @shared_arg_name_list = @{ $shared_arg_names }; |
| } |
| my %shared_data; |
| if(defined $shared_buffer) |
| { |
| while(my ($k, $v) = each %{ $shared_buffer }) |
| { |
| $shared_data{$k} = $v->handle; |
| } |
| } |
| my $shared_exec_handle = defined $shared_exec ? $shared_exec->handle : undef; |
| my ( |
| $updated_shared_data, |
| $in_arg_handles, |
| $arg_grad_handles, |
| $aux_state_handles, |
| $exe_handle |
| ); |
| eval { |
| ($updated_shared_data, $in_arg_handles, $arg_grad_handles, $aux_state_handles, $exe_handle) |
| = |
| check_call( |
| AI::MXNetCAPI::ExecutorSimpleBind( |
| $self->handle, |
| $ctx->device_type_id, |
| $ctx->device_id, |
| $num_ctx_map_keys, |
| \@ctx_map_keys, |
| \@ctx_map_dev_types, |
| \@ctx_map_dev_ids, |
| $provided_req_type_list_len, |
| \@provided_grad_req_names, |
| \@provided_grad_req_types, |
| scalar(@provided_arg_shape_names), |
| \@provided_arg_shape_names, |
| \@provided_arg_shape_data, |
| \@provided_arg_shape_idx, |
| $num_provided_arg_types, |
| \@provided_arg_type_names, |
| \@provided_arg_type_data, |
| scalar(@shared_arg_name_list), |
| \@shared_arg_name_list, |
| defined $shared_buffer ? \%shared_data : undef, |
| $shared_exec_handle |
| ) |
| ); |
| }; |
| if($@) |
| { |
| confess( |
| "simple_bind failed: Error: $@; Arguments: ". |
| Data::Dumper->new( |
| [$shapes//{}] |
| )->Purity(1)->Deepcopy(1)->Terse(1)->Dump |
| ); |
| } |
| if(defined $shared_buffer) |
| { |
| while(my ($k, $v) = each %{ $updated_shared_data }) |
| { |
| $shared_buffer->{$k} = AI::MXNet::NDArray->new(handle => $v); |
| } |
| } |
| my @arg_arrays = map { AI::MXNet::NDArray->new(handle => $_) } @{ $in_arg_handles }; |
| my @grad_arrays = map { defined $_ ? AI::MXNet::NDArray->new(handle => $_) : undef } @{ $arg_grad_handles }; |
| my @aux_arrays = map { AI::MXNet::NDArray->new(handle => $_) } @{ $aux_state_handles }; |
| my $executor = AI::MXNet::Executor->new( |
| handle => $exe_handle, |
| symbol => $self, |
| ctx => $ctx, |
| grad_req => $grad_req, |
| group2ctx => $group2ctx |
| ); |
| $executor->arg_arrays(\@arg_arrays); |
| $executor->grad_arrays(\@grad_arrays); |
| $executor->aux_arrays(\@aux_arrays); |
| return $executor; |
| } |
| |
| =head2 bind |
| |
| Bind current symbol to get an executor. |
| |
| Parameters |
| ---------- |
| :$ctx : AI::MXNet::Context |
| The device context the generated executor to run on. |
| |
| :$args : HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray] |
| Input arguments to the symbol. |
| - If type is array ref of NDArray, the position is in the same order of list_arguments. |
| - If type is hash ref of str to NDArray, then it maps the name of arguments |
| to the corresponding NDArray. |
| - In either case, all the arguments must be provided. |
| |
| :$args_grad : Maybe[HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]] |
| When specified, args_grad provide NDArrays to hold |
| the result of gradient value in backward. |
| - If type is array ref of NDArray, the position is in the same order of list_arguments. |
| - If type is hash ref of str to NDArray, then it maps the name of arguments |
| to the corresponding NDArray. |
| - When the type is hash ref of str to NDArray, users only need to provide the dict |
| for needed argument gradient. |
| Only the specified argument gradient will be calculated. |
| |
| :$grad_req : {'write', 'add', 'null'}, or array ref of str or hash ref of str to str, optional |
| Specifies how we should update the gradient to the args_grad. |
| - 'write' means everytime gradient is write to specified args_grad NDArray. |
| - 'add' means everytime gradient is add to the specified NDArray. |
| - 'null' means no action is taken, the gradient may not be calculated. |
| |
| :$aux_states : array ref of NDArray, or hash ref of str to NDArray, optional |
| Input auxiliary states to the symbol, only need to specify when |
| list_auxiliary_states is not empty. |
| - If type is array ref of NDArray, the position is in the same order of list_auxiliary_states |
| - If type is hash ref of str to NDArray, then it maps the name of auxiliary_states |
| to the corresponding NDArray, |
| - In either case, all the auxiliary_states need to be provided. |
| |
| :$group2ctx : hash ref of string to AI::MXNet::Context |
| The mapping of the ctx_group attribute to the context assignment. |
| |
| :$shared_exec : AI::MXNet::Executor |
| Executor to share memory with. This is intended for runtime reshaping, variable length |
| sequences, etc. The returned executor shares state with shared_exec, and should not be |
| used in parallel with it. |
| |
| Returns |
| ------- |
| $executor : AI::MXNet::Executor |
| The generated Executor |
| |
| Notes |
| ----- |
| Auxiliary states are special states of symbols that do not corresponds to an argument, |
| and do not have gradient. But still be useful for the specific operations. |
| A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm. |
| Most operators do not have auxiliary states and this parameter can be safely ignored. |
| |
| User can give up gradient by using a hash ref in args_grad and only specify |
| the gradient they're interested in. |
| =cut |
| |
| method bind( |
| AI::MXNet::Context :$ctx, |
| HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray] :$args, |
| Maybe[HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]] :$args_grad=, |
| Str|HashRef[Str]|ArrayRef[Str] :$grad_req='write', |
| Maybe[HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray]] :$aux_states=, |
| Maybe[HashRef[AI::MXNet::Context]] :$group2ctx=, |
| Maybe[AI::MXNet::Executor] :$shared_exec= |
| ) |
| { |
| $grad_req //= 'write'; |
| my $listed_arguments = $self->list_arguments(); |
| my ($args_handle, $args_grad_handle, $aux_args_handle) = ([], [], []); |
| ($args_handle, $args) = $self->_get_ndarray_inputs('args', $args, $listed_arguments); |
| if(not defined $args_grad) |
| { |
| @$args_grad_handle = ((undef) x (@$args)); |
| } |
| else |
| { |
| ($args_grad_handle, $args_grad) = $self->_get_ndarray_inputs( |
| 'args_grad', $args_grad, $listed_arguments, 1 |
| ); |
| } |
| |
| if(not defined $aux_states) |
| { |
| $aux_states = []; |
| } |
| ($aux_args_handle, $aux_states) = $self->_get_ndarray_inputs( |
| 'aux_states', $aux_states, $self->list_auxiliary_states() |
| ); |
| |
| # setup requirements |
| my $req_map = { null => 0, write => 1, add => 3 }; |
| my $req_array = []; |
| if(not ref $grad_req) |
| { |
| confess('grad_req must be one of "null,write,add"') |
| unless exists $req_map->{ $grad_req }; |
| @{ $req_array } = (($req_map->{ $grad_req }) x @{ $listed_arguments }); |
| } |
| elsif(ref $grad_req eq 'ARRAY') |
| { |
| @{ $req_array } = map { $req_map->{ $_ } } @{ $grad_req }; |
| } |
| else |
| { |
| for my $name (@{ $listed_arguments }) |
| { |
| if(exists $grad_req->{ $name }) |
| { |
| push @{ $req_array }, $req_map->{ $grad_req->{ $name } }; |
| } |
| else |
| { |
| push @{ $req_array }, 0; |
| } |
| } |
| } |
| |
| my $ctx_map_keys = []; |
| my $ctx_map_dev_types = []; |
| my $ctx_map_dev_ids = []; |
| |
| if(defined $group2ctx) |
| { |
| while(my ($key, $val) = each %{ $group2ctx }) |
| { |
| push @{ $ctx_map_keys } , $key; |
| push @{ $ctx_map_dev_types }, $val->device_type_id; |
| push @{ $ctx_map_dev_ids }, $val->device_id; |
| } |
| } |
| my $shared_handle = $shared_exec->handle if $shared_exec; |
| my $handle = check_call(AI::MXNetCAPI::ExecutorBindEX( |
| $self->handle, |
| $ctx->device_type_id, |
| $ctx->device_id, |
| scalar(@{ $ctx_map_keys }), |
| $ctx_map_keys, |
| $ctx_map_dev_types, |
| $ctx_map_dev_ids, |
| scalar(@{ $args }), |
| $args_handle, |
| $args_grad_handle, |
| $req_array, |
| scalar(@{ $aux_states }), |
| $aux_args_handle, |
| $shared_handle |
| ) |
| ); |
| my $executor = AI::MXNet::Executor->new( |
| handle => $handle, |
| symbol => $self, |
| ctx => $ctx, |
| grad_req => $grad_req, |
| group2ctx => $group2ctx |
| ); |
| $executor->arg_arrays($args); |
| $executor->grad_arrays($args_grad); |
| $executor->aux_arrays($aux_states); |
| return $executor; |
| } |
| |
| =head2 eval |
| |
| Evaluate a symbol given arguments |
| |
| The `eval` method combines a call to `bind` (which returns an executor) |
| with a call to `forward` (executor method). |
| For the common use case, where you might repeatedly evaluate with same arguments, |
| eval is slow. |
| In that case, you should call `bind` once and then repeatedly call forward. |
| Eval allows simpler syntax for less cumbersome introspection. |
| |
| Parameters |
| ---------- |
| :$ctx : Context |
| The device context the generated executor to run on. |
| Optional, defaults to cpu(0) |
| |
| :$args array ref of NDArray or hash ref of NDArray |
| |
| - If the type is an array ref of NDArray, the position is in the same order of list_arguments. |
| - If the type is a hash of str to NDArray, then it maps the name of the argument |
| to the corresponding NDArray. |
| - In either case, all arguments must be provided. |
| |
| Returns |
| ---------- |
| result : an array ref of NDArrays corresponding to the values |
| taken by each symbol when evaluated on given args. |
| When called on a single symbol (not a group), |
| the result will be an array ref with one element. |
| |
| Examples: |
| my $result = $symbol->eval(ctx => mx->gpu, args => {data => mx->nd->ones([5,5])}); |
| my $result = $symbol->eval(args => {data => mx->nd->ones([5,5])}); |
| |
| =cut |
| |
| method eval(:$ctx=AI::MXNet::Context->cpu, HashRef[AI::MXNet::NDArray]|ArrayRef[AI::MXNet::NDArray] :$args) |
| { |
| return $self->bind(ctx => $ctx, args => $args)->forward; |
| } |
| |
| =head2 grad |
| |
| Get the autodiff of current symbol. |
| This function can only be used if current symbol is a loss function. |
| |
| Parameters |
| ---------- |
| $wrt : Array of String |
| keyword arguments of the symbol that the gradients are taken. |
| |
| Returns |
| ------- |
| grad : AI::MXNet::Symbol |
| A gradient Symbol with returns to be the corresponding gradients. |
| =cut |
| |
| method grad(ArrayRef[Str] $wrt) |
| { |
| my $handle = check_call(AI::MXNetCAPI::SymbolGrad( |
| $self->handle, |
| scalar(@$wrt), |
| $wrt |
| ) |
| ); |
| return __PACKAGE__->new(handle => $handle); |
| } |
| |
| =head2 Variable |
| |
| Create a symbolic variable with specified name. |
| |
| Parameters |
| ---------- |
| name : str |
| Name of the variable. |
| attr : hash ref of string -> string |
| Additional attributes to set on the variable. |
| shape : array ref of positive integers |
| Optionally, one can specify the shape of a variable. This will be used during |
| shape inference. If user specified a different shape for this variable using |
| keyword argument when calling shape inference, this shape information will be ignored. |
| lr_mult : float |
| Specify learning rate muliplier for this variable. |
| wd_mult : float |
| Specify weight decay muliplier for this variable. |
| dtype : Dtype |
| Similar to shape, we can specify dtype for this variable. |
| init : initializer (mx->init->*) |
| Specify initializer for this variable to override the default initializer |
| kwargs : hash ref |
| other additional attribute variables |
| Returns |
| ------- |
| variable : Symbol |
| The created variable symbol. |
| =cut |
| |
| method Variable( |
| Str $name, |
| HashRef[Str] :$attr={}, |
| Maybe[Shape] :$shape=, |
| Maybe[Num] :$lr_mult=, |
| Maybe[Num] :$wd_mult=, |
| Maybe[Dtype] :$dtype=, |
| Maybe[AI::MXNet::Initializer] :$init=, |
| HashRef[Str] :$kwargs={}, |
| Maybe[Str] :$__layout__= |
| ) |
| { |
| my $handle = check_call(AI::MXNetCAPI::SymbolCreateVariable($name)); |
| my $ret = __PACKAGE__->new(handle => $handle); |
| $attr = AI::MXNet::Symbol::AttrScope->current->get($attr); |
| $attr->{__shape__} = "(".join(',', @{ $shape }).")" if $shape; |
| $attr->{__lr_mult__} = $lr_mult if defined $lr_mult; |
| $attr->{__wd_mult__} = $wd_mult if defined $wd_mult; |
| $attr->{__dtype__} = DTYPE_STR_TO_MX->{ $dtype } if $dtype; |
| $attr->{__init__} = "$init" if defined $init; |
| $attr->{__layout__} = $__layout__ if defined $__layout__; |
| while(my ($k, $v) = each %{ $kwargs }) |
| { |
| if($k =~ /^__/ and $k =~ /__$/) |
| { |
| $attr->{$k} = "$v"; |
| } |
| else |
| { |
| confess("Attribute name=$k is not supported.". |
| ' Additional attributes must start and end with double underscores,'. |
| ' e.g, __yourattr__' |
| ); |
| } |
| } |
| $ret->_set_attr(%{ $attr }); |
| return $ret; |
| } |
| |
| =head2 var |
| |
| A synonym to Variable. |
| =cut |
| |
| *var = \&Variable; |
| |
| =head2 Group |
| |
| Create a symbol that groups symbols together. |
| |
| Parameters |
| ---------- |
| symbols : array ref |
| List of symbols to be grouped. |
| |
| Returns |
| ------- |
| sym : Symbol |
| The created group symbol. |
| =cut |
| |
| method Group(ArrayRef[AI::MXNet::Symbol] $symbols) |
| { |
| my @handles = map { $_->handle } @{ $symbols }; |
| my $handle = check_call(AI::MXNetCAPI::SymbolCreateGroup(scalar(@handles), \@handles)); |
| return __PACKAGE__->new(handle => $handle); |
| } |
| |
| =head2 load |
| |
| Load symbol from a JSON file. |
| |
| You can also use Storable to do the job if you only work with Perl. |
| The advantage of load/save is the file is language agnostic. |
| This means the file saved using save can be loaded by other language binding of mxnet. |
| You also get the benefit being able to directly load/save from cloud storage(S3, HDFS) |
| |
| Parameters |
| ---------- |
| fname : str |
| The name of the file, examples: |
| |
| - `s3://my-bucket/path/my-s3-symbol` |
| - `hdfs://my-bucket/path/my-hdfs-symbol` |
| - `/path-to/my-local-symbol` |
| |
| Returns |
| ------- |
| sym : Symbol |
| The loaded symbol. |
| |
| See Also |
| -------- |
| AI::MXNet::Symbol->save : Used to save symbol into file. |
| =cut |
| |
| method load(Str $fname) |
| { |
| my $handle = check_call(AI::MXNetCAPI::SymbolCreateFromFile($fname)); |
| return __PACKAGE__->new(handle => $handle); |
| } |
| |
| =head2 load_json |
| Load symbol from json string. |
| |
| Parameters |
| ---------- |
| json_str : str |
| A json string. |
| |
| Returns |
| ------- |
| sym : Symbol |
| The loaded symbol. |
| |
| See Also |
| -------- |
| AI::MXNet::Symbol->tojson : Used to save symbol into json string. |
| =cut |
| |
| method load_json(Str $json) |
| { |
| my $handle = check_call(AI::MXNetCAPI::SymbolCreateFromJSON($json)); |
| return __PACKAGE__->new(handle => $handle); |
| } |
| |
| method zeros(Shape :$shape, Dtype :$dtype='float32', Maybe[Str] :$name=, Maybe[Str] :$__layout__=) |
| { |
| return __PACKAGE__->_zeros({ shape => $shape, dtype => $dtype, name => $name, ($__layout__ ? (__layout__ => $__layout__) : ()) }); |
| } |
| |
| method ones(Shape :$shape, Dtype :$dtype='float32', Maybe[Str] :$name=, Maybe[Str] :$__layout__=) |
| { |
| return __PACKAGE__->_ones({ shape => $shape, dtype => $dtype, name => $name, ($__layout__ ? (__layout__ => $__layout__) : ()) }); |
| } |
| |
| =head2 arange |
| |
| Simlar function in the MXNet ndarray as numpy.arange |
| See Also https://docs.scipy.org/doc/numpy/reference/generated/numpy.arange.html. |
| |
| Parameters |
| ---------- |
| start : number |
| Start of interval. The interval includes this value. The default start value is 0. |
| stop : number, optional |
| End of interval. The interval does not include this value. |
| step : number, optional |
| Spacing between values |
| repeat : int, optional |
| "The repeating time of all elements. |
| E.g repeat=3, the element a will be repeated three times --> a, a, a. |
| dtype : type, optional |
| The value type of the NDArray, default to np.float32 |
| |
| Returns |
| ------- |
| out : Symbol |
| The created Symbol |
| =cut |
| |
| method arange(Index :$start=0, Index :$stop=, Num :$step=1.0, Index :$repeat=1, Maybe[Str] :$name=, Dtype :$dtype='float32') |
| { |
| return __PACKAGE__->_arange({ |
| start => $start, (defined $stop ? (stop => $stop) : ()), |
| step => $step, repeat => $repeat, name => $name, dtype => $dtype |
| }); |
| } |
| |
| |
| sub _parse_arguments |
| { |
| my $type = shift; |
| my @args = @_; |
| my $type_c = find_type_constraint($type); |
| my $str_c = find_type_constraint("Str"); |
| my @positional_arguments; |
| my %kwargs; |
| my @kwargs_order; |
| my $only_dtypes_and_undefs = (@args == grep { not defined($_) or $type_c->check($_) } @args); |
| my $only_dtypes_and_strs = (@args == grep { $type_c->check($_) or $str_c->check($_) } @args); |
| if(@args % 2 and $only_dtypes_and_undefs) |
| { |
| @positional_arguments = @args; |
| } |
| else |
| { |
| if($only_dtypes_and_undefs) |
| { |
| @positional_arguments = @args; |
| } |
| elsif($only_dtypes_and_strs) |
| { |
| my %tmp = @args; |
| if(values(%tmp) == grep { $type_c->check($_) } values(%tmp)) |
| { |
| %kwargs = %tmp; |
| my $i = 0; |
| @kwargs_order = grep { $i ^= 1 } @args; |
| } |
| else |
| { |
| confess("Argument need to be of type $type"); |
| } |
| } |
| else |
| { |
| confess("Argument need to be one type $type"); |
| } |
| } |
| return (\@positional_arguments, \%kwargs, \@kwargs_order); |
| } |
| |
| sub _ufunc_helper |
| { |
| my ($lhs, $rhs, $fn_symbol, $lfn_scalar, $rfn_scalar, $reverse) = @_; |
| ($rhs, $lhs) = ($lhs, $rhs) if $reverse and $rfn_scalar; |
| if(not ref $lhs) |
| { |
| if(not $rfn_scalar) |
| { |
| return __PACKAGE__->can($lfn_scalar)->(__PACKAGE__, $rhs, { "scalar" => $lhs }); |
| } |
| else |
| { |
| return __PACKAGE__->can($rfn_scalar)->(__PACKAGE__, $rhs, { "scalar" => $lhs }); |
| } |
| } |
| elsif(not ref $rhs) |
| { |
| return __PACKAGE__->can($lfn_scalar)->(__PACKAGE__, $lhs, { "scalar" => $rhs }); |
| } |
| else |
| { |
| return __PACKAGE__->can($fn_symbol)->(__PACKAGE__, $lhs, $rhs); |
| } |
| } |
| |
| 1; |