diff --git a/rllib/BUILD b/rllib/BUILD index 20d81511ac161..a571563201127 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -805,7 +805,7 @@ py_test( py_test( name = "test_algorithm_config", tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"], - size = "small", + size = "medium", srcs = ["algorithms/tests/test_algorithm_config.py"], ) diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index d83b5705a0714..241645f632a7a 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -52,6 +52,7 @@ try_import_gymnasium_and_gym, ) from ray.rllib.utils.policy import validate_policy_id +from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.serialization import ( deserialize_type, NOT_SERIALIZABLE, @@ -846,7 +847,7 @@ def validate(self) -> None: error=True, ) - # RLModule API only works with connectors. + # RLModule API only works with connectors and with Learner API. if not self.enable_connectors and self._enable_rl_module_api: raise ValueError( "RLModule API only works with connectors. " @@ -855,19 +856,26 @@ def validate(self) -> None: ) # Learner API requires RLModule API. - if self._enable_learner_api and not self._enable_rl_module_api: + if self._enable_learner_api is not self._enable_rl_module_api: raise ValueError( - "Learner API requires RLModule API. " - "Please enable RLModule API via " - "`config.training(_enable_rl_module_api=True)`." + "Learner API requires RLModule API and vice-versa! " + "Enable RLModule API via " + "`config.rl_module(_enable_rl_module_api=True)` and the Learner API " + "via `config.training(_enable_learner_api=True)` (or set both to " + "False)." ) if bool(os.environ.get("RLLIB_ENABLE_RL_MODULE", False)): - # enable RLModule API and connectors if env variable is set + # Enable RLModule API and connectors if env variable is set # (to be used in unittesting) self.rl_module(_enable_rl_module_api=True) + self.training(_enable_learner_api=True) self.enable_connectors = True + # LR-schedule checking. + if self._enable_learner_api: + Scheduler.validate(self.lr_schedule, "lr_schedule", "learning rate") + # Validate grad clipping settings. if self.grad_clip_by not in ["value", "norm", "global_norm"]: raise ValueError( @@ -1587,7 +1595,8 @@ def training( lr_schedule: Learning rate schedule. In the format of [[timestep, lr-value], [timestep, lr-value], ...] Intermediary timesteps will be assigned to interpolated learning rate - values. A schedule should normally start from timestep 0. + values. A schedule config's first entry must start with timestep 0, + i.e.: [[0, initial_value], [...]]. grad_clip: The value to use for gradient clipping. Depending on the `grad_clip_by` setting, gradients will either be clipped by value, norm, or global_norm (see docstring on `grad_clip_by` below for more @@ -1664,7 +1673,7 @@ def training( deprecation_warning( old="AlgorithmConfig.training(_use_default_native_models=True)", help="_use_default_native_models is not supported " - "anymore. To get rid of this error, set `experimental(" + "anymore. To get rid of this error, set `rl_module(" "_enable_rl_module_api` to True. Native models will " "be better supported by the upcoming RLModule API.", # Error out if user tries to enable this diff --git a/rllib/algorithms/appo/appo.py b/rllib/algorithms/appo/appo.py index 5b503f239192e..8b27ef5100cc4 100644 --- a/rllib/algorithms/appo/appo.py +++ b/rllib/algorithms/appo/appo.py @@ -381,16 +381,9 @@ def get_default_policy_class( cls, config: AlgorithmConfig ) -> Optional[Type[Policy]]: if config["framework"] == "torch": - if config._enable_rl_module_api: - from ray.rllib.algorithms.appo.torch.appo_torch_policy_rlm import ( - APPOTorchPolicyWithRLModule, - ) - - return APPOTorchPolicyWithRLModule - else: - from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy + from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy - return APPOTorchPolicy + return APPOTorchPolicy elif config["framework"] == "tf": if config._enable_rl_module_api: raise ValueError( @@ -402,12 +395,6 @@ def get_default_policy_class( return APPOTF1Policy else: - if config._enable_rl_module_api: - from ray.rllib.algorithms.appo.tf.appo_tf_policy_rlm import ( - APPOTfPolicyWithRLModule, - ) - - return APPOTfPolicyWithRLModule from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF2Policy return APPOTF2Policy diff --git a/rllib/algorithms/appo/appo_tf_policy.py b/rllib/algorithms/appo/appo_tf_policy.py index d91b4516bfd7a..8441f8032ede8 100644 --- a/rllib/algorithms/appo/appo_tf_policy.py +++ b/rllib/algorithms/appo/appo_tf_policy.py @@ -81,10 +81,15 @@ def __init__( # First thing first, enable eager execution if necessary. base.enable_eager_execution_if_necessary() - # Although this is a no-op, we call __init__ here to make it clear - # that base.__init__ will use the make_model() call. - VTraceClipGradients.__init__(self) - VTraceOptimizer.__init__(self) + # If Learner API is used, we don't need any loss-specific mixins. + # However, we also would like to avoid creating special Policy-subclasses + # for this as the entire Policy concept will soon not be used anymore with + # the new Learner- and RLModule APIs. + if not config.get("_enable_learner_api", False): + # Although this is a no-op, we call __init__ here to make it clear + # that base.__init__ will use the make_model() call. + VTraceClipGradients.__init__(self) + VTraceOptimizer.__init__(self) # Initialize base class. base.__init__( @@ -104,7 +109,9 @@ def __init__( ) ValueNetworkMixin.__init__(self, config) KLCoeffMixin.__init__(self, config) - GradStatsMixin.__init__(self) + + if not config.get("_enable_learner_api", False): + GradStatsMixin.__init__(self) # Note: this is a bit ugly, but loss and optimizer initialization must # happen after all the MixIns are initialized. diff --git a/rllib/algorithms/appo/appo_torch_policy.py b/rllib/algorithms/appo/appo_torch_policy.py index b92b7c32fd510..4a7754830f321 100644 --- a/rllib/algorithms/appo/appo_torch_policy.py +++ b/rllib/algorithms/appo/appo_torch_policy.py @@ -69,9 +69,15 @@ class APPOTorchPolicy( def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.algorithms.appo.appo.APPOConfig().to_dict(), **config) - # Although this is a no-op, we call __init__ here to make it clear - # that base.__init__ will use the make_model() call. - VTraceOptimizer.__init__(self) + # If Learner API is used, we don't need any loss-specific mixins. + # However, we also would like to avoid creating special Policy-subclasses + # for this as the entire Policy concept will soon not be used anymore with + # the new Learner- and RLModule APIs. + if not config.get("_enable_learner_api", False): + # Although this is a no-op, we call __init__ here to make it clear + # that base.__init__ will use the make_model() call. + VTraceOptimizer.__init__(self) + LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) TorchPolicyV2.__init__( @@ -88,7 +94,6 @@ def __init__(self, observation_space, action_space, config): ValueNetworkMixin.__init__(self, config) KLCoeffMixin.__init__(self, config) - # TODO: Don't require users to call this manually. self._initialize_loss_from_dummy_batch() # Initiate TargetNetwork ops after loss initialization. diff --git a/rllib/algorithms/appo/tests/test_appo_learner.py b/rllib/algorithms/appo/tests/test_appo_learner.py index af954bf701e7a..1bc1bd1b0a087 100644 --- a/rllib/algorithms/appo/tests/test_appo_learner.py +++ b/rllib/algorithms/appo/tests/test_appo_learner.py @@ -92,7 +92,6 @@ def test_appo_loss(self): ) algo_config = config.copy(copy_frozen=False) - algo_config.training(_enable_learner_api=True) algo_config.validate() algo_config.freeze() diff --git a/rllib/algorithms/appo/tf/appo_tf_policy_rlm.py b/rllib/algorithms/appo/tf/appo_tf_policy_rlm.py deleted file mode 100644 index 24a4ec9cb6493..0000000000000 --- a/rllib/algorithms/appo/tf/appo_tf_policy_rlm.py +++ /dev/null @@ -1,227 +0,0 @@ -import logging -from typing import Dict, List, Union - -from ray.rllib.algorithms.ppo.ppo_tf_policy import validate_config -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_mixins import ( - EntropyCoeffSchedule, - LearningRateSchedule, - KLCoeffMixin, - GradStatsMixin, - TargetNetworkMixin, -) -from ray.rllib.algorithms.impala.impala_tf_policy import ( - VTraceClipGradients, - VTraceOptimizer, -) -from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 -from ray.rllib.utils.annotations import override -from ray.rllib.utils.deprecation import Deprecated -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_utils import ( - explained_variance, -) - - -from ray.rllib.algorithms.impala.tf.vtrace_tf_v2 import make_time_major, vtrace_tf2 -from ray.rllib.utils.typing import TensorType - -tf1, tf, tfv = try_import_tf() - -logger = logging.getLogger(__name__) - - -class APPOTfPolicyWithRLModule( - VTraceClipGradients, - VTraceOptimizer, - LearningRateSchedule, - KLCoeffMixin, - EntropyCoeffSchedule, - TargetNetworkMixin, - GradStatsMixin, - EagerTFPolicyV2, -): - def __init__(self, observation_space, action_space, config): - validate_config(config) - EagerTFPolicyV2.enable_eager_execution_if_necessary() - # Initialize MixIns before super().__init__ because base class will call - # self.loss, which requires these MixIns to be initialized. - LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) - EntropyCoeffSchedule.__init__( - self, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) - # Although this is a no-op, we call __init__ here to make it clear - # that base.__init__ will use the make_model() call. - VTraceClipGradients.__init__(self) - VTraceOptimizer.__init__(self) - self.framework = "tf2" - KLCoeffMixin.__init__(self, config) - GradStatsMixin.__init__(self) - EagerTFPolicyV2.__init__(self, observation_space, action_space, config) - # Construct the target model and make its weights the same as the model. - self.target_model = self.make_rl_module() - self.target_model.set_weights(self.model.get_weights()) - - # Initiate TargetNetwork ops after loss initialization. - self.maybe_initialize_optimizer_and_loss() - TargetNetworkMixin.__init__(self) - - @Deprecated(new="APPOTfLearner.compute_loss_per_module()", error=False) - @override(EagerTFPolicyV2) - def loss( - self, - model: Union[ModelV2, "tf.keras.Model"], - dist_class, - train_batch: SampleBatch, - ) -> Union[TensorType, List[TensorType]]: - train_batch[SampleBatch.ACTIONS] - train_batch[SampleBatch.ACTION_LOGP] - train_batch[SampleBatch.REWARDS] - train_batch[SampleBatch.TERMINATEDS] - - seqs_len = train_batch.get(SampleBatch.SEQ_LENS) - rollout_frag_or_episode_len = ( - self.config["rollout_fragment_length"] if not seqs_len else None - ) - drop_last = self.config["vtrace_drop_last_ts"] - - target_policy_fwd_out = model.forward_train(train_batch) - values = target_policy_fwd_out[SampleBatch.VF_PREDS] - target_policy_dist = target_policy_fwd_out[SampleBatch.ACTION_DIST] - - old_target_policy_fwd_out = self.target_model.forward_train(train_batch) - old_target_policy_dist = old_target_policy_fwd_out[SampleBatch.ACTION_DIST] - - behaviour_actions_logp = train_batch[SampleBatch.ACTION_LOGP] - target_actions_logp = target_policy_dist.logp(train_batch[SampleBatch.ACTIONS]) - old_target_actions_logp = old_target_policy_dist.logp( - train_batch[SampleBatch.ACTIONS] - ) - behaviour_actions_logp_time_major = make_time_major( - behaviour_actions_logp, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seqs_len, - drop_last=drop_last, - ) - target_actions_logp_time_major = make_time_major( - target_actions_logp, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seqs_len, - drop_last=drop_last, - ) - old_target_actions_logp_time_major = make_time_major( - old_target_actions_logp, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seqs_len, - drop_last=drop_last, - ) - values_time_major = make_time_major( - values, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seqs_len, - drop_last=drop_last, - ) - bootstrap_value = values_time_major[-1] - rewards_time_major = make_time_major( - train_batch[SampleBatch.REWARDS], - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seqs_len, - drop_last=drop_last, - ) - - # how to compute discouts? - # should they be pre computed? - discounts_time_major = ( - 1.0 - - tf.cast( - make_time_major( - train_batch[SampleBatch.TERMINATEDS], - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seqs_len, - drop_last=drop_last, - ), - dtype=tf.float32, - ) - ) * self.config["gamma"] - - # Note that vtrace will compute the main loop on the CPU for better performance. - vtrace_adjusted_target_values, pg_advantages = vtrace_tf2( - target_action_log_probs=old_target_actions_logp_time_major, - behaviour_action_log_probs=behaviour_actions_logp_time_major, - discounts=discounts_time_major, - rewards=rewards_time_major, - values=values_time_major, - bootstrap_value=bootstrap_value, - clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], - clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], - ) - - is_ratio = tf.clip_by_value( - tf.math.exp( - behaviour_actions_logp_time_major - target_actions_logp_time_major - ), - 0.0, - 2.0, - ) - logp_ratio = is_ratio * tf.math.exp( - target_actions_logp_time_major - behaviour_actions_logp_time_major - ) - - clip_param = self.config["clip_param"] - surrogate_loss = tf.math.minimum( - pg_advantages * logp_ratio, - ( - pg_advantages - * tf.clip_by_value(logp_ratio, 1 - clip_param, 1 + clip_param) - ), - ) - action_kl = old_target_policy_dist.kl(target_policy_dist) - mean_kl_loss = tf.math.reduce_mean(action_kl) - mean_pi_loss = -tf.math.reduce_mean(surrogate_loss) - - # The baseline loss. - delta = values_time_major - vtrace_adjusted_target_values - mean_vf_loss = 0.5 * tf.math.reduce_mean(delta**2) - - # The entropy loss. - mean_entropy_loss = -tf.math.reduce_mean(target_policy_dist.entropy()) - - # The summed weighted loss. - total_loss = ( - mean_pi_loss - + (mean_vf_loss * self.config["vf_loss_coeff"]) - + (mean_entropy_loss * self.entropy_coeff) - + (mean_kl_loss * self.kl_coeff) - ) - - self.stats = { - "total_loss": total_loss, - "policy_loss": mean_pi_loss, - "vf_loss": mean_vf_loss, - "values": values_time_major, - "entropy_loss": mean_entropy_loss, - "vtrace_adjusted_target_values": vtrace_adjusted_target_values, - "mean_kl": mean_kl_loss, - } - return total_loss - - @override(EagerTFPolicyV2) - def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: - return { - "cur_lr": tf.cast(self.cur_lr, tf.float64), - "policy_loss": self.stats["policy_loss"], - "entropy": self.stats["entropy_loss"], - "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), - "var_gnorm": tf.linalg.global_norm(self.model.trainable_variables), - "vf_loss": self.stats["vf_loss"], - "vf_explained_var": explained_variance( - tf.reshape(self.stats["vtrace_adjusted_target_values"], [-1]), - tf.reshape(self.stats["values"], [-1]), - ), - "mean_kl": self.stats["mean_kl"], - } - - @override(EagerTFPolicyV2) - def get_batch_divisibility_req(self) -> int: - return self.config["rollout_fragment_length"] diff --git a/rllib/algorithms/appo/torch/appo_torch_policy_rlm.py b/rllib/algorithms/appo/torch/appo_torch_policy_rlm.py deleted file mode 100644 index 81bc072eca43b..0000000000000 --- a/rllib/algorithms/appo/torch/appo_torch_policy_rlm.py +++ /dev/null @@ -1,213 +0,0 @@ -import logging - -from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import ( - make_time_major, - vtrace_torch, -) -from ray.rllib.policy.torch_mixins import ( - EntropyCoeffSchedule, - LearningRateSchedule, - KLCoeffMixin, - TargetNetworkMixin, -) -from ray.rllib.algorithms.impala.impala_torch_policy import ( - VTraceOptimizer, -) -from ray.rllib.algorithms.ppo.ppo_torch_policy import validate_config -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 -from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_utils import ( - convert_to_torch_tensor, - explained_variance, - global_norm, -) - -torch, _ = try_import_torch() - -logger = logging.getLogger(__name__) - - -# TODO: Remove once we have a RLModule capable sampler class that can replace -# `Policy.compute_actions_from_input_dict()`. -class APPOTorchPolicyWithRLModule( - VTraceOptimizer, - LearningRateSchedule, - KLCoeffMixin, - EntropyCoeffSchedule, - TargetNetworkMixin, - TorchPolicyV2, -): - def __init__(self, observation_space, action_space, config): - validate_config(config) - # Initialize MixIns before super().__init__ because base class will call - # self.loss, which requires these MixIns to be initialized. - LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) - EntropyCoeffSchedule.__init__( - self, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) - # Although this is a no-op, we call __init__ here to make it clear - # that base.__init__ will use the make_model() call. - # VTraceClipGradients.__init__(self) - VTraceOptimizer.__init__(self) - self.framework = "tf2" - KLCoeffMixin.__init__(self, config) - # GradStatsMixin.__init__(self) - TorchPolicyV2.__init__(self, observation_space, action_space, config) - # Construct the target model and make its weights the same as the model. - self.target_model = self.make_rl_module() - self.target_model.load_state_dict(self.model.state_dict()) - - # Initiate TargetNetwork ops after loss initialization. - self._initialize_loss_from_dummy_batch() - TargetNetworkMixin.__init__(self) - - @override(TorchPolicyV2) - def loss(self, model, dist_class, train_batch): - train_batch[SampleBatch.ACTION_LOGP] - train_batch[SampleBatch.ACTIONS] - train_batch[SampleBatch.REWARDS] - train_batch[SampleBatch.TERMINATEDS] - - seqs_len = train_batch.get(SampleBatch.SEQ_LENS) - rollout_frag_or_episode_len = ( - self.config["rollout_fragment_length"] if not seqs_len else None - ) - drop_last = self.config["vtrace_drop_last_ts"] - - target_policy_fwd_out = model.forward_train(train_batch) - values = target_policy_fwd_out[SampleBatch.VF_PREDS] - target_policy_dist = target_policy_fwd_out[SampleBatch.ACTION_DIST] - - old_target_policy_fwd_out = self.target_model.forward_train(train_batch) - old_target_policy_dist = old_target_policy_fwd_out[SampleBatch.ACTION_DIST] - - behaviour_actions_logp = train_batch[SampleBatch.ACTION_LOGP] - target_actions_logp = target_policy_dist.logp(train_batch[SampleBatch.ACTIONS]) - old_target_actions_logp = old_target_policy_dist.logp( - train_batch[SampleBatch.ACTIONS] - ) - behaviour_actions_logp_time_major = make_time_major( - behaviour_actions_logp, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seqs_len, - drop_last=drop_last, - ) - target_actions_logp_time_major = make_time_major( - target_actions_logp, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seqs_len, - drop_last=drop_last, - ) - old_target_actions_logp_time_major = make_time_major( - old_target_actions_logp, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seqs_len, - drop_last=drop_last, - ) - values_time_major = make_time_major( - values, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seqs_len, - drop_last=drop_last, - ) - bootstrap_value = values_time_major[-1] - rewards_time_major = make_time_major( - train_batch[SampleBatch.REWARDS], - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seqs_len, - drop_last=drop_last, - ) - - # how to compute discouts? - # should they be pre computed? - discounts_time_major = ( - 1.0 - - make_time_major( - train_batch[SampleBatch.TERMINATEDS], - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seqs_len, - drop_last=drop_last, - ).float() - ) * self.config["gamma"] - - # Note that vtrace will compute the main loop on the CPU for better performance. - vtrace_adjusted_target_values, pg_advantages = vtrace_torch( - target_action_log_probs=old_target_actions_logp_time_major, - behaviour_action_log_probs=behaviour_actions_logp_time_major, - discounts=discounts_time_major, - rewards=rewards_time_major, - values=values_time_major, - bootstrap_value=bootstrap_value, - clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], - clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], - ) - - is_ratio = torch.clip( - torch.exp( - behaviour_actions_logp_time_major - target_actions_logp_time_major - ), - 0.0, - 2.0, - ) - logp_ratio = is_ratio * torch.exp( - target_actions_logp_time_major - behaviour_actions_logp_time_major - ) - - clip_param = self.config["clip_param"] - surrogate_loss = torch.minimum( - pg_advantages * logp_ratio, - (pg_advantages * torch.clip(logp_ratio, 1 - clip_param, 1 + clip_param)), - ) - action_kl = old_target_policy_dist.kl(target_policy_dist) - mean_kl_loss = torch.mean(action_kl) - mean_pi_loss = -torch.mean(surrogate_loss) - - # The baseline loss. - delta = values_time_major - vtrace_adjusted_target_values - mean_vf_loss = 0.5 * torch.mean(delta**2) - - # The entropy loss. - mean_entropy_loss = -torch.mean(target_policy_dist.entropy()) - - # The summed weighted loss. - total_loss = ( - mean_pi_loss - + (mean_vf_loss * self.config["vf_loss_coeff"]) - + (mean_entropy_loss * self.entropy_coeff) - + (mean_kl_loss * self.kl_coeff) - ) - - self.stats = { - "total_loss": total_loss, - "policy_loss": mean_pi_loss, - "vf_loss": mean_vf_loss, - "values": values_time_major, - "entropy_loss": mean_entropy_loss, - "vtrace_adjusted_target_values": vtrace_adjusted_target_values, - "mean_kl": mean_kl_loss, - } - return total_loss - - @override(TorchPolicyV2) - def stats_fn(self, train_batch: SampleBatch): - return { - "cur_lr": convert_to_torch_tensor(self.cur_lr).type(torch.float64), - "policy_loss": self.stats["policy_loss"], - "entropy": self.stats["entropy_loss"], - "entropy_coeff": convert_to_torch_tensor(self.entropy_coeff).type( - torch.float64 - ), - "var_gnorm": global_norm(self.model.parameters()), - "vf_loss": self.stats["vf_loss"], - "vf_explained_var": explained_variance( - torch.reshape(self.stats["vtrace_adjusted_target_values"], [-1]), - torch.reshape(self.stats["values"], [-1]), - ), - "mean_kl": self.stats["mean_kl"], - } - - @override(TorchPolicyV2) - def get_batch_divisibility_req(self) -> int: - return self.config["rollout_fragment_length"] diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index 57c47e3ec9403..d1ecfda9d6e4a 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -46,10 +46,10 @@ SYNCH_WORKER_WEIGHTS_TIMER, SAMPLE_TIMER, ) +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder from ray.rllib.utils.replay_buffers.multi_agent_replay_buffer import ReplayMode from ray.rllib.utils.replay_buffers.replay_buffer import _ALL_POLICIES - -from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder +from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.typing import ( PartialAlgorithmConfigDict, PolicyID, @@ -371,6 +371,13 @@ def validate(self) -> None: # Check `entropy_coeff` for correctness. if self.entropy_coeff < 0.0: raise ValueError("`entropy_coeff` must be >= 0.0!") + # Entropy coeff schedule checking. + if self._enable_learner_api: + Scheduler.validate( + self.entropy_coeff_schedule, + "entropy_coeff_schedule", + "entropy coefficient", + ) # Check whether worker to aggregation-worker ratio makes sense. if self.num_aggregation_workers > self.num_rollout_workers: @@ -561,58 +568,39 @@ def get_default_policy_class( if not config["vtrace"]: raise ValueError("IMPALA with the learner API does not support non-VTrace ") - if config._enable_rl_module_api: - if config["framework"] == "tf2": - from ray.rllib.algorithms.impala.tf.impala_tf_policy_rlm import ( - ImpalaTfPolicyWithRLModule, + if config["framework"] == "torch": + if config["vtrace"]: + from ray.rllib.algorithms.impala.impala_torch_policy import ( + ImpalaTorchPolicy, ) - return ImpalaTfPolicyWithRLModule - if config["framework"] == "torch": - from ray.rllib.algorithms.impala.torch.impala_torch_policy_rlm import ( - ImpalaTorchPolicyWithRLModule, - ) - - return ImpalaTorchPolicyWithRLModule + return ImpalaTorchPolicy else: - raise ValueError( - f"IMPALA with the learner API does not support framework " - f"{config['framework']} " - ) - else: - if config["framework"] == "torch": - if config["vtrace"]: - from ray.rllib.algorithms.impala.impala_torch_policy import ( - ImpalaTorchPolicy, - ) + from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy - return ImpalaTorchPolicy - else: - from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy + return A3CTorchPolicy + elif config["framework"] == "tf": + if config["vtrace"]: + from ray.rllib.algorithms.impala.impala_tf_policy import ( + ImpalaTF1Policy, + ) - return A3CTorchPolicy - elif config["framework"] == "tf": - if config["vtrace"]: - from ray.rllib.algorithms.impala.impala_tf_policy import ( - ImpalaTF1Policy, - ) + return ImpalaTF1Policy + else: + from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy - return ImpalaTF1Policy - else: - from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy + return A3CTFPolicy + else: + if config["vtrace"]: + from ray.rllib.algorithms.impala.impala_tf_policy import ( + ImpalaTF2Policy, + ) - return A3CTFPolicy + return ImpalaTF2Policy else: - if config["vtrace"]: - from ray.rllib.algorithms.impala.impala_tf_policy import ( - ImpalaTF2Policy, - ) - - return ImpalaTF2Policy - else: - from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy + from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTFPolicy - return A3CTFPolicy + return A3CTFPolicy @override(Algorithm) def setup(self, config: AlgorithmConfig): diff --git a/rllib/algorithms/impala/impala_learner.py b/rllib/algorithms/impala/impala_learner.py index db0d28e1f9a9f..7e1153f4acc40 100644 --- a/rllib/algorithms/impala/impala_learner.py +++ b/rllib/algorithms/impala/impala_learner.py @@ -1,11 +1,11 @@ -from collections import defaultdict from dataclasses import dataclass -from typing import Any, List, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union import numpy as np import tree # pip install dm_tree from ray.rllib.core.learner.learner import Learner, LearnerHyperparameters +from ray.rllib.core.rl_module.rl_module import ModuleID from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.metrics import ( @@ -13,10 +13,13 @@ NUM_AGENT_STEPS_TRAINED, NUM_ENV_STEPS_TRAINED, ) -from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule +from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.typing import ResultDict +LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY = "curr_entropy_coeff" + + @dataclass class ImpalaHyperparameters(LearnerHyperparameters): """Hyperparameters for the ImpalaLearner sub-classes (framework specific). @@ -49,20 +52,26 @@ def build(self) -> None: super().build() # Build entropy coeff scheduling tools. - self.entropy_coeff_scheduler = None - if self.hps.entropy_coeff_schedule: - # Custom schedule, based on list of - # ([ts], [value to be reached by ts])-tuples. - self.entropy_coeff_schedule_per_module = defaultdict( - lambda: PiecewiseSchedule( - self.hps.entropy_coeff_schedule, - outside_value=self.hps.entropy_coeff_schedule[-1][-1], - framework=None, - ) - ) - self.curr_entropy_coeffs_per_module = defaultdict( - lambda: self._get_tensor_variable(self.hps.entropy_coeff) - ) + self.entropy_coeff_scheduler = Scheduler( + fixed_value=self.hps.entropy_coeff, + schedule=self.hps.entropy_coeff_schedule, + framework=self.framework, + device=self._device, + ) + + @override(Learner) + def additional_update_per_module( + self, module_id: ModuleID, timestep: int + ) -> Dict[str, Any]: + results = super().additional_update_per_module(module_id, timestep=timestep) + + # Update entropy coefficient via our Scheduler. + new_entropy_coeff = self.entropy_coeff_scheduler.update( + module_id, timestep=timestep + ) + results.update({LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY: new_entropy_coeff}) + + return results @override(Learner) def compile_results( diff --git a/rllib/algorithms/impala/impala_tf_policy.py b/rllib/algorithms/impala/impala_tf_policy.py index e1b66f5332127..d8b830ef7653b 100644 --- a/rllib/algorithms/impala/impala_tf_policy.py +++ b/rllib/algorithms/impala/impala_tf_policy.py @@ -297,13 +297,18 @@ def __init__( existing_model=existing_model, ) - GradStatsMixin.__init__(self) - VTraceClipGradients.__init__(self) - VTraceOptimizer.__init__(self) - LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) - EntropyCoeffSchedule.__init__( - self, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) + # If Learner API is used, we don't need any loss-specific mixins. + # However, we also would like to avoid creating special Policy-subclasses + # for this as the entire Policy concept will soon not be used anymore with + # the new Learner- and RLModule APIs. + if not self.config.get("_enable_learner_api"): + GradStatsMixin.__init__(self) + VTraceClipGradients.__init__(self) + VTraceOptimizer.__init__(self) + LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + EntropyCoeffSchedule.__init__( + self, config["entropy_coeff"], config["entropy_coeff_schedule"] + ) # Note: this is a bit ugly, but loss and optimizer initialization must # happen after all the MixIns are initialized. diff --git a/rllib/algorithms/impala/impala_torch_policy.py b/rllib/algorithms/impala/impala_torch_policy.py index 73d4b3c7bd121..71aed03206015 100644 --- a/rllib/algorithms/impala/impala_torch_policy.py +++ b/rllib/algorithms/impala/impala_torch_policy.py @@ -201,13 +201,18 @@ def __init__(self, observation_space, action_space, config): ray.rllib.algorithms.impala.impala.ImpalaConfig().to_dict(), **config ) - VTraceOptimizer.__init__(self) - # Need to initialize learning rate variable before calling - # TorchPolicyV2.__init__. - LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) - EntropyCoeffSchedule.__init__( - self, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) + # If Learner API is used, we don't need any loss-specific mixins. + # However, we also would like to avoid creating special Policy-subclasses + # for this as the entire Policy concept will soon not be used anymore with + # the new Learner- and RLModule APIs. + if not config.get("_enable_learner_api"): + VTraceOptimizer.__init__(self) + # Need to initialize learning rate variable before calling + # TorchPolicyV2.__init__. + LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + EntropyCoeffSchedule.__init__( + self, config["entropy_coeff"], config["entropy_coeff_schedule"] + ) TorchPolicyV2.__init__( self, @@ -217,7 +222,6 @@ def __init__(self, observation_space, action_space, config): max_seq_len=config["model"]["max_seq_len"], ) - # TODO: Don't require users to call this manually. self._initialize_loss_from_dummy_batch() @override(TorchPolicyV2) diff --git a/rllib/algorithms/impala/tests/test_impala_learner.py b/rllib/algorithms/impala/tests/test_impala_learner.py index 6adc936a45baa..8725434b97515 100644 --- a/rllib/algorithms/impala/tests/test_impala_learner.py +++ b/rllib/algorithms/impala/tests/test_impala_learner.py @@ -1,6 +1,7 @@ import unittest import numpy as np +import tree # pip install dm_tree import ray from ray.rllib.algorithms.impala import ImpalaConfig @@ -8,6 +9,7 @@ from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.framework import try_import_torch, try_import_tf from ray.rllib.utils.test_utils import framework_iterator +from ray.rllib.utils.torch_utils import convert_to_torch_tensor torch, nn = try_import_torch() tf1, tf, _ = try_import_tf() @@ -77,11 +79,17 @@ def test_impala_loss(self): # Deprecate the current default and set it to {}. config.exploration_config = {} - for _ in framework_iterator(config, frameworks=["tf2", "torch"]): + for fw in framework_iterator(config, frameworks=["tf2", "torch"]): algo = config.build() policy = algo.get_policy() - train_batch = SampleBatch(FAKE_BATCH) + if fw == "tf2": + train_batch = SampleBatch( + tree.map_structure(lambda x: tf.convert_to_tensor(x), FAKE_BATCH) + ) + elif fw == "torch": + train_batch = convert_to_torch_tensor(SampleBatch(FAKE_BATCH)) + algo_config = config.copy(copy_frozen=False) algo_config.validate() algo_config.freeze() diff --git a/rllib/algorithms/impala/tests/test_impala_off_policyness.py b/rllib/algorithms/impala/tests/test_impala_off_policyness.py index 09600ff3f046b..82a92916172f6 100644 --- a/rllib/algorithms/impala/tests/test_impala_off_policyness.py +++ b/rllib/algorithms/impala/tests/test_impala_off_policyness.py @@ -1,4 +1,3 @@ -import itertools import unittest import ray @@ -6,7 +5,6 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.test_utils import ( check_compute_single_action, - check_off_policyness, framework_iterator, ) @@ -28,43 +26,29 @@ def test_impala_off_policyness(self): .environment("CartPole-v1") .resources(num_gpus=0) .rollouts(num_rollout_workers=4) + .training(_enable_learner_api=True) + .rl_module(_enable_rl_module_api=True) ) num_iterations = 3 num_aggregation_workers_options = [0, 1] - enable_rlm_learner_group_options = [True, False] - - default_exploration_config = config.exploration_config.copy() - - for permutation in itertools.product( - num_aggregation_workers_options, enable_rlm_learner_group_options - ): - num_aggregation_workers, enable_learner_api = permutation - for fw in framework_iterator( - config, with_eager_tracing=True, frameworks=["tf2"] + for num_aggregation_workers in num_aggregation_workers_options: + for _ in framework_iterator( + config, frameworks=("tf2", "torch"), with_eager_tracing=True ): - # TODO(avnishn): Enable this for torch when we merge the torch learner. - if enable_learner_api and fw != "tf2": - continue - config.training(_enable_learner_api=enable_learner_api) - config.rl_module(_enable_rl_module_api=enable_learner_api) - if enable_learner_api: - # We have to set exploration_config here manually because setting - # it through config.exploration() only deepupdates it - config.exploration_config = {} - else: - config.exploration_config = default_exploration_config + # We have to set exploration_config here manually because setting + # it through config.exploration() only deepupdates it + config.exploration_config = {} config.num_aggregation_workers = num_aggregation_workers print("aggregation-workers={}".format(config.num_aggregation_workers)) algo = config.build() for i in range(num_iterations): - results = algo.train() + algo.train() # TODO (Avnish): Add off-policiness check when the metrics are - # added back to the IMPALA Learner - if not enable_learner_api: - off_policy_ness = check_off_policyness(results, upper_limit=2.0) - print(f"off-policy'ness={off_policy_ness}") + # added back to the IMPALA Learner. + # off_policy_ness = check_off_policyness(results, upper_limit=2.0) + # print(f"off-policy'ness={off_policy_ness}") check_compute_single_action( algo, diff --git a/rllib/algorithms/impala/tf/impala_tf_learner.py b/rllib/algorithms/impala/tf/impala_tf_learner.py index fa2481b5bbd5b..fd04a6bf8d23c 100644 --- a/rllib/algorithms/impala/tf/impala_tf_learner.py +++ b/rllib/algorithms/impala/tf/impala_tf_learner.py @@ -1,11 +1,9 @@ -from typing import Any, Dict, Mapping +from typing import Mapping from ray.rllib.algorithms.impala.impala_learner import ImpalaLearner from ray.rllib.algorithms.impala.tf.vtrace_tf_v2 import make_time_major, vtrace_tf2 -from ray.rllib.algorithms.ppo.ppo_learner import LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY from ray.rllib.core.learner.learner import ENTROPY_KEY from ray.rllib.core.learner.tf.tf_learner import TfLearner -from ray.rllib.core.rl_module.rl_module import ModuleID from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_tf @@ -102,21 +100,3 @@ def compute_loss_per_module( "vf_loss": mean_vf_loss, ENTROPY_KEY: -mean_entropy_loss, } - - @override(ImpalaLearner) - def additional_update_per_module( - self, module_id: ModuleID, timestep: int - ) -> Dict[str, Any]: - results = super().additional_update_per_module( - module_id, - timestep=timestep, - ) - - # Update entropy coefficient. - value = self.hps.entropy_coeff - if self.hps.entropy_coeff_schedule is not None: - value = self.entropy_coeff_schedule_per_module[module_id].value(t=timestep) - self.curr_entropy_coeffs_per_module[module_id].assign(value) - results.update({LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY: value}) - - return results diff --git a/rllib/algorithms/impala/tf/impala_tf_policy_rlm.py b/rllib/algorithms/impala/tf/impala_tf_policy_rlm.py deleted file mode 100644 index f24c6e88c9a13..0000000000000 --- a/rllib/algorithms/impala/tf/impala_tf_policy_rlm.py +++ /dev/null @@ -1,165 +0,0 @@ -import logging -from typing import Dict, List, Union - -from ray.rllib.algorithms.ppo.ppo_tf_policy import validate_config -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_mixins import ( - EntropyCoeffSchedule, - LearningRateSchedule, -) -from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 -from ray.rllib.utils.annotations import override -from ray.rllib.utils.deprecation import Deprecated -from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.tf_utils import ( - explained_variance, -) -from ray.rllib.algorithms.impala.tf.vtrace_tf_v2 import make_time_major, vtrace_tf2 -from ray.rllib.utils.typing import TensorType - -tf1, tf, tfv = try_import_tf() - -logger = logging.getLogger(__name__) - - -class ImpalaTfPolicyWithRLModule( - LearningRateSchedule, - EntropyCoeffSchedule, - EagerTFPolicyV2, -): - def __init__(self, observation_space, action_space, config): - validate_config(config) - EagerTFPolicyV2.enable_eager_execution_if_necessary() - # Initialize MixIns before super().__init__ because base class will call - # self.loss, which requires these MixIns to be initialized. - LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) - EntropyCoeffSchedule.__init__( - self, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) - EagerTFPolicyV2.__init__(self, observation_space, action_space, config) - - self.maybe_initialize_optimizer_and_loss() - - @Deprecated(new="ImpalaTfLearner.compute_loss_per_module()", error=False) - @override(EagerTFPolicyV2) - def loss( - self, - model: PPOTfRLModule, - dist_class, - train_batch: SampleBatch, - ) -> Union[TensorType, List[TensorType]]: - seq_len = train_batch.get(SampleBatch.SEQ_LENS) - rollout_frag_or_episode_len = ( - self.config["rollout_fragment_length"] if not seq_len else None - ) - drop_last = self.config["vtrace_drop_last_ts"] - - fwd_out = model.forward_train(train_batch) - - values = fwd_out[SampleBatch.VF_PREDS] - target_policy_dist = fwd_out[SampleBatch.ACTION_DIST] - - # this is probably a horribly inefficient way to do this. I should be able to - # compute this in a batch fashion - behaviour_actions_logp = train_batch[SampleBatch.ACTION_LOGP] - target_actions_logp = target_policy_dist.logp(train_batch[SampleBatch.ACTIONS]) - behaviour_actions_logp_time_major = make_time_major( - behaviour_actions_logp, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seq_len, - drop_last=drop_last, - ) - target_actions_logp_time_major = make_time_major( - target_actions_logp, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seq_len, - drop_last=drop_last, - ) - values_time_major = make_time_major( - values, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seq_len, - drop_last=drop_last, - ) - bootstrap_value = values_time_major[-1] - rewards_time_major = make_time_major( - train_batch[SampleBatch.REWARDS], - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seq_len, - drop_last=drop_last, - ) - - # how to compute discouts? - # should they be pre computed? - discounts_time_major = ( - 1.0 - - tf.cast( - make_time_major( - train_batch[SampleBatch.TERMINATEDS], - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seq_len, - drop_last=drop_last, - ), - dtype=tf.float32, - ) - ) * self.config["gamma"] - - # Note that vtrace will compute the main loop on the CPU for better performance. - vtrace_adjusted_target_values, pg_advantages = vtrace_tf2( - target_action_log_probs=target_actions_logp_time_major, - behaviour_action_log_probs=behaviour_actions_logp_time_major, - discounts=discounts_time_major, - rewards=rewards_time_major, - values=values_time_major, - bootstrap_value=bootstrap_value, - clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], - clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], - ) - - # The policy gradients loss. - pi_loss = -tf.reduce_sum(target_actions_logp_time_major * pg_advantages) - mean_pi_loss = -tf.reduce_mean(target_actions_logp_time_major * pg_advantages) - - # The baseline loss. - delta = values_time_major - vtrace_adjusted_target_values - vf_loss = 0.5 * tf.reduce_sum(tf.math.pow(delta, 2.0)) - mean_vf_loss = 0.5 * tf.reduce_mean(tf.math.pow(delta, 2.0)) - - # The entropy loss. - mean_entropy_loss = -tf.reduce_mean(target_policy_dist.entropy()) - - # The summed weighted loss. - total_loss = ( - pi_loss - + vf_loss * self.config["vf_loss_coeff"] - + mean_entropy_loss * self.entropy_coeff - ) - self.stats = { - "total_loss": total_loss, - "pi_loss": mean_pi_loss, - "vf_loss": mean_vf_loss, - "values": values_time_major, - "entropy_loss": mean_entropy_loss, - "vtrace_adjusted_target_values": vtrace_adjusted_target_values, - } - return total_loss - - @override(EagerTFPolicyV2) - def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: - return { - "cur_lr": tf.cast(self.cur_lr, tf.float64), - "policy_loss": self.stats["pi_loss"], - "entropy": self.stats["entropy_loss"], - "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), - "var_gnorm": tf.linalg.global_norm(self.model.trainable_variables), - "vf_loss": self.stats["vf_loss"], - "vf_explained_var": explained_variance( - tf.reshape(self.stats["vtrace_adjusted_target_values"], [-1]), - tf.reshape(self.stats["values"], [-1]), - ), - } - - @override(EagerTFPolicyV2) - def get_batch_divisibility_req(self) -> int: - return self.config["rollout_fragment_length"] diff --git a/rllib/algorithms/impala/torch/impala_torch_learner.py b/rllib/algorithms/impala/torch/impala_torch_learner.py index 907c2d4e32611..50dd9911823fb 100644 --- a/rllib/algorithms/impala/torch/impala_torch_learner.py +++ b/rllib/algorithms/impala/torch/impala_torch_learner.py @@ -1,14 +1,12 @@ -from typing import Any, Dict, Mapping +from typing import Mapping from ray.rllib.algorithms.impala.impala_learner import ImpalaLearner from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import ( vtrace_torch, make_time_major, ) -from ray.rllib.algorithms.ppo.ppo_learner import LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY from ray.rllib.core.learner.learner import ENTROPY_KEY from ray.rllib.core.learner.torch.torch_learner import TorchLearner -from ray.rllib.core.rl_module.rl_module import ModuleID from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch @@ -118,21 +116,3 @@ def compute_loss_per_module( "vf_loss": mean_vf_loss, ENTROPY_KEY: -mean_entropy_loss, } - - @override(ImpalaLearner) - def additional_update_per_module( - self, module_id: ModuleID, timestep: int - ) -> Dict[str, Any]: - results = super().additional_update_per_module( - module_id, - timestep=timestep, - ) - - # Update entropy coefficient. - value = self.hps.entropy_coeff - if self.hps.entropy_coeff_schedule is not None: - value = self.entropy_coeff_schedule_per_module[module_id].value(t=timestep) - self.curr_entropy_coeffs_per_module[module_id].data = torch.tensor(value) - results.update({LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY: value}) - - return results diff --git a/rllib/algorithms/impala/torch/impala_torch_policy_rlm.py b/rllib/algorithms/impala/torch/impala_torch_policy_rlm.py deleted file mode 100644 index 6e10a0a238395..0000000000000 --- a/rllib/algorithms/impala/torch/impala_torch_policy_rlm.py +++ /dev/null @@ -1,167 +0,0 @@ -import logging -from typing import Dict, List, Union - -from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import ( - make_time_major, - vtrace_torch, -) -from ray.rllib.algorithms.ppo.ppo_torch_policy import validate_config -from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.torch_utils import convert_to_torch_tensor -from ray.rllib.policy.torch_mixins import ( - EntropyCoeffSchedule, - LearningRateSchedule, -) -from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 -from ray.rllib.utils.annotations import override, Deprecated -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.torch_utils import ( - explained_variance, - global_norm, -) -from ray.rllib.utils.typing import TensorType - -torch, nn = try_import_torch() - -logger = logging.getLogger(__name__) - - -class ImpalaTorchPolicyWithRLModule( - LearningRateSchedule, - EntropyCoeffSchedule, - TorchPolicyV2, -): - def __init__(self, observation_space, action_space, config): - validate_config(config) - TorchPolicyV2.__init__(self, observation_space, action_space, config) - # Initialize MixIns. - LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) - EntropyCoeffSchedule.__init__( - self, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) - - # TODO: Don't require users to call this manually. - self._initialize_loss_from_dummy_batch() - - @Deprecated(new="ImpalaTorchLearner.compute_loss_per_module()", error=False) - @override(TorchPolicyV2) - def loss( - self, - model: PPOTorchRLModule, - dist_class, - train_batch: SampleBatch, - ) -> Union[TensorType, List[TensorType]]: - seq_len = train_batch.get(SampleBatch.SEQ_LENS) - rollout_frag_or_episode_len = ( - self.config["rollout_fragment_length"] if not seq_len else None - ) - drop_last = self.config["vtrace_drop_last_ts"] - - fwd_out = model.forward_train(train_batch) - - values = fwd_out[SampleBatch.VF_PREDS] - target_policy_dist = fwd_out[SampleBatch.ACTION_DIST] - - # this is probably a horribly inefficient way to do this. I should be able to - # compute this in a batch fashion - behaviour_actions_logp = train_batch[SampleBatch.ACTION_LOGP] - target_actions_logp = target_policy_dist.logp(train_batch[SampleBatch.ACTIONS]) - behaviour_actions_logp_time_major = make_time_major( - behaviour_actions_logp, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seq_len, - drop_last=drop_last, - ) - target_actions_logp_time_major = make_time_major( - target_actions_logp, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seq_len, - drop_last=drop_last, - ) - values_time_major = make_time_major( - values, - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seq_len, - drop_last=drop_last, - ) - bootstrap_value = values_time_major[-1] - rewards_time_major = make_time_major( - train_batch[SampleBatch.REWARDS], - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seq_len, - drop_last=drop_last, - ) - - # how to compute discouts? - # should they be pre computed? - discounts_time_major = ( - 1.0 - - make_time_major( - train_batch[SampleBatch.TERMINATEDS], - trajectory_len=rollout_frag_or_episode_len, - recurrent_seq_len=seq_len, - drop_last=drop_last, - ).type(dtype=torch.float32) - ) * self.config["gamma"] - - # Note that vtrace will compute the main loop on the CPU for better performance. - vtrace_adjusted_target_values, pg_advantages = vtrace_torch( - target_action_log_probs=target_actions_logp_time_major, - behaviour_action_log_probs=behaviour_actions_logp_time_major, - discounts=discounts_time_major, - rewards=rewards_time_major, - values=values_time_major, - bootstrap_value=bootstrap_value, - clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], - clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], - ) - - # The policy gradients loss. - pi_loss = -torch.sum(target_actions_logp_time_major * pg_advantages) - mean_pi_loss = -torch.mean(target_actions_logp_time_major * pg_advantages) - - # The baseline loss. - delta = values_time_major - vtrace_adjusted_target_values - vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0)) - mean_vf_loss = 0.5 * torch.mean(torch.pow(delta, 2.0)) - - # The entropy loss. - mean_entropy_loss = -torch.mean(target_policy_dist.entropy()) - - # The summed weighted loss. - total_loss = ( - pi_loss - + vf_loss * self.config["vf_loss_coeff"] - + mean_entropy_loss * self.entropy_coeff - ) - self.stats = { - "total_loss": total_loss, - "pi_loss": mean_pi_loss, - "vf_loss": mean_vf_loss, - "values": values_time_major, - "entropy_loss": mean_entropy_loss, - "vtrace_adjusted_target_values": vtrace_adjusted_target_values, - } - return total_loss - - @override(TorchPolicyV2) - def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: - return { - "cur_lr": convert_to_torch_tensor(self.cur_lr).type(torch.float64), - "policy_loss": self.stats["pi_loss"], - "entropy": self.stats["entropy_loss"], - "entropy_coeff": convert_to_torch_tensor(self.entropy_coeff).type( - torch.float64 - ), - "var_gnorm": global_norm(self.model.parameters()), - "vf_loss": self.stats["vf_loss"], - "vf_explained_var": explained_variance( - torch.reshape(self.stats["vtrace_adjusted_target_values"], [-1]), - torch.reshape(self.stats["values"], [-1]), - ), - } - - @override(TorchPolicyV2) - def get_batch_divisibility_req(self) -> int: - return self.config["rollout_fragment_length"] diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index 89481a8352572..05f531117901e 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -13,7 +13,6 @@ import logging from typing import List, Optional, Type, Union, TYPE_CHECKING -from ray.util.debug import log_once from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.algorithms.pg import PGConfig @@ -25,21 +24,19 @@ from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec from ray.rllib.execution.rollout_ops import ( standardize_fields, + synchronous_parallel_sample, ) from ray.rllib.execution.train_ops import ( train_one_step, multi_gpu_train_one_step, ) -from ray.rllib.utils.annotations import ExperimentalAPI from ray.rllib.policy.policy import Policy -from ray.rllib.utils.annotations import override +from ray.rllib.utils.annotations import ExperimentalAPI, override from ray.rllib.utils.deprecation import ( DEPRECATED_VALUE, deprecation_warning, ) from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY -from ray.rllib.utils.typing import ResultDict -from ray.rllib.execution.rollout_ops import synchronous_parallel_sample from ray.rllib.utils.metrics import ( NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED, @@ -47,6 +44,9 @@ SAMPLE_TIMER, ALL_MODULES, ) +from ray.rllib.utils.schedules.scheduler import Scheduler +from ray.rllib.utils.typing import ResultDict +from ray.util.debug import log_once if TYPE_CHECKING: from ray.rllib.core.learner.learner import Learner @@ -325,6 +325,13 @@ def validate(self) -> None: # Check `entropy_coeff` for correctness. if self.entropy_coeff < 0.0: raise ValueError("`entropy_coeff` must be >= 0.0") + # Entropy coeff schedule checking. + if self._enable_learner_api: + Scheduler.validate( + self.entropy_coeff_schedule, + "entropy_coeff_schedule", + "entropy coefficient", + ) class UpdateKL: @@ -371,32 +378,17 @@ def get_default_policy_class( ) -> Optional[Type[Policy]]: if config["framework"] == "torch": - if config._enable_rl_module_api: - from ray.rllib.algorithms.ppo.torch.ppo_torch_policy_rlm import ( - PPOTorchPolicyWithRLModule, - ) - - return PPOTorchPolicyWithRLModule - else: - from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy + from ray.rllib.algorithms.ppo.ppo_torch_policy import PPOTorchPolicy - return PPOTorchPolicy + return PPOTorchPolicy elif config["framework"] == "tf": from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy return PPOTF1Policy else: - if config._enable_rl_module_api: - from ray.rllib.algorithms.ppo.tf.ppo_tf_policy_rlm import ( - PPOTfPolicyWithRLModule, - ) - - return PPOTfPolicyWithRLModule - else: - - from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF2Policy + from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF2Policy - return PPOTF2Policy + return PPOTF2Policy @ExperimentalAPI def training_step(self) -> ResultDict: diff --git a/rllib/algorithms/ppo/ppo_learner.py b/rllib/algorithms/ppo/ppo_learner.py index 576471b285edd..be16bbb531121 100644 --- a/rllib/algorithms/ppo/ppo_learner.py +++ b/rllib/algorithms/ppo/ppo_learner.py @@ -1,11 +1,12 @@ from collections import defaultdict from dataclasses import dataclass -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union from ray.rllib.core.learner.learner import LearnerHyperparameters from ray.rllib.core.learner.learner import Learner +from ray.rllib.core.rl_module.rl_module import ModuleID from ray.rllib.utils.annotations import override -from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule +from ray.rllib.utils.schedules.scheduler import Scheduler LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY = "vf_loss_unclipped" @@ -34,10 +35,6 @@ class to configure your algorithm. entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = None vf_loss_coeff: float = None - # TODO: Move to base LearnerHyperparameter class (and handling of this setting - # into base Learners). - lr_schedule: Optional[List[List[Union[int, float]]]] = None - class PPOLearner(Learner): @override(Learner) @@ -45,25 +42,12 @@ def build(self) -> None: super().build() # Build entropy coeff scheduling tools. - self.entropy_coeff_scheduler = None - if self.hps.entropy_coeff_schedule: - # Custom schedule, based on list of - # ([ts], [value to be reached by ts])-tuples. - self.entropy_coeff_schedule_per_module = defaultdict( - lambda: PiecewiseSchedule( - self.hps.entropy_coeff_schedule, - outside_value=self.hps.entropy_coeff_schedule[-1][-1], - framework=None, - ) - ) - self.curr_entropy_coeffs_per_module = defaultdict( - lambda: self._get_tensor_variable(self.hps.entropy_coeff) - ) - # If no schedule, pin entropy coeff to its given (fixed) value. - else: - self.curr_entropy_coeffs_per_module = defaultdict( - lambda: self.hps.entropy_coeff - ) + self.entropy_coeff_scheduler = Scheduler( + fixed_value=self.hps.entropy_coeff, + schedule=self.hps.entropy_coeff_schedule, + framework=self.framework, + device=self._device, + ) # Set up KL coefficient variables (per module). # Note that the KL coeff is not controlled by a schedul, but seeks @@ -71,3 +55,21 @@ def build(self) -> None: self.curr_kl_coeffs_per_module = defaultdict( lambda: self._get_tensor_variable(self.hps.kl_coeff) ) + + @override(Learner) + def additional_update_per_module( + self, module_id: ModuleID, sampled_kl_values: dict, timestep: int + ) -> Dict[str, Any]: + results = super().additional_update_per_module( + module_id, + sampled_kl_values=sampled_kl_values, + timestep=timestep, + ) + + # Update entropy coefficient via our Scheduler. + new_entropy_coeff = self.entropy_coeff_scheduler.update( + module_id, timestep=timestep + ) + results.update({LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY: new_entropy_coeff}) + + return results diff --git a/rllib/algorithms/ppo/ppo_tf_policy.py b/rllib/algorithms/ppo/ppo_tf_policy.py index a00f8c037eb6e..76e8d0161689a 100644 --- a/rllib/algorithms/ppo/ppo_tf_policy.py +++ b/rllib/algorithms/ppo/ppo_tf_policy.py @@ -89,11 +89,11 @@ def __init__( # Initialize MixIns. ValueNetworkMixin.__init__(self, config) - KLCoeffMixin.__init__(self, config) EntropyCoeffSchedule.__init__( self, config["entropy_coeff"], config["entropy_coeff_schedule"] ) LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + KLCoeffMixin.__init__(self, config) # Note: this is a bit ugly, but loss and optimizer initialization must # happen after all the MixIns are initialized. diff --git a/rllib/algorithms/ppo/ppo_torch_policy.py b/rllib/algorithms/ppo/ppo_torch_policy.py index df45eefeb14c8..26a52dbe4d2b1 100644 --- a/rllib/algorithms/ppo/ppo_torch_policy.py +++ b/rllib/algorithms/ppo/ppo_torch_policy.py @@ -44,8 +44,6 @@ class PPOTorchPolicy( def __init__(self, observation_space, action_space, config): config = dict(ray.rllib.algorithms.ppo.ppo.PPOConfig().to_dict(), **config) - # TODO: Move into Policy API, if needed at all here. Why not move this into - # `PPOConfig`?. validate_config(config) TorchPolicyV2.__init__( @@ -63,7 +61,6 @@ def __init__(self, observation_space, action_space, config): ) KLCoeffMixin.__init__(self, config) - # TODO: Don't require users to call this manually. self._initialize_loss_from_dummy_batch() @override(TorchPolicyV2) diff --git a/rllib/algorithms/ppo/tests/test_ppo_learner.py b/rllib/algorithms/ppo/tests/test_ppo_learner.py index ceb726e7deac1..e16ea35641e68 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_learner.py +++ b/rllib/algorithms/ppo/tests/test_ppo_learner.py @@ -67,6 +67,7 @@ def test_loss(self): fcnet_activation="linear", vf_share_layers=False, ), + _enable_learner_api=True, ) .rl_module( _enable_rl_module_api=True, @@ -86,13 +87,11 @@ def test_loss(self): lambda x: torch.as_tensor(x).float(), train_batch ) else: - # tf train_batch = tree.map_structure( lambda x: tf.convert_to_tensor(x), train_batch ) algo_config = config.copy(copy_frozen=False) - algo_config.training(_enable_learner_api=True) algo_config.validate() algo_config.freeze() diff --git a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py index 3f2fc1d007b1c..c700ff7ab16e0 100644 --- a/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py +++ b/rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py @@ -126,14 +126,22 @@ def test_ppo_compilation_and_schedule_mixins(self): config.training(model=get_model_config(fw, lstm=lstm)) algo = config.build(env=env) + # TODO: Maybe add an API to get the Learner(s) instances within + # a learner group, remote or not. + learner = algo.learner_group._learner optim = algo.learner_group._learner._named_optimizers[ DEFAULT_POLICY_ID ] - entropy_coeff = algo.get_policy().entropy_coeff + # Check initial LR directly set in optimizer vs the first (ts=0) + # value from the schedule. lr = optim.param_groups[0]["lr"] if fw == "torch" else optim.lr + check(lr, config.lr_schedule[0][1]) + + # Check current entropy coeff value using the respective Scheduler. + entropy_coeff = learner.entropy_coeff_scheduler.get_current_value( + DEFAULT_POLICY_ID + ) check(entropy_coeff, 0.1) - # Check initial LR directly set in optimizer. - check(lr, config.lr) for i in range(num_iterations): results = algo.train() @@ -159,6 +167,7 @@ def test_ppo_exploration_setup(self): enable_connectors=True, ) .rl_module(_enable_rl_module_api=True) + .training(_enable_learner_api=True) ) obs = np.array(0) @@ -166,14 +175,14 @@ def test_ppo_exploration_setup(self): config, frameworks=("torch", "tf2"), with_eager_tracing=True ): # Default Agent should be setup with StochasticSampling. - trainer = config.build() + algo = config.build() # explore=False, always expect the same (deterministic) action. - a_ = trainer.compute_single_action( + a_ = algo.compute_single_action( obs, explore=False, prev_action=np.array(2), prev_reward=np.array(1.0) ) for _ in range(50): - a = trainer.compute_single_action( + a = algo.compute_single_action( obs, explore=False, prev_action=np.array(2), @@ -185,12 +194,12 @@ def test_ppo_exploration_setup(self): actions = [] for _ in range(300): actions.append( - trainer.compute_single_action( + algo.compute_single_action( obs, prev_action=np.array(2), prev_reward=np.array(1.0) ) ) check(np.mean(actions), 1.5, atol=0.2) - trainer.stop() + algo.stop() def test_ppo_free_log_std_with_rl_modules(self): """Tests the free log std option works.""" @@ -216,25 +225,22 @@ def test_ppo_free_log_std_with_rl_modules(self): ) for fw in framework_iterator(config, frameworks=("torch", "tf2")): - trainer = config.build() - policy = trainer.get_policy() + algo = config.build() + policy = algo.get_policy() + learner = algo.learner_group._learner + module = learner.module[DEFAULT_POLICY_ID] # Check the free log std var is created. if fw == "torch": - matching = [ - v for (n, v) in policy.model.named_parameters() if "log_std" in n - ] + matching = [v for (n, v) in module.named_parameters() if "log_std" in n] else: matching = [ - v for v in policy.model.trainable_variables if "log_std" in str(v) + v for v in module.trainable_variables if "log_std" in str(v) ] assert len(matching) == 1, matching log_std_var = matching[0] - # linter yells at you if you don't pass in the parameters. - # reason: https://docs.python-guide.org/writing/gotchas/ - # #late-binding-closures - def get_value(fw=fw, policy=policy, log_std_var=log_std_var): + def get_value(): if fw == "torch": return log_std_var.detach().cpu().numpy()[0] else: @@ -244,14 +250,13 @@ def get_value(fw=fw, policy=policy, log_std_var=log_std_var): init_std = get_value() assert init_std == 0.0, init_std batch = compute_gae_for_sample_batch(policy, PENDULUM_FAKE_BATCH.copy()) - if fw == "torch": - batch = policy._lazy_tensor_dict(batch) - policy.learn_on_batch(batch) + batch = policy._lazy_tensor_dict(batch) + algo.learner_group.update(batch.as_multi_agent()) # Check the variable is updated. post_std = get_value() assert post_std != 0.0, post_std - trainer.stop() + algo.stop() if __name__ == "__main__": diff --git a/rllib/algorithms/ppo/tf/ppo_tf_learner.py b/rllib/algorithms/ppo/tf/ppo_tf_learner.py index e65f658688252..6ab8960a326a5 100644 --- a/rllib/algorithms/ppo/tf/ppo_tf_learner.py +++ b/rllib/algorithms/ppo/tf/ppo_tf_learner.py @@ -3,7 +3,6 @@ from ray.rllib.algorithms.ppo.ppo_learner import ( LEARNER_RESULTS_KL_KEY, - LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY, LEARNER_RESULTS_CURR_KL_COEFF_KEY, LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY, LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY, @@ -97,7 +96,8 @@ def compute_loss_per_module( total_loss = tf.reduce_mean( -surrogate_loss + self.hps.vf_loss_coeff * vf_loss_clipped - - self.curr_entropy_coeffs_per_module[module_id] * curr_entropy + - self.entropy_coeff_scheduler.get_current_value(module_id) * curr_entropy + # - self.curr_entropy_coeffs_per_module[module_id] * curr_entropy ) # Add mean_kl_loss (already processed through `reduce_mean_valid`), @@ -139,11 +139,4 @@ def additional_update_per_module( curr_var.assign(curr_var * 0.5) results.update({LEARNER_RESULTS_CURR_KL_COEFF_KEY: curr_var.numpy()}) - # Update entropy coefficient. - value = self.hps.entropy_coeff - if self.hps.entropy_coeff_schedule is not None: - value = self.entropy_coeff_schedule_per_module[module_id].value(t=timestep) - self.curr_entropy_coeffs_per_module[module_id].assign(value) - results.update({LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY: value}) - return results diff --git a/rllib/algorithms/ppo/tf/ppo_tf_policy_rlm.py b/rllib/algorithms/ppo/tf/ppo_tf_policy_rlm.py deleted file mode 100644 index c99a17ad840a9..0000000000000 --- a/rllib/algorithms/ppo/tf/ppo_tf_policy_rlm.py +++ /dev/null @@ -1,185 +0,0 @@ -import logging -from typing import Dict, List, Union - -from ray.rllib.algorithms.ppo.ppo_tf_policy import validate_config -from ray.rllib.evaluation.postprocessing import ( - Postprocessing, - compute_gae_for_sample_batch, -) -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.tf_mixins import ( - EntropyCoeffSchedule, - KLCoeffMixin, - LearningRateSchedule, -) -from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 -from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.nested_dict import NestedDict -from ray.rllib.utils.tf_utils import ( - explained_variance, - warn_if_infinite_kl_divergence, -) - -from ray.rllib.utils.typing import TensorType - -tf1, tf, tfv = try_import_tf() - -logger = logging.getLogger(__name__) - - -class PPOTfPolicyWithRLModule( - LearningRateSchedule, - EntropyCoeffSchedule, - KLCoeffMixin, - EagerTFPolicyV2, -): - """PyTorch policy class used with PPO. - - This class is copied from PPOTFPolicy and is modified to support RLModules. - Some subtle differences: - - if config._enable_rl_module api is true make_rl_module should be implemented by - the policy the policy is assumed to be compatible with rl_modules (i.e. self.model - would be an RLModule) - - Tower stats no longer belongs to the model (i.e. RLModule) instead it belongs to - the policy itself. - - Connectors should be enabled to use this policy - - So far it only works for vectorized obs and action spaces (Fully connected neural - networks). we need model catalog to work for other obs and action spaces. - - # TODO: In the future we will deprecate doing all phases of training, exploration, - # and inference via one policy abstraction. Instead, we will use separate - # abstractions for each phase. For training (i.e. gradient updates, given the - # sample that have been collected) we will use Learner which will own one or - # possibly many RLModules, and RLOptimizer. For exploration, we will use RLSampler - # which will own RLModule, and RLTrajectoryProcessor. The exploration and inference - # phase details are TBD but the whole point is to make rllib extremely modular. - """ - - def __init__(self, observation_space, action_space, config): - # TODO: Move into Policy API, if needed at all here. Why not move this into - # `PPOConfig`?. - self.framework = "tf2" - EagerTFPolicyV2.enable_eager_execution_if_necessary() - validate_config(config) - # Initialize MixIns. - LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) - EntropyCoeffSchedule.__init__( - self, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) - KLCoeffMixin.__init__(self, config) - EagerTFPolicyV2.__init__(self, observation_space, action_space, config) - - self.maybe_initialize_optimizer_and_loss() - - @override(EagerTFPolicyV2) - def loss( - self, - model: Union[ModelV2, "tf.keras.Model"], - dist_class, - train_batch: SampleBatch, - ) -> Union[TensorType, List[TensorType]]: - - if not isinstance(train_batch, NestedDict): - train_batch = NestedDict(train_batch) - fwd_out = model.forward_train(train_batch) - curr_action_dist = fwd_out[SampleBatch.ACTION_DIST] - - action_dist_class = type(fwd_out[SampleBatch.ACTION_DIST]) - prev_action_dist = action_dist_class.from_logits( - train_batch[SampleBatch.ACTION_DIST_INPUTS] - ) - - logp_ratio = tf.exp( - fwd_out[SampleBatch.ACTION_LOGP] - train_batch[SampleBatch.ACTION_LOGP] - ) - - # Only calculate kl loss if necessary (kl-coeff > 0.0). - if self.config["kl_coeff"] > 0.0: - action_kl = prev_action_dist.kl(curr_action_dist) - mean_kl_loss = tf.reduce_mean(action_kl) - warn_if_infinite_kl_divergence(self, mean_kl_loss) - else: - mean_kl_loss = tf.constant(0.0) - - curr_entropy = fwd_out["entropy"] - mean_entropy = tf.reduce_mean(curr_entropy) - - surrogate_loss = tf.minimum( - train_batch[Postprocessing.ADVANTAGES] * logp_ratio, - train_batch[Postprocessing.ADVANTAGES] - * tf.clip_by_value( - logp_ratio, - 1 - self.config["clip_param"], - 1 + self.config["clip_param"], - ), - ) - - # Compute a value function loss. - if self.config["use_critic"]: - value_fn_out = fwd_out[SampleBatch.VF_PREDS] - vf_loss = tf.math.square( - value_fn_out - train_batch[Postprocessing.VALUE_TARGETS] - ) - vf_loss_clipped = tf.clip_by_value( - vf_loss, - 0, - self.config["vf_clip_param"], - ) - mean_vf_loss = tf.reduce_mean(vf_loss_clipped) - mean_vf_unclipped_loss = tf.reduce_mean(vf_loss) - # Ignore the value function. - else: - mean_vf_unclipped_loss = tf.constant(0.0) - value_fn_out = vf_loss_clipped = mean_vf_loss = tf.constant(0.0) - - total_loss = tf.reduce_mean( - -surrogate_loss - + self.config["vf_loss_coeff"] * vf_loss_clipped - - self.entropy_coeff * curr_entropy - ) - # Add mean_kl_loss (already processed through `reduce_mean_valid`), - # if necessary. - if self.config["kl_coeff"] > 0.0: - total_loss += self.kl_coeff * mean_kl_loss - - # Store stats in policy for stats_fn. - self._total_loss = total_loss - self._mean_policy_loss = tf.reduce_mean(-surrogate_loss) - self._mean_vf_loss = mean_vf_loss - self._unclipped_mean_vf_loss = mean_vf_unclipped_loss - self._mean_entropy = mean_entropy - # Backward compatibility: Deprecate self._mean_kl. - self._mean_kl_loss = self._mean_kl = mean_kl_loss - self._value_fn_out = value_fn_out - self._value_mean = tf.reduce_mean(value_fn_out) - - return total_loss - - @override(EagerTFPolicyV2) - def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: - return { - "cur_kl_coeff": tf.cast(self.kl_coeff, tf.float64), - "cur_lr": tf.cast(self.cur_lr, tf.float64), - "total_loss": self._total_loss, - "policy_loss": self._mean_policy_loss, - "vf_loss": self._mean_vf_loss, - "unclipped_vf_loss": self._unclipped_mean_vf_loss, - "vf_explained_var": explained_variance( - train_batch[Postprocessing.VALUE_TARGETS], self._value_fn_out - ), - "kl": self._mean_kl_loss, - "entropy": self._mean_entropy, - "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), - "value_mean": tf.cast(self._value_mean, tf.float64), - } - - @override(EagerTFPolicyV2) - def postprocess_trajectory( - self, sample_batch, other_agent_batches=None, episode=None - ): - sample_batch = super().postprocess_trajectory(sample_batch) - return compute_gae_for_sample_batch( - self, sample_batch, other_agent_batches, episode - ) diff --git a/rllib/algorithms/ppo/torch/ppo_torch_learner.py b/rllib/algorithms/ppo/torch/ppo_torch_learner.py index 675539f50f5c8..aa32a956d9780 100644 --- a/rllib/algorithms/ppo/torch/ppo_torch_learner.py +++ b/rllib/algorithms/ppo/torch/ppo_torch_learner.py @@ -3,7 +3,6 @@ from ray.rllib.algorithms.ppo.ppo_learner import ( LEARNER_RESULTS_KL_KEY, - LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY, LEARNER_RESULTS_CURR_KL_COEFF_KEY, LEARNER_RESULTS_VF_EXPLAINED_VAR_KEY, LEARNER_RESULTS_VF_LOSS_UNCLIPPED_KEY, @@ -93,7 +92,7 @@ def compute_loss_per_module( total_loss = torch.mean( -surrogate_loss + self.hps.vf_loss_coeff * vf_loss_clipped - - self.curr_entropy_coeffs_per_module[module_id] * curr_entropy + - self.entropy_coeff_scheduler.get_current_value(module_id) * curr_entropy ) # Add mean_kl_loss (already processed through `reduce_mean_valid`), @@ -135,11 +134,4 @@ def additional_update_per_module( curr_var.data *= 0.5 results.update({LEARNER_RESULTS_CURR_KL_COEFF_KEY: curr_var.item()}) - # Update entropy coefficient. - value = self.hps.entropy_coeff - if self.hps.entropy_coeff_schedule is not None: - value = self.entropy_coeff_schedule_per_module[module_id].value(t=timestep) - self.curr_entropy_coeffs_per_module[module_id].data = torch.tensor(value) - results.update({LEARNER_RESULTS_CURR_ENTROPY_COEFF_KEY: value}) - return results diff --git a/rllib/algorithms/ppo/torch/ppo_torch_policy_rlm.py b/rllib/algorithms/ppo/torch/ppo_torch_policy_rlm.py deleted file mode 100644 index 04ac92fc2ba5b..0000000000000 --- a/rllib/algorithms/ppo/torch/ppo_torch_policy_rlm.py +++ /dev/null @@ -1,246 +0,0 @@ -import logging -from typing import Dict, List, Type, Union - -from ray.rllib.algorithms.ppo.ppo_tf_policy import validate_config -from ray.rllib.evaluation.postprocessing import ( - Postprocessing, - compute_gae_for_sample_batch, -) -from ray.rllib.models.action_dist import ActionDistribution -from ray.rllib.models.modelv2 import ModelV2 -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.policy.torch_mixins import ( - EntropyCoeffSchedule, - KLCoeffMixin, - LearningRateSchedule, -) -from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 -from ray.rllib.utils.annotations import override -from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.utils.torch_utils import ( - apply_grad_clipping, - explained_variance, - sequence_mask, - warn_if_infinite_kl_divergence, -) -from ray.rllib.utils.typing import TensorType - -torch, nn = try_import_torch() - -logger = logging.getLogger(__name__) - - -class PPOTorchPolicyWithRLModule( - LearningRateSchedule, - EntropyCoeffSchedule, - KLCoeffMixin, - TorchPolicyV2, -): - """PyTorch policy class used with PPO. - - This class is copied from PPOTorchPolicyV2 and is modified to support RLModules. - Some subtle differences: - - if config._enable_rl_module api is true make_rl_module should be implemented by - the policy the policy is assumed to be compatible with rl_modules (i.e. self.model - would be an RLModule) - - Tower stats no longer belongs to the model (i.e. RLModule) instead it belongs to - the policy itself. - - Connectors should be enabled to use this policy - - So far it only works for vectorized obs and action spaces (Fully connected neural - networks). we need model catalog to work for other obs and action spaces. - - # TODO: In the future we will deprecate doing all phases of training, exploration, - # and inference via one policy abstraction. Instead, we will use separate - # abstractions for each phase. For training (i.e. gradient updates, given the - # sample that have been collected) we will use Learner which will own one or - # possibly many RLModules, and RLOptimizer. For exploration, we will use RLSampler - # which will own RLModule, and RLTrajectoryProcessor. The exploration and inference - # phase details are TBD but the whole point is to make rllib extremely modular. - """ - - def __init__(self, observation_space, action_space, config): - # TODO: Move into Policy API, if needed at all here. Why not move this into - # `PPOConfig`?. - validate_config(config) - - TorchPolicyV2.__init__( - self, - observation_space, - action_space, - config, - max_seq_len=config["model"]["max_seq_len"], - ) - - LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) - EntropyCoeffSchedule.__init__( - self, config["entropy_coeff"], config["entropy_coeff_schedule"] - ) - KLCoeffMixin.__init__(self, config) - - # TODO: Don't require users to call this manually. - self._initialize_loss_from_dummy_batch() - - @override(TorchPolicyV2) - def loss( - self, - model: ModelV2, - dist_class: Type[ActionDistribution], - train_batch: SampleBatch, - ) -> Union[TensorType, List[TensorType]]: - """Compute loss for Proximal Policy Objective. - - Args: - model: The Model to calculate the loss for. - dist_class: The action distr. class. - train_batch: The training data. - - Returns: - The PPO loss tensor given the input batch. - """ - - fwd_out = model.forward_train(train_batch) - curr_action_dist = fwd_out[SampleBatch.ACTION_DIST] - state = fwd_out.get("state_out", {}) - - # TODO (Kourosh): come back to RNNs later - # RNN case: Mask away 0-padded chunks at end of time axis. - if state: - B = len(train_batch[SampleBatch.SEQ_LENS]) - max_seq_len = train_batch[SampleBatch.OBS].shape[0] // B - mask = sequence_mask( - train_batch[SampleBatch.SEQ_LENS], - max_seq_len, - time_major=self.config["model"]["_time_major"], - ) - mask = torch.reshape(mask, [-1]) - num_valid = torch.sum(mask) - - def reduce_mean_valid(t): - return torch.sum(t[mask]) / num_valid - - # non-RNN case: No masking. - else: - mask = None - reduce_mean_valid = torch.mean - - action_dist_class = type(fwd_out[SampleBatch.ACTION_DIST]) - prev_action_dist = action_dist_class.from_logits( - train_batch[SampleBatch.ACTION_DIST_INPUTS] - ) - - logp_ratio = torch.exp( - fwd_out[SampleBatch.ACTION_LOGP] - train_batch[SampleBatch.ACTION_LOGP] - ) - - # Only calculate kl loss if necessary (kl-coeff > 0.0). - if self.config["kl_coeff"] > 0.0: - action_kl = prev_action_dist.kl(curr_action_dist) - mean_kl_loss = reduce_mean_valid(action_kl) - # TODO smorad: should we do anything besides warn? Could discard KL term - # for this update - warn_if_infinite_kl_divergence(self, mean_kl_loss) - else: - mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device) - - curr_entropy = fwd_out["entropy"] - mean_entropy = reduce_mean_valid(curr_entropy) - - surrogate_loss = torch.min( - train_batch[Postprocessing.ADVANTAGES] * logp_ratio, - train_batch[Postprocessing.ADVANTAGES] - * torch.clamp( - logp_ratio, 1 - self.config["clip_param"], 1 + self.config["clip_param"] - ), - ) - - # Compute a value function loss. - if self.config["use_critic"]: - value_fn_out = fwd_out[SampleBatch.VF_PREDS] - vf_loss = torch.pow( - value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0 - ) - vf_loss_clipped = torch.clamp(vf_loss, 0, self.config["vf_clip_param"]) - mean_vf_loss = reduce_mean_valid(vf_loss_clipped) - mean_vf_unclipped_loss = reduce_mean_valid(vf_loss) - # Ignore the value function. - else: - value_fn_out = torch.tensor(0.0).to(surrogate_loss.device) - mean_vf_unclipped_loss = vf_loss_clipped = mean_vf_loss = torch.tensor( - 0.0 - ).to(surrogate_loss.device) - - total_loss = reduce_mean_valid( - -surrogate_loss - + self.config["vf_loss_coeff"] * vf_loss_clipped - - self.entropy_coeff * curr_entropy - ) - - # Add mean_kl_loss (already processed through `reduce_mean_valid`), - # if necessary. - if self.config["kl_coeff"] > 0.0: - total_loss += self.kl_coeff * mean_kl_loss - - # TODO (Kourosh) Where would tower_stats go? How should stats_fn be implemented - # here? - # Store values for stats function in model (tower), such that for - # multi-GPU, we do not override them during the parallel loss phase. - self.tower_stats[model]["total_loss"] = total_loss - self.tower_stats[model]["mean_policy_loss"] = reduce_mean_valid(-surrogate_loss) - self.tower_stats[model]["mean_vf_loss"] = mean_vf_loss - self.tower_stats[model]["unclipped_vf_loss"] = mean_vf_unclipped_loss - self.tower_stats[model]["vf_explained_var"] = explained_variance( - train_batch[Postprocessing.VALUE_TARGETS], value_fn_out - ) - self.tower_stats[model]["mean_entropy"] = mean_entropy - self.tower_stats[model]["mean_kl_loss"] = mean_kl_loss - - return total_loss - - # TODO: Make this an event-style subscription (e.g.: - # "after_gradients_computed"). - @override(TorchPolicyV2) - def extra_grad_process(self, local_optimizer, loss): - return apply_grad_clipping(self, local_optimizer, loss) - - @override(TorchPolicyV2) - def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: - return convert_to_numpy( - { - "cur_kl_coeff": self.kl_coeff, - "cur_lr": self.cur_lr, - "total_loss": torch.mean( - torch.stack(self.get_tower_stats("total_loss")) - ), - "policy_loss": torch.mean( - torch.stack(self.get_tower_stats("mean_policy_loss")) - ), - "vf_loss": torch.mean( - torch.stack(self.get_tower_stats("mean_vf_loss")) - ), - "vf_explained_var": torch.mean( - torch.stack(self.get_tower_stats("vf_explained_var")) - ), - "kl": torch.mean(torch.stack(self.get_tower_stats("mean_kl_loss"))), - "entropy": torch.mean( - torch.stack(self.get_tower_stats("mean_entropy")) - ), - "entropy_coeff": self.entropy_coeff, - "unclipped_vf_loss": torch.mean( - torch.stack(self.get_tower_stats("unclipped_vf_loss")) - ), - } - ) - - @override(TorchPolicyV2) - def postprocess_trajectory( - self, sample_batch, other_agent_batches=None, episode=None - ): - # Do all post-processing always with no_grad(). - # Not using this here will introduce a memory leak - # in torch (issue #6962). - # TODO: no_grad still necessary? - with torch.no_grad(): - return compute_gae_for_sample_batch( - self, sample_batch, other_agent_batches, episode - ) diff --git a/rllib/algorithms/tests/test_algorithm_config.py b/rllib/algorithms/tests/test_algorithm_config.py index cc949c45f0e5a..7b6804ea92cc6 100644 --- a/rllib/algorithms/tests/test_algorithm_config.py +++ b/rllib/algorithms/tests/test_algorithm_config.py @@ -179,6 +179,7 @@ def test_rl_module_api(self): .framework("torch") .rollouts(enable_connectors=True) .rl_module(_enable_rl_module_api=True) + .training(_enable_learner_api=True) ) config.validate() @@ -328,7 +329,11 @@ def get_default_rl_module_spec(self): ######################################## # This is the simplest case where we have to construct the marl module based on # the default specs only. - config = SingleAgentAlgoConfig().rl_module(_enable_rl_module_api=True) + config = ( + SingleAgentAlgoConfig() + .rl_module(_enable_rl_module_api=True) + .training(_enable_learner_api=True) + ) config.validate() spec, expected = self._get_expected_marl_spec(config, DiscreteBCTorchModule) @@ -343,14 +348,18 @@ def get_default_rl_module_spec(self): ######################################## # This is the case where we pass in a multi-agent RLModuleSpec that asks the # algorithm to assign a specific type of RLModule class to certain module_ids. - config = SingleAgentAlgoConfig().rl_module( - _enable_rl_module_api=True, - rl_module_spec=MultiAgentRLModuleSpec( - module_specs={ - "p1": SingleAgentRLModuleSpec(module_class=CustomRLModule1), - "p2": SingleAgentRLModuleSpec(module_class=CustomRLModule1), - } - ), + config = ( + SingleAgentAlgoConfig() + .rl_module( + _enable_rl_module_api=True, + rl_module_spec=MultiAgentRLModuleSpec( + module_specs={ + "p1": SingleAgentRLModuleSpec(module_class=CustomRLModule1), + "p2": SingleAgentRLModuleSpec(module_class=CustomRLModule1), + }, + ), + ) + .training(_enable_learner_api=True) ) config.validate() @@ -360,9 +369,13 @@ def get_default_rl_module_spec(self): ######################################## # This is the case where we ask the algorithm to assign a specific type of # RLModule class to ALL module_ids. - config = SingleAgentAlgoConfig().rl_module( - _enable_rl_module_api=True, - rl_module_spec=SingleAgentRLModuleSpec(module_class=CustomRLModule1), + config = ( + SingleAgentAlgoConfig() + .rl_module( + _enable_rl_module_api=True, + rl_module_spec=SingleAgentRLModuleSpec(module_class=CustomRLModule1), + ) + .training(_enable_learner_api=True) ) config.validate() @@ -377,11 +390,15 @@ def get_default_rl_module_spec(self): ######################################## # This is an alternative way to ask the algorithm to assign a specific type of # RLModule class to ALL module_ids. - config = SingleAgentAlgoConfig().rl_module( - _enable_rl_module_api=True, - rl_module_spec=MultiAgentRLModuleSpec( - module_specs=SingleAgentRLModuleSpec(module_class=CustomRLModule1) - ), + config = ( + SingleAgentAlgoConfig() + .rl_module( + _enable_rl_module_api=True, + rl_module_spec=MultiAgentRLModuleSpec( + module_specs=SingleAgentRLModuleSpec(module_class=CustomRLModule1) + ), + ) + .training(_enable_learner_api=True) ) config.validate() @@ -398,15 +415,19 @@ def get_default_rl_module_spec(self): # This is not only assigning a specific type of RLModule class to EACH # module_id, but also defining a new custom MultiAgentRLModule class to be used # in the multi-agent scenario. - config = SingleAgentAlgoConfig().rl_module( - _enable_rl_module_api=True, - rl_module_spec=MultiAgentRLModuleSpec( - marl_module_class=CustomMARLModule1, - module_specs={ - "p1": SingleAgentRLModuleSpec(module_class=CustomRLModule1), - "p2": SingleAgentRLModuleSpec(module_class=CustomRLModule1), - }, - ), + config = ( + SingleAgentAlgoConfig() + .rl_module( + _enable_rl_module_api=True, + rl_module_spec=MultiAgentRLModuleSpec( + marl_module_class=CustomMARLModule1, + module_specs={ + "p1": SingleAgentRLModuleSpec(module_class=CustomRLModule1), + "p2": SingleAgentRLModuleSpec(module_class=CustomRLModule1), + }, + ), + ) + .training(_enable_learner_api=True) ) config.validate() @@ -435,8 +456,10 @@ def get_default_rl_module_spec(self): # This is the case where we ask the algorithm to use its default # MultiAgentRLModuleSpec, but the MultiAgentRLModuleSpec has not defined its # SingleAgentRLmoduleSpecs. - config = MultiAgentAlgoConfigWithNoSingleAgentSpec().rl_module( - _enable_rl_module_api=True + config = ( + MultiAgentAlgoConfigWithNoSingleAgentSpec() + .rl_module(_enable_rl_module_api=True) + .training(_enable_learner_api=True) ) self.assertRaisesRegex( @@ -449,7 +472,11 @@ def get_default_rl_module_spec(self): # This is the case where we ask the algorithm to use its default # MultiAgentRLModuleSpec, and the MultiAgentRLModuleSpec has defined its # SingleAgentRLmoduleSpecs. - config = MultiAgentAlgoConfig().rl_module(_enable_rl_module_api=True) + config = ( + MultiAgentAlgoConfig() + .rl_module(_enable_rl_module_api=True) + .training(_enable_learner_api=True) + ) config.validate() spec, expected = self._get_expected_marl_spec( diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index 688d1574d0342..ef2e3fc012174 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -47,7 +47,7 @@ OverrideToImplementCustomLogic, OverrideToImplementCustomLogic_CallToSuperRecommended, ) -from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule +from ray.rllib.utils.schedules.scheduler import Scheduler torch, _ = try_import_torch() tf1, tf, tfv = try_import_tf() @@ -245,6 +245,7 @@ def __init__( self._module_obj = module self._optimizer_config = optimizer_config self._hps = learner_hyperparameters or LearnerHyperparameters() + self._device = None # pick the configs that we need for the learner from scaling config self._learner_group_scaling_config = ( @@ -622,24 +623,15 @@ def build(self) -> None: return self._is_built = True - # Generic LR scheduling tools. - self.lr_scheduler = None - if self.hps.lr_schedule is not None: - # Custom schedule, based on list of - # ([ts], [value to be reached by ts])-tuples. - self.lr_schedule_per_module = defaultdict( - lambda: PiecewiseSchedule( - self.hps.lr_schedule, - outside_value=self.hps.lr_schedule[-1][-1], - framework=None, - ) - ) - self.curr_lr_per_module = defaultdict( - lambda: self._get_tensor_variable(self._optimizer_config["lr"]) - ) - # If no schedule, pin learning rate to its given (fixed) value. - else: - self.curr_lr_per_module = defaultdict(lambda: self._optimizer_config["lr"]) + # Build learning rate scheduling tools. + # TODO (sven): Move lr from optimizer config to Learner HPs? + # We might not need optimizer config. + self.lr_scheduler = Scheduler( + fixed_value=self._optimizer_config["lr"], + schedule=self.hps.lr_schedule, + framework=self.framework, + device=self._device, + ) self._module = self._make_module() for param_seq, optimizer in self.configure_optimizers(): @@ -773,7 +765,7 @@ def additional_update_per_module(self, module_id: ModuleID, tau: float): return results_all_modules - @OverrideToImplementCustomLogic + @OverrideToImplementCustomLogic_CallToSuperRecommended def additional_update_per_module( self, module_id: ModuleID, **kwargs ) -> Dict[str, Any]: diff --git a/rllib/core/learner/tf/tf_learner.py b/rllib/core/learner/tf/tf_learner.py index 45f98d4f75205..bfe03311cadca 100644 --- a/rllib/core/learner/tf/tf_learner.py +++ b/rllib/core/learner/tf/tf_learner.py @@ -88,9 +88,7 @@ def configure_optimizer_per_module( self, module_id: ModuleID ) -> Union[ParamOptimizerPair, NamedParamOptimizerPairs]: module = self._module[module_id] - # TODO (sven): Move lr from optimizer config to Learner HPs? - # We might not need optimizer config. - lr = self.curr_lr_per_module[module_id] + lr = self.lr_scheduler.get_current_value(module_id) optim = tf.keras.optimizers.Adam(learning_rate=lr) pair: ParamOptimizerPair = ( self.get_parameters(module), @@ -528,17 +526,23 @@ def filter_fwd_out(x): def additional_update_per_module( self, module_id: ModuleID, *, timestep: int, **kwargs ) -> Mapping[str, Any]: + + results = super().additional_update_per_module(module_id, timestep=timestep) + # Handle lr scheduling updates and apply new learning rates to the optimizers. + new_lr = self.lr_scheduler.update(module_id=module_id, timestep=timestep) + + # Not sure why we need to do this here besides setting the original + # tf Variable `self.curr_lr_per_module[module_id]`. But when tf creates the + # optimizer, it seems to detach its lr value from the given variable. + # Updating this variable is NOT sufficient to update the actual optimizer's + # learning rate, so we have to explicitly set it here. if self.hps.lr_schedule is not None: - value = self.lr_schedule_per_module[module_id].value(t=timestep) - self.curr_lr_per_module[module_id].assign(value) - # Not sure why we need to do this here besides setting the original - # tf Variable `self.curr_lr_per_module[module_id]`. When tf creates the - # optimizer, maybe it detaches its lr value from the given variable? - self._named_optimizers[module_id].lr = value - return { - LEARNER_RESULTS_CURR_LR_KEY: self._named_optimizers[module_id].lr.numpy() - } + self._named_optimizers[module_id].lr = new_lr + + results.update({LEARNER_RESULTS_CURR_LR_KEY: new_lr}) + + return results @override(Learner) def _get_tensor_variable(self, value, dtype=None, trainable=False) -> "tf.Tensor": diff --git a/rllib/core/learner/torch/torch_learner.py b/rllib/core/learner/torch/torch_learner.py index 43eb0ac9910be..29c4a2e4fc456 100644 --- a/rllib/core/learner/torch/torch_learner.py +++ b/rllib/core/learner/torch/torch_learner.py @@ -74,9 +74,7 @@ def configure_optimizer_per_module( self, module_id: ModuleID ) -> Union[ParamOptimizerPair, NamedParamOptimizerPairs]: module = self._module[module_id] - # TODO (sven): Move lr from optimizer config to Learner HPs? - # We might not need optimizer config. - lr = self.curr_lr_per_module[module_id] + lr = self.lr_scheduler.get_current_value(module_id) pair: ParamOptimizerPair = ( self.get_parameters(module), torch.optim.Adam(self.get_parameters(module), lr=lr), @@ -100,12 +98,13 @@ def compute_gradients( def additional_update_per_module( self, module_id: ModuleID, *, timestep: int, **kwargs ) -> Mapping[str, Any]: + results = super().additional_update_per_module(module_id, timestep=timestep) + # Handle lr scheduling updates and apply new learning rates to the optimizers. - value = self._optimizer_config["lr"] - if self.hps.lr_schedule is not None: - value = self.lr_schedule_per_module[module_id].value(t=timestep) - self.curr_lr_per_module[module_id].data = torch.tensor(value) - return {LEARNER_RESULTS_CURR_LR_KEY: value} + new_lr = self.lr_scheduler.update(module_id=module_id, timestep=timestep) + results.update({LEARNER_RESULTS_CURR_LR_KEY: new_lr}) + + return results @override(Learner) def postprocess_gradients( diff --git a/rllib/core/models/tests/test_catalog.py b/rllib/core/models/tests/test_catalog.py index bb3bb52c71106..9d17766ee8c4f 100644 --- a/rllib/core/models/tests/test_catalog.py +++ b/rllib/core/models/tests/test_catalog.py @@ -385,6 +385,7 @@ def build_vf_head(self, framework): _enable_rl_module_api=True, rl_module_spec=SingleAgentRLModuleSpec(catalog_class=MyCatalog), ) + .training(_enable_learner_api=True) .framework("torch") ) diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index 0b81dc45bd3e2..6bcf6225444a4 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -7,9 +7,6 @@ from ray.rllib.algorithms.callbacks import DefaultCallbacks import ray.rllib.algorithms.dqn as dqn import ray.rllib.algorithms.ppo as ppo -from ray.rllib.algorithms.ppo.torch.ppo_torch_policy_rlm import ( - PPOTorchPolicyWithRLModule, -) from ray.rllib.examples.env.debug_counter_env import MultiAgentDebugCounterEnv from ray.rllib.examples.env.multi_agent import MultiAgentPendulum from ray.rllib.evaluation.rollout_worker import RolloutWorker @@ -236,12 +233,9 @@ def test_traj_view_next_action(self): .rollouts(rollout_fragment_length=200, num_rollout_workers=0) ) config.validate() - enable_rl_module_api = config._enable_rl_module_api rollout_worker_w_api = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v1"), - default_policy_class=PPOTorchPolicyWithRLModule - if enable_rl_module_api - else ppo.PPOTorchPolicy, + default_policy_class=ppo.PPOTorchPolicy, config=config, ) # Add the next action (a') and 2nd next action (a'') to the view diff --git a/rllib/examples/self_play_with_open_spiel.py b/rllib/examples/self_play_with_open_spiel.py index 3c5360de15ded..f611cac7d1556 100644 --- a/rllib/examples/self_play_with_open_spiel.py +++ b/rllib/examples/self_play_with_open_spiel.py @@ -274,6 +274,7 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): # Train the "main" policy to play really well using self-play. results = None if not args.from_checkpoint: + create_checkpoints = not bool(os.environ.get("RLLIB_ENABLE_RL_MODULE", False)) results = tune.Tuner( "PPO", param_space=config, @@ -294,8 +295,8 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): sort_by_metric=True, ), checkpoint_config=air.CheckpointConfig( - checkpoint_at_end=True, - checkpoint_frequency=10, + checkpoint_at_end=create_checkpoints, + checkpoint_frequency=10 if create_checkpoints else 0, ), ), ).fit() diff --git a/rllib/models/tests/test_preprocessors.py b/rllib/models/tests/test_preprocessors.py index d26ba8b028ba6..4093260634589 100644 --- a/rllib/models/tests/test_preprocessors.py +++ b/rllib/models/tests/test_preprocessors.py @@ -39,6 +39,7 @@ def tearDownClass(cls) -> None: def test_rlms_and_preprocessing(self): config = ( ppo.PPOConfig() + .framework("tf2") .environment( env="ray.rllib.examples.env.random_env.RandomEnv", env_config={ @@ -48,18 +49,18 @@ def test_rlms_and_preprocessing(self): }, ) # Run this very quickly locally. - .rollouts(rollout_fragment_length=10) - .rollouts(num_rollout_workers=0) - .training(train_batch_size=10, sgd_minibatch_size=1, num_sgd_iter=1) + .rollouts(num_rollout_workers=0, rollout_fragment_length=10) + .training( + train_batch_size=10, + sgd_minibatch_size=1, + num_sgd_iter=1, + _enable_learner_api=True, + ) + .rl_module(_enable_rl_module_api=True) # Set this to True to enforce no preprocessors being used. .experimental(_disable_preprocessor_api=True) - .framework("tf2") ) - # TODO (Artur): No need to manually enable RLModules here since we have not - # fully migrated. Clear this up after migration. - config.rl_module(_enable_rl_module_api=True) - for _ in framework_iterator(config, frameworks=("torch", "tf2")): algo = config.build() results = algo.train() diff --git a/rllib/policy/eager_tf_policy_v2.py b/rllib/policy/eager_tf_policy_v2.py index 7e1e543b08ab0..161a4bea0a225 100644 --- a/rllib/policy/eager_tf_policy_v2.py +++ b/rllib/policy/eager_tf_policy_v2.py @@ -432,16 +432,17 @@ def _init_view_requirements(self): self.view_requirements[SampleBatch.INFOS].used_for_training = False def maybe_initialize_optimizer_and_loss(self): - optimizers = force_list(self.optimizer()) - if self.exploration: - # Policies with RLModules don't have an exploration object. - optimizers = self.exploration.get_exploration_optimizer(optimizers) - - # The list of local (tf) optimizers (one per loss term). - self._optimizers: List[LocalOptimizer] = optimizers - # Backward compatibility: A user's policy may only support a single - # loss term and optimizer (no lists). - self._optimizer: LocalOptimizer = optimizers[0] if optimizers else None + if not self.config.get("_enable_learner_api", False): + optimizers = force_list(self.optimizer()) + if self.exploration: + # Policies with RLModules don't have an exploration object. + optimizers = self.exploration.get_exploration_optimizer(optimizers) + + # The list of local (tf) optimizers (one per loss term). + self._optimizers: List[LocalOptimizer] = optimizers + # Backward compatibility: A user's policy may only support a single + # loss term and optimizer (no lists). + self._optimizer: LocalOptimizer = optimizers[0] if optimizers else None self._initialize_loss_from_dummy_batch( auto_remove_unneeded_view_reqs=True, diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index bb534497e3d28..c1ca60b83904a 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -1470,35 +1470,39 @@ def _initialize_loss_from_dummy_batch( seq_len = sample_batch_size // B seq_lens = np.array([seq_len for _ in range(B)], dtype=np.int32) postprocessed_batch[SampleBatch.SEQ_LENS] = seq_lens - # Switch on lazy to-tensor conversion on `postprocessed_batch`. - train_batch = self._lazy_tensor_dict(postprocessed_batch) - # Calling loss, so set `is_training` to True. - train_batch.set_training(True) - if seq_lens is not None: - train_batch[SampleBatch.SEQ_LENS] = seq_lens - train_batch.count = self._dummy_batch.count - # Call the loss function, if it exists. - # TODO(jungong) : clean up after all agents get migrated. - # We should simply do self.loss(...) here. - if self._loss is not None: - self._loss(self, self.model, self.dist_class, train_batch) - elif ( - is_overridden(self.loss) or self.config.get("_enable_rl_module_api", False) - ) and not self.config["in_evaluation"]: - self.loss(self.model, self.dist_class, train_batch) - # Call the stats fn, if given. - # TODO(jungong) : clean up after all agents get migrated. - # We should simply do self.stats_fn(train_batch) here. - if stats_fn is not None: - stats_fn(self, train_batch) - if hasattr(self, "stats_fn") and not self.config["in_evaluation"]: - self.stats_fn(train_batch) + + if not self.config.get("_enable_learner_api"): + # Switch on lazy to-tensor conversion on `postprocessed_batch`. + train_batch = self._lazy_tensor_dict(postprocessed_batch) + # Calling loss, so set `is_training` to True. + train_batch.set_training(True) + if seq_lens is not None: + train_batch[SampleBatch.SEQ_LENS] = seq_lens + train_batch.count = self._dummy_batch.count + + # Call the loss function, if it exists. + # TODO(jungong) : clean up after all agents get migrated. + # We should simply do self.loss(...) here. + if self._loss is not None: + self._loss(self, self.model, self.dist_class, train_batch) + elif is_overridden(self.loss) and not self.config["in_evaluation"]: + self.loss(self.model, self.dist_class, train_batch) + # Call the stats fn, if given. + # TODO(jungong) : clean up after all agents get migrated. + # We should simply do self.stats_fn(train_batch) here. + if stats_fn is not None: + stats_fn(self, train_batch) + if hasattr(self, "stats_fn") and not self.config["in_evaluation"]: + self.stats_fn(train_batch) # Re-enable tracing. self._no_tracing = False # Add new columns automatically to view-reqs. - if auto_remove_unneeded_view_reqs: + if ( + not self.config.get("_enable_learner_api") + and auto_remove_unneeded_view_reqs + ): # Add those needed for postprocessing and training. all_accessed_keys = ( train_batch.accessed_keys diff --git a/rllib/policy/tf_mixins.py b/rllib/policy/tf_mixins.py index 8ce18df5e9796..fe5e23a330e84 100644 --- a/rllib/policy/tf_mixins.py +++ b/rllib/policy/tf_mixins.py @@ -33,7 +33,9 @@ class LearningRateSchedule: @DeveloperAPI def __init__(self, lr, lr_schedule): self._lr_schedule = None - if lr_schedule is None: + # Disable any scheduling behavior related to learning if Learner API is active. + # Schedules are handled by Learner class. + if lr_schedule is None or self.config.get("_enable_learner_api", False): self.cur_lr = tf1.get_variable("lr", initializer=lr, trainable=False) else: self._lr_schedule = PiecewiseSchedule( @@ -78,7 +80,11 @@ class EntropyCoeffSchedule: @DeveloperAPI def __init__(self, entropy_coeff, entropy_coeff_schedule): self._entropy_coeff_schedule = None - if entropy_coeff_schedule is None: + # Disable any scheduling behavior related to learning if Learner API is active. + # Schedules are handled by Learner class. + if entropy_coeff_schedule is None or ( + self.config.get("_enable_learner_api", False) + ): self.entropy_coeff = get_variable( entropy_coeff, framework="tf", tf_name="entropy_coeff", trainable=False ) @@ -208,37 +214,32 @@ class TargetNetworkMixin: """ def __init__(self): - if self.config.get("_enable_rl_module_api", False): - # In order to access the variables for rl modules, we need to - # use the underlying keras api model.trainable_variables. - model_vars = self.model.trainable_variables - target_model_vars = self.target_model.trainable_variables - else: + if not self.config.get("_enable_rl_module_api", False): model_vars = self.model.trainable_variables() target_model_vars = self.target_model.trainable_variables() - @make_tf_callable(self.get_session()) - def update_target_fn(tau): - tau = tf.convert_to_tensor(tau, dtype=tf.float32) - update_target_expr = [] - assert len(model_vars) == len(target_model_vars), ( - model_vars, - target_model_vars, - ) - for var, var_target in zip(model_vars, target_model_vars): - update_target_expr.append( - var_target.assign(tau * var + (1.0 - tau) * var_target) + @make_tf_callable(self.get_session()) + def update_target_fn(tau): + tau = tf.convert_to_tensor(tau, dtype=tf.float32) + update_target_expr = [] + assert len(model_vars) == len(target_model_vars), ( + model_vars, + target_model_vars, ) - logger.debug("Update target op {}".format(var_target)) - return tf.group(*update_target_expr) - - # Hard initial update. - self._do_update = update_target_fn - # TODO: The previous SAC implementation does an update(1.0) here. - # If this is changed to tau != 1.0 the sac_loss_function test fails. Why? - # Also the test is not very maintainable, we need to change that unittest - # anyway. - self.update_target(tau=1.0) # self.config.get("tau", 1.0)) + for var, var_target in zip(model_vars, target_model_vars): + update_target_expr.append( + var_target.assign(tau * var + (1.0 - tau) * var_target) + ) + logger.debug("Update target op {}".format(var_target)) + return tf.group(*update_target_expr) + + # Hard initial update. + self._do_update = update_target_fn + # TODO: The previous SAC implementation does an update(1.0) here. + # If this is changed to tau != 1.0 the sac_loss_function test fails. Why? + # Also the test is not very maintainable, we need to change that unittest + # anyway. + self.update_target(tau=1.0) # self.config.get("tau", 1.0)) @property def q_func_vars(self): @@ -276,7 +277,8 @@ def set_weights(self, weights): EagerTFPolicyV2.set_weights(self, weights) elif isinstance(self, EagerTFPolicy): # Handle TF2 policies. EagerTFPolicy.set_weights(self, weights) - self.update_target(self.config.get("tau", 1.0)) + if not self.config.get("_enable_rl_module_api", False): + self.update_target(self.config.get("tau", 1.0)) class ValueNetworkMixin: diff --git a/rllib/policy/torch_mixins.py b/rllib/policy/torch_mixins.py index d6c4b03a935d3..b258c1d74560f 100644 --- a/rllib/policy/torch_mixins.py +++ b/rllib/policy/torch_mixins.py @@ -8,8 +8,6 @@ torch, nn = try_import_torch() -# TODO: (sven) Unify hyperparam annealing procedures across RLlib (tf/torch) -# and for all possible hyperparams, not just lr. @DeveloperAPI class LearningRateSchedule: """Mixin for TorchPolicy that adds a learning rate schedule.""" @@ -17,6 +15,8 @@ class LearningRateSchedule: @DeveloperAPI def __init__(self, lr, lr_schedule): self._lr_schedule = None + # Disable any scheduling behavior related to learning if Learner API is active. + # Schedules are handled by Learner class. if lr_schedule is None: self.cur_lr = lr else: @@ -28,7 +28,7 @@ def __init__(self, lr, lr_schedule): @override(Policy) def on_global_var_update(self, global_vars): super().on_global_var_update(global_vars) - if self._lr_schedule: + if self._lr_schedule and not self.config.get("_enable_learner_api", False): self.cur_lr = self._lr_schedule.value(global_vars["timestep"]) for opt in self._optimizers: for p in opt.param_groups: @@ -42,7 +42,11 @@ class EntropyCoeffSchedule: @DeveloperAPI def __init__(self, entropy_coeff, entropy_coeff_schedule): self._entropy_coeff_schedule = None - if entropy_coeff_schedule is None: + # Disable any scheduling behavior related to learning if Learner API is active. + # Schedules are handled by Learner class. + if entropy_coeff_schedule is None or ( + self.config.get("_enable_learner_api", False) + ): self.entropy_coeff = entropy_coeff else: # Allows for custom schedule similar to lr_schedule format diff --git a/rllib/policy/torch_policy_v2.py b/rllib/policy/torch_policy_v2.py index 0d58dbc55c2b3..1962a04adc605 100644 --- a/rllib/policy/torch_policy_v2.py +++ b/rllib/policy/torch_policy_v2.py @@ -182,29 +182,32 @@ def __init__( self.exploration = None else: self.exploration = self._create_exploration() - self._optimizers = force_list(self.optimizer()) - - # Backward compatibility workaround so Policy will call self.loss() directly. - # TODO(jungong): clean up after all policies are migrated to new sub-class - # implementation. - self._loss = None - - # Store, which params (by index within the model's list of - # parameters) should be updated per optimizer. - # Maps optimizer idx to set or param indices. - self.multi_gpu_param_groups: List[Set[int]] = [] - main_params = {p: i for i, p in enumerate(self.model.parameters())} - for o in self._optimizers: - param_indices = [] - for pg_idx, pg in enumerate(o.param_groups): - for p in pg["params"]: - param_indices.append(main_params[p]) - self.multi_gpu_param_groups.append(set(param_indices)) - - # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each - # one with m towers (num_gpus). - num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1) - self._loaded_batches = [[] for _ in range(num_buffers)] + + if not self.config.get("_enable_learner_api", False): + self._optimizers = force_list(self.optimizer()) + + # Backward compatibility workaround so Policy will call self.loss() + # directly. + # TODO (jungong): clean up after all policies are migrated to new sub-class + # implementation. + self._loss = None + + # Store, which params (by index within the model's list of + # parameters) should be updated per optimizer. + # Maps optimizer idx to set or param indices. + self.multi_gpu_param_groups: List[Set[int]] = [] + main_params = {p: i for i, p in enumerate(self.model.parameters())} + for o in self._optimizers: + param_indices = [] + for pg_idx, pg in enumerate(o.param_groups): + for p in pg["params"]: + param_indices.append(main_params[p]) + self.multi_gpu_param_groups.append(set(param_indices)) + + # Create n sample-batch buffers (num_multi_gpu_tower_stacks), each + # one with m towers (num_gpus). + num_buffers = self.config.get("num_multi_gpu_tower_stacks", 1) + self._loaded_batches = [[] for _ in range(num_buffers)] # If set, means we are using distributed allreduce during learning. self.distributed_world_size = None @@ -1104,7 +1107,7 @@ def _compute_action_helper( if self.model: self.model.eval() - extra_fetches = {} + extra_fetches = None if isinstance(self.model, RLModule): if explore: fwd_out = self.model.forward_exploration(input_dict) @@ -1166,7 +1169,7 @@ def _compute_action_helper( ) # Add default and custom fetches. - if not extra_fetches: + if extra_fetches is None: extra_fetches = self.extra_action_out( input_dict, state_batches, self.model, action_dist ) diff --git a/rllib/tuned_examples/impala/cartpole-impala.yaml b/rllib/tuned_examples/impala/cartpole-impala.yaml index 46c37c52ea697..1df02c4313cb5 100644 --- a/rllib/tuned_examples/impala/cartpole-impala.yaml +++ b/rllib/tuned_examples/impala/cartpole-impala.yaml @@ -8,8 +8,8 @@ cartpole-impala: # Works for both torch and tf. framework: tf2 num_gpus: 0 - _enable_rl_module_api: True - _enable_learner_api: True + _enable_rl_module_api: true + _enable_learner_api: true grad_clip: 40 num_workers: 2 num_learner_workers: 1 diff --git a/rllib/tuned_examples/ppo/cartpole-ppo-with-rl-module.yaml b/rllib/tuned_examples/ppo/cartpole-ppo-with-rl-module.yaml index fbfb6905b4ace..2f6afebd53efa 100644 --- a/rllib/tuned_examples/ppo/cartpole-ppo-with-rl-module.yaml +++ b/rllib/tuned_examples/ppo/cartpole-ppo-with-rl-module.yaml @@ -19,5 +19,5 @@ cartpole-ppo: vf_share_layers: true enable_connectors: true _enable_rl_module_api: true - _enable_learner_api: false + _enable_learner_api: true eager_tracing: false \ No newline at end of file diff --git a/rllib/tuned_examples/ppo/pendulum-ppo-with-rl-module.yaml b/rllib/tuned_examples/ppo/pendulum-ppo-with-rl-module.yaml index 5b2888d709f79..98da67a36fdb9 100644 --- a/rllib/tuned_examples/ppo/pendulum-ppo-with-rl-module.yaml +++ b/rllib/tuned_examples/ppo/pendulum-ppo-with-rl-module.yaml @@ -21,6 +21,7 @@ pendulum-ppo: enable_connectors: true model: fcnet_activation: relu + _enable_learner_api: true _enable_rl_module_api: true # Need to unset this b/c we are using the RLModule API, which # provides exploration control via the RLModule's `forward_exploration` method. diff --git a/rllib/utils/schedules/scheduler.py b/rllib/utils/schedules/scheduler.py new file mode 100644 index 0000000000000..7d349329791bd --- /dev/null +++ b/rllib/utils/schedules/scheduler.py @@ -0,0 +1,157 @@ +from collections import defaultdict +from typing import List, Optional, Tuple + +from ray.rllib.core.rl_module.rl_module import ModuleID +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule +from ray.rllib.utils.typing import TensorType + + +_, tf, _ = try_import_tf() +torch, _ = try_import_torch() + + +class Scheduler: + """Class to manage a scheduled (framework-dependent) tensor variable. + + Uses the PiecewiseSchedule (for maximum configuration flexibility) + """ + + def __init__( + self, + *, + fixed_value: Optional[float] = None, + schedule: Optional[List[Tuple[int, float]]] = None, + framework: str = "torch", + device: Optional[str] = None, + ): + """Initializes a Scheduler instance. + + Args: + fixed_value: A fixed, constant value (in case no schedule should be used). + Set `schedule` to None to always just use this fixed value. + If `fixed_value` is None, `schedule` must be provided. + schedule: The schedule configuration to use. In the format of + [[timestep, value], [timestep, value], ...] + Intermediary timesteps will be assigned to interpolated values (linear + interpolation will be used). A schedule config's first entry must + start with timestep 0, i.e.: [[0, initial_value], [...]]. + framework: The framework string, for which to create the tensor variable + that hold the current value. This is the variable that can be used in + the graph, e.g. in a loss function. + device: Optional device (for torch) to place the tensor variable on. + """ + self.use_schedule = schedule is not None + self.framework = framework + self.device = device + + if self.use_schedule: + # Custom schedule, based on list of + # ([ts], [value to be reached by ts])-tuples. + self.schedule_per_module = defaultdict( + lambda: PiecewiseSchedule( + schedule, + outside_value=schedule[-1][-1], + framework=None, + ) + ) + # As initial tensor valie, use the first timestep's (must be 0) value. + self.curr_value_per_module = defaultdict( + lambda: self._create_tensor_variable(initial_value=schedule[0][1]) + ) + # If no schedule, pin (fix) given value. + else: + self.curr_value_per_module = defaultdict(lambda: fixed_value) + + @staticmethod + def validate( + schedule: Optional[List[Tuple[int, float]]], + schedule_name: str, + value_name: str, + ) -> None: + """Performs checking of a certain schedule configuration. + + The first entry in `schedule` must have a timestep of 0. + + Args: + schedule: The schedule configuration to check. In the format of + [[timestep, value], [timestep, value], ...] + Intermediary timesteps will be assigned to interpolated values (linear + interpolation will be used). A schedule config's first entry must + start with timestep 0, i.e.: [[0, initial_value], [...]]. + schedule_name: The name of the schedule, e.g. `lr_schedule`. + value_name: A full text description of the variable that's being scheduled, + e.g. `learning rate`. + + Raises: + ValueError: In case, errors are found in the schedule's format. + """ + if schedule is not None: + if not isinstance(schedule, (list, tuple)) or (len(schedule) < 2): + raise ValueError( + f"Invalid `{schedule_name}` ({schedule}) specified! Must be a " + "list of at least 2 tuples, each of the form " + f"(`timestep`, `{value_name} to reach`), e.g. " + "`[(0, 0.001), (1e6, 0.0001), (2e6, 0.00005)]`." + ) + elif schedule[0][0] != 0: + raise ValueError( + f"When providing a `{schedule_name}`, the first timestep must be 0 " + f"and the corresponding lr value is the initial {value_name}! You " + f"provided ts={schedule[0][0]} {value_name}={schedule[0][1]}." + ) + + def get_current_value(self, module_id: ModuleID) -> TensorType: + """Returns the current value (as a tensor variable), given a ModuleID. + + Args: + module_id: The module ID, for which to retrueve the current tensor value. + + Returns: + The tensor variable (holding the current value to be used). + """ + return self.curr_value_per_module[module_id] + + def update(self, module_id: ModuleID, timestep: int) -> float: + """Updates the underlying (framework specific) tensor variable. + + Args: + module_id: The module ID, for which to update the tensor variable. + timestep: The current timestep. + + Returns: + The current value of the tensor variable as a python float. + """ + if self.use_schedule: + python_value = self.schedule_per_module[module_id].value(t=timestep) + if self.framework == "torch": + self.curr_value_per_module[module_id].data = torch.tensor(python_value) + else: + self.curr_value_per_module[module_id].assign(python_value) + else: + python_value = self.curr_value_per_module[module_id] + + return python_value + + def _create_tensor_variable(self, initial_value: float) -> TensorType: + """Creates a framework-specific tensor variable to be scheduled. + + Args: + initial_value: The initial (float) value for the variable to hold. + + Returns: + The created framework-specific tensor variable. + """ + if self.framework == "torch": + return torch.tensor( + initial_value, + requires_grad=False, + dtype=torch.float32, + device=self.device, + ) + else: + return tf.Variable( + initial_value, + trainable=False, + dtype=tf.float32, + )