Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Replace remaining mentions of "trainer" by "algorithm". #36557

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3693,7 +3693,7 @@ py_test(
)

# Taking out this test for now: Mixed torch- and tf- policies within the same
# Trainer never really worked.
# Algorothm never really worked.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

algorithm

# py_test(
# name = "examples/multi_agent_two_trainers_mixed_torch_tf",
# main = "examples/multi_agent_two_trainers.py",
Expand Down
6 changes: 3 additions & 3 deletions rllib/algorithms/a3c/tests/test_a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ def test_a3c_entropy_coeff_schedule(self):
min_time_s_per_iteration=0, min_sample_timesteps_per_iteration=20
)

def _step_n_times(trainer, n: int):
"""Step trainer n times.
def _step_n_times(algo, n: int):
"""Step Algorithm n times.

Returns:
learning rate at the end of the execution.
"""
for _ in range(n):
results = trainer.train()
results = algo.train()
return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][
"entropy_coeff"
]
Expand Down
120 changes: 27 additions & 93 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,9 @@ def __init__(
# Last resort: Create core AlgorithmConfig from merged dicts.
if isinstance(default_config, dict):
config = AlgorithmConfig.from_dict(
config_dict=self.merge_trainer_configs(default_config, config, True)
config_dict=self.merge_algorithm_configs(
default_config, config, True
)
)
# Default config is an AlgorithmConfig -> update its properties
# from the given config dict.
Expand Down Expand Up @@ -569,17 +571,17 @@ def setup(self, config: AlgorithmConfig) -> None:
)
self.config.off_policy_estimation_methods = ope_dict

# Deprecated way of implementing Trainer sub-classes (or "templates"
# Deprecated way of implementing Algorithm sub-classes (or "templates"
# via the `build_trainer` utility function).
# Instead, sub-classes should override the Trainable's `setup()`
# method and call super().setup() from within that override at some
# point.
# Old design: Override `Trainer._init`.
# Old design: Override `Algorithm._init`.
_init = False
try:
self._init(self.config, self.env_creator)
_init = True
# New design: Override `Trainable.setup()` (as indented by tune.Trainable)
# New design: Override `Algorithm.setup()` (as indented by tune.Trainable)
# and do or don't call `super().setup()` from within your override.
# By default, `super().setup()` will create both worker sets:
# "rollout workers" for collecting samples for training and - if
Expand Down Expand Up @@ -731,7 +733,7 @@ def setup(self, config: AlgorithmConfig) -> None:
# Run `on_algorithm_init` callback after initialization is done.
self.callbacks.on_algorithm_init(algorithm=self)

# TODO: Deprecated: In your sub-classes of Trainer, override `setup()`
# TODO: Deprecated: In your sub-classes of Algorithm, override `setup()`
# directly and call super().setup() from within it if you would like the
# default setup behavior plus some own setup logic.
# If you don't need the env/workers/config/etc.. setup for you by super,
Expand All @@ -755,13 +757,13 @@ def get_default_policy_class(

@override(Trainable)
def step(self) -> ResultDict:
"""Implements the main `Trainer.train()` logic.
"""Implements the main `Algorithm.train()` logic.

Takes n attempts to perform a single training step. Thereby
catches RayErrors resulting from worker failures. After n attempts,
fails gracefully.

Override this method in your Trainer sub-classes if you would like to
Override this method in your Algorithm sub-classes if you would like to
handle worker failures yourself.
Otherwise, override only `training_step()` to implement the core
algorithm logic.
Expand Down Expand Up @@ -803,7 +805,7 @@ def step(self) -> ResultDict:
if not evaluate_this_iter and self.config.always_attach_evaluation_results:
assert isinstance(
self.evaluation_metrics, dict
), "Trainer.evaluate() needs to return a dict."
), "Algorithm.evaluate() needs to return a dict."
results.update(self.evaluation_metrics)

if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
Expand Down Expand Up @@ -853,9 +855,6 @@ def evaluate(
) -> dict:
"""Evaluates current policy under `evaluation_config` settings.

Note that this default implementation does not do anything beyond
merging evaluation_config with the normal trainer config.

