Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tests] Adding tensordict __repr__ tests #435

Merged
merged 18 commits into from
Sep 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 286 additions & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,291 @@ def test_repr(self, td_name, device):
_ = str(td)


@pytest.mark.parametrize("device", [None, *get_available_devices()])
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
class TestTensorDictRepr:
def td(self, device, dtype):
if device is not None:
device_not_none = device
elif torch.has_cuda and torch.cuda.device_count():
device_not_none = torch.device("cuda:0")
else:
device_not_none = torch.device("cpu")

return TensorDict(
source={
"a": torch.zeros(4, 3, 2, 1, 5, dtype=dtype, device=device_not_none)
},
batch_size=[4, 3, 2, 1],
device=device,
)

def nested_td(self, device, dtype):
if device is not None:
device_not_none = device
elif torch.has_cuda and torch.cuda.device_count():
device_not_none = torch.device("cuda:0")
else:
device_not_none = torch.device("cpu")
return TensorDict(
source={
"my_nested_td": self.td(device, dtype),
"b": torch.zeros(4, 3, 2, 1, 5, dtype=dtype, device=device_not_none),
},
batch_size=[4, 3, 2, 1],
device=device,
)

def stacked_td(self, device, dtype):
if device is not None:
device_not_none = device
elif torch.has_cuda and torch.cuda.device_count():
device_not_none = torch.device("cuda:0")
else:
device_not_none = torch.device("cpu")
td1 = TensorDict(
source={
"a": torch.zeros(4, 3, 1, 5, dtype=dtype, device=device_not_none),
"c": torch.zeros(4, 3, 1, 5, dtype=dtype, device=device_not_none),
},
batch_size=[4, 3, 1],
device=device,
)
td2 = TensorDict(
source={
"a": torch.zeros(4, 3, 1, 5, dtype=dtype, device=device_not_none),
"b": torch.zeros(4, 3, 1, 10, dtype=dtype, device=device_not_none),
},
batch_size=[4, 3, 1],
device=device,
)

return stack_td([td1, td2], 2)

def memmap_td(self, device, dtype):
return self.td(device, dtype).memmap_(lock=False)

def share_memory_td(self, device, dtype):
return self.td(device, dtype).share_memory_(lock=False)

