Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename save_model() and load_model() into save() and load() #247

Merged
merged 3 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
refactor: rename save_model() and load_model() into save() and load();
  • Loading branch information
WenjieDu committed Nov 29, 2023
commit b963235f6e9a81cc2c40da6054130f7fbbaec334
100 changes: 75 additions & 25 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def _auto_save_model_if_necessary(
self,
training_finished: bool = True,
saving_name: str = None,
):
) -> None:
"""Automatically save the current model into a file if in need.

Parameters
Expand All @@ -230,17 +230,17 @@ def _auto_save_model_if_necessary(
"""
if self.saving_path is not None and self.model_saving_strategy is not None:
name = self.__class__.__name__ if saving_name is None else saving_name
saving_path = os.path.join(self.saving_path, name)
if not training_finished and self.model_saving_strategy == "better":
self.save_model(self.saving_path, name)
self.save(saving_path)
elif training_finished and self.model_saving_strategy == "best":
self.save_model(self.saving_path, name)
else:
return
self.save(saving_path)
else:
pass

def save_model(
def save(
self,
saving_dir: str,
file_name: str,
saving_path: str,
overwrite: bool = False,
) -> None:
"""Save the model with current parameters to a disk file.
Expand All @@ -251,20 +251,20 @@ def save_model(

Parameters
----------
saving_dir :
The given directory to save the model.

file_name :
The file name of the model to be saved.
saving_path :
The given path to save the model. The directory will be created if it does not exist.

overwrite :
Whether to overwrite the model file if the path already exists.

"""
file_name = (
file_name + ".pypots" if file_name.split(".")[-1] != "pypots" else file_name
)
saving_path = os.path.join(saving_dir, file_name)
# split the saving dir and file name from the given path
saving_dir, file_name = os.path.split(saving_path)
# add the suffix ".pypots" if not given
if file_name.split(".")[-1] != "pypots":
file_name += ".pypots"
# rejoin the path for saving the model
saving_path = os.path.join(saving_path, file_name)

if os.path.exists(saving_path):
if overwrite:
Expand All @@ -274,7 +274,7 @@ def save_model(
else:
logger.error(f"File {saving_path} exists. Saving operation aborted.")
try:
create_dir_if_not_exist(saving_dir)
create_dir_if_not_exist(saving_path)
if isinstance(self.device, list):
# to save a DataParallel model generically, save the model.module.state_dict()
torch.save(self.model.module, saving_path)
Expand All @@ -286,27 +286,27 @@ def save_model(
f'Failed to save the model to "{saving_path}" because of the below error! \n{e}'
)

def load_model(self, model_path: str) -> None:
def load(self, path: str) -> None:
"""Load the saved model from a disk file.

Parameters
----------
model_path :
Local path to a disk file saving trained model.
path :
The local path to a disk file saving the trained model.

Notes
-----
If the training environment and the deploying/test environment use the same type of device (GPU/CPU),
you can load the model directly with torch.load(model_path).

"""
assert os.path.exists(model_path), f"Model file {model_path} does not exist."
assert os.path.exists(path), f"Model file {path} does not exist."

try:
if isinstance(self.device, torch.device):
loaded_model = torch.load(model_path, map_location=self.device)
loaded_model = torch.load(path, map_location=self.device)
else:
loaded_model = torch.load(model_path)
loaded_model = torch.load(path)
if isinstance(loaded_model, torch.nn.Module):
if isinstance(self.device, torch.device):
self.model.load_state_dict(loaded_model.state_dict())
Expand All @@ -316,7 +316,57 @@ def load_model(self, model_path: str) -> None:
self.model = loaded_model.model
except Exception as e:
raise e
logger.info(f"Model loaded successfully from {model_path}.")
logger.info(f"Model loaded successfully from {path}.")

def save_model(
self,
saving_path: str,
overwrite: bool = False,
) -> None:
"""Save the model with current parameters to a disk file.

A ``.pypots`` extension will be appended to the filename if it does not already have one.
Please note that such an extension is not necessary, but to indicate the saved model is from PyPOTS framework
so people can distinguish.

Parameters
----------
saving_path :
The given path to save the model. The directory will be created if it does not exist.

overwrite :
Whether to overwrite the model file if the path already exists.

Warnings
--------
The method save_model is deprecated. Please use `save()` instead.
"""
logger.warning(
"🚨DeprecationWarning: The method save_model is deprecated. Please use `save()` instead."
)
self.save(saving_path, overwrite)

def load_model(self, path: str) -> None:
"""Load the saved model from a disk file.

Parameters
----------
path :
The local path to a disk file saving the trained model.

Notes
-----
If the training environment and the deploying/test environment use the same type of device (GPU/CPU),
you can load the model directly with torch.load(model_path).

Warnings
--------
The method load_model is deprecated. Please use `load()` instead.
"""
logger.warning(
"🚨DeprecationWarning: The method load_model is deprecated. Please use `load()` instead."
)
self.load(path)

@abstractmethod
def fit(
Expand Down
8 changes: 3 additions & 5 deletions tests/classification/brits.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.brits)

# save the trained model into file, and check if the path exists
self.brits.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.brits.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.brits.load_model(saved_model_path)
self.brits.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/classification/grud.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.grud)

# save the trained model into file, and check if the path exists
self.grud.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.grud.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.grud.load_model(saved_model_path)
self.grud.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/classification/raindrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.raindrop)

# save the trained model into file, and check if the path exists
self.raindrop.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.raindrop.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.raindrop.load_model(saved_model_path)
self.raindrop.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/clustering/crli.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.crli_gru)

# save the trained model into file, and check if the path exists
self.crli_gru.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.crli_gru.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.crli_gru.load_model(saved_model_path)
self.crli_gru.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/clustering/vader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.vader)

# save the trained model into file, and check if the path exists
self.vader.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.vader.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.vader.load_model(saved_model_path)
self.vader.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/imputation/brits.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.brits)

# save the trained model into file, and check if the path exists
self.brits.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.brits.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.brits.load_model(saved_model_path)
self.brits.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/imputation/csdi.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.csdi)

# save the trained model into file, and check if the path exists
self.csdi.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.csdi.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.csdi.load_model(saved_model_path)
self.csdi.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/imputation/gpvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.gp_vae)

# save the trained model into file, and check if the path exists
self.gp_vae.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.gp_vae.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.gp_vae.load_model(saved_model_path)
self.gp_vae.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/imputation/mrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.mrnn)

# save the trained model into file, and check if the path exists
self.mrnn.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.mrnn.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.mrnn.load_model(saved_model_path)
self.mrnn.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/imputation/saits.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.saits)

# save the trained model into file, and check if the path exists
self.saits.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.saits.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.saits.load_model(saved_model_path)
self.saits.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/imputation/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.transformer)

# save the trained model into file, and check if the path exists
self.transformer.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.transformer.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.transformer.load_model(saved_model_path)
self.transformer.load(saved_model_path)


if __name__ == "__main__":
Expand Down
8 changes: 3 additions & 5 deletions tests/imputation/usgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,11 @@ def test_3_saving_path(self):
check_tb_and_model_checkpoints_existence(self.us_gan)

# save the trained model into file, and check if the path exists
self.us_gan.save_model(
saving_dir=self.saving_path, file_name=self.model_save_name
)
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.us_gan.save(saved_model_path)

# test loading the saved model, not necessary, but need to test
saved_model_path = os.path.join(self.saving_path, self.model_save_name)
self.us_gan.load_model(saved_model_path)
self.us_gan.load(saved_model_path)


if __name__ == "__main__":
Expand Down
Loading