blob: d1a2c30ea2aac9c1e22cee3aa62abc03a9594cc6 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Deferred Compute for NDArray."""
import ctypes
import contextlib
from .base import _LIB, check_call, SymbolHandle, _as_list
from .symbol import Symbol
__all__ = []
def is_deferred_compute():
"""Get status of deferred compute mode."""
curr = ctypes.c_bool()
check_call(_LIB.MXNDArrayIsDeferredCompute(ctypes.byref(curr)))
return curr.value
def set_deferred_compute(state):
"""Enable / Disable deferred compute mode.
Parameters
----------
state: bool
Returns
-------
Previous deferred compute state.
"""
prev = ctypes.c_int()
check_call(_LIB.MXNDArraySetIsDeferredCompute(ctypes.c_int(state), ctypes.byref(prev)))
return bool(prev.value)
@contextlib.contextmanager
def context(state=True):
"""Set deferred compute state to `state` within context. Reset afterwards to previous value."""
# Like other MXNet context manager, this bleeds state across concurrent
# code: "Context managers that have state should use Context Variables
# instead of threading.local() to prevent their state from bleeding to
# other code unexpectedly, when used in concurrent code."
# https://github.com/apache/incubator-mxnet/issues/17495#issuecomment-585461965
val = set_deferred_compute(state)
try:
yield
finally:
set_deferred_compute(val)
def get_symbol(output_arrays, *, sym_cls=Symbol):
"""Get symbolic representation of computation recorded in deferred compute mode.
Parameters
----------
output_arrays: NDArray or List[NDArray]
sym_cls: class used to construct Symbol
Returns
-------
Symbol of sym_cls
"""
output_arrays = _as_list(output_arrays)
# Prepare ctypes array types
output_handles_type = ctypes.c_void_p * len(output_arrays)
# Convert handles
output_handles = output_handles_type(*[array.handle for array in output_arrays])
handle = SymbolHandle()
check_call(_LIB.MXNDArrayGetDeferredComputeSymbol(output_handles, len(output_arrays),
ctypes.byref(handle)))
return sym_cls(handle)
def set_variable(arrays, variables):
"""Associate variables with arrays.
Parameters
----------
arrays: NDArray or List[NDArray]
variables: Symbol or List[Symbol] of variables
"""
arrays = _as_list(arrays)
variables = _as_list(variables)
# Prepare ctypes array types
arrays_type = variables_type = ctypes.c_void_p * len(arrays)
# Convert handles
arrays = arrays_type(*[array.handle for array in arrays])
variables = variables_type(*[symbol.handle for symbol in variables])
check_call(_LIB.MXNDArraySetDeferredComputeVariable(arrays, variables, len(arrays)))
def clear(arrays):
"""Clear the dc info node associated with output variables.
Parameters
----------
arrays: NDArray or List[NDArray]
"""
arrays = _as_list(arrays)
# Prepare ctypes array types
arrays_type = ctypes.c_void_p * len(arrays)
# Convert handles
arrays = arrays_type(*[array.handle for array in arrays])
check_call(_LIB.MXNDArrayClearDeferredCompute(arrays, len(arrays)))