Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Python-based RNNs in place operations cause RuntimeError #1742

Closed
3 tasks
albertbou92 opened this issue Dec 11, 2023 · 5 comments
Closed
3 tasks

[BUG] Python-based RNNs in place operations cause RuntimeError #1742

albertbou92 opened this issue Dec 11, 2023 · 5 comments
Assignees
Labels
bug Something isn't working

Comments

@albertbou92
Copy link
Contributor

albertbou92 commented Dec 11, 2023

Describe the bug

Training with the Python-based GRU raises the following error, which indicates the current implementation has some in-place operations that prevent correct backward computation:

File "/home/abou/test_bug.py", line 157, in
main()
File "/home/abou/test_bug.py", line 148, in main
loss = loss_module(batch.cuda())
File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
File "/home/abou/tensordict/tensordict/_contextlib.py", line 126, in decorate_context
return func(*args, **kwargs)
File "/home/abou/tensordict/tensordict/nn/common.py", line 281, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/home/abou/rl/torchrl/objectives/sac.py", line 1096, in forward
loss_actor, metadata_actor = self._actor_loss(tensordict_reshape)
File "/home/abou/rl/torchrl/objectives/sac.py", line 1203, in _actor_loss
dist = self.actor_network.get_dist(tensordict)
File "/home/abou/tensordict/tensordict/nn/probabilistic.py", line 524, in get_dist
tensordict_out = self.get_dist_params(tensordict, tensordict_out, **kwargs)
File "/home/abou/tensordict/tensordict/nn/probabilistic.py", line 515, in get_dist_params
return tds(tensordict, tensordict_out, **kwargs)
File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/abou/tensordict/tensordict/nn/functional_modules.py", line 589, in new_fun
return getattr(type(self), fun_name)(self, *args, **kwargs)
File "/home/abou/tensordict/tensordict/nn/common.py", line 281, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/home/abou/tensordict/tensordict/_contextlib.py", line 126, in decorate_context
return func(*args, **kwargs)
File "/home/abou/tensordict/tensordict/nn/utils.py", line 253, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/home/abou/tensordict/tensordict/nn/sequence.py", line 426, in forward
tensordict = self._run_module(module, tensordict, **kwargs)
File "/home/abou/tensordict/tensordict/nn/sequence.py", line 407, in _run_module
tensordict = module(tensordict, **kwargs)
File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/abou/tensordict/tensordict/nn/functional_modules.py", line 589, in new_fun
return getattr(type(self), fun_name)(self, *args, **kwargs)
File "/home/abou/tensordict/tensordict/nn/common.py", line 281, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/home/abou/tensordict/tensordict/_contextlib.py", line 126, in decorate_context
return func(*args, **kwargs)
File "/home/abou/tensordict/tensordict/nn/utils.py", line 253, in wrapper
return func(_self, tensordict, *args, **kwargs)
File "/home/abou/tensordict/tensordict/nn/sequence.py", line 426, in forward
tensordict = self._run_module(module, tensordict, **kwargs)
File "/home/abou/tensordict/tensordict/nn/sequence.py", line 407, in _run_module
tensordict = module(tensordict, **kwargs)
File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/abou/tensordict/tensordict/nn/functional_modules.py", line 589, in new_fun
return getattr(type(self), fun_name)(self, *args, **kwargs)
File "/home/abou/rl/torchrl/modules/tensordict_module/rnn.py", line 1346, in forward
val, hidden = self._gru(value, batch, steps, device, dtype, hidden)
File "/home/abou/rl/torchrl/modules/tensordict_module/rnn.py", line 1386, in _gru
y, hidden = self.gru(input, hidden)
File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/abou/tensordict/tensordict/nn/functional_modules.py", line 589, in new_fun
return getattr(type(self), fun_name)(self, *args, **kwargs)
File "/home/abou/rl/torchrl/modules/tensordict_module/rnn.py", line 990, in forward
result = self._gru(input, hx)
File "/home/abou/rl/torchrl/modules/tensordict_module/rnn.py", line 945, in _gru
h_t[layer] = self._gru_cell(
File "/home/abou/rl/torchrl/modules/tensordict_module/rnn.py", line 904, in _gru_cell
gate_h = F.linear(hx, weight_hh, bias_hh)
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
28it [00:13, 2.10it/s]
Traceback (most recent call last):
File "/home/abou/test_bug.py", line 157, in
main()
File "/home/abou/test_bug.py", line 150, in main
loss_sum.backward()
File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/shared/albert/miniconda3/envs/torch_rl2/lib/python3.9/site-packages/torch/autograd/init.py", line 251, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 256]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

@vmoens you were right, in place operations cause problems here. I believe the issue is being addressed in #1732

To Reproduce

import tqdm
import torch
import random
import numpy as np
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.envs.libs.gym import GymEnv
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.modules.distributions import OneHotCategorical
from torchrl.modules import ProbabilisticActor, GRUModule, MLP
from torchrl.collectors import SyncDataCollector
from torchrl.objectives import DiscreteSACLoss
from torchrl.envs import (
    ParallelEnv,
    TransformedEnv,
    InitTracker,
    StepCounter,
    RewardSum,
)


def create_model(input_size, output_size, hidden_size=256, num_layers=3, out_key="logits"):

    embedding_module = TensorDictModule(
        in_keys=["observation"],
        out_keys=["embed"],
        module=torch.nn.Linear(input_size, input_size), # this raises RuntimeError
    )
    lstm_module = GRUModule(
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        in_key="embed",
        out_key="features",
        python_based=True,
    )
    mlp = TensorDictModule(
        MLP(
            in_features=hidden_size,
            out_features=output_size,
            num_cells=[],
        ),
        in_keys=["features"],
        out_keys=[out_key],
    )

    inference_model = TensorDictSequential(embedding_module, lstm_module, mlp)
    training_model = TensorDictSequential(embedding_module, lstm_module.set_recurrent_mode(), mlp)

    return inference_model, training_model


def create_rhs_transform(input_size, hidden_size=256, num_layers=3):
    lstm_module = GRUModule(
        input_size=input_size,
        hidden_size=hidden_size,
        num_layers=num_layers,
        in_key="observation",
        out_key="features",
    )
    return lstm_module.make_tensordict_primer()


def main():

    # Set seeds
    seed = 2024
    random.seed(int(seed))
    np.random.seed(int(seed))
    torch.manual_seed(int(seed))

    torch.autograd.set_detect_anomaly(True)

    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    test_env = GymEnv("CartPole-v1", device=device, categorical_action_encoding=True)
    action_spec = test_env.action_spec.space
    observation_spec = test_env.observation_spec["observation"]

    def create_env_fn():
        env = GymEnv("CartPole-v1", device=device)
        env = TransformedEnv(env)
        env.append_transform(create_rhs_transform(input_size=observation_spec.shape[-1]))
        env.append_transform(InitTracker())
        return env

    # Models
    ##################

    inference_actor, training_actor = create_model(input_size=observation_spec.shape[-1], output_size=action_spec.n)
    inference_actor = ProbabilisticActor(
        module=inference_actor,
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=OneHotCategorical,
        return_log_prob=True,
    )
    training_actor = ProbabilisticActor(
        module=training_actor,
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=OneHotCategorical,
        return_log_prob=True,
    )
    inference_actor = inference_actor.to(device)
    training_actor = training_actor.to(device)
    _, training_critic = create_model(input_size=observation_spec.shape[-1], output_size=action_spec.n, out_key="action_value")
    training_critic = training_critic.to(device)

    # Collector
    ##################

    collector = SyncDataCollector(
        create_env_fn=create_env_fn,
        policy=inference_actor,
        frames_per_batch=100,
        total_frames=5000,
        device=device,
        storing_device=device,
        max_frames_per_traj=-1,
        split_trajs=False,
    )

    # Buffer
    ##################

    buffer = TensorDictReplayBuffer(
        storage=LazyMemmapStorage(100),
        batch_size=1,
    )

    # Loss
    ##################

    loss_module = DiscreteSACLoss(
        actor_network=training_actor,
        qvalue_network=training_critic,
        num_actions=action_spec.n,
        num_qvalue_nets=2,
        loss_function="smooth_l1",
    )
    loss_module.make_value_estimator(gamma=0.99)

    # Collection loop
    ##################

    for data in tqdm.tqdm(collector):
        buffer.extend(data.cpu())
        batch = buffer.sample()
        loss = loss_module(batch.cuda())
        loss_sum = loss["loss_actor"] + loss["loss_qvalue"] + loss["loss_alpha"]
        loss_sum.backward()

    collector.shutdown()
    print("Success!")


if __name__ == "__main__":
    main()

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@albertbou92 albertbou92 added the bug Something isn't working label Dec 11, 2023
@vmoens
Copy link
Contributor

vmoens commented Dec 11, 2023

inplace ops are the devil :)
Let's work on the patch in the "faster rnn" PR!

