diff --git a/src/lightning/fabric/utilities/optimizer.py b/src/lightning/fabric/utilities/optimizer.py index e2605ceca4670..c5a9df39a44b6 100644 --- a/src/lightning/fabric/utilities/optimizer.py +++ b/src/lightning/fabric/utilities/optimizer.py @@ -14,11 +14,8 @@ from typing import Iterable -from lightning_utilities.core.apply_func import apply_to_collection -from torch import Tensor from torch.optim import Optimizer -from lightning.fabric.utilities.apply_func import move_data_to_device from lightning.fabric.utilities.types import _DEVICE @@ -30,5 +27,4 @@ def _optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> N def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None: """Moves the state of a single optimizer to the device.""" - for p, v in optimizer.state.items(): - optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True) + pass diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 802c1a17bc448..f5e90fdabf944 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -1766,7 +1766,7 @@ def current_memory(): trainer.fit(model) assert trainer.strategy.model is model - assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cpu") + assert list(trainer.optimizers[0].state.values())[0]["exp_avg_sq"].device == torch.device("cuda", 0) assert trainer.callback_metrics["train_loss"].device == torch.device("cpu") assert current_memory() <= initial