Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[build_base][RLlib] APPO TF with RLModule and Learner API #33310

Merged
merged 51 commits into from
Mar 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
517e0d6
Temp
avnishn Mar 13, 2023
293ac57
Temp
avnishn Mar 13, 2023
996d6af
Temp
avnishn Mar 14, 2023
d83f4de
Temp
avnishn Mar 14, 2023
3f4e27c
Temp
avnishn Mar 14, 2023
3495576
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 14, 2023
606f093
Make all TfModels tf.keras.Models
ArturNiederfahrenhorst Mar 15, 2023
134593e
Merge branch 'fixtensorflowacmodels' into appo_tf
avnishn Mar 15, 2023
8c5ae9b
Running appo
avnishn Mar 15, 2023
8a2a0f5
Lint, small updates
avnishn Mar 16, 2023
efc74fe
Move adding params to learner hps to validate in order to be compatib…
avnishn Mar 16, 2023
13548c1
Move adding params to learner hps to validate in order to be compatib…
avnishn Mar 16, 2023
cd1270c
Move learner_hp assignment from builder functions to validate
avnishn Mar 16, 2023
1f00e42
Merge branch 'move_learner_hp_assignment' into appo_tf
avnishn Mar 16, 2023
2ec8d08
Temp
avnishn Mar 17, 2023
62753d1
Temp
avnishn Mar 17, 2023
6c91172
Clip is ratio
avnishn Mar 17, 2023
61be19b
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 17, 2023
f0ea920
Wrote appo tf policy rlm which has working loss but isn't seemingly u…
avnishn Mar 19, 2023
47d4b5a
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 19, 2023
fb69db2
Add option for minibatching in impala/appo with the learner group
avnishn Mar 20, 2023
bb0daeb
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 20, 2023
68cd9df
Store most recent result for results reporting
avnishn Mar 21, 2023
5008690
dmc wrapper types
avnishn Mar 21, 2023
064adee
Merge branch 'appo_tf' of https://github.com/avnishn/ray into appo_tf
avnishn Mar 21, 2023
b015015
ADd back in UpdateTargetAndKL
avnishn Mar 21, 2023
7842087
Merge branch 'appo_tf' of https://github.com/avnishn/ray; branch 'mas…
avnishn Mar 21, 2023
2b9fec4
Fix broken tests
avnishn Mar 21, 2023
3fb3615
More tf related fixes
avnishn Mar 21, 2023
31adb97
More tf related fixes
avnishn Mar 21, 2023
25fdcac
Fix impala test
avnishn Mar 21, 2023
c89e151
Fixing remaining broken tests
avnishn Mar 21, 2023
2f0cea9
More tf related fixes
avnishn Mar 22, 2023
d61d198
More tf fixes with try catch
avnishn Mar 22, 2023
e83b0d4
Addressing comments
avnishn Mar 22, 2023
2a5acbe
Address comments
avnishn Mar 23, 2023
44bf47a
Ad rl module with target networks mixin interface
avnishn Mar 23, 2023
3f113e0
Temp
avnishn Mar 24, 2023
5746a8f
Address comments
avnishn Mar 24, 2023
6d28903
Address comments
avnishn Mar 24, 2023
2b3bfb4
Address comments
avnishn Mar 24, 2023
288620d
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 24, 2023
24a1d6e
Fix broken import
avnishn Mar 24, 2023
9594a6c
Lint
avnishn Mar 24, 2023
b8e4ec2
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 24, 2023
3837d36
Touching a file
avnishn Mar 24, 2023
2d6ac06
triggering the tests
kouroshHakha Mar 25, 2023
5fabfc6
Merge branch 'master' of https://github.com/ray-project/ray into appo_tf
avnishn Mar 26, 2023
117d9dd
Merge branch 'appo_tf' of https://github.com/avnishn/ray into appo_tf
avnishn Mar 26, 2023
68a19ab
Merge branch 'master' into appo_tf
kouroshHakha Mar 26, 2023
9e0d54b
Merge branch 'appo_tf' of github.com:avnishn/ray into appo_tf
kouroshHakha Mar 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Wrote appo tf policy rlm which has working loss but isn't seemingly u…
…pdating?

Signed-off-by: Avnish <[email protected]>
  • Loading branch information
avnishn committed Mar 19, 2023
commit f0ea9206e22693032e4547ee9871710710b56045
7 changes: 7 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,13 @@ py_test(
srcs = ["algorithms/appo/tests/test_appo_off_policyness.py"]
)

py_test(
name = "test_appo_learner",
tags = ["team:rllib", "algorithms_dir"],
size = "medium",
srcs = ["algorithms/appo/tests/tf/test_appo_learner.py"]
)

# ARS
py_test(
name = "test_ars",
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
NUM_AGENT_STEPS_SAMPLED,
NUM_ENV_STEPS_SAMPLED,
NUM_TARGET_UPDATES,
NUM_AGENT_STEPS_TRAINED,
NUM_ENV_STEPS_TRAINED,
)
from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY
from ray.rllib.utils.typing import (
Expand Down Expand Up @@ -378,7 +376,9 @@ def get_default_policy_class(
from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2

return EagerTFPolicyV2
# from ray.rllib.algorithms.appo.tf.appo_tf_policy_rlm import APPOTfPolicyWithRLModule
# TODO(avnishn): This policy class doesn't work just yet
# from ray.rllib.algorithms.appo.tf.appo_tf_policy_rlm import(
# ) APPOTfPolicyWithRLModule
# return APPOTfPolicyWithRLModule
from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF2Policy

Expand Down
Empty file.
109 changes: 109 additions & 0 deletions rllib/algorithms/appo/tests/tf/test_appo_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import unittest
import numpy as np

import ray
import ray.rllib.algorithms.appo as appo
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.metrics import ALL_MODULES
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check, framework_iterator


tf1, tf, _ = try_import_tf()

tf1.enable_eager_execution()

frag_length = 32

FAKE_BATCH = {
SampleBatch.OBS: np.random.uniform(low=0, high=1, size=(frag_length, 4)).astype(
np.float32
),
SampleBatch.ACTIONS: np.random.choice(2, frag_length).astype(np.float32),
SampleBatch.REWARDS: np.random.uniform(low=-1, high=1, size=(frag_length,)).astype(
np.float32
),
SampleBatch.TERMINATEDS: np.array(
[False for _ in range(frag_length - 1)] + [True]
).astype(np.float32),
SampleBatch.VF_PREDS: np.array(
list(reversed(range(frag_length))), dtype=np.float32
),
SampleBatch.ACTION_LOGP: np.log(
np.random.uniform(low=0, high=1, size=(frag_length,))
).astype(np.float32),
}


class TestImpalaTfLearner(unittest.TestCase):
@classmethod
def setUpClass(cls):
ray.init()

@classmethod
def tearDownClass(cls):
ray.shutdown()

def test_appo_loss(self):
"""Test that appo_policy_rlm loss matches the appo learner loss."""
config = (
appo.APPOConfig()
.environment("CartPole-v1")
.rollouts(
num_rollout_workers=0,
rollout_fragment_length=frag_length,
)
.resources(num_gpus=0)
.training(
gamma=0.99,
model=dict(
fcnet_hiddens=[10, 10],
fcnet_activation="linear",
vf_share_layers=False,
),
)
.rl_module(
_enable_rl_module_api=True,
)
)

for fw in framework_iterator(config, ("tf2")):
trainer = config.build()
policy = trainer.get_policy()

if fw == "tf2":
train_batch = tf.nest.map_structure(
lambda x: tf.convert_to_tensor(x), FAKE_BATCH
)
train_batch = SampleBatch(FAKE_BATCH)
policy_loss = policy.loss(policy.model, policy.dist_class, train_batch)

algo_config = config.copy(copy_frozen=False)
algo_config.training(_enable_learner_api=True)
algo_config.validate()
algo_config.freeze()

learner_group_config = algo_config.get_learner_group_config(
SingleAgentRLModuleSpec(
module_class=algo_config.rl_module_spec.module_class,
observation_space=policy.observation_space,
action_space=policy.action_space,
model_config_dict=policy.config["model"],
catalog_class=algo_config.rl_module_spec.catalog_class,
)
)
learner_group_config.num_learner_workers = 0
learner_group = learner_group_config.build()
learner_group.set_weights(trainer.get_weights())
results = learner_group.update(train_batch.as_multi_agent())
learner_group_loss = results[ALL_MODULES]["total_loss"]

check(learner_group_loss, policy_loss)


if __name__ == "__main__":
import pytest
import sys

sys.exit(pytest.main(["-v", __file__]))
10 changes: 7 additions & 3 deletions rllib/algorithms/appo/tf/appo_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,13 @@ def compute_loss_per_module(
)

# The policy gradients loss.
is_ratio = tf.clip_by_value(tf.math.exp(
behaviour_actions_logp_time_major - old_actions_logp_time_major
), 0.0, 2.0)
is_ratio = tf.clip_by_value(
tf.math.exp(
behaviour_actions_logp_time_major - old_actions_logp_time_major
),
0.0,
2.0,
)
logp_ratio = is_ratio * tf.math.exp(
target_actions_logp_time_major - behaviour_actions_logp_time_major
)
Expand Down
Loading