Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Select metric choice for visualization #286

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
rename
  • Loading branch information
reaganjlee committed Aug 17, 2023
commit 8ef905a2b6d6895628094d6f9066ef550058430d
40 changes: 21 additions & 19 deletions elk/plotting/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class TransferEvalHeatmap:
"""Class for generating heatmaps for transfer evaluation results."""

layer: int
score_type: str = "auroc_estimate"
metric_type: str = "auroc_estimate"
ensembling: str = "full"

def render(self, df: pd.DataFrame) -> go.Figure:
Expand All @@ -129,15 +129,15 @@ def render(self, df: pd.DataFrame) -> go.Figure:
model_name = df["eval_dataset"].iloc[0] # infer model name
# TODO: validate
pivot = pd.pivot_table(
df, values=self.score_type, index="eval_dataset", columns="train_dataset"
df, values=self.metric_type, index="eval_dataset", columns="train_dataset"
)

fig = px.imshow(pivot, color_continuous_scale="Viridis", text_auto=True)

fig.update_layout(
xaxis_title="Train Dataset",
yaxis_title="Transfer Dataset",
title=f"AUROC Score Heatmap: {model_name} | Layer {self.layer}",
title=f"{self.metric_type} Score Heatmap: {model_name} | Layer {self.layer}",
)

return fig
Expand All @@ -149,7 +149,7 @@ class TransferEvalTrend:
evaluation."""

dataset_names: list[str] | None
score_type: str = "auroc_estimate"
metric_type: str = "auroc_estimate"

def render(self, df: pd.DataFrame) -> go.Figure:
"""Render the trend plot visualization.
Expand All @@ -164,14 +164,14 @@ def render(self, df: pd.DataFrame) -> go.Figure:
if self.dataset_names is not None:
df = self._filter_transfer_datasets(df, self.dataset_names)
pivot = pd.pivot_table(
df, values=self.score_type, index="layer", columns="eval_dataset"
df, values=self.metric_type, index="layer", columns="eval_dataset"
)

fig = px.line(pivot, color_discrete_sequence=px.colors.qualitative.Plotly)
fig.update_layout(
xaxis_title="Layer",
yaxis_title="AUROC Score",
title=f"AUROC Score Trend: {model_name}",
yaxis_title=f"{self.metric_type} Score",
title=f"{self.metric_type} Score Trend: {model_name}",
)

avg = pivot.mean(axis=1)
Expand Down Expand Up @@ -244,17 +244,18 @@ def render_and_save(
self,
sweep: "SweepVisualization",
dataset_names: list[str] | None = None,
score_type="auroc_estimate",
metric_type="auroc_estimate",
ensembling="full",
) -> None:
"""Render and save the visualization for the model.

Args:
sweep: The SweepVisualization instance.
dataset_names: List of dataset names to include in the visualization.
score_type: The type of score to display.
metric_type: The type of score to display.
ensembling: The ensembling option to consider.
"""
metric_type = sweep.metric_type
df = self.df
model_name = self.model_name
layer_min, layer_max = df["layer"].min(), df["layer"].max()
Expand All @@ -264,10 +265,10 @@ def render_and_save(
for layer in range(layer_min, layer_max + 1):
filtered = df[(df["layer"] == layer) & (df["ensembling"] == ensembling)]
fig = TransferEvalHeatmap(
layer, score_type=score_type, ensembling=ensembling
layer, metric_type=metric_type, ensembling=ensembling
).render(filtered)
fig.write_image(file=model_path / f"{layer}.png")
fig = TransferEvalTrend(dataset_names).render(df)
fig = TransferEvalTrend(dataset_names, metric_type=metric_type).render(df)
fig.write_image(file=model_path / "transfer_eval_trend.png")

@staticmethod
Expand All @@ -288,6 +289,7 @@ class SweepVisualization:
path: Path
datasets: list[str]
models: dict[str, ModelVisualization]
metric_type: str

def model_names(self) -> list[str]:
"""Get the names of all models in the sweep.
Expand Down Expand Up @@ -323,7 +325,7 @@ def _get_model_paths(sweep_path: Path) -> list[Path]:
return folders

@classmethod
def collect(cls, sweep_path: Path) -> "SweepVisualization":
def collect(cls, sweep_path: Path, metric_type: str) -> "SweepVisualization":
"""Collect the evaluation data for a sweep.

Args:
Expand All @@ -348,7 +350,7 @@ def collect(cls, sweep_path: Path) -> "SweepVisualization":
}
df = pd.concat([model.df for model in models.values()], ignore_index=True)
datasets = list(df["eval_dataset"].unique())
return cls(sweep_name, df, sweep_viz_path, datasets, models)
return cls(sweep_name, df, sweep_viz_path, datasets, models, metric_type=metric_type)

def render_and_save(self):
"""Render and save all visualizations for the sweep."""
Expand All @@ -369,13 +371,13 @@ def render_multiplots(self, write=False):
]

def render_table(
self, score_type="auroc_estimate", display=True, write=False
self, metric_type="auroc_estimate", display=True, write=False
) -> pd.DataFrame:
"""Render and optionally write the score table.

Args:
layer: The layer number (from last layer) to include in the score table.
score_type: The type of score to include in the table.
metric_type: The type of score to include in the table.
display: Flag indicating whether to display the table to stdout.
write: Flag indicating whether to write the table to a file.

Expand All @@ -395,7 +397,7 @@ def render_table(
pivot_table = pd.concat(model_dfs).pivot_table(
index="eval_dataset",
columns="model_name",
values=score_type,
values=metric_type,
margins=True,
margins_name="Mean",
)
Expand All @@ -416,14 +418,14 @@ def render_table(
console.print(table)

if write:
pivot_table.to_csv(f"score_table_{score_type}.csv")
pivot_table.to_csv(f"score_table_{metric_type}.csv")
return pivot_table


def visualize_sweep(sweep_path: Path):
def visualize_sweep(sweep_path: Path, metric_type: str):
"""Visualize a sweep by generating and saving the visualizations.

Args:
sweep_path: The path to the sweep data directory.
"""
SweepVisualization.collect(sweep_path).render_and_save()
SweepVisualization.collect(sweep_path, metric_type).render_and_save()