Skip to content

Commit

Permalink
[RLlib] break up the learner group tests into shorter tests (ray-proj…
Browse files Browse the repository at this point in the history
…ect#35926)

Signed-off-by: Avnish <[email protected]>
  • Loading branch information
avnishn committed Jun 1, 2023
1 parent ba31fdf commit e6b5b8b
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 137 deletions.
33 changes: 31 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2001,10 +2001,39 @@ py_test(

# Learner
py_test(
name = "test_learner_group",
name = "TestLearnerGroupSyncUpdate",
main = "core/learner/tests/test_learner_group.py",
tags = ["team:rllib", "multi_gpu", "exclusive"],
size = "large",
srcs = ["core/learner/tests/test_learner_group.py"]
srcs = ["core/learner/tests/test_learner_group.py"],
args = ["TestLearnerGroupSyncUpdate"]
)

py_test(
name = "TestLearnerGroupCheckpointRestore",
main = "core/learner/tests/test_learner_group.py",
tags = ["team:rllib", "multi_gpu", "exclusive"],
size = "large",
srcs = ["core/learner/tests/test_learner_group.py"],
args = ["TestLearnerGroupCheckpointRestore"]
)

py_test(
name = "TestLearnerGroupAsyncUpdate",
main = "core/learner/tests/test_learner_group.py",
tags = ["team:rllib", "multi_gpu", "exclusive"],
size = "large",
srcs = ["core/learner/tests/test_learner_group.py"],
args = ["TestLearnerGroupAsyncUpdate"]
)

py_test(
name = "TestLearnerGroupSaveLoadState",
main = "core/learner/tests/test_learner_group.py",
tags = ["team:rllib", "multi_gpu", "exclusive"],
size = "large",
srcs = ["core/learner/tests/test_learner_group.py"],
args = ["TestLearnerGroupSaveLoadState"]
)

py_test(
Expand Down
298 changes: 163 additions & 135 deletions rllib/core/learner/tests/test_learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def local_training_helper(self, fw, scaling_mode) -> None:
env = gym.make("CartPole-v1")
scaling_config = LOCAL_SCALING_CONFIGS[scaling_mode]
learner_group = get_learner_group(fw, env, scaling_config, eager_tracing=True)
local_learner = get_learner(fw, env)
local_learner = get_learner(framework=fw, env=env)
local_learner.build()

# make the state of the learner and the local learner_group identical
Expand Down Expand Up @@ -106,7 +106,7 @@ def local_training_helper(self, fw, scaling_mode) -> None:
check(local_learner.get_state(), learner_group.get_state())


class TestLearnerGroup(unittest.TestCase):
class TestLearnerGroupSyncUpdate(unittest.TestCase):
def setUp(self) -> None:
ray.init()

Expand All @@ -128,7 +128,7 @@ def test_learner_group_local(self):

def test_update_multigpu(self):
fws = ["torch", "tf2"]
scaling_modes = REMOTE_SCALING_CONFIGS.keys()
scaling_modes = ["multi-gpu-ddp", "remote-gpu"]
test_iterator = itertools.product(fws, scaling_modes)

for fw, scaling_mode in test_iterator:
Expand Down Expand Up @@ -181,7 +181,7 @@ def _check_multi_worker_weights(self, results: List[Dict[str, Any]]):

def test_add_remove_module(self):
fws = ["torch", "tf2"]
scaling_modes = REMOTE_SCALING_CONFIGS.keys()
scaling_modes = ["multi-gpu-ddp"]
test_iterator = itertools.product(fws, scaling_modes)

for fw, scaling_mode in test_iterator:
Expand Down Expand Up @@ -243,141 +243,19 @@ def test_add_remove_module(self):
learner_group.shutdown()
del learner_group

def test_async_update(self):
"""Test that async style updates converge to the same result as sync."""
fws = ["torch", "tf2"]
# async_update only needs to be tested for the most complex case.
# so we'll only test it for multi-gpu-ddp.
scaling_modes = ["multi-gpu-ddp"]
test_iterator = itertools.product(fws, scaling_modes)

for fw, scaling_mode in test_iterator:
print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.")
env = gym.make("CartPole-v1")
scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode]
learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
reader = get_cartpole_dataset_reader(batch_size=512)
min_loss = float("inf")
batch = reader.next()
timer_sync = _Timer()
timer_async = _Timer()
with timer_sync:
learner_group.update(batch.as_multi_agent(), reduce_fn=None)
with timer_async:
result_async = learner_group.async_update(
batch.as_multi_agent(), reduce_fn=None
)
# ideally the the first async update will return nothing, and an easy
# way to check that is if the time for an async update call is faster
# than the time for a sync update call.
self.assertLess(timer_async.mean, timer_sync.mean)
self.assertIsInstance(result_async, list)
self.assertEqual(len(result_async), 0)
iter_i = 0
while True:
batch = reader.next()
async_results = learner_group.async_update(
batch.as_multi_agent(), reduce_fn=None
)
if not async_results:
continue
losses = [
np.mean(
[res[ALL_MODULES][Learner.TOTAL_LOSS_KEY] for res in results]
)
for results in async_results
]
min_loss_this_iter = min(losses)
min_loss = min(min_loss_this_iter, min_loss)
print(
f"[iter = {iter_i}] Loss: {min_loss_this_iter:.3f}, Min Loss: "
f"{min_loss:.3f}"
)
# The loss is initially around 0.69 (ln2). When it gets to around
# 0.57 the return of the policy gets to around 100.
if min_loss < 0.57:
break
for results in async_results:
for res1, res2 in zip(results, results[1:]):
self.assertEqual(
res1[DEFAULT_POLICY_ID]["mean_weight"],
res2[DEFAULT_POLICY_ID]["mean_weight"],
)
iter_i += 1
learner_group.shutdown()
self.assertLess(min_loss, 0.57)

def test_save_load_state(self):
fws = ["torch", "tf2"]
# this is expanded to more scaling modes on the release ci.
scaling_modes = REMOTE_SCALING_CONFIGS.keys()

test_iterator = itertools.product(fws, scaling_modes)
batch = SampleBatch(FAKE_BATCH)
for fw, scaling_mode in test_iterator:
print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.")
env = gym.make("CartPole-v1")

scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode]
initial_learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)

