diff --git a/sparse_autoencoder/model.py b/sparse_autoencoder/model.py index 3c9fd5b..f003d9e 100644 --- a/sparse_autoencoder/model.py +++ b/sparse_autoencoder/model.py @@ -1,4 +1,4 @@ -from typing import Callable, Union +from typing import Callable import torch import torch.nn as nn @@ -29,20 +29,30 @@ def __init__( self.latent_bias = nn.Parameter(torch.zeros(n_latents)) self.activation = activation if tied: - self.decoder = TiedTranspose(self.encoder) # type: Union[nn.Linear, TiedTranspose] + self.decoder: nn.Linear | TiedTranspose = TiedTranspose(self.encoder) else: self.decoder = nn.Linear(n_latents, n_inputs, bias=False) self.stats_last_nonzero: torch.Tensor + self.latents_activation_frequency: torch.Tensor + self.latents_mean_square: torch.Tensor self.register_buffer("stats_last_nonzero", torch.zeros(n_latents, dtype=torch.long)) + self.register_buffer( + "latents_activation_frequency", torch.ones(n_latents, dtype=torch.float) + ) + self.register_buffer("latents_mean_square", torch.zeros(n_latents, dtype=torch.float)) - def encode_pre_act(self, x: torch.Tensor) -> torch.Tensor: + def encode_pre_act(self, x: torch.Tensor, latent_slice: slice = slice(None)) -> torch.Tensor: """ :param x: input data (shape: [batch, n_inputs]) + :param latent_slice: slice of latents to compute + Example: latent_slice = slice(0, 10) to compute only the first 10 latents. :return: autoencoder latents before activation (shape: [batch, n_latents]) """ x = x - self.pre_bias - latents_pre_act = self.encoder(x) + self.latent_bias + latents_pre_act = F.linear( + x, self.encoder.weight[latent_slice], self.latent_bias[latent_slice] + ) return latents_pre_act def encode(self, x: torch.Tensor) -> torch.Tensor: @@ -77,12 +87,36 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Te return latents_pre_act, latents, recons @classmethod - def from_state_dict(cls, state_dict: dict[str, torch.Tensor], strict: bool = True) -> "Autoencoder": + def from_state_dict( + cls, state_dict: dict[str, torch.Tensor], strict: bool = True + ) -> "Autoencoder": n_latents, d_model = state_dict["encoder.weight"].shape autoencoder = cls(n_latents, d_model) + + # Retrieve activation + activation_class_name = state_dict.pop("activation", "ReLU") + activation_class = ACTIVATIONS_CLASSES.get(activation_class_name, nn.ReLU) + activation_state_dict = state_dict.pop("activation_state_dict", {}) + if hasattr(activation_class, "from_state_dict"): + autoencoder.activation = activation_class.from_state_dict( + activation_state_dict, strict=strict + ) + else: + autoencoder.activation = activation_class() + if hasattr(autoencoder.activation, "load_state_dict"): + autoencoder.activation.load_state_dict(activation_state_dict, strict=strict) + + # Load remaining state dict autoencoder.load_state_dict(state_dict, strict=strict) return autoencoder + def state_dict(self, destination=None, prefix="", keep_vars=False): + sd = super().state_dict(destination, prefix, keep_vars) + sd[prefix + "activation"] = self.activation.__class__.__name__ + if hasattr(self.activation, "state_dict"): + sd[prefix + "activation_state_dict"] = self.activation.state_dict() + return sd + class TiedTranspose(nn.Module): def __init__(self, linear: nn.Linear): @@ -100,3 +134,36 @@ def weight(self) -> torch.Tensor: @property def bias(self) -> torch.Tensor: return self.linear.bias + + +class TopK(nn.Module): + def __init__(self, k: int, postact_fn: Callable = nn.ReLU()) -> None: + super().__init__() + self.k = k + self.postact_fn = postact_fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + topk = torch.topk(x, k=self.k, dim=-1) + values = self.postact_fn(topk.values) + # make all other values 0 + result = torch.zeros_like(x) + result.scatter_(-1, topk.indices, values) + return result + + def state_dict(self, destination=None, prefix="", keep_vars=False): + state_dict = super().state_dict(destination, prefix, keep_vars) + state_dict.update({prefix + "k": self.k, prefix + "postact_fn": self.postact_fn.__class__.__name__}) + return state_dict + + @classmethod + def from_state_dict(cls, state_dict: dict[str, torch.Tensor], strict: bool = True) -> "TopK": + k = state_dict["k"] + postact_fn = ACTIVATIONS_CLASSES[state_dict["postact_fn"]]() + return cls(k=k, postact_fn=postact_fn) + + +ACTIVATIONS_CLASSES = { + "ReLU": nn.ReLU, + "Identity": nn.Identity, + "TopK": TopK, +}