Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Select metric choice for visualization #286

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ together. You can also add a `--visualize` flag to visualize the results of the
elk sweep --models gpt2-{medium,large,xl} --datasets imdb amazon_polarity --add_pooled
```

If you just do `elk plot`, it will plot the results from the most recent sweep.
If you want to plot a specific sweep, you can do so with:
If you just do `elk plot`, it will plot the results of AUROC from the most recent sweep.
If you want to plot a specific sweep, with a specific metric type, you can do so with:

```bash
elk plot {sweep_name}
elk plot {sweep_name} --metric acc_estimate
```

## Caching
Expand Down
5 changes: 4 additions & 1 deletion elk/plotting/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class Plot:
overwrite: bool = False
"""Whether to overwrite existing plots."""

metric_type: str = "auroc_estimate"
"""Name of metric to plot"""

def execute(self):
root_dir = sweeps_dir()

Expand All @@ -47,4 +50,4 @@ def execute(self):
if self.overwrite:
shutil.rmtree(sweep_path / "viz")

visualize_sweep(sweep_path)
visualize_sweep(sweep_path, self.metric_type)
54 changes: 27 additions & 27 deletions elk/plotting/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def render(
shared_yaxes=True,
vertical_spacing=0.1,
x_title="Layer",
y_title="AUROC",
y_title=f"{sweep.metric_type}",
)
color_map = dict(zip(ensembles, qualitative.Plotly))

Expand All @@ -56,7 +56,7 @@ def render(
if with_transfer: # TODO write tests
ensemble_data = ensemble_data.groupby(
["eval_dataset", "layer", "ensembling"], as_index=False
).agg({"auroc_estimate": "mean"})
).agg({f"{sweep.metric_type}": "mean"})
else:
ensemble_data = ensemble_data[
ensemble_data["eval_dataset"] == ensemble_data["train_dataset"]
Expand All @@ -75,7 +75,7 @@ def render(
fig.add_trace(
go.Scatter(
x=dataset_data["layer"],
y=dataset_data["auroc_estimate"],
y=dataset_data[f"{sweep.metric_type}"],
mode="lines",
name=ensemble,
showlegend=False
Expand All @@ -95,7 +95,7 @@ def render(
legend=dict(
title="Ensembling",
),
title=f"AUROC Trend: {self.model_name}",
title=f"{sweep.metric_type} Trend: {self.model_name}",
)
if write:
fig.write_image(
Expand All @@ -114,7 +114,7 @@ class TransferEvalHeatmap:
"""Class for generating heatmaps for transfer evaluation results."""

layer: int
score_type: str = "auroc_estimate"
metric_type: str = ""
ensembling: str = "full"

def render(self, df: pd.DataFrame) -> go.Figure:
Expand All @@ -129,27 +129,28 @@ def render(self, df: pd.DataFrame) -> go.Figure:
model_name = df["eval_dataset"].iloc[0] # infer model name
# TODO: validate
pivot = pd.pivot_table(
df, values=self.score_type, index="eval_dataset", columns="train_dataset"
df, values=self.metric_type, index="eval_dataset", columns="train_dataset"
)

fig = px.imshow(pivot, color_continuous_scale="Viridis", text_auto=True)

fig.update_layout(
xaxis_title="Train Dataset",
yaxis_title="Transfer Dataset",
title=f"AUROC Score Heatmap: {model_name} | Layer {self.layer}",
title=f"{self.metric_type} Score Heatmap: {model_name} \
| Layer {self.layer}",
)

return fig


@dataclass
class TransferEvalTrend:
"""Class for generating line plots for the trend of AUROC scores in transfer
"""Class for generating line plots for the trend of metric scores in transfer
evaluation."""

dataset_names: list[str] | None
score_type: str = "auroc_estimate"
metric_type: str = ""

def render(self, df: pd.DataFrame) -> go.Figure:
"""Render the trend plot visualization.
Expand All @@ -164,14 +165,14 @@ def render(self, df: pd.DataFrame) -> go.Figure:
if self.dataset_names is not None:
df = self._filter_transfer_datasets(df, self.dataset_names)
pivot = pd.pivot_table(
df, values=self.score_type, index="layer", columns="eval_dataset"
df, values=self.metric_type, index="layer", columns="eval_dataset"
)

fig = px.line(pivot, color_discrete_sequence=px.colors.qualitative.Plotly)
fig.update_layout(
xaxis_title="Layer",
yaxis_title="AUROC Score",
title=f"AUROC Score Trend: {model_name}",
yaxis_title=f"{self.metric_type} Score",
title=f"{self.metric_type} Score Trend: {model_name}",
)

avg = pivot.mean(axis=1)
Expand Down Expand Up @@ -244,17 +245,16 @@ def render_and_save(
self,
sweep: "SweepVisualization",
dataset_names: list[str] | None = None,
score_type="auroc_estimate",
ensembling="full",
) -> None:
"""Render and save the visualization for the model.

Args:
sweep: The SweepVisualization instance.
dataset_names: List of dataset names to include in the visualization.
score_type: The type of score to display.
ensembling: The ensembling option to consider.
"""
metric_type = sweep.metric_type
df = self.df
model_name = self.model_name
layer_min, layer_max = df["layer"].min(), df["layer"].max()
Expand All @@ -264,10 +264,10 @@ def render_and_save(
for layer in range(layer_min, layer_max + 1):
filtered = df[(df["layer"] == layer) & (df["ensembling"] == ensembling)]
fig = TransferEvalHeatmap(
layer, score_type=score_type, ensembling=ensembling
layer, metric_type=metric_type, ensembling=ensembling
).render(filtered)
fig.write_image(file=model_path / f"{layer}.png")
fig = TransferEvalTrend(dataset_names).render(df)
fig = TransferEvalTrend(dataset_names, metric_type=metric_type).render(df)
fig.write_image(file=model_path / "transfer_eval_trend.png")

@staticmethod
Expand All @@ -288,6 +288,7 @@ class SweepVisualization:
path: Path
datasets: list[str]
models: dict[str, ModelVisualization]
metric_type: str

def model_names(self) -> list[str]:
"""Get the names of all models in the sweep.
Expand Down Expand Up @@ -323,7 +324,7 @@ def _get_model_paths(sweep_path: Path) -> list[Path]:
return folders

@classmethod
def collect(cls, sweep_path: Path) -> "SweepVisualization":
def collect(cls, sweep_path: Path, metric_type: str) -> "SweepVisualization":
"""Collect the evaluation data for a sweep.

Args:
Expand All @@ -348,7 +349,9 @@ def collect(cls, sweep_path: Path) -> "SweepVisualization":
}
df = pd.concat([model.df for model in models.values()], ignore_index=True)
datasets = list(df["eval_dataset"].unique())
return cls(sweep_name, df, sweep_viz_path, datasets, models)
return cls(
sweep_name, df, sweep_viz_path, datasets, models, metric_type=metric_type
)

def render_and_save(self):
"""Render and save all visualizations for the sweep."""
Expand All @@ -368,14 +371,11 @@ def render_multiplots(self, write=False):
for model in self.models
]

def render_table(
self, score_type="auroc_estimate", display=True, write=False
) -> pd.DataFrame:
def render_table(self, display=True, write=False) -> pd.DataFrame:
"""Render and optionally write the score table.

Args:
layer: The layer number (from last layer) to include in the score table.
score_type: The type of score to include in the table.
display: Flag indicating whether to display the table to stdout.
write: Flag indicating whether to write the table to a file.

Expand All @@ -387,15 +387,15 @@ def render_table(
# For each model, we use the layer whose mean AUROC is the highest
best_layers, model_dfs = [], []
for _, model_df in df.groupby("model_name"):
best_layer = model_df.groupby("layer").auroc_estimate.mean().argmax()
best_layer = model_df.groupby("layer")[self.metric_type].mean().argmax()

best_layers.append(best_layer)
model_dfs.append(model_df[model_df["layer"] == best_layer])

pivot_table = pd.concat(model_dfs).pivot_table(
index="eval_dataset",
columns="model_name",
values=score_type,
values=self.metric_type,
margins=True,
margins_name="Mean",
)
Expand All @@ -416,14 +416,14 @@ def render_table(
console.print(table)

if write:
pivot_table.to_csv(f"score_table_{score_type}.csv")
pivot_table.to_csv(f"score_table_{self.metric_type}.csv")
return pivot_table


def visualize_sweep(sweep_path: Path):
def visualize_sweep(sweep_path: Path, metric_type: str):
"""Visualize a sweep by generating and saving the visualizations.

Args:
sweep_path: The path to the sweep data directory.
"""
SweepVisualization.collect(sweep_path).render_and_save()
SweepVisualization.collect(sweep_path, metric_type).render_and_save()
5 changes: 4 additions & 1 deletion elk/training/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Sweep:
visualize: bool = False
"""Whether to generate visualizations of the results of the sweep."""

metric_type: str = "auroc_estimate"
"""Name of metric to plot"""

name: str | None = None

# A bit of a hack to add all the command line arguments from Elicit
Expand Down Expand Up @@ -176,4 +179,4 @@ def execute(self):
eval.execute(highlight_color="green")

if self.visualize:
visualize_sweep(sweep_dir)
visualize_sweep(sweep_dir, self.metric_type)
Loading