Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Deprecate direct usage of memmap tensors #1699

Merged
merged 12 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Nov 14, 2023
commit 24a4ed3ae27ea50f2c2ebd38529d90f34a3a26f5
11 changes: 8 additions & 3 deletions test/test_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
import torch.nn.functional as F

from _utils_internal import get_default_devices
from tensordict import is_tensor_collection, MemmapTensor, TensorDict, TensorDictBase
from tensordict import (
is_tensor_collection,
MemoryMappedTensor,
TensorDict,
TensorDictBase,
)
from tensordict.nn import TensorDictModule
from torchrl.data.rlhf import TensorDictTokenizer
from torchrl.data.rlhf.dataset import (
Expand Down Expand Up @@ -188,8 +193,8 @@ def test_dataset_to_tensordict(tmpdir, suffix):
else:
assert ("c", "d", "a") in td.keys(True)
assert ("c", "d", "b") in td.keys(True)
assert isinstance(td.get((suffix, "a")), MemmapTensor)
assert isinstance(td.get((suffix, "b")), MemmapTensor)
assert isinstance(td.get((suffix, "a")), MemoryMappedTensor)
assert isinstance(td.get((suffix, "b")), MemoryMappedTensor)


@pytest.mark.skipif(
Expand Down
39 changes: 26 additions & 13 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch
from tensordict import is_tensorclass
from tensordict.memmap import MemoryMappedTensor
from tensordict.memmap import MemmapTensor, MemoryMappedTensor
from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase
from tensordict.utils import expand_right

Expand Down Expand Up @@ -482,7 +482,7 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
if self.device == "auto":
self.device = data.device
if isinstance(data, torch.Tensor):
# if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
out = torch.empty(
self.max_size,
*data.shape,
Expand Down Expand Up @@ -531,12 +531,12 @@ class LazyMemmapStorage(LazyTensorStorage):
>>> storage.get(0)
TensorDict(
fields={
some data: MemmapTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
some data: MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
some: TensorDict(
fields={
nested: TensorDict(
fields={
data: MemmapTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)},
data: MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([11]),
device=cpu,
is_shared=False)},
Expand All @@ -560,8 +560,8 @@ class LazyMemmapStorage(LazyTensorStorage):
>>> storage.set(range(10), data)
>>> storage.get(0)
MyClass(
bar=MemmapTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False),
foo=MemmapTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
bar=MemoryMappedTensor(shape=torch.Size([11, 12]), device=cpu, dtype=torch.float32, is_shared=False),
foo=MemoryMappedTensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
batch_size=torch.Size([11]),
device=cpu,
is_shared=False)
Expand Down Expand Up @@ -603,7 +603,12 @@ def load_state_dict(self, state_dict):
if isinstance(self._storage, torch.Tensor):
_mem_map_tensor_as_tensor(self._storage).copy_(_storage)
elif self._storage is None:
self._storage = make_memmap(_storage, path=self.scratch_dir + "/tensor.memmap" if self.scratch_dir is not None else None)
self._storage = _make_memmap(
_storage,
path=self.scratch_dir + "/tensor.memmap"
if self.scratch_dir is not None
else None,
)
else:
raise RuntimeError(
f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}"
Expand Down Expand Up @@ -656,9 +661,13 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
)
else:
# If not a tensorclass/tensordict, it must be a tensor(-like)
# if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
out = make_empty_memmap(
(self.max_size, *data.shape), dtype=data.dtype, path=self.scratch_dir + "/tensor.memmap" if self.scratch_dir is not None else None
# if Tensor, we just create a MemoryMappedTensor of the desired shape, device and dtype
out = _make_empty_memmap(
(self.max_size, *data.shape),
dtype=data.dtype,
path=self.scratch_dir + "/tensor.memmap"
if self.scratch_dir is not None
else None,
)
if VERBOSE:
filesize = os.path.getsize(out.filename) / 1024 / 1024
Expand All @@ -677,6 +686,7 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor:
f"Supported backends are {_CKPT_BACKEND.backends}"
)
if isinstance(mem_map_tensor, torch.Tensor):
# This will account for MemoryMappedTensors
return mem_map_tensor
if _CKPT_BACKEND == "torchsnapshot":
# TorchSnapshot doesn't know how to stream MemmapTensor, so we view MemmapTensor
Expand Down Expand Up @@ -760,7 +770,10 @@ def _get_default_collate(storage, _is_tensordict=False):
f"Could not find a default collate_fn for storage {type(storage)}."
)

def make_memmap(tensor, path):
return MemoryMappedTensor.from_tensor(tensor, filename=path)
def make_empty_memmap(shape, dtype, path):

def _make_memmap(tensor, path):
return MemoryMappedTensor.from_tensor(tensor, filename=path)


def _make_empty_memmap(shape, dtype, path):
return MemoryMappedTensor.empty(shape=shape, dtype=dtype, filename=path)
8 changes: 4 additions & 4 deletions torchrl/data/rlhf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class TokenizedDatasetLoader:
>>> print(dataset)
TensorDict(
fields={
attention_mask: MemmapTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids: MemmapTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
attention_mask: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids: MemoryMappedTensor(shape=torch.Size([185068, 550]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([185068]),
device=None,
is_shared=False)
Expand Down Expand Up @@ -270,8 +270,8 @@ def dataset_to_tensordict(
fields={
prefix: TensorDict(
fields={
labels: MemmapTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
tokens: MemmapTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
labels: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.float32, is_shared=False),
tokens: MemoryMappedTensor(shape=torch.Size([10, 11]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([10]),
device=None,
is_shared=False)},
Expand Down
8 changes: 4 additions & 4 deletions torchrl/data/rlhf/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def from_dataset(
>>> data = PromptData.from_dataset("train")
>>> print(data)
PromptDataTLDR(
attention_mask=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
prompt_rindex=MemmapTensor(shape=torch.Size([116722]), device=cpu, dtype=torch.int64, is_shared=False),
labels=MemmapTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
attention_mask=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
prompt_rindex=MemoryMappedTensor(shape=torch.Size([116722]), device=cpu, dtype=torch.int64, is_shared=False),
labels=MemoryMappedTensor(shape=torch.Size([116722, 550]), device=cpu, dtype=torch.int64, is_shared=False),
logits=None,
loss=None,
batch_size=torch.Size([116722]),
Expand Down
16 changes: 8 additions & 8 deletions torchrl/data/rlhf/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ class PairwiseDataset:
>>> print(data)
PairwiseDataset(
chosen_data=RewardData(
attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
rewards=None,
end_scores=None,
batch_size=torch.Size([92534]),
device=None,
is_shared=False),
rejected_data=RewardData(
attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
rewards=None,
end_scores=None,
batch_size=torch.Size([92534]),
Expand Down Expand Up @@ -97,16 +97,16 @@ def from_dataset(
>>> print(data)
PairwiseDataset(
chosen_data=RewardData(
attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
rewards=None,
end_scores=None,
batch_size=torch.Size([92534]),
device=None,
is_shared=False),
rejected_data=RewardData(
attention_mask=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemmapTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
attention_mask=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
input_ids=MemoryMappedTensor(shape=torch.Size([92534, 550]), device=cpu, dtype=torch.int64, is_shared=False),
rewards=None,
end_scores=None,
batch_size=torch.Size([92534]),
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/value/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch

from tensordict import MemmapTensor
from tensordict import MemmapTensor, MemoryMappedTensor

__all__ = [
"generalized_advantage_estimate",
Expand Down
Loading