Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core structure updates: image and text featurizer #172

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
Prev Previous commit
Next Next commit
nan bug fixed for MedicalTranscriptionsDataset
  • Loading branch information
zzachw committed Jun 17, 2023
commit f4387540cee2a2b57b1a4eee875ed1edc95d0566
49 changes: 49 additions & 0 deletions examples/medical_transcriptions_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import torch

from pyhealth.datasets import MedicalTranscriptionsDataset
from pyhealth.datasets import get_dataloader
from pyhealth.models import HuggingfaceAutoModel
from pyhealth.trainer import Trainer

root = "/srv/local/data/zw12/raw_data/MedicalTranscriptions"
base_dataset = MedicalTranscriptionsDataset(root)

sample_dataset = base_dataset.set_task()

ratios = [0.7, 0.1, 0.2]
index = np.arange(len(sample_dataset))
np.random.shuffle(index)
s1 = int(len(sample_dataset) * ratios[0])
s2 = int(len(sample_dataset) * (ratios[0] + ratios[1]))
train_index = index[: s1]
val_index = index[s1: s2]
test_index = index[s2:]
train_dataset = torch.utils.data.Subset(sample_dataset, train_index)
val_dataset = torch.utils.data.Subset(sample_dataset, val_index)
test_dataset = torch.utils.data.Subset(sample_dataset, test_index)

train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

model = HuggingfaceAutoModel(
model_name="emilyalsentzer/Bio_ClinicalBERT",
dataset=sample_dataset,
feature_keys=["transcription"],
label_key="label",
mode="multiclass",
)

trainer = Trainer(model=model)

print(trainer.evaluate(test_dataloader))

trainer.train(
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
epochs=1,
monitor="accuracy"
)

print(trainer.evaluate(test_dataloader))
1 change: 1 addition & 0 deletions pyhealth/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
x = self.tokenizer(
x, return_tensors="pt", padding=True, truncation=True, max_length=256
)
x = x.to(self.device)
embeddings = self.model(**x).pooler_output
logits = self.fc(embeddings)
y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer)
Expand Down
5 changes: 4 additions & 1 deletion pyhealth/tasks/medical_transcriptions_classification.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass, field
from typing import Dict
import pandas as pd

from pyhealth.tasks import TaskTemplate

Expand All @@ -11,7 +12,9 @@ class MedicalTranscriptionsClassification(TaskTemplate):
output_schema: Dict[str, str] = field(default_factory=lambda: {"label": "label"})

def __call__(self, patient):
if patient["transcription"] is None or patient["medical_specialty"] is None:
if patient["transcription"] is None or pd.isna(patient["transcription"]):
return []
if patient["medical_specialty"] is None or pd.isna(patient["medical_specialty"]):
return []
sample = {
"transcription": patient["transcription"],
Expand Down