Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Learner API: Policies using RLModules (for sampler only) do not need loss/stats/mixins. #34445

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
47c003c
wip
sven1977 Apr 15, 2023
cc25d0c
wip
sven1977 Apr 15, 2023
e6f2d3b
merge
sven1977 May 4, 2023
0248ce8
wip
sven1977 May 4, 2023
04595b3
wip
sven1977 May 4, 2023
bf78988
wip
sven1977 May 4, 2023
ceb8caa
wip
sven1977 May 4, 2023
cc632cb
Merge branch 'master' of https://github.com/ray-project/ray into lear…
sven1977 May 5, 2023
904bfda
merge
sven1977 May 5, 2023
b36304d
wip
sven1977 May 5, 2023
a112d5b
wip
sven1977 May 5, 2023
4b600ef
wip
sven1977 May 6, 2023
6dba48b
wip
sven1977 May 6, 2023
95fd464
wip
sven1977 May 6, 2023
6e8ab3f
wip
sven1977 May 6, 2023
b224d9f
wip
sven1977 May 6, 2023
26bdf8d
fix
sven1977 May 6, 2023
07409be
fix
sven1977 May 6, 2023
74f159f
wip
sven1977 May 6, 2023
70e6127
fix
sven1977 May 6, 2023
e4a58ee
fix
sven1977 May 6, 2023
c40cb8d
fix
sven1977 May 6, 2023
a2bc97a
fix
sven1977 May 6, 2023
81bcbd5
fix
sven1977 May 6, 2023
499d818
wip
sven1977 May 6, 2023
fd61ef6
LINT
sven1977 May 6, 2023
a53b043
fix
sven1977 May 6, 2023
f175a39
LINT
sven1977 May 6, 2023
3e95159
fix
sven1977 May 6, 2023
0ad186f
fix
sven1977 May 6, 2023
3a85eae
wip
sven1977 May 6, 2023
1380cb6
Merge remote-tracking branch 'origin/learner_rlm_policies_simplificat…
sven1977 May 6, 2023
e9ad050
merge
sven1977 May 6, 2023
05103ba
fix
sven1977 May 6, 2023
f0c1145
LINT
sven1977 May 6, 2023
90ee0ab
Add new Scheduler API.
sven1977 May 7, 2023
a628043
Merge branch 'master' of https://github.com/ray-project/ray into lear…
sven1977 May 7, 2023
6e9b0cd
wip
sven1977 May 7, 2023
7f85d0a
LINT
sven1977 May 7, 2023
1bb3ec2
fix
sven1977 May 7, 2023
21486d7
Merge branch 'master' of https://github.com/ray-project/ray into lear…
sven1977 May 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
LINT
Signed-off-by: sven1977 <[email protected]>
  • Loading branch information
sven1977 committed May 6, 2023
commit f175a3999c420b4f4c99cbe2e3aeb433bef79f0e
6 changes: 3 additions & 3 deletions rllib/algorithms/impala/tests/test_impala_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ def test_impala_loss(self):
policy = algo.get_policy()

if fw == "tf2":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait, is this necessary? why did this test work before when SampleBatch was in numpy fmt?

train_batch = SampleBatch(tree.map_structure(
lambda x: tf.convert_to_tensor(x), FAKE_BATCH
))
train_batch = SampleBatch(
tree.map_structure(lambda x: tf.convert_to_tensor(x), FAKE_BATCH)
)
elif fw == "torch":
train_batch = convert_to_torch_tensor(SampleBatch(FAKE_BATCH))

Expand Down
88 changes: 55 additions & 33 deletions rllib/algorithms/tests/test_algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,15 +348,19 @@ def get_default_rl_module_spec(self):
########################################
# This is the case where we pass in a multi-agent RLModuleSpec that asks the
# algorithm to assign a specific type of RLModule class to certain module_ids.
config = SingleAgentAlgoConfig().rl_module(
_enable_rl_module_api=True,
rl_module_spec=MultiAgentRLModuleSpec(
module_specs={
"p1": SingleAgentRLModuleSpec(module_class=CustomRLModule1),
"p2": SingleAgentRLModuleSpec(module_class=CustomRLModule1),
}
),
).training(_enable_learner_api=True)
config = (
SingleAgentAlgoConfig()
.rl_module(
_enable_rl_module_api=True,
rl_module_spec=MultiAgentRLModuleSpec(
module_specs={
"p1": SingleAgentRLModuleSpec(module_class=CustomRLModule1),
"p2": SingleAgentRLModuleSpec(module_class=CustomRLModule1),
},
),
)
.training(_enable_learner_api=True)
)
config.validate()

