Skip to content

Commit

Permalink
feat: keep num_params as an attribute when run _print_model_size();
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed May 30, 2024
1 parent e2052c6 commit 4077383
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,17 +497,20 @@ def __init__(
self.num_workers = num_workers

self.model = None
self.num_params = None
self.optimizer = None
self.best_model_dict = None
self.best_loss = float("inf")
self.best_epoch = -1

def _print_model_size(self) -> None:
"""Print the number of trainable parameters in the initialized NN model."""
num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
self.num_params = sum(
p.numel() for p in self.model.parameters() if p.requires_grad
)
logger.info(
f"{self.__class__.__name__} initialized with the given hyperparameters, "
f"the number of trainable parameters: {num_params:,}"
f"the number of trainable parameters: {self.num_params:,}"
)

@abstractmethod
Expand Down

0 comments on commit 4077383

Please sign in to comment.