@vmoens
Copy link
Contributor

vmoens commented Dec 13, 2023

@albertbou92 can you write a smaller (more standalone and mininal) reprod example that I could put in the tests?

@albertbou92
Copy link
Contributor Author

albertbou92 commented Dec 14, 2023

What about like this?

import torch
from torchrl.collectors import SyncDataCollector
from torchrl.envs import InitTracker, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import GRUModule, ProbabilisticActor
from torchrl.modules.distributions import OneHotCategorical
from torchrl.objectives import DiscreteSACLoss


def create_model(input_size, output_size, num_layers=3, out_key="logits"):
    gru_module = GRUModule(
        input_size=input_size,
        hidden_size=output_size,
        num_layers=num_layers,
        in_key="observation",
        out_key=out_key,
        python_based=True,
    )
    return (
        gru_module,
        gru_module.set_recurrent_mode(True),
        gru_module.make_tensordict_primer(),
    )


def test_python_gru(device):

    env_name = "CartPole-v1"
    test_env = GymEnv(env_name)
    observation_size = test_env.observation_spec["observation"].shape[-1]
    num_actions = int(test_env.action_spec.space.n)

    inference_actor, training_actor, rhs_transform = create_model(
        input_size=observation_size, output_size=num_actions
    )
    inference_actor = ProbabilisticActor(
        module=inference_actor,
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=OneHotCategorical,
        return_log_prob=True,
    )
    training_actor = ProbabilisticActor(
        module=training_actor,
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=OneHotCategorical,
        return_log_prob=True,
    )
    inference_actor = inference_actor.to(device)
    training_actor = training_actor.to(device)
    _, training_critic, _ = create_model(
        input_size=observation_size, output_size=num_actions, out_key="action_value"
    )
    training_critic = training_critic.to(device)

    def create_env_fn():
        env = GymEnv(env_name, device=device)
        env = TransformedEnv(env)
        env.append_transform(rhs_transform)
        env.append_transform(InitTracker())
        return env

    collector = SyncDataCollector(
        create_env_fn=create_env_fn,
        policy=inference_actor,
        frames_per_batch=10,
        total_frames=100,
    )

    loss_module = DiscreteSACLoss(
        actor_network=training_actor,
        qvalue_network=training_critic,
        num_actions=num_actions,
        num_qvalue_nets=2,
        loss_function="smooth_l1",
    )

    for data in collector:
        loss = loss_module(data.cuda())
        loss_sum = loss["loss_actor"] + loss["loss_qvalue"] + loss["loss_alpha"]
        loss_sum.backward()

    collector.shutdown()
    print("Success!")

@vmoens
Copy link
Contributor

vmoens commented Dec 14, 2023

I don't think we need to create any env, actor, distribution, or even spec to test that bug. This isn't fit to be a unit test unfortunately. But no worry I will try to find a minimal example on my own :)

@vmoens
Copy link
Contributor

vmoens commented Dec 14, 2023

The issue isn't rnn-related, it is related to the fact that you share params between actor and value and we don't clone when we pass the value params. Hence when we optimize the params, the graph breaks.
I even wrote a comment about this:

rl/torchrl/objectives/sac.py

Lines 1208 to 1210 in 0906206

td_q = self._vmap_qnetworkN0(
td_q, self._cached_detached_qvalue_params # should we clone?
)

We should call clone whenever the qvalue and actor nets share params. I will write a fix

@vmoens vmoens closed this as completed Dec 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants