Skip to content

Commit

Permalink
add annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
lauritowal committed Jul 28, 2023
1 parent 001deb8 commit e8872c8
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions elk/training/burns_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,20 @@


class BurnsNorm(nn.Module):
"""Burns et al. style normalization. Minimal changes from the original code."""
"""Burns et al. style normalization. Minimal changes from the original code.
"""

def __init__(self, scale: bool = True):
super().__init__()
self.scale: bool = scale

def forward(self, x: Tensor) -> Tensor:
"""Normalizes per template
Args:
x: input of dimension (n, v, c, d) or (n, v, d)
Returns:
x_normalized: normalized output
"""
num_elements = x.shape[0]
x_normalized: Tensor = x - x.mean(dim=0) if num_elements > 1 else x

Expand All @@ -23,7 +30,8 @@ def forward(self, x: Tensor) -> Tensor:

# Compute the dimensions over which
# we want to compute the mean standard deviation
# exclude the first dimension (v)
# exclude the first dimension v,
# which is the template dimension
dims = tuple(range(1, std.dim()))

avg_norm = std.mean(dim=dims)
Expand Down

0 comments on commit e8872c8

Please sign in to comment.