Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION] How to reset only certain nested parts of a key with TensorDictPrimer? #2053

Closed
kfu02 opened this issue Apr 2, 2024 · 3 comments · Fixed by #2071
Closed

[QUESTION] How to reset only certain nested parts of a key with TensorDictPrimer? #2053

kfu02 opened this issue Apr 2, 2024 · 3 comments · Fixed by #2071

Comments

@kfu02
Copy link

kfu02 commented Apr 2, 2024

Hi, I have an observation spec for a multi-agent environment which looks like this:

CompositeSpec(
    agents: CompositeSpec(
        observation: UnboundedContinuousTensorSpec(
            shape=torch.Size([100, 2, 14]),
            space=None,
            device=cuda:0,
            dtype=torch.float32,
            domain=continuous),
        episode_reward: UnboundedContinuousTensorSpec(
            shape=torch.Size([100, 2, 1]),
            space=None,
            device=cuda:0,
            dtype=torch.float32,
            domain=continuous),
        edge_index: UnboundedContinuousTensorSpec(
            shape=torch.Size([100, 2, 2, 2]),
            space=None,
            device=cuda:0,
            dtype=torch.float32,
            domain=continuous), device=cuda:0, shape=torch.Size([100, 2])),
...

Here, the key ("agents", "edge_index") is a special field that I populate once upon creating the env and never want to change.

My problem is that I would like to add a recurrent policy, which requires tracking the hidden state for each agent. I read the Recurrent DQN tutorial, but the LSTMModule's make_tensordict_primer() does not quite work for me as it is designed for the single-agent case.

Thus I have tried to write a custom TensorDictPrimer transform, like so:

existing_obs_spec = env.observation_spec
hidden_state_spec = UnboundedContinuousTensorSpec(shape=(*env.observation_spec["agents"].shape[:2], cfg.actor.gru.num_layers, cfg.actor.gru.hidden_size), device=cfg.env.device)
existing_obs_spec[("agents", "hidden_state")] = hidden_state_spec
env.append_transform(TensorDictPrimer(existing_obs_spec))

However I notice that on environment resets, this TensorDictPrimer now overwrites all the fields in this spec with 0s. I have attempted to specify the TensorDictPrimer's input keys as solely the ("agents", "hidden_state") key I want to zero-out, but when I do so, I end up losing the other nested keys under "agents" on reset.

Am I misunderstanding the usage of TensorDictPrimer? Any help would be appreciated.

@kfu02
Copy link
Author

kfu02 commented Apr 2, 2024

For further clarity on the "end up losing the other nested keys" part when I only specify the key I would like, here I print the observation spec before and after adding my TensorDictPrimer transform:

    env = TransformedEnv(env, Compose(InitTracker()))
    print(env.observation_spec)
    td = env.reset()

    hidden_state_spec = UnboundedContinuousTensorSpec(shape=(*env.observation_spec["agents"].shape[:2], cfg.actor.gru.num_layers, cfg.actor.gru.hidden_size), device=cfg.env.device)
    new_hidden_spec = CompositeSpec(
            agents=CompositeSpec(
                hidden_state=hidden_state_spec,
                shape=(hidden_state_spec.shape[0], hidden_state_spec.shape[1])
            ),
            shape=[hidden_state_spec.shape[0]],
    )

    print("new_hidden_spec", new_hidden_spec)
    env.append_transform(TensorDictPrimer(new_hidden_spec))
    print(env.observation_spec)

Observation spec before:

CompositeSpec(
    agents: CompositeSpec(
        observation: UnboundedContinuousTensorSpec(
            shape=torch.Size([100, 2, 11]),
            space=None,
            device=cuda:0,
            dtype=torch.float32,
            domain=continuous), device=cuda:0, shape=torch.Size([100, 2])),
    is_init: DiscreteTensorSpec(
        shape=torch.Size([100, 1]),
        space=DiscreteBox(n=2),
        device=cuda:0,
        dtype=torch.bool,
        domain=discrete), device=cuda:0, shape=torch.Size([100]))

Observation spec after:

CompositeSpec(
    agents: CompositeSpec(
        hidden_state: UnboundedContinuousTensorSpec(
            shape=torch.Size([100, 2, 2, 32]),
            space=None,
            device=cuda:0,
            dtype=torch.float32,
            domain=continuous), device=cuda:0, shape=torch.Size([100, 2])),
    is_init: DiscreteTensorSpec(
        shape=torch.Size([100, 1]),
        space=DiscreteBox(n=2),
        device=cuda:0,
        dtype=torch.bool,
        domain=discrete), device=cuda:0, shape=torch.Size([100]))

The hidden_spec I am passing to TensorDictPrimer:

new_hidden_spec CompositeSpec(
    agents: CompositeSpec(
        hidden_state: UnboundedContinuousTensorSpec(
            shape=torch.Size([100, 2, 2, 32]),
            space=None,
            device=cuda:0,
            dtype=torch.float32,
            domain=continuous), device=None, shape=torch.Size([100, 2])), device=None, shape=torch.Size([100]))

@kfu02
Copy link
Author

kfu02 commented Apr 2, 2024

For the record, I am able to work around this issue by simply specifying a different key for the hidden states, e.g. ("agents_hs", "hidden_state") which avoids overwriting the original obs_spec or zeroing out other fields at the same nesting level. I would just like to know if this dilemma is avoidable.

@vmoens
Copy link
Contributor

vmoens commented Apr 11, 2024

On it sorry for the delay.
I guess we can check that this bug is solved in #2071

@vmoens vmoens linked a pull request Apr 18, 2024 that will close this issue
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants