Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] break up the learner group tests into shorter tests #35926

Merged
33 changes: 31 additions & 2 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1994,10 +1994,39 @@ py_test(

# Learner
py_test(
name = "test_learner_group",
name = "TestLearnerGroupSyncUpdate",
main = "core/learner/tests/test_learner_group.py",
tags = ["team:rllib", "multi_gpu", "exclusive"],
size = "large",
srcs = ["core/learner/tests/test_learner_group.py"]
srcs = ["core/learner/tests/test_learner_group.py"],
args = ["TestLearnerGroupSyncUpdate"]
)

py_test(
name = "TestLearnerGroupCheckpointRestore",
main = "core/learner/tests/test_learner_group.py",
tags = ["team:rllib", "multi_gpu", "exclusive"],
size = "large",
srcs = ["core/learner/tests/test_learner_group.py"],
args = ["TestLearnerGroupCheckpointRestore"]
)

py_test(
name = "TestLearnerGroupAsyncUpdate",
main = "core/learner/tests/test_learner_group.py",
tags = ["team:rllib", "multi_gpu", "exclusive"],
size = "large",
srcs = ["core/learner/tests/test_learner_group.py"],
args = ["TestLearnerGroupAsyncUpdate"]
)

py_test(
name = "TestLearnerGroupSaveLoadState",
main = "core/learner/tests/test_learner_group.py",
tags = ["team:rllib", "multi_gpu", "exclusive"],
size = "large",
srcs = ["core/learner/tests/test_learner_group.py"],
args = ["TestLearnerGroupSaveLoadState"]
)

py_test(
Expand Down
298 changes: 163 additions & 135 deletions rllib/core/learner/tests/test_learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def local_training_helper(self, fw, scaling_mode) -> None:
env = gym.make("CartPole-v1")
scaling_config = LOCAL_SCALING_CONFIGS[scaling_mode]
learner_group = get_learner_group(fw, env, scaling_config, eager_tracing=True)
local_learner = get_learner(fw, env)
local_learner = get_learner(framework=fw, env=env)
local_learner.build()

# make the state of the learner and the local learner_group identical
Expand Down Expand Up @@ -106,7 +106,7 @@ def local_training_helper(self, fw, scaling_mode) -> None:
check(local_learner.get_state(), learner_group.get_state())


class TestLearnerGroup(unittest.TestCase):
class TestLearnerGroupSyncUpdate(unittest.TestCase):
def setUp(self) -> None:
ray.init()

Expand All @@ -128,7 +128,7 @@ def test_learner_group_local(self):

def test_update_multigpu(self):
fws = ["torch", "tf2"]
scaling_modes = REMOTE_SCALING_CONFIGS.keys()
scaling_modes = ["multi-gpu-ddp", "remote-gpu"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@avnishn Can you add explanations on why these changes are made? REMOTE_SCALING_CONFIGS.keys() --> ["multi-gpu-ddp", "remote-gpu"]?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can reduce the number of tests that we need to run while getting full test coverage by doing this.

test_iterator = itertools.product(fws, scaling_modes)

for fw, scaling_mode in test_iterator:
Expand Down Expand Up @@ -181,7 +181,7 @@ def _check_multi_worker_weights(self, results: List[Dict[str, Any]]):

def test_add_remove_module(self):
fws = ["torch", "tf2"]
scaling_modes = REMOTE_SCALING_CONFIGS.keys()
scaling_modes = ["multi-gpu-ddp"]
test_iterator = itertools.product(fws, scaling_modes)

for fw, scaling_mode in test_iterator:
Expand Down Expand Up @@ -243,141 +243,19 @@ def test_add_remove_module(self):
learner_group.shutdown()
del learner_group

def test_async_update(self):
"""Test that async style updates converge to the same result as sync."""
fws = ["torch", "tf2"]
# async_update only needs to be tested for the most complex case.
# so we'll only test it for multi-gpu-ddp.
scaling_modes = ["multi-gpu-ddp"]
test_iterator = itertools.product(fws, scaling_modes)

for fw, scaling_mode in test_iterator:
print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.")
env = gym.make("CartPole-v1")
scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode]
learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
reader = get_cartpole_dataset_reader(batch_size=512)
min_loss = float("inf")
batch = reader.next()
timer_sync = _Timer()
timer_async = _Timer()
with timer_sync:
learner_group.update(batch.as_multi_agent(), reduce_fn=None)
with timer_async:
result_async = learner_group.async_update(
batch.as_multi_agent(), reduce_fn=None
)
# ideally the the first async update will return nothing, and an easy
# way to check that is if the time for an async update call is faster
# than the time for a sync update call.
self.assertLess(timer_async.mean, timer_sync.mean)
self.assertIsInstance(result_async, list)
self.assertEqual(len(result_async), 0)
iter_i = 0
while True:
batch = reader.next()
async_results = learner_group.async_update(
batch.as_multi_agent(), reduce_fn=None
)
if not async_results:
continue
losses = [
np.mean(
[res[ALL_MODULES][Learner.TOTAL_LOSS_KEY] for res in results]
)
for results in async_results
]
min_loss_this_iter = min(losses)
min_loss = min(min_loss_this_iter, min_loss)
print(
f"[iter = {iter_i}] Loss: {min_loss_this_iter:.3f}, Min Loss: "
f"{min_loss:.3f}"
)
# The loss is initially around 0.69 (ln2). When it gets to around
# 0.57 the return of the policy gets to around 100.
if min_loss < 0.57:
break
for results in async_results:
for res1, res2 in zip(results, results[1:]):
self.assertEqual(
res1[DEFAULT_POLICY_ID]["mean_weight"],
res2[DEFAULT_POLICY_ID]["mean_weight"],
)
iter_i += 1
learner_group.shutdown()
self.assertLess(min_loss, 0.57)

def test_save_load_state(self):
fws = ["torch", "tf2"]
# this is expanded to more scaling modes on the release ci.
scaling_modes = REMOTE_SCALING_CONFIGS.keys()

test_iterator = itertools.product(fws, scaling_modes)
batch = SampleBatch(FAKE_BATCH)
for fw, scaling_mode in test_iterator:
print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.")
env = gym.make("CartPole-v1")

scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode]
initial_learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)

