Skip to content

Commit

Permalink
[RLlib][RLModules] RNNs and RLModules (ray-project#32723)
Browse files Browse the repository at this point in the history
Signed-off-by: Artur Niederfahrenhorst <[email protected]>
Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
Co-authored-by: Kourosh Hakhamaneshi <[email protected]>
Co-authored-by:  Artur Niederfahrenhorst <[email protected]>
  • Loading branch information
ArturNiederfahrenhorst and kouroshHakha committed Jun 28, 2023
1 parent 8f325dc commit 960032a
Show file tree
Hide file tree
Showing 52 changed files with 2,255 additions and 573 deletions.
91 changes: 75 additions & 16 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,16 @@ py_test(
args = ["--dir=tuned_examples/ppo"]
)

py_test(
name = "learning_tests_repeat_after_me_ppo_with_rl_module",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_discrete", "torch_only"],
size = "medium",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/ppo/repeatafterme-ppo-lstm-with-rl-module.yaml"],
args = ["--dir=tuned_examples/ppo"]
)

py_test(
name = "learning_tests_cartpole_ppo_fake_gpus",
main = "tests/run_regression_tests.py",
Expand Down Expand Up @@ -3087,12 +3097,12 @@ py_test(
)

py_test(
name = "examples/cartpole_lstm_impala_tf",
name = "examples/cartpole_lstm_impala_tf2",
main = "examples/cartpole_lstm.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--framework=tf", "--run=IMPALA", "--stop-reward=40", "--num-cpus=4"]
args = ["--run=IMPALA", "--as-test", "--framework=tf2", "--stop-reward=28", "--num-cpus=4"]
)

py_test(
Expand All @@ -3101,25 +3111,18 @@ py_test(
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--framework=torch", "--run=IMPALA", "--stop-reward=40", "--num-cpus=4"]
args = ["--run=IMPALA", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4"]
)

py_test(
name = "examples/cartpole_lstm_ppo_tf",
main = "examples/cartpole_lstm.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--framework=tf", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
)

# TODO (Kourosh): tf2 ~5x slower compared to torch on the new stack
py_test(
name = "examples/cartpole_lstm_ppo_tf2",
main = "examples/cartpole_lstm.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
size = "large",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--framework=tf2", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
args = ["--run=PPO", "--as-test", "--framework=tf2", "--stop-reward=28", "--num-cpus=4"]
)

py_test(
Expand All @@ -3128,16 +3131,16 @@ py_test(
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--framework=torch", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
args = ["--run=PPO", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4"]
)

py_test(
name = "examples/cartpole_lstm_ppo_tf_with_prev_a_and_r",
name = "examples/cartpole_lstm_ppo_torch_with_prev_a_and_r",
main = "examples/cartpole_lstm.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--run=PPO", "--stop-reward=40", "--use-prev-action", "--use-prev-reward", "--num-cpus=4"]
args = ["--run=PPO", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4", "--use-prev-action", "--use-prev-reward"]
)

py_test(
Expand Down Expand Up @@ -3445,6 +3448,43 @@ py_test(
args = ["--as-test", "--framework=torch", "--run=PPO", "--stop-reward=10", "--stop-timesteps=300000", "--env=RepeatInitialObsEnv", "--num-cpus=4"]
)


py_test(
name = "examples/custom_recurrent_rnn_tokenizer_repeat_after_me_tf2",
main = "examples/custom_recurrent_rnn_tokenizer.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/custom_recurrent_rnn_tokenizer.py"],
args = ["--as-test", "--framework=tf2", "--stop-reward=40", "--env=RepeatAfterMeEnv", "--num-cpus=4"]
)

py_test(
name = "examples/custom_recurrent_rnn_tokenizer_repeat_initial_obs_env_tf2",
main = "examples/custom_recurrent_rnn_tokenizer.py",
tags = ["team:rllib", "examples"],
size = "medium",
srcs = ["examples/custom_recurrent_rnn_tokenizer.py"],
args = ["--as-test", "--framework=tf2", "--stop-reward=10", "--stop-timesteps=300000", "--env=RepeatInitialObsEnv", "--num-cpus=4"]
)

py_test(
name = "examples/custom_recurrent_rnn_tokenizer_repeat_after_me_torch",
main = "examples/custom_recurrent_rnn_tokenizer.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/custom_recurrent_rnn_tokenizer.py"],
args = ["--as-test", "--framework=torch", "--stop-reward=40", "--env=RepeatAfterMeEnv", "--num-cpus=4"]
)

py_test(
name = "examples/custom_recurrent_rnn_tokenizer_repeat_initial_obs_env_torch",
main = "examples/custom_recurrent_rnn_tokenizer.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/custom_recurrent_rnn_tokenizer.py"],
args = ["--as-test", "--framework=torch", "--stop-reward=10", "--stop-timesteps=300000", "--env=RepeatInitialObsEnv", "--num-cpus=4"]
)

py_test(
name = "examples/custom_tf_policy",
tags = ["team:rllib", "exclusive", "examples"],
Expand Down Expand Up @@ -4093,6 +4133,25 @@ py_test(
args = ["--as-test", "--framework=torch", "--stop-reward=100.0"]
)


py_test(
name = "examples/trajectory_view_api_rlm_tf2",
main = "examples/trajectory_view_api_rlm.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/trajectory_view_api_rlm.py"],
args = ["--as-test", "--framework=tf2", "--stop-reward=100.0"]
)

py_test(
name = "examples/trajectory_view_api_rlm_torch",
main = "examples/trajectory_view_api_rlm.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
srcs = ["examples/trajectory_view_api_rlm.py"],
args = ["--as-test", "--framework=torch", "--stop-reward=100.0"]
)

py_test(
name = "examples/tune/framework",
main = "examples/tune/framework.py",
Expand Down
13 changes: 1 addition & 12 deletions rllib/algorithms/appo/tf/appo_tf_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
OLD_ACTION_DIST_LOGITS_KEY,
)
from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule
from ray.rllib.core.models.base import ACTOR, CRITIC, STATE_IN
from ray.rllib.core.models.base import ACTOR
from ray.rllib.core.models.tf.encoder import ENCODER_OUT
from ray.rllib.core.rl_module.rl_module_with_target_networks_interface import (
RLModuleWithTargetNetworksInterface,
Expand Down Expand Up @@ -45,19 +45,8 @@ def output_specs_train(self) -> List[str]:
@override(PPOTfRLModule)
def _forward_train(self, batch: NestedDict):
outs = super()._forward_train(batch)

# TODO (Artur): Remove this once Policy supports RNN
batch = batch.copy()
if self.encoder.config.shared:
batch[STATE_IN] = None
else:
batch[STATE_IN] = {
ACTOR: None,
CRITIC: None,
}
batch[SampleBatch.SEQ_LENS] = None
old_pi_inputs_encoded = self.old_encoder(batch)[ENCODER_OUT][ACTOR]

old_action_dist_logits = tf.stop_gradient(self.old_pi(old_pi_inputs_encoded))
outs[OLD_ACTION_DIST_LOGITS_KEY] = old_action_dist_logits
return outs
3 changes: 0 additions & 3 deletions rllib/algorithms/ppo/ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import abc
from typing import Type

from ray.rllib.core.models.base import ActorCriticEncoder
from ray.rllib.core.models.specs.specs_dict import SpecDict
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.models.distributions import Distribution
Expand All @@ -28,8 +27,6 @@ def setup(self):
self.action_dist_cls = catalog.get_action_dist_cls(framework=self.framework)
# __sphinx_doc_end__

assert isinstance(self.encoder, ActorCriticEncoder)

def get_train_action_dist_cls(self) -> Type[Distribution]:
return self.action_dist_cls

Expand Down
99 changes: 58 additions & 41 deletions rllib/algorithms/ppo/tests/test_ppo_rl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import gymnasium as gym
import numpy as np
import tensorflow as tf
import tree

import ray
Expand All @@ -15,13 +14,19 @@
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import (
PPOTorchRLModule,
)
from ray.rllib.core.models.base import STATE_IN
from ray.rllib.core.models.base import STATE_IN, STATE_OUT
from ray.rllib.core.rl_module.rl_module import RLModuleConfig
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.torch_utils import convert_to_torch_tensor


tf1, tf, _ = try_import_tf()
tf1.enable_eager_execution()
torch, nn = try_import_torch()


def get_expected_module_config(
env: gym.Env,
model_config_dict: dict,
Expand Down Expand Up @@ -110,17 +115,17 @@ def _get_ppo_module(framework, env, lstm, observation_space):
return module


def _get_input_batch_from_obs(framework, obs):
def _get_input_batch_from_obs(framework, obs, lstm):
if framework == "torch":
batch = {
SampleBatch.OBS: convert_to_torch_tensor(obs)[None],
STATE_IN: None,
}
else:
batch = {
SampleBatch.OBS: tf.convert_to_tensor([obs]),
STATE_IN: None,
SampleBatch.OBS: tf.convert_to_tensor(obs)[None],
}
if lstm:
batch[SampleBatch.OBS] = batch[SampleBatch.OBS][None]
return batch


Expand All @@ -138,8 +143,7 @@ def test_rollouts(self):
frameworks = ["torch", "tf2"]
env_names = ["CartPole-v1", "Pendulum-v1", "ALE/Breakout-v5"]
fwd_fns = ["forward_exploration", "forward_inference"]
# TODO(Artur): Re-enable LSTM
lstm = [False]
lstm = [True, False]
config_combinations = [frameworks, env_names, fwd_fns, lstm]
for config in itertools.product(*config_combinations):
fw, env_name, fwd_fn, lstm = config
Expand All @@ -164,15 +168,14 @@ def test_rollouts(self):
obs, _ = env.reset()
obs = preprocessor.transform(obs)

batch = _get_input_batch_from_obs(fw, obs)
batch = _get_input_batch_from_obs(fw, obs, lstm)

# TODO (Artur): Un-uncomment once Policy supports RNN
# state_in = module.get_initial_state()
# state_in = tree.map_structure(
# lambda x: x[None], convert_to_torch_tensor(state_in)
# )
# batch[STATE_IN] = state_in
# batch[SampleBatch.SEQ_LENS] = torch.Tensor([1])
if lstm:
state_in = module.get_initial_state()
if fw == "torch":
state_in = convert_to_torch_tensor(state_in)
state_in = tree.map_structure(lambda x: x[None], state_in)
batch[STATE_IN] = state_in

if fwd_fn == "forward_exploration":
module.forward_exploration(batch)
Expand All @@ -181,16 +184,12 @@ def test_rollouts(self):

def test_forward_train(self):
# TODO: Add FrozenLake-v1 to cover LSTM case.
frameworks = ["torch", "tf2"]
frameworks = ["tf2", "torch"]
env_names = ["CartPole-v1", "Pendulum-v1", "ALE/Breakout-v5"]
# TODO(Artur): Re-enable LSTM
lstm = [False]
lstm = [False, True]
config_combinations = [frameworks, env_names, lstm]
for config in itertools.product(*config_combinations):
fw, env_name, lstm = config
if lstm and fw == "tf2":
# LSTM not implemented in TF2 yet
continue
print(f"[FW={fw} | [ENV={env_name}] | LSTM={lstm}")
# TODO(Artur): Figure out why this is needed and fix it.
if env_name.startswith("ALE/"):
Expand All @@ -213,18 +212,23 @@ def test_forward_train(self):
obs, _ = env.reset()
obs = preprocessor.transform(obs)
tstep = 0
# TODO (Artur): Un-uncomment once Policy supports RNN
# state_in = module.get_initial_state()
# state_in = tree.map_structure(
# lambda x: x[None], convert_to_torch_tensor(state_in)
# )
# initial_state = state_in

if lstm:
state_in = module.get_initial_state()
if fw == "torch":
state_in = tree.map_structure(
lambda x: x[None], convert_to_torch_tensor(state_in)
)
else:
state_in = tree.map_structure(
lambda x: tf.convert_to_tensor(x)[None], state_in
)
initial_state = state_in

while tstep < 10:
input_batch = _get_input_batch_from_obs(fw, obs)
# TODO (Artur): Un-uncomment once Policy supports RNN
# input_batch[STATE_IN] = state_in
# input_batch[SampleBatch.SEQ_LENS] = np.array([1])
input_batch = _get_input_batch_from_obs(fw, obs, lstm=lstm)
if lstm:
input_batch[STATE_IN] = state_in

fwd_out = module.forward_exploration(input_batch)
action_dist_cls = module.get_exploration_action_dist_cls()
Expand All @@ -234,6 +238,10 @@ def test_forward_train(self):
_action = action_dist.sample()
action = convert_to_numpy(_action[0])
action_logp = convert_to_numpy(action_dist.logp(_action)[0])
if lstm:
# Since this is inference, fwd out should only contain one action
assert len(action) == 1
action = action[0]
new_obs, reward, terminated, truncated, _ = env.step(action)
new_obs = preprocessor.transform(new_obs)
output_batch = {
Expand All @@ -247,9 +255,9 @@ def test_forward_train(self):
STATE_IN: None,
}

# TODO (Artur): Un-uncomment once Policy supports RNN
# assert STATE_OUT in fwd_out
# state_in = fwd_out[STATE_OUT]
if lstm:
assert STATE_OUT in fwd_out
state_in = fwd_out[STATE_OUT]
batches.append(output_batch)
obs = new_obs
tstep += 1
Expand All @@ -261,9 +269,13 @@ def test_forward_train(self):
fwd_in = {
k: convert_to_torch_tensor(np.array(v)) for k, v in batch.items()
}
# TODO (Artur): Un-uncomment once Policy supports RNN
# fwd_in[STATE_IN] = initial_state
# fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10])
if lstm:
fwd_in[STATE_IN] = initial_state
# If we test lstm, the collected timesteps make up only one batch
fwd_in = {
k: torch.unsqueeze(v, 0) if k != STATE_IN else v
for k, v in fwd_in.items()
}

# forward train
# before training make sure module is on the right device
Expand All @@ -281,9 +293,14 @@ def test_forward_train(self):
fwd_in = tree.map_structure(
lambda x: tf.convert_to_tensor(x, dtype=tf.float32), batch
)
# TODO (Artur): Un-uncomment once Policy supports RNN
# fwd_in[STATE_IN] = initial_state
# fwd_in[SampleBatch.SEQ_LENS] = torch.Tensor([10])
if lstm:
fwd_in[STATE_IN] = initial_state
# If we test lstm, the collected timesteps make up only one batch
fwd_in = {
k: tf.expand_dims(v, 0) if k != STATE_IN else v
for k, v in fwd_in.items()
}

with tf.GradientTape() as tape:
fwd_out = module.forward_train(fwd_in)
loss = dummy_tf_ppo_loss(module, fwd_in, fwd_out)
Expand Down
Loading

0 comments on commit 960032a

Please sign in to comment.