Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/xray report generation #144

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
Prev Previous commit
Next Next commit
updated documetation
  • Loading branch information
samarthkeshari committed Apr 25, 2023
commit d4aad7fd1687447a83ac3716f37a454c9089c24d
6 changes: 3 additions & 3 deletions pyhealth/metrics/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
from pycocoevalcap.meteor.meteor import Meteor

def sequence_metrics_fn(
y_true: Dict[int,str],
y_generated: Dict[int,str],
y_true: List[Dict[int,str]],
y_generated: List[Dict[int,str]],
metrics: Optional[List[str]] = None
) -> Dict[str, float]:
"""
"""Compute metrics relevant for evaluating sequences
"""
scorers = [
(Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
Expand Down
2 changes: 1 addition & 1 deletion pyhealth/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pyhealth.tokenizer import Tokenizer

# TODO: add support for regression
VALID_MODE = ["binary", "multiclass", "multilabel","sequence"]
VALID_MODE = ["binary", "multiclass", "multilabel", "sequence"]


class BaseModel(ABC, nn.Module):
Expand Down
92 changes: 52 additions & 40 deletions pyhealth/models/sentsat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ def __init__(self):
self.densenet121.classifier = nn.Identity()

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward propagation.
"""Forward propagation.
Extract fixed-length feature vectors from the input image.

Args:
x: A tensor of tranfomed image of size
[batch_size,3,224,224]
x: A tensor of transfomed image of size
[batch_size,3,512,512]
Return:
x: A tensor of image feature vectors of size
x: A tensor of image feature vectors of size
[batch_size,1024,16,16]
"""
x = self.densenet121.features(x)
Expand All @@ -35,8 +35,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
class SentSATAttention(nn.Module):
"""SAT Attention Module

Computes a set of attention weights based on the current hidden state of
the RNN and the feature vectors from the CNN, which are then used to
Computes a set of attention weights based on the current hidden state of
the RNN and the feature vectors from the CNN, which are then used to
compute a weighted average of the feature vectors.

Args:
Expand All @@ -45,18 +45,18 @@ class SentSATAttention(nn.Module):
affine_dim: affine dimension. Default is 512
"""
def __init__(
self,
k_size: int,
v_size: int,
self,
k_size: int,
v_size: int,
affine_dim: int =512):
super().__init__()
self.affine_k = nn.Linear(k_size, affine_dim, bias=False)
self.affine_v = nn.Linear(v_size, affine_dim, bias=False)
self.affine = nn.Linear(affine_dim, 1, bias=False)

def forward(
self,
k: torch.Tensor,
self,
k: torch.Tensor,
v: torch.Tensor) -> (torch.Tensor,torch.Tensor):
"""Forward propagation

Expand All @@ -77,18 +77,20 @@ def forward(
return context, alpha

class SentSATDecoder(nn.Module):
""" Word SAT decoder model for one sentence
""" Sentence SAT decoder model that treats a caption as multiple sentences

An LSTM based model that takes as input the attention-weighted feature
vector and generates a sequence of words, one at a time.
An LSTM based model that takes as input the attention-weighted feature
vector and generates a sequence of sentences and followed by words, in each
sentence

Args:
attention: attention module instance
vocab_size: vocabulary size
n_encoder_inputs: number of image inputs given to the encoder
feature_dim: encoder output feature dimesion
feature_dim: encoder output feature dimension
embedding_dim: decoder embedding dimension
hidden_dim: LSTM hidden dimension
dropout: dropout rate between [0,1]
dropout: dropout rate between [0,1]
"""
def __init__(
self,
Expand All @@ -115,29 +117,33 @@ def __init__(
hidden_dim)
self.sent_lstm = nn.LSTMCell(self.n_encoder_inputs * self.feature_dim,
hidden_dim)

self.word_lstm = nn.LSTMCell(self.embedding_dim + self.hidden_dim +
self.n_encoder_inputs * feature_dim,
hidden_dim)
self.fc = nn.Linear(self.hidden_dim, self.vocab_size)
self.dropout = nn.Dropout(dropout)

def forward(
self,
cnn_features: List[torch.Tensor],
self,
cnn_features: List[torch.Tensor],
captions: List[torch.Tensor] = None,
update_masks: torch.Tensor = None,
max_sents: int = 10,
max_len: int = 30,
stop_id: int = None) -> torch.Tensor:

"""Forward propagation

Args:
cnn_features: a list of tensors where each tensor is of
size [batch_size, feature_dim, spatial_size].
captions: a list of tensors.
updat_masks: a boolean tensor to identify the actual tokens
max_sents: maximum number of sentences that can be generated
max_len: maximum length of training or generated caption
stop_id: token id from vocabulary to stop word generation for a
sentence during inference

Returns:
logits: a tensor
Expand All @@ -161,7 +167,7 @@ def forward(

word_h = cnn_features[0].new_zeros((batch_size,
self.hidden_dim),dtype=torch.float)

word_c = cnn_features[0].new_zeros((batch_size,
self.hidden_dim),dtype=torch.float)

Expand All @@ -182,10 +188,10 @@ def forward(

for t in range(seq_len_k):
batch_mask = update_masks[:, k, t]

word_h_, word_c_ = self.word_lstm(
torch.cat((embeddings[batch_mask, k, t],
sent_h[batch_mask],
torch.cat((embeddings[batch_mask, k, t],
sent_h[batch_mask],
context[batch_mask]), dim=1),
(word_h[batch_mask], word_c[batch_mask]))

Expand All @@ -201,17 +207,17 @@ def forward(
# Evaluation/Inference phase
else:
x_t = cnn_features[0].new_full((batch_size,), 1, dtype=torch.long)

for k in range(num_sents):
contexts = [self.attend(sent_h, cnn_feat_t)[0]
for cnn_feat_t in cnn_feats_t]
context = torch.cat(contexts, dim=1)
sent_h, sent_c = self.sent_lstm(context, (sent_h, sent_c))

for t in range(seq_len):
embedding = self.embed(x_t)
word_h, word_c = self.word_lstm(
torch.cat((embedding, sent_h, context), dim=1),
torch.cat((embedding, sent_h, context), dim=1),
(word_h, word_c))
logit = self.fc(word_h)
x_t = logit.argmax(dim=1)
Expand Down Expand Up @@ -285,7 +291,7 @@ def __init__(
save_generated_caption = save_generated_caption
)
self.n_input_images = n_input_images

# Encoder component
self.encoder = SentSATEncoder()
if encoder_pretrained_weights:
Expand All @@ -312,10 +318,10 @@ def __init__(
)

def forward(
self,
decoder_maxsents: int =10,
decoder_maxlen:int = 20,
decoder_stop_id: int = None,
self,
decoder_maxsents: int =10,
decoder_maxlen:int = 20,
decoder_stop_id: int = None,
**kwargs) -> Dict[str,str]:
"""Forward propagation.

Expand Down Expand Up @@ -359,7 +365,7 @@ def forward(
kwargs[self.label_key])

# Perform predictions
logits = self.decoder(cnn_features,
logits = self.decoder(cnn_features,
captions[:, :, :-1],
update_masks,
decoder_maxsents,
Expand All @@ -374,7 +380,7 @@ def forward(
loss = self.get_loss_function()(logits, captions)
loss = loss.masked_select(loss_masks).mean()
output["loss"] = loss

with torch.no_grad():
output["y_generated"] = self._forward_inference(patient_ids,
decoder_maxsents,
Expand Down Expand Up @@ -414,12 +420,13 @@ def _prepare_batch_captions(
]
Returns:
captions_idx: an int tensor
masks: a bool tensor
loss_masks: a bool tensor for each sentence in a caption
update_masks: a bool tensor for each sentence in a caption
"""
x = self.caption_tokenizer.batch_encode_3d(captions)
captions_idx = torch.tensor(x, dtype=torch.long,
captions_idx = torch.tensor(x, dtype=torch.long,
device=self.device)

loss_masks = torch.zeros_like(captions_idx,dtype=torch.bool)
update_masks = torch.zeros_like(captions_idx,dtype=torch.bool)

Expand All @@ -429,7 +436,7 @@ def _prepare_batch_captions(
if l==0: continue
loss_masks[icap, isent, 1:l].fill_(1)
update_masks[icap, isent, :l-1].fill_(1)

return captions_idx, loss_masks, update_masks

def _forward_inference(
Expand All @@ -443,7 +450,12 @@ def _forward_inference(

Args:
patient_ids: a list of patient ids
decoder_maxsents: maximum number of sentences that can be generated
decoder_maxlen: maximum length of words in a every sentence of a
caption
cnn_features: a list of tensors
stop_id: token id from vocabulary to stop word generation for a
sentence

Returns:
generated_results: a dict with following keys
Expand All @@ -463,7 +475,7 @@ def _forward_inference(
for isent in range(pred.size(0)):
pred_tokens = self.caption_tokenizer \
.convert_indices_to_tokens(pred[isent].tolist())

words = []
for token in pred_tokens:
if token == '<start>' or token == '<pad>':
Expand Down
Loading