Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sweep Visualizations #245

Merged
merged 25 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
37002b1
create heatmap-visualizations for sweeps
lauritofzi May 5, 2023
e552f55
fix viz path + cleanup
lauritofzi May 7, 2023
199c8d5
initial
derpyplops May 14, 2023
1568cd0
refactoring
derpyplops May 14, 2023
181283f
code fix
derpyplops May 14, 2023
2908c0f
fix deps
derpyplops May 14, 2023
6ab281e
fix elk sweep viz flag usage
derpyplops May 15, 2023
ee6e14b
fix typo
derpyplops May 15, 2023
5d526d7
delete comment and factorize
derpyplops May 16, 2023
058cae8
cleanup
lauritowal May 16, 2023
5162cab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2023
398ec42
Fix file resolution and factor out sweep_dir()
derpyplops May 16, 2023
9a79c8f
change to relative import
lauritowal May 16, 2023
cf0ad33
Merge branch 'visualizations' of https://github.com/EleutherAI/elk in…
lauritowal May 16, 2023
d6bd7b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2023
d4e99b0
Address walt's comments and some tests
derpyplops May 17, 2023
a2e5f60
Change write location to elk-reporters/{sweep}/viz
derpyplops May 18, 2023
68e4b52
Edit README
derpyplops May 18, 2023
0e152bc
Fix TestGetModelPaths
derpyplops May 18, 2023
9d6552f
Fix duplicate bug
derpyplops May 18, 2023
623b2c7
add overwrite flag
derpyplops May 18, 2023
256ad68
add transfer to SweepByDsMultiplot
derpyplops May 18, 2023
9f9c5bb
Remove docstrings for consistency
derpyplops May 18, 2023
033e901
remove vestigial .gitignore
derpyplops May 18, 2023
c176732
remove burns datasets
derpyplops May 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix viz path + cleanup
  • Loading branch information
lauritofzi committed May 7, 2023
commit e552f553cd2df3dd3a543d6252cc305da430af1f
2 changes: 1 addition & 1 deletion elk/training/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,4 @@ def execute(self):

if self.visualize:
for i, model in enumerate(self.models):
render_model_results(sweep_dir / model, sweep_dir / "visualizations")
render_model_results(sweep_dir, model)
85 changes: 36 additions & 49 deletions elk/utils/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import plotly.express as px


def generate_heatmap(data, model_name, layer, viz_folder_path):

def generate_heatmap(data, model, layer, viz_dir: Path):
pivot = pd.pivot_table(
data, values="auroc_estimate", index="eval_dataset", columns="train_dataset"
)
Expand All @@ -16,57 +15,45 @@ def generate_heatmap(data, model_name, layer, viz_folder_path):
fig.update_layout(
xaxis_title="Train Dataset",
yaxis_title="Transfer Dataset",
title=f"AUROC Score Heatmap: {model_name} | Layer {layer}",
title=f"AUROC Score Heatmap: {model} | Layer {layer}",
)

model_viz_folder_path = os.path.join(viz_folder_path, model_name)
if not os.path.exists(model_viz_folder_path):
os.mkdir(model_viz_folder_path)
fig.write_image(os.path.join(model_viz_folder_path, f"{layer}.png"))
viz_dir = viz_dir / model
viz_dir.mkdir(parents=True, exist_ok=True)

fig.write_image(viz_dir / f"{layer}.png")


def reduce_model_results(model_path):
df = pd.DataFrame()

for train_dir in model_path.iterdir():
file = train_dir / "eval.csv"
eval_df = pd.read_csv(file)
eval_df["eval_dataset"] = train_dir.name
eval_df["train_dataset"] = train_dir.name
df = pd.concat([df, eval_df], ignore_index=True)
transfer_dir = train_dir / "transfer"
for eval_ds_dir in transfer_dir.iterdir():
eval_file_path = eval_ds_dir / "eval.csv"
eval_df = pd.read_csv(eval_file_path)
eval_df["eval_dataset"] = eval_ds_dir.name
eval_df["train_dataset"] = train_dir.name
df = pd.concat([df, eval_df], ignore_index=True)


def filter_df(df, layer):
df = df[df["ensembling"] == "full"] # filter df to only include full ensembling
df = df[df["layer"] == layer] # filter df to only include target layer
return df

def render_model_results(root_dir, model):
viz_dir = root_dir / "visualizations"
print(f"Saving sweep visualizations to \033[1m{viz_dir}\033[0m")

def render_model_results(model_path, visualizations_path):
if not visualizations_path.exists():
visualizations_path.mkdir()

df = None

def get_layer_min_max(model_path: Path):
dir = model_path.iterdir().__next__()
file = os.path.join(dir, "eval.csv")
raw_eval_df = pd.read_csv(file)
layer_min, layer_max = raw_eval_df["layer"].min(), raw_eval_df["layer"].max()
return layer_min, layer_max

layer_min, layer_max = get_layer_min_max(model_path)

df = reduce_model_results(root_dir / model)
layer_min, layer_max = df["layer"].min(), df["layer"].max()
for layer in range(layer_min, layer_max + 1):
for dir in model_path.iterdir():
file = os.path.join(dir, "eval.csv")
raw_eval_df = pd.read_csv(file)

eval_df = filter_df(raw_eval_df, layer)
eval_df["eval_dataset"] = dir.name
eval_df["train_dataset"] = dir.name
if df is None:
df = eval_df # first time
else:
df = pd.concat([df, eval_df], ignore_index=True)

transfer_dir = Path(os.path.join(dir, "transfer"))
for eval_ds_dir in transfer_dir.iterdir():
eval_file_path = os.path.join(eval_ds_dir, "eval.csv")
raw_eval_df = pd.read_csv(eval_file_path)
eval_df = filter_df(raw_eval_df, layer)
eval_df["eval_dataset"] = eval_ds_dir.name
eval_df["train_dataset"] = dir.name
df = pd.concat([df, eval_df], ignore_index=True)

model_name = model_path.parts[-1]
generate_heatmap(df, model_name, layer, visualizations_path)
df = df[df["ensembling"] == "full"]
df = df[df["layer"] == layer]

generate_heatmap(data=df,
model=model,
layer=layer,
viz_dir=viz_dir)