Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Select metric choice for visualization #286

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
add functionality to sweep / cleanup
  • Loading branch information
reaganjlee committed Aug 18, 2023
commit 179066cac16daa1d2a8fc17e81c79ea9e49d53b8
6 changes: 1 addition & 5 deletions elk/plotting/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Plot:
overwrite: bool = False
"""Whether to overwrite existing plots."""

metric_type: str = None
metric_type: str = "auroc_estimate"
"""Name of metric to plot"""

def execute(self):
Expand All @@ -38,10 +38,6 @@ def execute(self):
else:
sweep_paths = [root_dir / sweep for sweep in self.sweeps]

if not self.metric_type:
# ArgumentParser maps cli input --metric to metric_type
self.metric_type = "auroc_estimate"

for sweep_path in sweep_paths:
if not sweep_path.exists():
pretty_error(f"No sweep with name {{{sweep_path}}} found in {root_dir}")
Expand Down
4 changes: 2 additions & 2 deletions elk/plotting/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class TransferEvalHeatmap:
"""Class for generating heatmaps for transfer evaluation results."""

layer: int
metric_type: str = None
metric_type: str = ""
ensembling: str = "full"

def render(self, df: pd.DataFrame) -> go.Figure:
Expand Down Expand Up @@ -149,7 +149,7 @@ class TransferEvalTrend:
evaluation."""

dataset_names: list[str] | None
metric_type: str = None
metric_type: str = ""

def render(self, df: pd.DataFrame) -> go.Figure:
"""Render the trend plot visualization.
Expand Down
5 changes: 4 additions & 1 deletion elk/training/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Sweep:
visualize: bool = False
"""Whether to generate visualizations of the results of the sweep."""

metric_type: str = "auroc_estimate"
"""Name of metric to plot"""

name: str | None = None

# A bit of a hack to add all the command line arguments from Elicit
Expand Down Expand Up @@ -176,4 +179,4 @@ def execute(self):
eval.execute(highlight_color="green")

if self.visualize:
visualize_sweep(sweep_dir)
visualize_sweep(sweep_dir, self.metric_type)
Loading