Skip to content

Commit

Permalink
DLTrainer -> Trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
zzachw committed Nov 10, 2022
1 parent 5ab1d22 commit 950d8b1
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions pyhealth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_metrics_fn(mode: str) -> Callable:
raise ValueError(f"Mode {mode} is not supported")


class DLTrainer:
class Trainer:
"""Trainer for PyTorch models.
Args:
Expand All @@ -57,7 +57,8 @@ class DLTrainer:
the model will be randomly initialized.
metrics: List of metric names to be calculated. Default is None, which
means the default metrics in each metrics_fn will be used.
device: Device to be used for training. Default is "cpu".
device: Device to be used for training. Default is None, which means
the device will be GPU if available, otherwise CPU.
enable_logging: Whether to enable logging. Default is True.
output_path: Path to save the output. Default is "./output".
exp_name: Name of the experiment. Default is current datetime.
Expand All @@ -68,11 +69,14 @@ def __init__(
model: nn.Module,
checkpoint_path: Optional[str] = None,
metrics: Optional[List[str]] = None,
device: str = "cpu",
device: Optional[str] = None,
enable_logging: bool = True,
output_path: Optional[str] = None,
exp_name: Optional[str] = None,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

self.model = model
self.metrics = metrics
self.device = device
Expand Down Expand Up @@ -290,7 +294,7 @@ def load_ckpt(self, ckpt_path: str) -> None:
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from pyhealth.utils import collate_fn_dict
from pyhealth.datasets.utils import collate_fn_dict


class MNISTDataset(Dataset):
Expand Down Expand Up @@ -356,7 +360,7 @@ def forward(self, x, y, **kwargs):

model = Model()

trainer = DLTrainer(
trainer = Trainer(
model, device="cuda" if torch.cuda.is_available() else "cpu"
)
trainer.train(train_dataloader=train_dataloader,
Expand Down

0 comments on commit 950d8b1

Please sign in to comment.