blob: 91c2f5035ffa208c8b47dc893feb2c6e10d503ec [file] [log] [blame]
# coding: utf-8
# pylint: disable=no-member, protected-access, unused-import, no-name-in-module
"""Random number interface of MXNet."""
from __future__ import absolute_import
import ctypes
from .base import _LIB, check_call
from ._ndarray_internal import _sample_uniform as uniform
from ._ndarray_internal import _sample_normal as normal
from ._ndarray_internal import _sample_gamma as gamma
from ._ndarray_internal import _sample_exponential as exponential
from ._ndarray_internal import _sample_poisson as poisson
from ._ndarray_internal import _sample_negbinomial as negative_binomial
from ._ndarray_internal import _sample_gennegbinomial as generalized_negative_binomial
def seed(seed_state):
"""Seeds the random number generators in MXNet.
This affects the behavior of modules in MXNet that uses random number generators,
like the dropout operator and `NDArray`'s random sampling operators.
seed_state : int
The random number seed to set to all devices.
Random number generators in MXNet are device specific. Therefore, random numbers
generated from two devices can be different even if they are seeded using the same seed.
>>> print(mx.nd.random_normal(shape=(2,2)).asnumpy())
[[ 1.36481571 -0.62203991]
[-1.4962182 -0.08511394]]
>>> print(mx.nd.random_normal(shape=(2,2)).asnumpy())
[[ 1.09544981 -0.20014545]
[-0.20808885 0.2527658 ]]
>>> mx.random.seed(128)
>>> print(mx.nd.random_normal(shape=(2,2)).asnumpy())
[[ 0.47400656 -0.75213492]
[ 0.20251541 0.95352972]]
>>> mx.random.seed(128)
>>> print(mx.nd.random_normal(shape=(2,2)).asnumpy())
[[ 0.47400656 -0.75213492]
[ 0.20251541 0.95352972]]
if not isinstance(seed_state, int):
raise ValueError('sd must be int')
seed_state = ctypes.c_int(int(seed_state))