Skip to content

Commit

Permalink
Minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jun 13, 2024
1 parent 6716770 commit 8b9c4c6
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions direct/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ def to_tensor(data: np.ndarray) -> torch.Tensor:
return torch.from_numpy(data)


def verify_fft_dtype_possible(data: torch.Tensor, dims: tuple[int, ...]) -> bool:
def verify_fft_dtype_possible(data: torch.Tensor, dims: tuple[int, int] | tuple[int, int, int]) -> bool:
"""fft and ifft can only be performed on GPU in float16 if the shapes are powers of 2. This function verifies if
this is the case.
Parameters
----------
data: torch.Tensor
dims: tuple
dims: tuple of two or three ints
Returns
-------
Expand Down

0 comments on commit 8b9c4c6

Please sign in to comment.