Skip to content

Commit

Permalink
Fix logging problem (#39)
Browse files Browse the repository at this point in the history
* fix duplicated logging problem

* fix rnn dropout warning

* reformat dataset statistics print
  • Loading branch information
zzachw committed Dec 7, 2022
1 parent 7c51d21 commit 6aad348
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 78 deletions.
5 changes: 4 additions & 1 deletion pyhealth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
os.makedirs(BASE_CACHE_PATH)

# logging
logger = logging.getLogger()
logger = logging.getLogger(__name__)
logger.propagate = False
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
74 changes: 41 additions & 33 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from abc import ABC, abstractmethod
from abc import ABC
from collections import Counter
from copy import deepcopy
from typing import Dict, Callable, Tuple, Union, List, Optional
Expand All @@ -15,6 +15,8 @@
from pyhealth.medcode import CrossMap
from pyhealth.utils import load_pickle, save_pickle

logger = logging.getLogger(__name__)

INFO_MSG = """
dataset.patients: patient_id -> <Patient>
Expand Down Expand Up @@ -117,18 +119,19 @@ def __init__(
# check if cache exists or refresh_cache is True
if os.path.exists(self.filepath) and (not refresh_cache):
# load from cache
logging.debug(f"Loaded {self.dataset_name} base dataset from {self.filepath}")
logger.debug(
f"Loaded {self.dataset_name} base dataset from {self.filepath}")
self.patients = load_pickle(self.filepath)
else:
# load from raw data
logging.debug(f"Processing {self.dataset_name} base dataset...")
logger.debug(f"Processing {self.dataset_name} base dataset...")
# parse tables
patients = self.parse_tables()
# convert codes
patients = self._convert_code_in_patient_dict(patients)
self.patients = patients
# save to cache
logging.debug(f"Saved {self.dataset_name} base dataset to {self.filepath}")
logger.debug(f"Saved {self.dataset_name} base dataset to {self.filepath}")
save_pickle(self.patients, self.filepath)

def _load_code_mapping_tools(self) -> Dict[str, CrossMap]:
Expand Down Expand Up @@ -535,6 +538,7 @@ def stat(self) -> str:
def base_stat(self) -> str:
"""Returns some statistics of the base dataset."""
lines = list()
lines.append("")
lines.append(f"Statistics of {self.dataset_name} dataset (dev={self.dev}):")
lines.append(f"\t- Number of patients: {len(self.patients)}")
num_visits = [len(p) for p in self.patients.values()]
Expand All @@ -550,6 +554,7 @@ def base_stat(self) -> str:
f"\t- Number of events per visit in {table}: "
f"{sum(num_events) / len(num_events):.4f}"
)
lines.append("")
print("\n".join(lines))
return "\n".join(lines)

Expand All @@ -558,6 +563,7 @@ def task_stat(self) -> str:
if self.task is None:
raise ValueError("Please set task first.")
lines = list()
lines.append("")
lines.append(f"Statistics of {self.task} task:")
lines.append(f"\t- Dataset: {self.dataset_name} (dev={self.dev})")
lines.append(f"\t- Number of samples: {len(self)}")
Expand Down Expand Up @@ -596,6 +602,7 @@ def task_stat(self) -> str:
top10 = sorted(distribution.items(), key=lambda x: x[1], reverse=True)[:10]
lines.append(
f"\t\t- Distribution of {key} (Top-10): {top10}")
lines.append("")
print("\n".join(lines))
return "\n".join(lines)

Expand All @@ -604,6 +611,7 @@ def info():
"""Prints the output format."""
print(INFO_MSG)


class SampleDataset(ABC, Dataset):
"""Abstract sample dataset class.
Expand Down Expand Up @@ -668,7 +676,7 @@ def _index_visit(self) -> Dict[str, List[int]]:
for idx, sample in enumerate(self.samples):
visit_to_index.setdefault(sample["visit_id"], []).append(idx)
return visit_to_index

def get_all_tokens(
self,
key: str,
Expand Down Expand Up @@ -782,7 +790,7 @@ def get_distribution_tokens(self, key: str) -> Dict[str, int]:
Returns:
distribution: a dict mapping token to count.
"""

tokens = self.get_all_tokens(key, remove_duplicates=False, sort=False)
counter = Counter(tokens)
return counter
Expand Down Expand Up @@ -856,36 +864,36 @@ def stat(self) -> None:
]
samples2 = [
{'patient_id': 'patient-0',
'visit_id': 'visit-0',
'conditions': ['cond-33',
'cond-86',
'cond-80'],
'procedures': ['prod-11',
'prod-8',
'prod-15',
'prod-66',
'prod-91',
'prod-94'],
'label': 1},
'visit_id': 'visit-0',
'conditions': ['cond-33',
'cond-86',
'cond-80'],
'procedures': ['prod-11',
'prod-8',
'prod-15',
'prod-66',
'prod-91',
'prod-94'],
'label': 1},
{'patient_id': 'patient-0',
'visit_id': 'visit-0',
'conditions': ['cond-33',
'cond-86',
'cond-80'],
'procedures': ['prod-11',
'prod-8',
'prod-15',
'prod-66',
'prod-91',
'prod-94'],
'label': 1}
'visit_id': 'visit-0',
'conditions': ['cond-33',
'cond-86',
'cond-80'],
'procedures': ['prod-11',
'prod-8',
'prod-15',
'prod-66',
'prod-91',
'prod-94'],
'label': 1}
]

dataset = SampleDataset(
samples=samples2,
dataset_name="test")
print (dataset.stat())

print(dataset.stat())
data = iter(dataset)
print (next(data))
print (next(data))
print(next(data))
print(next(data))
8 changes: 5 additions & 3 deletions pyhealth/medcode/cross_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from pyhealth.medcode.utils import MODULE_CACHE_PATH, download_and_read_csv
from pyhealth.utils import load_pickle, save_pickle

logger = logging.getLogger(__name__)


class CrossMap:
"""Contains mapping between two medical code systems.
Expand All @@ -30,11 +32,11 @@ def __init__(
pickle_filename = f"{self.s_vocab}_to_{self.t_vocab}.pkl"
pickle_filepath = os.path.join(MODULE_CACHE_PATH, pickle_filename)
if os.path.exists(pickle_filepath) and (not refresh_cache):
logging.debug(f"Loaded {self.s_vocab}->{self.t_vocab} mapping "
logger.debug(f"Loaded {self.s_vocab}->{self.t_vocab} mapping "
f"from {pickle_filepath}")
self.mapping = load_pickle(pickle_filepath)
else:
logging.debug(f"Processing {self.s_vocab}->{self.t_vocab} mapping...")
logger.debug(f"Processing {self.s_vocab}->{self.t_vocab} mapping...")
try:
local_filename = f"{self.s_vocab}_to_{self.t_vocab}.csv"
df = download_and_read_csv(local_filename, refresh_cache)
Expand All @@ -44,7 +46,7 @@ def __init__(
self.mapping = defaultdict(list)
for _, row in df.iterrows():
self.mapping[row[self.s_vocab]].append(row[self.t_vocab])
logging.debug(f"Saved {self.s_vocab}->{self.t_vocab} mapping "
logger.debug(f"Saved {self.s_vocab}->{self.t_vocab} mapping "
f"to {pickle_filepath}")
save_pickle(self.mapping, pickle_filepath)

Expand Down
22 changes: 12 additions & 10 deletions pyhealth/medcode/inner_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pyhealth.medcode.utils import MODULE_CACHE_PATH, download_and_read_csv
from pyhealth.utils import load_pickle, save_pickle

logger = logging.getLogger(__name__)


# TODO: add this callable method: InnerMap(vocab)
class InnerMap(ABC):
Expand All @@ -22,7 +24,7 @@ class InnerMap(ABC):
Note:
This class cannot be instantiated using `__init__()` (throws an error).
"""

@abstractmethod
def __init__(
self,
Expand All @@ -36,10 +38,10 @@ def __init__(
pickle_filepath = os.path.join(MODULE_CACHE_PATH, self.vocabulary + ".pkl")
csv_filename = self.vocabulary + ".csv"
if os.path.exists(pickle_filepath) and (not refresh_cache):
logging.debug(f"Loaded {vocabulary} code from {pickle_filepath}")
logger.debug(f"Loaded {vocabulary} code from {pickle_filepath}")
self.graph = load_pickle(pickle_filepath)
else:
logging.debug(f"Processing {vocabulary} code...")
logger.debug(f"Processing {vocabulary} code...")
df = download_and_read_csv(csv_filename, refresh_cache)
# create graph
df = df.set_index("code")
Expand All @@ -54,7 +56,7 @@ def __init__(
if "parent_code" in row:
if not pd.isna(row["parent_code"]):
self.graph.add_edge(row["parent_code"], code)
logging.debug(f"Saved {vocabulary} code to {pickle_filepath}")
logger.debug(f"Saved {vocabulary} code to {pickle_filepath}")
save_pickle(self.graph, pickle_filepath)
return

Expand Down Expand Up @@ -170,11 +172,11 @@ def get_descendants(self, code: str) -> List[str]:
)
return descendants


if __name__ == "__main__":
icd9cm = InnerMap.load("ICD9CM")
print (icd9cm.stat())
print ("428.0" in icd9cm)
print (icd9cm.lookup("4280"))
print (icd9cm.get_ancestors("428.0"))
print (icd9cm.get_descendants("428.0"))

print(icd9cm.stat())
print("428.0" in icd9cm)
print(icd9cm.lookup("4280"))
print(icd9cm.get_ancestors("428.0"))
print(icd9cm.get_descendants("428.0"))
4 changes: 2 additions & 2 deletions pyhealth/models/gamenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,14 @@ def __init__(
embedding_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout,
dropout=dropout if num_layers > 1 else 0,
batch_first=True,
)
self.proc_rnn = nn.GRU(
embedding_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout,
dropout=dropout if num_layers > 1 else 0,
batch_first=True,
)
self.query = nn.Sequential(
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/models/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
input_size,
hidden_size,
num_layers=num_layers,
dropout=dropout,
dropout=dropout if num_layers > 1 else 0,
bidirectional=bidirectional,
batch_first=True,
)
Expand Down
4 changes: 2 additions & 2 deletions pyhealth/models/safedrug.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,14 @@ def __init__(
embedding_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout,
dropout=dropout if num_layers > 1 else 0,
batch_first=True,
)
self.proc_rnn = nn.GRU(
embedding_dim,
hidden_dim,
num_layers=num_layers,
dropout=dropout,
dropout=dropout if num_layers > 1 else 0,
batch_first=True,
)
self.query = nn.Sequential(
Expand Down
Loading

0 comments on commit 6aad348

Please sign in to comment.