Skip to content

Commit

Permalink
handle model mode
Browse files Browse the repository at this point in the history
  • Loading branch information
zzachw committed Oct 22, 2023
1 parent 49bc445 commit 97d9654
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions pyhealth/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import List, Dict, Union, Callable
from typing import List, Dict, Union, Callable, Optional

import torch
import torch.nn as nn
Expand All @@ -11,9 +11,8 @@
from sklearn.decomposition import PCA
from pyhealth.tokenizer import Tokenizer


# TODO: add support for regression
VALID_MODE = [None, "binary", "multiclass", "multilabel"]
VALID_MODE = ["binary", "multiclass", "multilabel"]


class BaseModel(ABC, nn.Module):
Expand All @@ -25,26 +24,30 @@ class BaseModel(ABC, nn.Module):
feature_keys: list of keys in samples to use as features,
e.g. ["conditions", "procedures"].
label_key: key in samples to use as label (e.g., "drugs").
mode: one of "binary", "multiclass", or "multilabel".
mode: one of "binary", "multiclass", or "multilabel". Default is None.
Note that when mode is None, some class methods may not work (e.g.,
`get_loss_function` and `prepare_y_prob`).
"""

def __init__(
self,
dataset: SampleBaseDataset,
feature_keys: List[str],
label_key: str,
mode: str,
mode: Optional[str] = None,
pretrained_emb: str = None
):
super(BaseModel, self).__init__()
assert mode in VALID_MODE, f"mode must be one of {VALID_MODE}"
if mode is not None:
assert mode in VALID_MODE, f"mode must be one of {VALID_MODE}"
self.dataset = dataset
self.feature_keys = feature_keys
self.label_key = label_key
self.mode = mode
# pretrained embedding type, should be in ["KG", "LM", None]
if pretrained_emb is not None:
assert pretrained_emb[:3] in ["KG/", "LM/"], f"pretrained_emb must start with one of ['KG/', 'LM/']"
assert pretrained_emb[:3] in ["KG/",
"LM/"], f"pretrained_emb must start with one of ['KG/', 'LM/']"
# self.rand_init_embedding = nn.ModuleDict()
# self.pretrained_embedding = nn.ModuleDict()
self.pretrained_emb = pretrained_emb
Expand Down Expand Up @@ -188,7 +191,6 @@ def add_feature_transform_layer(self, feature_key: str, info, special_tokens=Non
special_tokens=special_tokens,
)
self.feat_tokenizers[feature_key] = tokenizer

# feature embedding
if self.pretrained_emb != None:
print(f"Loading pretrained embedding for {feature_key}...")
Expand Down Expand Up @@ -233,8 +235,6 @@ def add_feature_transform_layer(self, feature_key: str, info, special_tokens=Non
self.embedding_dim,
padding_idx=tokenizer.get_padding_index(),
)


elif info["type"] in [float, int]:
self.linear_layers[feature_key] = nn.Linear(info["len"], self.embedding_dim)
else:
Expand Down

0 comments on commit 97d9654

Please sign in to comment.