Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Dreamer support #341

Merged
merged 609 commits into from
Oct 20, 2022
Merged

[Feature]: Dreamer support #341

merged 609 commits into from
Oct 20, 2022

Conversation

nicolas-dufour
Copy link
Contributor

@nicolas-dufour nicolas-dufour commented Aug 5, 2022

Description

In this PR we add Dreamer a model based environnement.

Implemented objects

To retrieve the dreamers objects, call make_dreamer

from torchrl.trainers.helpers.models import make_dreamer
world_model, model_based_env, actor_model, value_model, policy = make_dreamer(
        proof_environment=proof_env, cfg=cfg, device=device, use_decoder_in_env=True,action_key="action",
        value_key="predicted_value",
    )

proof_env the env on which we are going to train afterward, cfg the config file (see DreamerConfig in torchrl.trainers.helpers.models)

world_model is the world model of dreamer. From a given obs, it predicts a reward, a latent world state and a reconstruction of the obs.
model_based_env is an env that works on the latent world state (no obs). It can generate new latent world state from initial latent world state. If use_decoder_in_env, the env can decode the generated states with decode_obs method
actor_model is the associated dreamer actor model
value_model predicts a value
policy combines world_model and actor, such as from a given observation, we predict the action to take

We also provide 3 loss models:

world_model_loss = DreamerModelLoss(world_model, cfg).to(device)
actor_loss = DreamerActorLoss(
    actor_model, value_model, model_based_env, cfg
).to(device)
value_loss = DreamerValueLoss(value_model).to(device)

This 3 losses module will each enable training the 5 models above.

Usecase of model_based_env

Our Env allow us to generate new data.
To do so we can do:

td  = model_based_env.rollout(
    max_steps=self.cfg.imagination_horizon,
    policy=self.actor_model,
)

By default, the env will get reseted to have defaults states of zeros. The better way to sample from it is to have a previous world state (lets call it td)

td  = model_based_env.rollout(
    max_steps=self.cfg.imagination_horizon,
    policy=self.actor_model,
    auto_reset=False,
    tensordict=td
)

Motivation and Context

Dreamer is very data efficient and will enable easy use of model based methods.

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • New feature (non-breaking change which adds core functionality)
  • Documentation (update in the documentation)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 5, 2022
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some high level comments.
I would rather have a DreamerLoss module that contains all that is not needed for inference than one big env that does everything. Do you agree?
Have a look at how we do it for SAC and REDQ for instance.

torchrl/envs/mb_envs/dreamer.py Outdated Show resolved Hide resolved
torchrl/envs/mb_envs/dreamer.py Outdated Show resolved Hide resolved
torchrl/envs/mb_envs/dreamer.py Outdated Show resolved Hide resolved
torchrl/envs/mb_envs/dreamer.py Outdated Show resolved Hide resolved
torchrl/envs/mb_envs/dreamer.py Outdated Show resolved Hide resolved
torchrl/envs/mb_envs/dreamer.py Outdated Show resolved Hide resolved
torchrl/envs/mb_envs/dreamer.py Outdated Show resolved Hide resolved
torchrl/envs/mb_envs/dreamer.py Outdated Show resolved Hide resolved
torchrl/envs/mb_envs/dreamer.py Outdated Show resolved Hide resolved
torchrl/envs/mb_envs/dreamer.py Outdated Show resolved Hide resolved
@nicolas-dufour nicolas-dufour marked this pull request as draft August 10, 2022 10:55
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the TorchRL primitives instead.
Also let's use LSTM(batched_input)
With batched_input with size [B, T] instead of looping over the LSTM. It uses CuDNN and it's way way faster.
Let's avoid building distributions unless it's really necessary.

torchrl/modules/models/models.py Outdated Show resolved Hide resolved
torchrl/modules/models/models.py Outdated Show resolved Hide resolved
torchrl/modules/models/models.py Outdated Show resolved Hide resolved
torchrl/modules/models/models.py Outdated Show resolved Hide resolved
torchrl/modules/models/models.py Outdated Show resolved Hide resolved
torchrl/modules/models/models.py Outdated Show resolved Hide resolved
torchrl/modules/models/models.py Outdated Show resolved Hide resolved
torchrl/envs/model_based.py Outdated Show resolved Hide resolved
torchrl/modules/models/models.py Outdated Show resolved Hide resolved
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments to improve efficiency

