Skip to content

Commit

Permalink
Merge pull request Unbabel#75 from Joao-Maria-Janeiro/master
Browse files Browse the repository at this point in the history
Added verification to validate whether the selected activation function exists
  • Loading branch information
ricardorei authored May 3, 2022
2 parents 40fb8ba + 70ec407 commit 5e40605
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions comet/modules/feedforward.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ def __init__(
self.ff = nn.Sequential(*modules)

def build_activation(self, activation: str) -> nn.Module:
if hasattr(nn, activation):
return getattr(nn, activation)()
if hasattr(nn, activation.title()):
return getattr(nn, activation.title())()
else:
raise Exception(f"{activation} is not a valid activation function!")

def forward(self, in_features: torch.Tensor) -> torch.Tensor:
return self.ff(in_features)

0 comments on commit 5e40605

Please sign in to comment.