Skip to content

Commit

Permalink
[RLlib] Add torch compile capabilities to TorchRLModule (ray-project#…
Browse files Browse the repository at this point in the history
…34640)

Signed-off-by: Artur Niederfahrenhorst <[email protected]>
  • Loading branch information
ArturNiederfahrenhorst committed May 17, 2023
1 parent 429316c commit f796635
Show file tree
Hide file tree
Showing 13 changed files with 436 additions and 56 deletions.
13 changes: 12 additions & 1 deletion rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1934,9 +1934,20 @@ py_test(
name = "test_torch_rl_module",
tags = ["team:rllib", "core"],
size = "medium",
srcs = ["core/rl_module/torch/tests/test_torch_rl_module.py"]
srcs = ["core/rl_module/torch/tests/test_torch_rl_module.py"],
args = ["TestRLModule"],
)

# TODO(Artur): Comment this back in as soon as we can test with GPU
# py_test(
# name = "test_torch_rl_module_gpu",
# main = "core/rl_module/torch/tests/test_torch_rl_module.py",
# tags = ["team:rllib", "core", "gpu", "exclusive"],
# size = "medium",
# srcs = ["core/rl_module/torch/tests/test_torch_rl_module.py"],
# args = ["TestRLModuleGPU"],
# )

py_test(
name = "test_tf_rl_module",
tags = ["team:rllib", "core"],
Expand Down
105 changes: 99 additions & 6 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import math
import os
import sys
from typing import (
Any,
Callable,
Expand All @@ -16,22 +17,25 @@
Union,
)

from packaging import version

import ray
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.core.learner.learner import LearnerHyperparameters
from ray.rllib.core.learner.learner_group_config import (
LearnerGroupConfig,
ModuleSpec,
)
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.wrappers.atari_wrappers import is_atari
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
from ray.rllib.utils.torch_utils import TORCH_COMPILE_REQUIRED_VERSION
from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector
from ray.rllib.evaluation.episode import Episode
from ray.rllib.env.wrappers.atari_wrappers import is_atari
from ray.rllib.evaluation.rollout_worker import RolloutWorker
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
Expand Down Expand Up @@ -69,16 +73,15 @@
ResultDict,
SampleBatchType,
)
from ray.tune.tune import _Config
from ray.tune.logger import Logger
from ray.tune.registry import get_trainable_cls
from ray.tune.result import TRIAL_INFO
from ray.tune.tune import _Config
from ray.util import log_once

gym, old_gym = try_import_gymnasium_and_gym()
Space = gym.Space


