import torch
from torch import nn
from .functional import (gumbel_sigmoid, semantic_hashing, straight_through_estimator)
__all__ = ["GumbelSigmoid", "StraightThroughEstimator", "SemanticHashing"]
[docs]class GumbelSigmoid(nn.Module):
""" This module outputs `gumbel_sigmoid` while training and `input.sigmoid() >= threshold` while evaluation
"""
def __init__(self,
temp: float = 0.1,
threshold: float = 0.5):
super(GumbelSigmoid, self).__init__()
self.temp = temp
self.threshold = threshold
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.training:
return gumbel_sigmoid(input, self.temp)
else:
return (input.sigmoid() >= self.threshold).float()
[docs]class StraightThroughEstimator(nn.Module):
def __init__(self):
super(StraightThroughEstimator, self).__init__()
[docs] def forward(self, input: torch.Tensor):
return straight_through_estimator(input)
[docs]class SemanticHashing(nn.Module):
def __init__(self):
super(SemanticHashing, self).__init__()
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor:
return semantic_hashing(input, self.training)