Skip to content

Commit

Permalink
more comments added in base_dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zzachw committed Nov 15, 2022
1 parent 7b49dc9 commit 0204538
Showing 1 changed file with 48 additions and 11 deletions.
59 changes: 48 additions & 11 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import logging
import os
from abc import ABC, abstractmethod
from collections import Counter
from copy import deepcopy
from typing import Dict, Callable, Tuple, Union
from typing import Dict, Callable, Tuple, Union, List, Optional

from torch.utils.data import Dataset
from tqdm import tqdm

from pyhealth.data import Patient, Event
from pyhealth.datasets.utils import *
from pyhealth.datasets.utils import MODULE_CACHE_PATH
from pyhealth.datasets.utils import hash_str, list_nested_level, is_homo_list
from pyhealth.medcode import CrossMap
from pyhealth.utils import load_pickle, save_pickle

Expand Down Expand Up @@ -309,20 +311,42 @@ def set_task(
):
samples.extend(self.task_fn(patient))

# validate the samples
"""
Validate the samples.
1. Check if all samples are of type dict.
2. Check if all samples have the same keys.
3. Check if "patient_id" and "visit_id" are in the keys.
4. For each key, check if it is either:
- a single value
- a list of values of the sample type
- a list of list of values of the same type
Note that in check 4, we do not restrict the type of the values
to leave more flexibility for the user. But if the user wants to
use some helper functions (e.g., `self.get_all_tokens()` and
`self.stat()`) in the dataset, we will further check the type of
the values.
"""
assert all(isinstance(s, dict) for s in samples), "Each sample should be a dict"
keys = samples[0].keys()
assert all(set(s.keys()) == set(keys) for s in samples), \
"All samples should have the same keys"
assert "patient_id" in keys, "patient_id should be in the keys"
assert "visit_id" in keys, "visit_id should be in the keys"
# each feature has to be either a single value,
# a list of values, or a list of list of values
for key in keys:
# either a single value, a list of values, or a list of list of values
# check if all the samples have the same type of feature for the key
check = is_homo_list([s[key] for s in samples])
assert check, f"Key {key} has mixed types across samples"
type_ = type(samples[0][key])
# 1 or 2 level nested list

# if key's feature is list
if type_ == list:
# All samples should either all be
# (1) a list of values, i.e, 1 level nested list
# (2) or a list of list of values, i.e., 2 level nested list
levels = set([list_nested_level(s[key]) for s in samples])
assert len(levels) == 1, \
f"Key {key} has mixed nested list levels across samples"
Expand All @@ -331,16 +355,24 @@ def set_task(
f"Key {key} has unsupported nested list level across samples"
# 1 level nested list
if level == 1:
check = [is_homo_list(s[key]) for s in samples]
assert all(check), \
# a list of values of the same type
# sum() flattens the nested list
# e.g, [[1, 2], [3, 4]] -> [1, 2, 3, 4]
check = is_homo_list(sum([s[key] for s in samples], []))
assert check, \
f"Key {key} has mixed types in the nested list within samples"
# 2 level nested list
else:
# eliminate the case list [[1, 2], 3] where the
# nested level is 2 but some elements in the outer list
# are not list
check = [is_homo_list(s[key]) for s in samples]
assert all(check), \
f"Key {key} has mixed nested list levels within samples"
check = [is_homo_list(l) for s in samples for l in s[key]]
assert all(check), \
# a list of list of values of the same type
# sum() flattens the nested list
check = is_homo_list(sum([l for s in samples for l in s[key]], []))
assert check, \
f"Key {key} has mixed types in the nested list within samples"

# set the samples
Expand Down Expand Up @@ -427,6 +459,7 @@ def get_all_tokens(
continue
# a list of lists of values
elif type(sample[key][0]) == list:
# sum() flattens the nested list
tokens.extend(sum(sample[key], []))
# a list of values
else:
Expand Down Expand Up @@ -529,14 +562,18 @@ def task_stat(self) -> str:
for key in self.samples[0]:
if key in ["patient_id", "visit_id"]:
continue
# list
# key's feature is a list
if type(self.samples[0][key]) == list:
# check if the list also contains lists
nested = [isinstance(e, list) for s in self.samples for e in s[key]]
# key's feature is a list of lists
if any(nested):
# sum() flattens the nested list
num_events = [len(sum(sample[key], [])) for sample in self.samples]
# key's feature is a list of values
else:
num_events = [len(sample[key]) for sample in self.samples]
# single value
# key's feature is a single value
else:
num_events = [1 for sample in self.samples]
lines.append(f"\t- {key}:")
Expand Down

0 comments on commit 0204538

Please sign in to comment.