Skip to content

Commit

Permalink
Merge pull request #41 from sunlabuiuc/develop
Browse files Browse the repository at this point in the history
(Redo) Improved doc & logging
  • Loading branch information
zzachw committed Dec 8, 2022
2 parents c5a15aa + 6aad348 commit b9a5d94
Show file tree
Hide file tree
Showing 11 changed files with 237 additions and 78 deletions.
5 changes: 4 additions & 1 deletion pyhealth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
os.makedirs(BASE_CACHE_PATH)

# logging
logger = logging.getLogger()
logger = logging.getLogger(__name__)
logger.propagate = False
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
74 changes: 41 additions & 33 deletions pyhealth/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from abc import ABC, abstractmethod
from abc import ABC
from collections import Counter
from copy import deepcopy
from typing import Dict, Callable, Tuple, Union, List, Optional
Expand All @@ -15,6 +15,8 @@
from pyhealth.medcode import CrossMap
from pyhealth.utils import load_pickle, save_pickle

logger = logging.getLogger(__name__)

INFO_MSG = """
dataset.patients: patient_id -> <Patient>
Expand Down Expand Up @@ -117,18 +119,19 @@ def __init__(
# check if cache exists or refresh_cache is True
if os.path.exists(self.filepath) and (not refresh_cache):
# load from cache
logging.debug(f"Loaded {self.dataset_name} base dataset from {self.filepath}")
logger.debug(
f"Loaded {self.dataset_name} base dataset from {self.filepath}")
self.patients = load_pickle(self.filepath)
else:
# load from raw data
logging.debug(f"Processing {self.dataset_name} base dataset...")
logger.debug(f"Processing {self.dataset_name} base dataset...")
# parse tables
patients = self.parse_tables()
# convert codes
patients = self._convert_code_in_patient_dict(patients)
self.patients = patients
# save to cache
logging.debug(f"Saved {self.dataset_name} base dataset to {self.filepath}")
logger.debug(f"Saved {self.dataset_name} base dataset to {self.filepath}")
save_pickle(self.patients, self.filepath)

def _load_code_mapping_tools(self) -> Dict[str, CrossMap]:
Expand Down Expand Up @@ -535,6 +538,7 @@ def stat(self) -> str:
def base_stat(self) -> str:
"""Returns some statistics of the base dataset."""
lines = list()
lines.append("")
lines.append(f"Statistics of {self.dataset_name} dataset (dev={self.dev}):")
lines.append(f"\t- Number of patients: {len(self.patients)}")
num_visits = [len(p) for p in self.patients.values()]
Expand All @@ -550,6 +554,7 @@ def base_stat(self) -> str:
f"\t- Number of events per visit in {table}: "
f"{sum(num_events) / len(num_events):.4f}"
)
lines.append("")
print("\n".join(lines))
return "\n".join(lines)

Expand All @@ -558,6 +563,7 @@ def task_stat(self) -> str:
if self.task is None:
raise ValueError("Please set task first.")
lines = list()
lines.append("")
lines.append(f"Statistics of {self.task} task:")
lines.append(f"\t- Dataset: {self.dataset_name} (dev={self.dev})")
lines.append(f"\t- Number of samples: {len(self)}")
Expand Down Expand Up @@ -596,6 +602,7 @@ def task_stat(self) -> str:
top10 = sorted(distribution.items(), key=lambda x: x[1], reverse=True)[:10]
lines.append(
f"\t\t- Distribution of {key} (Top-10): {top10}")
lines.append("")
print("\n".join(lines))
return "\n".join(lines)

Expand All @@ -604,6 +611,7 @@ def info():
"""Prints the output format."""
print(INFO_MSG)


class SampleDataset(ABC, Dataset):
"""Abstract sample dataset class.
Expand Down Expand Up @@ -668,7 +676,7 @@ def _index_visit(self) -> Dict[str, List[int]]:
for idx, sample in enumerate(self.samples):
visit_to_index.setdefault(sample["visit_id"], []).append(idx)
return visit_to_index

def get_all_tokens(
self,
key: str,
Expand Down Expand Up @@ -782,7 +790,7 @@ def get_distribution_tokens(self, key: str) -> Dict[str, int]:
Returns:
distribution: a dict mapping token to count.
"""

tokens = self.get_all_tokens(key, remove_duplicates=False, sort=False)
counter = Counter(tokens)
return counter
Expand Down Expand Up @@ -856,36 +864,36 @@ def stat(self) -> None:
]
samples2 = [
{'patient_id': 'patient-0',
'visit_id': 'visit-0',
'conditions': ['cond-33',
'cond-86',
'cond-80'],
'procedures': ['prod-11',
'prod-8',
'prod-15',
'prod-66',
'prod-91',
'prod-94'],
'label': 1},
'visit_id': 'visit-0',
'conditions': ['cond-33',
'cond-86',
'cond-80'],
'procedures': ['prod-11',
'prod-8',
'prod-15',
'prod-66',
'prod-91',
'prod-94'],
'label': 1},
{'patient_id': 'patient-0',
'visit_id': 'visit-0',
'conditions': ['cond-33',
'cond-86',
'cond-80'],
'procedures': ['prod-11',
'prod-8',
'prod-15',
'prod-66',
'prod-91',
'prod-94'],
'label': 1}
'visit_id': 'visit-0',
'conditions': ['cond-33',
'cond-86',
'cond-80'],
'procedures': ['prod-11',
'prod-8',
'prod-15',
'prod-66',
'prod-91',
'prod-94'],
'label': 1}
]

dataset = SampleDataset(
samples=samples2,
dataset_name="test")
print (dataset.stat())

print(dataset.stat())
data = iter(dataset)
print (next(data))
print (next(data))
print(next(data))
print(next(data))
8 changes: 5 additions & 3 deletions pyhealth/medcode/cross_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from pyhealth.medcode.utils import MODULE_CACHE_PATH, download_and_read_csv
from pyhealth.utils import load_pickle, save_pickle

logger = logging.getLogger(__name__)


class CrossMap:
"""Contains mapping between two medical code systems.
Expand All @@ -30,11 +32,11 @@ def __init__(
pickle_filename = f"{self.s_vocab}_to_{self.t_vocab}.pkl"
pickle_filepath = os.path.join(MODULE_CACHE_PATH, pickle_filename)
if os.path.exists(pickle_filepath) and (not refresh_cache):
logging.debug(f"Loaded {self.s_vocab}->{self.t_vocab} mapping "
logger.debug(f"Loaded {self.s_vocab}->{self.t_vocab} mapping "
f"from {pickle_filepath}")
self.mapping = load_pickle(pickle_filepath)
else:
logging.debug(f"Processing {self.s_vocab}->{self.t_vocab} mapping...")
logger.debug(f"Processing {self.s_vocab}->{self.t_vocab} mapping...")
try:
local_filename = f"{self.s_vocab}_to_{self.t_vocab}.csv"
df = download_and_read_csv(local_filename, refresh_cache)
Expand All @@ -44,7 +46,7 @@ def __init__(
self.mapping = defaultdict(list)
for _, row in df.iterrows():
self.mapping[row[self.s_vocab]].append(row[self.t_vocab])
logging.debug(f"Saved {self.s_vocab}->{self.t_vocab} mapping "
logger.debug(f"Saved {self.s_vocab}->{self.t_vocab} mapping "
f"to {pickle_filepath}")
save_pickle(self.mapping, pickle_filepath)

Expand Down
22 changes: 12 additions & 10 deletions pyhealth/medcode/inner_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from pyhealth.medcode.utils import MODULE_CACHE_PATH, download_and_read_csv
from pyhealth.utils import load_pickle, save_pickle

logger = logging.getLogger(__name__)


# TODO: add this callable method: InnerMap(vocab)
class InnerMap(ABC):
Expand All @@ -22,7 +24,7 @@ class InnerMap(ABC):
Note:
This class cannot be instantiated using `__init__()` (throws an error).
"""

@abstractmethod
def __init__(
self,
Expand All @@ -36,10 +38,10 @@ def __init__(
pickle_filepath = os.path.join(MODULE_CACHE_PATH, self.vocabulary + ".pkl")
csv_filename = self.vocabulary + ".csv"
if os.path.exists(pickle_filepath) and (not refresh_cache):
logging.debug(f"Loaded {vocabulary} code from {pickle_filepath}")
logger.debug(f"Loaded {vocabulary} code from {pickle_filepath}")
self.graph = load_pickle(pickle_filepath)
else:
logging.debug(f"Processing {vocabulary} code...")
logger.debug(f"Processing {vocabulary} code...")
df = download_and_read_csv(csv_filename, refresh_cache)
# create graph
df = df.set_index("code")
Expand All @@ -54,7 +56,7 @@ def __init__(
if "parent_code" in row:
if not pd.isna(row["parent_code"]):
self.graph.add_edge(row["parent_code"], code)
logging.debug(f"Saved {vocabulary} code to {pickle_filepath}")
logger.debug(f"Saved {vocabulary} code to {pickle_filepath}")
save_pickle(self.graph, pickle_filepath)
return

Expand Down Expand Up @@ -170,11 +172,11 @@ def get_descendants(self, code: str) -> List[str]:
)
return descendants


if __name__ == "__main__":
icd9cm = InnerMap.load("ICD9CM")
print (icd9cm.stat())
print ("428.0" in icd9cm)
print (icd9cm.lookup("4280"))
print (icd9cm.get_ancestors("428.0"))
print (icd9cm.get_descendants("428.0"))

print(icd9cm.stat())
print("428.0" in icd9cm)
print(icd9cm.lookup("4280"))
print(icd9cm.get_ancestors("428.0"))
print(icd9cm.get_descendants("428.0"))
38 changes: 38 additions & 0 deletions pyhealth/metrics/binary.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional, Dict

import numpy as np
import sklearn.metrics as sklearn_metrics

Expand All @@ -9,6 +10,43 @@ def binary_metrics_fn(
metrics: Optional[List[str]] = None,
threshold: float = 0.5,
) -> Dict[str, float]:
"""Computes metrics for binary classification.
User can specify which metrics to compute by passing a list of metric names.
The accepted metric names are:
- pr_auc: area under the precision-recall curve
- roc_auc: area under the receiver operating characteristic curve
- accuracy: accuracy score
- balanced_accuracy: balanced accuracy score (usually used for imbalanced
datasets)
- f1: f1 score
- precision: precision score
- recall: recall score
- cohen_kappa: Cohen's kappa score
- jaccard: Jaccard similarity coefficient score
If no metrics are specified, pr_auc, roc_auc and f1 are computed by default.
This function calls sklearn.metrics functions to compute the metrics. For
more information on the metrics, please refer to the documentation of the
corresponding sklearn.metrics functions.
Args:
y_true: True target values of shape (n_samples,).
y_prob: Predicted probabilities of shape (n_samples,).
metrics: List of metrics to compute. Default is ["pr_auc", "roc_auc", "f1"].
threshold: Threshold for binary classification. Default is 0.5.
Returns:
Dictionary of metrics whose keys are the metric names and values are
the metric values.
Examples:
>>> from pyhealth.metrics import binary_metrics_fn
>>> y_true = np.array([0, 0, 1, 1])
>>> y_prob = np.array([0.1, 0.4, 0.35, 0.8])
>>> binary_metrics_fn(y_true, y_prob, metrics=["accuracy"])
{'accuracy': 0.75}
"""
if metrics is None:
metrics = ["pr_auc", "roc_auc", "f1"]

Expand Down
49 changes: 49 additions & 0 deletions pyhealth/metrics/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,55 @@ def multiclass_metrics_fn(
y_prob: np.ndarray,
metrics: Optional[List[str]] = None,
) -> Dict[str, float]:
"""Computes metrics for multiclass classification.
User can specify which metrics to compute by passing a list of metric names.
The accepted metric names are:
- roc_auc_macro_ovo: area under the receiver operating characteristic curve,
macro averaged over one-vs-one multiclass classification
- roc_auc_macro_ovr: area under the receiver operating characteristic curve,
macro averaged over one-vs-rest multiclass classification
- roc_auc_weighted_ovo: area under the receiver operating characteristic curve,
weighted averaged over one-vs-one multiclass classification
- roc_auc_weighted_ovr: area under the receiver operating characteristic curve,
weighted averaged over one-vs-rest multiclass classification
- accuracy: accuracy score
- balanced_accuracy: balanced accuracy score (usually used for imbalanced
datasets)
- f1_micro: f1 score, micro averaged
- f1_macro: f1 score, macro averaged
- f1_weighted: f1 score, weighted averaged
- jaccard_micro: Jaccard similarity coefficient score, micro averaged
- jaccard_macro: Jaccard similarity coefficient score, macro averaged
- jaccard_weighted: Jaccard similarity coefficient score, weighted averaged
- cohen_kappa: Cohen's kappa score
If no metrics are specified, accuracy, f1_macro, and f1_micro are computed
by default.
This function calls sklearn.metrics functions to compute the metrics. For
more information on the metrics, please refer to the documentation of the
corresponding sklearn.metrics functions.
Args:
y_true: True target values of shape (n_samples,).
y_prob: Predicted probabilities of shape (n_samples, n_classes).
metrics: List of metrics to compute. Default is ["accuracy", "f1_macro",
"f1_micro"].
Returns:
Dictionary of metrics whose keys are the metric names and values are
the metric values.
Examples:
>>> from pyhealth.metrics import multiclass_metrics_fn
>>> y_true = np.array([0, 1, 2, 2])
>>> y_prob = np.array([[0.9, 0.05, 0.05],
... [0.05, 0.9, 0.05],
... [0.05, 0.05, 0.9],
... [0.6, 0.2, 0.2]])
>>> multiclass_metrics_fn(y_true, y_prob, metrics=["accuracy"])
{'accuracy': 0.75}
"""
if metrics is None:
metrics = ["accuracy", "f1_macro", "f1_micro"]

Expand Down
Loading

0 comments on commit b9a5d94

Please sign in to comment.