spec, expected = self._get_expected_marl_spec(config, CustomRLModule1)
Expand All @@ -365,10 +369,14 @@ def get_default_rl_module_spec(self):
########################################
# This is the case where we ask the algorithm to assign a specific type of
# RLModule class to ALL module_ids.
config = SingleAgentAlgoConfig().rl_module(
_enable_rl_module_api=True,
rl_module_spec=SingleAgentRLModuleSpec(module_class=CustomRLModule1),
).training(_enable_learner_api=True)
config = (
SingleAgentAlgoConfig()
.rl_module(
_enable_rl_module_api=True,
rl_module_spec=SingleAgentRLModuleSpec(module_class=CustomRLModule1),
)
.training(_enable_learner_api=True)
)
config.validate()

spec, expected = self._get_expected_marl_spec(config, CustomRLModule1)
Expand All @@ -382,12 +390,16 @@ def get_default_rl_module_spec(self):
########################################
# This is an alternative way to ask the algorithm to assign a specific type of
# RLModule class to ALL module_ids.
config = SingleAgentAlgoConfig().rl_module(
_enable_rl_module_api=True,
rl_module_spec=MultiAgentRLModuleSpec(
module_specs=SingleAgentRLModuleSpec(module_class=CustomRLModule1)
),
).training(_enable_learner_api=True)
config = (
SingleAgentAlgoConfig()
.rl_module(
_enable_rl_module_api=True,
rl_module_spec=MultiAgentRLModuleSpec(
module_specs=SingleAgentRLModuleSpec(module_class=CustomRLModule1)
),
)
.training(_enable_learner_api=True)
)
config.validate()

spec, expected = self._get_expected_marl_spec(config, CustomRLModule1)
Expand All @@ -403,16 +415,20 @@ def get_default_rl_module_spec(self):
# This is not only assigning a specific type of RLModule class to EACH
# module_id, but also defining a new custom MultiAgentRLModule class to be used
# in the multi-agent scenario.
config = SingleAgentAlgoConfig().rl_module(
_enable_rl_module_api=True,
rl_module_spec=MultiAgentRLModuleSpec(
marl_module_class=CustomMARLModule1,
module_specs={
"p1": SingleAgentRLModuleSpec(module_class=CustomRLModule1),
"p2": SingleAgentRLModuleSpec(module_class=CustomRLModule1),
},
),
).training(_enable_learner_api=True)
config = (
SingleAgentAlgoConfig()
.rl_module(
_enable_rl_module_api=True,
rl_module_spec=MultiAgentRLModuleSpec(
marl_module_class=CustomMARLModule1,
module_specs={
"p1": SingleAgentRLModuleSpec(module_class=CustomRLModule1),
"p2": SingleAgentRLModuleSpec(module_class=CustomRLModule1),
},
),
)
.training(_enable_learner_api=True)
)
config.validate()

spec, expected = self._get_expected_marl_spec(
Expand Down Expand Up @@ -440,9 +456,11 @@ def get_default_rl_module_spec(self):
# This is the case where we ask the algorithm to use its default
# MultiAgentRLModuleSpec, but the MultiAgentRLModuleSpec has not defined its
# SingleAgentRLmoduleSpecs.
config = MultiAgentAlgoConfigWithNoSingleAgentSpec().rl_module(
_enable_rl_module_api=True
).training(_enable_learner_api=True)
config = (
MultiAgentAlgoConfigWithNoSingleAgentSpec()
.rl_module(_enable_rl_module_api=True)
.training(_enable_learner_api=True)
)

self.assertRaisesRegex(
ValueError,
Expand All @@ -454,7 +472,11 @@ def get_default_rl_module_spec(self):
# This is the case where we ask the algorithm to use its default
# MultiAgentRLModuleSpec, and the MultiAgentRLModuleSpec has defined its
# SingleAgentRLmoduleSpecs.
config = MultiAgentAlgoConfig().rl_module(_enable_rl_module_api=True).training(_enable_learner_api=True)
config = (
MultiAgentAlgoConfig()
.rl_module(_enable_rl_module_api=True)
.training(_enable_learner_api=True)
)
config.validate()

spec, expected = self._get_expected_marl_spec(
Expand Down