Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Support for planners and CEM #384

Merged
merged 60 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
e5d3ed1
Added MPC planner
nicolas-dufour Aug 30, 2022
9098b39
Added CEM planning
nicolas-dufour Aug 30, 2022
fb43687
Planner refactoring
nicolas-dufour Aug 31, 2022
1bf3ba4
Bug fixes
nicolas-dufour Sep 5, 2022
316223d
Fixes
nicolas-dufour Sep 6, 2022
0c1ad85
Added proposed fixes and tests
nicolas-dufour Sep 7, 2022
dc655b8
Merge branch 'main' into mpcp
nicolas-dufour Sep 7, 2022
09fa7e7
Added stateful vs stateless distinction in EnvBase
nicolas-dufour Sep 7, 2022
7eea053
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 7, 2022
2940958
Fixed mock
nicolas-dufour Sep 7, 2022
18ce61f
Renamed is stateful in batch_locked
nicolas-dufour Sep 8, 2022
ff5330a
Reverted gym for CI stability
nicolas-dufour Sep 8, 2022
7f706cb
reverted gym version to pass CI
nicolas-dufour Sep 8, 2022
5d96e19
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 8, 2022
b13b93b
Changed is_stateful
nicolas-dufour Sep 8, 2022
608d5ba
Changed batched_lock to be a property that can't be changed and added…
nicolas-dufour Sep 9, 2022
27e9001
Formatting
nicolas-dufour Sep 9, 2022
15ec558
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 9, 2022
3e2a33d
merged batch locked
nicolas-dufour Sep 9, 2022
8d9dc8f
Changed to _new_
nicolas-dufour Sep 9, 2022
c758e13
fixed MockBatch_env
nicolas-dufour Sep 9, 2022
d0e13a3
fixed MockBatch_env
nicolas-dufour Sep 9, 2022
2fed37a
Added BatchedEnv and fixed batch_locked
nicolas-dufour Sep 12, 2022
58009d8
Changed for TransformedEnv
nicolas-dufour Sep 12, 2022
5a79da2
Added test from Transformed Env
nicolas-dufour Sep 12, 2022
65fe441
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 12, 2022
b3533e8
updated test
nicolas-dufour Sep 12, 2022
148f3f5
Formatting
nicolas-dufour Sep 12, 2022
6f3059c
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 12, 2022
aecbad2
Reverted gym downgrade
nicolas-dufour Sep 13, 2022
cc183ca
Merge branch 'main' into statefull_stateless
vmoens Sep 13, 2022
e05f963
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 13, 2022
f3766d3
Changed expand
nicolas-dufour Sep 13, 2022
88b537f
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 13, 2022
8e25244
Made suggested fix
nicolas-dufour Sep 14, 2022
dc1239c
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 14, 2022
3e6bc26
Removed attribute from init
nicolas-dufour Sep 14, 2022
fef3f3f
Fixed expand in planner
nicolas-dufour Sep 14, 2022
9ca7ff1
fixed TransformedEnv
nicolas-dufour Sep 14, 2022
6dcf7bd
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 14, 2022
ed716a5
Merge branch 'main' into mpcp
nicolas-dufour Sep 16, 2022
e085bdf
Fixed in_place cls parameter passing
nicolas-dufour Sep 16, 2022
e84524a
Update common.py
nicolas-dufour Sep 16, 2022
fc5ca15
Merge branch 'main' into fix_in_place
nicolas-dufour Sep 16, 2022
926d818
Merge branch 'main' into fix_in_place
nicolas-dufour Sep 21, 2022
afdf9f3
Merge branch 'fix_in_place' of github.com:nicolas-dufour/torchrl into…
nicolas-dufour Sep 21, 2022
d941f2a
Merge branch 'fix_in_place' into mpcp
nicolas-dufour Sep 21, 2022
f8f2a9d
Merge branch 'main' of github.com:facebookresearch/rl into mpcp
nicolas-dufour Sep 22, 2022
bda0e59
Added requested changes
nicolas-dufour Sep 22, 2022
8f66c59
Doc fixes
nicolas-dufour Sep 23, 2022
4dd923f
fixed test
nicolas-dufour Sep 23, 2022
aecafa3
Fix test on gpu
nicolas-dufour Sep 23, 2022
1a6c6f1
Ran precommit
nicolas-dufour Sep 23, 2022
7a10cca
Merge branch 'main' into mpcp
nicolas-dufour Sep 23, 2022
182245a
fixed gpu tests
nicolas-dufour Sep 23, 2022
5934220
linting
nicolas-dufour Sep 23, 2022
e812ad2
Update common.py
vmoens Sep 23, 2022
82ce1b1
Merge branch 'main' into mpcp
vmoens Sep 23, 2022
9796a33
example in ds
vmoens Sep 23, 2022
7f54b8d
Merge branch 'main' into mpcp
vmoens Sep 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import torch
from _utils_internal import get_available_devices
from mocking_classes import MockBatchedUnLockedEnv
from torch import nn
from torchrl.data import TensorDict
from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec
Expand All @@ -18,6 +19,7 @@
ValueOperator,
ProbabilisticActor,
LSTMNet,
CEMPlanner,
)
from torchrl.modules.functional_modules import (
FunctionalModule,
Expand Down Expand Up @@ -326,6 +328,31 @@ def test_func_transformer(self):
torch.testing.assert_close(fmodule(params, buffers, x, x), module(x, x))


class TestPlanner:
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("batch_size", [3, 5])
def test_CEM_model_free_env(self, device, batch_size, seed=1):
env = MockBatchedUnLockedEnv(device=device)
env.set_seed(seed)
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
planner = CEMPlanner(
env,
planning_horizon=10,
optim_steps=2,
num_candidates=100,
num_top_k_candidates=2,
).to(device)
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
td = env.reset(TensorDict({}, batch_size=batch_size)).to(device)
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
td_copy = td.clone()
td = planner(td)
assert td.get("action").shape[1:] == env.action_spec.shape

assert env.action_spec.is_in(td.get("action"))

for key in td.keys():
if key != "action":
assert torch.allclose(td[key], td_copy[key])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't you also test the values of the action?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what should i test them for? I'm already testing that they exit and have the right shape but i don't really have any information that could allow to test for the values



if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
4 changes: 2 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,12 @@ def __init__(
self.batch_size = torch.Size([])

@classmethod
def __new__(cls, *args, _batch_locked=True, **kwargs):
def __new__(cls, *args, _inplace_update=False, _batch_locked=True, **kwargs):
# inplace update will write tensors in-place on the provided tensordict.
# This is risky, especially if gradients need to be passed (in-place copy
# for tensors that are part of computational graphs will result in an error).
# It can also lead to inconsistencies when calling rollout.
cls._inplace_update = False
cls._inplace_update = _inplace_update
cls._batch_locked = _batch_locked
return super().__new__(cls)

Expand Down
10 changes: 10 additions & 0 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,16 @@ def __init__(
self._observation_spec = None
self.batch_size = self.base_env.batch_size

def __new__(cls, env, *args, **kwargs):
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
return super().__new__(
cls,
env,
*args,
_inplace_update=env._inplace_update,
_batch_locked=env.batch_locked,
**kwargs,
)

def _set_env(self, env: EnvBase, device) -> None:
self.base_env = env.to(device)
# updates need not be inplace, as transforms may modify values out-place
Expand Down
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .distributions import *
from .models import *
from .tensordict_module import *
from .planners import *
7 changes: 7 additions & 0 deletions torchrl/modules/planners/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .common import *
from .cem import *
130 changes: 130 additions & 0 deletions torchrl/modules/planners/cem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch

from torchrl.data.tensordict.tensordict import TensorDictBase
from torchrl.envs import EnvBase
from torchrl.modules.planners import MPCPlannerBase

__all__ = ["CEMPlanner"]


class CEMPlanner(MPCPlannerBase):
"""
CEMPlanner Module. This class inherits from TensorDictModule.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved

Provided a TensorDict, this module will perform a CEM planning step.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
The CEM planning step is performed by sampling actions from a Gaussian distribution with zero mean and unit variance.
The actions are then used to perform a rollout in the environment.
The rewards are then used to update the mean and standard deviation of the Gaussian distribution.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
The mean and standard deviation of the Gaussian distribution are then used to sample actions for the next planning step.
The CEM planning step is repeated for a specified number of steps.
At the end, we recover the best action which is the one that maximizes the reward given a planning horizon.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved

Args:
env (Environment): The environment to perform the planning step on (Can be ModelBasedEnv or EnvBase).
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
planning_horizon (int): The number of steps to perform the planning step for.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
optim_steps (int): The number of steps to perform the MPC planning step for.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
num_candidates (int): The number of candidates to sample from the Gaussian distribution.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
num_top_k_candidates (int): The number of top candidates to use to update the mean and standard deviation of the Gaussian distribution.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
reward_key (str): The key in the TensorDict to use to retrieve the reward.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
action_key (str): The key in the TensorDict to use to store the action.

Returns:
TensorDict: The TensorDict with the action added.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
env: EnvBase,
planning_horizon: int,
optim_steps: int,
num_candidates: int,
num_top_k_candidates: int,
reward_key: str = "reward",
action_key: str = "action",
):
super().__init__(env=env, action_key=action_key)
self.planning_horizon = planning_horizon
self.optim_steps = optim_steps
self.num_candidates = num_candidates
self.num_top_k_candidates = num_top_k_candidates
self.reward_key = reward_key

def planning(self, td: TensorDictBase) -> torch.Tensor:
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to prop gradients here, or do we want to set a no_grad decorator?

batch_size = td.batch_size
expanded_original_td = (
td.unsqueeze(-1)
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
.expand(*batch_size, self.num_candidates)
.contiguous()
.view(-1)
)
flatten_batch_size = batch_size.numel()
actions_means = torch.zeros(
flatten_batch_size,
1,
self.planning_horizon,
*self.action_spec.shape,
device=td.device,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI
Soon we'll have tensordicts with device=None
No need to change anything here but keep it in mind when you call tensordict.device

dtype=self.env.action_spec.dtype,
)
actions_stds = torch.ones(
flatten_batch_size,
1,
self.planning_horizon,
*self.action_spec.shape,
device=td.device,
dtype=self.env.action_spec.dtype,
)

for _ in range(self.optim_steps):
actions = actions_means + actions_stds * torch.randn(
flatten_batch_size,
self.num_candidates,
self.planning_horizon,
*self.action_spec.shape,
device=td.device,
dtype=self.env.action_spec.dtype,
)
actions = actions.flatten(0, 1)
actions = self.env.action_spec.project(actions)
optim_td = expanded_original_td.to_tensordict()
policy = PrecomputedActionsSequentialSetter(actions)
optim_td = self.env.rollout(
max_steps=self.planning_horizon,
policy=policy,
auto_reset=False,
tensordict=optim_td,
)
rewards = (
optim_td.get(self.reward_key)
.sum(dim=1)
.reshape(flatten_batch_size, self.num_candidates)
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
)
_, top_k = rewards.topk(self.num_top_k_candidates, dim=1)

best_actions = actions.unflatten(
0, (flatten_batch_size, self.num_candidates)
)
best_actions = best_actions[
torch.arange(flatten_batch_size).unsqueeze(1), top_k
]
actions_means = best_actions.mean(dim=1, keepdim=True)
actions_stds = best_actions.std(dim=1, keepdim=True)
return (actions_means[:, :, 0]).reshape(*batch_size, *self.action_spec.shape)
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved


class PrecomputedActionsSequentialSetter:
def __init__(self, actions):
self.actions = actions
self.cmpt = 0

def __call__(self, td):
if self.cmpt >= self.actions.shape[1]:
raise ValueError("Precomputed actions are too short")
td = td.set("action", self.actions[:, self.cmpt])
self.cmpt += 1
return td
72 changes: 72 additions & 0 deletions torchrl/modules/planners/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import abc
from typing import Optional

import torch

from torchrl.data.tensordict.tensordict import TensorDictBase
from torchrl.envs import EnvBase
from torchrl.modules import TensorDictModule

__all__ = ["MPCPlannerBase"]


class MPCPlannerBase(TensorDictModule, metaclass=abc.ABCMeta):
"""
MPCPlannerBase Module. This is an abstract class and must be implemented by the user.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just say it is an abstract class.
Also I would remove this from the headline. Or just say "MPCPlannerBase base module" which says it all

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget: no new line between """ and the next word (i'm in the process of correcting the whole lib for these so I'd appreciate if new PRs don't have that)


This class inherits from TensorDictModule. Provided a TensorDict, this module will perform a Model Predictive Control (MPC) planning step.
At the end of the planning step, the MPCPlanner will return the action that should be taken.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved

Args:
env (Environment): The environment to perform the planning step on (Can be ModelBasedEnv or EnvBase).
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
action_key (str): The key in the TensorDict to use to store the action.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved

Returns:
TensorDict: The TensorDict with the action added.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
env: EnvBase,
action_key: str = "action",
):
# Check if env is stateless
if env.batch_locked:
raise ValueError("Environment is not stateless")
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
out_keys = [action_key]
in_keys = list(env.observation_spec.keys())
super().__init__(env, in_keys=in_keys, out_keys=out_keys)
self.env = env
self.action_spec = env.action_spec

@abc.abstractmethod
def planning(self, td: TensorDictBase) -> torch.Tensor:
"""
Perform the MPC planning step.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
Args:
td (TensorDict): The TensorDict to perform the planning step on.
Returns:
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
TensorDict: The TensorDict with the action added.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
"""
raise NotImplementedError()

def forward(
self,
tensordict: TensorDictBase,
tensordict_out: Optional[TensorDictBase] = None,
**kwargs,
) -> TensorDictBase:
if "params" in kwargs or "vmap" in kwargs:
raise ValueError("params not supported")
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
action = self.planning(tensordict)
action = self.action_spec.project(action)
tensordict_out = self._write_to_tensordict(
tensordict,
(action,),
tensordict_out,
)
return tensordict_out
2 changes: 1 addition & 1 deletion torchrl/trainers/helpers/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def make_env_transforms(
key
for key in env.observation_spec.keys()
if ("pixels" not in key)
and (key.strip("next_") not in env.input_spec.keys())
and (key.replace("next_", "") not in env.input_spec.keys())
]

# even if there is a single tensor, it'll be renamed in "next_observation_vector"
Expand Down