-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] Change default framework from tf to torch #33604
[RLlib] Change default framework from tf to torch #33604
Conversation
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
@@ -261,7 +261,7 @@ def __init__(self, algo_class=None): | |||
self.placement_strategy = "PACK" | |||
|
|||
# `self.framework()` | |||
self.framework_str = "tf" | |||
self.framework_str = "torch" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:fingers-crossed: :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Some example scripts used to run only on tf and will now only run on torch, but I guess that's ok. E.g. custom_metrics_and_callbacks
.
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
@@ -297,6 +298,6 @@ def new_policy_mapping_fn(agent_id, episode, worker, **kwargs): | |||
|
|||
# __export-models-as-onnx-begin__ | |||
# Using the same Policy object, we can also export our NN Model in the ONNX format: | |||
ppo_policy.export_model("/tmp/my_nn_model", onnx=True) | |||
ppo_policy.export_model("/tmp/my_nn_model", onnx=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update comment?
@@ -1618,7 +1618,7 @@ py_test( | |||
py_test( | |||
name = "connectors/tests/test_agent", | |||
tags = ["team:rllib", "connector"], | |||
size = "small", | |||
size = "medium", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is from some other pr right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The other PR will get merged and this difference will go away. I wanted to make sure the tests on CI doesn't get red b/c of time outs.
@@ -449,6 +449,7 @@ def compute_actions_from_input_dict( | |||
env_creator=lambda _: MultiAgentCartPole({"num_agents": 2}), | |||
default_policy_class=ModelBasedPolicy, | |||
config=DQNConfig() | |||
.framework("tf") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious, multi-agent env doesn't work with torch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does. This test is overfitted to tf.
ok ok |
* changed default in algo config * implicitly added tf framework to the test scripts Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
* changed default in algo config * implicitly added tf framework to the test scripts Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
* changed default in algo config * implicitly added tf framework to the test scripts Signed-off-by: Kourosh Hakhamaneshi <[email protected]> Signed-off-by: bhuang <[email protected]>
* changed default in algo config * implicitly added tf framework to the test scripts Signed-off-by: Kourosh Hakhamaneshi <[email protected]> Signed-off-by: Jonathan Carter <[email protected]>
* changed default in algo config * implicitly added tf framework to the test scripts Signed-off-by: Kourosh Hakhamaneshi <[email protected]> Signed-off-by: elliottower <[email protected]>
* changed default in algo config * implicitly added tf framework to the test scripts Signed-off-by: Kourosh Hakhamaneshi <[email protected]> Signed-off-by: Jack He <[email protected]>
Why are these changes needed?
This PR changes the default framework_str from tf to either torch or tf2. First step towards hopefully deprecating tf1 stack.
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.