Skip to content

Commit

Permalink
[Features] Conv3dNet and PermuteTransform (pytorch#1398)
Browse files Browse the repository at this point in the history
Co-authored-by: vmoens <[email protected]>
  • Loading branch information
xmaples and vmoens committed Oct 10, 2023
1 parent 297a047 commit 3430168
Show file tree
Hide file tree
Showing 11 changed files with 630 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ to be able to create this other composition:
NoopResetEnv
ObservationNorm
ObservationTransform
PermuteTransform
PinMemoryTransform
R3MTransform
RandomCropTensorDict
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ Regular modules

MLP
ConvNet
Conv3dNet
LSTMNet
SqueezeLayer
Squeeze2dLayer
Expand Down
2 changes: 1 addition & 1 deletion test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def __new__(
cls._out_key = "pixels_orig"
state_spec = CompositeSpec(
{
cls._out_key: observation_spec["pixels_orig"],
cls._out_key: observation_spec["pixels_orig"].clone(),
},
shape=batch_size,
)
Expand Down
109 changes: 108 additions & 1 deletion test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
VDNMixer,
)
from torchrl.modules.distributions.utils import safeatanh, safetanh
from torchrl.modules.models import ConvNet, MLP, NoisyLazyLinear, NoisyLinear
from torchrl.modules.models import Conv3dNet, ConvNet, MLP, NoisyLazyLinear, NoisyLinear
from torchrl.modules.models.decision_transformer import (
_has_transformers,
DecisionTransformer,
Expand Down Expand Up @@ -181,6 +181,113 @@ def test_convnet(
assert y.shape == torch.Size([*batch, expected_features])


class TestConv3d:
@pytest.mark.parametrize("in_features", [3, 10, None])
@pytest.mark.parametrize(
"input_size, depth, num_cells, kernel_sizes, strides, paddings, expected_features",
[
(10, None, None, 3, 1, 0, 32 * 4 * 4 * 4),
(10, 3, 32, 3, 1, 1, 32 * 10 * 10 * 10),
],
)
@pytest.mark.parametrize(
"activation_class, activation_kwargs",
[(nn.ReLU, {"inplace": True}), (nn.ReLU, {}), (nn.PReLU, {})],
)
@pytest.mark.parametrize(
"norm_class, norm_kwargs",
[
(None, None),
(nn.LazyBatchNorm3d, {}),
(nn.BatchNorm3d, {"num_features": 32}),
],
)
@pytest.mark.parametrize("bias_last_layer", [True, False])
@pytest.mark.parametrize(
"aggregator_class, aggregator_kwargs",
[(SquashDims, None)],
)
@pytest.mark.parametrize("squeeze_output", [False])
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("batch", [(2,), (2, 2)])
def test_conv3dnet(
self,
batch,
in_features,
depth,
num_cells,
kernel_sizes,
strides,
paddings,
activation_class,
activation_kwargs,
norm_class,
norm_kwargs,
bias_last_layer,
aggregator_class,
aggregator_kwargs,
squeeze_output,
device,
input_size,
expected_features,
seed=0,
):
torch.manual_seed(seed)
conv3dnet = Conv3dNet(
in_features=in_features,
depth=depth,
num_cells=num_cells,
kernel_sizes=kernel_sizes,
strides=strides,
paddings=paddings,
activation_class=activation_class,
activation_kwargs=activation_kwargs,
norm_class=norm_class,
norm_kwargs=norm_kwargs,
bias_last_layer=bias_last_layer,
aggregator_class=aggregator_class,
aggregator_kwargs=aggregator_kwargs,
squeeze_output=squeeze_output,
device=device,
)
if in_features is None:
in_features = 5
x = torch.randn(
*batch, in_features, input_size, input_size, input_size, device=device
)
y = conv3dnet(x)
assert y.shape == torch.Size([*batch, expected_features])
with pytest.raises(ValueError, match="must have at least 4 dimensions"):
conv3dnet(torch.randn(3, 16, 16))

def test_errors(self):
with pytest.raises(
ValueError, match="Null depth is not permitted with Conv3dNet"
):
conv3dnet = Conv3dNet(
in_features=5,
num_cells=32,
depth=0,
)
with pytest.raises(
ValueError, match="depth=None requires one of the input args"
):
conv3dnet = Conv3dNet(
in_features=5,
num_cells=32,
depth=None,
)
with pytest.raises(
ValueError, match="consider matching or specifying a constant num_cells"
):
conv3dnet = Conv3dNet(
in_features=5,
num_cells=[32],
depth=None,
kernel_sizes=[3, 3],
)


@pytest.mark.parametrize(
"layer_class",
[
Expand Down
137 changes: 137 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ContinuousActionVecMockEnv,
CountingBatchedEnv,
CountingEnvCountPolicy,
DiscreteActionConvMockEnv,
DiscreteActionConvMockEnvNumpy,
IncrementingEnv,
MockBatchedLockedEnv,
Expand Down Expand Up @@ -70,6 +71,7 @@
NoopResetEnv,
ObservationNorm,
ParallelEnv,
PermuteTransform,
PinMemoryTransform,
R3MTransform,
RandomCropTensorDict,
Expand Down Expand Up @@ -8643,6 +8645,141 @@ def test_transform_inverse(self):
)


class TestPermuteTransform(TransformBase):
envclass = DiscreteActionConvMockEnv

@classmethod
def _get_permute(cls):
return PermuteTransform(
(-1, -2, -3), in_keys=["pixels_orig", "pixels"], in_keys_inv=["pixels_orig"]
)

def test_single_trans_env_check(self):
base_env = TestPermuteTransform.envclass()
env = TransformedEnv(base_env, TestPermuteTransform._get_permute())
check_env_specs(env)
assert env.observation_spec["pixels"] == env.observation_spec["pixels_orig"]
assert env.state_spec["pixels_orig"] == env.observation_spec["pixels_orig"]

def test_serial_trans_env_check(self):
env = SerialEnv(
2,
lambda: TransformedEnv(
TestPermuteTransform.envclass(), TestPermuteTransform._get_permute()
),
)
check_env_specs(env)

def test_parallel_trans_env_check(self):
env = ParallelEnv(
2,
lambda: TransformedEnv(
TestPermuteTransform.envclass(), TestPermuteTransform._get_permute()
),
)
check_env_specs(env)

def test_trans_serial_env_check(self):
env = TransformedEnv(
SerialEnv(2, TestPermuteTransform.envclass),
TestPermuteTransform._get_permute(),
)
check_env_specs(env)

def test_trans_parallel_env_check(self):
env = TransformedEnv(
ParallelEnv(2, TestPermuteTransform.envclass),
TestPermuteTransform._get_permute(),
)
check_env_specs(env)

@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
def test_transform_compose(self, batch):
D, W, H, C = 8, 32, 64, 3
trans = Compose(
PermuteTransform(
dims=(-1, -4, -2, -3),
in_keys=["pixels"],
)
) # DxWxHxC => CxDxHxW
td = TensorDict({"pixels": torch.randn((*batch, D, W, H, C))}, batch_size=batch)
td = trans(td)
assert td["pixels"].shape == torch.Size((*batch, C, D, H, W))

def test_transform_env(self):
base_env = TestPermuteTransform.envclass()
env = TransformedEnv(base_env, TestPermuteTransform._get_permute())
check_env_specs(env)
assert env.observation_spec["pixels"] == env.observation_spec["pixels_orig"]
assert env.state_spec["pixels_orig"] == env.observation_spec["pixels_orig"]
assert env.state_spec["pixels_orig"] != base_env.state_spec["pixels_orig"]
assert env.observation_spec["pixels"] != base_env.observation_spec["pixels"]

td = env.rollout(3)
assert td["pixels"].shape == torch.Size([3, 7, 7, 1])

# check error
with pytest.raises(ValueError, match="Only tailing dims with negative"):
t = PermuteTransform((-1, -10))

def test_transform_model(self):
batch = [2]
D, W, H, C = 8, 32, 64, 3
trans = PermuteTransform(
dims=(-1, -4, -2, -3),
in_keys=["pixels"],
) # DxWxHxC => CxDxHxW
td = TensorDict({"pixels": torch.randn((*batch, D, W, H, C))}, batch_size=batch)
out_channels = 4
from tensordict.nn import TensorDictModule

model = nn.Sequential(
trans,
TensorDictModule(
nn.Conv3d(C, out_channels, 3, padding=1),
in_keys=["pixels"],
out_keys=["pixels"],
),
)
td = model(td)
assert td["pixels"].shape == torch.Size((*batch, out_channels, D, H, W))

def test_transform_rb(self):
batch = [6]
D, W, H, C = 4, 5, 6, 3
trans = PermuteTransform(
dims=(-1, -4, -2, -3),
in_keys=["pixels"],
) # DxWxHxC => CxDxHxW
td = TensorDict({"pixels": torch.randn((*batch, D, W, H, C))}, batch_size=batch)
rb = TensorDictReplayBuffer(storage=LazyTensorStorage(5), transform=trans)
rb.extend(td)
sample = rb.sample(2)
assert sample["pixels"].shape == torch.Size([2, C, D, H, W])

@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
def test_transform_inverse(self, batch):
D, W, H, C = 8, 32, 64, 3
trans = PermuteTransform(
dims=(-1, -4, -2, -3),
in_keys_inv=["pixels"],
) # DxWxHxC => CxDxHxW
td = TensorDict({"pixels": torch.randn((*batch, C, D, H, W))}, batch_size=batch)
td = trans.inv(td)
assert td["pixels"].shape == torch.Size((*batch, D, W, H, C))

@pytest.mark.parametrize("batch", [[], [2], [2, 4]])
def test_transform_no_env(self, batch):
D, W, H, C = 8, 32, 64, 3
trans = PermuteTransform(
dims=(-1, -4, -2, -3),
in_keys=["pixels"],
) # DxWxHxC => CxDxHxW
td = TensorDict({"pixels": torch.randn((*batch, D, W, H, C))}, batch_size=batch)
td = trans(td)
assert td["pixels"].shape == torch.Size((*batch, C, D, H, W))


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1 change: 1 addition & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
NoopResetEnv,
ObservationNorm,
ObservationTransform,
PermuteTransform,
PinMemoryTransform,
R3MTransform,
RandomCropTensorDict,
Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
NoopResetEnv,
ObservationNorm,
ObservationTransform,
PermuteTransform,
PinMemoryTransform,
RandomCropTensorDict,
RenameTransform,
Expand Down
Loading

0 comments on commit 3430168

Please sign in to comment.