Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sweep Visualizations #245

Merged
merged 25 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
37002b1
create heatmap-visualizations for sweeps
lauritofzi May 5, 2023
e552f55
fix viz path + cleanup
lauritofzi May 7, 2023
199c8d5
initial
derpyplops May 14, 2023
1568cd0
refactoring
derpyplops May 14, 2023
181283f
code fix
derpyplops May 14, 2023
2908c0f
fix deps
derpyplops May 14, 2023
6ab281e
fix elk sweep viz flag usage
derpyplops May 15, 2023
ee6e14b
fix typo
derpyplops May 15, 2023
5d526d7
delete comment and factorize
derpyplops May 16, 2023
058cae8
cleanup
lauritowal May 16, 2023
5162cab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2023
398ec42
Fix file resolution and factor out sweep_dir()
derpyplops May 16, 2023
9a79c8f
change to relative import
lauritowal May 16, 2023
cf0ad33
Merge branch 'visualizations' of https://github.com/EleutherAI/elk in…
lauritowal May 16, 2023
d6bd7b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2023
d4e99b0
Address walt's comments and some tests
derpyplops May 17, 2023
a2e5f60
Change write location to elk-reporters/{sweep}/viz
derpyplops May 18, 2023
68e4b52
Edit README
derpyplops May 18, 2023
0e152bc
Fix TestGetModelPaths
derpyplops May 18, 2023
9d6552f
Fix duplicate bug
derpyplops May 18, 2023
623b2c7
add overwrite flag
derpyplops May 18, 2023
256ad68
add transfer to SweepByDsMultiplot
derpyplops May 18, 2023
9f9c5bb
Remove docstrings for consistency
derpyplops May 18, 2023
033e901
remove vestigial .gitignore
derpyplops May 18, 2023
c176732
remove burns datasets
derpyplops May 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address walt's comments and some tests
  • Loading branch information
derpyplops committed May 17, 2023
commit d4e99b0dc921092bcb87108bfb0773864dcbeb8a
5 changes: 2 additions & 3 deletions elk/plotting/command.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from dataclasses import dataclass, field
from dataclasses import dataclass

from simple_parsing import field

Expand All @@ -13,7 +12,7 @@ class Plot:

def execute(self):
sweeps_root_dir = sweeps_dir()
sweep = max(sweeps_root_dir.iterdir(), key=os.path.getctime)
sweep = max(sweeps_root_dir.iterdir(), key=lambda f: f.stat().st_ctime)
if self.sweeps:
sweep = sweeps_root_dir / self.sweeps[0]
if not sweep.exists():
Expand Down
27 changes: 0 additions & 27 deletions elk/plotting/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import shutil
from pathlib import Path

from rich.console import Console
from rich.table import Table

Expand Down Expand Up @@ -52,27 +49,3 @@ def display_table(pivot_table):
table.add_row(str(index), *[str(value) for value in row])

console.print(table)


def restructure_to_sweep(
elk_reporters: Path, data_path, new_name: str
): # usually /elk-reporters/*
for model_repo_path in elk_reporters.iterdir():
for model_path in model_repo_path.iterdir():
for dataset_path in model_path.iterdir():
for run_path in dataset_path.iterdir():
new_path = (
data_path
/ new_name
/ run_path.name
/ model_repo_path.name
/ model_path.name
/ dataset_path.name
)
if not new_path.exists():
new_path.mkdir(parents=True)
for file in run_path.iterdir():
if file.is_file():
shutil.copy(file, new_path / file.name)
else:
shutil.copytree(file, new_path / file.name)
166 changes: 66 additions & 100 deletions elk/plotting/visualize.py
derpyplops marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,29 @@
from plotly.subplots import make_subplots

import elk.plotting.utils as utils
from elk.utils.constants import BURNS_DATASETS

VIZ_PATH = Path(os.getcwd()) / "viz"
derpyplops marked this conversation as resolved.
Show resolved Hide resolved

ALL_DS_NAMES = [
"super_glue:rte",
"super_glue:boolq",
"dbpedia_14",
"piqa",
"amazon_polarity",
"glue:qnli",
"ag_news",
"imdb",
"super_glue:copa",
]


class SweepByDsMultiplot:
"""Multiplot containing visualization where the x-axis is the layer and
the y-axis is auroc_estimate. Each subplot is a different dataset.
"""

def __init__(self, model_name: str):
self.model_name = model_name

def validate(self, sweep: Sweep) -> bool:
return True

def render(self, sweep: Sweep, with_transfer=False, write=True) -> go.Figure:
def render(
self,
sweep: SweepVisualization,
with_transfer=False,
ensembles=["full", "partial", "none"],
write=False,
) -> go.Figure:
df = sweep.df
unique_datasets = df["eval_dataset"].unique()
run_names = df["run_name"].unique()
ensembles = ["full", "none"]
num_datasets = len(unique_datasets)
num_rows = (num_datasets + 2) // 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's happening here? maybe add a comment? 🟡

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDK, chatgpt generated it


Expand Down Expand Up @@ -123,12 +118,9 @@
self.score_type = score_type
self.ensembling = ensembling

def validate(self, df) -> bool:
return True

def generate(self, df: pd.DataFrame) -> go.Figure:
def render(self, df: pd.DataFrame, write=False) -> go.Figure:
"""
Generate a heatmap for dataset of a model.
Render a heatmap for dataset of a model.
"""
model_name = df["eval_dataset"].iloc[0] # infer model name
# TODO: validate
Expand All @@ -152,10 +144,7 @@
self.dataset_names = dataset_names
self.score_type = score_type

def validate(self, df) -> bool:
return True

def generate(self, df: pd.DataFrame) -> go.Figure:
def render(self, df: pd.DataFrame) -> go.Figure:
# TODO should I filter out the non-transfer dataset?
model_name = df["eval_dataset"].iloc[0] # infer model name
df = self._filter_transfer_datasets(df, self.dataset_names)
Expand Down Expand Up @@ -190,54 +179,47 @@


@dataclass
class Model:
class ModelVisualization:
df: pd.DataFrame
sweep_name: str
model_name: str
is_transfer: bool

@classmethod
def collect(cls, model_path: Path, sweep_name: str) -> Model:
def collect(cls, model_path: Path, sweep_name: str) -> ModelVisualization:
df = pd.DataFrame()
model_name = model_path.name

def handle_csv(dir, eval_dataset, train_dataset):
file = dir / "eval.csv"
eval_df = pd.read_csv(file)
eval_df["eval_dataset"] = eval_dataset
eval_df["train_dataset"] = train_dataset
return eval_df

is_transfer = False
for train_dir in model_path.iterdir():
eval_df = handle_csv(train_dir, train_dir.name, train_dir.name)
eval_df = cls._read_eval_csv(train_dir, train_dir.name, train_dir.name)
df = pd.concat([df, eval_df], ignore_index=True)
transfer_dir = train_dir / "transfer"
if transfer_dir.exists():
is_transfer = True
for eval_ds_dir in transfer_dir.iterdir():
eval_df = handle_csv(eval_ds_dir, eval_ds_dir.name, train_dir.name)
eval_df = cls._read_eval_csv(
eval_ds_dir, eval_ds_dir.name, train_dir.name
)
df = pd.concat([df, eval_df], ignore_index=True)

df["model_name"] = model_name
df["run_name"] = sweep_name

return cls(df, sweep_name, model_name, is_transfer)

def render(
def render_and_save(
self,
dataset_names: list[str] = ALL_DS_NAMES,
dataset_names: list[str] = BURNS_DATASETS,
score_type="auroc_estimate",
ensembling="full",
) -> None:
df = self.df
model_name = self.model_name
sweep_name = self.sweep_name
layer_min, layer_max = df["layer"].min(), df["layer"].max()
if not (VIZ_PATH / sweep_name).exists():
(VIZ_PATH / sweep_name).mkdir()
if not (VIZ_PATH / sweep_name / f"{model_name}").exists():
(VIZ_PATH / sweep_name / f"{model_name}").mkdir()
model_path = VIZ_PATH / sweep_name / f"{model_name}"
model_path.mkdir(parents=True, exist_ok=True)
if self.is_transfer:
for layer in range(layer_min, layer_max + 1):
filtered = df[(df["layer"] == layer) & (df["ensembling"] == ensembling)]
Expand All @@ -246,105 +228,89 @@
path.parent.mkdir()
fig = TransferEvalHeatmap(
layer, score_type=score_type, ensembling=ensembling
).generate(filtered)
).render(filtered)
fig.write_image(file=path)

fig = (
TransferEvalTrend(dataset_names)
.generate(df)
.write_image(
file=VIZ_PATH / sweep_name / f"{model_name}" / "transfer_eval_trend.png"
)
fig = TransferEvalTrend(dataset_names).render(df)
fig.write_image(
file=VIZ_PATH / sweep_name / f"{model_name}" / "transfer_eval_trend.png"
)


# the following function does too many things.
# it can be split up into:
# function that takes a sweep/run and returns a dataframe
# function that takes df runs / run / model (with or without transfer) and renders it
@staticmethod
def _read_eval_csv(path, eval_dataset, train_dataset):
file = path / "eval.csv"
eval_df = pd.read_csv(file)
eval_df["eval_dataset"] = eval_dataset
eval_df["train_dataset"] = train_dataset
return eval_df


@dataclass
class Sweep:
class SweepVisualization:
name: str
df: pd.DataFrame
path: Path
datasets: list[str]
models: dict[str, Model]
models: dict[str, ModelVisualization]

def model_names(self):
return list(self.models.keys())

@staticmethod
def _get_model_paths(sweep_path: Path) -> list[Path]:
folders = []
for model_repo in sweep_path.iterdir():
if not model_repo.is_dir():
raise Exception("expected model repo to be a directory")
if model_repo.name.startswith("gpt2"):
folders += [model_repo]
else:
folders += [p for p in model_repo.iterdir() if p.is_dir()]
return folders

@classmethod
def collect(cls, sweep: Path) -> Sweep:
def collect(cls, sweep: Path) -> SweepVisualization:
sweep_name = sweep.parts[-1]
sweep_viz_path = VIZ_PATH / sweep_name
sweep_viz_path.mkdir(parents=True, exist_ok=True)

# TODO refactor out
if not VIZ_PATH.exists():
VIZ_PATH.mkdir()
if not sweep_viz_path.exists():
sweep_viz_path.mkdir()

model_paths = get_model_paths(sweep)
model_paths = cls._get_model_paths(sweep)
models = {
model_path.name: Model.collect(model_path, sweep_name)
model_path.name: ModelVisualization.collect(model_path, sweep_name)
for model_path in model_paths
}
df = pd.concat([model.df for model in models.values()], ignore_index=True)
datasets = list(df["eval_dataset"].unique())
return cls(sweep_name, df, sweep_viz_path, datasets, models)

def render(self):
def render_and_save(self):
for model in self.models.values():
model.render()
self.generate_table(write=True)
self.generate_multiplots()
model.render_and_save()
self.render_table(write=True)
self.render_multiplots(write=True)

def generate_multiplots(self):
return [SweepByDsMultiplot(model).render(self) for model in self.models]
def render_multiplots(self, write=False):
return [
SweepByDsMultiplot(model).render(self, write=write) for model in self.models
]

def generate_table(
self, layer=-5, score_type="auroc_estimate", print=True, write=False
def render_table(
self, layer=-5, score_type="auroc_estimate", display=True, write=False
):
df = self.df

layer_by_model = (df.groupby("model_name")["layer"].max() + layer).clip(lower=0)

# Create an empty DataFrame to store the selected records
df_selected_layer = pd.DataFrame()

# For each model, select the record corresponding to max layer - 5
for model, layer in layer_by_model.items():
record = df[(df["model_name"] == model) & (df["layer"] == layer)]
df_selected_layer = pd.concat([df_selected_layer, record])

# Generate the pivot table
pivot_table = df_selected_layer.pivot_table(
index="run_name", columns="model_name", values=score_type
)

if print:
if display:
utils.display_table(pivot_table)
if write:
pivot_table.to_csv(f"score_table_{score_type}.csv")

return pivot_table


def get_model_paths(sweep_path: Path) -> list[Path]:
# TODO write test
# run / model_repo / model / dataset / eval.csv
folders = []
for model_repo in sweep_path.iterdir():
if not model_repo.is_dir():
raise Exception("expected model repo to be a directory")
if model_repo.name.startswith("gpt2"):
folders += [model_repo]
else:
folders += [p for p in model_repo.iterdir() if p.is_dir()]
return folders


def visualize_sweep(sweep_path: Path):
Sweep.collect(sweep_path).render()
SweepVisualization.collect(sweep_path).render_and_save(write=True)

Check failure on line 316 in elk/plotting/visualize.py

View workflow job for this annotation

GitHub Actions / run-tests (3.11, macos-latest)

No parameter named "write" (reportGeneralTypeIssues)
15 changes: 2 additions & 13 deletions elk/training/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ..plotting.visualize import visualize_sweep
from ..training.eigen_reporter import EigenReporterConfig
from ..utils import colorize
from ..utils.constants import BURNS_DATASETS
from .train import Elicit


Expand Down Expand Up @@ -62,19 +63,7 @@ def __post_init__(self, add_pooled: bool):
# on the Huggingface Hub.
if "burns" in self.datasets:
self.datasets.remove("burns")
self.datasets.extend(
[
"ag_news",
"amazon_polarity",
"dbpedia_14",
"glue:qnli",
"imdb",
"piqa",
"super_glue:boolq",
"super_glue:copa",
"super_glue:rte",
]
)
self.datasets.extend(BURNS_DATASETS)
print(
"Interpreting `burns` as all datasets used in Burns et al. (2022) "
"available on the HuggingFace Hub"
Expand Down
12 changes: 12 additions & 0 deletions elk/utils/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# all datasets used in Burns et al. (2022)
BURNS_DATASETS = [
"ag_news",
"amazon_polarity",
"dbpedia_14",
"glue:qnli",
"imdb",
"piqa",
"super_glue:boolq",
"super_glue:copa",
"super_glue:rte",
]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ description = "Keeping language models honest by directly eliciting knowledge en
readme = "README.md"
requires-python = ">=3.10"
keywords = ["nlp", "interpretability", "language-models", "explainable-ai"]
license = {text = "MIT License"}
license = { text = "MIT License" }
dependencies = [
# Allows us to use device_map in from_pretrained. Also needed for 8bit
"accelerate",
Expand Down Expand Up @@ -46,6 +46,7 @@ dev = [
"pytest",
"pyright==1.1.304",
"scikit-learn",
"pyfakefs"
]
8bit = [
"bitsandbytes",
Expand Down
Loading
Loading