blob: 91c2f5035ffa208c8b47dc893feb2c6e10d503ec [file] [log] [blame]
# coding: utf-8
# pylint: disable=no-member, protected-access, unused-import, no-name-in-module
"""Random number interface of MXNet."""
from __future__ import absolute_import
import ctypes
from .base import _LIB, check_call
from ._ndarray_internal import _sample_uniform as uniform
from ._ndarray_internal import _sample_normal as normal
from ._ndarray_internal import _sample_gamma as gamma
from ._ndarray_internal import _sample_exponential as exponential
from ._ndarray_internal import _sample_poisson as poisson
from ._ndarray_internal import _sample_negbinomial as negative_binomial
from ._ndarray_internal import _sample_gennegbinomial as generalized_negative_binomial
def seed(seed_state):
"""Seeds the random number generators in MXNet.
This affects the behavior of modules in MXNet that uses random number generators,
like the dropout operator and `NDArray`'s random sampling operators.
Parameters
----------
seed_state : int
The random number seed to set to all devices.
Notes
-----
Random number generators in MXNet are device specific. Therefore, random numbers
generated from two devices can be different even if they are seeded using the same seed.
Example
-------
>>> print(mx.nd.random_normal(shape=(2,2)).asnumpy())
[[ 1.36481571 -0.62203991]
[-1.4962182 -0.08511394]]
>>> print(mx.nd.random_normal(shape=(2,2)).asnumpy())
[[ 1.09544981 -0.20014545]
[-0.20808885 0.2527658 ]]
>>>
>>> mx.random.seed(128)
>>> print(mx.nd.random_normal(shape=(2,2)).asnumpy())
[[ 0.47400656 -0.75213492]
[ 0.20251541 0.95352972]]
>>> mx.random.seed(128)
>>> print(mx.nd.random_normal(shape=(2,2)).asnumpy())
[[ 0.47400656 -0.75213492]
[ 0.20251541 0.95352972]]
"""
if not isinstance(seed_state, int):
raise ValueError('sd must be int')
seed_state = ctypes.c_int(int(seed_state))
check_call(_LIB.MXRandomSeed(seed_state))