-
Notifications
You must be signed in to change notification settings - Fork 289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature]: Support for planners and CEM #384
Merged
Merged
Changes from 49 commits
Commits
Show all changes
60 commits
Select commit
Hold shift + click to select a range
e5d3ed1
Added MPC planner
nicolas-dufour 9098b39
Added CEM planning
nicolas-dufour fb43687
Planner refactoring
nicolas-dufour 1bf3ba4
Bug fixes
nicolas-dufour 316223d
Fixes
nicolas-dufour 0c1ad85
Added proposed fixes and tests
nicolas-dufour dc655b8
Merge branch 'main' into mpcp
nicolas-dufour 09fa7e7
Added stateful vs stateless distinction in EnvBase
nicolas-dufour 7eea053
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour 2940958
Fixed mock
nicolas-dufour 18ce61f
Renamed is stateful in batch_locked
nicolas-dufour ff5330a
Reverted gym for CI stability
nicolas-dufour 7f706cb
reverted gym version to pass CI
nicolas-dufour 5d96e19
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour b13b93b
Changed is_stateful
nicolas-dufour 608d5ba
Changed batched_lock to be a property that can't be changed and added…
nicolas-dufour 27e9001
Formatting
nicolas-dufour 15ec558
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour 3e2a33d
merged batch locked
nicolas-dufour 8d9dc8f
Changed to _new_
nicolas-dufour c758e13
fixed MockBatch_env
nicolas-dufour d0e13a3
fixed MockBatch_env
nicolas-dufour 2fed37a
Added BatchedEnv and fixed batch_locked
nicolas-dufour 58009d8
Changed for TransformedEnv
nicolas-dufour 5a79da2
Added test from Transformed Env
nicolas-dufour 65fe441
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour b3533e8
updated test
nicolas-dufour 148f3f5
Formatting
nicolas-dufour 6f3059c
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour aecbad2
Reverted gym downgrade
nicolas-dufour cc183ca
Merge branch 'main' into statefull_stateless
vmoens e05f963
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour f3766d3
Changed expand
nicolas-dufour 88b537f
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour 8e25244
Made suggested fix
nicolas-dufour dc1239c
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour 3e6bc26
Removed attribute from init
nicolas-dufour fef3f3f
Fixed expand in planner
nicolas-dufour 9ca7ff1
fixed TransformedEnv
nicolas-dufour 6dcf7bd
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour ed716a5
Merge branch 'main' into mpcp
nicolas-dufour e085bdf
Fixed in_place cls parameter passing
nicolas-dufour e84524a
Update common.py
nicolas-dufour fc5ca15
Merge branch 'main' into fix_in_place
nicolas-dufour 926d818
Merge branch 'main' into fix_in_place
nicolas-dufour afdf9f3
Merge branch 'fix_in_place' of github.com:nicolas-dufour/torchrl into…
nicolas-dufour d941f2a
Merge branch 'fix_in_place' into mpcp
nicolas-dufour f8f2a9d
Merge branch 'main' of github.com:facebookresearch/rl into mpcp
nicolas-dufour bda0e59
Added requested changes
nicolas-dufour 8f66c59
Doc fixes
nicolas-dufour 4dd923f
fixed test
nicolas-dufour aecafa3
Fix test on gpu
nicolas-dufour 1a6c6f1
Ran precommit
nicolas-dufour 7a10cca
Merge branch 'main' into mpcp
nicolas-dufour 182245a
fixed gpu tests
nicolas-dufour 5934220
linting
nicolas-dufour e812ad2
Update common.py
vmoens 82ce1b1
Merge branch 'main' into mpcp
vmoens 9796a33
example in ds
vmoens 7f54b8d
Merge branch 'main' into mpcp
vmoens File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ | |
from .distributions import * | ||
from .models import * | ||
from .tensordict_module import * | ||
from .planners import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from .common import * | ||
from .cem import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
|
||
from torchrl.data.tensordict.tensordict import TensorDictBase | ||
from torchrl.envs import EnvBase | ||
from torchrl.modules.planners import MPCPlannerBase | ||
|
||
__all__ = ["CEMPlanner"] | ||
|
||
|
||
class CEMPlanner(MPCPlannerBase): | ||
"""CEMPlanner Module. | ||
|
||
Reference: The cross-entropy method for optimization, Botev et al. 2013 | ||
|
||
This module will perform a CEM planning step when given a TensorDict containing initial states. | ||
The CEM planning step is performed by sampling actions from a Gaussian distribution with zero mean and unit variance. | ||
The sampled actions are then used to perform a rollout in the environment. The cumulative rewards obtained with the rollout is then | ||
ranked. We select the top-k episodes and use their actions to update the mean and standard deviation of the actions distribution. | ||
The CEM planning step is repeated for a specified number of steps. | ||
|
||
A call to the module returns the actions that empirically maximised the returns given a planning horizon | ||
|
||
Args: | ||
env (EnvBase): The environment to perform the planning step on (can be ```ModelBasedEnv``` or ```EnvBase```). | ||
nicolas-dufour marked this conversation as resolved.
Show resolved
Hide resolved
|
||
planning_horizon (int): The length of the simulated trajectories | ||
optim_steps (int): The number of optimization steps used by the MPC planner | ||
num_candidates (int): The number of candidates to sample from the Gaussian distributions. | ||
num_top_k_candidates (int): The number of top candidates to use to update the mean and standard deviation of the Gaussian distribution. | ||
nicolas-dufour marked this conversation as resolved.
Show resolved
Hide resolved
|
||
reward_key (str, optional): The key in the TensorDict to use to retrieve the reward. | ||
action_key (str, optional): The key in the TensorDict to use to store the action. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
env: EnvBase, | ||
planning_horizon: int, | ||
optim_steps: int, | ||
num_candidates: int, | ||
num_top_k_candidates: int, | ||
reward_key: str = "reward", | ||
action_key: str = "action", | ||
): | ||
super().__init__(env=env, action_key=action_key) | ||
self.planning_horizon = planning_horizon | ||
self.optim_steps = optim_steps | ||
self.num_candidates = num_candidates | ||
self.num_top_k_candidates = num_top_k_candidates | ||
self.reward_key = reward_key | ||
|
||
def planning(self, tensordict: TensorDictBase) -> torch.Tensor: | ||
batch_size = tensordict.batch_size | ||
expanded_original_tensordict = ( | ||
tensordict.unsqueeze(-1) | ||
.expand(*batch_size, self.num_candidates) | ||
.reshape(-1) | ||
) | ||
flatten_batch_size = batch_size.numel() | ||
actions_means = torch.zeros( | ||
flatten_batch_size, | ||
1, | ||
self.planning_horizon, | ||
*self.action_spec.shape, | ||
device=tensordict.device, | ||
dtype=self.env.action_spec.dtype, | ||
) | ||
actions_stds = torch.ones( | ||
flatten_batch_size, | ||
1, | ||
self.planning_horizon, | ||
*self.action_spec.shape, | ||
device=tensordict.device, | ||
dtype=self.env.action_spec.dtype, | ||
) | ||
|
||
for _ in range(self.optim_steps): | ||
actions = actions_means + actions_stds * torch.randn( | ||
flatten_batch_size, | ||
self.num_candidates, | ||
self.planning_horizon, | ||
*self.action_spec.shape, | ||
device=tensordict.device, | ||
dtype=self.env.action_spec.dtype, | ||
) | ||
actions = actions.flatten(0, 1) | ||
actions = self.env.action_spec.project(actions) | ||
optim_tensordict = expanded_original_tensordict.to_tensordict() | ||
policy = PrecomputedActionsSequentialSetter(actions) | ||
optim_tensordict = self.env.rollout( | ||
max_steps=self.planning_horizon, | ||
policy=policy, | ||
auto_reset=False, | ||
tensordict=optim_tensordict, | ||
) | ||
rewards = ( | ||
optim_tensordict.get(self.reward_key) | ||
.sum(dim=1) | ||
.reshape(flatten_batch_size, self.num_candidates) | ||
nicolas-dufour marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
_, top_k = rewards.topk(self.num_top_k_candidates, dim=1) | ||
|
||
best_actions = actions.unflatten( | ||
0, (flatten_batch_size, self.num_candidates) | ||
) | ||
best_actions = best_actions[ | ||
torch.arange(flatten_batch_size).unsqueeze(1), top_k | ||
] | ||
actions_means = best_actions.mean(dim=1, keepdim=True) | ||
actions_stds = best_actions.std(dim=1, keepdim=True) | ||
return actions_means[:, :, 0].reshape(*batch_size, *self.action_spec.shape) | ||
|
||
|
||
class PrecomputedActionsSequentialSetter: | ||
def __init__(self, actions): | ||
self.actions = actions | ||
self.cmpt = 0 | ||
|
||
def __call__(self, td): | ||
if self.cmpt >= self.actions.shape[1]: | ||
raise ValueError("Precomputed actions are too short") | ||
td = td.set("action", self.actions[:, self.cmpt]) | ||
self.cmpt += 1 | ||
return td |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
import abc | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
from torchrl.data.tensordict.tensordict import TensorDictBase | ||
from torchrl.envs import EnvBase | ||
from torchrl.modules import TensorDictModule | ||
|
||
__all__ = ["MPCPlannerBase"] | ||
|
||
|
||
class MPCPlannerBase(TensorDictModule, metaclass=abc.ABCMeta): | ||
"""MPCPlannerBase Module. | ||
|
||
This is an abstract class. | ||
|
||
This class inherits from TensorDictModule. Provided a TensorDict, this module will perform a Model Predictive Control (MPC) planning step. | ||
At the end of the planning step, the MPCPlanner will return a proposed action | ||
nicolas-dufour marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
env (EnvBase): The environment to perform the planning step on (Can be ModelBasedEnv or EnvBase). | ||
action_key (str, optional): The key that will point to the computed action | ||
nicolas-dufour marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
env: EnvBase, | ||
action_key: str = "action", | ||
): | ||
# Check if env is stateless | ||
if env.batch_locked: | ||
raise ValueError("Environment is batch_locked. MPCPlanners need an environnement that accepts batched inputs with any batch size") | ||
out_keys = [action_key] | ||
in_keys = list(env.observation_spec.keys()) | ||
super().__init__(env, in_keys=in_keys, out_keys=out_keys) | ||
self.env = env | ||
self.action_spec = env.action_spec | ||
self.to(env.device) | ||
|
||
@abc.abstractmethod | ||
def planning(self, td: TensorDictBase) -> torch.Tensor: | ||
"""Perform the MPC planning step. | ||
nicolas-dufour marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
td (TensorDict): The TensorDict to perform the planning step on. | ||
""" | ||
raise NotImplementedError() | ||
|
||
def forward( | ||
self, | ||
tensordict: TensorDictBase, | ||
tensordict_out: Optional[TensorDictBase] = None, | ||
**kwargs, | ||
) -> TensorDictBase: | ||
if "params" in kwargs or "vmap" in kwargs: | ||
raise ValueError("MPCPlannerBase does not support params or vmap for now.") | ||
nicolas-dufour marked this conversation as resolved.
Show resolved
Hide resolved
|
||
action = self.planning(tensordict) | ||
action = self.action_spec.project(action) | ||
tensordict_out = self._write_to_tensordict( | ||
tensordict, | ||
(action,), | ||
tensordict_out, | ||
) | ||
return tensordict_out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't you also test the values of the action?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what should i test them for? I'm already testing that they exit and have the right shape but i don't really have any information that could allow to test for the values