Args:
duration_fn: An optional callable taking the already run
num episodes as only arg and returning the number of
Expand Down Expand Up @@ -902,7 +901,7 @@ def evaluate(
):
raise ValueError(
"Cannot evaluate w/o an evaluation worker set in "
"the Trainer or w/o an env on the local worker!\n"
"the Algorithm or w/o an env on the local worker!\n"
"Try one of the following:\n1) Set "
"`evaluation_interval` >= 0 to force creating a "
"separate evaluation worker set.\n2) Set "
Expand Down Expand Up @@ -1093,7 +1092,7 @@ def duration_fn(num_units_done):
metrics["off_policy_estimator"][name] = avg_estimate

# Evaluation does not run for every step.
# Save evaluation metrics on trainer, so it can be attached to
# Save evaluation metrics on Algorithm, so it can be attached to
# subsequent step results as latest evaluation result.
self.evaluation_metrics = {"evaluation": metrics}

Expand Down Expand Up @@ -1286,7 +1285,7 @@ def remote_fn(worker):
metrics["off_policy_estimator"][name] = estimates

# Evaluation does not run for every step.
# Save evaluation metrics on trainer, so it can be attached to
# Save evaluation metrics on Algorithm, so it can be attached to
# subsequent step results as latest evaluation result.
self.evaluation_metrics = {"evaluation": metrics}

Expand Down Expand Up @@ -1348,7 +1347,7 @@ def training_step(self) -> ResultDict:
"""Default single iteration logic of an algorithm.

- Collect on-policy samples (SampleBatches) in parallel using the
Trainer's RolloutWorkers (@ray.remote).
Algorithm's RolloutWorkers (@ray.remote).
- Concatenate collected SampleBatches into one train batch.
- Note that we may have more than one policy in the multi-agent case:
Call the different policies' `learn_on_batch` (simple optimizer) OR
Expand Down Expand Up @@ -1419,10 +1418,10 @@ def training_step(self) -> ResultDict:
@staticmethod
def execution_plan(workers, config, **kwargs):
raise NotImplementedError(
"It is not longer recommended to use Trainer's `execution_plan` method/API."
"It is no longer supported to use the `Algorithm.execution_plan()` API!"
" Set `_disable_execution_plan_api=True` in your config and override the "
"`Trainer.training_step()` method with your algo's custom "
"execution logic."
"`Algorithm.training_step()` method with your algo's custom "
"execution logic instead."
)

@PublicAPI
Expand All @@ -1442,9 +1441,6 @@ def compute_single_action(
episode: Optional[Episode] = None,
unsquash_action: Optional[bool] = None,
clip_action: Optional[bool] = None,
# Deprecated args.
unsquash_actions=DEPRECATED_VALUE,
clip_actions=DEPRECATED_VALUE,
# Kwargs placeholder for future compatibility.
**kwargs,
) -> Union[
Expand Down Expand Up @@ -1494,24 +1490,9 @@ def compute_single_action(
or we have an RNN-based Policy.

Raises:
KeyError: If the `policy_id` cannot be found in this Trainer's
local worker.
KeyError: If the `policy_id` cannot be found in this Algorithm's local
worker.
"""
if clip_actions != DEPRECATED_VALUE:
deprecation_warning(
old="Trainer.compute_single_action(`clip_actions`=...)",
new="Trainer.compute_single_action(`clip_action`=...)",
error=True,
)
clip_action = clip_actions
if unsquash_actions != DEPRECATED_VALUE:
deprecation_warning(
old="Trainer.compute_single_action(`unsquash_actions`=...)",
new="Trainer.compute_single_action(`unsquash_action`=...)",
error=True,
)
unsquash_action = unsquash_actions

# `unsquash_action` is None: Use value of config['normalize_actions'].
if unsquash_action is None:
unsquash_action = self.config.normalize_actions
Expand All @@ -1523,7 +1504,7 @@ def compute_single_action(
# are all None.
err_msg = (
"Provide either `input_dict` OR [`observation`, ...] as "
"args to Trainer.compute_single_action!"
"args to `Algorithm.compute_single_action()`!"
)
if input_dict is not None:
assert (
Expand All @@ -1537,12 +1518,12 @@ def compute_single_action(
assert observation is not None, err_msg

# Get the policy to compute the action for (in the multi-agent case,
# Trainer may hold >1 policies).
# Algorithm may hold >1 policies).
policy = self.get_policy(policy_id)
if policy is None:
raise KeyError(
f"PolicyID '{policy_id}' not found in PolicyMap of the "
f"Trainer's local worker!"
f"Algorithm's local worker!"
)
local_worker = self.workers.local_worker()

Expand Down Expand Up @@ -1645,8 +1626,6 @@ def compute_actions(
episodes: Optional[List[Episode]] = None,
unsquash_actions: Optional[bool] = None,
clip_actions: Optional[bool] = None,
# Deprecated.
normalize_actions=None,
**kwargs,
):
"""Computes an action for the specified policy on the local Worker.
Expand Down Expand Up @@ -1688,14 +1667,6 @@ def compute_actions(
the full output of policy.compute_actions_from_input_dict() if
full_fetch=True or we have an RNN-based Policy.
"""
if normalize_actions is not None:
deprecation_warning(
old="Trainer.compute_actions(`normalize_actions`=...)",
new="Trainer.compute_actions(`unsquash_actions`=...)",
error=True,
)
unsquash_actions = normalize_actions

# `unsquash_actions` is None: Use value of config['normalize_actions'].
if unsquash_actions is None:
unsquash_actions = self.config.normalize_actions
Expand Down Expand Up @@ -1822,8 +1793,6 @@ def add_policy(
] = None,
evaluation_workers: bool = True,
module_spec: Optional[SingleAgentRLModuleSpec] = None,
# Deprecated.
workers: Optional[List[Union[RolloutWorker, ActorHandle]]] = DEPRECATED_VALUE,
) -> Optional[Policy]:
"""Adds a new policy to this Algorithm.

Expand Down Expand Up @@ -1861,27 +1830,13 @@ def add_policy(
module_spec: In the new RLModule API we need to pass in the module_spec for
the new module that is supposed to be added. Knowing the policy spec is
not sufficient.
workers: A list of RolloutWorker/ActorHandles (remote
RolloutWorkers) to add this policy to. If defined, will only
add the given policy to these workers.


Returns:
The newly added policy (the copy that got added to the local
worker). If `workers` was provided, None is returned.
"""
validate_policy_id(policy_id, error=True)

if workers is not DEPRECATED_VALUE:
deprecation_warning(
old="Algorithm.add_policy(.., workers=..)",
help=(
"The `workers` argument to `Algorithm.add_policy()` is deprecated! "
"Please do not use it anymore."
),
error=True,
)

self.workers.add_policy(
policy_id,
policy_cls,
Expand Down Expand Up @@ -2004,7 +1959,6 @@ def export_policy_model(
def export_policy_checkpoint(
self,
export_dir: str,
filename_prefix=DEPRECATED_VALUE, # deprecated arg, do not use anymore
policy_id: PolicyID = DEFAULT_POLICY_ID,
) -> None:
"""Exports Policy checkpoint to a local directory and returns an AIR Checkpoint.
Expand All @@ -2027,14 +1981,6 @@ def export_policy_checkpoint(
>>> algo.train() # doctest: +SKIP
>>> algo.export_policy_checkpoint("/tmp/export_dir") # doctest: +SKIP
"""
# `filename_prefix` should not longer be used as new Policy checkpoints
# contain more than one file with a fixed filename structure.
if filename_prefix != DEPRECATED_VALUE:
deprecation_warning(
old="Algorithm.export_policy_checkpoint(filename_prefix=...)",
error=True,
)

policy = self.get_policy(policy_id)
if policy is None:
raise KeyError(f"Policy with ID {policy_id} not found in Algorithm!")
Expand Down Expand Up @@ -2161,7 +2107,8 @@ def load_checkpoint(self, checkpoint: str) -> None:
def log_result(self, result: ResultDict) -> None:
# Log after the callback is invoked, so that the user has a chance
# to mutate the result.
# TODO: Remove `trainer` arg at some point to fully deprecate the old signature.
# TODO: Remove `algorithm` arg at some point to fully deprecate the old
# signature.
self.callbacks.on_train_result(algorithm=self, result=result)
# Then log according to Trainable's logging logic.
Trainable.log_result(self, result)
Expand Down Expand Up @@ -2465,7 +2412,7 @@ def get_auto_filled_metrics(
return auto_filled

@classmethod
def merge_trainer_configs(
def merge_algorithm_configs(
cls,
config1: AlgorithmConfigDict,
config2: PartialAlgorithmConfigDict,
Expand Down Expand Up @@ -2742,7 +2689,7 @@ def _checkpoint_info_to_algorithm_state(
if isinstance(default_config, AlgorithmConfig):
new_config = default_config.update_from_dict(state["config"])
else:
new_config = Algorithm.merge_trainer_configs(
new_config = Algorithm.merge_algorithm_configs(
default_config, state["config"]
)

Expand Down Expand Up @@ -3134,21 +3081,8 @@ def _record_usage(self, config):
alg = "USER_DEFINED"
record_extra_usage_tag(TagKey.RLLIB_ALGORITHM, alg)

@Deprecated(new="Algorithm.compute_single_action()", error=True)
def compute_action(self, *args, **kwargs):
return self.compute_single_action(*args, **kwargs)

@Deprecated(new="construct WorkerSet(...) instance directly", error=True)
def _make_workers(self, *args, **kwargs):
pass

@Deprecated(new="AlgorithmConfig.validate()", error=False)
def validate_config(self, config):
pass

@staticmethod
@Deprecated(new="AlgorithmConfig.validate()", error=True)
def _validate_config(config, trainer_or_none):
def validate_config(self, config):
pass


Expand Down
10 changes: 5 additions & 5 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class AlgorithmConfig(_Config):
... .resources(num_gpus=0)
... .rollouts(num_rollout_workers=4)
... .callbacks(MemoryTrackingCallbacks)
>>> # A config object can be used to construct the respective Trainer.
>>> # A config object can be used to construct the respective Algorithm.
>>> rllib_algo = config.build() # doctest: +SKIP

Example:
Expand All @@ -139,7 +139,7 @@ class AlgorithmConfig(_Config):
>>> # Use `to_dict()` method to get the legacy plain python config dict
>>> # for usage with `tune.Tuner().fit()`.
>>> tune.Tuner( # doctest: +SKIP
... "[registered trainer class]", param_space=config.to_dict()
... "[registered Algorithm class]", param_space=config.to_dict()
... ).fit()
"""

Expand Down Expand Up @@ -234,7 +234,7 @@ def overrides(cls, **kwargs):
def __init__(self, algo_class=None):
# Define all settings and their default values.

# Define the default RLlib Trainer class that this AlgorithmConfig will be
# Define the default RLlib Algorithm class that this AlgorithmConfig will be
# applied to.
self.algo_class = algo_class

Expand Down Expand Up @@ -1154,7 +1154,7 @@ def resources(
`num_gpus_per_learner_worker` accordingly (e.g. 4 GPUs total, and model
needs 2 GPUs: `num_learner_workers = 2` and
`num_gpus_per_learner_worker = 2`)
num_cpus_per_learner_worker: Number of CPUs allocated per trainer worker.
num_cpus_per_learner_worker: Number of CPUs allocated per Learner worker.
Only necessary for custom processing pipeline inside each Learner
requiring multiple CPU cores. Ignored if `num_learner_workers = 0`.
num_gpus_per_learner_worker: Number of GPUs allocated per worker. If
Expand Down Expand Up @@ -3094,7 +3094,7 @@ def get_default_learner_class(self) -> Union[Type["Learner"], str]:

Returns:
The Learner class to use for this algorithm either as a class type or as
a string (e.g. ray.rllib.core.learner.testing.torch.BCTrainer).
a string (e.g. ray.rllib.core.learner.testing.torch.BC).
"""
raise NotImplementedError

Expand Down
Loading
Loading