Skip to content

Commit

Permalink
[RLlib] Learner API (+DreamerV3 prep): Learner.register_metrics API…
Browse files Browse the repository at this point in the history
…, cleanup, etc.. (ray-project#35573)
  • Loading branch information
sven1977 committed May 22, 2023
1 parent decc28d commit 0fd06ad
Show file tree
Hide file tree
Showing 31 changed files with 626 additions and 462 deletions.
18 changes: 12 additions & 6 deletions rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.appo.appo_learner import (
AppoHyperparameters,
AppoLearnerHyperparameters,
LEARNER_RESULTS_KL_KEY,
)
from ray.rllib.algorithms.impala.impala import Impala, ImpalaConfig
Expand Down Expand Up @@ -222,7 +222,10 @@ def get_default_learner_class(self):

return APPOTfLearner
else:
raise ValueError(f"The framework {self.framework_str} is not supported.")
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use either 'torch' or 'tf2'."
)

@override(ImpalaConfig)
def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec:
Expand All @@ -235,16 +238,19 @@ def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec:
APPOTfRLModule as RLModule,
)
else:
raise ValueError(f"The framework {self.framework_str} is not supported.")
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use either 'torch' or 'tf2'."
)

from ray.rllib.algorithms.appo.appo_catalog import APPOCatalog

return SingleAgentRLModuleSpec(module_class=RLModule, catalog_class=APPOCatalog)

