Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Replaced device_safe() with device #485

Merged
merged 6 commits into from
Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 16 additions & 21 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,11 @@ def test_tensordict_set(device):
@pytest.mark.parametrize("device", get_available_devices())
def test_tensordict_device(device):
tensordict = TensorDict({"a": torch.randn(3, 4)}, [])
with pytest.raises(RuntimeError):
tensordict.device
assert tensordict.device is None

tensordict = TensorDict({"a": torch.randn(3, 4, device=device)}, [])
assert tensordict["a"].device == device
with pytest.raises(RuntimeError):
tensordict.device
assert tensordict.device is None

tensordict = TensorDict(
{
Expand Down Expand Up @@ -2104,12 +2102,10 @@ def test_create_on_device():

# TensorDict
td = TensorDict({}, [5])
with pytest.raises(RuntimeError):
td.device
assert td.device is None

td.set("a", torch.randn(5, device=device))
with pytest.raises(RuntimeError):
td.device
assert td.device is None

td = TensorDict({}, [5], device="cuda:0")
td.set("a", torch.randn(5, 1))
Expand All @@ -2119,11 +2115,10 @@ def test_create_on_device():
td1 = TensorDict({}, [5])
td2 = TensorDict({}, [5])
stackedtd = stack_td([td1, td2], 0)
with pytest.raises(RuntimeError):
stackedtd.device
assert stackedtd.device is None

stackedtd.set("a", torch.randn(2, 5, device=device))
with pytest.raises(RuntimeError):
stackedtd.device
assert stackedtd.device is None

stackedtd = stackedtd.to(device)
assert stackedtd.device == device
Expand All @@ -2139,12 +2134,12 @@ def test_create_on_device():
# TensorDict, indexed
td = TensorDict({}, [5])
subtd = td[1]
with pytest.raises(RuntimeError):
subtd.device
assert subtd.device is None

subtd.set("a", torch.randn(1, device=device))
with pytest.raises(RuntimeError):
# setting element of subtensordict doesn't set top-level device
subtd.device
# setting element of subtensordict doesn't set top-level device
assert subtd.device is None

subtd = subtd.to(device)
assert subtd.device == device
assert subtd["a"].device == device
Expand All @@ -2162,8 +2157,8 @@ def test_create_on_device():
# SavedTensorDict
td = TensorDict({}, [5])
savedtd = td.to(SavedTensorDict)
with pytest.raises(RuntimeError):
savedtd.device
assert savedtd.device is None

savedtd = savedtd.to(device)
assert savedtd.device == device

Expand All @@ -2175,8 +2170,8 @@ def test_create_on_device():
# ViewedTensorDict
td = TensorDict({}, [6])
viewedtd = td.view(2, 3)
with pytest.raises(RuntimeError):
viewedtd.device
assert viewedtd.device is None

viewedtd = viewedtd.to(device)
assert viewedtd.device == device

Expand Down
4 changes: 2 additions & 2 deletions torchrl/collectors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase:
"mask",
torch.ones(
rollout_tensordict.shape,
device=rollout_tensordict.device_safe(),
device=rollout_tensordict.device,
dtype=torch.bool,
),
)
Expand All @@ -66,7 +66,7 @@ def split_trajectories(rollout_tensordict: TensorDictBase) -> TensorDictBase:
}
td = TensorDict(
source=out_dict,
device=rollout_tensordict.device_safe(),
device=rollout_tensordict.device,
batch_size=out_dict["mask"].shape[:-1],
)
if (out_dict["done"].sum(1) > 1).any():
Expand Down
21 changes: 5 additions & 16 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,7 @@ def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]:


def _pin_memory(output: Any) -> Any:
output_device = (
output.device_safe() if hasattr(output, "device_safe") else output.device
)
if hasattr(output, "pin_memory") and output_device == torch.device("cpu"):
if hasattr(output, "pin_memory") and output.device == torch.device("cpu"):
return output.pin_memory()
else:
return output
Expand Down Expand Up @@ -369,11 +366,7 @@ def __getitem__(self, index: Union[int, Tensor]) -> Any:
weight = np.power(weight / p_min, -self._beta)
# x = first_field(data)
# if isinstance(x, torch.Tensor):
device = (
data.device_safe()
if hasattr(data, "device_safe")
else (data.device if hasattr(data, "device") else torch.device("cpu"))
)
device = data.device if hasattr(data, "device") else torch.device("cpu")
weight = to_torch(weight, device, self._pin_memory)
return data, weight

Expand Down Expand Up @@ -476,11 +469,7 @@ def _sample(self, batch_size: int) -> Tuple[Any, torch.Tensor, torch.Tensor]:

# x = first_field(data) # avoid calling tree.flatten
# if isinstance(x, torch.Tensor):
device = (
data.device_safe()
if hasattr(data, "device_safe")
else (data.device if hasattr(data, "device") else torch.device("cpu"))
)
device = data.device if hasattr(data, "device") else torch.device("cpu")
weight = to_torch(weight, device, self._pin_memory)
return data, weight, index

Expand Down Expand Up @@ -672,7 +661,7 @@ def extend(
"index",
torch.zeros(
tensordicts.shape,
device=tensordicts.device_safe(),
device=tensordicts.device,
dtype=torch.int,
),
)
Expand All @@ -686,7 +675,7 @@ def extend(
idx = super().extend(tensordicts, priorities)
stacked_td.set(
"index",
torch.tensor(idx, dtype=torch.int, device=stacked_td.device_safe()),
torch.tensor(idx, dtype=torch.int, device=stacked_td.device),
inplace=True,
)
return idx
Expand Down
10 changes: 1 addition & 9 deletions torchrl/data/tensordict/metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,7 @@ def __init__(
_is_memmap = isinstance(tensor, MemmapTensor)
# FIXME: using isinstance(tensor, TensorDictBase) would likely be
# better here, but creates circular import without more refactoring
device = (
(
tensor.device_safe()
if hasattr(tensor, "device_safe")
else tensor.device
)
if not tensor.is_meta
else device
)
device = tensor.device if not tensor.is_meta else device
if _is_tensordict is None:
_is_tensordict = not _is_memmap and not isinstance(tensor, torch.Tensor)
if not _is_tensordict:
Expand Down
Loading