diff --git a/rllib/BUILD b/rllib/BUILD index afd1202f8a03c..4f6ce678c38aa 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -927,6 +927,13 @@ py_test( srcs = ["algorithms/appo/tests/test_appo_off_policyness.py"] ) +py_test( + name = "test_appo_learner", + tags = ["team:rllib", "algorithms_dir"], + size = "medium", + srcs = ["algorithms/appo/tests/tf/test_appo_learner.py"] +) + # ARS py_test( name = "test_ars", diff --git a/rllib/algorithms/appo/appo.py b/rllib/algorithms/appo/appo.py index eba5fdbb01b58..9202a4395f694 100644 --- a/rllib/algorithms/appo/appo.py +++ b/rllib/algorithms/appo/appo.py @@ -14,8 +14,10 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.algorithms.impala.impala import Impala, ImpalaConfig +from ray.rllib.algorithms.appo.tf.appo_tf_learner import AppoHPs, LEARNER_RESULTS_KL_KEY from ray.rllib.algorithms.ppo.ppo import UpdateKL from ray.rllib.execution.common import _get_shared_metrics, STEPS_SAMPLED_COUNTER +from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override from ray.rllib.utils.metrics import ( @@ -23,8 +25,10 @@ NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED, NUM_TARGET_UPDATES, + NUM_ENV_STEPS_TRAINED, + NUM_AGENT_STEPS_TRAINED, ) -from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.metrics import ALL_MODULES, LEARNER_STATS_KEY from ray.rllib.utils.typing import ( ResultDict, ) @@ -74,6 +78,7 @@ def __init__(self, algo_class=None): # __sphinx_doc_begin__ # APPO specific settings: + self._learner_hps = AppoHPs() self.vtrace = True self.use_critic = True self.use_gae = True @@ -92,6 +97,7 @@ def __init__(self, algo_class=None): self.num_multi_gpu_tower_stacks = 1 self.minibatch_buffer_size = 1 self.num_sgd_iter = 1 + self.target_update_frequency = 1 self.replay_proportion = 0.0 self.replay_buffer_num_slots = 100 self.learner_queue_size = 16 @@ -108,6 +114,8 @@ def __init__(self, algo_class=None): self.vf_loss_coeff = 0.5 self.entropy_coeff = 0.01 self.entropy_coeff_schedule = None + self.tau = 1.0 + # __sphinx_doc_end__ # fmt: on @@ -123,6 +131,8 @@ def training( use_kl_loss: Optional[bool] = NotProvided, kl_coeff: Optional[float] = NotProvided, kl_target: Optional[float] = NotProvided, + tau: Optional[float] = NotProvided, + target_update_frequency: Optional[int] = NotProvided, **kwargs, ) -> "APPOConfig": """Sets the training related configuration. @@ -141,6 +151,19 @@ def training( kl_coeff: Coefficient for weighting the KL-loss term. kl_target: Target term for the KL-term to reach (via adjusting the `kl_coeff` automatically). + tau: The factor by which to update the target policy network towards + the current policy network. Can range between 0 and 1. + e.g. updated_param = tau * current_param + (1 - tau) * target_param + target_update_frequency: The frequency to update the target policy and + tune the kl loss coefficients that are used during training. After + setting this parameter, the algorithm waits for at least + `target_update_frequency * minibatch_size * num_sgd_iter` number of + samples to be trained on by the learner group before updating the target + networks and tuned the kl loss coefficients that are used during + training. + NOTE: this parameter is only applicable when using the learner api + (_enable_learner_api=True and _enable_rl_module_api=True). + Returns: This updated AlgorithmConfig object. @@ -158,15 +181,52 @@ def training( self.lambda_ = lambda_ if clip_param is not NotProvided: self.clip_param = clip_param + self._learner_hps.clip_param = clip_param if use_kl_loss is not NotProvided: self.use_kl_loss = use_kl_loss if kl_coeff is not NotProvided: self.kl_coeff = kl_coeff + self._learner_hps.kl_coeff = kl_coeff if kl_target is not NotProvided: self.kl_target = kl_target + self._learner_hps.kl_target = kl_target + if tau is not NotProvided: + self.tau = tau + self._learner_hps.tau = tau + if target_update_frequency is not NotProvided: + self.target_update_frequency = target_update_frequency return self + @override(AlgorithmConfig) + def get_default_learner_class(self): + if self.framework_str == "tf2": + from ray.rllib.algorithms.appo.tf.appo_tf_learner import APPOTfLearner + + return APPOTfLearner + else: + raise ValueError(f"The framework {self.framework_str} is not supported.") + + @override(AlgorithmConfig) + def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec: + if self.framework_str == "tf2": + from ray.rllib.algorithms.appo.appo_catalog import APPOCatalog + from ray.rllib.algorithms.appo.tf.appo_tf_rl_module import APPOTfRLModule + + return SingleAgentRLModuleSpec( + module_class=APPOTfRLModule, catalog_class=APPOCatalog + ) + else: + raise ValueError(f"The framework {self.framework_str} is not supported.") + + @override(ImpalaConfig) + def validate(self) -> None: + super().validate() + self._learner_hps.tau = self.tau + self._learner_hps.kl_target = self.kl_target + self._learner_hps.kl_coeff = self.kl_coeff + self._learner_hps.clip_param = self.clip_param + class UpdateTargetAndKL: def __init__(self, workers, config): @@ -199,15 +259,23 @@ def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) # After init: Initialize target net. - self.workers.local_worker().foreach_policy_to_train( - lambda p, _: p.update_target() - ) + + # TODO(avnishn): + # does this need to happen in __init__? I think we can move it to setup() + if not self.config._enable_rl_module_api: + self.workers.local_worker().foreach_policy_to_train( + lambda p, _: p.update_target() + ) @override(Impala) def setup(self, config: AlgorithmConfig): super().setup(config) - self.update_kl = UpdateKL(self.workers) + # TODO(avnishn): + # this attribute isn't used anywhere else in the code. I think we can safely + # delete it. + if not self.config._enable_rl_module_api: + self.update_kl = UpdateKL(self.workers) def after_train_step(self, train_results: ResultDict) -> None: """Updates the target network and the KL coefficient for the APPO-loss. @@ -222,45 +290,84 @@ def after_train_step(self, train_results: ResultDict) -> None: train_results: The results dict collected during the most recent training step. """ - cur_ts = self._counters[ - NUM_AGENT_STEPS_SAMPLED - if self.config.count_steps_by == "agent_steps" - else NUM_ENV_STEPS_SAMPLED - ] + last_update = self._counters[LAST_TARGET_UPDATE_TS] - target_update_freq = ( - self.config.num_sgd_iter * self.config.minibatch_buffer_size - ) - if cur_ts - last_update > target_update_freq: - self._counters[NUM_TARGET_UPDATES] += 1 - self._counters[LAST_TARGET_UPDATE_TS] = cur_ts - # Update our target network. - self.workers.local_worker().foreach_policy_to_train( - lambda p, _: p.update_target() + if self.config._enable_learner_api and train_results: + # using steps trained here instead of sampled ... I'm not sure why the + # other implemenetation uses sampled. + # to be quite frank, im not sure if I understand how their target update + # freq would work. The difference in steps sampled/trained is pretty + # much always going to be larger than self.config.num_sgd_iter * + # self.config.minibatch_buffer_size unless the number of steps collected + # is really small. The thing is that the default rollout fragment length + # is 50, so the minibatch buffer size * num_sgd_iter is going to be + # have to be 50 to even meet the threshold of having delayed target + # updates. + # we should instead have the target / kl threshold update be based off + # of the train_batch_size * some target update frequency * num_sgd_iter. + cur_ts = self._counters[ + NUM_ENV_STEPS_TRAINED + if self.config.count_steps_by == "env_steps" + else NUM_AGENT_STEPS_TRAINED + ] + target_update_steps_freq = ( + self.config.train_batch_size + * self.config.num_sgd_iter + * self.config.target_update_frequency ) + if (cur_ts - last_update) >= target_update_steps_freq: + kls_to_update = {} + for module_id, module_results in train_results.items(): + if module_id != ALL_MODULES: + kls_to_update[module_id] = module_results[LEARNER_STATS_KEY][ + LEARNER_RESULTS_KL_KEY + ] + self._counters[NUM_TARGET_UPDATES] += 1 + self._counters[LAST_TARGET_UPDATE_TS] = cur_ts + self.learner_group.additional_update(sampled_kls=kls_to_update) - # Also update the KL-coefficient for the APPO loss, if necessary. - if self.config.use_kl_loss: - - def update(pi, pi_id): - assert LEARNER_STATS_KEY not in train_results, ( - "{} should be nested under policy id key".format( - LEARNER_STATS_KEY - ), - train_results, - ) - if pi_id in train_results: - kl = train_results[pi_id][LEARNER_STATS_KEY].get("kl") - assert kl is not None, (train_results, pi_id) - # Make the actual `Policy.update_kl()` call. - pi.update_kl(kl) - else: - logger.warning("No data for {}, not updating kl".format(pi_id)) - - # Update KL on all trainable policies within the local (trainer) - # Worker. - self.workers.local_worker().foreach_policy_to_train(update) + else: + cur_ts = self._counters[ + NUM_AGENT_STEPS_SAMPLED + if self.config.count_steps_by == "agent_steps" + else NUM_ENV_STEPS_SAMPLED + ] + target_update_freq = ( + self.config.num_sgd_iter * self.config.minibatch_buffer_size + ) + if cur_ts - last_update > target_update_freq: + self._counters[NUM_TARGET_UPDATES] += 1 + self._counters[LAST_TARGET_UPDATE_TS] = cur_ts + + # Update our target network. + self.workers.local_worker().foreach_policy_to_train( + lambda p, _: p.update_target() + ) + + # Also update the KL-coefficient for the APPO loss, if necessary. + if self.config.use_kl_loss: + + def update(pi, pi_id): + assert LEARNER_STATS_KEY not in train_results, ( + "{} should be nested under policy id key".format( + LEARNER_STATS_KEY + ), + train_results, + ) + if pi_id in train_results: + kl = train_results[pi_id][LEARNER_STATS_KEY].get("kl") + assert kl is not None, (train_results, pi_id) + # Make the actual `Policy.update_kl()` call. + pi.update_kl(kl) + else: + logger.warning( + "No data for {}, not updating kl".format(pi_id) + ) + + # Update KL on all trainable policies within the local (trainer) + # Worker. + self.workers.local_worker().foreach_policy_to_train(update) @override(Impala) def training_step(self) -> ResultDict: @@ -282,14 +389,33 @@ def get_default_policy_class( cls, config: AlgorithmConfig ) -> Optional[Type[Policy]]: if config["framework"] == "torch": - from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy - - return APPOTorchPolicy + if config._enable_rl_module_api: + raise ValueError( + "APPO with the torch backend is not yet supported by " + " the RLModule and Learner API." + ) + else: + from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy + + return APPOTorchPolicy elif config["framework"] == "tf": + if config._enable_rl_module_api: + raise ValueError( + "RLlib's RLModule and Learner API is not supported for" + " tf1. Use " + "framework='tf2' instead." + ) from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF1Policy return APPOTF1Policy else: + if config._enable_rl_module_api: + # TODO(avnishn): This policy class doesn't work just yet + from ray.rllib.algorithms.appo.tf.appo_tf_policy_rlm import ( + APPOTfPolicyWithRLModule, + ) + + return APPOTfPolicyWithRLModule from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF2Policy return APPOTF2Policy diff --git a/rllib/algorithms/appo/appo_catalog.py b/rllib/algorithms/appo/appo_catalog.py new file mode 100644 index 0000000000000..b675cba4b9cdb --- /dev/null +++ b/rllib/algorithms/appo/appo_catalog.py @@ -0,0 +1,24 @@ +from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog + + +class APPOCatalog(PPOCatalog): + """The Catalog class used to build models for APPO. + + PPOCatalog provides the following models: + - ActorCriticEncoder: The encoder used to encode the observations. + - Pi Head: The head used to compute the policy logits. + - Value Function Head: The head used to compute the value function. + + The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs + for the policy and value function. See implementations of PPORLModuleBase for + more details. + + Any custom ActorCriticEncoder can be built by overriding the + build_actor_critic_encoder() method. Alternatively, the ActorCriticEncoderConfig + at PPOCatalog.actor_critic_encoder_config can be overridden to build a custom + ActorCriticEncoder during RLModule runtime. + + Any custom head can be built by overriding the build_pi_head() and build_vf_head() + methods. Alternatively, the PiHeadConfig and VfHeadConfig can be overridden to + build custom heads during RLModule runtime. + """ diff --git a/rllib/algorithms/appo/tests/tf/__init__.py b/rllib/algorithms/appo/tests/tf/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/rllib/algorithms/appo/tests/tf/test_appo_learner.py b/rllib/algorithms/appo/tests/tf/test_appo_learner.py new file mode 100644 index 0000000000000..0a2822cde1013 --- /dev/null +++ b/rllib/algorithms/appo/tests/tf/test_appo_learner.py @@ -0,0 +1,109 @@ +import unittest +import numpy as np + +import ray +import ray.rllib.algorithms.appo as appo +from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.metrics import ALL_MODULES +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.test_utils import check, framework_iterator + + +tf1, tf, _ = try_import_tf() + +tf1.enable_eager_execution() + +frag_length = 50 + +FAKE_BATCH = { + SampleBatch.OBS: np.random.uniform(low=0, high=1, size=(frag_length, 4)).astype( + np.float32 + ), + SampleBatch.ACTIONS: np.random.choice(2, frag_length).astype(np.float32), + SampleBatch.REWARDS: np.random.uniform(low=-1, high=1, size=(frag_length,)).astype( + np.float32 + ), + SampleBatch.TERMINATEDS: np.array( + [False for _ in range(frag_length - 1)] + [True] + ).astype(np.float32), + SampleBatch.VF_PREDS: np.array( + list(reversed(range(frag_length))), dtype=np.float32 + ), + SampleBatch.ACTION_LOGP: np.log( + np.random.uniform(low=0, high=1, size=(frag_length,)) + ).astype(np.float32), +} + + +class TestImpalaTfLearner(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init() + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def test_appo_loss(self): + """Test that appo_policy_rlm loss matches the appo learner loss.""" + config = ( + appo.APPOConfig() + .environment("CartPole-v1") + .rollouts( + num_rollout_workers=0, + rollout_fragment_length=frag_length, + ) + .resources(num_gpus=0) + .training( + gamma=0.99, + model=dict( + fcnet_hiddens=[10, 10], + fcnet_activation="linear", + vf_share_layers=False, + ), + ) + .rl_module( + _enable_rl_module_api=True, + ) + ) + + for fw in framework_iterator(config, ("tf2")): + trainer = config.build() + policy = trainer.get_policy() + + if fw == "tf2": + train_batch = tf.nest.map_structure( + lambda x: tf.convert_to_tensor(x), FAKE_BATCH + ) + train_batch = SampleBatch(FAKE_BATCH) + policy_loss = policy.loss(policy.model, policy.dist_class, train_batch) + + algo_config = config.copy(copy_frozen=False) + algo_config.training(_enable_learner_api=True) + algo_config.validate() + algo_config.freeze() + + learner_group_config = algo_config.get_learner_group_config( + SingleAgentRLModuleSpec( + module_class=algo_config.rl_module_spec.module_class, + observation_space=policy.observation_space, + action_space=policy.action_space, + model_config_dict=policy.config["model"], + catalog_class=algo_config.rl_module_spec.catalog_class, + ) + ) + learner_group_config.num_learner_workers = 0 + learner_group = learner_group_config.build() + learner_group.set_weights(trainer.get_weights()) + results = learner_group.update(train_batch.as_multi_agent()) + learner_group_loss = results[ALL_MODULES]["total_loss"] + + check(learner_group_loss, policy_loss) + + +if __name__ == "__main__": + import pytest + import sys + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/algorithms/appo/tf/__init__.py b/rllib/algorithms/appo/tf/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/rllib/algorithms/appo/tf/appo_tf_learner.py b/rllib/algorithms/appo/tf/appo_tf_learner.py new file mode 100644 index 0000000000000..cadda7c4e6956 --- /dev/null +++ b/rllib/algorithms/appo/tf/appo_tf_learner.py @@ -0,0 +1,257 @@ +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Dict, Mapping + +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.algorithms.appo.tf.appo_tf_rl_module import OLD_ACTION_DIST_KEY +from ray.rllib.algorithms.impala.tf.vtrace_tf_v2 import make_time_major, vtrace_tf2 +from ray.rllib.algorithms.impala.impala_base_learner import ImpalaHPs +from ray.rllib.algorithms.impala.tf.impala_tf_learner import ImpalaTfLearner +from ray.rllib.core.learner.learner import POLICY_LOSS_KEY, VF_LOSS_KEY, ENTROPY_KEY +from ray.rllib.core.rl_module.marl_module import ModuleID +from ray.rllib.core.learner.tf.tf_learner import TfLearner +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import TensorType + +_, tf, _ = try_import_tf() + + +LEARNER_RESULTS_KL_KEY = "mean_kl_loss" + + +@dataclass +class AppoHPs(ImpalaHPs): + """Hyper-parameters for APPO. + + Attributes: + rollout_frag_or_episode_len: The length of a rollout fragment or episode. + Used when making SampleBatches time major for computing loss. + recurrent_seq_len: The length of a recurrent sequence. Used when making + SampleBatches time major for computing loss. + discount_factor: The discount factor to use for computing returns. + vtrace_clip_rho_threshold: The rho threshold to use for clipping the + importance weights. + vtrace_clip_pg_rho_threshold: The rho threshold to use for clipping the + importance weights when computing the policy_gradient loss. + vtrace_drop_last_ts: Whether to drop the last timestep when computing the loss. + This is useful for stabilizing the loss. + NOTE: This shouldn't be True when training on environments where the rewards + come at the end of the episode. + vf_loss_coeff: The amount to weight the value function loss by when computing + the total loss. + entropy_coeff: The amount to weight the average entropy of the actions in the + SampleBatch towards the total_loss for module updates. The higher this + coefficient, the more that the policy network will be encouraged to output + distributions with higher entropy/std deviation, which will encourage + greater exploration. + kl_target: The target kl divergence loss coefficient to use for the KL loss. + kl_coeff: The coefficient to weight the KL divergence between the old policy + and the target policy towards the total loss for module updates. + tau: The factor by which to update the target policy network towards + the current policy network. Can range between 0 and 1. + e.g. updated_param = tau * current_param + (1 - tau) * target_param + + """ + + kl_target: float = 0.01 + kl_coeff: float = 0.1 + clip_param = 0.2 + tau = 1.0 + + +class APPOTfLearner(ImpalaTfLearner): + """Implements APPO loss / update logic on top of ImpalaTfLearner. + + This class implements the APPO loss under `_compute_loss_per_module()` and + implements the target network and KL coefficient updates under + `additional_updates_per_module()` + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.kl_target = self._hps.kl_target + self.clip_param = self._hps.clip_param + self.kl_coeffs = defaultdict(lambda: self._hps.kl_coeff) + self.kl_coeff = self._hps.kl_coeff + self.tau = self._hps.tau + + @override(TfLearner) + def compute_loss_per_module( + self, module_id: str, batch: SampleBatch, fwd_out: Mapping[str, TensorType] + ) -> TensorType: + values = fwd_out[SampleBatch.VF_PREDS] + target_policy_dist = fwd_out[SampleBatch.ACTION_DIST] + old_target_policy_dist = fwd_out[OLD_ACTION_DIST_KEY] + + old_target_policy_actions_logp = old_target_policy_dist.logp( + batch[SampleBatch.ACTIONS] + ) + behaviour_actions_logp = batch[SampleBatch.ACTION_LOGP] + target_actions_logp = target_policy_dist.logp(batch[SampleBatch.ACTIONS]) + + behaviour_actions_logp_time_major = make_time_major( + behaviour_actions_logp, + trajectory_len=self.rollout_frag_or_episode_len, + recurrent_seq_len=self.recurrent_seq_len, + drop_last=self.vtrace_drop_last_ts, + ) + target_actions_logp_time_major = make_time_major( + target_actions_logp, + trajectory_len=self.rollout_frag_or_episode_len, + recurrent_seq_len=self.recurrent_seq_len, + drop_last=self.vtrace_drop_last_ts, + ) + old_actions_logp_time_major = make_time_major( + old_target_policy_actions_logp, + trajectory_len=self.rollout_frag_or_episode_len, + recurrent_seq_len=self.recurrent_seq_len, + drop_last=self.vtrace_drop_last_ts, + ) + values_time_major = make_time_major( + values, + trajectory_len=self.rollout_frag_or_episode_len, + recurrent_seq_len=self.recurrent_seq_len, + drop_last=self.vtrace_drop_last_ts, + ) + bootstrap_value = values_time_major[-1] + rewards_time_major = make_time_major( + batch[SampleBatch.REWARDS], + trajectory_len=self.rollout_frag_or_episode_len, + recurrent_seq_len=self.recurrent_seq_len, + drop_last=self.vtrace_drop_last_ts, + ) + + # the discount factor that is used should be gamma except for timesteps where + # the episode is terminated. In that case, the discount factor should be 0. + discounts_time_major = ( + 1.0 + - tf.cast( + make_time_major( + batch[SampleBatch.TERMINATEDS], + trajectory_len=self.rollout_frag_or_episode_len, + recurrent_seq_len=self.recurrent_seq_len, + drop_last=self.vtrace_drop_last_ts, + ), + dtype=tf.float32, + ) + ) * self.discount_factor + vtrace_adjusted_target_values, pg_advantages = vtrace_tf2( + target_action_log_probs=old_actions_logp_time_major, + behaviour_action_log_probs=behaviour_actions_logp_time_major, + rewards=rewards_time_major, + values=values_time_major, + bootstrap_value=bootstrap_value, + clip_pg_rho_threshold=self.vtrace_clip_pg_rho_threshold, + clip_rho_threshold=self.vtrace_clip_rho_threshold, + discounts=discounts_time_major, + ) + + # The policy gradients loss. + is_ratio = tf.clip_by_value( + tf.math.exp( + behaviour_actions_logp_time_major - old_actions_logp_time_major + ), + 0.0, + 2.0, + ) + logp_ratio = is_ratio * tf.math.exp( + target_actions_logp_time_major - behaviour_actions_logp_time_major + ) + + surrogate_loss = tf.math.minimum( + pg_advantages * logp_ratio, + ( + pg_advantages + * tf.clip_by_value(logp_ratio, 1 - self.clip_param, 1 + self.clip_param) + ), + ) + + action_kl = old_target_policy_dist.kl(target_policy_dist) + mean_kl_loss = tf.math.reduce_mean(action_kl) + mean_pi_loss = -tf.math.reduce_mean(surrogate_loss) + + # The baseline loss. + delta = values_time_major - vtrace_adjusted_target_values + mean_vf_loss = 0.5 * tf.math.reduce_mean(delta**2) + + # The entropy loss. + mean_entropy_loss = -tf.math.reduce_mean(target_actions_logp_time_major) + + # The summed weighted loss. + total_loss = ( + mean_pi_loss + + (mean_vf_loss * self.vf_loss_coeff) + + (mean_entropy_loss * self.entropy_coeff) + + (mean_kl_loss * self.kl_coeffs[module_id]) + ) + + return { + self.TOTAL_LOSS_KEY: total_loss, + POLICY_LOSS_KEY: mean_pi_loss, + VF_LOSS_KEY: mean_vf_loss, + ENTROPY_KEY: mean_entropy_loss, + LEARNER_RESULTS_KL_KEY: mean_kl_loss, + } + + @override(ImpalaTfLearner) + def remove_module(self, module_id: str): + super().remove_module(module_id) + self.kl_coeffs.pop(module_id) + + def _update_module_target_networks(self, module_id: ModuleID): + """Update the target policy of each module with the current policy. + + Do that update via polyak averaging. + + Args: + module_id: The module whose target networks need to be updated. + + """ + module = self.module[module_id] + + target_current_network_pairs = module.get_target_network_pairs() + for target_network, current_network in target_current_network_pairs: + for old_var, current_var in zip( + target_network.variables, current_network.variables + ): + updated_var = self.tau * current_var + (1.0 - self.tau) * old_var + old_var.assign(updated_var) + + def _update_module_kl_coeff( + self, module_id: ModuleID, sampled_kls: Dict[ModuleID, float] + ): + """Dynamically update the KL loss coefficients of each module with. + + The update is completed using the mean KL divergence between the action + distributions current policy and old policy of each module. That action + distribution is computed during the most recent update/call to `compute_loss`. + + Args: + module_id: The module whose KL loss coefficient to update. + sampled_kls: The KL divergence between the action distributions of + the current policy and old policy of each module. + + """ + if module_id in sampled_kls: + sampled_kl = sampled_kls[module_id] + # Update the current KL value based on the recently measured value. + # Increase. + if sampled_kl > 2.0 * self.kl_target: + self.kl_coeffs[module_id] *= 1.5 + # Decrease. + elif sampled_kl < 0.5 * self.kl_target: + self.kl_coeffs[module_id] *= 0.5 + + @override(ImpalaTfLearner) + def additional_update_per_module( + self, module_id: ModuleID, sampled_kls: Dict[ModuleID, float], **kwargs + ) -> Mapping[str, Any]: + """Update the target networks and KL loss coefficients of each module. + + Args: + + """ + self._update_module_target_networks(module_id) + self._update_module_kl_coeff(module_id, sampled_kls) + return {} diff --git a/rllib/algorithms/appo/tf/appo_tf_policy_rlm.py b/rllib/algorithms/appo/tf/appo_tf_policy_rlm.py new file mode 100644 index 0000000000000..f01235834d852 --- /dev/null +++ b/rllib/algorithms/appo/tf/appo_tf_policy_rlm.py @@ -0,0 +1,227 @@ +import logging +from typing import Dict, List, Union + +from ray.rllib.algorithms.ppo.ppo_tf_policy import validate_config +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_mixins import ( + EntropyCoeffSchedule, + LearningRateSchedule, + KLCoeffMixin, + GradStatsMixin, + TargetNetworkMixin, +) + +from ray.rllib.algorithms.impala.impala_tf_policy import ( + VTraceClipGradients, + VTraceOptimizer, +) + +from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 +from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.tf_utils import ( + explained_variance, +) + + +from ray.rllib.algorithms.impala.tf.vtrace_tf_v2 import make_time_major, vtrace_tf2 +from ray.rllib.utils.typing import TensorType + +tf1, tf, tfv = try_import_tf() + +logger = logging.getLogger(__name__) + + +class APPOTfPolicyWithRLModule( + VTraceClipGradients, + VTraceOptimizer, + LearningRateSchedule, + KLCoeffMixin, + EntropyCoeffSchedule, + TargetNetworkMixin, + GradStatsMixin, + EagerTFPolicyV2, +): + def __init__(self, observation_space, action_space, config): + validate_config(config) + EagerTFPolicyV2.enable_eager_execution_if_necessary() + # Initialize MixIns before super().__init__ because base class will call + # self.loss, which requires these MixIns to be initialized. + LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + EntropyCoeffSchedule.__init__( + self, config["entropy_coeff"], config["entropy_coeff_schedule"] + ) + # Although this is a no-op, we call __init__ here to make it clear + # that base.__init__ will use the make_model() call. + VTraceClipGradients.__init__(self) + VTraceOptimizer.__init__(self) + self.framework = "tf2" + KLCoeffMixin.__init__(self, config) + GradStatsMixin.__init__(self) + EagerTFPolicyV2.__init__(self, observation_space, action_space, config) + # construct the target model and make its weights the same as the model + self.target_model = self.make_rl_module() + self.target_model.set_weights(self.model.get_weights()) + + # Initiate TargetNetwork ops after loss initialization. + self.maybe_initialize_optimizer_and_loss() + TargetNetworkMixin.__init__(self) + + @Deprecated(new="APPOTfLearner.compute_loss_per_module()", error=False) + @override(EagerTFPolicyV2) + def loss( + self, + model: Union[ModelV2, "tf.keras.Model"], + dist_class, + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + train_batch[SampleBatch.ACTIONS] + train_batch[SampleBatch.ACTION_LOGP] + train_batch[SampleBatch.REWARDS] + train_batch[SampleBatch.TERMINATEDS] + + seqs_len = train_batch.get(SampleBatch.SEQ_LENS) + rollout_frag_or_episode_len = ( + self.config["rollout_fragment_length"] if not seqs_len else None + ) + drop_last = self.config["vtrace_drop_last_ts"] + + target_policy_fwd_out = model.forward_train(train_batch) + values = target_policy_fwd_out[SampleBatch.VF_PREDS] + target_policy_dist = target_policy_fwd_out[SampleBatch.ACTION_DIST] + + old_target_policy_fwd_out = self.target_model.forward_train(train_batch) + old_target_policy_dist = old_target_policy_fwd_out[SampleBatch.ACTION_DIST] + + behaviour_actions_logp = train_batch[SampleBatch.ACTION_LOGP] + target_actions_logp = target_policy_dist.logp(train_batch[SampleBatch.ACTIONS]) + old_target_actions_logp = old_target_policy_dist.logp( + train_batch[SampleBatch.ACTIONS] + ) + behaviour_actions_logp_time_major = make_time_major( + behaviour_actions_logp, + trajectory_len=rollout_frag_or_episode_len, + recurrent_seq_len=seqs_len, + drop_last=drop_last, + ) + target_actions_logp_time_major = make_time_major( + target_actions_logp, + trajectory_len=rollout_frag_or_episode_len, + recurrent_seq_len=seqs_len, + drop_last=drop_last, + ) + old_target_actions_logp_time_major = make_time_major( + old_target_actions_logp, + trajectory_len=rollout_frag_or_episode_len, + recurrent_seq_len=seqs_len, + drop_last=drop_last, + ) + values_time_major = make_time_major( + values, + trajectory_len=rollout_frag_or_episode_len, + recurrent_seq_len=seqs_len, + drop_last=drop_last, + ) + bootstrap_value = values_time_major[-1] + rewards_time_major = make_time_major( + train_batch[SampleBatch.REWARDS], + trajectory_len=rollout_frag_or_episode_len, + recurrent_seq_len=seqs_len, + drop_last=drop_last, + ) + + # how to compute discouts? + # should they be pre computed? + discounts_time_major = ( + 1.0 + - tf.cast( + make_time_major( + train_batch[SampleBatch.TERMINATEDS], + trajectory_len=rollout_frag_or_episode_len, + recurrent_seq_len=seqs_len, + drop_last=drop_last, + ), + dtype=tf.float32, + ) + ) * self.config["gamma"] + vtrace_adjusted_target_values, pg_advantages = vtrace_tf2( + target_action_log_probs=old_target_actions_logp_time_major, + behaviour_action_log_probs=behaviour_actions_logp_time_major, + rewards=rewards_time_major, + values=values_time_major, + bootstrap_value=bootstrap_value, + clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], + clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], + discounts=discounts_time_major, + ) + + is_ratio = tf.clip_by_value( + tf.math.exp( + behaviour_actions_logp_time_major - target_actions_logp_time_major + ), + 0.0, + 2.0, + ) + logp_ratio = is_ratio * tf.math.exp( + target_actions_logp_time_major - behaviour_actions_logp_time_major + ) + + clip_param = self.config["clip_param"] + surrogate_loss = tf.math.minimum( + pg_advantages * logp_ratio, + ( + pg_advantages + * tf.clip_by_value(logp_ratio, 1 - clip_param, 1 + clip_param) + ), + ) + action_kl = old_target_policy_dist.kl(target_policy_dist) + mean_kl_loss = tf.math.reduce_mean(action_kl) + mean_pi_loss = -tf.math.reduce_mean(surrogate_loss) + + # The baseline loss. + delta = values_time_major - vtrace_adjusted_target_values + mean_vf_loss = 0.5 * tf.math.reduce_mean(delta**2) + + # The entropy loss. + mean_entropy_loss = -tf.math.reduce_mean(target_actions_logp_time_major) + + # The summed weighted loss. + total_loss = ( + mean_pi_loss + + (mean_vf_loss * self.config["vf_loss_coeff"]) + + (mean_entropy_loss * self.entropy_coeff) + + (mean_kl_loss * self.kl_coeff) + ) + + self.stats = { + "total_loss": total_loss, + "policy_loss": mean_pi_loss, + "vf_loss": mean_vf_loss, + "values": values_time_major, + "entropy_loss": mean_entropy_loss, + "vtrace_adjusted_target_values": vtrace_adjusted_target_values, + "mean_kl": mean_kl_loss, + } + return total_loss + + @override(EagerTFPolicyV2) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + return { + "cur_lr": tf.cast(self.cur_lr, tf.float64), + "policy_loss": self.stats["policy_loss"], + "entropy": self.stats["entropy_loss"], + "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), + "var_gnorm": tf.linalg.global_norm(self.model.trainable_variables), + "vf_loss": self.stats["vf_loss"], + "vf_explained_var": explained_variance( + tf.reshape(self.stats["vtrace_adjusted_target_values"], [-1]), + tf.reshape(self.stats["values"], [-1]), + ), + "mean_kl": self.stats["mean_kl"], + } + + @override(EagerTFPolicyV2) + def get_batch_divisibility_req(self) -> int: + return self.config["rollout_fragment_length"] diff --git a/rllib/algorithms/appo/tf/appo_tf_rl_module.py b/rllib/algorithms/appo/tf/appo_tf_rl_module.py new file mode 100644 index 0000000000000..0131e33622d02 --- /dev/null +++ b/rllib/algorithms/appo/tf/appo_tf_rl_module.py @@ -0,0 +1,53 @@ +from typing import List + + +from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule +from ray.rllib.core.models.base import ACTOR +from ray.rllib.core.models.tf.encoder import ENCODER_OUT +from ray.rllib.core.rl_module.rl_module_with_target_networks_interface import ( + RLModuleWithTargetNetworksInterface, +) +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.nested_dict import NestedDict + +_, tf, _ = try_import_tf() + +OLD_ACTION_DIST_KEY = "old_action_dist" +OLD_ACTION_DIST_LOGITS_KEY = "old_action_dist_logits" + + +class APPOTfRLModule(PPOTfRLModule, RLModuleWithTargetNetworksInterface): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + catalog = self.config.get_catalog() + # old pi and old encoder are the "target networks" that are used for + # the stabilization of the updates of the current pi and encoder. + self.old_pi = catalog.build_pi_head(framework=self.framework) + self.old_encoder = catalog.build_actor_critic_encoder(framework=self.framework) + self.old_pi.set_weights(self.pi.get_weights()) + self.old_encoder.set_weights(self.encoder.get_weights()) + self.old_pi.trainable = False + self.old_encoder.trainable = False + + @override(RLModuleWithTargetNetworksInterface) + def get_target_network_pairs(self): + return [(self.old_pi, self.pi), (self.old_encoder, self.encoder)] + + @override(PPOTfRLModule) + def output_specs_train(self) -> List[str]: + return [ + SampleBatch.ACTION_DIST, + SampleBatch.VF_PREDS, + OLD_ACTION_DIST_KEY, + ] + + def _forward_train(self, batch: NestedDict): + outs = super()._forward_train(batch) + old_pi_inputs_encoded = self.old_encoder(batch)[ENCODER_OUT][ACTOR] + old_action_dist_logits = self.old_pi(old_pi_inputs_encoded) + old_action_dist = self.action_dist_cls.from_logits(old_action_dist_logits) + outs[OLD_ACTION_DIST_KEY] = old_action_dist + outs[OLD_ACTION_DIST_LOGITS_KEY] = old_action_dist_logits + return outs diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index 92e609cbaa85e..7c23cf614e281 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -141,6 +141,7 @@ def __init__(self, algo_class=None): # Override some of AlgorithmConfig's default values with ARS-specific values. self.rollout_fragment_length = 50 self.train_batch_size = 500 + self.minibatch_size = self.train_batch_size self.num_rollout_workers = 2 self.num_gpus = 1 self.lr = 0.0005 @@ -163,6 +164,7 @@ def training( gamma: Optional[float] = NotProvided, num_multi_gpu_tower_stacks: Optional[int] = NotProvided, minibatch_buffer_size: Optional[int] = NotProvided, + minibatch_size: Optional[int] = NotProvided, num_sgd_iter: Optional[int] = NotProvided, replay_proportion: Optional[float] = NotProvided, replay_buffer_num_slots: Optional[int] = NotProvided, @@ -215,6 +217,11 @@ def training( is performing gradient calculations. minibatch_buffer_size: How many train batches should be retained for minibatching. This conf only has an effect if `num_sgd_iter > 1`. + minibatch_size: The size of minibatches that are trained over during + each SGD iteration. Note this only has an effect if + `_enable_learner_api` == True. + Note: minibatch_size must be a multiple of rollout_fragment_length or + sequence_length and smaller than or equal to train_batch_size. num_sgd_iter: Number of passes to make over each train batch. replay_proportion: Set >0 to enable experience replay. Saved samples will be replayed with a p:1 proportion to new data samples. @@ -329,6 +336,8 @@ def training( self.after_train_step = after_train_step if gamma is not NotProvided: self.gamma = gamma + if minibatch_size is not NotProvided: + self.minibatch_size = minibatch_size return self @@ -377,6 +386,18 @@ def validate(self) -> None: "term/optimizer! Try setting config.training(" "_tf_policy_handles_more_than_one_loss=True)." ) + if self._enable_learner_api: + if not ( + (self.minibatch_size % self.rollout_fragment_length == 0) + and self.minibatch_size <= self.train_batch_size + ): + raise ValueError( + "minibatch_size must be a multiple of rollout_fragment_length and " + "must be smaller than or equal to train_batch_size. Got" + f" minibatch_size={self.minibatch_size}, train_batch_size=" + f"{self.train_batch_size}, and rollout_fragment_length=" + f"{self.get_rollout_fragment_length()}" + ) # learner hps need to be updated inside of config.validate in order to have # the correct values for when a user starts an experiment from a dict. This is # as oppposed to assigning the values inthe builder functions such as `training` @@ -622,6 +643,9 @@ def setup(self, config: AlgorithmConfig): ) self._aggregator_actor_manager = None + # This variable is used to keep track of the statistics from the most recent + # update of the learner group + self._results = {} self._timeout_s_sampler_manager = self.config.timeout_s_sampler_manager if not self.config._enable_learner_api: @@ -703,7 +727,17 @@ def training_step(self) -> ResultDict: timeout_seconds=self.config.worker_health_probe_timeout_s, mark_healthy=True, ) - return train_results + + if self.config._enable_learner_api: + if train_results: + # store the most recent result and return it if no new result is + # available. This keeps backwards compatibility with the old + # training stack / results reporting stack. This is necessary + # any time we develop an asynchronous algorithm. + self._results = train_results + return self._results + else: + return train_results @classmethod @override(Algorithm) @@ -879,6 +913,7 @@ def learn_on_processed_samples(self) -> ResultDict: reduce_fn=_reduce_impala_results, block=blocking, num_iters=self.config.num_sgd_iter, + minibatch_size=self.config.minibatch_size, ) else: lg_results = None @@ -1062,7 +1097,6 @@ def update_workers_from_learner_group( self._counters[NUM_TRAINING_STEP_CALLS_SINCE_LAST_SYNCH_WORKER_WEIGHTS] = 0 self._counters[NUM_SYNCH_WORKER_WEIGHTS] += 1 weights = self.learner_group.get_weights(policy_ids) - if self.config.num_rollout_workers == 0: worker = self.workers.local_worker() worker.set_weights(weights) diff --git a/rllib/algorithms/impala/impala_tf_policy.py b/rllib/algorithms/impala/impala_tf_policy.py index 57b83aa37cc45..e0e005da69a26 100644 --- a/rllib/algorithms/impala/impala_tf_policy.py +++ b/rllib/algorithms/impala/impala_tf_policy.py @@ -179,15 +179,19 @@ def compute_gradients_fn( self, optimizer: LocalOptimizer, loss: TensorType ) -> ModelGradients: # Supporting more than one loss/optimizer. + if self.config.get("_enable_rl_module_api", False): + # In order to access the variables for rl modules, we need to + # use the underlying keras api model.trainable_variables. + trainable_variables = self.model.trainable_variables + else: + trainable_variables = self.model.trainable_variables() if self.config["_tf_policy_handles_more_than_one_loss"]: optimizers = force_list(optimizer) losses = force_list(loss) assert len(optimizers) == len(losses) clipped_grads_and_vars = [] for optim, loss_ in zip(optimizers, losses): - grads_and_vars = optim.compute_gradients( - loss_, self.model.trainable_variables() - ) + grads_and_vars = optim.compute_gradients(loss_, trainable_variables) clipped_g_and_v = [] for g, v in grads_and_vars: if g is not None: @@ -205,9 +209,7 @@ def compute_gradients_fn( ) grads = [g for (g, v) in grads_and_vars] self.grads, _ = tf.clip_by_global_norm(grads, self.config["grad_clip"]) - clipped_grads_and_vars = list( - zip(self.grads, self.model.trainable_variables()) - ) + clipped_grads_and_vars = list(zip(self.grads, trainable_variables)) return clipped_grads_and_vars diff --git a/rllib/algorithms/impala/tests/test_impala_learner.py b/rllib/algorithms/impala/tests/test_impala_learner.py index a982b6f7a1a89..bfcb3e3769c2e 100644 --- a/rllib/algorithms/impala/tests/test_impala_learner.py +++ b/rllib/algorithms/impala/tests/test_impala_learner.py @@ -16,8 +16,7 @@ tf1, tf, _ = try_import_tf() tf1.enable_eager_execution() - -frag_length = 32 +frag_length = 50 FAKE_BATCH = { SampleBatch.OBS: np.random.uniform(low=0, high=1, size=(frag_length, 4)).astype( diff --git a/rllib/algorithms/impala/tf/impala_tf_policy_rlm.py b/rllib/algorithms/impala/tf/impala_tf_policy_rlm.py index b63bd1815858e..0244a96c0ac21 100644 --- a/rllib/algorithms/impala/tf/impala_tf_policy_rlm.py +++ b/rllib/algorithms/impala/tf/impala_tf_policy_rlm.py @@ -31,12 +31,13 @@ class ImpalaTfPolicyWithRLModule( def __init__(self, observation_space, action_space, config): validate_config(config) EagerTFPolicyV2.enable_eager_execution_if_necessary() - EagerTFPolicyV2.__init__(self, observation_space, action_space, config) - # Initialize MixIns. + # Initialize MixIns before super().__init__ because base class will call + # self.loss, which requires these MixIns to be initialized. LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) EntropyCoeffSchedule.__init__( self, config["entropy_coeff"], config["entropy_coeff_schedule"] ) + EagerTFPolicyV2.__init__(self, observation_space, action_space, config) self.maybe_initialize_optimizer_and_loss() diff --git a/rllib/algorithms/ppo/ppo_base_learner.py b/rllib/algorithms/ppo/ppo_base_learner.py index e9347e71651c8..e660dcf8aae2f 100644 --- a/rllib/algorithms/ppo/ppo_base_learner.py +++ b/rllib/algorithms/ppo/ppo_base_learner.py @@ -31,7 +31,6 @@ def __init__(self, *args, **kwargs): def additional_update_per_module( self, module_id: ModuleID, sampled_kl_values: dict, timestep: int ) -> Mapping[str, Any]: - assert sampled_kl_values, "Sampled KL values are empty." sampled_kl = sampled_kl_values[module_id] diff --git a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py index d648d99c5b5d9..45959034f6b41 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_rl_module.py @@ -221,13 +221,16 @@ def test_forward_train(self): # input_batch[SampleBatch.SEQ_LENS] = np.array([1]) fwd_out = module.forward_exploration(input_batch) - action = convert_to_numpy(fwd_out["action_dist"].sample()[0]) + _action = fwd_out["action_dist"].sample() + action = convert_to_numpy(_action[0]) + action_logp = convert_to_numpy(fwd_out["action_dist"].logp(_action)[0]) new_obs, reward, terminated, truncated, _ = env.step(action) new_obs = preprocessor.transform(new_obs) output_batch = { SampleBatch.OBS: obs, SampleBatch.NEXT_OBS: new_obs, SampleBatch.ACTIONS: action, + SampleBatch.ACTION_LOGP: action_logp, SampleBatch.REWARDS: np.array(reward), SampleBatch.TERMINATEDS: np.array(terminated), SampleBatch.TRUNCATEDS: np.array(truncated), diff --git a/rllib/algorithms/ppo/tf/ppo_tf_policy_rlm.py b/rllib/algorithms/ppo/tf/ppo_tf_policy_rlm.py index b5e759ebab0d0..d19355f0b3eb7 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_policy_rlm.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_policy_rlm.py @@ -60,15 +60,16 @@ class PPOTfPolicyWithRLModule( def __init__(self, observation_space, action_space, config): # TODO: Move into Policy API, if needed at all here. Why not move this into # `PPOConfig`?. - validate_config(config) + self.framework = "tf2" EagerTFPolicyV2.enable_eager_execution_if_necessary() - EagerTFPolicyV2.__init__(self, observation_space, action_space, config) + validate_config(config) # Initialize MixIns. LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) EntropyCoeffSchedule.__init__( self, config["entropy_coeff"], config["entropy_coeff_schedule"] ) KLCoeffMixin.__init__(self, config) + EagerTFPolicyV2.__init__(self, observation_space, action_space, config) self.maybe_initialize_optimizer_and_loss() diff --git a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py index b28747fdc715b..faf6cb311b600 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_rl_module.py @@ -7,14 +7,12 @@ from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule from ray.rllib.models.specs.specs_dict import SpecDict -from ray.rllib.models.specs.specs_tf import TFTensorSpecs from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.nested_dict import NestedDict tf1, tf, _ = try_import_tf() -tf1.enable_eager_execution() class PPOTfRLModule(PPORLModuleBase, TfRLModule): @@ -34,19 +32,17 @@ def __init__(self, *args, **kwargs): @override(RLModule) def input_specs_train(self) -> List[str]: - return [SampleBatch.OBS, SampleBatch.ACTIONS] + return [SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.ACTION_LOGP] @override(RLModule) - def output_specs_train(self) -> SpecDict: - spec = SpecDict( - { - SampleBatch.ACTION_DIST: Distribution, - SampleBatch.ACTION_LOGP: TFTensorSpecs("b", dtype=tf.float32), - SampleBatch.VF_PREDS: TFTensorSpecs("b", dtype=tf.float32), - "entropy": TFTensorSpecs("b", dtype=tf.float32), - } - ) - return spec + def output_specs_train(self) -> List[str]: + return [ + SampleBatch.ACTION_DIST_INPUTS, + SampleBatch.ACTION_DIST, + SampleBatch.ACTION_LOGP, + SampleBatch.VF_PREDS, + "entropy", + ] @override(RLModule) def input_specs_exploration(self): diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index b9c1bd746fe68..818204da963ef 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -55,6 +55,11 @@ ParamRef = Hashable ParamDictType = Dict[ParamRef, ParamType] +# COMMON LEARNER LOSS_KEYS +POLICY_LOSS_KEY = "policy_loss" +VF_LOSS_KEY = "vf_loss" +ENTROPY_KEY = "entropy" + @dataclass class FrameworkHPs: diff --git a/rllib/core/learner/tf/tf_learner.py b/rllib/core/learner/tf/tf_learner.py index 4eba5224cee48..5fea170ca4ef6 100644 --- a/rllib/core/learner/tf/tf_learner.py +++ b/rllib/core/learner/tf/tf_learner.py @@ -52,9 +52,18 @@ def __init__( framework_hyperparameters: Optional[FrameworkHPs] = FrameworkHPs(), **kwargs, ): - super().__init__(framework_hyperparameters=framework_hyperparameters, **kwargs) - tf1.enable_eager_execution() + # by default in rllib we disable tf2 behavior + # This call re-enables it as it is needed for using + # this class. + try: + tf1.enable_v2_behavior() + except ValueError: + # This is a hack to avoid the error that happens when calling + # enable_v2_behavior after variables have already been created. + pass + + super().__init__(framework_hyperparameters=framework_hyperparameters, **kwargs) self._enable_tf_function = framework_hyperparameters.eager_tracing # the default strategy is a no-op that can be used in the local mode diff --git a/rllib/core/rl_module/rl_module_with_target_networks_interface.py b/rllib/core/rl_module/rl_module_with_target_networks_interface.py new file mode 100644 index 0000000000000..0a247739c2831 --- /dev/null +++ b/rllib/core/rl_module/rl_module_with_target_networks_interface.py @@ -0,0 +1,23 @@ +import abc +from typing import List, Tuple + +from ray.rllib.utils.typing import NetworkType + + +class RLModuleWithTargetNetworksInterface(abc.ABC): + """An RLModule Mixin for adding an interface for target networks. + + This is used for identifying the target networks that are used for stabilizing + the updates of the current trainable networks of this RLModule. + """ + + @abc.abstractmethod + def get_target_network_pairs(self) -> List[Tuple[NetworkType, NetworkType]]: + """Returns a list of (target, current) networks. + + This is used for identifying the target networks that are used for stabilizing + the updates of the current trainable networks of this RLModule. + + Returns: + A list of (target, current) networks. + """ diff --git a/rllib/env/wrappers/dm_control_wrapper.py b/rllib/env/wrappers/dm_control_wrapper.py index 1e18f40542ac6..bc743341cacf7 100644 --- a/rllib/env/wrappers/dm_control_wrapper.py +++ b/rllib/env/wrappers/dm_control_wrapper.py @@ -190,13 +190,13 @@ def step(self, action): assert self._norm_action_space.contains(action) action = self._convert_action(action) assert self._true_action_space.contains(action) - reward = 0 + reward = 0.0 extra = {"internal_state": self._env.physics.get_state().copy()} terminated = truncated = False for _ in range(self._frame_skip): time_step = self._env.step(action) - reward += time_step.reward or 0 + reward += time_step.reward or 0.0 terminated = False truncated = time_step.last() if terminated or truncated: diff --git a/rllib/models/specs/specs_tf.py b/rllib/models/specs/specs_tf.py index f438880c2d893..28c0f1ee40362 100644 --- a/rllib/models/specs/specs_tf.py +++ b/rllib/models/specs/specs_tf.py @@ -14,15 +14,17 @@ def get_type(cls) -> Type: return tf.Tensor @override(TensorSpec) - def get_shape(self, tensor: tf.Tensor) -> Tuple[int]: + def get_shape(self, tensor: "tf.Tensor") -> Tuple[int]: return tuple(tensor.shape) @override(TensorSpec) - def get_dtype(self, tensor: tf.Tensor) -> Any: + def get_dtype(self, tensor: "tf.Tensor") -> Any: return tensor.dtype @override(TensorSpec) - def _full(self, shape: Tuple[int], fill_value: Union[float, int] = 0) -> tf.Tensor: + def _full( + self, shape: Tuple[int], fill_value: Union[float, int] = 0 + ) -> "tf.Tensor": if self.dtype: return tf.ones(shape, dtype=self.dtype) * fill_value return tf.fill(shape, fill_value) diff --git a/rllib/models/tf/tf_distributions.py b/rllib/models/tf/tf_distributions.py index 7748f59ddec52..8d18a139f59ff 100644 --- a/rllib/models/tf/tf_distributions.py +++ b/rllib/models/tf/tf_distributions.py @@ -94,8 +94,8 @@ class TfCategorical(TfDistribution): @override(TfDistribution) def __init__( self, - probs: tf.Tensor = None, - logits: tf.Tensor = None, + probs: "tf.Tensor" = None, + logits: "tf.Tensor" = None, temperature: float = 1.0, ) -> None: # We assert this here because to_deterministic makes this assumption. @@ -124,8 +124,8 @@ def logp(self, value: TensorType, **kwargs) -> TensorType: @override(TfDistribution) def _get_tf_distribution( self, - probs: tf.Tensor = None, - logits: tf.Tensor = None, + probs: "tf.Tensor" = None, + logits: "tf.Tensor" = None, temperature: float = 1.0, ) -> "tfp.distributions.Distribution": if logits is not None: @@ -189,8 +189,8 @@ class TfDiagGaussian(TfDistribution): @override(TfDistribution) def __init__( self, - loc: Union[float, tf.Tensor], - scale: Optional[Union[float, tf.Tensor]], + loc: Union[float, TensorType], + scale: Optional[Union[float, TensorType]] = None, ): self.loc = loc super().__init__(loc=loc, scale=scale) @@ -253,7 +253,7 @@ class TfDeterministic(Distribution): """ @override(Distribution) - def __init__(self, loc: tf.Tensor) -> None: + def __init__(self, loc: "tf.Tensor") -> None: super().__init__() self.loc = loc diff --git a/rllib/policy/eager_tf_policy_v2.py b/rllib/policy/eager_tf_policy_v2.py index ada4b5fe5e17e..06ca36e8373e8 100644 --- a/rllib/policy/eager_tf_policy_v2.py +++ b/rllib/policy/eager_tf_policy_v2.py @@ -185,7 +185,7 @@ def loss( # sampler will include those keys in the sample batches it returns. This means # that the correct sample batch keys will be available when using the learner # group API. - if self.config._enable_learner_api: + if self.config.get("_enable_rl_module_api", False): for k in model.input_specs_train(): train_batch[k] return None @@ -717,11 +717,12 @@ def get_state(self) -> PolicyState: # Legacy Policy state (w/o keras model and w/o PolicySpec). state = super().get_state() - state["global_timestep"] = state["global_timestep"].numpy() - if self._optimizer and len(self._optimizer.variables()) > 0: - state["_optimizer_variables"] = self._optimizer.variables() - # Add exploration state. - state["_exploration_state"] = self.exploration.get_state() + if not self.config.get("_enable_rl_module_api", False): + state["global_timestep"] = state["global_timestep"].numpy() + if self._optimizer and len(self._optimizer.variables()) > 0: + state["_optimizer_variables"] = self._optimizer.variables() + # Add exploration state. + state["_exploration_state"] = self.exploration.get_state() return state @override(Policy) diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 2a7bcc228c326..eb88e3519f2b9 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -1472,7 +1472,9 @@ def _initialize_loss_from_dummy_batch( # We should simply do self.loss(...) here. if self._loss is not None: self._loss(self, self.model, self.dist_class, train_batch) - elif is_overridden(self.loss) and not self.config["in_evaluation"]: + elif ( + is_overridden(self.loss) or self.config.get("_enable_rl_module_api", False) + ) and not self.config["in_evaluation"]: self.loss(self.model, self.dist_class, train_batch) # Call the stats fn, if given. # TODO(jungong) : clean up after all agents get migrated. diff --git a/rllib/policy/tf_mixins.py b/rllib/policy/tf_mixins.py index 35d0c2f4d2a84..8ce18df5e9796 100644 --- a/rllib/policy/tf_mixins.py +++ b/rllib/policy/tf_mixins.py @@ -208,9 +208,14 @@ class TargetNetworkMixin: """ def __init__(self): - - model_vars = self.model.trainable_variables() - target_model_vars = self.target_model.trainable_variables() + if self.config.get("_enable_rl_module_api", False): + # In order to access the variables for rl modules, we need to + # use the underlying keras api model.trainable_variables. + model_vars = self.model.trainable_variables + target_model_vars = self.target_model.trainable_variables + else: + model_vars = self.model.trainable_variables() + target_model_vars = self.target_model.trainable_variables() @make_tf_callable(self.get_session()) def update_target_fn(tau): @@ -238,13 +243,19 @@ def update_target_fn(tau): @property def q_func_vars(self): if not hasattr(self, "_q_func_vars"): - self._q_func_vars = self.model.variables() + if self.config.get("_enable_rl_module_api", False): + self._q_func_vars = self.model.variables + else: + self._q_func_vars = self.model.variables() return self._q_func_vars @property def target_q_func_vars(self): if not hasattr(self, "_target_q_func_vars"): - self._target_q_func_vars = self.target_model.variables() + if self.config.get("_enable_rl_module_api", False): + self._target_q_func_vars = self.target_model.variables + else: + self._target_q_func_vars = self.target_model.variables() return self._target_q_func_vars # Support both hard and soft sync. @@ -253,7 +264,10 @@ def update_target(self, tau: int = None) -> None: @override(TFPolicy) def variables(self) -> List[TensorType]: - return self.model.variables() + if self.config.get("_enable_rl_module_api", False): + return self.model.variables + else: + return self.model.variables() def set_weights(self, weights): if isinstance(self, TFPolicy): diff --git a/rllib/tuned_examples/appo/cartpole-appo-learner.yaml b/rllib/tuned_examples/appo/cartpole-appo-learner.yaml new file mode 100644 index 0000000000000..b87c51f7f2b20 --- /dev/null +++ b/rllib/tuned_examples/appo/cartpole-appo-learner.yaml @@ -0,0 +1,29 @@ +cartpole-appo-learner: + env: CartPole-v1 + run: APPO + stop: + episode_reward_mean: 150 + timesteps_total: 200000 + config: + # Works for both torch and tf. + framework: tf2 + num_workers: + grid_search: + - 3 + num_gpus: 0 + observation_filter: MeanStdFilter + num_sgd_iter: + grid_search: + - 6 + vf_loss_coeff: 0.01 + vtrace: True + model: + fcnet_hiddens: [32] + fcnet_activation: linear + vf_share_layers: true + enable_connectors: True + _enable_learner_api: True + _enable_rl_module_api: True + eager_tracing: True + lr: 0.001 + entropy_coeff: 0.1 diff --git a/rllib/utils/typing.py b/rllib/utils/typing.py index 09e2372a01058..af6a31b1b8dac 100644 --- a/rllib/utils/typing.py +++ b/rllib/utils/typing.py @@ -38,6 +38,9 @@ # A shape of a tensor. TensorShape = Union[Tuple[int], List[int]] +# A neural network +NetworkType = Union["torch.nn.Module", "tf.keras.Module"] + # Represents a fully filled out config of a Algorithm class. # Note: Policy config dicts are usually the same as AlgorithmConfigDict, but # parts of it may sometimes be altered in e.g. a multi-agent setup,