Skip to content

Commit

Permalink
move collate_fn_dict to datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
zzachw committed Nov 10, 2022
1 parent d8f0c0b commit 5ab1d22
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ An ML Pipeline Example
from pyhealth.tasks import drug_recommendation_mimic3_fn
from pyhealth.datasets.splitter import split_by_patient
from torch.utils.data import DataLoader
from pyhealth.utils import collate_fn_dict
from pyhealth.datasets.utils import collate_fn_dict
mimic3dataset.set_task(task_fn=drug_recommendation_mimic3_fn) # use default drugrec task
train_ds, val_ds, test_ds = split_by_patient(mimic3dataset, [0.8, 0.1, 0.1])
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ An ML Pipeline Example
from pyhealth.tasks import drug_recommendation_mimic3_fn
from pyhealth.datasets.splitter import split_by_patient
from torch.utils.data import DataLoader
from pyhealth.utils import collate_fn_dict
from pyhealth.datasets.utils import collate_fn_dict
mimic3dataset.set_task(task_fn=drug_recommendation_mimic3_fn) # use default drugrec task
train_ds, val_ds, test_ds = split_by_patient(mimic3dataset, [0.8, 0.1, 0.1])
Expand Down
2 changes: 1 addition & 1 deletion leaderboard/leaderboard_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyhealth.models import *
from pyhealth.datasets.splitter import split_by_patient
from pyhealth.tasks import *
from pyhealth.utils import collate_fn_dict
from pyhealth.datasets.utils import collate_fn_dict
from pyhealth.trainer import Trainer
from pyhealth.evaluator import evaluate
from pyhealth.metrics import *
Expand Down
1 change: 1 addition & 0 deletions pyhealth/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .omop import OMOPDataset
from .eicu import eICUDataset
from .splitter import split_by_patient, split_by_visit
from .utils import collate_fn_dict, get_dataloader
15 changes: 14 additions & 1 deletion pyhealth/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import hashlib
import os
from datetime import datetime
from typing import Optional
import hashlib

from dateutil.parser import parse as dateutil_parse
from torch.utils.data import DataLoader

from pyhealth import BASE_CACHE_PATH
from pyhealth.utils import create_directory
Expand All @@ -28,3 +30,14 @@ def strptime(s: str) -> Optional[datetime]:
if s != s:
return None
return dateutil_parse(s)


def collate_fn_dict(batch):
return {key: [d[key] for d in batch] for key in batch[0]}


def get_dataloader(dataset, batch_size, shuffle=False):
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn_dict
)
return dataloader
2 changes: 1 addition & 1 deletion pyhealth/models/test/test_gamenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.models import GAMENet
from pyhealth.tasks import drug_recommendation_mimic3_fn
from pyhealth.utils import collate_fn_dict
from pyhealth.datasets.utils import collate_fn_dict

dataset = MIMIC3Dataset(
root="/srv/local/data/physionet.org/files/mimiciii/1.4",
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/models/test/test_general_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from torch.utils.data import DataLoader

from pyhealth.datasets import MIMIC3Dataset
from pyhealth.utils import collate_fn_dict
from pyhealth.datasets.utils import collate_fn_dict

# from pyhealth.models import CNN as Model
# from pyhealth.models import RNN as Model
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/models/test/test_micron.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.models import MICRON
from pyhealth.tasks import drug_recommendation_mimic3_fn
from pyhealth.utils import collate_fn_dict
from pyhealth.datasets.utils import collate_fn_dict

dataset = MIMIC3Dataset(
root="/srv/local/data/physionet.org/files/mimiciii/1.4",
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/models/test/test_safedrug.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pyhealth.datasets import MIMIC3Dataset
from pyhealth.models import SafeDrug
from pyhealth.tasks import drug_recommendation_mimic3_fn
from pyhealth.utils import collate_fn_dict
from pyhealth.datasets.utils import collate_fn_dict

dataset = MIMIC3Dataset(
root="/srv/local/data/physionet.org/files/mimiciii/1.4",
Expand Down
4 changes: 0 additions & 4 deletions pyhealth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,6 @@
import torch


def collate_fn_dict(batch):
return {key: [d[key] for d in batch] for key in batch[0]}


def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
Expand Down

0 comments on commit 5ab1d22

Please sign in to comment.