diff --git a/pyhealth/trainer.py b/pyhealth/trainer.py index aa022893..5b4d9e27 100644 --- a/pyhealth/trainer.py +++ b/pyhealth/trainer.py @@ -48,7 +48,7 @@ def get_metrics_fn(mode: str) -> Callable: raise ValueError(f"Mode {mode} is not supported") -class DLTrainer: +class Trainer: """Trainer for PyTorch models. Args: @@ -57,7 +57,8 @@ class DLTrainer: the model will be randomly initialized. metrics: List of metric names to be calculated. Default is None, which means the default metrics in each metrics_fn will be used. - device: Device to be used for training. Default is "cpu". + device: Device to be used for training. Default is None, which means + the device will be GPU if available, otherwise CPU. enable_logging: Whether to enable logging. Default is True. output_path: Path to save the output. Default is "./output". exp_name: Name of the experiment. Default is current datetime. @@ -68,11 +69,14 @@ def __init__( model: nn.Module, checkpoint_path: Optional[str] = None, metrics: Optional[List[str]] = None, - device: str = "cpu", + device: Optional[str] = None, enable_logging: bool = True, output_path: Optional[str] = None, exp_name: Optional[str] = None, ): + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.model = model self.metrics = metrics self.device = device @@ -290,7 +294,7 @@ def load_ckpt(self, ckpt_path: str) -> None: import torch.nn as nn from torch.utils.data import DataLoader, Dataset from torchvision import datasets, transforms - from pyhealth.utils import collate_fn_dict + from pyhealth.datasets.utils import collate_fn_dict class MNISTDataset(Dataset): @@ -356,7 +360,7 @@ def forward(self, x, y, **kwargs): model = Model() - trainer = DLTrainer( + trainer = Trainer( model, device="cuda" if torch.cuda.is_available() else "cpu" ) trainer.train(train_dataloader=train_dataloader,