You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to implement a bert classifier based on bert-cpp. While most of the classifier layer implementation is done, I am stuck at the pooler layer.
For a bert pooler defined as the following
class BertPooler(nn.Module):
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
I have to use hidden_states[:, 0] variable which will then be passed through a dense pooler, tanh and a dense classifier layer.
I have been able to implement pooler, tanh and dense classifier layer, but I am not sure how I can take only the hidden states corresponding to the first token ie, [CLS] token which will be used for the final classification.
I would really appreciate your inputs.
Thanks in advance!
The text was updated successfully, but these errors were encountered:
Hey, thanks for the great work.
I am trying to implement a bert classifier based on bert-cpp. While most of the classifier layer implementation is done, I am stuck at the pooler layer.
For a bert pooler defined as the following
from https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py#L654
I have to use hidden_states[:, 0] variable which will then be passed through a dense pooler, tanh and a dense classifier layer.
I have been able to implement pooler, tanh and dense classifier layer, but I am not sure how I can take only the hidden states corresponding to the first token ie, [CLS] token which will be used for the final classification.
I would really appreciate your inputs.
Thanks in advance!
The text was updated successfully, but these errors were encountered: