Skip to content

Commit

Permalink
TODO cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rjha18 committed Jan 3, 2024
1 parent 2d2b32d commit 2bdb411
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 14 deletions.
4 changes: 0 additions & 4 deletions modules/base_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,8 +408,6 @@ def load_cifar_100_dataset(path, train=True, coarse=True):
16, 19, 2, 4, 6, 19, 5, 5, 8, 19,
18, 1, 2, 15, 6, 0, 17, 8, 14, 13])
dataset.targets = coarse_labels[dataset.targets]

# TODO: get actual class names
dataset.classes = range(coarse_labels.max()+1)
return dataset

Expand Down Expand Up @@ -606,5 +604,3 @@ def construct_user_dataset(distill_dataset, labels, mask=None, target_label=None

def get_n_classes(dataset_flag):
return N_CLASSES[dataset_flag]

# TODO: Add oversampling trick back into tiny_imagenet
1 change: 0 additions & 1 deletion modules/base_utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ def mini_train(
train_epoch_correct += int(correct.item())
train_epoch_loss += float(loss.item())
pbar.update(minibatch_size)
# TODO: make this into a list of callbacks
if callback is not None:
callback(model, opt, epoch, i)

Expand Down
13 changes: 4 additions & 9 deletions modules/train_expert/run_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from modules.base_utils.datasets import get_matching_datasets, pick_poisoner
from modules.base_utils.datasets import get_matching_datasets, get_n_classes, pick_poisoner
from modules.base_utils.util import extract_toml, load_model,\
generate_full_path, clf_eval, mini_train,\
get_train_info, needs_big_ims
Expand Down Expand Up @@ -43,14 +43,9 @@ def run(experiment_name, module_name, **kwargs):

Path(output_path[:output_path.rfind('/')]).mkdir(parents=True,
exist_ok=True)

# TODO: make this more extensible
if dataset_flag == "cifar_100":
model = load_model(model_flag, 20)
elif dataset_flag == "tiny_imagenet":
model = load_model(model_flag, 200)
else:
model = load_model(model_flag)

n_classes = get_n_classes(dataset_flag)
model = load_model(model_flag, n_classes)

print(f"{model_flag=} {clean_label=} {target_label=} {poisoner_flag=}")
print("Building datasets...")
Expand Down

0 comments on commit 2bdb411

Please sign in to comment.