Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the bug of tensors not on the same device when running on CUDA device #59

Merged
merged 10 commits into from
Apr 19, 2023
Merged
Prev Previous commit
Next Next commit
feat: add tb_file_saving_path as a model attribute;
  • Loading branch information
WenjieDu committed Apr 19, 2023
commit 5355e307d597c9961b2999116d627c802a224410
5 changes: 3 additions & 2 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(

# set up the summary writer for training log saving below
# initialize self.summary_writer if tb_file_saving_path is given and not None, otherwise don't save the log
self.tb_file_saving_path = None
if isinstance(tb_file_saving_path, str):

from datetime import datetime
Expand All @@ -86,12 +87,12 @@ def __init__(
time_now = datetime.now().__format__("%Y%m%d_T%H%M%S")
# the actual directory name to save the tensorboard file
actual_tb_saving_dir_name = "tensorboard_" + time_now
actual_tb_file_saving_path = os.path.join(
self.tb_file_saving_path = os.path.join(
tb_file_saving_path, actual_tb_saving_dir_name
)
# os.makedirs(actual_tb_file_saving_path) # create the dir for file saving
self.summary_writer = SummaryWriter(
actual_tb_file_saving_path, filename_suffix=".pypots"
self.tb_file_saving_path, filename_suffix=".pypots"
)

def save_log_into_tb_file(self, step: int, stage: str, loss_dict: dict) -> None:
Expand Down
27 changes: 9 additions & 18 deletions pypots/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,11 @@ def test_3_saving_path(self):
assert os.path.exists(
self.saving_path
), f"file {self.saving_path} does not exist"

# whether the tensorboard file exists
files = os.listdir(self.saving_path)
assert len(files) > 0, "tensorboard dir does not exist"
tensorboard_dir_name = files[0]
tensorboard_dir_path = os.path.join(self.saving_path, tensorboard_dir_name)
assert (
tensorboard_dir_name.startswith("tensorboard")
and len(os.listdir(tensorboard_dir_path)) > 0
self.brits.tb_file_saving_path is not None
and len(os.listdir(self.brits.tb_file_saving_path)) > 0
), "tensorboard file does not exist"

# save the trained model into file, and check if the path exists
Expand Down Expand Up @@ -152,14 +149,11 @@ def test_3_saving_path(self):
assert os.path.exists(
self.saving_path
), f"file {self.saving_path} does not exist"

# whether the tensorboard file exists
files = os.listdir(self.saving_path)
assert len(files) > 0, "tensorboard dir does not exist"
tensorboard_dir_name = files[0]
tensorboard_dir_path = os.path.join(self.saving_path, tensorboard_dir_name)
assert (
tensorboard_dir_name.startswith("tensorboard")
and len(os.listdir(tensorboard_dir_path)) > 0
self.grud.tb_file_saving_path is not None
and len(os.listdir(self.grud.tb_file_saving_path)) > 0
), "tensorboard file does not exist"

# save the trained model into file, and check if the path exists
Expand Down Expand Up @@ -236,14 +230,11 @@ def test_3_saving_path(self):
assert os.path.exists(
self.saving_path
), f"file {self.saving_path} does not exist"

# whether the tensorboard file exists
files = os.listdir(self.saving_path)
assert len(files) > 0, "tensorboard dir does not exist"
tensorboard_dir_name = files[0]
tensorboard_dir_path = os.path.join(self.saving_path, tensorboard_dir_name)
assert (
tensorboard_dir_name.startswith("tensorboard")
and len(os.listdir(tensorboard_dir_path)) > 0
self.raindrop.tb_file_saving_path is not None
and len(os.listdir(self.raindrop.tb_file_saving_path)) > 0
), "tensorboard file does not exist"

# save the trained model into file, and check if the path exists
Expand Down
18 changes: 6 additions & 12 deletions pypots/tests/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,11 @@ def test_3_saving_path(self):
assert os.path.exists(
self.saving_path
), f"file {self.saving_path} does not exist"

# whether the tensorboard file exists
files = os.listdir(self.saving_path)
assert len(files) > 0, "tensorboard dir does not exist"
tensorboard_dir_name = files[0]
tensorboard_dir_path = os.path.join(self.saving_path, tensorboard_dir_name)
assert (
tensorboard_dir_name.startswith("tensorboard")
and len(os.listdir(tensorboard_dir_path)) > 0
self.crli.tb_file_saving_path is not None
and len(os.listdir(self.crli.tb_file_saving_path)) > 0
), "tensorboard file does not exist"

# save the trained model into file, and check if the path exists
Expand Down Expand Up @@ -152,14 +149,11 @@ def test_3_saving_path(self):
assert os.path.exists(
self.saving_path
), f"file {self.saving_path} does not exist"

# whether the tensorboard file exists
files = os.listdir(self.saving_path)
assert len(files) > 0, "tensorboard dir does not exist"
tensorboard_dir_name = files[0]
tensorboard_dir_path = os.path.join(self.saving_path, tensorboard_dir_name)
assert (
tensorboard_dir_name.startswith("tensorboard")
and len(os.listdir(tensorboard_dir_path)) > 0
self.vader.tb_file_saving_path is not None
and len(os.listdir(self.vader.tb_file_saving_path)) > 0
), "tensorboard file does not exist"

# save the trained model into file, and check if the path exists
Expand Down
27 changes: 9 additions & 18 deletions pypots/tests/test_imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,11 @@ def test_3_saving_path(self):
assert os.path.exists(
self.saving_path
), f"file {self.saving_path} does not exist"

# whether the tensorboard file exists
files = os.listdir(self.saving_path)
assert len(files) > 0, "tensorboard dir does not exist"
tensorboard_dir_name = files[0]
tensorboard_dir_path = os.path.join(self.saving_path, tensorboard_dir_name)
assert (
tensorboard_dir_name.startswith("tensorboard")
and len(os.listdir(tensorboard_dir_path)) > 0
self.saits.tb_file_saving_path is not None
and len(os.listdir(self.saits.tb_file_saving_path)) > 0
), "tensorboard file does not exist"

# save the trained model into file, and check if the path exists
Expand Down Expand Up @@ -172,14 +169,11 @@ def test_3_saving_path(self):
assert os.path.exists(
self.saving_path
), f"file {self.saving_path} does not exist"

# whether the tensorboard file exists
files = os.listdir(self.saving_path)
assert len(files) > 0, "tensorboard dir does not exist"
tensorboard_dir_name = files[0]
tensorboard_dir_path = os.path.join(self.saving_path, tensorboard_dir_name)
assert (
tensorboard_dir_name.startswith("tensorboard")
and len(os.listdir(tensorboard_dir_path)) > 0
self.transformer.tb_file_saving_path is not None
and len(os.listdir(self.transformer.tb_file_saving_path)) > 0
), "tensorboard file does not exist"

# save the trained model into file, and check if the path exists
Expand Down Expand Up @@ -243,14 +237,11 @@ def test_3_saving_path(self):
assert os.path.exists(
self.saving_path
), f"file {self.saving_path} does not exist"

# whether the tensorboard file exists
files = os.listdir(self.saving_path)
assert len(files) > 0, "tensorboard dir does not exist"
tensorboard_dir_name = files[0]
tensorboard_dir_path = os.path.join(self.saving_path, tensorboard_dir_name)
assert (
tensorboard_dir_name.startswith("tensorboard")
and len(os.listdir(tensorboard_dir_path)) > 0
self.brits.tb_file_saving_path is not None
and len(os.listdir(self.brits.tb_file_saving_path)) > 0
), "tensorboard file does not exist"

# save the trained model into file, and check if the path exists
Expand Down