Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core structure updates: image and text featurizer #172

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
Prev Previous commit
Next Next commit
wrap up this pr
- add split by sample
  • Loading branch information
zzachw committed Jul 15, 2023
commit 47ee96fcac10708bda533d2b2385e1a77bbf947d
51 changes: 45 additions & 6 deletions pyhealth/datasets/splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ def split_by_visit(
np.random.shuffle(index)
train_index = index[: int(len(dataset) * ratios[0])]
val_index = index[
int(len(dataset) * ratios[0]) : int(len(dataset) * (ratios[0] + ratios[1]))
]
test_index = index[int(len(dataset) * (ratios[0] + ratios[1])) :]
int(len(dataset) * ratios[0]): int(
len(dataset) * (ratios[0] + ratios[1]))
]
test_index = index[int(len(dataset) * (ratios[0] + ratios[1])):]
train_dataset = torch.utils.data.Subset(dataset, train_index)
val_dataset = torch.utils.data.Subset(dataset, val_index)
test_dataset = torch.utils.data.Subset(dataset, test_index)
Expand Down Expand Up @@ -75,9 +76,10 @@ def split_by_patient(
np.random.shuffle(patient_indx)
train_patient_indx = patient_indx[: int(num_patients * ratios[0])]
val_patient_indx = patient_indx[
int(num_patients * ratios[0]) : int(num_patients * (ratios[0] + ratios[1]))
]
test_patient_indx = patient_indx[int(num_patients * (ratios[0] + ratios[1])) :]
int(num_patients * ratios[0]): int(
num_patients * (ratios[0] + ratios[1]))
]
test_patient_indx = patient_indx[int(num_patients * (ratios[0] + ratios[1])):]
train_index = list(
chain(*[dataset.patient_to_index[i] for i in train_patient_indx])
)
Expand All @@ -87,3 +89,40 @@ def split_by_patient(
val_dataset = torch.utils.data.Subset(dataset, val_index)
test_dataset = torch.utils.data.Subset(dataset, test_index)
return train_dataset, val_dataset, test_dataset


def split_by_sample(
dataset: SampleBaseDataset,
ratios: Union[Tuple[float, float, float], List[float]],
seed: Optional[int] = None,
):
"""Splits the dataset by sample

Args:
dataset: a `SampleBaseDataset` object
ratios: a list/tuple of ratios for train / val / test
seed: random seed for shuffling the dataset

Returns:
train_dataset, val_dataset, test_dataset: three subsets of the dataset of
type `torch.utils.data.Subset`.

Note:
The original dataset can be accessed by `train_dataset.dataset`,
`val_dataset.dataset`, and `test_dataset.dataset`.
"""
if seed is not None:
np.random.seed(seed)
assert sum(ratios) == 1.0, "ratios must sum to 1.0"
index = np.arange(len(dataset))
np.random.shuffle(index)
train_index = index[: int(len(dataset) * ratios[0])]
val_index = index[
int(len(dataset) * ratios[0]): int(
len(dataset) * (ratios[0] + ratios[1]))
]
test_index = index[int(len(dataset) * (ratios[0] + ratios[1])):]
train_dataset = torch.utils.data.Subset(dataset, train_index)
val_dataset = torch.utils.data.Subset(dataset, val_index)
test_dataset = torch.utils.data.Subset(dataset, test_index)
return train_dataset, val_dataset, test_dataset