-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Supporting PyTorch 2.3 #1179
Comments
Also, as a side effect of not supporting PyTorch 2.3, there is also an incompatibility with the recently released NumPy 2.0.0 (since
Upgrading to PyTorch 2.3.1 fixes this issue. So, as long as |
Hi |
I did test and torchio 0.19.6 does not force a change of pytorch version to 2.2, torchio 0.19.7 does appear to force a change of pytorch version causing trochvision to not work! |
Rel: #1178 |
Hi all. I tried to explain the problem in #1178 (comment) but didn't do a very good job. I've just had a baby so I'm not sleeping very well :D and don't have much time. This is what I wrote:
The issue in that snippet is that Maybe these changes have caused this as well, but I haven't investigated (@c-winder): @fzimmermann89, would you be able to take a look at this? The inheritance design of the |
Here are the changes in the PyTorch commit that broke TorchIO: https://github.com/pytorch/pytorch/pull/120553/files#diff-ab54604ec520467537cb424daad0f56b2f5702e2f4d7fa2e9f68c9def296ccdf |
Mh, I am still a bit confused how torchio's datasets are supposed to be used (I only ever used the augmentations) Currently, in torch 2.3 the batch is a subject. What you describe would be similar to setting This would result in batch being a normal dictionary c containing only the batched tensors. Without a complete overhaul, I currently only solutions using a custom collate function. |
...and congratulations 👶 :) |
Thank you, @fzimmermann89!
This isn't really related to datasets but to the
Yes. This is not expected. The batch should be a dictionary. >>> import torch
>>> import torchio as tio
>>> from torch.utils.data import DataLoader
>>> subject = tio.Subject(image=tio.ScalarImage(tensor=torch.ones(1, 1, 1, 1)))
>>> subjects = 10 * [subject]
>>> loader = DataLoader(subjects, batch_size=2)
>>> batch = next(iter(loader)) Expected behavior (PyTorch < 2.3): >>> batch
{
'image': {
'data': tensor([[[[[1.]]]],
[[[[1.]]]]]),
'affine': tensor([[[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]],
[[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 0.],
[0., 0., 0., 1.]]], dtype=torch.float64),
'path': ['', ''],
'stem': ['', ''],
'type': ['intensity', 'intensity']
}
}
>>> batch["image"]["data"].shape
torch.Size([2, 1, 1, 1, 1]) Unexpected behavior (PyTorch 2.3): >>> batch
Subject(Keys: ('image',); images: 1)
>>> batch["image"]["data"].shape
torch.Size([2, 1, 1, 1, 1]) The data is not lost with the new default collate function in PyTorch 2.3, but I find it confusing to have one subject with one image as that is not the case anymore, and it's definitely not what used to happen before, where the list of subjects was collated into a dictionary, and images were converted into tensors. Moreover, this subject now has 5D images, which is also unexpected in TorchIO. I'm not sure how we can fix this with a minimal effect on users. I guess code won't break with the new changes in PyTorch and that is more important than having a batch of unexpected type? What do people think? |
Hi, The batch being of type Subject causes a problem when using PyTorch Lightning. I hope this is relevant here, let me know if I should raise a separate issue or raise it in the Lightning repo. The elem_type = type(data)
# Recursively apply to collection items
if isinstance(data, Mapping):
out = []
for k, v in data.items():
v = _apply_to_collection_slow(
v,
dtype,
function,
*args,
wrong_dtype=wrong_dtype,
include_none=include_none,
allow_frozen=allow_frozen,
**kwargs,
)
if include_none or v is not None:
out.append((k, v))
if isinstance(data, defaultdict):
return elem_type(data.default_factory, OrderedDict(out))
return elem_type(OrderedDict(out)) The problem comes at the last line, trying to re-apply the original type to the output, as Subject() isn't expecting Full stacktrace: click to expandTraceback (most recent call last):
File "/home/rwalsh/Documents/repos/pytorch_testing/pythonProject/mwe_lightning.py", line 63, in <module>
run()
File "/home/rwalsh/Documents/repos/pytorch_testing/pythonProject/mwe_lightning.py", line 59, in run
trainer.fit(model, datamodule=dm)
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
call._call_and_handle_interrupt(
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
results = self._run_stage()
^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1030, in _run_stage
self.fit_loop.run()
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
self.advance()
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
self.epoch_loop.run(self._data_fetcher)
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
self.advance(data_fetcher)
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 223, in advance
batch = call._call_strategy_hook(trainer, "batch_to_device", batch, dataloader_idx=0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 277, in batch_to_device
return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/core/module.py", line 358, in _apply_batch_transfer_handler
batch = self._call_batch_hook("transfer_batch_to_device", batch, device, dataloader_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/core/module.py", line 347, in _call_batch_hook
return trainer_method(trainer, hook_name, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 159, in _call_lightning_module_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/core/hooks.py", line 611, in transfer_batch_to_device
return move_data_to_device(batch, device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/lightning_fabric/utilities/apply_func.py", line 103, in move_data_to_device
return apply_to_collection(batch, dtype=_TransferableDataType, function=batch_to)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 72, in apply_to_collection
return _apply_to_collection_slow(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 104, in _apply_to_collection_slow
v = _apply_to_collection_slow(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 118, in _apply_to_collection_slow
return elem_type(OrderedDict(out))
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/torchio/data/image.py", line 858, in __init__
super().__init__(*args, **kwargs)
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/torchio/data/image.py", line 179, in __init__
self.path = self._parse_path(path)
^^^^^^^^^^^^^^^^^^^^^^
File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/torchio/data/image.py", line 471, in _parse_path
raise TypeError('The path argument cannot be a dictionary')
TypeError: The path argument cannot be a dictionary MWEExpected behaviour: training completes without error. conda create -p ./venv python=3.11 pytorch=2.3.1 torchio=0.19.8 lightning=2.3.1 -c pytorch -c conda-forge
# No error with PyTorch<2.3
# conda create -p ./venv python=3.11 "pytorch<2.3" torchio=0.19.8 lightning=2.3.1 -c pytorch -c conda-forge
conda activate ./venv import torch
import torchio as tio
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from torch.utils.data import DataLoader
class ExampleDataModule(LightningDataModule):
def setup(self, stage):
subject = tio.datasets.Colin27()
subject.load()
self.subjects = 10 * [tio.Subject(subject)]
def train_dataloader(self):
return DataLoader(self.subjects, batch_size=2)
def test_dataloader(self):
return DataLoader(self.subjects, batch_size=2)
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv3d(1, 1, 1)
def forward(self, x):
return torch.sigmoid(self.conv(x))
def training_step(self, batch, batch_idx):
inputs = batch['t1'][tio.DATA]
targets = batch['brain'][tio.DATA].float()
pred = self(inputs)
return torch.nn.functional.binary_cross_entropy(pred, targets)
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.1)
def run():
dm = ExampleDataModule()
dm.prepare_data()
dm.setup(stage="fit")
model = SimpleModel()
trainer = Trainer(num_sanity_val_steps=0, max_epochs=2, accelerator="cpu")
trainer.fit(model, datamodule=dm)
if __name__ == '__main__':
run() |
PyTorch 2.3 would be forcibly replaced with PyTorch 2.2 when installing the latest TorchIO (0.19.7), breaking Torchvision installed.
The text was updated successfully, but these errors were encountered: