Skip to content

Commit

Permalink
[RLlib] APPO+new-stack (Atari benchmark) - Preparatory PR 02. (ray-pr…
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed May 1, 2023
1 parent b5e5bd7 commit e399fb8
Show file tree
Hide file tree
Showing 28 changed files with 615 additions and 516 deletions.
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ py_test(
py_test(
name = "learning_tests_pendulum_ppo_with_rl_module",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "torch_only"],
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "no_tf_static_graph"],
size = "large", # bazel may complain about it being too long sometimes - large is on purpose as some frameworks take longer
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/ppo/pendulum-ppo-with-rl-module.yaml"],
Expand Down
58 changes: 36 additions & 22 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import dataclasses
import logging
import math
import os
Expand All @@ -18,7 +17,7 @@

import ray
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.core.learner.learner import LearnerHPs
from ray.rllib.core.learner.learner import LearnerHyperparameters
from ray.rllib.core.learner.learner_group_config import (
LearnerGroupConfig,
ModuleSpec,
Expand Down Expand Up @@ -322,12 +321,8 @@ def __init__(self, algo_class=None):
self.model = copy.deepcopy(MODEL_DEFAULTS)
self.optimizer = {}
self.max_requests_in_flight_per_sampler_worker = 2
self.learner_class = None
self._learner_class = None
self._enable_learner_api = False
# experimental: this will contain the hyper-parameters that are passed to the
# Learner, for computing loss, etc. New algorithms have to set this to their
# own default. .training() will modify the fields of this object.
self._learner_hps = LearnerHPs()

# `self.callbacks()`
self.callbacks_class = DefaultCallbacks
Expand Down Expand Up @@ -469,10 +464,6 @@ def __init__(self, algo_class=None):
self.soft_horizon = DEPRECATED_VALUE
self.no_done_at_end = DEPRECATED_VALUE

@property
def learner_hps(self) -> LearnerHPs:
return self._learner_hps

def to_dict(self) -> AlgorithmConfigDict:
"""Converts all settings into a legacy config dict for backward compatibility.
Expand Down Expand Up @@ -1039,11 +1030,6 @@ def validate(self) -> None:
"(i.e. num_learner_workers = 0)"
)

# Resolve learner class.
if self._enable_learner_api and self.learner_class is None:
learner_class_path = self.get_default_learner_class()
self.learner_class = deserialize_type(learner_class_path)

def build(
self,
env: Optional[Union[str, EnvType]] = None,
Expand Down Expand Up @@ -1706,7 +1692,7 @@ def training(
if _enable_learner_api is not NotProvided:
self._enable_learner_api = _enable_learner_api
if learner_class is not NotProvided:
self.learner_class = learner_class
self._learner_class = learner_class

return self

Expand Down Expand Up @@ -2544,6 +2530,20 @@ def experimental(

return self

@property
def learner_class(self) -> Type["Learner"]:
"""Returns the Learner sub-class to use by this Algorithm.
Either
a) User sets a specific learner class via calling `.training(learner_class=...)`
b) User leaves learner class unset (None) and the AlgorithmConfig itself
figures out the actual learner class by calling its own
`.get_default_learner_class()` method.
"""
return self._learner_class or self.get_default_learner_class()

# TODO: Make rollout_fragment_length as read-only property and replace the current
# self.rollout_fragment_length a private variable.
def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
"""Automatically infers a proper rollout_fragment_length setting if "auto".
Expand Down Expand Up @@ -2579,6 +2579,8 @@ def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
else:
return self.rollout_fragment_length

# TODO: Make evaluation_config as read-only property and replace the current
# self.evaluation_config a private variable.
def get_evaluation_config_object(
self,
) -> Optional["AlgorithmConfig"]:
Expand Down Expand Up @@ -2872,6 +2874,8 @@ def is_policy_to_train(pid, batch=None):

return policies, is_policy_to_train