# checkpoint the initial learner state for later comparison
initial_learner_checkpoint_dir = tempfile.TemporaryDirectory().name
initial_learner_group.save_state(initial_learner_checkpoint_dir)
initial_learner_group_weights = initial_learner_group.get_weights()

# do a single update
initial_learner_group.update(batch.as_multi_agent(), reduce_fn=None)

# checkpoint the learner state after 1 update for later comparison
learner_after_1_update_checkpoint_dir = tempfile.TemporaryDirectory().name
initial_learner_group.save_state(learner_after_1_update_checkpoint_dir)

# remove that learner, construct a new one, and load the state of the old
# learner into the new one
initial_learner_group.shutdown()
del initial_learner_group
new_learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
new_learner_group.load_state(learner_after_1_update_checkpoint_dir)

# do another update
results_with_break = new_learner_group.update(
batch.as_multi_agent(), reduce_fn=None
)
weights_after_1_update_with_break = new_learner_group.get_weights()
new_learner_group.shutdown()
del new_learner_group

# construct a new learner group and load the initial state of the learner
learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
learner_group.load_state(initial_learner_checkpoint_dir)
check(learner_group.get_weights(), initial_learner_group_weights)
learner_group.update(batch.as_multi_agent(), reduce_fn=None)
results_without_break = learner_group.update(
batch.as_multi_agent(), reduce_fn=None
)
weights_after_1_update_without_break = learner_group.get_weights()
learner_group.shutdown()
del learner_group
class TestLearnerGroupCheckpointRestore(unittest.TestCase):
def setUp(self) -> None:
ray.init()

# compare the results of the two updates
check(results_with_break, results_without_break)
check(
weights_after_1_update_with_break, weights_after_1_update_without_break
)
def tearDown(self) -> None:
ray.shutdown()

def test_load_module_state(self):
"""Test that module state can be loaded from a checkpoint."""
fws = ["torch", "tf2"]
# this is expanded to more scaling modes on the release ci.
scaling_modes = ["local-cpu", "multi-cpu-ddp", "multi-gpu-ddp"]
scaling_modes = ["local-cpu", "multi-gpu-ddp"]

test_iterator = itertools.product(fws, scaling_modes)
for fw, scaling_mode in test_iterator:
Expand Down Expand Up @@ -505,8 +383,158 @@ def test_load_module_state_errors(self):
del learner_group


class TestLearnerGroupSaveLoadState(unittest.TestCase):
def setUp(self) -> None:
ray.init()

def tearDown(self) -> None:
ray.shutdown()

def test_save_load_state(self):
"""Check that saving and loading learner group state works."""
fws = ["torch", "tf2"]
# this is expanded to more scaling modes on the release ci.
scaling_modes = ["multi-gpu-ddp", "local-cpu"]
test_iterator = itertools.product(fws, scaling_modes)
batch = SampleBatch(FAKE_BATCH)
for fw, scaling_mode in test_iterator:
print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.")
env = gym.make("CartPole-v1")

