Skip to content

Commit

Permalink
Revert accidental changes to test file. (#6681)
Browse files Browse the repository at this point in the history
  • Loading branch information
robertnishihara authored and pcmoritz committed Jan 3, 2020
1 parent b8669bc commit 80e77f7
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions python/ray/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ def test_simple_serialization(ray_start_regular):
np.float64(1.9),
]

if sys.version_info < (3, 0):
primitive_objects.append(long(0)) # noqa: E501,F821

composite_objects = (
[[obj]
for obj in primitive_objects] + [(obj, )
Expand All @@ -91,6 +88,25 @@ def f(x):
assert type(obj) == type(new_obj_2)


def test_background_tasks_with_max_calls(shutdown_only):
ray.init(num_cpus=2)

@ray.remote
def g():
time.sleep(.1)
return 0

@ray.remote(max_calls=1, max_retries=0)
def f():
return [g.remote()]

nested = ray.get([f.remote() for _ in range(10)])

# Should still be able to retrieve these objects, since f's workers will
# wait for g to finish before exiting.
ray.get([x[0] for x in nested])


def test_fair_queueing(shutdown_only):
ray.init(
num_cpus=1, _internal_config=json.dumps({
Expand Down Expand Up @@ -166,17 +182,7 @@ def assert_equal(obj1, obj2):
assert obj1 == obj2, "Objects {} and {} are different.".format(
obj1, obj2)

if sys.version_info >= (3, 0):
long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])]
else:

long_extras = [
long(0), # noqa: E501,F821
np.array([
["hi", u"hi"],
[1.3, long(1)] # noqa: E501,F821
])
]
long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])]

PRIMITIVE_OBJECTS = [
0, 0.0, 0.9, 1 << 62, 1 << 100, 1 << 999, [1 << 100, [1 << 100]], "a",
Expand Down Expand Up @@ -791,8 +797,6 @@ def f3(x):
assert ray.get(f3.remote(4)) == 4


@pytest.mark.skipif(
sys.version_info < (3, 0), reason="This test requires Python 3.")
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
Expand Down Expand Up @@ -828,8 +832,6 @@ def test_function(fn, remote_fn):
ray.get(remote_test_function.remote(local_method, actor_method))


@pytest.mark.skipif(
sys.version_info < (3, 0), reason="This test requires Python 3.")
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
Expand Down Expand Up @@ -871,8 +873,6 @@ def test_function(fn, remote_fn):
ray.get(remote_test_function.remote(local_method, actor_method))


@pytest.mark.skipif(
sys.version_info < (3, 0), reason="This test requires Python 3.")
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
Expand Down Expand Up @@ -1636,5 +1636,4 @@ def f(delay):

if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))

0 comments on commit 80e77f7

Please sign in to comment.