Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting PyTorch 2.3 #1179

Open
Kayaba-Akihiko opened this issue Jun 16, 2024 · 11 comments
Open

Supporting PyTorch 2.3 #1179

Kayaba-Akihiko opened this issue Jun 16, 2024 · 11 comments
Labels
enhancement New feature or request

Comments

@Kayaba-Akihiko
Copy link

Kayaba-Akihiko commented Jun 16, 2024

PyTorch 2.3 would be forcibly replaced with PyTorch 2.2 when installing the latest TorchIO (0.19.7), breaking Torchvision installed.

@Kayaba-Akihiko Kayaba-Akihiko added the enhancement New feature or request label Jun 16, 2024
@fepegar
Copy link
Owner

fepegar commented Jun 16, 2024 via email

@joshuacwnewton
Copy link

joshuacwnewton commented Jun 17, 2024

Also, as a side effect of not supporting PyTorch 2.3, there is also an incompatibility with the recently released NumPy 2.0.0 (since torch==2.2.x now throws the following error):

A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Upgrading to PyTorch 2.3.1 fixes this issue. So, as long as torchio doesn't support it, torchio will be incompatible with numpy==2.0.0 as well.

@romainVala
Copy link
Contributor

romainVala commented Jun 18, 2024

Hi
@fepegar why did you add the incompatibility with torch 2.3 ? (in the last commit) -> torch>=1.1,<2.3
I made a fresh new install a few weeks ago, and I am now working with torch 2.3.1 and torchio 0.19.6
it looks fine to me
any reported problem I missed ?

@BLQNXAY
Copy link

BLQNXAY commented Jun 21, 2024

I did test and torchio 0.19.6 does not force a change of pytorch version to 2.2, torchio 0.19.7 does appear to force a change of pytorch version causing trochvision to not work!

@fzimmermann89
Copy link

Rel: #1178

@fepegar
Copy link
Owner

fepegar commented Jun 28, 2024

Hi all. I tried to explain the problem in #1178 (comment) but didn't do a very good job. I've just had a baby so I'm not sleeping very well :D and don't have much time.

This is what I wrote:

This doesn't work

import torch
import torchio as tio

subject = tio.datasets.Colin27()
subject.load()
subjects = 10 * [subject]

from torch.utils.data import DataLoader
loader = DataLoader(subjects, batch_size=4)
batch = next(iter(loader))
batch.__class__ is dict

because batch is an instance of Subject.

This happens in PyTorch >=2.3 because of

The issue in that snippet is that batch is an instance of Subject which is not unexpected. We would typically (before PyTorch 2.3) expect a dict containing the same keys as the subject, and with 4D tensors instead of images.

Maybe these changes have caused this as well, but I haven't investigated (@c-winder):

@fzimmermann89, would you be able to take a look at this? The inheritance design of the Subject and its subclasses is not very elegant for historical reasons, so it might be tricky. @justusschock might be able to help as well, but I think he's a bit busy as well.

@fepegar
Copy link
Owner

fepegar commented Jun 28, 2024

@fzimmermann89
Copy link

Mh, I am still a bit confused how torchio's datasets are supposed to be used (I only ever used the augmentations)

Currently, in torch 2.3 the batch is a subject.
And subject['t1'].data is the batched data tensor.

What you describe would be similar to setting
torch.utils.data._utils.collate.default_collate_fn_map[tio.Subject] = lambda batch, _: {key: [subject[key].data for subject in batch] for key in batch[0]}

This would result in batch being a normal dictionary c containing only the batched tensors.

Without a complete overhaul, I currently only solutions using a custom collate function.
Maybe somebody else has a good idea?

@fzimmermann89
Copy link

...and congratulations 👶 :)

@fepegar
Copy link
Owner

fepegar commented Jun 30, 2024

Thank you, @fzimmermann89!

Mh, I am still a bit confused how torchio's datasets are supposed to be used (I only ever used the augmentations)

This isn't really related to datasets but to the Subject class implementation.

Currently, in torch 2.3 the batch is a subject.
And subject['t1'].data is the batched data tensor.

Yes. This is not expected. The batch should be a dictionary.

>>> import torch
>>> import torchio as tio
>>> from torch.utils.data import DataLoader

>>> subject = tio.Subject(image=tio.ScalarImage(tensor=torch.ones(1, 1, 1, 1)))
>>> subjects = 10 * [subject]

>>> loader = DataLoader(subjects, batch_size=2)
>>> batch = next(iter(loader))

Expected behavior (PyTorch < 2.3):

>>> batch
{
    'image': {
        'data': tensor([[[[[1.]]]],



        [[[[1.]]]]]),
        'affine': tensor([[[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]],

        [[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.]]], dtype=torch.float64),
        'path': ['', ''],
        'stem': ['', ''],
        'type': ['intensity', 'intensity']
    }
}
>>> batch["image"]["data"].shape
torch.Size([2, 1, 1, 1, 1])

Unexpected behavior (PyTorch 2.3):

>>> batch
Subject(Keys: ('image',); images: 1)
>>> batch["image"]["data"].shape
torch.Size([2, 1, 1, 1, 1])

