diff --git a/local_sfmx/main.py b/local_sfmx/main.py index c52a60b..2816e85 100644 --- a/local_sfmx/main.py +++ b/local_sfmx/main.py @@ -3,11 +3,62 @@ import torch import torch.nn.functional as F - def standard_softmax(tensor): return F.softmax(tensor, dim=0) +#selu softmax +def selu_softmax(x): + """ + selu_softmax works by first applying the scaled exponential linear unit + (selu) activation function to the input tensor and then applying softmax. + + x: input tensor + """ + #selu params + alpha, scale = 1.6732632423543772848170429916717, 1.0507009873554804934193349852946 + return F.softmax(scale * F.selu(x, alpha), dim=0) + +# 2. Sparsemax +def sparsemax(x): + """ + sparsemax works by first sorting the input tensor in descending order and + then applying the following formula to the sorted tensor: + + sparsemax(z) = max(0, z - tau(z)) where tau(z) = (sum_i=1^k z_i - 1) / k + + z: input tensor + k: number of elements to keep + + + """ + original_size = x.size() + x = x.view(-1, original_size[-1]) + dim = 1 + number_of_logits = x.size(dim) + + # Translate x by max for numerical stability + x = x - torch.max(x, dim=dim, keepdim=True).values + sorted_x, _ = torch.sort(x, dim=dim, descending=True) + cumulative_values = torch.cumsum(sorted_x, dim=dim) - 1 + range_values = torch.arange(start=1, end=number_of_logits + 1, device=x.device) + bound = (sorted_x - cumulative_values / range_values) > 0 + rho = torch.count_nonzero(bound, dim=dim) + tau = cumulative_values.gather(dim, rho.unsqueeze(dim) - 1) + tau /= rho.to(dtype=torch.float32) + return torch.max(torch.zeros_like(x), x - tau.unsqueeze(dim)).view(original_size) + +# 3. Local Softmax def local_softmax(tensor, num_chunks: int = 2): + """ + local softmax works by splitting the input tensor into num_chunks smaller + tensors and then applying softmax on each chunk. The results are then + concatenated and returned. + + tensor: input tensor + num_chunks: number of chunks to split the tensor into + + + """ #split the tensor into num chunks smaller tensor tensors = torch.chunk(tensor, num_chunks, dim=0) @@ -21,7 +72,7 @@ def local_softmax(tensor, num_chunks: int = 2): return concated_results - +# 4. Fast Softmax def fast_softmax(tensor): """ LogSumExp trick for numerical stability @@ -38,8 +89,18 @@ def fast_softmax(tensor): return exps / torch.sum(exps) - +# 5. Sparse Softmax def sparse_softmax(z, k: int = 3): + """ + Sparsemax works by first sorting the input tensor in descending order and + then applying the following formula to the sorted tensor: + + sparsemax(z) = max(0, z - tau(z)) where tau(z) = (sum_i=1^k z_i - 1) / k + + z: input tensor + k: number of elements to keep + + """ _, top_k_indices = z.topk(k, dim=0) omega_k = top_k_indices @@ -51,6 +112,30 @@ def sparse_softmax(z, k: int = 3): return values +# 6. gumbelmax +def gumbelmax(x, temp=1.0, hard=False): + """ + Gumbelmax works by adding Gumbel noise to the input tensor x and then + applying softmax. The hard parameter controls whether the output will + be one-hot or a probability distribution. + + x: input tensor + temp: temperature parameter + hard: if True, the returned tensor will be one-hot, otherwise a probability distribution + + """ + gumbels = -torch.empty_like(x).exponential_().log() + y = x + gumbels + y = F.softmax(y / temp, dim=-1) + + if hard: + y_hard = torch.zeros_like(x).scatter_( + -1, + y.argmax(dim=-1, keepdim=True), + 1.0 + ) + y = y_hard - y.detach() + y + return y tensor = torch.randn(10, 5) result = sparse_softmax(tensor, k=3) @@ -69,7 +154,7 @@ def benchmark(func, tensor, num_iterations=10000): num_iterations = 10000 std_time = benchmark(fast_softmax, tensor, num_iterations) -fast_time = benchmark(local_softmax, tensor, num_iterations) +fast_time = benchmark(gumbelmax, tensor, num_iterations) print(f"Standard Softmax: {std_time:.5f} seconds for {num_iterations} iterations") print(f"Fast Softmax: {fast_time:.5f} seconds for {num_iterations} iterations")