Skip to content

Commit

Permalink
4. training
Browse files Browse the repository at this point in the history
  • Loading branch information
zzachw committed Oct 7, 2023
1 parent 51a1d1a commit da31ed3
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions pyhealth/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ def train(
self.save_ckpt(os.path.join(self.exp_path, "best.ckpt"))

# load best model
if load_best_model_at_last and self.exp_path is not None and os.path.isfile(os.path.join(self.exp_path, "best.ckpt")):
if load_best_model_at_last and self.exp_path is not None and os.path.isfile(
os.path.join(self.exp_path, "best.ckpt")):
logger.info("Loaded best model")
self.load_ckpt(os.path.join(self.exp_path, "best.ckpt"))

Expand All @@ -243,7 +244,8 @@ def train(

return

def inference(self, dataloader, additional_outputs=None, return_patient_ids=False) -> Dict[str, float]:
def inference(self, dataloader, additional_outputs=None,
return_patient_ids=False) -> Dict[str, float]:
"""Model inference.
Args:
Expand Down Expand Up @@ -300,12 +302,22 @@ def evaluate(self, dataloader) -> Dict[str, float]:
Returns:
scores: a dictionary of scores.
"""
y_true_all, y_prob_all, loss_mean = self.inference(dataloader)

mode = self.model.mode
metrics_fn = get_metrics_fn(mode)
scores = metrics_fn(y_true_all, y_prob_all, metrics=self.metrics)
scores["loss"] = loss_mean
if self.model.mode is not None:
y_true_all, y_prob_all, loss_mean = self.inference(dataloader)
mode = self.model.mode
metrics_fn = get_metrics_fn(mode)
scores = metrics_fn(y_true_all, y_prob_all, metrics=self.metrics)
scores["loss"] = loss_mean
else:
loss_all = []
for data in tqdm(dataloader, desc="Evaluation"):
self.model.eval()
with torch.no_grad():
output = self.model(**data)
loss = output["loss"]
loss_all.append(loss.item())
loss_mean = sum(loss_all) / len(loss_all)
scores = {"loss": loss_mean}
return scores

def save_ckpt(self, ckpt_path: str) -> None:
Expand All @@ -329,6 +341,7 @@ def load_ckpt(self, ckpt_path: str) -> None:

from pyhealth.datasets.utils import collate_fn_dict


class MNISTDataset(Dataset):
def __init__(self, train=True):
transform = transforms.Compose(
Expand All @@ -345,6 +358,7 @@ def __getitem__(self, index):
def __len__(self):
return len(self.dataset)


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
Expand Down Expand Up @@ -376,6 +390,7 @@ def forward(self, x, y, **kwargs):
y_prob = torch.softmax(x, dim=1)
return {"loss": loss, "y_prob": y_prob, "y_true": y}


train_dataset = MNISTDataset(train=True)
val_dataset = MNISTDataset(train=False)

Expand Down

0 comments on commit da31ed3

Please sign in to comment.