wip
diff --git a/julia/src/MXNet.jl b/julia/src/MXNet.jl
index 89ec88b..6a01fb0 100644
--- a/julia/src/MXNet.jl
+++ b/julia/src/MXNet.jl
@@ -64,6 +64,20 @@
broadcast_axis,
broadcast_axes
+# autograd.jl
+export attach_grad!,
+ backward!,
+ getgrad,
+ is_recording,
+ is_training,
+ mark_variables,
+ pause,
+ predict_mode,
+ record,
+ symbol,
+ train_mode,
+ @custom
+
# executor.jl
export Executor,
bind,
diff --git a/julia/src/autograd.jl b/julia/src/autograd.jl
index 8b5edae..3a32c08 100644
--- a/julia/src/autograd.jl
+++ b/julia/src/autograd.jl
@@ -19,6 +19,9 @@
# this is a port of Python's autograd module
# https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/autograd.py
+using Base.Meta: isexpr
+using Base.GC # FIXME
+
###############################################################################
# Private util functions
###############################################################################
@@ -211,7 +214,7 @@
- `head::NDArray`: output NDArray
-- `head_grad::NDArray` or `Cvoid`: gradient coefficient with respect to head.
+- `head_grad::NDArray` or `Nothing`: gradient coefficient with respect to head.
- `heads::Vector{NDArray}`: a list of output NDArray
@@ -227,11 +230,14 @@
backward!(head::NDArray, head_grad::NDArray; kws...) =
backward!([head], [head_grad]; kws...)
-backward!(head::NDArray, head_grad::Cvoid = nothing; kws...) =
+backward!(head::NDArray, head_grad::Nothing = nothing; kws...) =
backward!([head], head_grad; kws...)
-function backward!(heads::VecOfNDArray, head_grad::Cvoid;
+function backward!(heads::VecOfNDArray, ::Nothing;
retain_graph::Bool = false, train_mode::Bool = true)
+ cblist_ref = first(keys(_cblists))
+
+ # TODO check MXAutogradBackwardEx usage
@mxcall(
:MXAutogradBackwardEx,
(MX_uint,
@@ -242,8 +248,8 @@
Cint,
Cint,
Cint,
- Ptr{MX_handle},
- Ptr{MX_handle}),
+ Ptr{Ptr{MX_handle}},
+ Ptr{Ptr{Cint}}),
length(heads),
map(x -> x.handle, heads),
C_NULL,
@@ -279,8 +285,8 @@
Cint,
Cint,
Cint,
- Ptr{MX_handle},
- Ptr{MX_handle}),
+ Ptr{Ptr{MX_handle}},
+ Ptr{Ptr{Cint}}),
length(output_handles),
output_handles,
ograd_handles,
@@ -400,5 +406,219 @@
end
###############################################################################
-# TODO: User-defined differentiable function
+# User-defined differentiable function
###############################################################################
+
+
+# gc-free holder
+const _cbs_r = [Ref{Ptr{Cvoid}}(C_NULL), Ref{Ptr{Cvoid}}(C_NULL)]
+const _cbs = [Ptr{Cvoid}(C_NULL), Ptr{Cvoid}(C_NULL)]
+const _cbsref = Ref{Ptr{Ptr{Cvoid}}}(C_NULL)
+const _frefs = Dict() # hold custom function instance and its args
+const _conds = []
+
+function _back_wrapper(num_ograds, num_igrads, ptrs, reqs, is_train, fptr::Ptr{Cvoid})
+ # @info "_back_wrapper"
+ # hdls = unsafe_wrap(Array, ptrs, num_ograds + num_igrads)
+ # @info "_back_wrapper" hdls
+ # ograds = map(x -> NDArray(MX_NDArrayHandle(x), false), hdls[1:num_ograds])
+ # @info "_back_wrapper" ograds
+ # igrads = map(NDArray ∘ MX_NDArrayHandle, hdls[num_ograds+1:num_ograds+num_igrads])
+ # @info "_back_wrapper" igrads
+ # reqs = unsafe_wrap(Array, reqs, num_igrads)
+ # @info "_back_wrapper" reqs
+ #
+ # # passing closure via raw pointer
+ # f = unsafe_pointer_to_objref(fptr)
+ #
+ # Δs = backward!(f, ograds...)
+ # Δs = Δs isa NDArray ? [Δs] : Δs
+ #
+ # # update gradient
+ # for (i, Δ, req) ∈ zip(igrads, Δs, reqs)
+ # req = GRAD_REQ(req)
+ # if req == GRAD_NOP
+ # continue
+ # elseif req ∈ (GRAD_WRITE, GRAD_INPLACE)
+ # i[:] = Δ
+ # elseif req == GRAD_ADD
+ # i[:] += Δ
+ # end
+ # end
+ #
+ # # release ref for gc
+ # delete!(_frefs, f)
+
+ Cint(true)
+end
+
+function _back_wrapper(num_ograds, num_igrads, ptrs, reqs, is_train, handle)
+ ccall(:uv_async_send, Cint, (Ptr{Cvoid},), handle)
+end
+
+function _del_wrapper(handle)
+ ccall(:uv_async_send, Cint, (Ptr{Cvoid},), handle)
+end
+
+function _wtf_wrapper(handle)
+ ccall(:uv_async_send, Cint, (Ptr{Cvoid},), handle)
+end
+
+function _init_customfunc() # will be invoked in __init__
+ global _cbs_r
+ global _cbs
+ global _cbsref
+
+ # the callback function prototype:
+ # https://github.com/apache/incubator-mxnet/blob/ca565a00285d4fb0ca77ba9dc651a07ce1f01b24/include/mxnet/c_api.h#L209-L212
+ _cbs_r[1][] = _cbs[1] = @cfunction(_back_wrapper, Cint,
+ (Cint, Cint, Ptr{Ptr{Cvoid}}, Ptr{Cint},
+ Cint, Ptr{Cvoid}))
+ # _cbs_r[1][] = _cbs[1] = @cfunction(_wtf_wrapper, Cvoid, (Ptr{Cvoid},))
+
+ _cbs_r[2][] = _cbs[2] = @cfunction(_del_wrapper, Cint, (Ptr{Cvoid},))
+ _cbsref[] = Base.unsafe_convert(Ptr{Ptr{Cvoid}}, _cbs)
+ @info "_init_customfunc" _cbs _cbsref[]
+end
+
+struct MXCallbackList
+ n::Cint # int num_callbacks;
+ cbs::Ptr{Ptr{Cvoid}} # int (**callbacks)(Cvoid);
+ ctxs::Ptr{Ptr{Cvoid}} # void **contexts;
+
+ # we must provide two callback functions
+ # the first is backward function `_back_wrapper`
+ # the second is delete callback `_del_wrapper`
+ # https://github.com/apache/incubator-mxnet/blob/2f8c1e83f94e84a25a48d2cd43136030fb3f2d1e/include/mxnet/c_api.h#L174-L182
+
+ # `ctxs` is a array which is same size as `cbs`
+ # its elements will be passed as `state` for callback functions,
+ # usually the last argument.
+ # In our case, we will push the pointer of custom func instance as
+ # first element of `ctxs`; the pointer of MXCallbackList instance as
+ # the second element.
+ # The purpose of first pointer is to pass closure into `cfunction`.
+ # The second pointer is to free the reference of MXCallbackList,
+ # and let the function instance be GC-ed properly.
+
+ function MXCallbackList(f) # where all args are Refs
+ fr = Ref(f)
+ push!(_fholder, fr)
+ @info "f ref" Base.unsafe_convert(Ptr{Cvoid}, fr)
+ cond = Base.AsyncCondition() do cond
+ @info "real back callback"
+ A = ones(10000000)
+ for i ∈ 1:10000
+ B = A * A
+ end
+ @info "long run op end"
+ end
+ cond2 = Base.AsyncCondition() do cond
+ @info "real del callback"
+ end
+ push!(_conds, cond)
+ push!(_conds, cond2)
+ @info "conds" cond.handle cond2.handle
+ ctxs = [
+ cond.handle,
+ cond2.handle,
+ ]
+ ctxsptr = Base.unsafe_convert(Ptr{Ptr{Cvoid}}, ctxs)
+ cblist = new(length(ctxs), _cbsref[], ctxsptr)
+ # get the reference, and make a self-reference in ctxs[2]
+ cblist_ref = Ref{MXCallbackList}(cblist)
+ ctxs[2] = Base.unsafe_convert(Ptr{Cvoid}, cblist_ref)
+ # insert ref into a holder to prevent from being GC-ed.
+ # hold `xs` and `ys` which is passed into `MXCustomFunctionRecord`.
+ _cblists[cblist_ref] = Ref(ctxs)
+ cblist_ref
+ end
+end
+
+# hold MXCallbackList to prevent from gc
+const _cblists = Dict{Ref{MXCallbackList},Ref}()
+const _fholder = []
+
+"""
+ @custom
+Create callable custom function.
+All the position-arguments should be `NDArray`.
+The return value should be a instance of your custom type.
+Please checkout `examples/autograd/customfunc.jl` for example.
+"""
+macro custom(ex::Expr)
+ fdef = splitdef(ex) # by MacroTools
+ sig = ex.args[1]
+ body = esc(Expr(:let, Expr(:block), ex.args[2])) # create a new scope via `let`
+
+ # only extract symbols, get rid of all annotations and default values
+ args = map(x -> esc(splitarg(x)[1]), fdef[:args])
+ # forward(f, xs...)
+ forward_expr = Expr(:call, :forward, :f, args...)
+ # insert keyword args
+ if !isempty(fdef[:kwargs])
+ # only extract symbols, get rid of all annotations and default values
+ kwargs = map(fdef[:kwargs]) do x
+ sym = splitarg(x)[1]
+ Expr(:kw, sym, esc(sym))
+ end
+ append!(forward_expr.args, kwargs)
+ end
+
+ # xs, FIXME: a list of NDArray from positional argument
+ xs_len = length(args)
+ xs_expr = Expr(:vect, args...)
+
+ body′ = quote
+ f, ys = _record(false, nothing) do
+ f = $body # f is the object instance
+ ys = $forward_expr
+ f, ys
+ end
+
+ !is_recording() && return ys
+
+ xs = $xs_expr
+ ys′ = ys isa NDArray ? [ys] : ys
+
+ # struct MXCallbackList
+ cblist_ref = MXCallbackList(f)
+
+ # gc-free
+ xsr, ysr = Ref(xs), Ref(ys′)
+ _frefs[f] = (xsr, ysr)
+ # @info _frefs
+
+ @mxcall(
+ :MXCustomFunctionRecord,
+ (Cint, # num_inputs
+ Ref{MX_handle}, # inputs
+
+ Cint, # num_outputs
+ Ref{MX_handle}, # outputs
+
+ Ref{MXCallbackList}), # callbacks
+ $xs_len,
+ xs,
+
+ length(ys′),
+ ys′,
+
+ cblist_ref)
+
+ @info "inputs xs" Base.unsafe_convert(Ref{MX_handle}, xs)
+ @info "outputs ys" Base.unsafe_convert(Ref{MX_handle}, ys′)
+
+
+ ys
+ end
+
+ GC.enable(false) # FIXME
+
+ Expr(:function, esc(sig), body′)
+end
+
+# custom function should overload these functions.
+# the # of forward return values is the inputs of backward!.
+function forward end
+function backward! end
diff --git a/julia/src/base.jl b/julia/src/base.jl
index 6831464..d10be39 100644
--- a/julia/src/base.jl
+++ b/julia/src/base.jl
@@ -69,6 +69,7 @@
_get_libmx_op_names()
_populate_iter_creator_cache!()
_get_lib_version!()
+ _init_customfunc()
atexit() do
# notify libmxnet we are shutting down