Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Support for planners and CEM #384

Merged
merged 60 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
e5d3ed1
Added MPC planner
nicolas-dufour Aug 30, 2022
9098b39
Added CEM planning
nicolas-dufour Aug 30, 2022
fb43687
Planner refactoring
nicolas-dufour Aug 31, 2022
1bf3ba4
Bug fixes
nicolas-dufour Sep 5, 2022
316223d
Fixes
nicolas-dufour Sep 6, 2022
0c1ad85
Added proposed fixes and tests
nicolas-dufour Sep 7, 2022
dc655b8
Merge branch 'main' into mpcp
nicolas-dufour Sep 7, 2022
09fa7e7
Added stateful vs stateless distinction in EnvBase
nicolas-dufour Sep 7, 2022
7eea053
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 7, 2022
2940958
Fixed mock
nicolas-dufour Sep 7, 2022
18ce61f
Renamed is stateful in batch_locked
nicolas-dufour Sep 8, 2022
ff5330a
Reverted gym for CI stability
nicolas-dufour Sep 8, 2022
7f706cb
reverted gym version to pass CI
nicolas-dufour Sep 8, 2022
5d96e19
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 8, 2022
b13b93b
Changed is_stateful
nicolas-dufour Sep 8, 2022
608d5ba
Changed batched_lock to be a property that can't be changed and added…
nicolas-dufour Sep 9, 2022
27e9001
Formatting
nicolas-dufour Sep 9, 2022
15ec558
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 9, 2022
3e2a33d
merged batch locked
nicolas-dufour Sep 9, 2022
8d9dc8f
Changed to _new_
nicolas-dufour Sep 9, 2022
c758e13
fixed MockBatch_env
nicolas-dufour Sep 9, 2022
d0e13a3
fixed MockBatch_env
nicolas-dufour Sep 9, 2022
2fed37a
Added BatchedEnv and fixed batch_locked
nicolas-dufour Sep 12, 2022
58009d8
Changed for TransformedEnv
nicolas-dufour Sep 12, 2022
5a79da2
Added test from Transformed Env
nicolas-dufour Sep 12, 2022
65fe441
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 12, 2022
b3533e8
updated test
nicolas-dufour Sep 12, 2022
148f3f5
Formatting
nicolas-dufour Sep 12, 2022
6f3059c
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 12, 2022
aecbad2
Reverted gym downgrade
nicolas-dufour Sep 13, 2022
cc183ca
Merge branch 'main' into statefull_stateless
vmoens Sep 13, 2022
e05f963
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 13, 2022
f3766d3
Changed expand
nicolas-dufour Sep 13, 2022
88b537f
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 13, 2022
8e25244
Made suggested fix
nicolas-dufour Sep 14, 2022
dc1239c
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 14, 2022
3e6bc26
Removed attribute from init
nicolas-dufour Sep 14, 2022
fef3f3f
Fixed expand in planner
nicolas-dufour Sep 14, 2022
9ca7ff1
fixed TransformedEnv
nicolas-dufour Sep 14, 2022
6dcf7bd
Merge branch 'statefull_stateless' into mpcp
nicolas-dufour Sep 14, 2022
ed716a5
Merge branch 'main' into mpcp
nicolas-dufour Sep 16, 2022
e085bdf
Fixed in_place cls parameter passing
nicolas-dufour Sep 16, 2022
e84524a
Update common.py
nicolas-dufour Sep 16, 2022
fc5ca15
Merge branch 'main' into fix_in_place
nicolas-dufour Sep 16, 2022
926d818
Merge branch 'main' into fix_in_place
nicolas-dufour Sep 21, 2022
afdf9f3
Merge branch 'fix_in_place' of github.com:nicolas-dufour/torchrl into…
nicolas-dufour Sep 21, 2022
d941f2a
Merge branch 'fix_in_place' into mpcp
nicolas-dufour Sep 21, 2022
f8f2a9d
Merge branch 'main' of github.com:facebookresearch/rl into mpcp
nicolas-dufour Sep 22, 2022
bda0e59
Added requested changes
nicolas-dufour Sep 22, 2022
8f66c59
Doc fixes
nicolas-dufour Sep 23, 2022
4dd923f
fixed test
nicolas-dufour Sep 23, 2022
aecafa3
Fix test on gpu
nicolas-dufour Sep 23, 2022
1a6c6f1
Ran precommit
nicolas-dufour Sep 23, 2022
7a10cca
Merge branch 'main' into mpcp
nicolas-dufour Sep 23, 2022
182245a
fixed gpu tests
nicolas-dufour Sep 23, 2022
5934220
linting
nicolas-dufour Sep 23, 2022
e812ad2
Update common.py
vmoens Sep 23, 2022
82ce1b1
Merge branch 'main' into mpcp
vmoens Sep 23, 2022
9796a33
example in ds
vmoens Sep 23, 2022
7f54b8d
Merge branch 'main' into mpcp
vmoens Sep 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import torch
from _utils_internal import get_available_devices
from mocking_classes import MockBatchedUnLockedEnv
from torch import nn
from torchrl.data import TensorDict
from torchrl.data.tensor_specs import OneHotDiscreteTensorSpec
Expand All @@ -18,6 +19,7 @@
ValueOperator,
ProbabilisticActor,
LSTMNet,
CEMPlanner,
)
from torchrl.modules.functional_modules import (
FunctionalModule,
Expand Down Expand Up @@ -326,6 +328,31 @@ def test_func_transformer(self):
torch.testing.assert_close(fmodule(params, buffers, x, x), module(x, x))


class TestPlanner:
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("batch_size", [3, 5])
def test_CEM_model_free_env(self, device, batch_size, seed=1):
env = MockBatchedUnLockedEnv(device=device)
torch.manual_seed(seed)
planner = CEMPlanner(
env,
planning_horizon=10,
optim_steps=2,
num_candidates=100,
num_top_k_candidates=2,
)
td = env.reset(TensorDict({}, batch_size=batch_size))
td_copy = td.clone()
td = planner(td)
assert td.get("action").shape[1:] == env.action_spec.shape

assert env.action_spec.is_in(td.get("action"))

for key in td.keys():
if key != "action":
assert torch.allclose(td[key], td_copy[key])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't you also test the values of the action?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what should i test them for? I'm already testing that they exit and have the right shape but i don't really have any information that could allow to test for the values



if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1 change: 1 addition & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .distributions import *
from .models import *
from .tensordict_module import *
from .planners import *
7 changes: 7 additions & 0 deletions torchrl/modules/planners/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .common import *
from .cem import *
127 changes: 127 additions & 0 deletions torchrl/modules/planners/cem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch

from torchrl.data.tensordict.tensordict import TensorDictBase
from torchrl.envs import EnvBase
from torchrl.modules.planners import MPCPlannerBase

__all__ = ["CEMPlanner"]


class CEMPlanner(MPCPlannerBase):
"""CEMPlanner Module.

Reference: The cross-entropy method for optimization, Botev et al. 2013

This module will perform a CEM planning step when given a TensorDict containing initial states.
The CEM planning step is performed by sampling actions from a Gaussian distribution with zero mean and unit variance.
The sampled actions are then used to perform a rollout in the environment. The cumulative rewards obtained with the rollout is then
ranked. We select the top-k episodes and use their actions to update the mean and standard deviation of the actions distribution.
The CEM planning step is repeated for a specified number of steps.

A call to the module returns the actions that empirically maximised the returns given a planning horizon

Args:
env (EnvBase): The environment to perform the planning step on (can be ```ModelBasedEnv``` or ```EnvBase```).
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
planning_horizon (int): The length of the simulated trajectories
optim_steps (int): The number of optimization steps used by the MPC planner
num_candidates (int): The number of candidates to sample from the Gaussian distributions.
num_top_k_candidates (int): The number of top candidates to use to update the mean and standard deviation of the Gaussian distribution.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
reward_key (str, optional): The key in the TensorDict to use to retrieve the reward.
action_key (str, optional): The key in the TensorDict to use to store the action.
"""

def __init__(
self,
env: EnvBase,
planning_horizon: int,
optim_steps: int,
num_candidates: int,
num_top_k_candidates: int,
reward_key: str = "reward",
action_key: str = "action",
):
super().__init__(env=env, action_key=action_key)
self.planning_horizon = planning_horizon
self.optim_steps = optim_steps
self.num_candidates = num_candidates
self.num_top_k_candidates = num_top_k_candidates
self.reward_key = reward_key

def planning(self, tensordict: TensorDictBase) -> torch.Tensor:
batch_size = tensordict.batch_size
expanded_original_tensordict = (
tensordict.unsqueeze(-1)
.expand(*batch_size, self.num_candidates)
.reshape(-1)
)
flatten_batch_size = batch_size.numel()
actions_means = torch.zeros(
flatten_batch_size,
1,
self.planning_horizon,
*self.action_spec.shape,
device=tensordict.device,
dtype=self.env.action_spec.dtype,
)
actions_stds = torch.ones(
flatten_batch_size,
1,
self.planning_horizon,
*self.action_spec.shape,
device=tensordict.device,
dtype=self.env.action_spec.dtype,
)

for _ in range(self.optim_steps):
actions = actions_means + actions_stds * torch.randn(
flatten_batch_size,
self.num_candidates,
self.planning_horizon,
*self.action_spec.shape,
device=tensordict.device,
dtype=self.env.action_spec.dtype,
)
actions = actions.flatten(0, 1)
actions = self.env.action_spec.project(actions)
optim_tensordict = expanded_original_tensordict.to_tensordict()
policy = PrecomputedActionsSequentialSetter(actions)
optim_tensordict = self.env.rollout(
max_steps=self.planning_horizon,
policy=policy,
auto_reset=False,
tensordict=optim_tensordict,
)
rewards = (
optim_tensordict.get(self.reward_key)
.sum(dim=1)
.reshape(flatten_batch_size, self.num_candidates)
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
)
_, top_k = rewards.topk(self.num_top_k_candidates, dim=1)

best_actions = actions.unflatten(
0, (flatten_batch_size, self.num_candidates)
)
best_actions = best_actions[
torch.arange(flatten_batch_size).unsqueeze(1), top_k
]
actions_means = best_actions.mean(dim=1, keepdim=True)
actions_stds = best_actions.std(dim=1, keepdim=True)
return actions_means[:, :, 0].reshape(*batch_size, *self.action_spec.shape)


class PrecomputedActionsSequentialSetter:
def __init__(self, actions):
self.actions = actions
self.cmpt = 0

def __call__(self, td):
if self.cmpt >= self.actions.shape[1]:
raise ValueError("Precomputed actions are too short")
td = td.set("action", self.actions[:, self.cmpt])
self.cmpt += 1
return td
69 changes: 69 additions & 0 deletions torchrl/modules/planners/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import abc
from typing import Optional

import torch

from torchrl.data.tensordict.tensordict import TensorDictBase
from torchrl.envs import EnvBase
from torchrl.modules import TensorDictModule

__all__ = ["MPCPlannerBase"]


class MPCPlannerBase(TensorDictModule, metaclass=abc.ABCMeta):
"""MPCPlannerBase Module.

This is an abstract class.

This class inherits from TensorDictModule. Provided a TensorDict, this module will perform a Model Predictive Control (MPC) planning step.
At the end of the planning step, the MPCPlanner will return a proposed action
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved

Args:
env (EnvBase): The environment to perform the planning step on (Can be ModelBasedEnv or EnvBase).
action_key (str, optional): The key that will point to the computed action
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
env: EnvBase,
action_key: str = "action",
):
# Check if env is stateless
if env.batch_locked:
raise ValueError("Environment is batch_locked. MPCPlanners need an environnement that accepts batched inputs with any batch size")
out_keys = [action_key]
in_keys = list(env.observation_spec.keys())
super().__init__(env, in_keys=in_keys, out_keys=out_keys)
self.env = env
self.action_spec = env.action_spec
self.to(env.device)

@abc.abstractmethod
def planning(self, td: TensorDictBase) -> torch.Tensor:
"""Perform the MPC planning step.
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved

Args:
td (TensorDict): The TensorDict to perform the planning step on.
"""
raise NotImplementedError()

def forward(
self,
tensordict: TensorDictBase,
tensordict_out: Optional[TensorDictBase] = None,
**kwargs,
) -> TensorDictBase:
if "params" in kwargs or "vmap" in kwargs:
raise ValueError("MPCPlannerBase does not support params or vmap for now.")
nicolas-dufour marked this conversation as resolved.
Show resolved Hide resolved
action = self.planning(tensordict)
action = self.action_spec.project(action)
tensordict_out = self._write_to_tensordict(
tensordict,
(action,),
tensordict_out,
)
return tensordict_out
2 changes: 1 addition & 1 deletion torchrl/trainers/helpers/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def make_env_transforms(
key
for key in env.observation_spec.keys()
if ("pixels" not in key)
and (key.strip("next_") not in env.input_spec.keys())
and (key.replace("next_", "") not in env.input_spec.keys())
]

# even if there is a single tensor, it'll be renamed in "next_observation_vector"
Expand Down