blob: 2b60dca4f7830fe40f489717f65832c0a1952dbd [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
################################################################################
# Common types used in MXNet API
################################################################################
const MX_uint = Cuint
const MX_float = Cfloat
const MX_handle = Ptr{Cvoid}
const char_p = Ptr{UInt8}
const char_pp = Ptr{char_p}
################################################################################
# Enumeration from MXNet headers
################################################################################
# OpReqType in include/mxnet/op_attr_types.h
@enum GRAD_REQ GRAD_NOP=0 GRAD_WRITE=1 GRAD_INPLACE=2 GRAD_ADD=3
const grad_req_map = Dict{Symbol,GRAD_REQ}(
:nop => GRAD_NOP, # no operation, do not write anything
:write => GRAD_WRITE, # write gradient to provided space
:inplace => GRAD_INPLACE, # perform an inplace write
:add => GRAD_ADD, # add to the provided space
)
################################################################################
# Initialization and library API entrance
################################################################################
const MXNET_LIB = Libdl.find_library(["libmxnet.$(Libdl.dlext)", "libmxnet.so"], # see build.jl
[joinpath(get(ENV, "MXNET_HOME", ""), "lib"),
get(ENV, "MXNET_HOME", ""),
joinpath(@__DIR__, "..",
"deps", "usr", "lib")])
const LIB_VERSION = Ref{Cint}(0)
if isempty(MXNET_LIB)
# touch this file, so that after the user properly build libmxnet, the precompiled
# MXNet.ji will be re-compiled to get MXNET_LIB properly.
touch(@__FILE__)
error("Cannot find or load libmxnet.$(Libdl.dlext). " *
"Please see the document on how to build it.")
else
include_dependency(MXNET_LIB)
end
function __init__()
# TODO: bug in nnvm, if do not call this, call get handle "_copyto" will fail
_get_libmx_op_names()
_populate_iter_creator_cache!()
_get_lib_version!()
atexit() do
# notify libmxnet we are shutting down
ccall( ("MXNotifyShutdown", MXNET_LIB), Cint, () )
end
end
function mx_get_last_error()
msg = ccall( ("MXGetLastError", MXNET_LIB), char_p, () )
if msg == C_NULL
throw(MXError("Failed to get last error message"))
end
return unsafe_string(msg)
end
"Utility macro to call MXNet API functions"
macro mxcall(f, argtypes, args...)
args = map(esc, args)
quote
_mxret = ccall(($f, $MXNET_LIB),
Cint, $(esc(argtypes)), $(args...))
if _mxret != 0
err_msg = mx_get_last_error()
throw(MXError(err_msg))
end
end
end
"""
Get libmxnet version
This function will changes the global variable `LIB_VERSION`.
"""
function _get_lib_version!()
@mxcall :MXGetVersion (Ref{Cint},) LIB_VERSION
LIB_VERSION[]
end
################################################################################
# Handle types
################################################################################
function mx_define_handle_t(name, destructor)
@eval begin
mutable struct $name
value::MX_handle
function $name(value = C_NULL)
hdr = new(value)
$(if destructor != nothing
:(finalizer(delete!, hdr))
end)
return hdr
end
end
$(if finalizer != nothing
quote
function delete!(h :: $name)
if h.value != C_NULL
@mxcall($(QuoteNode(destructor)), (MX_handle,), h.value)
h.value = C_NULL
end
end
end
end)
function Base.unsafe_convert(::Type{MX_handle}, obj::$name)
obj.value
end
Base.convert(t::Type{MX_handle}, obj::$name) = Base.unsafe_convert(t, obj)
Base.cconvert(t::Type{MX_handle}, obj::$name) = Base.unsafe_convert(t, obj)
MX_handle(x::$name) = Base.convert(MX_handle, x)
end
end
mx_define_handle_t(:MX_NDArrayHandle, :MXNDArrayFree)
mx_define_handle_t(:MX_OpHandle, nothing)
mx_define_handle_t(:MX_SymbolHandle, :MXSymbolFree)
mx_define_handle_t(:MX_ExecutorHandle, :MXExecutorFree)
mx_define_handle_t(:MX_DataIterHandle, :MXDataIterFree)
mx_define_handle_t(:MX_KVStoreHandle, :MXKVStoreFree)
################################################################################
# MXNet Params
#
# MXNet API use string to pass some common parameters like the configurations
# when defining layers. Typically, it is enough to use string(obj) to get a
# recognizable representation for libmxnet. However, there is currently a
# caveat:
#
# Because Julia use column-major ordering for tensors. In order to properly
# interact with Julia Arrays, the shape will look "reversed" from the Julia
# side. For example, a typical MNIST mini-batch tensor is of shape (28,28,1,100)
# from Julia side, while the shape information for the same piece of memory
# should be interpreted as (100,1,28,28) from C/C++/Python side.
#
# Therefore, when passing parameters to libmxnet, we should reverse the shape
# parameter. For example, when the user specify a non-square kernel size for
# a convolution or pooling layer. Unfortunately, those operators are automatically
# imported, and information about the type of each parameter is somehow limited.
# One hacky way is to match the type description for the string "Shape(tuple)"
# when importing operators. But currently we simply decided to reverse **all**
# NTuple{N, Int} passed to libmxnet.
#
# TODO: find a better solution in case this cause issues in the future.
# I made `@_remap` in `ndarray.jl`. (Iblis Lin)
################################################################################
dump_mx_param(val::Any) = string(val)
dump_mx_param(val::Float64) = @sprintf("%.16e", val)
dump_mx_param(val::Float32) = @sprintf("%.8e", val)
dump_mx_param(val::Float16) = @sprintf("%.4e", val)
dump_mx_param(val::Irrational) = @sprintf("%.16e", val)
dump_mx_param(shape::NTuple{N,<:Integer}) where N =
string(reverse(shape))
"""
A convenient macro copied from Mocha.jl that could be used to define structs
with default values and type checks. For example
```julia
@defstruct MyStruct Any (
field1 :: Int = 0,
(field2 :: AbstractString = "", !isempty(field2))
)
```
where each field could be either
```julia
field_name :: field_type = default_value
```
or put within a tuple, with the second element
specifying a validation check on the field value.
In the example above, the default value for
field2 does not satisfy the assertion, this
could be used to force user to provide a
valid value when no meaningful default value
is available.
The macro will define a constructor that could accept
the keyword arguments.
"""
macro defstruct(name, fields)
_defstruct_impl(false, name, fields)
end
"""A convenient macro to define immutable structs. The same as
`@defstruct` except that the defined type is immutable.
"""
macro defimmutable(name, fields)
_defstruct_impl(true, name, fields)
end
"""Internal use only, this value is used to indicate a required value
is not specified.
"""
struct __Undefined
end
function _defstruct_impl(is_immutable, name, fields)
if isa(fields, Expr) && fields.head == :tuple
fields = fields.args
else
fields = [fields]
end
@assert length(fields) > 0
if isa(name, Symbol)
name = esc(name)
super_name = :Any
else
@assert(isa(name, Expr) && name.head == :(<:) && length(name.args) == 2 &&
isa(name.args[1], Symbol) && isa(name.args[2], Symbol),
"name must be of form 'Name <: SuperType'")
super_name = esc(name.args[2])
name = esc(name.args[1])
end
field_defs = Vector{Expr}(undef, length(fields)) # :(field2 :: Int)
field_names = Vector{Expr}(undef, length(fields)) # :field2
field_defaults = Vector{Expr}(undef, length(fields)) # :(field2 = 0)
field_types = Vector{Expr}(undef, length(fields)) # Int
field_asserts = Vector{Expr}(undef, length(fields)) # :(field2 >= 0)
required_field = Symbol[]
for i = 1:length(fields)
field = fields[i]
if field.head == :tuple
field_asserts[i] = esc(field.args[2])
field = field.args[1]
end
if field.head == :(=)
fname = field.args[1].args[1]
field_defs[i] = esc(field.args[1])
field_names[i] = esc(fname)
field_types[i] = esc(field.args[1].args[2])
field_defaults[i] = Expr(:kw, fname, esc(field.args[2]))
else
# no default value provided, required field
fname = field.args[1]
field_defs[i] = esc(field)
field_names[i] = esc(fname)
field_types[i] = esc(field.args[2])
field_defaults[i] = Expr(:kw, fname, __Undefined())
push!(required_field, fname)
end
end
# body of layer type, defining fields
type_body = Expr(:block, field_defs...)
# constructor
requires = map(required_field) do fname
:(@assert(!isa($fname, __Undefined), "value for " * string($fname) * " is required"))
end
converts = map(zip(field_names, field_types)) do param
f_name, f_type = param
:($f_name = convert($f_type, $f_name))
end
asserts = map(filter(i -> isassigned(field_asserts,i), 1:length(fields))) do i
:(@assert($(field_asserts[i])))
end
construct = Expr(:call, name, field_names...)
ctor_body = Expr(:block, requires..., converts..., asserts..., construct)
ctor_def = Expr(:call, name, Expr(:parameters, field_defaults...))
ctor = Expr(:(=), ctor_def, ctor_body)
if is_immutable
quote
struct $(name) <: $(super_name)
$type_body
end
$ctor
end
else
quote
mutable struct $(name) <: $(super_name)
$type_body
end
$ctor
end
end
end