Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed Sep 29, 2023
1 parent 35a0a2c commit 13e528c
Showing 1 changed file with 39 additions and 3 deletions.
42 changes: 39 additions & 3 deletions local_sfmx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,45 @@ def gumbelmax(x, temp=1.0, hard=False):
y = y_hard - y.detach() + y
return y

tensor = torch.randn(10, 5)
result = sparse_softmax(tensor, k=3)
print(f'result sparse softmax: {result}')
# 7. Softmax with temp
def temp_softmax(x, temp=1.0):
"""
Temp softmax works by dividing the input tensor by the temperature
parameter and then applying softmax.
x: input tensor
temp: temperature parameter
"""
return F.softmax(x / temp, dim=-1)

# 8. logit scaled softmax
def logit_scaled_softmax(x, scale=1.0):
"""
logit scaled softmax works by multiplying the input tensor by the scale
parameter and then applying softmax.
x: input tensor
scale: scale parameter
"""
return F.softmax(x * scale, dim=-1)

# 9. norm exponential softmax
def norm_exp_softmax(x, scale=1.0):
"""
norm exponential softmax works by applying the following formula to the
input tensor:
norm_exp_softmax(x) = exp(scale * x) / sum(exp(scale * x))
x: input tensor
scale: scale parameter
"""
return torch.exp(
scale * x
) / torch.exp(
scale * x
).sum(dim=-1, keepdim=True)


# Benchmark function
Expand Down

0 comments on commit 13e528c

Please sign in to comment.