Skip to content

Commit

Permalink
chore: improving CI speed
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Apr 9, 2024
1 parent 8417505 commit 9e3863c
Show file tree
Hide file tree
Showing 17 changed files with 132 additions and 182 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ max-complexity = 25
extend-select = E9, F63, F7, F82
show-source = true
statistics = true
exclude = ./wandb/*, ./research/wandb/*
exclude = ./wandb/*, ./research/wandb/*, .venv/*
32 changes: 28 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ on:
push:
branches:
- main
- clean_up_repo
pull_request:
branches:
- main
Expand All @@ -25,14 +24,39 @@ jobs:
python-version: ["3.10", "3.11"]

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Cache Huggingface assets
uses: actions/cache@v4
with:
key: huggingface-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
path: ~/.cache/huggingface
restore-keys: |
huggingface-${{ runner.os }}-${{ matrix.python-version }}-
- name: Load cached Poetry installation
id: cached-poetry
uses: actions/cache@v4
with:
path: ~/.local # the path depends on the OS
key: poetry-${{ runner.os }}-${{ matrix.python-version }}-0 # increment to reset cache
- name: Install Poetry
if: steps.cached-poetry.outputs.cache-hit != 'true'
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction
- name: Lint with flake8
run: poetry run flake8 .
Expand All @@ -49,7 +73,7 @@ jobs:
uses: codecov/[email protected]
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: jbloomAus/mats_sae_training
slug: jbloomAus/SAELens

release:

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ flake8 = "^7.0.0"
isort = "^5.13.2"
pyright = "^1.1.351"


[tool.isort]
profile = "black"

Expand All @@ -54,6 +55,7 @@ reportConstantRedefinition = "none"
reportUnknownLambdaType = "none"
reportPrivateUsage = "none"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import plotly
import plotly.express as px
import torch
import wandb
from sae_vis.data_config_classes import (
ActsHistogramConfig,
Column,
Expand All @@ -23,7 +24,6 @@
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

import wandb
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader


Expand Down
1 change: 0 additions & 1 deletion sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Optional, cast

import torch

import wandb


Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import pandas as pd
import torch
import wandb
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name

import wandb
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.sparse_autoencoder import SparseAutoencoder

Expand Down
1 change: 1 addition & 0 deletions sae_lens/training/lm_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, cast

import wandb

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader

Expand Down
24 changes: 11 additions & 13 deletions sae_lens/training/sae_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pickle
from itertools import product
from types import SimpleNamespace
from typing import Any, Iterator
from typing import Iterator

import torch

Expand Down Expand Up @@ -58,9 +58,8 @@ def to(self, device: torch.device | str):
for ae in self.autoencoders:
ae.to(device)

# old pickled SAEs load as a dict
@classmethod
def load_from_pretrained(cls, path: str) -> "SAEGroup" | dict[str, Any]:
def load_from_pretrained(cls, path: str) -> "SAEGroup":
"""
Load function for the model. Loads the model's state_dict and the config used to train it.
This method can be called directly on the class, without needing an instance.
Expand Down Expand Up @@ -114,18 +113,17 @@ def load_from_pretrained(cls, path: str) -> "SAEGroup" | dict[str, Any]:
f"Unexpected file extension: {path}, supported extensions are .pt, .pkl, and .pkl.gz"
)

return group
# # # Ensure the loaded state contains both 'cfg' and 'state_dict'
# # if "cfg" not in state_dict or "state_dict" not in state_dict:
# # raise ValueError(
# # "The loaded state dictionary must contain 'cfg' and 'state_dict' keys"
# # )
# handle loading old autoencoders where before SAEGroup existed, where we just save a dict
if isinstance(group, dict):
sparse_autoencoder = SparseAutoencoder(cfg=group["cfg"])
sparse_autoencoder.load_state_dict(group["state_dict"])
group = cls(group["cfg"])
group.autoencoders[0] = sparse_autoencoder

# # Create an instance of the class using the loaded configuration
# instance = cls(cfg=state_dict["cfg"])
# instance.load_state_dict(state_dict["state_dict"])
if not isinstance(group, cls):
raise ValueError("The loaded object is not a valid SAEGroup")

# return instance
return group

def save_model(self, path: str):
"""
Expand Down
26 changes: 1 addition & 25 deletions sae_lens/training/session_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.sae_group import SAEGroup
from sae_lens.training.sparse_autoencoder import SparseAutoencoder


class LMSparseAutoencoderSessionloader:
Expand Down Expand Up @@ -40,31 +39,8 @@ def load_session_from_pretrained(
"""
Loads a session for analysing a pretrained sparse autoencoder group.
"""
# if torch.backends.mps.is_available():
# cfg = torch.load(path, map_location="mps")["cfg"]
# cfg.device = "mps"
# elif torch.cuda.is_available():
# cfg = torch.load(path, map_location="cuda")["cfg"]
# else:
# cfg = torch.load(path, map_location="cpu")["cfg"]

sparse_autoencoders = SAEGroup.load_from_pretrained(path)

# hacky code to deal with old SAE saves
if type(sparse_autoencoders) is dict:
sparse_autoencoder = SparseAutoencoder(cfg=sparse_autoencoders["cfg"])
sparse_autoencoder.load_state_dict(sparse_autoencoders["state_dict"])
model, sparse_autoencoders, activations_loader = cls(
sparse_autoencoder.cfg
).load_session()
sparse_autoencoders.autoencoders[0] = sparse_autoencoder
elif type(sparse_autoencoders) is SAEGroup:
model, _, activations_loader = cls(sparse_autoencoders.cfg).load_session()
else:
raise ValueError(
"The loaded sparse_autoencoders object is neither an SAE dict nor a SAEGroup"
)

model, _, activations_loader = cls(sparse_autoencoders.cfg).load_session()
return model, sparse_autoencoders, activations_loader

def get_model(self, model_name: str):
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/toy_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import einops
import torch

import wandb

from sae_lens.training.sparse_autoencoder import SparseAutoencoder
from sae_lens.training.toy_models import Config as ToyConfig
from sae_lens.training.toy_models import Model as ToyModel
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
from typing import Any, NamedTuple, cast

import torch
import wandb
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import LRScheduler
from tqdm import tqdm
from transformer_lens import HookedTransformer

import wandb
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.evals import run_evals
from sae_lens.training.geometric_median import compute_geometric_median
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/train_sae_on_toy_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, cast

import torch
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from sae_lens.training.sparse_autoencoder import SparseAutoencoder


Expand Down
5 changes: 2 additions & 3 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import pytest
from transformer_lens import HookedTransformer

from tests.unit.helpers import TINYSTORIES_MODEL
from tests.unit.helpers import TINYSTORIES_MODEL, load_model_cached


@pytest.fixture
def ts_model():
return HookedTransformer.from_pretrained(TINYSTORIES_MODEL, device="cpu")
return load_model_cached(TINYSTORIES_MODEL)
24 changes: 20 additions & 4 deletions tests/unit/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

import torch
from transformer_lens import HookedTransformer

from sae_lens.training.config import LanguageModelSAERunnerConfig

Expand All @@ -26,13 +27,13 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig:
l1_coefficient=2e-3,
lp_norm=1,
lr=2e-4,
train_batch_size=2048,
context_size=64,
train_batch_size=4,
context_size=6,
feature_sampling_window=50,
dead_feature_threshold=1e-7,
n_batches_in_buffer=10,
n_batches_in_buffer=2,
total_training_tokens=1_000_000,
store_batch_size=32,
store_batch_size=4,
log_to_wandb=False,
wandb_project="test_project",
wandb_entity="test_entity",
Expand All @@ -48,3 +49,18 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig:
setattr(mock_config, key, val)

return mock_config


MODEL_CACHE: dict[str, HookedTransformer] = {}


def load_model_cached(model_name: str) -> HookedTransformer:
"""
helper to avoid unnecessarily loading the same model multiple times.
NOTE: if the model gets modified in tests this will not work.
"""
if model_name not in MODEL_CACHE:
MODEL_CACHE[model_name] = HookedTransformer.from_pretrained(
model_name, device="cpu"
)
return MODEL_CACHE[model_name]
Loading

0 comments on commit 9e3863c

Please sign in to comment.