blob: 2cca47f9ab4d5dce90b658898884868b5d6d9b7f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
package AI::MXNet::Context;
use strict;
use warnings;
use Mouse;
use AI::MXNet::NS;
use AI::MXNet::Base;
use AI::MXNet::Types;
use AI::MXNet::Function::Parameters;
use constant devtype2str => { 1 => 'cpu', 2 => 'gpu', 3 => 'cpu_pinned' };
use constant devstr2type => { cpu => 1, gpu => 2, cpu_pinned => 3 };
around BUILDARGS => sub {
my $orig = shift;
my $class = shift;
return $class->$orig(device_type => $_[0])
if @_ == 1 and $_[0] =~ /^(?:cpu|gpu|cpu_pinned)$/;
return $class->$orig(
device_type => $_[0]->device_type,
device_id => $_[0]->device_id
) if @_ == 1 and blessed $_[0];
return $class->$orig(device_type => $_[0], device_id => $_[0])
if @_ == 2 and $_[0] =~ /^(?:cpu|gpu|cpu_pinned)$/;
return $class->$orig(@_);
};
has 'device_type' => (
is => 'rw',
isa => enum([qw[cpu gpu cpu_pinned]]),
default => 'cpu'
);
has 'device_type_id' => (
is => 'rw',
isa => enum([1, 2, 3]),
default => sub { devstr2type->{ shift->device_type } },
lazy => 1
);
has 'device_id' => (
is => 'rw',
isa => 'Int',
default => 0
);
use overload
'==' => sub {
my ($self, $other) = @_;
return 0 unless blessed($other) and $other->isa(__PACKAGE__);
return "$self" eq "$other";
},
'""' => sub {
my ($self) = @_;
return sprintf("%s(%s)", $self->device_type, $self->device_id);
},
fallback => 1;
=head1 NAME
AI::MXNet::Context - A device context.
=cut
=head1 DESCRIPTION
This class governs the device context of AI::MXNet::NDArray objects.
=cut
=head1 SYNOPSIS
use AI::MXNet qw(mx);
print nd->array([[1,2],[3,4]], ctx => mx->cpu)->aspdl;
my $arr_gpu = nd->random->uniform(shape => [10, 10], ctx => mx->gpu(0));
=cut
=head2
Constructing a context.
Parameters
----------
device_type : {'cpu', 'gpu'} or Context.
String representing the device type
device_id : int (default=0)
The device id of the device, needed for GPU
=cut
=head2 cpu
Returns a CPU context.
Parameters
----------
device_id : int, optional
The device id of the device. device_id is not needed for CPU.
This is included to make interface compatible with GPU.
Returns
-------
context : AI::MXNet::Context
The corresponding CPU context.
=cut
method cpu(Int $device_id=0)
{
return $self->new(device_type => 'cpu', device_id => $device_id);
}
=head2 cpu_pinned
Returns a CPU pinned memory context. Copying from CPU pinned memory to GPU
is faster than from normal CPU memory.
Parameters
----------
device_id : int, optional
The device id of the device. `device_id` is not needed for CPU.
This is included to make interface compatible with GPU.
Returns
-------
context : Context
The corresponding CPU pinned memory context.
=cut
method cpu_pinned(Int $device_id=0)
{
return $self->new(device_type => 'cpu_pinned', device_id => $device_id);
}
=head2 gpu
Returns a GPU context.
Parameters
----------
device_id : int, optional
Returns
-------
context : AI::MXNet::Context
The corresponding GPU context.
=cut
method gpu(Int $device_id=0)
{
return $self->new(device_type => 'gpu', device_id => $device_id);
}
=head2 current_context
Returns the current context.
Returns
-------
$default_ctx : AI::MXNet::Context
=cut
=head2 num_gpus
Query CUDA for the number of GPUs present.
Raises
------
Will raise an exception on any CUDA error.
Returns
-------
count : int
The number of GPUs.
=cut
method num_gpus()
{
return scalar(check_call(AI::MXNetCAPI::GetGPUCount()));
}
=head2 gpu_memory_info
Query CUDA for the free and total bytes of GPU global memory.
Parameters
----------
$device_id=0 : int, optional
The device id of the GPU device.
Raises
------
Will raise an exception on any CUDA error.
Returns
-------
($free, $total) : (int, int)
Free and total memory in bytes.
=cut
method gpu_memory_info($device_id=0)
{
return check_call(AI::MXNetCAPI::GetGPUMemoryInformation64($device_id));
}
method current_ctx()
{
return $AI::MXNet::Context;
}
method set_current(AI::MXNet::Context $current)
{
$AI::MXNet::Context = $current;
}
*current_context = \&current_ctx;
method deepcopy()
{
return __PACKAGE__->new(
device_type => $self->device_type,
device_id => $self->device_id
);
}
__PACKAGE__->AI::MXNet::NS::register('AI::MXNet');
1;