Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez committed Sep 29, 2023
1 parent e021562 commit 35a0a2c
Showing 1 changed file with 89 additions and 4 deletions.
93 changes: 89 additions & 4 deletions local_sfmx/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,62 @@
import torch
import torch.nn.functional as F


def standard_softmax(tensor):
return F.softmax(tensor, dim=0)

#selu softmax
def selu_softmax(x):
"""
selu_softmax works by first applying the scaled exponential linear unit
(selu) activation function to the input tensor and then applying softmax.
x: input tensor
"""
#selu params
alpha, scale = 1.6732632423543772848170429916717, 1.0507009873554804934193349852946
return F.softmax(scale * F.selu(x, alpha), dim=0)

# 2. Sparsemax
def sparsemax(x):
"""
sparsemax works by first sorting the input tensor in descending order and
then applying the following formula to the sorted tensor:
sparsemax(z) = max(0, z - tau(z)) where tau(z) = (sum_i=1^k z_i - 1) / k
z: input tensor
k: number of elements to keep
"""
original_size = x.size()
x = x.view(-1, original_size[-1])
dim = 1
number_of_logits = x.size(dim)

# Translate x by max for numerical stability
x = x - torch.max(x, dim=dim, keepdim=True).values
sorted_x, _ = torch.sort(x, dim=dim, descending=True)
cumulative_values = torch.cumsum(sorted_x, dim=dim) - 1
range_values = torch.arange(start=1, end=number_of_logits + 1, device=x.device)
bound = (sorted_x - cumulative_values / range_values) > 0
rho = torch.count_nonzero(bound, dim=dim)
tau = cumulative_values.gather(dim, rho.unsqueeze(dim) - 1)
tau /= rho.to(dtype=torch.float32)
return torch.max(torch.zeros_like(x), x - tau.unsqueeze(dim)).view(original_size)

# 3. Local Softmax
def local_softmax(tensor, num_chunks: int = 2):
"""
local softmax works by splitting the input tensor into num_chunks smaller
tensors and then applying softmax on each chunk. The results are then
concatenated and returned.
tensor: input tensor
num_chunks: number of chunks to split the tensor into
"""
#split the tensor into num chunks smaller tensor
tensors = torch.chunk(tensor, num_chunks, dim=0)

Expand All @@ -21,7 +72,7 @@ def local_softmax(tensor, num_chunks: int = 2):

return concated_results


# 4. Fast Softmax
def fast_softmax(tensor):
"""
LogSumExp trick for numerical stability
Expand All @@ -38,8 +89,18 @@ def fast_softmax(tensor):
return exps / torch.sum(exps)



# 5. Sparse Softmax
def sparse_softmax(z, k: int = 3):
"""
Sparsemax works by first sorting the input tensor in descending order and
then applying the following formula to the sorted tensor:
sparsemax(z) = max(0, z - tau(z)) where tau(z) = (sum_i=1^k z_i - 1) / k
z: input tensor
k: number of elements to keep
"""
_, top_k_indices = z.topk(k, dim=0)
omega_k = top_k_indices

Expand All @@ -51,6 +112,30 @@ def sparse_softmax(z, k: int = 3):

return values

# 6. gumbelmax
def gumbelmax(x, temp=1.0, hard=False):
"""
Gumbelmax works by adding Gumbel noise to the input tensor x and then
applying softmax. The hard parameter controls whether the output will
be one-hot or a probability distribution.
x: input tensor
temp: temperature parameter
hard: if True, the returned tensor will be one-hot, otherwise a probability distribution
"""
gumbels = -torch.empty_like(x).exponential_().log()
y = x + gumbels
y = F.softmax(y / temp, dim=-1)

if hard:
y_hard = torch.zeros_like(x).scatter_(
-1,
y.argmax(dim=-1, keepdim=True),
1.0
)
y = y_hard - y.detach() + y
return y

tensor = torch.randn(10, 5)
result = sparse_softmax(tensor, k=3)
Expand All @@ -69,7 +154,7 @@ def benchmark(func, tensor, num_iterations=10000):
num_iterations = 10000

std_time = benchmark(fast_softmax, tensor, num_iterations)
fast_time = benchmark(local_softmax, tensor, num_iterations)
fast_time = benchmark(gumbelmax, tensor, num_iterations)

print(f"Standard Softmax: {std_time:.5f} seconds for {num_iterations} iterations")
print(f"Fast Softmax: {fast_time:.5f} seconds for {num_iterations} iterations")

0 comments on commit 35a0a2c

Please sign in to comment.