blob: ac7f4fc71653b084152d09a51d0dd54c92bacb6d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
################################################################################
# Dataset related utilities
################################################################################
function get_data_dir()
data_dir = joinpath(@__DIR__, "..", "data")
mkpath(data_dir)
data_dir
end
function get_mnist_ubyte()
data_dir = get_data_dir()
mnist_dir = joinpath(data_dir, "mnist")
mkpath(mnist_dir)
filenames = Dict(:train_data => "train-images-idx3-ubyte",
:train_label => "train-labels-idx1-ubyte",
:test_data => "t10k-images-idx3-ubyte",
:test_label => "t10k-labels-idx1-ubyte")
filenames = Dict((x[1] => joinpath(mnist_dir, x[2]) for x pairs(filenames)))
if !all(isfile, values(filenames))
cd(mnist_dir) do
data = download("http://data.mxnet.io/mxnet/data/mnist.zip", "mnist.zip")
try
run(`unzip -u $data`)
catch
try
run(pipeline(`7z x $data`,stdout = devnull))
catch
error("Extraction Failed:No extraction program found in path")
end
end
end
end
return filenames
end
function get_cifar10()
data_dir = get_data_dir()
cifar10_dir = joinpath(data_dir, "cifar10")
mkpath(cifar10_dir)
filenames = Dict(:train => "cifar/train.rec", :test => "cifar/test.rec")
filenames = Dict(map((x) -> x[1] => joinpath(cifar10_dir, x[2]), filenames))
if !all(isfile, values(filenames))
cd(cifar10_dir) do
download("http://data.mxnet.io/mxnet/data/cifar10.zip", "cifar10.zip")
try
run(`unzip -u cifar10.zip`)
catch
try
run(pipeline(`7z x cifar10.zip`, stdout = devnull))
catch
error("Extraction Failed:No extraction program found in path")
end
end
end
end
filenames[:mean] = joinpath(cifar10_dir, "cifar/cifar_mean.bin")
return filenames
end
################################################################################
# Internal Utilities
################################################################################
function _get_libmx_op_names()
n = Ref{MX_uint}(0)
names = Ref{char_pp}(0)
@mxcall(:MXListAllOpNames, (Ref{MX_uint}, Ref{char_pp}), n, names)
names = unsafe_wrap(Array, names[], n[])
return [unsafe_string(x) for x in names]
end
function _get_libmx_op_handle(name :: String)
handle = Ref{MX_handle}(0)
@mxcall(:NNGetOpHandle, (char_p, Ref{MX_handle}), name, handle)
return MX_OpHandle(handle[])
end
# We keep a cache and retrieve the address everytime
# we run Julia, instead of pre-compiling with macro,
# because the actual handle might change in different
# runs
const _libmx_op_cache = Dict{String, MX_OpHandle}()
function _get_cached_libmx_op_handle(name :: String)
if !haskey(_libmx_op_cache, name)
handle = _get_libmx_op_handle(name)
_libmx_op_cache[name] = handle
return handle
else
return _libmx_op_cache[name]
end
end
function _get_libmx_op_description(name::String, handle::MX_OpHandle)
# get operator information (human readable)
ref_real_name = Ref{char_p}(0)
ref_desc = Ref{char_p}(0)
ref_narg = Ref{MX_uint}(0)
ref_arg_names = Ref{char_pp}(0)
ref_arg_types = Ref{char_pp}(0)
ref_arg_descs = Ref{char_pp}(0)
ref_key_narg = Ref{char_p}(0)
ref_ret_type = Ref{char_p}(0)
@mxcall(:MXSymbolGetAtomicSymbolInfo,
(MX_handle, Ref{char_p}, Ref{char_p}, Ref{MX_uint}, Ref{char_pp},
Ref{char_pp}, Ref{char_pp}, Ref{char_p}, Ref{char_p}),
handle, ref_real_name, ref_desc, ref_narg, ref_arg_names,
ref_arg_types, ref_arg_descs, ref_key_narg, ref_ret_type)
real_name = unsafe_string(ref_real_name[])
signature = _format_signature(Int(ref_narg[]), ref_arg_names)
desc = " " * name * "(" * signature * ")\n\n"
if real_name != name
desc *= name * " is an alias of " * real_name * ".\n\n"
end
key_narg = unsafe_string(ref_key_narg[])
if key_narg != ""
desc *= "**Note**: " * name * " takes variable number of positional inputs. "
desc *= "So instead of calling as $name([x, y, z], $key_narg=3), "
desc *= "one should call via $name(x, y, z), and $key_narg will be "
desc *= "determined automatically.\n\n"
end
desc *= unsafe_string(ref_desc[]) * "\n\n"
desc *= "# Arguments\n"
desc *= _format_docstring(Int(ref_narg[]), ref_arg_names, ref_arg_types, ref_arg_descs)
return desc, key_narg
end
_format_typestring(s::String) = replace(s, r"\bSymbol\b" => "SymbolicNode")
function _format_docstring(narg::Int, arg_names::Ref{char_pp}, arg_types::Ref{char_pp}, arg_descs::Ref{char_pp}, remove_dup::Bool=true)
param_keys = Set{String}()
arg_names = unsafe_wrap(Array, arg_names[], narg)
arg_types = unsafe_wrap(Array, arg_types[], narg)
arg_descs = unsafe_wrap(Array, arg_descs[], narg)
docstrings = String[]
for i = 1:narg
arg_name = unsafe_string(arg_names[i])
if arg_name param_keys && remove_dup
continue
end
push!(param_keys, arg_name)
arg_type = _format_typestring(unsafe_string(arg_types[i]))
arg_desc = unsafe_string(arg_descs[i])
push!(docstrings, "* `$arg_name::$arg_type`: $arg_desc\n")
end
return join(docstrings, "\n")
end
function _format_signature(narg::Int, arg_names::Ref{char_pp})
arg_names = unsafe_wrap(Array, arg_names[], narg)
return join([unsafe_string(name) for name in arg_names] , ", ")
end
"""
Extract the line of `Defined in ...`
julia> mx._getdocdefine("sgd_update")
"Defined in `src/operator/optimizer_op.cc:L53`"
"""
function _getdocdefine(name::String)
op = _get_libmx_op_handle(name)
str = _get_libmx_op_description(name, op)[1]
lines = split(str, '\n')
for m match.(Ref(r"^Defined in ([\S]+)$"), lines)
m != nothing && return "Defined in `$(m.captures[1])`"
end
""
end
"""
libmxnet operators signature checker.
C/Python have different convernsion of accessing array. Those languages
handle arrays in row-major and zero-indexing which differs from Julia's
colume-major and 1-indexing.
This function scans the docstrings of NDArray's APIs,
filter out the signature which contain `axis`, `axes`, `keepdims` and `shape`
as its function argument.
We invoks this checker in Travis CI build and pop up the warning message
if the functions does not get manually mapped
(imply it's dimension refering may looks weird).
If you found any warning in Travis CI build, please open an issue on GitHub.
"""
function _sig_checker()
names = filter(n -> ∉(lowercase(n), _op_import_bl), _get_libmx_op_names())
foreach(names) do name
op_handle = _get_libmx_op_handle(name)
desc, key_narg = _get_libmx_op_description(name, op_handle)
_sig = desc |> s -> split(s, '\n') |> first |> strip
_m = match(r"(axis|axes|keepdims|shape)", _sig)
if _m === nothing
return
end
@warn(_sig)
end
end
"""
Get first position argument from function sig
"""
function _firstarg(sig::Expr)
if sig.head (:where, :(::))
_firstarg(sig.args[1])
elseif sig.head == :call
i = if sig.args[2] isa Expr && sig.args[2].head == :parameters
# there are some keyward arguments locate at args[2]
3
elseif sig.args[1] === :broadcast_
# case of broadcasting, skip the first arg `::typeof(...)`
3
else
2
end
_firstarg(sig.args[i])
end
end
_firstarg(s::Symbol) = s
const _import_map = Dict{Symbol,Union{Missing,Module}}(
:diag => LinearAlgebra,
:dot => LinearAlgebra,
:norm => LinearAlgebra,
:shuffle => Random,
:mean => Statistics,
:gamma => missing,
)
function _import_expr(func_name::Symbol)
mod = get(_import_map, func_name, Base)
isdefined(mod, func_name) ? :(import $(Symbol(mod)): $func_name) : :()
end