-
Notifications
You must be signed in to change notification settings - Fork 289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BUG] Python-based RNNs in place operations cause RuntimeError #1742
Comments
inplace ops are the devil :) |
@albertbou92 can you write a smaller (more standalone and mininal) reprod example that I could put in the tests? |
What about like this? import torch
from torchrl.collectors import SyncDataCollector
from torchrl.envs import InitTracker, TransformedEnv
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import GRUModule, ProbabilisticActor
from torchrl.modules.distributions import OneHotCategorical
from torchrl.objectives import DiscreteSACLoss
def create_model(input_size, output_size, num_layers=3, out_key="logits"):
gru_module = GRUModule(
input_size=input_size,
hidden_size=output_size,
num_layers=num_layers,
in_key="observation",
out_key=out_key,
python_based=True,
)
return (
gru_module,
gru_module.set_recurrent_mode(True),
gru_module.make_tensordict_primer(),
)
def test_python_gru(device):
env_name = "CartPole-v1"
test_env = GymEnv(env_name)
observation_size = test_env.observation_spec["observation"].shape[-1]
num_actions = int(test_env.action_spec.space.n)
inference_actor, training_actor, rhs_transform = create_model(
input_size=observation_size, output_size=num_actions
)
inference_actor = ProbabilisticActor(
module=inference_actor,
in_keys=["logits"],
out_keys=["action"],
distribution_class=OneHotCategorical,
return_log_prob=True,
)
training_actor = ProbabilisticActor(
module=training_actor,
in_keys=["logits"],
out_keys=["action"],
distribution_class=OneHotCategorical,
return_log_prob=True,
)
inference_actor = inference_actor.to(device)
training_actor = training_actor.to(device)
_, training_critic, _ = create_model(
input_size=observation_size, output_size=num_actions, out_key="action_value"
)
training_critic = training_critic.to(device)
def create_env_fn():
env = GymEnv(env_name, device=device)
env = TransformedEnv(env)
env.append_transform(rhs_transform)
env.append_transform(InitTracker())
return env
collector = SyncDataCollector(
create_env_fn=create_env_fn,
policy=inference_actor,
frames_per_batch=10,
total_frames=100,
)
loss_module = DiscreteSACLoss(
actor_network=training_actor,
qvalue_network=training_critic,
num_actions=num_actions,
num_qvalue_nets=2,
loss_function="smooth_l1",
)
for data in collector:
loss = loss_module(data.cuda())
loss_sum = loss["loss_actor"] + loss["loss_qvalue"] + loss["loss_alpha"]
loss_sum.backward()
collector.shutdown()
print("Success!") |
I don't think we need to create any env, actor, distribution, or even spec to test that bug. This isn't fit to be a unit test unfortunately. But no worry I will try to find a minimal example on my own :) |
The issue isn't rnn-related, it is related to the fact that you share params between actor and value and we don't clone when we pass the value params. Hence when we optimize the params, the graph breaks. Lines 1208 to 1210 in 0906206
We should call |
Describe the bug
Training with the Python-based GRU raises the following error, which indicates the current implementation has some in-place operations that prevent correct backward computation:
@vmoens you were right, in place operations cause problems here. I believe the issue is being addressed in #1732
To Reproduce
Checklist
The text was updated successfully, but these errors were encountered: