Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[build_base][RLlib] APPO TF with RLModule and Learner API #33310

Merged
merged 51 commits into from
Mar 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
517e0d6
Temp
avnishn Mar 13, 2023
293ac57
Temp
avnishn Mar 13, 2023
996d6af
Temp
avnishn Mar 14, 2023
d83f4de
Temp
avnishn Mar 14, 2023
3f4e27c
Temp
avnishn Mar 14, 2023
3495576
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 14, 2023
606f093
Make all TfModels tf.keras.Models
ArturNiederfahrenhorst Mar 15, 2023
134593e
Merge branch 'fixtensorflowacmodels' into appo_tf
avnishn Mar 15, 2023
8c5ae9b
Running appo
avnishn Mar 15, 2023
8a2a0f5
Lint, small updates
avnishn Mar 16, 2023
efc74fe
Move adding params to learner hps to validate in order to be compatib…
avnishn Mar 16, 2023
13548c1
Move adding params to learner hps to validate in order to be compatib…
avnishn Mar 16, 2023
cd1270c
Move learner_hp assignment from builder functions to validate
avnishn Mar 16, 2023
1f00e42
Merge branch 'move_learner_hp_assignment' into appo_tf
avnishn Mar 16, 2023
2ec8d08
Temp
avnishn Mar 17, 2023
62753d1
Temp
avnishn Mar 17, 2023
6c91172
Clip is ratio
avnishn Mar 17, 2023
61be19b
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 17, 2023
f0ea920
Wrote appo tf policy rlm which has working loss but isn't seemingly u…
avnishn Mar 19, 2023
47d4b5a
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 19, 2023
fb69db2
Add option for minibatching in impala/appo with the learner group
avnishn Mar 20, 2023
bb0daeb
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 20, 2023
68cd9df
Store most recent result for results reporting
avnishn Mar 21, 2023
5008690
dmc wrapper types
avnishn Mar 21, 2023
064adee
Merge branch 'appo_tf' of https://github.com/avnishn/ray into appo_tf
avnishn Mar 21, 2023
b015015
ADd back in UpdateTargetAndKL
avnishn Mar 21, 2023
7842087
Merge branch 'appo_tf' of https://github.com/avnishn/ray; branch 'mas…
avnishn Mar 21, 2023
2b9fec4
Fix broken tests
avnishn Mar 21, 2023
3fb3615
More tf related fixes
avnishn Mar 21, 2023
31adb97
More tf related fixes
avnishn Mar 21, 2023
25fdcac
Fix impala test
avnishn Mar 21, 2023
c89e151
Fixing remaining broken tests
avnishn Mar 21, 2023
2f0cea9
More tf related fixes
avnishn Mar 22, 2023
d61d198
More tf fixes with try catch
avnishn Mar 22, 2023
e83b0d4
Addressing comments
avnishn Mar 22, 2023
2a5acbe
Address comments
avnishn Mar 23, 2023
44bf47a
Ad rl module with target networks mixin interface
avnishn Mar 23, 2023
3f113e0
Temp
avnishn Mar 24, 2023
5746a8f
Address comments
avnishn Mar 24, 2023
6d28903
Address comments
avnishn Mar 24, 2023
2b3bfb4
Address comments
avnishn Mar 24, 2023
288620d
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 24, 2023
24a1d6e
Fix broken import
avnishn Mar 24, 2023
9594a6c
Lint
avnishn Mar 24, 2023
b8e4ec2
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 24, 2023
3837d36
Touching a file
avnishn Mar 24, 2023
2d6ac06
triggering the tests
kouroshHakha Mar 25, 2023
5fabfc6
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 26, 2023
117d9dd
Merge branch 'appo_tf' of https://github.com/avnishn/ray into appo_tf
avnishn Mar 26, 2023
68a19ab
Merge branch 'master' into appo_tf
kouroshHakha Mar 26, 2023
9e0d54b
Merge branch 'appo_tf' of github.com:avnishn/ray into appo_tf
kouroshHakha Mar 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Temp
Signed-off-by: Avnish <[email protected]>
  • Loading branch information
avnishn committed Mar 17, 2023
commit 2ec8d088bc15fb4a6bd524a87c391121fac9d442
61 changes: 34 additions & 27 deletions rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,33 +262,38 @@ def after_train_step(self, train_results: ResultDict) -> None:
last_update = self._counters[LAST_TARGET_UPDATE_TS]

