Skip to content

Commit

Permalink
[Refactor] Refactor contiguous (#716)
Browse files Browse the repository at this point in the history
(cherry picked from commit b4c91e8)
  • Loading branch information
vmoens committed Mar 24, 2024
1 parent e856ffd commit 91515b1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
2 changes: 0 additions & 2 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2824,8 +2824,6 @@ def is_contiguous(self) -> bool:
return all([value.is_contiguous() for _, value in self.items()])

def contiguous(self) -> T:
if self.is_contiguous():
return self
return self._fast_apply(lambda x: x.contiguous())

def rename_key_(
Expand Down
18 changes: 12 additions & 6 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,9 +2048,17 @@ def _clone(self, recurse: bool = True) -> T:
return result

def contiguous(self) -> T:
if not self.is_contiguous():
return self.clone()
return self
source = {key: value.contiguous() for key, value in self.items()}
batch_size = self.batch_size
device = self.device
out = TensorDict(
source=source,
batch_size=batch_size,
device=device,
names=self.names,
_run_checks=False,
)
return out

def empty(self, recurse=False) -> T:
if not recurse:
Expand Down Expand Up @@ -2738,11 +2746,9 @@ def is_contiguous(self) -> bool:
return all(value.is_contiguous() for value in self.values())

def contiguous(self) -> T:
if self.is_contiguous():
return self
return TensorDict(
batch_size=self.batch_size,
source={key: value for key, value in self.items()},
source={key: value.contiguous() for key, value in self.items()},
device=self.device,
names=self.names,
_run_checks=False,
Expand Down

0 comments on commit 91515b1

Please sign in to comment.