# checkpoint the initial learner state for later comparison
initial_learner_checkpoint_dir = tempfile.TemporaryDirectory().name
initial_learner_group.save_state(initial_learner_checkpoint_dir)
initial_learner_group_weights = initial_learner_group.get_weights()

# do a single update
initial_learner_group.update(batch.as_multi_agent(), reduce_fn=None)

# checkpoint the learner state after 1 update for later comparison
learner_after_1_update_checkpoint_dir = tempfile.TemporaryDirectory().name
initial_learner_group.save_state(learner_after_1_update_checkpoint_dir)

# remove that learner, construct a new one, and load the state of the old
# learner into the new one
initial_learner_group.shutdown()
del initial_learner_group
new_learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
new_learner_group.load_state(learner_after_1_update_checkpoint_dir)

# do another update
results_with_break = new_learner_group.update(
batch.as_multi_agent(), reduce_fn=None
)
weights_after_1_update_with_break = new_learner_group.get_weights()
new_learner_group.shutdown()
del new_learner_group

# construct a new learner group and load the initial state of the learner
learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
learner_group.load_state(initial_learner_checkpoint_dir)
check(learner_group.get_weights(), initial_learner_group_weights)
learner_group.update(batch.as_multi_agent(), reduce_fn=None)
results_without_break = learner_group.update(
batch.as_multi_agent(), reduce_fn=None
)
weights_after_1_update_without_break = learner_group.get_weights()
learner_group.shutdown()
del learner_group
class TestLearnerGroupCheckpointRestore(unittest.TestCase):
def setUp(self) -> None:
ray.init()

# compare the results of the two updates
check(results_with_break, results_without_break)
check(
weights_after_1_update_with_break, weights_after_1_update_without_break
)
def tearDown(self) -> None:
ray.shutdown()

def test_load_module_state(self):
"""Test that module state can be loaded from a checkpoint."""
fws = ["torch", "tf2"]
# this is expanded to more scaling modes on the release ci.
scaling_modes = ["local-cpu", "multi-cpu-ddp", "multi-gpu-ddp"]
scaling_modes = ["local-cpu", "multi-gpu-ddp"]

test_iterator = itertools.product(fws, scaling_modes)
for fw, scaling_mode in test_iterator:
Expand Down Expand Up @@ -505,8 +383,158 @@ def test_load_module_state_errors(self):
del learner_group


class TestLearnerGroupSaveLoadState(unittest.TestCase):
def setUp(self) -> None:
ray.init()

def tearDown(self) -> None:
ray.shutdown()

def test_save_load_state(self):
"""Check that saving and loading learner group state works."""
fws = ["torch", "tf2"]
# this is expanded to more scaling modes on the release ci.
scaling_modes = ["multi-gpu-ddp", "local-cpu"]
test_iterator = itertools.product(fws, scaling_modes)
batch = SampleBatch(FAKE_BATCH)
for fw, scaling_mode in test_iterator:
print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.")
env = gym.make("CartPole-v1")