if self.config._enable_learner_api:
if train_results:
# using steps trained here instead of sampled ... I'm not sure why the
# other implemenetation uses sampled.
# to be quite frank, im not sure if I understand how their target update
# freq would work. The difference in steps sampled/trained is pretty
# much always going to be larger than self.config.num_sgd_iter *
# self.config.minibatch_buffer_size unless the number of steps collected
# is really small. The thing is that the default rollout fragment length
# is 50, so the minibatch buffer size * num_sgd_iter is going to be
# have to be 50 to even meet the threshold of having delayed target
# updates.
# we should instead have the target / kl threshold update be based off
# of the train_batch_size * some target update frequency * num_sgd_iter.
cur_ts = self._counters[
NUM_ENV_STEPS_TRAINED
if self.config.count_steps_by == "env_steps"
else NUM_AGENT_STEPS_TRAINED
]
target_update_steps_freq = (
self.config.num_sgd_iter
* self.config.train_batch_size
* self.config.target_update_frequency
)
if cur_ts - last_update > target_update_steps_freq:
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
self.learner_group.additional_update()
# using steps trained here instead of sampled ... I'm not sure why the
# other implemenetation uses sampled.
# to be quite frank, im not sure if I understand how their target update
# freq would work. The difference in steps sampled/trained is pretty
# much always going to be larger than self.config.num_sgd_iter *
# self.config.minibatch_buffer_size unless the number of steps collected
# is really small. The thing is that the default rollout fragment length
# is 50, so the minibatch buffer size * num_sgd_iter is going to be
# have to be 50 to even meet the threshold of having delayed target
# updates.
# we should instead have the target / kl threshold update be based off
# of the train_batch_size * some target update frequency * num_sgd_iter.
# cur_ts = self._counters[
# NUM_ENV_STEPS_TRAINED
# if self.config.count_steps_by == "env_steps"
# else NUM_AGENT_STEPS_TRAINED
# ]
# target_update_steps_freq = (
# self.config.num_sgd_iter
# * self.config.train_batch_size
# * self.config.target_update_frequency
# )
cur_ts = self._counters[
NUM_AGENT_STEPS_SAMPLED
if self.config.count_steps_by == "agent_steps"
else NUM_ENV_STEPS_SAMPLED
]
target_update_steps_freq = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hardcoded number?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

if cur_ts - last_update > target_update_steps_freq:
self._counters[NUM_TARGET_UPDATES] += 1
self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
self.learner_group.additional_update()

else:
cur_ts = self._counters[
Expand Down Expand Up @@ -374,6 +379,8 @@ def get_default_policy_class(
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2

return EagerTFPolicyV2
# from ray.rllib.algorithms.appo.tf.appo_tf_policy_rlm import APPOTfPolicyWithRLModule
# return APPOTfPolicyWithRLModule
from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF2Policy

return APPOTF2Policy
Expand Down
7 changes: 4 additions & 3 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,9 +686,9 @@ def training_step(self) -> ResultDict:
timeout_seconds=self.config.worker_health_probe_timeout_s,
mark_healthy=True,
)
if not train_results:
# adding this allows results to be properly logged by ray tune.
time.sleep(1e-1)
# if not train_results:
# # adding this allows results to be properly logged by ray tune.
# time.sleep(1e-1)
return train_results

@classmethod
Expand Down Expand Up @@ -865,6 +865,7 @@ def learn_on_processed_samples(self) -> ResultDict:
reduce_fn=_reduce_impala_results,
block=blocking,
num_iters=self.config.num_sgd_iter,
# minibatch_size=(2 * self.config.rollout_fragment_length)
)
else:
lg_results = None
Expand Down
18 changes: 14 additions & 4 deletions rllib/tuned_examples/appo/cartpole-appo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,29 @@ cartpole-appo:
config:
# Works for both torch and tf.
framework: tf2
num_workers: 4
num_workers:
grid_search:
- 3
num_gpus: 0
observation_filter: MeanStdFilter
num_sgd_iter: 6
vf_loss_coeff: 0.01
vtrace: True
grad_clip: 0

num_learner_workers: 1
model:
fcnet_hiddens: [32]
fcnet_activation: linear
vf_share_layers: true
enable_connectors: True
_enable_learner_api: True
_enable_rl_module_api: True
eager_tracing: False
# lr: 0.001
eager_tracing: True
lr: 0.001
seed:
grid_search:
- 1
- 2
- 3
- 4
- 5