forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_torch_learner_compile.py
131 lines (104 loc) · 4.41 KB
/
test_torch_learner_compile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import itertools
import unittest
import gymnasium as gym
import ray
from ray.rllib.core.learner.learner import (
FrameworkHyperparameters,
TorchCompileWhatToCompile,
)
from ray.rllib.utils.torch_utils import _dynamo_is_available
from ray.rllib.core.rl_module.torch.torch_compile_config import TorchCompileConfig
from ray.rllib.core.testing.utils import get_learner
from ray.rllib.core.testing.utils import get_module_spec
from ray.rllib.policy.sample_batch import MultiAgentBatch
from ray.rllib.utils.test_utils import get_cartpole_dataset_reader
class TestLearner(unittest.TestCase):
@classmethod
def setUp(cls) -> None:
ray.init()
@classmethod
def tearDown(cls) -> None:
ray.shutdown()
@unittest.skipIf(not _dynamo_is_available(), "torch._dynamo not available")
def test_torch_compile(self):
"""Test if torch.compile() can be applied and used on the learner.
Also tests if we can update with the compiled update method without errors.
"""
env = gym.make("CartPole-v1")
is_multi_agents = [False, True]
what_to_compiles = [
TorchCompileWhatToCompile.FORWARD_TRAIN,
TorchCompileWhatToCompile.COMPLETE_UPDATE,
]
for is_multi_agent, what_to_compile in itertools.product(
is_multi_agents, what_to_compiles
):
print(
f"Testing is_multi_agent={is_multi_agent},"
f"what_to_compile={what_to_compile}"
)
framework_hps = FrameworkHyperparameters(
torch_compile=True,
torch_compile_cfg=TorchCompileConfig(),
what_to_compile=what_to_compile,
)
learner = get_learner(
framework="torch", env=env, framework_hps=framework_hps
)
learner.build()
reader = get_cartpole_dataset_reader(batch_size=512)
for iter_i in range(10):
batch = reader.next()
learner.update(batch.as_multi_agent())
spec = get_module_spec(framework="torch", env=env, is_multi_agent=False)
learner.add_module(module_id="another_module", module_spec=spec)
for iter_i in range(10):
batch = MultiAgentBatch(
{"another_module": reader.next(), "default_policy": reader.next()},
0,
)
learner.update(batch)
learner.remove_module(module_id="another_module")
# Todo (rllib-team): Fix for torch 2.0+
@unittest.skip("Failing with torch >= 2.0")
@unittest.skipIf(not _dynamo_is_available(), "torch._dynamo not available")
def test_torch_compile_no_breaks(self):
"""Tests if torch.compile() does encounter too many breaks.
torch.compile() should ideally not encounter any breaks when compiling the
update method of the learner. This method tests if we encounter only a given
number of breaks.
"""
env = gym.make("CartPole-v1")
framework_hps = FrameworkHyperparameters(
torch_compile=True,
torch_compile_cfg=TorchCompileConfig(),
)
learner = get_learner(framework="torch", env=env, framework_hps=framework_hps)
learner.build()
import torch._dynamo as dynamo
reader = get_cartpole_dataset_reader(batch_size=512)
batch = reader.next().as_multi_agent()
batch = learner._convert_batch_type(batch)
# The followingcall to dynamo.explain() breaks depending on the torch version.
# It works for torch==2.0.0.
# TODO(Artur): Fit this to to the correct torch version once it is enabled on
# CI.
# This is a helper method of dynamo to analyze where breaks occur.
(
explanation,
out_guards,
graphs,
ops_per_graph,
break_reasons,
explanation_verbose,
) = dynamo.explain(learner._update, batch)
print(explanation_verbose)
# There should be only one break reason - `return_value` - since inputs and
# outputs are not checked
# TODO(Artur): Attempt bringing breaks down to 1. (This may not be possible)
# Note: This test is skipped on CI if torch dynamo is available.
self.assertEquals(len(break_reasons), 3)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))