Skip to content

Commit

Permalink
[RLlib] APPO+new-stack (Atari benchmark) - Preparatory PR 04 - Learne…
Browse files Browse the repository at this point in the history
…rAPI changes/tf-tracing fixes. (ray-project#34959)
  • Loading branch information
sven1977 committed May 11, 2023
1 parent e81963d commit 384ad04
Show file tree
Hide file tree
Showing 25 changed files with 383 additions and 236 deletions.
2 changes: 1 addition & 1 deletion rllib/algorithms/appo/appo_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class APPOCatalog(PPOCatalog):
- Value Function Head: The head used to compute the value function.
The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs
for the policy and value function. See implementations of PPORLModuleBase for
for the policy and value function. See implementations of PPORLModule for
more details.
Any custom ActorCriticEncoder can be built by overriding the
Expand Down
10 changes: 5 additions & 5 deletions rllib/algorithms/appo/appo_learner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Dict, Mapping
from typing import Any, Mapping

from ray.rllib.algorithms.impala.impala_learner import (
ImpalaLearner,
Expand Down Expand Up @@ -115,7 +115,7 @@ def _update_module_target_networks(self, module_id: ModuleID) -> None:

@abc.abstractmethod
def _update_module_kl_coeff(
self, module_id: ModuleID, sampled_kls: Dict[ModuleID, float]
self, module_id: ModuleID, sampled_kl: float
) -> Mapping[str, Any]:
"""Dynamically update the KL loss coefficients of each module with.
Expand All @@ -125,7 +125,7 @@ def _update_module_kl_coeff(
Args:
module_id: The module whose KL loss coefficient to update.
sampled_kls: Mapping from Module ID to this module's KL divergence between
the action distributions of the current (most recently updated) module
and the old module version.
sampled_kl: The computed KL loss for the given Module
(KL divergence between the action distributions of the current
(most recently updated) module and the old module version).
"""
5 changes: 4 additions & 1 deletion rllib/algorithms/appo/tf/appo_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ def compute_loss_per_module(
total_loss = (
mean_pi_loss
+ (mean_vf_loss * self.hps.vf_loss_coeff)
+ (mean_entropy_loss * self.hps.entropy_coeff)
+ (
mean_entropy_loss
* self.entropy_coeff_scheduler.get_current_value(module_id)
)
+ (mean_kl_loss * self.curr_kl_coeffs_per_module[module_id])
)

Expand Down
14 changes: 13 additions & 1 deletion rllib/algorithms/appo/tf/appo_tf_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
OLD_ACTION_DIST_LOGITS_KEY,
)
from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule
from ray.rllib.core.models.base import ACTOR
from ray.rllib.core.models.base import ACTOR, CRITIC, STATE_IN
from ray.rllib.core.models.tf.encoder import ENCODER_OUT
from ray.rllib.core.rl_module.rl_module_with_target_networks_interface import (
RLModuleWithTargetNetworksInterface,
Expand Down Expand Up @@ -45,7 +45,19 @@ def output_specs_train(self) -> List[str]:
@override(PPOTfRLModule)
def _forward_train(self, batch: NestedDict):
outs = super()._forward_train(batch)

# TODO (Artur): Remove this once Policy supports RNN
batch = batch.copy()
if self.encoder.config.shared:
batch[STATE_IN] = None
else:
batch[STATE_IN] = {
ACTOR: None,
CRITIC: None,
}
batch[SampleBatch.SEQ_LENS] = None
old_pi_inputs_encoded = self.old_encoder(batch)[ENCODER_OUT][ACTOR]

old_action_dist_logits = tf.stop_gradient(self.old_pi(old_pi_inputs_encoded))
outs[OLD_ACTION_DIST_LOGITS_KEY] = old_action_dist_logits
return outs
16 changes: 11 additions & 5 deletions rllib/algorithms/appo/torch/appo_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
AppoLearner,
LEARNER_RESULTS_CURR_KL_COEFF_KEY,
LEARNER_RESULTS_KL_KEY,
OLD_ACTION_DIST_KEY,
OLD_ACTION_DIST_LOGITS_KEY,
)
from ray.rllib.algorithms.impala.torch.vtrace_torch_v2 import (
make_time_major,
Expand Down Expand Up @@ -37,12 +37,15 @@ def compute_loss_per_module(
) -> TensorType:

values = fwd_out[SampleBatch.VF_PREDS]
action_dist_cls_train = self._module[module_id].get_train_action_dist_cls()
action_dist_cls_train = (
self.module[module_id].unwrapped().get_train_action_dist_cls()
)
target_policy_dist = action_dist_cls_train.from_logits(
fwd_out[SampleBatch.ACTION_DIST_INPUTS]
)

old_target_policy_dist = fwd_out[OLD_ACTION_DIST_KEY]
old_target_policy_dist = action_dist_cls_train.from_logits(
fwd_out[OLD_ACTION_DIST_LOGITS_KEY]
)
old_target_policy_actions_logp = old_target_policy_dist.logp(
batch[SampleBatch.ACTIONS]
)
Expand Down Expand Up @@ -133,7 +136,10 @@ def compute_loss_per_module(
total_loss = (
mean_pi_loss
+ (mean_vf_loss * self.hps.vf_loss_coeff)
+ (mean_entropy_loss * self.hps.entropy_coeff)
+ (
mean_entropy_loss
* self.entropy_coeff_scheduler.get_current_value(module_id)
)
+ (mean_kl_loss * self.curr_kl_coeffs_per_module[module_id])
)

Expand Down
6 changes: 2 additions & 4 deletions rllib/algorithms/appo/torch/appo_torch_rl_module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List

from ray.rllib.algorithms.appo.appo_learner import (
OLD_ACTION_DIST_KEY,
OLD_ACTION_DIST_LOGITS_KEY,
)
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
Expand Down Expand Up @@ -35,16 +34,15 @@ def get_target_network_pairs(self):
@override(PPOTorchRLModule)
def output_specs_train(self) -> List[str]:
return [
SampleBatch.ACTION_DIST_INPUTS,
OLD_ACTION_DIST_LOGITS_KEY,
SampleBatch.VF_PREDS,
OLD_ACTION_DIST_KEY,
]

@override(PPOTorchRLModule)
def _forward_train(self, batch: NestedDict):
outs = super()._forward_train(batch)
old_pi_inputs_encoded = self.old_encoder(batch)[ENCODER_OUT][ACTOR]
old_action_dist_logits = self.old_pi(old_pi_inputs_encoded)
old_action_dist = self.action_dist_cls.from_logits(old_action_dist_logits)
outs[OLD_ACTION_DIST_KEY] = old_action_dist
outs[OLD_ACTION_DIST_LOGITS_KEY] = old_action_dist_logits
return outs
53 changes: 30 additions & 23 deletions rllib/algorithms/impala/impala.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import random
from typing import Callable, List, Optional, Set, Tuple, Type, Union

import numpy as np
import tree # pip install dm_tree

import ray
from ray import ObjectRef
from ray.rllib import SampleBatch
Expand Down Expand Up @@ -938,35 +941,39 @@ def learn_on_processed_samples(self) -> ResultDict:
Aggregated results from the learner group after an update is completed.
"""
result = {}
# There are batches on the queue -> Send them to the learner group.
# There are batches on the queue -> Send them all to the learner group.
if self.batches_to_place_on_learner:
batch = self.batches_to_place_on_learner.pop(0)
batches = self.batches_to_place_on_learner[:]
self.batches_to_place_on_learner.clear()
# If there are no learner workers and learning is directly on the driver
# Then we can't do async updates, so we need to block.
blocking = self.config.num_learner_workers == 0
lg_results = self.learner_group.update(
batch,
reduce_fn=_reduce_impala_results,
block=blocking,
num_iters=self.config.num_sgd_iter,
minibatch_size=self.config.minibatch_size,
)
# Nothing on the queue -> Don't send requests to learner group.
else:
lg_results = None

if lg_results:
self._counters[NUM_ENV_STEPS_TRAINED] += lg_results[ALL_MODULES].pop(
NUM_ENV_STEPS_TRAINED
)
self._counters[NUM_AGENT_STEPS_TRAINED] += lg_results[ALL_MODULES].pop(
NUM_AGENT_STEPS_TRAINED
)
results = []
for batch in batches:
result = self.learner_group.update(
batch,
reduce_fn=_reduce_impala_results,
block=blocking,
num_iters=self.config.num_sgd_iter,
minibatch_size=self.config.minibatch_size,
)
if result:
self._counters[NUM_ENV_STEPS_TRAINED] += result[ALL_MODULES].pop(
NUM_ENV_STEPS_TRAINED
)
self._counters[NUM_AGENT_STEPS_TRAINED] += result[ALL_MODULES].pop(
NUM_AGENT_STEPS_TRAINED
)
results.append(result)
self._counters.update(self.learner_group.get_in_queue_stats())
result = lg_results
# If there are results, reduce-mean over each individual value and return.
if results:
return tree.map_structure(lambda *x: np.mean(x), *results)

return result
# Nothing on the queue -> Don't send requests to learner group
# or no results ready (from previous `self.learner_group.update()` calls) for
# reducing.
return {}

def place_processed_samples_on_learner_thread_queue(self) -> None:
"""Place processed samples on the learner queue for training.
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/impala_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _reduce_impala_results(results: List[ResultDict]) -> ResultDict:
"""Reduce/Aggregate a list of results from Impala Learners.
Average the values of the result dicts. Add keys for the number of agent and env
steps trained.
steps trained (on all modules).
Args:
results: result dicts to reduce.
Expand Down
5 changes: 3 additions & 2 deletions rllib/algorithms/impala/tf/impala_tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class ImpalaTfLearner(ImpalaLearner, TfLearner):
def compute_loss_per_module(
self, module_id: str, batch: SampleBatch, fwd_out: Mapping[str, TensorType]
) -> TensorType:
action_dist_class_train = self._module[module_id].get_train_action_dist_cls()
action_dist_class_train = self.module[module_id].get_train_action_dist_cls()
target_policy_dist = action_dist_class_train.from_logits(
fwd_out[SampleBatch.ACTION_DIST_INPUTS]
)
Expand Down Expand Up @@ -95,7 +95,8 @@ def compute_loss_per_module(
total_loss = (
pi_loss
+ vf_loss * self.hps.vf_loss_coeff
+ mean_entropy_loss * self.hps.entropy_coeff
+ mean_entropy_loss
* (self.entropy_coeff_scheduler.get_current_value(module_id))
)
return {
self.TOTAL_LOSS_KEY: total_loss,
Expand Down
7 changes: 5 additions & 2 deletions rllib/algorithms/impala/torch/impala_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ class ImpalaTorchLearner(ImpalaLearner, TorchLearner):
def compute_loss_per_module(
self, module_id: str, batch: SampleBatch, fwd_out: Mapping[str, TensorType]
) -> TensorType:
action_dist_class_train = self._module[module_id].get_train_action_dist_cls()
action_dist_class_train = (
self.module[module_id].unwrapped().get_train_action_dist_cls()
)
target_policy_dist = action_dist_class_train.from_logits(
fwd_out[SampleBatch.ACTION_DIST_INPUTS]
)
Expand Down Expand Up @@ -111,7 +113,8 @@ def compute_loss_per_module(
total_loss = (
pi_loss
+ vf_loss * self.hps.vf_loss_coeff
+ mean_entropy_loss * self.hps.entropy_coeff
+ mean_entropy_loss
* (self.entropy_coeff_scheduler.get_current_value(module_id))
)
return {
self.TOTAL_LOSS_KEY: total_loss,
Expand Down
13 changes: 7 additions & 6 deletions rllib/algorithms/ppo/ppo_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class PPOCatalog(Catalog):
- Value Function Head: The head used to compute the value function.
The ActorCriticEncoder is a wrapper around Encoders to produce separate outputs
for the policy and value function. See implementations of PPORLModuleBase for
for the policy and value function. See implementations of PPORLModule for
more details.
Any custom ActorCriticEncoder can be built by overriding the
Expand Down Expand Up @@ -89,8 +89,9 @@ def __init__(
hidden_layer_dims=post_fcnet_hiddens,
hidden_layer_activation=post_fcnet_activation,
output_activation="linear",
output_dims=None, # We don't know the output dimension yet, because it
# depends on the action distribution input dimension
# We don't know the output dimension yet, because it depends on the
# action distribution input dimension.
output_dims=None,
)

self.vf_head_config = MLPHeadConfig(
Expand All @@ -106,7 +107,7 @@ def build_actor_critic_encoder(self, framework: str) -> ActorCriticEncoder:
The default behavior is to build the encoder from the encoder_config.
This can be overridden to build a custom ActorCriticEncoder as a means of
configuring the behavior of a PPORLModuleBase implementation.
configuring the behavior of a PPORLModule implementation.
Args:
framework: The framework to use. Either "torch" or "tf2".
Expand All @@ -131,7 +132,7 @@ def build_pi_head(self, framework: str) -> Model:
The default behavior is to build the head from the pi_head_config.
This can be overridden to build a custom policy head as a means of configuring
the behavior of a PPORLModuleBase implementation.
the behavior of a PPORLModule implementation.
Args:
framework: The framework to use. Either "torch" or "tf2".
Expand All @@ -156,7 +157,7 @@ def build_vf_head(self, framework: str) -> Model:
The default behavior is to build the head from the vf_head_config.
This can be overridden to build a custom value function head as a means of
configuring the behavior of a PPORLModuleBase implementation.
configuring the behavior of a PPORLModule implementation.
Args:
framework: The framework to use. Either "torch" or "tf2".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,19 @@
"""

import abc
from typing import Type

from ray.rllib.core.models.base import ActorCriticEncoder
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.rl_module import RLModuleConfig
from ray.rllib.models.distributions import Distribution
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import ExperimentalAPI
from ray.rllib.utils.annotations import override


@ExperimentalAPI
class PPORLModuleBase(RLModule, abc.ABC):
def __init__(self, config: RLModuleConfig):
super().__init__(config)

class PPORLModule(RLModule, abc.ABC):
def setup(self):
# __sphinx_doc_begin__
catalog = self.config.get_catalog()
Expand All @@ -33,13 +30,13 @@ def setup(self):

assert isinstance(self.encoder, ActorCriticEncoder)

def get_train_action_dist_cls(self) -> Distribution:
def get_train_action_dist_cls(self) -> Type[Distribution]:
return self.action_dist_cls

def get_exploration_action_dist_cls(self) -> Distribution:
def get_exploration_action_dist_cls(self) -> Type[Distribution]:
return self.action_dist_cls

def get_inference_action_dist_cls(self) -> Distribution:
def get_inference_action_dist_cls(self) -> Type[Distribution]:
return self.action_dist_cls

@override(RLModule)
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/ppo/tests/test_ppo_with_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def on_train_result(self, *, algorithm, result: dict, **kwargs):
0.05 if algorithm.iteration == 1 else 0.0,
)

# Learning rate should decrease by 0.0001 per iteration.
# Learning rate should decrease by 0.0001/4 per iteration.
check(
stats[LEARNER_RESULTS_CURR_LR_KEY],
0.0003 if algorithm.iteration == 1 else 0.0002,
0.0000075 if algorithm.iteration == 1 else 0.000005,
)
# Compare reported curr lr vs the actual lr found in the optimizer object.
optim = algorithm.learner_group._learner._named_optimizers[DEFAULT_POLICY_ID]
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_ppo_compilation_and_schedule_mixins(self):
.training(
num_sgd_iter=2,
# Setup lr schedule for testing lr-scheduling correctness.
lr_schedule=[[0, 0.0004], [512, 0.0]], # 512=4x128
lr_schedule=[[0, 0.00001], [512, 0.0]], # 512=4x128
# Set entropy_coeff to a faulty value to proof that it'll get
# overridden by the schedule below (which is expected).
entropy_coeff=100.0,
Expand Down
Loading

0 comments on commit 384ad04

Please sign in to comment.