generated from kyegomez/Paper-Implementation-Template
-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
75 lines (49 loc) · 1.76 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import timeit
import torch
import torch.nn.functional as F
def standard_softmax(tensor):
return F.softmax(tensor, dim=0)
def local_softmax(tensor, num_chunks: int = 2):
#split the tensor into num chunks smaller tensor
tensors = torch.chunk(tensor, num_chunks, dim=0)
#apply softmax on each chunk and collect the results in a list
results = [
F.softmax(t, dim=0) for t in tensors
]
#concat results
concated_results = torch.cat(results, dim=0)
return concated_results
def fast_softmax(tensor):
"""
LogSumExp trick for numerical stability
tensor = torch.rand(10, 5)
result = fast_softmax(tensor)
print(result)
"""
shiftx = tensor - torch.max(tensor)
exps = torch.exp(shiftx)
return exps / torch.sum(exps)
def sparse_softmax(z, k: int = 3):
_, top_k_indices = z.topk(k, dim=0)
omega_k = top_k_indices
#compute sparse softmax transformation
exp_z = torch.exp(z)
masked_sum_exp = exp_z[omega_k].sum()
values = torch.zeros_like(z)
values[omega_k] = exp_z[omega_k] / masked_sum_exp
return values
tensor = torch.randn(10, 5)
result = sparse_softmax(tensor, k=3)
print(f'result sparse softmax: {result}')
# Benchmark function
def benchmark(func, tensor, num_iterations=10000):
timer = timeit.Timer(lambda: func(tensor))
time_taken = timer.timeit(num_iterations)
return time_taken
tensor = torch.randn(1000) # Random tensor of size 1000
# Benchmarking
num_iterations = 10000
std_time = benchmark(fast_softmax, tensor, num_iterations)
fast_time = benchmark(local_softmax, tensor, num_iterations)
print(f"Standard Softmax: {std_time:.5f} seconds for {num_iterations} iterations")
print(f"Fast Softmax: {fast_time:.5f} seconds for {num_iterations} iterations")