"""TODO(jungong, sven): in "offline_data" we can potentially unify all input types
under input and input_config keys. E.g.
input: sample
Expand Down Expand Up @@ -278,6 +281,17 @@ def __init__(self, algo_class=None):
"intra_op_parallelism_threads": 8,
"inter_op_parallelism_threads": 8,
}
# Torch compile settings
self.torch_compile_learner = False
self.torch_compile_learner_dynamo_backend = (
"aot_eager" if sys.platform == "darwin" else "inductor"
)
self.torch_compile_learner_dynamo_mode = "reduce-overhead"
self.torch_compile_worker = False
self.torch_compile_worker_dynamo_backend = (
"aot_eager" if sys.platform == "darwin" else "inductor"
)
self.torch_compile_worker_dynamo_mode = "reduce-overhead"

# `self.environment()`
self.env = None
Expand Down Expand Up @@ -773,6 +787,15 @@ def validate(self) -> None:
else:
_torch, _ = try_import_torch()

# Check if torch framework supports torch.compile.
if (
_torch is not None
and self.framework_str == "torch"
and version.parse(_torch.__version__) < TORCH_COMPILE_REQUIRED_VERSION
and (self.torch_compile_learner or self.torch_compile_worker)
):
raise ValueError("torch.compile is only supported from torch 2.0.0")

self._check_if_correct_nn_framework_installed(_tf1, _tf, _torch)
self._resolve_tf_settings(_tf1, _tfv)

Expand Down Expand Up @@ -1190,6 +1213,12 @@ def framework(
eager_max_retraces: Optional[int] = NotProvided,
tf_session_args: Optional[Dict[str, Any]] = NotProvided,
local_tf_session_args: Optional[Dict[str, Any]] = NotProvided,
torch_compile_learner: Optional[bool] = NotProvided,
torch_compile_learner_dynamo_mode: Optional[str] = NotProvided,
torch_compile_learner_dynamo_backend: Optional[str] = NotProvided,
torch_compile_worker: Optional[bool] = NotProvided,
torch_compile_worker_dynamo_backend: Optional[str] = NotProvided,
torch_compile_worker_dynamo_mode: Optional[str] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's DL framework settings.
Expand All @@ -1210,6 +1239,21 @@ def framework(
tf_session_args: Configures TF for single-process operation by default.
local_tf_session_args: Override the following tf session args on the local
worker
torch_compile_learner: If True, forward_train methods on TorchRLModules
on the learner are compiled. If not specified, the default is to compile
forward train on the learner.
torch_compile_learner_dynamo_backend: The torch dynamo backend to use on
the learner.
torch_compile_learner_dynamo_mode: The torch dynamo mode to use on the
learner.
torch_compile_worker: If True, forward exploration and inference methods on
TorchRLModules on the workers are compiled. If not specified,
the default is to not compile forward methods on the workers because
retracing can be expensive.
torch_compile_worker_dynamo_backend: The torch dynamo backend to use on
the workers.
torch_compile_worker_dynamo_mode: The torch dynamo mode to use on the
workers.
Returns:
This updated AlgorithmConfig object.
Expand All @@ -1231,6 +1275,23 @@ def framework(
if local_tf_session_args is not NotProvided:
self.local_tf_session_args = local_tf_session_args

if torch_compile_learner is not NotProvided:
self.torch_compile_learner = torch_compile_learner
if torch_compile_learner_dynamo_backend is not NotProvided:
self.torch_compile_learner_dynamo_backend = (
torch_compile_learner_dynamo_backend
)
if torch_compile_learner_dynamo_mode is not NotProvided:
self.torch_compile_learner_dynamo_mode = torch_compile_learner_dynamo_mode
if torch_compile_worker is not NotProvided:
self.torch_compile_worker = torch_compile_worker
if torch_compile_worker_dynamo_backend is not NotProvided:
self.torch_compile_worker_dynamo_backend = (
torch_compile_worker_dynamo_backend
)
if torch_compile_worker_dynamo_mode is not NotProvided:
self.torch_compile_worker_dynamo_mode = torch_compile_worker_dynamo_mode

return self

def environment(
Expand Down Expand Up @@ -2424,6 +2485,7 @@ def rl_module(
By default if you call `config.rl_module(...)`, the
RLModule API will NOT be enabled. If you want to enable it, you can call
`config.rl_module(_enable_rl_module_api=True)`.
Returns:
This updated AlgorithmConfig object.
"""
Expand Down Expand Up @@ -2923,6 +2985,33 @@ def validate_train_batch_size_vs_rollout_fragment_length(self) -> None:
f"{suggested_rollout_fragment_length}."
)

def get_torch_compile_learner_config(self):
"""Returns the TorchCompileConfig to use on learners."""

from ray.rllib.core.rl_module.torch.torch_compile_config import (
TorchCompileConfig,
)

return TorchCompileConfig(
compile_forward_train=self.torch_compile_learner,
torch_dynamo_backend=self.torch_compile_learner_dynamo_backend,
torch_dynamo_mode=self.torch_compile_learner_dynamo_mode,
)

def get_torch_compile_worker_config(self):
"""Returns the TorchCompileConfig to use on workers."""

from ray.rllib.core.rl_module.torch.torch_compile_config import (
TorchCompileConfig,
)

return TorchCompileConfig(
compile_forward_exploration=self.torch_compile_worker,
compile_forward_inference=self.torch_compile_worker,
torch_dynamo_backend=self.torch_compile_worker_dynamo_backend,
torch_dynamo_mode=self.torch_compile_worker_dynamo_mode,
)

