Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tests] Adding tensordict __repr__ tests #435

Merged
merged 18 commits into from
Sep 21, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding tests for indexed tensordicts (nested)
  • Loading branch information
Souranil Sen committed Sep 13, 2022
commit 8342872f1b01825d0f995ee7aaf31cde53e1bfb5
116 changes: 96 additions & 20 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import argparse
import os.path
import re
from textwrap import indent

import pytest
import torch
Expand Down Expand Up @@ -1572,21 +1571,16 @@ def test_repr(self, td_name, device):
class TestTensorDictRepr:
def td(self, device):
return TensorDict(
source={
"a": torch.randn(4, 3, 2, 1, 5)
},
source={"a": torch.randn(4, 3, 2, 1, 5)},
batch_size=[4, 3, 2, 1],
device=device,
)

def nested_td(self, device):
return TensorDict(
source={
"a": torch.randn(4, 3, 2, 1, 5),
"my_nested_td": self.td(device)
},
source={"my_nested_td": self.td(device), "b": torch.randn(4, 3, 2, 1, 5)},
batch_size=[4, 3, 2, 1],
device=device
device=device,
)

def stacked_td(self, device):
Expand All @@ -1607,20 +1601,20 @@ def stacked_td(self, device):
return stack_td([td1, td2], 2)

def test_plain(self, device):
td = self.td(device)
tensordict = self.td(device)
expected = """TensorDict(
fields={
a: Tensor(torch.Size([4, 3, 2, 1, 5]), dtype=torch.float32)},
batch_size=torch.Size([4, 3, 2, 1]),
device=cpu,
is_shared=False)"""
assert (repr(td) == expected)
assert repr(tensordict) == expected

def test_nested(self, device):
nested_td = self.nested_td(device)
expected = '''TensorDict(
expected = """TensorDict(
fields={
a: Tensor(torch.Size([4, 3, 2, 1, 5]), dtype=torch.float32),
b: Tensor(torch.Size([4, 3, 2, 1, 5]), dtype=torch.float32),
my_nested_td: TensorDict(
fields={
a: Tensor(torch.Size([4, 3, 2, 1, 5]), dtype=torch.float32)},
Expand All @@ -1629,27 +1623,109 @@ def test_nested(self, device):
is_shared=False)},
batch_size=torch.Size([4, 3, 2, 1]),
device=cpu,
is_shared=False)'''
assert(repr(nested_td) == expected)
is_shared=False)"""
assert repr(nested_td) == expected

def test_stacked(self, device):
stacked_td = self.stacked_td(device)
expected = '''LazyStackedTensorDict(
expected = """LazyStackedTensorDict(
fields={
},
batch_size=torch.Size([4, 3, 2, 1]),
device=cpu,
is_shared=False)'''
assert(repr(stacked_td) == expected)
is_shared=False)"""
assert repr(stacked_td) == expected

def test_indexed_tensor(self, device):
tensordict = TensorDict({}, [5], device=device)
tensordict.set("a", torch.randn(5, 4, 3))
expected = """TensorDict(
fields={
a: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32)},
batch_size=torch.Size([5]),
device=cpu,
is_shared=False)"""
assert repr(tensordict) == expected

def test_indexed_nested(self, device):
tensordict = TensorDict({}, [4, 3, 2, 1], device=device)
tensordict.set("nested_td", self.nested_td(device))

expected = """TensorDict(
fields={
nested_td: TensorDict(
fields={
b: Tensor(torch.Size([4, 3, 2, 1, 5]), dtype=torch.float32),
my_nested_td: TensorDict(
fields={
a: Tensor(torch.Size([4, 3, 2, 1, 5]), dtype=torch.float32)},
batch_size=torch.Size([4, 3, 2, 1]),
device=cpu,
is_shared=False)},
batch_size=torch.Size([4, 3, 2, 1]),
device=cpu,
is_shared=False)},
batch_size=torch.Size([4, 3, 2, 1]),
device=cpu,
is_shared=False)"""
assert repr(tensordict) == expected

def test_indexed_integer(self, device):
tensordict = TensorDict({}, [5], device=device)
tensordict.set("k_int", torch.randint(10, (5, 4, 3)))

expected = """TensorDict(
fields={
a: Tensor(torch.Size([5, 4, 3]), dtype=torch.float32),
k_int: Tensor(torch.Size([5, 4, 3]), dtype=torch.int64)},
batch_size=torch.Size([5]),
device=cpu,
is_shared=False)"""
assert repr(tensordict), expected

def test_indexed_mask(self, device):
tensordict = TensorDict({}, [2], device=device)
tensordict.set("a", torch.randn(2, 3))
mask = torch.BoolTensor([[1, 0, 1], [1, 0, 1]])
masked_td = tensordict[mask]
expected = """TensorDict(
fields={
a: Tensor(torch.Size([4, 1]), dtype=torch.float32)},
batch_size=torch.Size([4]),
device=cpu,
is_shared=False)"""
assert repr(masked_td) == expected

def test_indexed_stack(self, device):
tensordict = TensorDict({}, [5], device=device)
td3 = TensorDict({"d": torch.randn(5, 4, 3)}, [5], device=device)
stacked_td = stack_td([tensordict, td3], 2)

expected = """LazyStackedTensorDict(
fields={
},
batch_size=torch.Size([5, 2]),
device=cpu,
is_shared=False)"""
assert repr(stacked_td), expected

@pytest.mark.skipif(not torch.cuda.device_count(), reason="no cuda")
def test_device_to_device(self, device):
dev2 = torch.device(0)
td = self.td(device)
td2 = td.to(dev2)
assert(repr(td) == repr(td2))

assert repr(td) == repr(td2)
sladebot marked this conversation as resolved.
Show resolved Hide resolved

def test_batch_size_update(self, device):
td = self.td(device)
td.batch_size = torch.Size([4, 3, 2])
expected = """TensorDict(
fields={
a: Tensor(torch.Size([4, 3, 2, 1, 5]), dtype=torch.float32)},
batch_size=torch.Size([4, 3, 2]),
device=cpu,
is_shared=False)"""
assert repr(td) == expected


@pytest.mark.parametrize(
Expand Down