Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core structure updates: image and text featurizer #172

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
Prev Previous commit
Next Next commit
tmp commit
  • Loading branch information
zzachw committed Jul 12, 2023
commit 5b1cb90398c8d75aeac75f5f9fd7019280858446
4 changes: 2 additions & 2 deletions examples/medical_transcriptions_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pyhealth.datasets import MedicalTranscriptionsDataset
from pyhealth.datasets import get_dataloader
from pyhealth.models import HuggingfaceAutoModel
from pyhealth.models import TransformersModel
from pyhealth.trainer import Trainer

root = "/srv/local/data/zw12/raw_data/MedicalTranscriptions"
Expand All @@ -27,7 +27,7 @@
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)

model = HuggingfaceAutoModel(
model = TransformersModel(
model_name="emilyalsentzer/Bio_ClinicalBERT",
dataset=sample_dataset,
feature_keys=["transcription"],
Expand Down
Original file line number Diff line number Diff line change
@@ -1,30 +1,26 @@
from typing import List, Dict

import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
from torchvision import models
from transformers import AutoModel, AutoTokenizer

from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel


class HuggingfaceAutoModel(BaseModel):
"""AutoModel class for Huggingface models.
class TransformersModel(BaseModel):
"""Transformers class for Huggingface models.
"""

def __init__(
self,
model_name: str,
dataset: SampleDataset,
feature_keys: List[str],
label_key: str,
mode: str,
pretrained=False,
num_layers=18,
**kwargs,
model_name: str,
):
super(HuggingfaceAutoModel, self).__init__(
super(TransformersModel, self).__init__(
dataset=dataset,
feature_keys=feature_keys,
label_key=label_key,
Expand Down Expand Up @@ -60,7 +56,6 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:

if __name__ == "__main__":
from pyhealth.datasets import MedicalTranscriptionsDataset, get_dataloader
from torchvision import transforms

base_dataset = MedicalTranscriptionsDataset(
root="/srv/local/data/zw12/raw_data/MedicalTranscriptions"
Expand All @@ -70,14 +65,12 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:

train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True)

model = HuggingfaceAutoModel(
model_name="emilyalsentzer/Bio_ClinicalBERT",
model = TransformersModel(
dataset=sample_dataset,
feature_keys=[
"transcription",
],
feature_keys=["transcription"],
label_key="label",
mode="multiclass",
model_name="emilyalsentzer/Bio_ClinicalBERT",
)

# data batch
Expand All @@ -88,4 +81,4 @@ def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
print(ret)

# try loss backward
ret["loss"].backward()
ret["loss"].backward()