Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Deprecate direct usage of memmap tensors #1699

Merged
merged 12 commits into from
Nov 15, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Nov 8, 2023
commit 3832c373803a9ec906f203f1ba24f19ada416a75
24 changes: 21 additions & 3 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ def load_state_dict(self, state_dict):
if isinstance(self._storage, torch.Tensor):
_mem_map_tensor_as_tensor(self._storage).copy_(_storage)
elif self._storage is None:
self._storage = MemmapTensor(_storage)
self._storage = make_memmap(_storage, path=self.scratch_dir + "/tensor.memmap" if self.scratch_dir is not None else None)
else:
raise RuntimeError(
f"Cannot copy a storage of type {type(_storage)} onto another of type {type(self._storage)}"
Expand Down Expand Up @@ -657,8 +657,8 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None:
else:
# If not a tensorclass/tensordict, it must be a tensor(-like)
# if Tensor, we just create a MemmapTensor of the desired shape, device and dtype
out = MemmapTensor(
self.max_size, *data.shape, device=self.device, dtype=data.dtype
out = make_empty_memmap(
(self.max_size, *data.shape), dtype=data.dtype, path=self.scratch_dir + "/tensor.memmap" if self.scratch_dir is not None else None
)
if VERBOSE:
filesize = os.path.getsize(out.filename) / 1024 / 1024
Expand Down Expand Up @@ -759,3 +759,21 @@ def _get_default_collate(storage, _is_tensordict=False):
raise NotImplementedError(
f"Could not find a default collate_fn for storage {type(storage)}."
)

@implement_for("torch", None, "2.2.0")
def make_memmap(tensor, path):
return MemmapTensor.from_tensor(tensor, filename=path)

@implement_for("torch", "2.2.0")
def make_memmap(tensor, path):
from tensordict._memory_map import from_tensor
return from_tensor(tensor, filename=path)

@implement_for("torch", None, "2.2.0")
def make_empty_memmap(shape, dtype, path):
return MemmapTensor(shape, dtype=dtype, filename=path)

@implement_for("torch", "2.2.0")
def make_empty_memmap(shape, dtype, path):
from tensordict._memory_map import empty_like
return empty_like(torch.zeros((), dtype=dtype).expand(shape), filename=path)
Loading