def test_repr_plain(self, device, dtype):
tensordict = self.td(device, dtype)
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
expected = f"""TensorDict(
fields={{
a: Tensor(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(tensordict) == expected

def test_repr_memmap(self, device, dtype):
tensordict = self.memmap_td(device, dtype)
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
expected = f"""TensorDict(
fields={{
a: MemmapTensor(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(tensordict) == expected

def test_repr_share_memory(self, device, dtype):
tensordict = self.share_memory_td(device, dtype)
is_shared = True
is_device_cpu = device is not None and device.type == "cpu"
is_none_device_cpu = device is None and torch.cuda.device_count() == 0
tensor_class = (
"SharedTensor" if is_none_device_cpu or is_device_cpu else "Tensor"
)
expected = f"""TensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(tensordict) == expected

def test_repr_nested(self, device, dtype):
nested_td = self.nested_td(device, dtype)
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
tensor_class = "Tensor"
expected = f"""TensorDict(
fields={{
b: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype}),
my_nested_td: TensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(nested_td) == expected

def test_repr_stacked(self, device, dtype):
stacked_td = self.stacked_td(device, dtype)
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
tensor_class = "Tensor"
expected = f"""LazyStackedTensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(stacked_td) == expected

@pytest.mark.parametrize("index", [None, (slice(None), 0)])
def test_repr_indexed_tensordict(self, device, dtype, index):
tensordict = self.td(device, dtype)[index]
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
tensor_class = "Tensor"
if index is None:
expected = f"""TensorDict(
fields={{
a: {tensor_class}(torch.Size([1, 4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([1, 4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
else:
expected = f"""TensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 2, 1]),
device={str(device)},
is_shared={is_shared})"""

assert repr(tensordict) == expected

@pytest.mark.parametrize("index", [None, (slice(None), 0)])
def test_repr_indexed_nested_tensordict(self, device, dtype, index):
nested_tensordict = self.nested_td(device, dtype)[index]
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
tensor_class = "Tensor"
if index is None:
expected = f"""TensorDict(
fields={{
b: {tensor_class}(torch.Size([1, 4, 3, 2, 1, 5]), dtype={dtype}),
my_nested_td: TensorDict(
fields={{
a: {tensor_class}(torch.Size([1, 4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([1, 4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})}},
batch_size=torch.Size([1, 4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
else:
expected = f"""TensorDict(
fields={{
b: {tensor_class}(torch.Size([4, 2, 1, 5]), dtype={dtype}),
my_nested_td: TensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 2, 1]),
device={str(device)},
is_shared={is_shared})}},
batch_size=torch.Size([4, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(nested_tensordict) == expected

@pytest.mark.parametrize("index", [None, (slice(None), 0)])
def test_repr_indexed_stacked_tensordict(self, device, dtype, index):
stacked_tensordict = self.stacked_td(device, dtype)
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
tensor_class = "Tensor"
if index is None:
expected = f"""LazyStackedTensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
else:
expected = f"""LazyStackedTensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(stacked_tensordict) == expected

@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda")
@pytest.mark.parametrize("device_cast", get_available_devices())
def test_repr_device_to_device(self, device, dtype, device_cast):
td = self.td(device, dtype)
if (device_cast is None and (torch.cuda.device_count() > 0)) or (
device_cast is not None and device_cast.type == "cuda"
):
is_shared = True
else:
is_shared = False
tensor_class = "Tensor"
td2 = td.to(device_cast)
expected = f"""TensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device_cast)},
is_shared={is_shared})"""
assert repr(td2) == expected

@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda")
def test_repr_batch_size_update(self, device, dtype):
td = self.td(device, dtype)
td.batch_size = torch.Size([4, 3, 2])
is_shared = False
tensor_class = "Tensor"
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
expected = f"""TensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2]),
device={device},
is_shared={is_shared})"""
assert repr(td) == expected


@pytest.mark.parametrize(
"td_name",
[
Expand Down Expand Up @@ -1718,7 +2003,7 @@ def test_batchsize_reset():
td.set("c", torch.randn(3))

# test index
subtd = td[torch.tensor([1, 2])]
td[torch.tensor([1, 2])]
with pytest.raises(
RuntimeError,
match=re.escape(
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/tensordict/metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(
name = "TensorDict"
elif _is_memmap:
name = "MemmapTensor"
elif _is_shared:
elif _is_shared and device.type != "cuda":
name = "SharedTensor"
else:
name = "Tensor"
Expand Down Expand Up @@ -152,7 +152,7 @@ def share_memory_(self) -> MetaTensor:
"""

self._is_shared = True
self.class_name = "SharedTensor"
self.class_name = "SharedTensor" if self.device.type != "cuda" else "Tensor"
return self

def is_shared(self) -> bool:
Expand Down
13 changes: 10 additions & 3 deletions torchrl/data/tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,11 +2244,18 @@ def share_memory_(self, lock=True) -> TensorDictBase:
"share_memory_ must be called when the TensorDict is ("
"partially) populated. Set a tensor first."
)
if self.device_safe() is not None and self.device != torch.device("cpu"):
if self.device_safe() is not None and self.device.type == "cuda":
# cuda tensors are shared by default
self._is_shared = True
return self
for value in self.values():
value.share_memory_()
# no need to consider MemmapTensors here as we have checked that this is not a memmap-tensordict
if (
isinstance(value, torch.Tensor)
and value.device.type == "cpu"
or isinstance(value, TensorDictBase)
):
value.share_memory_()
for value in self.values_meta():
value.share_memory_()
self._is_shared = True
Expand All @@ -2261,7 +2268,7 @@ def detach_(self) -> TensorDictBase:
return self

def memmap_(self, prefix=None, lock=True) -> TensorDictBase:
if self.is_shared() and self.device == torch.device("cpu"):
if self.is_shared() and self.device_safe() == torch.device("cpu"):
raise RuntimeError(
"memmap and shared memory are mutually exclusive features."
)
Expand Down