Skip to content

Commit

Permalink
[RLlib] Add seeding to learner and fix rl module enabling/disabling (r…
Browse files Browse the repository at this point in the history
…ay-project#35951)

Signed-off-by: Artur Niederfahrenhorst <[email protected]>
  • Loading branch information
ArturNiederfahrenhorst committed Jun 1, 2023
1 parent 7bd8886 commit 5c130a5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
6 changes: 3 additions & 3 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2534,9 +2534,8 @@ def rl_module(
if rl_module_spec is not NotProvided:
self.rl_module_spec = rl_module_spec

if _enable_rl_module_api is not NotProvided or self._enable_rl_module_api:
if not self._enable_rl_module_api:
self._enable_rl_module_api = _enable_rl_module_api
if _enable_rl_module_api is not NotProvided:
self._enable_rl_module_api = _enable_rl_module_api
if _enable_rl_module_api is True and self.exploration_config:
logger.warning(
"Setting `exploration_config={}` because you set "
Expand Down Expand Up @@ -3321,6 +3320,7 @@ def get_learner_hyperparameters(self) -> LearnerHyperparameters:
grad_clip=self.grad_clip,
grad_clip_by=self.grad_clip_by,
_per_module_overrides=per_module_learner_hp_overrides,
seed=self.seed,
)

def __setattr__(self, key, value):
Expand Down
8 changes: 8 additions & 0 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.debug import update_global_seed_if_necessary
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.metrics import (
ALL_MODULES,
Expand Down Expand Up @@ -124,6 +125,7 @@ class LearnerHyperparameters:
learning_rate: LearningRateOrSchedule = None
grad_clip: float = None
grad_clip_by: str = None
seed: int = None

# Maps ModuleIDs to LearnerHyperparameters that are to be used for that particular
# module.
Expand Down Expand Up @@ -284,6 +286,12 @@ def __init__(
learner_hyperparameters: Optional[LearnerHyperparameters] = None,
framework_hyperparameters: Optional[FrameworkHyperparameters] = None,
):
# We first set seeds
if learner_hyperparameters and learner_hyperparameters.seed is not None:
update_global_seed_if_necessary(
self.framework, learner_hyperparameters.seed
)

if (module_spec is None) is (module is None):
raise ValueError(
"Exactly one of `module_spec` or `module` must be provided to Learner!"
Expand Down

0 comments on commit 5c130a5

Please sign in to comment.