Skip to content

Commit

Permalink
stat bug fixed for empty list
Browse files Browse the repository at this point in the history
  • Loading branch information
zzachw committed Nov 15, 2022
1 parent bc78800 commit 3e1edb1
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 25 deletions.
95 changes: 70 additions & 25 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import logging
import os
from abc import ABC, abstractmethod
from collections import Counter
from copy import deepcopy
from typing import Optional, List, Dict, Callable, Tuple, Union
from typing import Dict, Callable, Tuple, Union

from torch.utils.data import Dataset
from tqdm import tqdm

from pyhealth.data import Patient, Event
from pyhealth.datasets.utils import MODULE_CACHE_PATH, hash_str
from pyhealth.datasets.utils import *
from pyhealth.medcode import CrossMap
from pyhealth.utils import load_pickle, save_pickle

Expand Down Expand Up @@ -309,6 +308,42 @@ def set_task(
self.patients.items(), desc=f"Generating samples for {self.task}"
):
samples.extend(self.task_fn(patient))

# validate the samples
assert all(isinstance(s, dict) for s in samples), "Each sample should be a dict"
keys = samples[0].keys()
assert all(set(s.keys()) == set(keys) for s in samples), \
"All samples should have the same keys"
assert "patient_id" in keys, "patient_id should be in the keys"
assert "visit_id" in keys, "visit_id should be in the keys"
for key in keys:
# either a single value, a list of values, or a list of list of values
check = is_homo_list([s[key] for s in samples])
assert check, f"Key {key} has mixed types across samples"
type_ = type(samples[0][key])
# 1 or 2 level nested list
if type_ == list:
levels = set([list_nested_level(s[key]) for s in samples])
assert len(levels) == 1, \
f"Key {key} has mixed nested list levels across samples"
level = levels.pop()
assert level in [1, 2], \
f"Key {key} has unsupported nested list level across samples"
# 1 level nested list
if level == 1:
check = [is_homo_list(s[key]) for s in samples]
assert all(check), \
f"Key {key} has mixed types in the nested list within samples"
# 2 level nested list
else:
check = [is_homo_list(s[key]) for s in samples]
assert all(check), \
f"Key {key} has mixed nested list levels within samples"
check = [is_homo_list(l) for s in samples for l in s[key]]
assert all(check), \
f"Key {key} has mixed types in the nested list within samples"