# TODO: Move this to those algorithms that really need this, which is currently
# only A2C and PG.
def validate_train_batch_size_vs_rollout_fragment_length(self) -> None:
"""Detects mismatches for `train_batch_size` vs `rollout_fragment_length`.
Expand Down Expand Up @@ -3130,7 +3134,7 @@ def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfi
"grad_clip": self.grad_clip,
"grad_clip_by": self.grad_clip_by,
},
learner_hps=self.learner_hps,
learner_hyperparameters=self.get_learner_hyperparameters(),
)
.resources(
num_learner_workers=self.num_learner_workers,
Expand All @@ -3143,6 +3147,20 @@ def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfi

return config

def get_learner_hyperparameters(self) -> LearnerHyperparameters:
"""Returns a new LearnerHyperparameters instance for the respective Learner.
The LearnerHyperparameters is a dataclass containing only those config settings
from AlgorithmConfig that are used by the algorithm's specific Learner
sub-class. They allow distributing only those settings relevant for learning
across a set of learner workers (instead of having to distribute the entire
AlgorithmConfig object).
Note that LearnerHyperparameters should always be derived directly from a
AlgorithmConfig object's own settings and considered frozen/read-only.
"""
return LearnerHyperparameters()

def __setattr__(self, key, value):
"""Gatekeeper in case we are in frozen state and need to error."""

Expand Down Expand Up @@ -3247,10 +3265,6 @@ def _serialize_dict(config):
config["model"]["custom_model"]
)

# Serialize dataclasses.
if isinstance(config.get("_learner_hps"), LearnerHPs):
config["_learner_hps"] = dataclasses.asdict(config["_learner_hps"])

# List'ify `policies`, iff a set or tuple (these types are not JSON'able).
ma_config = config.get("multiagent")
if ma_config is not None:
Expand Down
40 changes: 20 additions & 20 deletions rllib/algorithms/appo/appo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
Detailed documentation:
https://docs.ray.io/en/master/rllib-algorithms.html#appo
"""
import dataclasses
from typing import Optional, Type
import logging

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.appo.appo_learner import (
AppoHyperparameters,
LEARNER_RESULTS_KL_KEY,
)
from ray.rllib.algorithms.impala.impala import Impala, ImpalaConfig
from ray.rllib.algorithms.appo.tf.appo_tf_learner import AppoHPs, LEARNER_RESULTS_KL_KEY
from ray.rllib.algorithms.ppo.ppo import UpdateKL
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.policy.policy import Policy
Expand Down Expand Up @@ -77,7 +81,6 @@ def __init__(self, algo_class=None):
# __sphinx_doc_begin__

