Source code for homura.modules.functional.discretizations

# some functions to discretize input_forward tensors

import random

import torch
from torch.autograd import Function
from torch.distributions import RelaxedBernoulli
from torch.nn import functional as F

__all__ = ["gumbel_sigmoid", "straight_through_estimator", "semantic_hashing"]


[docs]def gumbel_sigmoid(input: torch.Tensor, temp: float) -> torch.Tensor: """ gumbel sigmoid function """ return RelaxedBernoulli(temp, probs=input.sigmoid()).rsample()
class _STE(Function): """ Straight Through Estimator """ @staticmethod def forward(ctx, input: torch.Tensor) -> torch.Tensor: return (input > 0).float() @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: return F.hardtanh(grad_output)
[docs]def straight_through_estimator(input: torch.Tensor) -> torch.Tensor: """ straight through estimator >>> straight_through_estimator(torch.randn(3, 3)) tensor([[0., 1., 0.], [0., 1., 1.], [0., 0., 1.]]) """ return _STE.apply(input)
def _saturated_sigmoid(input: torch.Tensor) -> torch.Tensor: # max(0, min(1, 1.2 * input_forward.sigmoid() - 0.1)) return F.relu(1 - F.relu(1.1 - 1.2 * input.sigmoid()))
[docs]def semantic_hashing(input: torch.Tensor, is_training: bool) -> torch.Tensor: """ Semantic hashing >>> semantic_hashing(torch.randn(3, 3), True) # by 0.5 tensor([[0.3515, 0.0918, 0.7717], [0.8246, 0.1620, 0.0689], [1.0000, 0.3575, 0.6598]]) >>> semantic_hashing(torch.randn(3, 3), False) tensor([[0., 0., 1.], [0., 1., 1.], [0., 1., 1.]]) """ v1 = _saturated_sigmoid(input) v2 = (input < 0).float().detach() if not is_training or random.random() > 0.5: # derivative is 0 + dv1/dx + 0 return v1 - v1.detach() + v2 else: return v1