# set the samples
self.samples = samples
self.patient_to_index = self._index_patient()
self.visit_to_index = self._index_visit()
Expand All @@ -331,7 +366,7 @@ def _index_patient(self) -> Dict[str, List[int]]:
def _index_visit(self) -> Dict[str, List[int]]:
"""Helper function which indexes the samples by visit_id.
Will be called in set_task().
Will be called in `self.set_task()`.
Returns:
visit_to_index: a dict mapping visit_id to a list of sample indices.
Expand Down Expand Up @@ -368,7 +403,8 @@ def available_keys(self) -> List[str]:
return list(keys)

def get_all_tokens(
self, key: str,
self,
key: str,
remove_duplicates: bool = True,
sort: bool = True
) -> List[str]:
Expand All @@ -386,17 +422,22 @@ def get_all_tokens(
raise ValueError("Please set task first.")
tokens = []
for sample in self.samples:
if type(sample[key]) == list:
if len(sample[key]) == 0:
continue
# a list of lists of values
elif type(sample[key][0]) == list:
tokens.extend(sum(sample[key], []))
# a list of values
else:
tokens.extend(sample[key])
# single value
if type(sample[key]) in [bool, int, str]:
tokens.append(sample[key])
# a list of values
elif type(sample[key][0]) in [bool, int, str]:
tokens.extend(sample[key])
# a list of lists of values
elif type(sample[key][0]) == list:
tokens.extend(sum(sample[key], []))
else:
raise ValueError(f"Unknown type of {key}: {type(sample[key])}")
tokens.append(sample[key])
types = set([type(t) for t in tokens])
assert len(types) == 1, f"{key} tokens have mixed types"
assert types.pop() in [int, float, str, bool], \
f"{key} tokens have unsupported types"
if remove_duplicates:
tokens = list(set(tokens))
if sort:
Expand Down Expand Up @@ -467,6 +508,7 @@ def base_stat(self) -> str:
f"\t- Number of events per visit in {table}: "
f"{sum(num_events) / len(num_events):.4f}"
)
print("\n".join(lines))
return "\n".join(lines)

def task_stat(self) -> str:
Expand All @@ -482,29 +524,32 @@ def task_stat(self) -> str:
num_visits = len(set([sample["visit_id"] for sample in self.samples]))
lines.append(f"\t- Number of visits: {num_visits}")
lines.append(
f"\t- Number of visits per patient: {len(self) / num_patients:.4f}")
f"\t- Number of visits per patient: {len(self) / num_patients:.4f}"
)
for key in self.samples[0]:
if key in ["patient_id", "visit_id"]:
continue
if type(self.samples[0][key]) in [bool, int, str]:
num_events = [1 for sample in self.samples]
# a list of values
elif type(self.samples[0][key][0]) in [bool, int, str]:
num_events = [len(sample[key]) for sample in self.samples]
# a list of lists of values
elif type(self.samples[0][key][0]) == list:
num_events = [len(sum(sample[key], [])) for sample in self.samples]
# list
if type(self.samples[0][key]) == list:
nested = [isinstance(e, list) for s in self.samples for e in s[key]]
if any(nested):
num_events = [len(sum(sample[key], [])) for sample in self.samples]
else:
num_events = [len(sample[key]) for sample in self.samples]
# single value
else:
raise ValueError(f"Unknown type of {key}: {type(self.samples[0][key])}")
num_events = [1 for sample in self.samples]
lines.append(f"\t- {key}:")
lines.append(f"\t\t- Number of {key} per sample: "
f"{sum(num_events) / len(num_events):.4f}")
lines.append(
f"\t\t- Number of unique {key}: {len(self.get_all_tokens(key))}")
f"\t\t- Number of unique {key}: {len(self.get_all_tokens(key))}"
)
distribution = self.get_distribution_tokens(key)
top10 = sorted(distribution.items(), key=lambda x: x[1], reverse=True)[:10]
lines.append(
f"\t\t- Distribution of {key} (Top-10): {top10}")
print("\n".join(lines))
return "\n".join(lines)

@staticmethod
Expand Down
40 changes: 40 additions & 0 deletions pyhealth/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib
import os
from datetime import datetime
from typing import List
from typing import Optional

from dateutil.parser import parse as dateutil_parse
Expand Down Expand Up @@ -32,6 +33,36 @@ def strptime(s: str) -> Optional[datetime]:
return dateutil_parse(s)


def list_nested_level(l: List) -> int:
"""Gets the nested level of a list.
Args:
l: List, the list to be checked.
Returns:
int, the nested level of the list.
"""
if not isinstance(l, list):
return 0
if not l:
return 1
return 1 + max(list_nested_level(i) for i in l)


def is_homo_list(l: List) -> bool:
"""Checks if a list is homogeneous.
Args:
l: the list to be checked.
Returns:
bool, True if the list is homogeneous, False otherwise.
"""
if not l:
return True
return all(isinstance(i, type(l[0])) for i in l)


def collate_fn_dict(batch):
return {key: [d[key] for d in batch] for key in batch[0]}

Expand All @@ -41,3 +72,12 @@ def get_dataloader(dataset, batch_size, shuffle=False):
dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn_dict
)
return dataloader


if __name__ == "__main__":
print(list_nested_level([1, 2, 3]))
print(list_nested_level([1, [2], 3]))
print(list_nested_level([1, [2], [[3]]]))
print(is_homo_list([1, 2, 3]))
print(is_homo_list([1, 2, [3]]))
print(is_homo_list([1, 2.0]))

0 comments on commit 3e1edb1

Please sign in to comment.