torchrl/objectives/costs/dreamer.py Outdated Show resolved Hide resolved
torchrl/objectives/costs/dreamer.py Outdated Show resolved Hide resolved
)
actor_loss = -lambda_target.mean()
with torch.no_grad():
value_td = tensordict.clone().detach()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Detach under no_grad?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not detach we have graph optimization overlap

torchrl/objectives/costs/utils.py Outdated Show resolved Hide resolved
torchrl/modules/models/models.py Outdated Show resolved Hide resolved
obs_decoded = obs_decoded.reshape(*batch_sizes, C, H, W)
return obs_decoded

class RSSMPriorRollout(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me this would work perfectly with tensordictmodule.
That would permit us to preallocate the tensors of the rollout, which should be more efficient than stacking the outputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean not having RSSMPriorRollout as an nn.Module but doing the loop over a TDModule? But then how would you integrate this with a TDSequence?

class RSSMPrior(nn.Module):
def __init__(self, hidden_dim=200, rnn_hidden_dim=200, state_dim=20):
super().__init__()
self.min_std = 0.1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this a hyperparam

torchrl/modules/models/models.py Outdated Show resolved Hide resolved
torchrl/modules/models/models.py Outdated Show resolved Hide resolved
return (
TensorDict(
{
"loss_world_model": loss,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you are returning 4 losses, one of them being a sum of the others. To make sure that we don't do anything stupid like re-summing the losses (which is what the trainer will do) you should either return one loss only (but then we won't be able to log each of them individually) or return the decomposed loss only.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what i 've seen from trainer it doesnt retrieve losses if it doesnt start with "loss_". So renaming loss_kl by kl and so on would do the trick no?

torchrl/objectives/costs/utils.py Outdated Show resolved Hide resolved
torchrl/trainers/helpers/envs.py Outdated Show resolved Hide resolved
@nicolas-dufour nicolas-dufour changed the title [Feature]: Dreamer ModelBased env [Feature]: Dreamer support Aug 16, 2022
examples/dreamer/dreamer.py Show resolved Hide resolved
scaler.unscale_(value_opt)
clip_grad_norm_(value_model.parameters(), cfg.grad_clip)

scaler.step(world_model_opt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do

loss1 = ...
Optim1.step()

Loss2 = ...
Optim2.Step()
Etc

Like this we allow the gpu to free mem when calling backward

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was doing it like that before, however the pb is that autocast struggle in this context. Indeed according to pytorch doc only a single scaler can be created and you cannot scale again after unscale_ . I felt that 16bit precision would be better to have

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you scale / unscale multiple times? What would be the difference between that and scaling/unscaling through a loop?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

                with autocast(dtype=torch.float16):
                    model_loss_td, sampled_tensordict = world_model_loss(
                        sampled_tensordict
                    )
                    if cfg.record_video:
                        world_model_td = sampled_tensordict.clone().select(
                            "pixels", "reco_pixels", "posterior_states", "next_belief"

                        )[:4].detach()
                scaler1.scale(model_loss_td["loss_world_model"]).backward()
                scaler1.unscale_(world_model_opt)
                clip_grad_norm_(world_model.parameters(), cfg.grad_clip)
                scaler1.step(world_model_opt)
                world_model_opt.zero_grad()
                scaler1.update()

                with autocast(dtype=torch.float16):
                    actor_loss_td, sampled_tensordict = actor_loss(sampled_tensordict)
                scaler2.scale(actor_loss_td["loss_actor"]).backward()
                scaler2.unscale_(actor_opt)
                clip_grad_norm_(actor_model.parameters(), cfg.grad_clip)
                scaler2.step(actor_opt)
                actor_opt.zero_grad()
                scaler2.update()

                with autocast(dtype=torch.float16):
                    value_loss_td, sampled_tensordict = value_loss(sampled_tensordict)
                scaler3.scale(value_loss_td["loss_value"]).backward()
                scaler3.unscale_(value_opt)
                clip_grad_norm_(value_model.parameters(), cfg.grad_clip)
                scaler3.step(value_opt)
                value_opt.zero_grad()
                scaler3.update()

batch_size=None,
):
super(DummyModelBasedEnv, self).__init__(
WorldModelWrapper(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's avoid building thing inside a caller, it's just syntax but it feels messy

Copy link
Contributor Author

@nicolas-dufour nicolas-dufour Aug 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change that in the MBEnv PR instead

observation, start_dim=0, end_dim=end_dim
)
obs_encoded = self.encoder(observation)
latent = obs_encoded.reshape(*batch_sizes, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need reshape or does view work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

View does not work in this case

@@ -0,0 +1,181 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like some of the changes in this PR belong in #333? It would be cleaner for code review and log history to have this PR only include the dreamer specific implementations of the abstractions provided in #333.

examples/dreamer/dreamer.py Outdated Show resolved Hide resolved
policy=actor_model,
auto_reset=False,
tensordict=world_model_td[:, 0],
).detach()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is already under a no_grad(), no need to detach

Copy link
Contributor Author

@nicolas-dufour nicolas-dufour Aug 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've seen memory explosion without the detach. Tell me if i'm wrong but no grad makes sure to not collect gradients but it doesn't detach elements that are already in the graph no?

imagine_pxls = recover_pixels(model_based_env.decode_obs(world_model_td)["reco_pixels"], stats)

stacked_pixels = torch.cat([true_pixels, reco_pixels, imagine_pxls], dim=-1)
logger.log_video(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does t appear in the log that we have way more reconstructions than actual pixels? Storing pixels is heavy and can quickly overload the disk

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"way more reconstructions than actual pixels?" -> What do you mean by this?

sampled_tensordict
)
if cfg.record_video:
world_model_td = sampled_tensordict.clone().select(
Copy link
Contributor

@vmoens vmoens Aug 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

careful here: you clone the whole thing then select, I would do the opposite (select -> clone)
Not even sure clone is needed (select is not done in-place unless specified).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was doing the opposite before but this was changing the original tensordict, a lot of keys were missing for the actor part after, so that's why i reverted it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no select does not change the original tensordict! Or there is a serious bug!

# update weights of the inference policy
collector.update_policy_weights_()

if r0 is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should log the training rewards somewhere

scaler.update()

with torch.no_grad(), set_exploration_mode("mode"):
td_record = record(None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on my end this step does not log any video. Perhaps the way the global_step is indicated in TensorBoard conflicts with the Wandb api?


scaler.update()

with torch.no_grad(), set_exploration_mode("mode"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally prefer to have this out of the inner training loop. The reason is that it's easier to control the number of collection steps than the number of training steps, and small changes in the config can have a great impact in the number of evaluation data collections, which has a great physical memory cost and compute time cost.

examples/dreamer/dreamer.py Outdated Show resolved Hide resolved
else:
current_frames = tensordict.numel()
collected_frames += current_frames
tensordict = tensordict.reshape(-1, cfg.batch_length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Doesn't the tensordict already have this size?
This will break if there is a "mask" key in the tensordict (see my comment above)
Also reshape should be replaced by 'view' no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the tensordict has the size of the collected batch which is n_workers x max_frames_per_traj with max_frames_per_traj=1000, but then we want to use tensors of size B x batch_length with batch_length = 50 in dreamer default

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure that the resulting tensordict will always be properly shaped? Also that might break no, if batch_length does not match the size of the collected data. e.g. what happens if max_frames_per_traj=789 and batch_length=25?

)
from torchrl.trainers.helpers.models import (
make_dreamer,
DreamerConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why importing this when we redefine it after?

@vmoens vmoens added the new algo New algorithm request or PR label Oct 19, 2022
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@vmoens vmoens merged commit e1fbf86 into pytorch:main Oct 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request new algo New algorithm request or PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants