Skip to content

Commit

Permalink
add VmapModule and from_lmhead_model method
Browse files Browse the repository at this point in the history
  • Loading branch information
apbard committed Jun 27, 2023
1 parent d3dbc15 commit c67b4be
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 1 deletion.
43 changes: 43 additions & 0 deletions torchrl/modules/tensordict_module/actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from typing import List, Optional, Sequence, Tuple, Union

import torch

from tensordict import TensorDictBase
from tensordict.nn import (
dispatch,
TensorDictModule,
TensorDictModuleBase,
TensorDictModuleWrapper,
TensorDictSequential,
)
from torch import nn
from torch.distributions import Categorical

from torchrl.data.tensor_specs import CompositeSpec, TensorSpec
from torchrl.modules.models.models import DistributionalDQNnet
Expand Down Expand Up @@ -1300,6 +1303,46 @@ def get_value_head(self) -> SafeSequential:
"""Returns the value head."""
return self.module[2]

@classmethod
def from_lmhead_model(cls, base_model):
"""Builds an Actor-value operator from
This method:
- takes as input an huggingface-like *LMHeadModel
- extracts the final linear layer uses it as a base layer of the actor_head and
adds the sampling layer
- uses the common transformer as common model
- adds a linear critic
Args:
base_model (nn.Module): a torch model composed by a `.transformer` model and `.lm_head` linear layer
"""
actor_head = base_model.lm_head
value_head = nn.Linear(actor_head.in_features, 1, bias=False)
common = TensorDictSequential(
TensorDictModule(
base_model.transformer,
in_keys={"input_ids": "input_ids", "attention_mask": "attention_mask"},
out_keys=["x"],
),
TensorDictModule(lambda x: x[:, -1, :], in_keys=["x"], out_keys=["x"]),
)
actor_head = TensorDictModule(actor_head, in_keys=["x"], out_keys=["logits"])
actor_head = SafeProbabilisticTensorDictSequential(
actor_head,
SafeProbabilisticModule(
in_keys=["logits"],
out_keys=["action"],
distribution_class=Categorical,
return_log_prob=True,
),
)
value_head = TensorDictModule(
value_head, in_keys=["x"], out_keys=["state_value"]
)

return cls(common, actor_head, value_head)


class ActorCriticOperator(ActorValueOperator):
"""Actor-critic operator.
Expand Down
33 changes: 32 additions & 1 deletion torchrl/modules/tensordict_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch

from tensordict.nn import TensorDictModule
from tensordict.nn import TensorDictModule, TensorDictModuleBase
from tensordict.tensordict import TensorDictBase
from torch import nn

Expand Down Expand Up @@ -401,3 +401,34 @@ def ensure_tensordict_compatible(
if out_keys is not None:
kwargs["out_keys"] = out_keys
return wrapper_type(module, **kwargs)


class VmapModule(TensorDictModuleBase):
"""A TensorDictModule wrapper to vmap over the input.
It is intended to be used with modules that accept data with one less batch
dimension than the one provided. By using this wrapper, one can hide a
batch dimension and satisfy the wrapped module.
Args:
module (TensorDictModuleBase): the module to vmap over.
vmap_dim (int, optional): the vmap input and output dim.
If none is provided, the last dimension of the tensordict is
assumed.
"""

def __init__(self, module: TensorDictModuleBase, vmap_dim=None):
super().__init__()
self.in_keys = module.in_keys
self.out_keys = module.out_keys
self.module = module
self.vmap_dim = vmap_dim

def forward(self, tensordict):
vmap_dim = self.vmap_dim
if vmap_dim is None:
ndim = tensordict.ndim
vmap_dim = ndim - 1
td = torch.vmap(self.module, (vmap_dim,), (vmap_dim,))(tensordict)
return tensordict.update(td)

0 comments on commit c67b4be

Please sign in to comment.