Skip to content

Commit

Permalink
Adding label selection module
Browse files Browse the repository at this point in the history
  • Loading branch information
rjha18 committed Jan 3, 2024
1 parent 23c3115 commit cfc42d6
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 23 deletions.
5 changes: 5 additions & 0 deletions experiments/example_flip_selection/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[select_flips]
budgets = [150, 300, 500, 1000, 1500]
input = "/gscratch/sewoong/rjha01/code/robust-ml-suite/out/labels/r18_1xs/*/labels.npy"
true = "/gscratch/sewoong/rjha01/code/robust-ml-suite/out/labels/r18_1xs/0/true.npy"
output_path = "out/computed/r18_1xs/"
63 changes: 63 additions & 0 deletions modules/select_flips/run_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
TODO
"""

from pathlib import Path
import sys, glob

import numpy as np

from modules.base_utils.util import extract_toml


def run(experiment_name, module_name, **kwargs):
"""
TODO
"""
slurm_id = kwargs.get('slurm_id', None)

args = extract_toml(experiment_name, module_name)
budgets = args.get("budgets", [150, 300, 500, 1000, 1500])

input_path = args["input"] if slurm_id is None\
else args["input"].format(slurm_id)

true_path = args["true"] if slurm_id is None\
else args["true"].format(slurm_id)

output_path = args["output_path"] if slurm_id is None\
else args["output_path"].format(slurm_id)

Path(output_path).mkdir(parents=True, exist_ok=True)

distances = []
all_labels = []
for f in glob.glob(input_path):
labels = np.load(f)

true = np.load(true_path)
dists = np.zeros(len(labels))

inds = labels.argmax(axis=1) != true.argmax(axis=1)
dists[inds] = labels[inds].max(axis=1) -\
labels[inds][np.arange(inds.sum()), true[inds].argmax(axis=1)]

sorted = np.sort(labels[~inds])
dists[~inds] = sorted[:, -2] - sorted[:, -1]
distances.append(dists)
all_labels.append(labels)
distances = np.stack(distances)
all_labels = np.stack(all_labels).mean(axis=0)

np.save(f'{output_path}/true.npy', true)
for n in budgets:
to_save = true.copy()
if n != 0:
idx = np.argsort(distances.min(axis=0))[-n:]
all_labels[idx] = all_labels[idx] - 50000 * true[idx]
to_save[idx] = all_labels[idx]
np.save(f'{output_path}/{n}.npy', to_save)

if __name__ == "__main__":
experiment_name, module_name = sys.argv[1], sys.argv[2]
run(experiment_name, module_name)
23 changes: 0 additions & 23 deletions schemas/distillation.toml

This file was deleted.

10 changes: 10 additions & 0 deletions schemas/select_flips.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
###
# TODO
# select_flips schema
###

[select_flips]
budgets = "TODO"
input = "TODO"
true = "TODO"
output_path = "TODO"

0 comments on commit cfc42d6

Please sign in to comment.