blob: c86e8fffd231d1e7db12e2fef9acbd0a7e3e2165 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# NDArray functions dynamically imported from libmxnet
function _invoke_mxfunction(func_handle::MX_handle, use_vars, scalars, mut_vars; kwargs...)
names = String[string(entry[1]) for entry in kwargs]
args = String[string(entry[2]) for entry in kwargs]
@mxcall(:MXFuncInvokeEx,
(MX_handle, Ptr{MX_handle}, Ptr{MX_float}, Ptr{MX_handle}, Cint, char_pp, char_pp),
func_handle, use_vars, scalars, mut_vars, length(names), names, args)
end
@enum(LIBMX_FUNC_TYPE_MASK,
NDARRAY_ARG_BEFORE_SCALAR = 1,
ACCEPT_EMPTY_MUTATE_TARGET = (1 << 2)
)
# Import corresponding math functions from base so the automatically defined libmxnet
# functions can overload them
import Base: sqrt
"""
The libxmnet APIs are automatically imported from `libmxnet.so`. The functions listed
here operate on `NDArray` objects. The arguments to the functions are typically ordered
as
```julia
func_name(arg_in1, arg_in2, ..., scalar1, scalar2, ..., arg_out1, arg_out2, ...)
```
unless `NDARRAY_ARG_BEFORE_SCALAR` is not set. In this case, the scalars are put before the input arguments:
```julia
func_name(scalar1, scalar2, ..., arg_in1, arg_in2, ..., arg_out1, arg_out2, ...)
```
If `ACCEPT_EMPTY_MUTATE_TARGET` is set. An overloaded function without the output arguments will also be defined:
```julia
func_name(arg_in1, arg_in2, ..., scalar1, scalar2, ...)
```
Upon calling, the output arguments will be automatically initialized with empty NDArrays.
Those functions always return the output arguments. If there is only one output (the typical situation), that
object (`NDArray`) is returned. Otherwise, a tuple containing all the outputs will be returned.
"""
function _get_ndarray_function_def(name::String)
func_name = Symbol(name)
func_def = quote
function $func_name(::Type{<:NDArray}, args::NDArray...; out=nothing, kwargs...)
if out != nothing
output_vars = out
if isa(output_vars, NDArray)
output_vars = NDArray[output_vars]
end
num_outputs = length(output_vars)
else
output_vars = NDArray[]
num_outputs = 0
end
args = collect(args) # tuple to list
if length(args) == 0
args = MX_handle[]
end
output_handles_pp = if length(output_vars) > 0
[map(x -> x.handle, output_vars)]
else
[Ptr{MX_handle}(C_NULL)]
end
num_outputs_p = [convert(Cint, num_outputs)]
kw_keys_str = String[string(x[1]) for x in kwargs]
kw_vals_str = String[dump_mx_param(x[2]) for x in kwargs]
op_handle = _get_cached_libmx_op_handle($(name))
@mxcall(:MXImperativeInvoke,
(MX_handle, Cint, Ptr{MX_handle},
Ptr{Cint}, Ptr{Ptr{MX_handle}},
Cint, char_pp, char_pp),
op_handle, length(args), args,
num_outputs_p, output_handles_pp,
length(kwargs), kw_keys_str, kw_vals_str)
if out == nothing
n = num_outputs_p[]
hdls = unsafe_wrap(Array{MX_handle}, output_handles_pp[], n)
xs = NDArray[NDArray(MX_NDArrayHandle(x)) for x in hdls]
if n == 1
return xs[]
else
return xs
end
else
return out
end
end
end
func_def2 = quote
function $func_name(args::NDArray...; out=nothing, kwargs...)
$func_name(NDArray, args...; out=out, kwargs...)
end
end
return func_def, func_def2
end
const _op_import_bl = [ # import black list; do not import these funcs
"_full", # we already have `mx.fill`
"_ones", # we already have `mx.ones`
"_zeros", # we already have `mx.zeros`
"clip",
"expand_dims",
# arithmetic
"_plus",
"_minus",
"_mod",
"_mod_scalar",
"_rmod_scalar",
"dot",
"max",
"max_axis",
"mean",
"min",
"min_axis",
"prod",
"reshape",
"sum",
"transpose",
# trigonometric
"sin",
"cos",
"tan",
"arcsin",
"arccos",
"arctan",
# hyperbolic
"sinh",
"cosh",
"tanh",
"arcsinh",
"arccosh",
"arctanh",
# activation
"sigmoid",
"relu",
"softmax",
"log_softmax",
# broadcast
"broadcast_add",
"broadcast_plus",
"broadcast_minus",
"broadcast_sub",
"broadcast_mul",
"broadcast_div",
"broadcast_mod",
"broadcast_power",
"broadcast_equal",
"broadcast_not_equal",
"broadcast_greater",
"broadcast_greater_equal",
"broadcast_lesser",
"broadcast_lesser_equal",
"broadcast_maximum",
"broadcast_minimum",
"broadcast_to",
"broadcast_axis",
"broadcast_axes",
"broadcast_hypot",
# reduction
"argmax",
"argmin",
]
macro _import_ndarray_functions()
names = filter(n -> ∉(lowercase(n), _op_import_bl), _get_libmx_op_names())
func_exprs = map(names) do name
op_handle = _get_libmx_op_handle(name)
desc, key_narg = _get_libmx_op_description(name, op_handle)
func_def, func_def2 = _get_ndarray_function_def(name)
func_name = Symbol(name)
import_expr = _import_expr(func_name)
quote
$import_expr
$func_def
@doc $desc
$func_def2
end
end
esc(quote
$(func_exprs...)
end)
end
@_import_ndarray_functions