Skip to content

Commit

Permalink
wrap sum() into fn flatten_list
Browse files Browse the repository at this point in the history
  • Loading branch information
zzachw committed Nov 15, 2022
1 parent 0204538 commit 7978924
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 10 deletions.
20 changes: 10 additions & 10 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@

from pyhealth.data import Patient, Event
from pyhealth.datasets.utils import MODULE_CACHE_PATH
from pyhealth.datasets.utils import hash_str, list_nested_level, is_homo_list
from pyhealth.datasets.utils import hash_str
from pyhealth.datasets.utils import list_nested_level, is_homo_list, flatten_list
from pyhealth.medcode import CrossMap
from pyhealth.utils import load_pickle, save_pickle

Expand Down Expand Up @@ -356,9 +357,7 @@ def set_task(
# 1 level nested list
if level == 1:
# a list of values of the same type
# sum() flattens the nested list
# e.g, [[1, 2], [3, 4]] -> [1, 2, 3, 4]
check = is_homo_list(sum([s[key] for s in samples], []))
check = is_homo_list(flatten_list([s[key] for s in samples]))
assert check, \
f"Key {key} has mixed types in the nested list within samples"
# 2 level nested list
Expand All @@ -370,8 +369,9 @@ def set_task(
assert all(check), \
f"Key {key} has mixed nested list levels within samples"
# a list of list of values of the same type
# sum() flattens the nested list
check = is_homo_list(sum([l for s in samples for l in s[key]], []))
check = is_homo_list(
flatten_list([l for s in samples for l in s[key]])
)
assert check, \
f"Key {key} has mixed types in the nested list within samples"

Expand Down Expand Up @@ -459,8 +459,7 @@ def get_all_tokens(
continue
# a list of lists of values
elif type(sample[key][0]) == list:
# sum() flattens the nested list
tokens.extend(sum(sample[key], []))
tokens.extend(flatten_list(sample[key]))
# a list of values
else:
tokens.extend(sample[key])
Expand Down Expand Up @@ -568,8 +567,9 @@ def task_stat(self) -> str:
nested = [isinstance(e, list) for s in self.samples for e in s[key]]
# key's feature is a list of lists
if any(nested):
# sum() flattens the nested list
num_events = [len(sum(sample[key], [])) for sample in self.samples]
num_events = [
len(flatten_list(sample[key])) for sample in self.samples
]
# key's feature is a list of values
else:
num_events = [len(sample[key]) for sample in self.samples]
Expand Down
19 changes: 19 additions & 0 deletions pyhealth/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ def strptime(s: str) -> Optional[datetime]:
return dateutil_parse(s)


def flatten_list(l: List) -> List:
"""Flattens a list of list.
Args:
l: List, the list of list to be flattened.
Returns:
List, the flattened list.
Examples:
>>> flatten_list([[1], [2, 3], [4]])
[1, 2, 3, 4]R
>>> flatten_list([[1], [[2], 3], [4]])
[1, [2], 3, 4]
"""
assert isinstance(l, list), "l must be a list."
return sum(l, start=[])


def list_nested_level(l: List) -> int:
"""Gets the nested level of a list.
Expand Down

0 comments on commit 7978924

Please sign in to comment.