Skip to content

Commit

Permalink
Add TopK activation function (#2)
Browse files Browse the repository at this point in the history
* Add TopK activation function

* add additional buffers

* update
  • Loading branch information
TomDLT committed Jun 4, 2024
1 parent 8f74a1c commit a7b2365
Showing 1 changed file with 72 additions and 5 deletions.
77 changes: 72 additions & 5 deletions sparse_autoencoder/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Union
from typing import Callable

import torch
import torch.nn as nn
Expand Down Expand Up @@ -29,20 +29,30 @@ def __init__(
self.latent_bias = nn.Parameter(torch.zeros(n_latents))
self.activation = activation
if tied:
self.decoder = TiedTranspose(self.encoder) # type: Union[nn.Linear, TiedTranspose]
self.decoder: nn.Linear | TiedTranspose = TiedTranspose(self.encoder)
else:
self.decoder = nn.Linear(n_latents, n_inputs, bias=False)

self.stats_last_nonzero: torch.Tensor
self.latents_activation_frequency: torch.Tensor
self.latents_mean_square: torch.Tensor
self.register_buffer("stats_last_nonzero", torch.zeros(n_latents, dtype=torch.long))
self.register_buffer(
"latents_activation_frequency", torch.ones(n_latents, dtype=torch.float)
)
self.register_buffer("latents_mean_square", torch.zeros(n_latents, dtype=torch.float))

def encode_pre_act(self, x: torch.Tensor) -> torch.Tensor:
def encode_pre_act(self, x: torch.Tensor, latent_slice: slice = slice(None)) -> torch.Tensor:
"""
:param x: input data (shape: [batch, n_inputs])
:param latent_slice: slice of latents to compute
Example: latent_slice = slice(0, 10) to compute only the first 10 latents.
:return: autoencoder latents before activation (shape: [batch, n_latents])
"""
x = x - self.pre_bias
latents_pre_act = self.encoder(x) + self.latent_bias
latents_pre_act = F.linear(
x, self.encoder.weight[latent_slice], self.latent_bias[latent_slice]
)
return latents_pre_act

def encode(self, x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -77,12 +87,36 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Te
return latents_pre_act, latents, recons

@classmethod
def from_state_dict(cls, state_dict: dict[str, torch.Tensor], strict: bool = True) -> "Autoencoder":
def from_state_dict(
cls, state_dict: dict[str, torch.Tensor], strict: bool = True
) -> "Autoencoder":
n_latents, d_model = state_dict["encoder.weight"].shape
autoencoder = cls(n_latents, d_model)

# Retrieve activation
activation_class_name = state_dict.pop("activation", "ReLU")
activation_class = ACTIVATIONS_CLASSES.get(activation_class_name, nn.ReLU)
activation_state_dict = state_dict.pop("activation_state_dict", {})
if hasattr(activation_class, "from_state_dict"):
autoencoder.activation = activation_class.from_state_dict(
activation_state_dict, strict=strict
)
else:
autoencoder.activation = activation_class()
if hasattr(autoencoder.activation, "load_state_dict"):
autoencoder.activation.load_state_dict(activation_state_dict, strict=strict)

# Load remaining state dict
autoencoder.load_state_dict(state_dict, strict=strict)
return autoencoder

def state_dict(self, destination=None, prefix="", keep_vars=False):
sd = super().state_dict(destination, prefix, keep_vars)
sd[prefix + "activation"] = self.activation.__class__.__name__
if hasattr(self.activation, "state_dict"):
sd[prefix + "activation_state_dict"] = self.activation.state_dict()
return sd


class TiedTranspose(nn.Module):
def __init__(self, linear: nn.Linear):
Expand All @@ -100,3 +134,36 @@ def weight(self) -> torch.Tensor:
@property
def bias(self) -> torch.Tensor:
return self.linear.bias


class TopK(nn.Module):
def __init__(self, k: int, postact_fn: Callable = nn.ReLU()) -> None:
super().__init__()
self.k = k
self.postact_fn = postact_fn

def forward(self, x: torch.Tensor) -> torch.Tensor:
topk = torch.topk(x, k=self.k, dim=-1)
values = self.postact_fn(topk.values)
# make all other values 0
result = torch.zeros_like(x)
result.scatter_(-1, topk.indices, values)
return result

def state_dict(self, destination=None, prefix="", keep_vars=False):
state_dict = super().state_dict(destination, prefix, keep_vars)
state_dict.update({prefix + "k": self.k, prefix + "postact_fn": self.postact_fn.__class__.__name__})
return state_dict

@classmethod
def from_state_dict(cls, state_dict: dict[str, torch.Tensor], strict: bool = True) -> "TopK":
k = state_dict["k"]
postact_fn = ACTIVATIONS_CLASSES[state_dict["postact_fn"]]()
return cls(k=k, postact_fn=postact_fn)


ACTIVATIONS_CLASSES = {
"ReLU": nn.ReLU,
"Identity": nn.Identity,
"TopK": TopK,
}

0 comments on commit a7b2365

Please sign in to comment.