The data is not lost with the new default collate function in PyTorch 2.3, but I find it confusing to have one subject with one image as that is not the case anymore, and it's definitely not what used to happen before, where the list of subjects was collated into a dictionary, and images were converted into tensors. Moreover, this subject now has 5D images, which is also unexpected in TorchIO.

I'm not sure how we can fix this with a minimal effect on users. I guess code won't break with the new changes in PyTorch and that is more important than having a batch of unexpected type? What do people think?

@rickymwalsh
Copy link

Hi,

The batch being of type Subject causes a problem when using PyTorch Lightning. I hope this is relevant here, let me know if I should raise a separate issue or raise it in the Lightning repo.

The apply_to_collection function (and now _apply_to_collection_slow) is called when
moving the batch to GPU. This is applied recursively with the below code, where data is the batch
(previously a dict, now of type Subject). https://github.com/Lightning-AI/utilities/blob/main/src/lightning_utilities/core/apply_func.py#L84

    elem_type = type(data)

    # Recursively apply to collection items
    if isinstance(data, Mapping):
        out = []
        for k, v in data.items():
            v = _apply_to_collection_slow(
                v,
                dtype,
                function,
                *args,
                wrong_dtype=wrong_dtype,
                include_none=include_none,
                allow_frozen=allow_frozen,
                **kwargs,
            )
            if include_none or v is not None:
                out.append((k, v))
        if isinstance(data, defaultdict):
            return elem_type(data.default_factory, OrderedDict(out))
        return elem_type(OrderedDict(out))

The problem comes at the last line, trying to re-apply the original type to the output, as Subject() isn't expecting
a dict as input, so we get the error:
TypeError: The path argument cannot be a dictionary

Full stacktrace: click to expand
Traceback (most recent call last):
  File "/home/rwalsh/Documents/repos/pytorch_testing/pythonProject/mwe_lightning.py", line 63, in <module>
    run()
  File "/home/rwalsh/Documents/repos/pytorch_testing/pythonProject/mwe_lightning.py", line 59, in run
    trainer.fit(model, datamodule=dm)
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
    call._call_and_handle_interrupt(
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1030, in _run_stage
    self.fit_loop.run()
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 223, in advance
    batch = call._call_strategy_hook(trainer, "batch_to_device", batch, dataloader_idx=0)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 277, in batch_to_device
    return model._apply_batch_transfer_handler(batch, device=device, dataloader_idx=dataloader_idx)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/core/module.py", line 358, in _apply_batch_transfer_handler
    batch = self._call_batch_hook("transfer_batch_to_device", batch, device, dataloader_idx)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/core/module.py", line 347, in _call_batch_hook
    return trainer_method(trainer, hook_name, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 159, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/pytorch_lightning/core/hooks.py", line 611, in transfer_batch_to_device
    return move_data_to_device(batch, device)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/lightning_fabric/utilities/apply_func.py", line 103, in move_data_to_device
    return apply_to_collection(batch, dtype=_TransferableDataType, function=batch_to)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 72, in apply_to_collection
    return _apply_to_collection_slow(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 104, in _apply_to_collection_slow
    v = _apply_to_collection_slow(
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/lightning_utilities/core/apply_func.py", line 118, in _apply_to_collection_slow
    return elem_type(OrderedDict(out))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/torchio/data/image.py", line 858, in __init__
    super().__init__(*args, **kwargs)
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/torchio/data/image.py", line 179, in __init__
    self.path = self._parse_path(path)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rwalsh/Documents/repos/pytorch_testing/env/lib/python3.11/site-packages/torchio/data/image.py", line 471, in _parse_path
    raise TypeError('The path argument cannot be a dictionary')
TypeError: The path argument cannot be a dictionary

MWE

Expected behaviour: training completes without error.
Actual behaviour: error before training starts.

conda create -p ./venv python=3.11 pytorch=2.3.1 torchio=0.19.8 lightning=2.3.1 -c pytorch -c conda-forge
# No error with PyTorch<2.3
# conda create -p ./venv python=3.11 "pytorch<2.3" torchio=0.19.8 lightning=2.3.1 -c pytorch -c conda-forge
conda activate ./venv
import torch
import torchio as tio
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from torch.utils.data import DataLoader


class ExampleDataModule(LightningDataModule):
    def setup(self, stage):
        subject = tio.datasets.Colin27()
        subject.load()
        self.subjects = 10 * [tio.Subject(subject)]

    def train_dataloader(self):
        return DataLoader(self.subjects, batch_size=2)

    def test_dataloader(self):
        return DataLoader(self.subjects, batch_size=2)


class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv3d(1, 1, 1)

    def forward(self, x):
        return torch.sigmoid(self.conv(x))

    def training_step(self, batch, batch_idx):
        inputs = batch['t1'][tio.DATA]
        targets = batch['brain'][tio.DATA].float()
        pred = self(inputs)
        return torch.nn.functional.binary_cross_entropy(pred, targets)

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.1)


def run():
    dm = ExampleDataModule()
    dm.prepare_data()
    dm.setup(stage="fit")

    model = SimpleModel()

    trainer = Trainer(num_sanity_val_steps=0, max_epochs=2, accelerator="cpu")
    trainer.fit(model, datamodule=dm)


if __name__ == '__main__':
    run()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

7 participants