blob: 3a32c08165ebca38f14db55364600a4e9a5c9f06 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Autograd for NDArray
# this is a port of Python's autograd module
# https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/autograd.py
using Base.Meta: isexpr
using Base.GC # FIXME
###############################################################################
# Private util functions
###############################################################################
"""
_set_recording(state::Bool)::Bool
Set status to recording/not recording. When recording, graph will be constructed
for gradient computation.
## Parameters
* `state::Bool`
## Returns
Previous state before this set
"""
function _set_recording(state::Bool)::Bool
prev = Ref{Cint}(C_NULL)
@mxcall(:MXAutogradSetIsRecording, (Cint, Ref{Cint}), state, prev)
prev[]
end
_set_recording(::Cvoid) = nothing
"""
Set status to training/predicting.
For example, Dropout will drop inputs randomly when
`train_mode = true` while simply passing through if `train_mode = false`.
## Parameters
* `train_mode::Bool`
## Returns
Previous state before this set.
"""
function _set_training(train_mode::Bool)::Bool
prev = Ref{Cint}(C_NULL)
@mxcall(:MXAutogradSetIsTraining, (Cint, Ref{Cint}), train_mode, prev)
prev[]
end
_set_training(::Cvoid) = nothing
###############################################################################
# Public API
###############################################################################
"""
is_recording()::Bool
Get status on recording/not recording.
"""
function is_recording()::Bool
state = Ref{Cint}(C_NULL)
@mxcall(:MXAutogradIsRecording, (Ref{Cint},), state)
state[]
end
"""
is_training()::Bool
Get status on recording/not recording.
"""
function is_training()::Bool
state = Ref{Cint}(C_NULL)
@mxcall(:MXAutogradIsTraining, (Ref{Cint},), state)
state[]
end
@inline function _record(f, is_record::Union{Cvoid,Bool}, train_mode::Union{Cvoid,Bool})
# Port from Python's `_RecordingStateScope` context manager
# __enter__
prev_is_record = _set_recording(is_record)
prev_train_mode = _set_training(train_mode)
try
f()
finally
# __exit__
if is_record != nothing && prev_is_record != is_record
_set_recording(prev_is_record)
end
if train_mode != nothing && prev_train_mode != train_mode
_set_recording(prev_train_mode)
end
end
end
"""
record(f, train_mode = true)
record(translates = true) do
...
end
Returns an autograd recording scope context to be used in `do` block
and captures code that needs gradients to be calculated.
Parameter `train_mode::Bool` controls whether the forward pass is in training
or predicting mode.
This controls the behavior of some layers such as `Dropout`, `BatchNorm`.
!!! note
When forwarding with `train_mode = false`, the corresponding backward
should also use `train_mode = false`, otherwise gradient is undefined.
```julia
x = mx.NDArray([1 2; 3 4])
∇ = mx.attach_grad!(x)
y = mx.record() do
2x
end
mx.backward!(y)
julia> ∇
2×2 mx.NDArray{Int64,2} @ CPU0:
2 2
2 2
```
"""
record(f, train_mode::Bool = true) = _record(f, true, train_mode)
"""
pause(f, train_mode = false)
pause(train_mode = false) do
...
end
Create a scope context for codes that do not need gradients to be calculated.
```julia
record() do
...
pause() do
# testing, IO, gradient updates...
end
end
```
"""
pause(f, train_mode::Bool = false) = _record(f, false, train_mode)
"""
train_mode(f)
train_mode() do
...
end
Create a scope context in which forward pass behavior is set to training mode,
without changing the recording states.
```julia
y = model(x)
train_mode() do
z = mx.Dropout(y)
...
end
```
"""
train_mode(f) = _record(f, nothing, true)
"""
predict_mode(f)
predict_mode() do
...
end
Create a scope context in which forward pass behavior is set to inference mode,
without changing the recording states.
```julia
record() do
y = model(x)
predict_mode() do
y = sampling(y)
end
end
```
"""
predict_mode(f) = _record(f, nothing, false)
"""
backward!(head, head_grad; retain_graph = false, train_mode = true)
backward!(heads, head_grads; retain_graph = false, train_mode = true)
Compute the gradients of heads w.r.t previously marked variables.
## Parameters
- `head::NDArray`: output NDArray
- `head_grad::NDArray` or `Nothing`: gradient coefficient with respect to head.
- `heads::Vector{NDArray}`: a list of output NDArray
- `head_grads::Vector`: a list of gradient coefficient with respect ot heads.
the element should be `NDArray` or `Cvoid`
- `retain_graph::Bool`: whether to keep the graph after backward. e.g:
If you want to differentiate the same graph twice,
you need to pass `retain_graph=true`.
- `train_mode::Bool`: whether to do backward for training or predicting.
"""
backward!(head::NDArray, head_grad::NDArray; kws...) =
backward!([head], [head_grad]; kws...)
backward!(head::NDArray, head_grad::Nothing = nothing; kws...) =
backward!([head], head_grad; kws...)
function backward!(heads::VecOfNDArray, ::Nothing;
retain_graph::Bool = false, train_mode::Bool = true)
cblist_ref = first(keys(_cblists))
# TODO check MXAutogradBackwardEx usage
@mxcall(
:MXAutogradBackwardEx,
(MX_uint,
Ptr{MX_handle},
Ptr{MX_handle},
MX_uint,
Ptr{MX_handle},
Cint,
Cint,
Cint,
Ptr{Ptr{MX_handle}},
Ptr{Ptr{Cint}}),
length(heads),
map(x -> x.handle, heads),
C_NULL,
0,
C_NULL,
retain_graph,
false, # create_graph
train_mode,
C_NULL,
C_NULL)
end
function backward!(heads::VecOfNDArray, head_grads::Vector;
retain_graph::Bool = false, train_mode::Bool = true)
output_handles = map(x -> x.handle, heads)
ograd_handles = map(head_grads) do x
if x isa NDArray
x.handle
elseif x nothing # faster than `x isa Cvoid` in Julia 0.7
MX_handle(C_NULL)
else
throw(ArgumentError("element of head_grads should be NDArray or Cvoid"))
end
end
@assert length(output_handles) == length(ograd_handles)
@mxcall(
:MXAutogradBackwardEx,
(MX_uint,
Ptr{MX_handle},
Ptr{MX_handle},
MX_uint,
Ptr{MX_handle},
Cint,
Cint,
Cint,
Ptr{Ptr{MX_handle}},
Ptr{Ptr{Cint}}),
length(output_handles),
output_handles,
ograd_handles,
0,
C_NULL,
retain_graph,
false, # create_graph
train_mode,
C_NULL,
C_NULL)
end
"""
getgrad(arr::NDArray)
Returns the gradient buffer attached to this `NDArray`.
If the gradient buffer isn't attached yet, return `nothing`.
"""
function getgrad(arr::NDArray)
out = Ref{MX_handle}(C_NULL)
@mxcall(:MXNDArrayGetGrad, (MX_handle, Ref{MX_handle}), arr.handle, out)
(out[] == C_NULL) ? nothing : NDArray(MX_NDArrayHandle(out[]))
end
"""
attach_grad!(x::NDArray, grad_req::Symbol = :write)
Attach a gradient buffer to this `NDArray`,
so that [`backward!`](@ref) can compute gradient with respect to it.
## Parameters
- `x::NDArray`
- `grad_req::Symbol` (default is `:write`)
## Return
The attached gradient buffer
## See also
- [`getgrad`](@ref)
"""
function attach_grad!(x::NDArray, grad_req::Symbol = :write)
# TODO: support storage type (stype in Python)
# TODO: make sure it works with gpu array
grad = zeros_like(x)
_mark_variables!([x], [grad], grad_req)
grad
end
"""
mark_variables!(var, grad, grad_req)
mark_variables!(vars, grads, grad_reqs)
Mark `NDArrays` as variables to compute gradient for autograd.
## Parameters
- `var::NDArray`
- `grad::NDArray`
- `grad_req::Symbol`: `:nop`, `:write`, `:inplace` or `:add`
- `vars::Vector{NDArray}`
- `grads::Vector{NDArray}`
- `grad_req::Vector{Symbol}`
"""
mark_variables!(var::NDArray, grad::NDArray, grad_reqs::Symbol = :write) =
_mark_variables!([var], [grad], grad_reqs)
mark_variables!(var::VecOfNDArray, grads::VecOfNDArray, grad_reqs = :write) =
_mark_variables!(var, grads, grad_reqs)
@inline function _getgrad_req(x::Symbol)::GRAD_REQ
val = get(grad_req_map, x, false)
if val == false
throw(ArgumentError("invalid grad_reqs $x"))
end
val
end
@inline _getgrad_reqs(x::Symbol, n::Int) =
map((_) -> MX_uint(_getgrad_req(x)), Base.OneTo(n))
@inline function _getgrad_reqs(xs::Vector{Symbol}, n::Int)
if length(xs) != n
throw(ArgumentError("number of variables and grad_reqs not matched"))
end
map(MX_uint _getgrad_req, xs)
end
@inline function _mark_variables!(vars::VecOfNDArray, grads::VecOfNDArray,
grad_reqs = :write)
n = length(vars)
if n != length(grads)
throw(ArgumentError("number of variables and gradients not matched"))
end
var_hdls = map(x -> x.handle, vars)
grad_hdls = map(x -> x.handle, grads)
grad_reqs = _getgrad_reqs(grad_reqs, n)
@mxcall(:MXAutogradMarkVariables,
(MX_uint, Ref{MX_handle}, Ptr{MX_uint}, Ref{MX_handle}),
length(vars), var_hdls, grad_reqs, grad_hdls)
end
"""
symbol(x::NDArray)
Retrieve recorded computation history as `SymbolicNode`,
where `x` is a `NDArray` representing the head of computation graph.
"""
function symbol(x::NDArray)
ref = Ref{MX_handle}(C_NULL)
@mxcall(:MXAutogradGetSymbol, (MX_handle, Ref{MX_handle}), x, ref)
SymbolicNode(MX_SymbolHandle(ref[]))
end
###############################################################################
# User-defined differentiable function
###############################################################################
# gc-free holder
const _cbs_r = [Ref{Ptr{Cvoid}}(C_NULL), Ref{Ptr{Cvoid}}(C_NULL)]
const _cbs = [Ptr{Cvoid}(C_NULL), Ptr{Cvoid}(C_NULL)]
const _cbsref = Ref{Ptr{Ptr{Cvoid}}}(C_NULL)
const _frefs = Dict() # hold custom function instance and its args
const _conds = []
function _back_wrapper(num_ograds, num_igrads, ptrs, reqs, is_train, fptr::Ptr{Cvoid})
# @info "_back_wrapper"
# hdls = unsafe_wrap(Array, ptrs, num_ograds + num_igrads)
# @info "_back_wrapper" hdls
# ograds = map(x -> NDArray(MX_NDArrayHandle(x), false), hdls[1:num_ograds])
# @info "_back_wrapper" ograds
# igrads = map(NDArray ∘ MX_NDArrayHandle, hdls[num_ograds+1:num_ograds+num_igrads])
# @info "_back_wrapper" igrads
# reqs = unsafe_wrap(Array, reqs, num_igrads)
# @info "_back_wrapper" reqs
#
# # passing closure via raw pointer
# f = unsafe_pointer_to_objref(fptr)
#
# Δs = backward!(f, ograds...)
# Δs = Δs isa NDArray ? [Δs] : Δs
#
# # update gradient
# for (i, Δ, req) ∈ zip(igrads, Δs, reqs)
# req = GRAD_REQ(req)
# if req == GRAD_NOP
# continue
# elseif req ∈ (GRAD_WRITE, GRAD_INPLACE)
# i[:] = Δ
# elseif req == GRAD_ADD
# i[:] += Δ
# end
# end
#
# # release ref for gc
# delete!(_frefs, f)
Cint(true)
end
function _back_wrapper(num_ograds, num_igrads, ptrs, reqs, is_train, handle)
ccall(:uv_async_send, Cint, (Ptr{Cvoid},), handle)
end
function _del_wrapper(handle)
ccall(:uv_async_send, Cint, (Ptr{Cvoid},), handle)
end
function _wtf_wrapper(handle)
ccall(:uv_async_send, Cint, (Ptr{Cvoid},), handle)
end
function _init_customfunc() # will be invoked in __init__
global _cbs_r
global _cbs
global _cbsref
# the callback function prototype:
# https://github.com/apache/incubator-mxnet/blob/ca565a00285d4fb0ca77ba9dc651a07ce1f01b24/include/mxnet/c_api.h#L209-L212
_cbs_r[1][] = _cbs[1] = @cfunction(_back_wrapper, Cint,
(Cint, Cint, Ptr{Ptr{Cvoid}}, Ptr{Cint},
Cint, Ptr{Cvoid}))
# _cbs_r[1][] = _cbs[1] = @cfunction(_wtf_wrapper, Cvoid, (Ptr{Cvoid},))
_cbs_r[2][] = _cbs[2] = @cfunction(_del_wrapper, Cint, (Ptr{Cvoid},))
_cbsref[] = Base.unsafe_convert(Ptr{Ptr{Cvoid}}, _cbs)
@info "_init_customfunc" _cbs _cbsref[]
end
struct MXCallbackList
n::Cint # int num_callbacks;
cbs::Ptr{Ptr{Cvoid}} # int (**callbacks)(Cvoid);
ctxs::Ptr{Ptr{Cvoid}} # void **contexts;
# we must provide two callback functions
# the first is backward function `_back_wrapper`
# the second is delete callback `_del_wrapper`
# https://github.com/apache/incubator-mxnet/blob/2f8c1e83f94e84a25a48d2cd43136030fb3f2d1e/include/mxnet/c_api.h#L174-L182
# `ctxs` is a array which is same size as `cbs`
# its elements will be passed as `state` for callback functions,
# usually the last argument.
# In our case, we will push the pointer of custom func instance as
# first element of `ctxs`; the pointer of MXCallbackList instance as
# the second element.
# The purpose of first pointer is to pass closure into `cfunction`.
# The second pointer is to free the reference of MXCallbackList,
# and let the function instance be GC-ed properly.
function MXCallbackList(f) # where all args are Refs
fr = Ref(f)
push!(_fholder, fr)
@info "f ref" Base.unsafe_convert(Ptr{Cvoid}, fr)
cond = Base.AsyncCondition() do cond
@info "real back callback"
A = ones(10000000)
for i 1:10000
B = A * A
end
@info "long run op end"
end
cond2 = Base.AsyncCondition() do cond
@info "real del callback"
end
push!(_conds, cond)
push!(_conds, cond2)
@info "conds" cond.handle cond2.handle
ctxs = [
cond.handle,
cond2.handle,
]
ctxsptr = Base.unsafe_convert(Ptr{Ptr{Cvoid}}, ctxs)
cblist = new(length(ctxs), _cbsref[], ctxsptr)
# get the reference, and make a self-reference in ctxs[2]
cblist_ref = Ref{MXCallbackList}(cblist)
ctxs[2] = Base.unsafe_convert(Ptr{Cvoid}, cblist_ref)
# insert ref into a holder to prevent from being GC-ed.
# hold `xs` and `ys` which is passed into `MXCustomFunctionRecord`.
_cblists[cblist_ref] = Ref(ctxs)
cblist_ref
end
end
# hold MXCallbackList to prevent from gc
const _cblists = Dict{Ref{MXCallbackList},Ref}()
const _fholder = []
"""
@custom
Create callable custom function.
All the position-arguments should be `NDArray`.
The return value should be a instance of your custom type.
Please checkout `examples/autograd/customfunc.jl` for example.
"""
macro custom(ex::Expr)
fdef = splitdef(ex) # by MacroTools
sig = ex.args[1]
body = esc(Expr(:let, Expr(:block), ex.args[2])) # create a new scope via `let`
# only extract symbols, get rid of all annotations and default values
args = map(x -> esc(splitarg(x)[1]), fdef[:args])
# forward(f, xs...)
forward_expr = Expr(:call, :forward, :f, args...)
# insert keyword args
if !isempty(fdef[:kwargs])
# only extract symbols, get rid of all annotations and default values
kwargs = map(fdef[:kwargs]) do x
sym = splitarg(x)[1]
Expr(:kw, sym, esc(sym))
end
append!(forward_expr.args, kwargs)
end
# xs, FIXME: a list of NDArray from positional argument
xs_len = length(args)
xs_expr = Expr(:vect, args...)
body = quote
f, ys = _record(false, nothing) do
f = $body # f is the object instance
ys = $forward_expr
f, ys
end
!is_recording() && return ys
xs = $xs_expr
ys = ys isa NDArray ? [ys] : ys
# struct MXCallbackList
cblist_ref = MXCallbackList(f)
# gc-free
xsr, ysr = Ref(xs), Ref(ys′)
_frefs[f] = (xsr, ysr)
# @info _frefs
@mxcall(
:MXCustomFunctionRecord,
(Cint, # num_inputs
Ref{MX_handle}, # inputs
Cint, # num_outputs
Ref{MX_handle}, # outputs
Ref{MXCallbackList}), # callbacks
$xs_len,
xs,
length(ys′),
ys′,
cblist_ref)
@info "inputs xs" Base.unsafe_convert(Ref{MX_handle}, xs)
@info "outputs ys" Base.unsafe_convert(Ref{MX_handle}, ys′)
ys
end
GC.enable(false) # FIXME
Expr(:function, esc(sig), body′)
end
# custom function should overload these functions.
# the # of forward return values is the inputs of backward!.
function forward end
function backward! end