Skip to content

Commit

Permalink
[Tests] Adding tensordict __repr__ tests (#435)
Browse files Browse the repository at this point in the history
  • Loading branch information
sladebot authored Sep 21, 2022
1 parent 8ede243 commit 443e5d9
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 6 deletions.
287 changes: 286 additions & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1612,6 +1612,291 @@ def test_repr(self, td_name, device):
_ = str(td)


@pytest.mark.parametrize("device", [None, *get_available_devices()])
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
class TestTensorDictRepr:
def td(self, device, dtype):
if device is not None:
device_not_none = device
elif torch.has_cuda and torch.cuda.device_count():
device_not_none = torch.device("cuda:0")
else:
device_not_none = torch.device("cpu")

return TensorDict(
source={
"a": torch.zeros(4, 3, 2, 1, 5, dtype=dtype, device=device_not_none)
},
batch_size=[4, 3, 2, 1],
device=device,
)

def nested_td(self, device, dtype):
if device is not None:
device_not_none = device
elif torch.has_cuda and torch.cuda.device_count():
device_not_none = torch.device("cuda:0")
else:
device_not_none = torch.device("cpu")
return TensorDict(
source={
"my_nested_td": self.td(device, dtype),
"b": torch.zeros(4, 3, 2, 1, 5, dtype=dtype, device=device_not_none),
},
batch_size=[4, 3, 2, 1],
device=device,
)

def stacked_td(self, device, dtype):
if device is not None:
device_not_none = device
elif torch.has_cuda and torch.cuda.device_count():
device_not_none = torch.device("cuda:0")
else:
device_not_none = torch.device("cpu")
td1 = TensorDict(
source={
"a": torch.zeros(4, 3, 1, 5, dtype=dtype, device=device_not_none),
"c": torch.zeros(4, 3, 1, 5, dtype=dtype, device=device_not_none),
},
batch_size=[4, 3, 1],
device=device,
)
td2 = TensorDict(
source={
"a": torch.zeros(4, 3, 1, 5, dtype=dtype, device=device_not_none),
"b": torch.zeros(4, 3, 1, 10, dtype=dtype, device=device_not_none),
},
batch_size=[4, 3, 1],
device=device,
)

return stack_td([td1, td2], 2)

def memmap_td(self, device, dtype):
return self.td(device, dtype).memmap_(lock=False)

def share_memory_td(self, device, dtype):
return self.td(device, dtype).share_memory_(lock=False)

def test_repr_plain(self, device, dtype):
tensordict = self.td(device, dtype)
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
expected = f"""TensorDict(
fields={{
a: Tensor(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(tensordict) == expected

def test_repr_memmap(self, device, dtype):
tensordict = self.memmap_td(device, dtype)
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
expected = f"""TensorDict(
fields={{
a: MemmapTensor(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(tensordict) == expected

def test_repr_share_memory(self, device, dtype):
tensordict = self.share_memory_td(device, dtype)
is_shared = True
is_device_cpu = device is not None and device.type == "cpu"
is_none_device_cpu = device is None and torch.cuda.device_count() == 0
tensor_class = (
"SharedTensor" if is_none_device_cpu or is_device_cpu else "Tensor"
)
expected = f"""TensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(tensordict) == expected

def test_repr_nested(self, device, dtype):
nested_td = self.nested_td(device, dtype)
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
tensor_class = "Tensor"
expected = f"""TensorDict(
fields={{
b: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype}),
my_nested_td: TensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(nested_td) == expected

def test_repr_stacked(self, device, dtype):
stacked_td = self.stacked_td(device, dtype)
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
tensor_class = "Tensor"
expected = f"""LazyStackedTensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(stacked_td) == expected

@pytest.mark.parametrize("index", [None, (slice(None), 0)])
def test_repr_indexed_tensordict(self, device, dtype, index):
tensordict = self.td(device, dtype)[index]
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
tensor_class = "Tensor"
if index is None:
expected = f"""TensorDict(
fields={{
a: {tensor_class}(torch.Size([1, 4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([1, 4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
else:
expected = f"""TensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 2, 1]),
device={str(device)},
is_shared={is_shared})"""

assert repr(tensordict) == expected

@pytest.mark.parametrize("index", [None, (slice(None), 0)])
def test_repr_indexed_nested_tensordict(self, device, dtype, index):
nested_tensordict = self.nested_td(device, dtype)[index]
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
tensor_class = "Tensor"
if index is None:
expected = f"""TensorDict(
fields={{
b: {tensor_class}(torch.Size([1, 4, 3, 2, 1, 5]), dtype={dtype}),
my_nested_td: TensorDict(
fields={{
a: {tensor_class}(torch.Size([1, 4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([1, 4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})}},
batch_size=torch.Size([1, 4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
else:
expected = f"""TensorDict(
fields={{
b: {tensor_class}(torch.Size([4, 2, 1, 5]), dtype={dtype}),
my_nested_td: TensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 2, 1]),
device={str(device)},
is_shared={is_shared})}},
batch_size=torch.Size([4, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(nested_tensordict) == expected

@pytest.mark.parametrize("index", [None, (slice(None), 0)])
def test_repr_indexed_stacked_tensordict(self, device, dtype, index):
stacked_tensordict = self.stacked_td(device, dtype)
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
else:
is_shared = False
tensor_class = "Tensor"
if index is None:
expected = f"""LazyStackedTensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
else:
expected = f"""LazyStackedTensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device)},
is_shared={is_shared})"""
assert repr(stacked_tensordict) == expected

@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda")
@pytest.mark.parametrize("device_cast", get_available_devices())
def test_repr_device_to_device(self, device, dtype, device_cast):
td = self.td(device, dtype)
if (device_cast is None and (torch.cuda.device_count() > 0)) or (
device_cast is not None and device_cast.type == "cuda"
):
is_shared = True
else:
is_shared = False
tensor_class = "Tensor"
td2 = td.to(device_cast)
expected = f"""TensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2, 1]),
device={str(device_cast)},
is_shared={is_shared})"""
assert repr(td2) == expected

@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda")
def test_repr_batch_size_update(self, device, dtype):
td = self.td(device, dtype)
td.batch_size = torch.Size([4, 3, 2])
is_shared = False
tensor_class = "Tensor"
if (device is None and (torch.cuda.device_count() > 0)) or (
device is not None and device.type == "cuda"
):
is_shared = True
expected = f"""TensorDict(
fields={{
a: {tensor_class}(torch.Size([4, 3, 2, 1, 5]), dtype={dtype})}},
batch_size=torch.Size([4, 3, 2]),
device={device},
is_shared={is_shared})"""
assert repr(td) == expected


@pytest.mark.parametrize(
"td_name",
[
Expand Down Expand Up @@ -1718,7 +2003,7 @@ def test_batchsize_reset():
td.set("c", torch.randn(3))

# test index
subtd = td[torch.tensor([1, 2])]
td[torch.tensor([1, 2])]
with pytest.raises(
RuntimeError,
match=re.escape(
Expand Down
4 changes: 2 additions & 2 deletions torchrl/data/tensordict/metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def __init__(
name = "TensorDict"
elif _is_memmap:
name = "MemmapTensor"
elif _is_shared:
elif _is_shared and device.type != "cuda":
name = "SharedTensor"
else:
name = "Tensor"
Expand Down Expand Up @@ -152,7 +152,7 @@ def share_memory_(self) -> MetaTensor:
"""

self._is_shared = True
self.class_name = "SharedTensor"
self.class_name = "SharedTensor" if self.device.type != "cuda" else "Tensor"
return self

def is_shared(self) -> bool:
Expand Down
13 changes: 10 additions & 3 deletions torchrl/data/tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,11 +2244,18 @@ def share_memory_(self, lock=True) -> TensorDictBase:
"share_memory_ must be called when the TensorDict is ("
"partially) populated. Set a tensor first."
)
if self.device_safe() is not None and self.device != torch.device("cpu"):
if self.device_safe() is not None and self.device.type == "cuda":
# cuda tensors are shared by default
self._is_shared = True
return self
for value in self.values():
value.share_memory_()
# no need to consider MemmapTensors here as we have checked that this is not a memmap-tensordict
if (
isinstance(value, torch.Tensor)
and value.device.type == "cpu"
or isinstance(value, TensorDictBase)
):
value.share_memory_()
for value in self.values_meta():
value.share_memory_()
self._is_shared = True
Expand All @@ -2261,7 +2268,7 @@ def detach_(self) -> TensorDictBase:
return self

def memmap_(self, prefix=None, lock=True) -> TensorDictBase:
if self.is_shared() and self.device == torch.device("cpu"):
if self.is_shared() and self.device_safe() == torch.device("cpu"):
raise RuntimeError(
"memmap and shared memory are mutually exclusive features."
)
Expand Down

0 comments on commit 443e5d9

Please sign in to comment.