Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Select metric choice for visualization #286

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
switch out rest of auroc with metric_type
  • Loading branch information
reaganjlee committed Aug 17, 2023
commit 4deb51bfcad2822512d24839d33049e60f6283c3
14 changes: 6 additions & 8 deletions elk/plotting/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def render(
if with_transfer: # TODO write tests
ensemble_data = ensemble_data.groupby(
["eval_dataset", "layer", "ensembling"], as_index=False
).agg({"auroc_estimate": "mean"})
).agg({f"{sweep.metric_type}": "mean"})
else:
ensemble_data = ensemble_data[
ensemble_data["eval_dataset"] == ensemble_data["train_dataset"]
Expand All @@ -75,7 +75,7 @@ def render(
fig.add_trace(
go.Scatter(
x=dataset_data["layer"],
y=dataset_data["auroc_estimate"],
y=dataset_data[f"{sweep.metric_type}"],
mode="lines",
name=ensemble,
showlegend=False
Expand All @@ -95,7 +95,7 @@ def render(
legend=dict(
title="Ensembling",
),
title=f"AUROC Trend: {self.model_name}",
title=f"{sweep.metric_type} Trend: {self.model_name}",
)
if write:
fig.write_image(
Expand All @@ -114,7 +114,7 @@ class TransferEvalHeatmap:
"""Class for generating heatmaps for transfer evaluation results."""

layer: int
metric_type: str = "auroc_estimate"
metric_type: str = None
ensembling: str = "full"

def render(self, df: pd.DataFrame) -> go.Figure:
Expand Down Expand Up @@ -145,11 +145,11 @@ def render(self, df: pd.DataFrame) -> go.Figure:

@dataclass
class TransferEvalTrend:
"""Class for generating line plots for the trend of AUROC scores in transfer
"""Class for generating line plots for the trend of metric scores in transfer
evaluation."""

dataset_names: list[str] | None
metric_type: str = "auroc_estimate"
metric_type: str = None

def render(self, df: pd.DataFrame) -> go.Figure:
"""Render the trend plot visualization.
Expand Down Expand Up @@ -244,15 +244,13 @@ def render_and_save(
self,
sweep: "SweepVisualization",
dataset_names: list[str] | None = None,
metric_type="auroc_estimate",
ensembling="full",
) -> None:
"""Render and save the visualization for the model.

Args:
sweep: The SweepVisualization instance.
dataset_names: List of dataset names to include in the visualization.
metric_type: The type of score to display.
ensembling: The ensembling option to consider.
"""
metric_type = sweep.metric_type
Expand Down