Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core structure updates: image and text featurizer #172

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Prev Previous commit
Next Next commit
add connection to huggingface
  • Loading branch information
zzachw committed Jun 14, 2023
commit c952c591c57c0770de05923c89075bcf611e8a6d
90 changes: 90 additions & 0 deletions pyhealth/models/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import List, Dict

import torch
from transformers import AutoModel, AutoTokenizer
import torch.nn as nn
from torchvision import models

from pyhealth.datasets import SampleDataset
from pyhealth.models import BaseModel


class HuggingfaceAutoModel(BaseModel):
"""AutoModel class for Huggingface models.
"""

def __init__(
self,
model_name: str,
dataset: SampleDataset,
feature_keys: List[str],
label_key: str,
mode: str,
pretrained=False,
num_layers=18,
**kwargs,
):
super(HuggingfaceAutoModel, self).__init__(
dataset=dataset,
feature_keys=feature_keys,
label_key=label_key,
mode=mode,
)
self.model_name = model_name
self.model = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.label_tokenizer = self.get_label_tokenizer()
output_size = self.get_output_size(self.label_tokenizer)
hidden_dim = self.model.config.hidden_size
self.fc = nn.Linear(hidden_dim, output_size)

def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
"""Forward propagation."""
# concat the info within one batch (batch, channel, length)
x = kwargs[self.feature_keys[0]]
x = self.tokenizer(
x, return_tensors="pt", padding=True, truncation=True, max_length=256
)
embeddings = self.model(**x).pooler_output
logits = self.fc(embeddings)
y_true = self.prepare_labels(kwargs[self.label_key], self.label_tokenizer)
loss = self.get_loss_function()(logits, y_true)
y_prob = self.prepare_y_prob(logits)
return {
"loss": loss,
"y_prob": y_prob,
"y_true": y_true,
}


if __name__ == "__main__":
from pyhealth.datasets import MedicalTranscriptionsDataset, get_dataloader
from torchvision import transforms

base_dataset = MedicalTranscriptionsDataset(
root="/srv/local/data/zw12/raw_data/MedicalTranscriptions"
)

sample_dataset = base_dataset.set_task()

train_loader = get_dataloader(sample_dataset, batch_size=16, shuffle=True)

model = HuggingfaceAutoModel(
model_name="emilyalsentzer/Bio_ClinicalBERT",
dataset=sample_dataset,
feature_keys=[
"transcription",
],
label_key="label",
mode="multiclass",
)

# data batch
data_batch = next(iter(train_loader))

# try the model
ret = model(**data_batch)
print(ret)

# try loss backward
ret["loss"].backward()