Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Core structure updates: image and text featurizer #172

Draft
wants to merge 17 commits into
base: master
Choose a base branch
from
Draft
Prev Previous commit
Next Next commit
wrap up the pr: import bug fix
  • Loading branch information
zzachw committed Jul 16, 2023
commit 19e5abd69901096bd435300e5d5bf0cec1543eb3
9 changes: 6 additions & 3 deletions pyhealth/models/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
elif "densenet" in model:
SUPPORTED_MODELS_FINAL_LAYER[model] = "classifier"
elif "vit" in model:
SUPPORTED_MODELS_FINAL_LAYER[model] = "heads"
SUPPORTED_MODELS_FINAL_LAYER[model] = "heads.head"
elif "swin" in model:
SUPPORTED_MODELS_FINAL_LAYER[model] = "head"
else:
Expand Down Expand Up @@ -109,7 +109,10 @@ def __init__(

self.model = torchvision.models.get_model(model_name, **model_config)
final_layer_name = SUPPORTED_MODELS_FINAL_LAYER[model_name]
hidden_dim = getattr(self.model, final_layer_name).in_features
final_layer = self.model
for name in final_layer_name.split("."):
final_layer = getattr(final_layer, name)
hidden_dim = final_layer.in_features
self.label_tokenizer = self.get_label_tokenizer()
output_size = self.get_output_size(self.label_tokenizer)
setattr(self.model, final_layer_name, nn.Linear(hidden_dim, output_size))
Expand Down Expand Up @@ -164,7 +167,7 @@ def encode(sample):
feature_keys=["path"],
label_key="label",
mode="multiclass",
model_name="swin_t",
model_name="vit_b_16",
model_config={"weights": "DEFAULT"},
)

Expand Down
2 changes: 1 addition & 1 deletion pyhealth/tasks/covid19_cxr_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from typing import Dict

from pyhealth.tasks import TaskTemplate
from pyhealth.tasks.task_template import TaskTemplate


@dataclass(frozen=True)
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/tasks/medical_transcriptions_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Dict
import pandas as pd

from pyhealth.tasks import TaskTemplate
from pyhealth.tasks.task_template import TaskTemplate


@dataclass(frozen=True)
Expand Down