Skip to content

Commit

Permalink
support long for mx.random.seed (apache#14314)
Browse files Browse the repository at this point in the history
* support long for mx.random.seed

* update test_random

* reorder

* use mx.random.uniform

* trigger CI

* retrigger CI
  • Loading branch information
wkcn authored and haohuw committed Jun 23, 2019
1 parent 131ff05 commit 4215ef8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
6 changes: 3 additions & 3 deletions python/mxnet/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from __future__ import absolute_import

import ctypes
from .base import _LIB, check_call
from .base import _LIB, check_call, integer_types
from .ndarray.random import *
from .context import Context

Expand Down Expand Up @@ -90,9 +90,9 @@ def seed(seed_state, ctx="all"):
[[ 2.5020072 -1.6884501]
[-0.7931333 -1.4218881]]
"""
if not isinstance(seed_state, int):
if not isinstance(seed_state, integer_types):
raise ValueError('seed_state must be int')
seed_state = ctypes.c_int(seed_state)
seed_state = ctypes.c_int(int(seed_state))
if ctx == "all":
check_call(_LIB.MXRandomSeed(seed_state))
else:
Expand Down
29 changes: 28 additions & 1 deletion tests/python/unittest/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def test_parallel_random_seed_setting():
# Avoid excessive test cpu runtimes
num_temp_seeds = 25 if ctx.device_type == 'gpu' else 1
# To flush out a possible race condition, run multiple times

for _ in range(20):
# Create enough samples such that we get a meaningful distribution.
shape = (200, 200)
Expand Down Expand Up @@ -670,7 +671,7 @@ def gen_data(seed=None):
with random_seed(seed):
python_data = [rnd.random() for _ in range(size)]
np_data = np.random.rand(size)
mx_data = mx.nd.random_uniform(shape=shape, ctx=ctx).asnumpy()
mx_data = mx.random.uniform(shape=shape, ctx=ctx).asnumpy()
return (seed, python_data, np_data, mx_data)

# check data, expecting them to be the same or different based on the seeds
Expand Down Expand Up @@ -712,6 +713,32 @@ def check_data(a, b):
for j in range(i+1, num_seeds):
check_data(data[i],data[j])

@with_seed()
def test_random_seed():
shape = (5, 5)
seed = rnd.randint(-(1 << 31), (1 << 31))

def _assert_same_mx_arrays(a, b):
assert len(a) == len(b)
for a_i, b_i in zip(a, b):
assert (a_i.asnumpy() == b_i.asnumpy()).all()

N = 100
mx.random.seed(seed)
v1 = [mx.random.uniform(shape=shape) for _ in range(N)]

mx.random.seed(seed)
v2 = [mx.random.uniform(shape=shape) for _ in range(N)]
_assert_same_mx_arrays(v1, v2)

try:
long
mx.random.seed(long(seed))
v3 = [mx.random.uniform(shape=shape) for _ in range(N)]
_assert_same_mx_arrays(v1, v3)
except NameError:
pass

@with_seed()
def test_unique_zipfian_generator():
ctx = mx.context.current_context()
Expand Down

0 comments on commit 4215ef8

Please sign in to comment.