Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

support long for mx.random.seed #14314

Merged
merged 7 commits into from
Mar 5, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/mxnet/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from __future__ import absolute_import

import ctypes
from .base import _LIB, check_call
from .base import _LIB, check_call, integer_types
from .ndarray.random import *
from .context import Context

Expand Down Expand Up @@ -90,9 +90,9 @@ def seed(seed_state, ctx="all"):
[[ 2.5020072 -1.6884501]
[-0.7931333 -1.4218881]]
"""
if not isinstance(seed_state, int):
if not isinstance(seed_state, integer_types):
raise ValueError('seed_state must be int')
seed_state = ctypes.c_int(seed_state)
seed_state = ctypes.c_int(int(seed_state))
if ctx == "all":
check_call(_LIB.MXRandomSeed(seed_state))
else:
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,32 @@ def check_data(a, b):
for j in range(i+1, num_seeds):
check_data(data[i],data[j])

@with_seed()
def test_random_seed():
shape = (5, 5)
seed = rnd.randint(-(1 << 31), (1 << 31))
mx.random.seed(seed)

def _assert_same_mx_arrays(a, b):
assert len(a) == len(b)
for a_i, b_i in zip(a, b):
assert (a_i.asnumpy() == b_i.asnumpy()).all()

N = 100
v1 = [mx.nd.random_uniform(shape=shape) for _ in range(N)]
wkcn marked this conversation as resolved.
Show resolved Hide resolved

mx.random.seed(seed)
v2 = [mx.nd.random_uniform(shape=shape) for _ in range(N)]
_assert_same_mx_arrays(v1, v2)

try:
long
mx.random.seed(long(seed))
v3 = [mx.nd.random_uniform(shape=shape) for _ in range(N)]
_assert_same_mx_arrays(v1, v3)
except NameError:
pass

@with_seed()
def test_unique_zipfian_generator():
ctx = mx.context.current_context()
Expand Down