@override(ImpalaConfig)
def get_learner_hyperparameters(self) -> AppoHyperparameters:
def get_learner_hyperparameters(self) -> AppoLearnerHyperparameters:
base_hps = super().get_learner_hyperparameters()
return AppoHyperparameters(
return AppoLearnerHyperparameters(
use_kl_loss=self.use_kl_loss,
kl_target=self.kl_target,
kl_coeff=self.kl_coeff,
Expand Down Expand Up @@ -355,7 +361,7 @@ def _get_additional_update_kwargs(self, train_results) -> dict:
return dict(
last_update=self._counters[LAST_TARGET_UPDATE_TS],
mean_kl_loss_per_module={
mid: r[LEARNER_STATS_KEY][LEARNER_RESULTS_KL_KEY]
mid: r[LEARNER_RESULTS_KL_KEY]
for mid, r in train_results.items()
if mid != ALL_MODULES
},
Expand Down
14 changes: 8 additions & 6 deletions rllib/algorithms/appo/appo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ray.rllib.algorithms.impala.impala_learner import (
ImpalaLearner,
ImpalaHyperparameters,
ImpalaLearnerHyperparameters,
)
from ray.rllib.core.rl_module.marl_module import ModuleID
from ray.rllib.utils.annotations import override
Expand All @@ -19,7 +19,7 @@


@dataclass
class AppoHyperparameters(ImpalaHyperparameters):
class AppoLearnerHyperparameters(ImpalaLearnerHyperparameters):
"""Hyperparameters for the APPOLearner sub-classes (framework specific).
These should never be set directly by the user. Instead, use the APPOConfig
Expand All @@ -37,7 +37,7 @@ class to configure your algorithm.


class AppoLearner(ImpalaLearner):
"""Adds KL coeff updates via `additional_updates_per_module()` to Impala logic.
"""Adds KL coeff updates via `additional_update_for_module()` to Impala logic.
Framework-specific sub-classes must override `_update_module_target_networks()`
and `_update_module_kl_coeff()`
Expand All @@ -59,10 +59,10 @@ def remove_module(self, module_id: str):
self.curr_kl_coeffs_per_module.pop(module_id)

@override(ImpalaLearner)
def additional_update_per_module(
def additional_update_for_module(
self,
module_id: ModuleID,
*,
module_id: ModuleID,
last_update: int,
mean_kl_loss_per_module: dict,
timestep: int,
Expand All @@ -84,7 +84,9 @@ def additional_update_per_module(
# updates.
# We should instead have the target / kl threshold update be based off
# of the train_batch_size * some target update frequency * num_sgd_iter.
results = super().additional_update_per_module(module_id, timestep=timestep)
results = super().additional_update_for_module(
module_id=module_id, timestep=timestep
)

if (timestep - last_update) >= self.hps.target_update_frequency_ts:
self._update_module_target_networks(module_id)
Expand Down
30 changes: 18 additions & 12 deletions rllib/algorithms/appo/tf/appo_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ray.rllib.core.rl_module.marl_module import ModuleID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.typing import TensorType

_, tf, _ = try_import_tf()
Expand All @@ -22,8 +23,8 @@ class APPOTfLearner(AppoLearner, TfLearner):
"""Implements APPO loss / update logic on top of ImpalaTfLearner."""

@override(TfLearner)
def compute_loss_per_module(
self, module_id: str, batch: SampleBatch, fwd_out: Mapping[str, TensorType]
def compute_loss_for_module(
self, module_id: str, batch: NestedDict, fwd_out: Mapping[str, TensorType]
) -> TensorType:
values = fwd_out[SampleBatch.VF_PREDS]
action_dist_cls_train = self._module[module_id].get_train_action_dist_cls()
Expand Down Expand Up @@ -139,16 +140,21 @@ def compute_loss_per_module(
+ (mean_kl_loss * self.curr_kl_coeffs_per_module[module_id])
)

return {
self.TOTAL_LOSS_KEY: total_loss,
POLICY_LOSS_KEY: mean_pi_loss,
VF_LOSS_KEY: mean_vf_loss,
ENTROPY_KEY: -mean_entropy_loss,
LEARNER_RESULTS_KL_KEY: mean_kl_loss,
LEARNER_RESULTS_CURR_KL_COEFF_KEY: (
self.curr_kl_coeffs_per_module[module_id]
),
}
# Register important loss stats.
self.register_metrics(
module_id,
{
POLICY_LOSS_KEY: mean_pi_loss,
VF_LOSS_KEY: mean_vf_loss,
ENTROPY_KEY: -mean_entropy_loss,
LEARNER_RESULTS_KL_KEY: mean_kl_loss,
LEARNER_RESULTS_CURR_KL_COEFF_KEY: (
self.curr_kl_coeffs_per_module[module_id]
),
},
)
# Return the total loss.
return total_loss

@override(AppoLearner)
def _update_module_target_networks(self, module_id: ModuleID):
Expand Down
27 changes: 18 additions & 9 deletions rllib/algorithms/appo/torch/appo_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.typing import TensorType

torch, nn = try_import_torch()
Expand All @@ -32,8 +33,8 @@ class APPOTorchLearner(AppoLearner, TorchLearner):
"""Implements APPO loss / update logic on top of ImpalaTorchLearner."""

@override(TorchLearner)
def compute_loss_per_module(
self, module_id: str, batch: SampleBatch, fwd_out: Mapping[str, TensorType]
def compute_loss_for_module(
self, module_id: str, batch: NestedDict, fwd_out: Mapping[str, TensorType]
) -> TensorType:

values = fwd_out[SampleBatch.VF_PREDS]
Expand Down Expand Up @@ -143,13 +144,21 @@ def compute_loss_per_module(
+ (mean_kl_loss * self.curr_kl_coeffs_per_module[module_id])
)

return {
self.TOTAL_LOSS_KEY: total_loss,
POLICY_LOSS_KEY: mean_pi_loss,
VF_LOSS_KEY: mean_vf_loss,
ENTROPY_KEY: -mean_entropy_loss,
LEARNER_RESULTS_KL_KEY: mean_kl_loss,
}
# Register important loss stats.
self.register_metrics(
module_id,
{
POLICY_LOSS_KEY: mean_pi_loss,
VF_LOSS_KEY: mean_vf_loss,
ENTROPY_KEY: -mean_entropy_loss,
LEARNER_RESULTS_KL_KEY: mean_kl_loss,
LEARNER_RESULTS_CURR_KL_COEFF_KEY: (
self.curr_kl_coeffs_per_module[module_id]
),
},
)
# Return the total loss.
return total_loss

@override(TorchLearner)
def _make_modules_ddp_if_necessary(self) -> None:
Expand Down
20 changes: 13 additions & 7 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.impala.impala_learner import (
ImpalaHyperparameters,
ImpalaLearnerHyperparameters,
_reduce_impala_results,
)
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
Expand Down Expand Up @@ -427,9 +427,9 @@ def validate(self) -> None:
)

@override(AlgorithmConfig)
def get_learner_hyperparameters(self) -> ImpalaHyperparameters:
def get_learner_hyperparameters(self) -> ImpalaLearnerHyperparameters:
base_hps = super().get_learner_hyperparameters()
learner_hps = ImpalaHyperparameters(
learner_hps = ImpalaLearnerHyperparameters(
rollout_frag_or_episode_len=self.get_rollout_fragment_length(),
discount_factor=self.gamma,
entropy_coeff=self.entropy_coeff,
Expand All @@ -446,7 +446,7 @@ def get_learner_hyperparameters(self) -> ImpalaHyperparameters:
learner_hps.recurrent_seq_len is None
), (
"One of `rollout_frag_or_episode_len` or `recurrent_seq_len` must be not "
"None in ImpalaHyperparameters!"
"None in ImpalaLearnerHyperparameters!"
)
return learner_hps

Expand Down Expand Up @@ -481,7 +481,10 @@ def get_default_learner_class(self):

return ImpalaTfLearner
else:
raise ValueError(f"The framework {self.framework_str} is not supported.")
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use either 'torch' or 'tf2'."
)

@override(AlgorithmConfig)
def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec:
Expand All @@ -500,7 +503,10 @@ def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec:
module_class=PPOTorchRLModule, catalog_class=PPOCatalog
)
else:
raise ValueError(f"The framework {self.framework_str} is not supported.")
raise ValueError(
f"The framework {self.framework_str} is not supported. "
"Use either 'torch' or 'tf2'."
)


def make_learner_thread(local_worker, config):
Expand Down Expand Up @@ -1221,7 +1227,7 @@ def _get_additional_update_kwargs(self, train_results: dict) -> dict:
"""Returns the kwargs to `LearnerGroup.additional_update()`.
Should be overridden by subclasses to specify wanted/needed kwargs for
their own implementation of `Learner.additional_update_per_module()`.
their own implementation of `Learner.additional_update_for_module()`.
"""
return {}

Expand Down
30 changes: 8 additions & 22 deletions rllib/algorithms/impala/impala_learner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Optional, Union

import numpy as np
import tree # pip install dm_tree

from ray.rllib.core.learner.learner import Learner, LearnerHyperparameters
from ray.rllib.core.rl_module.rl_module import ModuleID
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.metrics import (
ALL_MODULES,
Expand All @@ -21,8 +20,8 @@


@dataclass
class ImpalaHyperparameters(LearnerHyperparameters):
"""Hyperparameters for the ImpalaLearner sub-classes (framework specific).
class ImpalaLearnerHyperparameters(LearnerHyperparameters):
"""LearnerHyperparameters for the ImpalaLearner sub-classes (framework specific).
These should never be set directly by the user. Instead, use the IMPALAConfig
class to configure your algorithm.
Expand Down Expand Up @@ -60,10 +59,12 @@ def build(self) -> None:
)

@override(Learner)
def additional_update_per_module(
self, module_id: ModuleID, timestep: int
def additional_update_for_module(
self, *, module_id: ModuleID, timestep: int
) -> Dict[str, Any]:
results = super().additional_update_per_module(module_id, timestep=timestep)
results = super().additional_update_for_module(
module_id=module_id, timestep=timestep
)

# Update entropy coefficient via our Scheduler.
new_entropy_coeff = self.entropy_coeff_scheduler.update(
Expand All @@ -73,21 +74,6 @@ def additional_update_per_module(

return results

@override(Learner)
def compile_results(
self,
batch: MultiAgentBatch,
fwd_out: Mapping[str, Any],
postprocessed_loss: Mapping[str, Any],
postprocessed_gradients: Mapping[str, Any],
) -> Mapping[str, Any]:
results = super().compile_results(
batch, fwd_out, postprocessed_loss, postprocessed_gradients
)
results[ALL_MODULES][NUM_AGENT_STEPS_TRAINED] = batch.agent_steps()
results[ALL_MODULES][NUM_ENV_STEPS_TRAINED] = batch.env_steps()
return results


def _reduce_impala_results(results: List[ResultDict]) -> ResultDict:
"""Reduce/Aggregate a list of results from Impala Learners.
Expand Down
8 changes: 4 additions & 4 deletions rllib/algorithms/impala/tests/test_impala_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,16 @@ def test_impala_loss(self):
# Deprecate the current default and set it to {}.
config.exploration_config = {}

for fw in framework_iterator(config, frameworks=["tf2", "torch"]):
for fw in framework_iterator(config, frameworks=["torch", "tf2"]):
algo = config.build()
policy = algo.get_policy()

if fw == "tf2":
if fw == "torch":
train_batch = convert_to_torch_tensor(SampleBatch(FAKE_BATCH))
else:
train_batch = SampleBatch(
tree.map_structure(lambda x: tf.convert_to_tensor(x), FAKE_BATCH)
)
elif fw == "torch":
train_batch = convert_to_torch_tensor(SampleBatch(FAKE_BATCH))

algo_config = config.copy(copy_frozen=False)
algo_config.validate()
Expand Down
23 changes: 15 additions & 8 deletions rllib/algorithms/impala/tf/impala_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.typing import TensorType

_, tf, _ = try_import_tf()
Expand All @@ -16,8 +17,8 @@ class ImpalaTfLearner(ImpalaLearner, TfLearner):
"""Implements the IMPALA loss function in tensorflow."""

@override(TfLearner)
def compute_loss_per_module(
self, module_id: str, batch: SampleBatch, fwd_out: Mapping[str, TensorType]
def compute_loss_for_module(
self, module_id: str, batch: NestedDict, fwd_out: Mapping[str, TensorType]
) -> TensorType:
action_dist_class_train = self.module[module_id].get_train_action_dist_cls()
target_policy_dist = action_dist_class_train.from_logits(
Expand Down Expand Up @@ -98,9 +99,15 @@ def compute_loss_per_module(
+ mean_entropy_loss
* (self.entropy_coeff_scheduler.get_current_value(module_id))
)
return {
self.TOTAL_LOSS_KEY: total_loss,
"pi_loss": mean_pi_loss,
"vf_loss": mean_vf_loss,
ENTROPY_KEY: -mean_entropy_loss,
}

# Register important loss stats.
self.register_metrics(
module_id,
{
"pi_loss": mean_pi_loss,
"vf_loss": mean_vf_loss,
ENTROPY_KEY: -mean_entropy_loss,
},
)
# Return the total loss.
return total_loss
Loading

0 comments on commit 0fd06ad

Please sign in to comment.