scaling_config = REMOTE_SCALING_CONFIGS.get(
scaling_mode
) or LOCAL_SCALING_CONFIGS.get(scaling_mode)
initial_learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)

# checkpoint the initial learner state for later comparison
initial_learner_checkpoint_dir = tempfile.TemporaryDirectory().name
initial_learner_group.save_state(initial_learner_checkpoint_dir)
initial_learner_group_weights = initial_learner_group.get_weights()

# do a single update
initial_learner_group.update(batch.as_multi_agent(), reduce_fn=None)

# checkpoint the learner state after 1 update for later comparison
learner_after_1_update_checkpoint_dir = tempfile.TemporaryDirectory().name
initial_learner_group.save_state(learner_after_1_update_checkpoint_dir)

# remove that learner, construct a new one, and load the state of the old
# learner into the new one
initial_learner_group.shutdown()
del initial_learner_group
new_learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
new_learner_group.load_state(learner_after_1_update_checkpoint_dir)

# do another update
results_with_break = new_learner_group.update(
batch.as_multi_agent(), reduce_fn=None
)
weights_after_1_update_with_break = new_learner_group.get_weights()
new_learner_group.shutdown()
del new_learner_group

# construct a new learner group and load the initial state of the learner
learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
learner_group.load_state(initial_learner_checkpoint_dir)
check(learner_group.get_weights(), initial_learner_group_weights)
learner_group.update(batch.as_multi_agent(), reduce_fn=None)
results_without_break = learner_group.update(
batch.as_multi_agent(), reduce_fn=None
)
weights_after_1_update_without_break = learner_group.get_weights()
learner_group.shutdown()
del learner_group

# compare the results of the two updates
check(results_with_break, results_without_break)
check(
weights_after_1_update_with_break, weights_after_1_update_without_break
)


class TestLearnerGroupAsyncUpdate(unittest.TestCase):
def setUp(self) -> None:
ray.init()

def tearDown(self) -> None:
ray.shutdown()

def test_async_update(self):
"""Test that async style updates converge to the same result as sync."""
fws = ["torch", "tf2"]
# async_update only needs to be tested for the most complex case.
# so we'll only test it for multi-gpu-ddp.
scaling_modes = ["multi-gpu-ddp", "remote-gpu"]
test_iterator = itertools.product(fws, scaling_modes)

for fw, scaling_mode in test_iterator:
print(f"Testing framework: {fw}, scaling mode: {scaling_mode}.")
env = gym.make("CartPole-v1")
scaling_config = REMOTE_SCALING_CONFIGS[scaling_mode]
learner_group = get_learner_group(
fw, env, scaling_config, eager_tracing=True
)
reader = get_cartpole_dataset_reader(batch_size=512)
min_loss = float("inf")
batch = reader.next()
timer_sync = _Timer()
timer_async = _Timer()
with timer_sync:
learner_group.update(batch.as_multi_agent(), reduce_fn=None)
with timer_async:
result_async = learner_group.async_update(
batch.as_multi_agent(), reduce_fn=None
)
# ideally the the first async update will return nothing, and an easy
# way to check that is if the time for an async update call is faster
# than the time for a sync update call.
self.assertLess(timer_async.mean, timer_sync.mean)
self.assertIsInstance(result_async, list)
self.assertEqual(len(result_async), 0)
iter_i = 0
while True:
batch = reader.next()
async_results = learner_group.async_update(
batch.as_multi_agent(), reduce_fn=None
)
if not async_results:
continue
losses = [
np.mean(
[res[ALL_MODULES][Learner.TOTAL_LOSS_KEY] for res in results]
)
for results in async_results
]
min_loss_this_iter = min(losses)
min_loss = min(min_loss_this_iter, min_loss)
print(
f"[iter = {iter_i}] Loss: {min_loss_this_iter:.3f}, Min Loss: "
f"{min_loss:.3f}"
)
# The loss is initially around 0.69 (ln2). When it gets to around
# 0.57 the return of the policy gets to around 100.
if min_loss < 0.57:
break
for results in async_results:
for res1, res2 in zip(results, results[1:]):
self.assertEqual(
res1[DEFAULT_POLICY_ID]["mean_weight"],
res2[DEFAULT_POLICY_ID]["mean_weight"],
)
iter_i += 1
learner_group.shutdown()
self.assertLess(min_loss, 0.57)


if __name__ == "__main__":
import pytest
import sys
import pytest

sys.exit(pytest.main(["-v", __file__]))
class_ = sys.argv[1] if len(sys.argv) > 1 else None
sys.exit(pytest.main(["-v", __file__ + ("" if class_ is None else "::" + class_)]))