Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/xray report generation #144

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
Prev Previous commit
Next Next commit
added sentsat model
  • Loading branch information
samarthkeshari committed Apr 24, 2023
commit 00f5eb5c3891305a8d64086a4db0bca5890f851b
3 changes: 2 additions & 1 deletion pyhealth/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@
from .grasp import GRASP, GRASPLayer
from .stagenet import StageNet, StageNetLayer
from .tcn import TCN, TCNLayer
from .wordsat import WordSAT
from .wordsat import WordSAT, WordSATEncoder,WordSATDecoder, WordSATAttention
from .sentsat import SentSAT
5 changes: 4 additions & 1 deletion pyhealth/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class BaseModel(ABC, nn.Module):
feature_keys: list of keys in samples to use as features,
e.g. ["conditions", "procedures"].
label_key: key in samples to use as label (e.g., "drugs").
mode: one of "binary", "multiclass", or "multilabel".
mode: one of "binary", "multiclass", "multilabel", or sequence.
"""

def __init__(
Expand All @@ -31,13 +31,16 @@ def __init__(
feature_keys: List[str],
label_key: str,
mode: str,
save_generated_caption: bool = False
):
super(BaseModel, self).__init__()
assert mode in VALID_MODE, f"mode must be one of {VALID_MODE}"
self.dataset = dataset
self.feature_keys = feature_keys
self.label_key = label_key
self.mode = mode
if mode == "sequence":
self.save_generated_caption = save_generated_caption
# used to query the device of the model
self._dummy_param = nn.Parameter(torch.empty(0))
return
Expand Down
Loading