def get_default_rl_module_spec(self) -> ModuleSpec:
"""Returns the RLModule spec to use for this algorithm.
Expand Down Expand Up @@ -3142,9 +3231,13 @@ def get_learner_group_config(self, module_spec: ModuleSpec) -> LearnerGroupConfi
num_gpus_per_learner_worker=self.num_gpus_per_learner_worker,
local_gpu_idx=self.local_gpu_idx,
)
.framework(eager_tracing=self.eager_tracing)
)

if self.framework_str == "torch":
config.framework(torch_compile_cfg=self.get_torch_compile_learner_config())
elif self.framework_str == "tf2":
config.framework(eager_tracing=self.eager_tracing)

return config

def get_learner_hyperparameters(self) -> LearnerHyperparameters:
Expand Down
47 changes: 28 additions & 19 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import abc
from collections import defaultdict
from dataclasses import dataclass, field
import json
import logging
import numpy as np
import pathlib
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (
Any,
Callable,
Expand All @@ -18,36 +17,42 @@
Tuple,
Type,
Union,
TYPE_CHECKING,
)

import numpy as np

import ray
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.core.learner.reduce_result_dict_fn import _reduce_mean_results
from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig
from ray.rllib.core.rl_module.marl_module import (
MultiAgentRLModule,
MultiAgentRLModuleSpec,
)
from ray.rllib.core.rl_module.rl_module import (
RLModule,
ModuleID,
SingleAgentRLModuleSpec,
)
from ray.rllib.core.rl_module.marl_module import (
MultiAgentRLModule,
MultiAgentRLModuleSpec,
)
from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch
from ray.rllib.utils.annotations import (
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.metrics import LEARNER_STATS_KEY, ALL_MODULES
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.typing import TensorType, ResultDict
from ray.rllib.utils.minibatch_utils import (
MiniBatchDummyIterator,
MiniBatchCyclicIterator,
)
from ray.rllib.utils.serialization import serialize_type
from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig
from ray.rllib.core.learner.reduce_result_dict_fn import _reduce_mean_results
from ray.rllib.utils.annotations import (
OverrideToImplementCustomLogic,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.schedules.scheduler import Scheduler
from ray.rllib.utils.serialization import serialize_type
from ray.rllib.utils.typing import TensorType, ResultDict

if TYPE_CHECKING:
from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig

torch, _ = try_import_torch()
tf1, tf, tfv = try_import_tf()
Expand Down Expand Up @@ -81,9 +86,12 @@ class FrameworkHyperparameters:
This is useful for speeding up the training loop. However, it is not
compatible with all tf operations. For example, tf.print is not supported
in tf.function.
troch_compile_config: The TorchCompileConfig to use for compiling the RL
Module in Torch.
"""

eager_tracing: bool = False
torch_compile_cfg: Optional["TorchCompileConfig"] = None


@dataclass
Expand Down Expand Up @@ -143,7 +151,7 @@ class Learner:
ray.rllib.core.learner.learner.LearnerHyperparameters for more info.
framework_hps: The framework specific hyper-parameters. This will be used to
pass in any framework specific hyper-parameter that will impact the module
creation. For example eager_tracing in TF or compile in Torch.
creation. For example `eager_tracing` in TF or `torch.compile()` in Torch.
Refer to ray.rllib.core.learner.learner.FrameworkHyperparameters for
more info.
Expand Down Expand Up @@ -634,6 +642,7 @@ def build(self) -> None:
)

self._module = self._make_module()

for param_seq, optimizer in self.configure_optimizers():
self._optimizer_parameters[optimizer] = []
for param in param_seq:
Expand Down
24 changes: 17 additions & 7 deletions rllib/core/learner/learner_group_config.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from typing import Type, Optional, TYPE_CHECKING, Union, Dict

from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.core.learner.learner_group import LearnerGroup
from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig
from ray.rllib.core.learner.learner import (
LearnerSpec,
LearnerHyperparameters,
FrameworkHyperparameters,
)
from ray.rllib.core.learner.learner_group import LearnerGroup
from ray.rllib.core.learner.scaling_config import LearnerGroupScalingConfig
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.utils.from_config import NotProvided


if TYPE_CHECKING:
from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig
from ray.rllib.core.learner import Learner

ModuleSpec = Union[SingleAgentRLModuleSpec, MultiAgentRLModuleSpec]
Expand Down Expand Up @@ -58,6 +58,7 @@ def __init__(self, cls: Type[LearnerGroup] = None) -> None:

# `self.framework()`
self.eager_tracing = False
self.torch_compile_cfg = None

def validate(self) -> None:

Expand All @@ -83,7 +84,10 @@ def build(self) -> LearnerGroup:
local_gpu_idx=self.local_gpu_idx,
)

framework_hps = FrameworkHyperparameters(eager_tracing=self.eager_tracing)
framework_hps = FrameworkHyperparameters(
eager_tracing=self.eager_tracing,
torch_compile_cfg=self.torch_compile_cfg,
)

learner_spec = LearnerSpec(
learner_class=self.learner_class,
Expand All @@ -97,11 +101,17 @@ def build(self) -> LearnerGroup:
return self.learner_group_class(learner_spec)

def framework(
self, eager_tracing: Optional[bool] = NotProvided
self,
eager_tracing: Optional[bool] = NotProvided,
torch_compile_cfg: Optional["TorchCompileConfig"] = NotProvided,
) -> "LearnerGroupConfig":

if eager_tracing is not NotProvided:
self.eager_tracing = eager_tracing

if torch_compile_cfg is not NotProvided:
self.torch_compile_cfg = torch_compile_cfg

return self

def module(
Expand Down
Loading

0 comments on commit f796635

Please sign in to comment.