# APPO specific settings:
self._learner_hps = AppoHPs()
self.vtrace = True
self.use_critic = True
self.use_gae = True
Expand Down Expand Up @@ -195,24 +198,20 @@ def training(
self.lambda_ = lambda_
if clip_param is not NotProvided:
self.clip_param = clip_param
self._learner_hps.clip_param = clip_param
if use_kl_loss is not NotProvided:
self.use_kl_loss = use_kl_loss
if kl_coeff is not NotProvided:
self.kl_coeff = kl_coeff
self._learner_hps.kl_coeff = kl_coeff
if kl_target is not NotProvided:
self.kl_target = kl_target
self._learner_hps.kl_target = kl_target
if tau is not NotProvided:
self.tau = tau
self._learner_hps.tau = tau
if target_update_frequency is not NotProvided:
self.target_update_frequency = target_update_frequency

return self

@override(AlgorithmConfig)
@override(ImpalaConfig)
def get_default_learner_class(self):
if self.framework_str == "tf2":
from ray.rllib.algorithms.appo.tf.appo_tf_learner import APPOTfLearner
Expand All @@ -221,7 +220,7 @@ def get_default_learner_class(self):
else:
raise ValueError(f"The framework {self.framework_str} is not supported.")

@override(AlgorithmConfig)
@override(ImpalaConfig)
def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec:
if self.framework_str == "tf2":
from ray.rllib.algorithms.appo.appo_catalog import APPOCatalog
Expand All @@ -234,20 +233,23 @@ def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec:
raise ValueError(f"The framework {self.framework_str} is not supported.")

@override(ImpalaConfig)
def validate(self) -> None:
super().validate()
self._learner_hps.tau = self.tau
self._learner_hps.kl_target = self.kl_target
self._learner_hps.kl_coeff = self.kl_coeff
self._learner_hps.clip_param = self.clip_param
def get_learner_hyperparameters(self) -> AppoHyperparameters:
base_hps = super().get_learner_hyperparameters()
return AppoHyperparameters(
use_kl_loss=self.use_kl_loss,
kl_target=self.kl_target,
kl_coeff=self.kl_coeff,
clip_param=self.clip_param,
tau=self.tau,
**dataclasses.asdict(base_hps),
)


# Still used by one of the old checkpoints in tests.
# Keep a shim version of this around.
class UpdateTargetAndKL:
def __init__(self, workers, config):
self.workers = workers
self.config = config
pass


class APPO(Impala):
Expand Down Expand Up @@ -277,9 +279,8 @@ def setup(self, config: AlgorithmConfig):
def after_train_step(self, train_results: ResultDict) -> None:
"""Updates the target network and the KL coefficient for the APPO-loss.
This method is called from within the `training_iteration` method after each
train update.
This method is called from within the `training_step` method after each train
update.
The target network update frequency is calculated automatically by the product
of `num_sgd_iter` setting (usually 1 for APPO) and `minibatch_buffer_size`.
Expand Down Expand Up @@ -407,7 +408,6 @@ def get_default_policy_class(
return APPOTF1Policy
else:
if config._enable_rl_module_api:
# TODO(avnishn): This policy class doesn't work just yet
from ray.rllib.algorithms.appo.tf.appo_tf_policy_rlm import (
APPOTfPolicyWithRLModule,
)
Expand Down
103 changes: 103 additions & 0 deletions rllib/algorithms/appo/appo_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import abc
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, Mapping

import numpy as np

from ray.rllib.algorithms.impala.impala_learner import (
ImpalaLearner,
ImpalaHyperparameters,
)
from ray.rllib.core.rl_module.marl_module import ModuleID
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import get_variable


LEARNER_RESULTS_KL_KEY = "mean_kl_loss"
LEARNER_RESULTS_CURR_KL_COEFF_KEY = "curr_kl_coeff"
OLD_ACTION_DIST_KEY = "old_action_dist"
OLD_ACTION_DIST_LOGITS_KEY = "old_action_dist_logits"


@dataclass
class AppoHyperparameters(ImpalaHyperparameters):
"""Hyperparameters for the APPOLearner sub-classes (framework specific).
These should never be set directly by the user. Instead, use the APPOConfig
class to configure your algorithm.
See `ray.rllib.algorithms.appo.appo::APPOConfig::training()` for more details on the
individual properties.
"""

use_kl_loss: bool = None
kl_coeff: float = None
kl_target: float = None
clip_param: float = None
tau: float = None


class AppoLearner(ImpalaLearner):
"""Adds KL coeff updates via `additional_updates_per_module()` to Impala logic.
Framework-specific sub-classes must override `_update_module_target_networks()`
and `_update_module_kl_coeff()`
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Create framework-specific variables (simple python vars for torch).
self.kl_coeffs = defaultdict(
lambda: get_variable(
self._hps.kl_coeff,
framework=self.framework,
trainable=False,
dtype=np.float32,
)
)

@override(ImpalaLearner)
def remove_module(self, module_id: str):
super().remove_module(module_id)
self.kl_coeffs.pop(module_id)

@override(ImpalaLearner)
def additional_update_per_module(
self, module_id: ModuleID, sampled_kls: Dict[ModuleID, float], **kwargs
) -> Mapping[str, Any]:
"""Updates the target networks and KL loss coefficients (per module).
Args:
module_id:
"""
self._update_module_target_networks(module_id)
if self._hps.use_kl_loss:
self._update_module_kl_coeff(module_id, sampled_kls)
return {}

@abc.abstractmethod
def _update_module_target_networks(self, module_id: ModuleID) -> None:
"""Update the target policy of each module with the current policy.
Do that update via polyak averaging.
Args:
module_id: The module ID, whose target network(s) need to be updated.
"""

@abc.abstractmethod
def _update_module_kl_coeff(
self, module_id: ModuleID, sampled_kls: Dict[ModuleID, float]
) -> None:
"""Dynamically update the KL loss coefficients of each module with.
The update is completed using the mean KL divergence between the action
distributions current policy and old policy of each module. That action
distribution is computed during the most recent update/call to `compute_loss`.
Args:
module_id: The module whose KL loss coefficient to update.
sampled_kls: The KL divergence between the action distributions of
the current policy and old policy of each module.
"""
7 changes: 4 additions & 3 deletions rllib/algorithms/appo/tests/tf/test_appo_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def test_kl_coeff_changes(self):
config = (
appo.APPOConfig()
.environment("CartPole-v1")
.framework(eager_tracing=True)
.rollouts(
num_rollout_workers=0,
rollout_fragment_length=frag_length,
Expand All @@ -134,13 +135,13 @@ def test_kl_coeff_changes(self):
)
.exploration(exploration_config={})
)
for _ in framework_iterator(config, "tf2", with_eager_tracing=True):
for _ in framework_iterator(config, frameworks="tf2"):
algo = config.build()
# Call train while results aren't returned because this is
# a asynchronous trainer and results are returned asynchronously.
while 1:
while True:
results = algo.train()
if results and "info" in results and LEARNER_INFO in results["info"]:
if results.get("info", {}).get(LEARNER_INFO, {}).get(DEFAULT_POLICY_ID):
break
curr_kl_coeff = results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][
LEARNER_STATS_KEY
Expand Down
Loading

0 comments on commit e399fb8

Please sign in to comment.