scaling_config = REMOTE_SCALING_CONFIGS.get(
scaling_mode
) or LOCAL_SCALING_CONFIGS.get(scaling_mode)
initial_learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)

# checkpoint the initial learner state for later comparison
initial_learner_checkpoint_dir = tempfile.TemporaryDirectory().name
initial_learner_group.save_state(initial_learner_checkpoint_dir)
initial_learner_group_weights = initial_learner_group.get_weights()

# do a single update
initial_learner_group.update(batch.as_multi_agent(), reduce_fn=None)

# checkpoint the learner state after 1 update for later comparison
learner_after_1_update_checkpoint_dir = tempfile.TemporaryDirectory().name
initial_learner_group.save_state(learner_after_1_update_checkpoint_dir)

# remove that learner, construct a new one, and load the state of the old
# learner into the new one
initial_learner_group.shutdown()
del initial_learner_group
new_learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
new_learner_group.load_state(learner_after_1_update_checkpoint_dir)

# do another update
results_with_break = new_learner_group.update(
batch.as_multi_agent(), reduce_fn=None
)
weights_after_1_update_with_break = new_learner_group.get_weights()
new_learner_group.shutdown()
del new_learner_group

# construct a new learner group and load the initial state of the learner
learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
learner_group.load_state(initial_learner_checkpoint_dir)
check(learner_group.get_weights(), initial_learner_group_weights)
learner_group.update(batch.as_multi_agent(), reduce_fn=None)
results_without_break = learner_group.update(
batch.as_multi_agent(), reduce_fn=None
)
weights_after_1_update_without_break = learner_group.get_weights()
learner_group.shutdown()
del learner_group

# compare the results of the two updates
check(results_with_break, results_without_break)
check(
weights_after_1_update_with_break, weights_after_1_update_without_break
)


class TestLearnerGroupAsyncUpdate(unittest.TestCase):
def setUp(self) -> None:
ray.init()

def tearDown(self) -> None:
ray.shutdown()

def test_async_update(self):
"""Test that async style updates converge to the same result as sync."""
fws = ["torch", "tf2"]
# async_update only needs to be tested for the most complex case.
# so we'll only test it for multi-gpu-ddp.
scaling_modes = ["multi-gpu-ddp", "remote-gpu"]
test_iterator = itertools.product(fws, scaling_modes)

for fw, scaling_mode in test_iterator:
print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.")
env = gym.make("CartPole-v1")
scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode]
learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
reader = get_cartpole_dataset_reader(batch_size=512)
min_loss = float("inf")
batch = reader.next()
timer_sync = _Timer()
timer_async = _Timer()
with timer_sync:
learner_group.update(batch.as_multi_agent(), reduce_fn=None)
with timer_async:
result_async = learner_group.async_update(
batch.as_multi_agent(), reduce_fn=None
)
# ideally the the first async update will return nothing, and an easy
# way to check that is if the time for an async update call is faster
# than the time for a sync update call.
self.assertLess(timer_async.mean, timer_sync.mean)
self.assertIsInstance(result_async, list)
self.assertEqual(len(result_async), 0)
iter_i = 0
while True:
batch = reader.next()
async_results = learner_group.async_update(
batch.as_multi_agent(), reduce_fn=None
)
if not async_results:
continue
losses = [
np.mean(
[res[ALL_MODULES][Learner.TOTAL_LOSS_KEY] for res in results]
)
for results in async_results
]
min_loss_this_iter = min(losses)
min_loss = min(min_loss_this_iter, min_loss)
print(
f"[iter = {iter_i}] Loss: {min_loss_this_iter:.3f}, Min Loss: "
f"{min_loss:.3f}"
)
# The loss is initially around 0.69 (ln2). When it gets to around
# 0.57 the return of the policy gets to around 100.
if min_loss < 0.57:
break
for results in async_results:
for res1, res2 in zip(results, results[1:]):
self.assertEqual(
res1[DEFAULT_POLICY_ID]["mean_weight"],
res2[DEFAULT_POLICY_ID]["mean_weight"],
)
iter_i += 1
learner_group.shutdown()
self.assertLess(min_loss, 0.57)


if __name__ == "__main__":
import pytest
import sys
import pytest

sys.exit(pytest.main(["-v", __file__]))
class_ = sys.argv[1] if len(sys.argv) > 1 else None
sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))

0 comments on commit e6b5b8b

Please sign in to comment.