Skip to content

Commit

Permalink
Refactoring logits to soft
Browse files Browse the repository at this point in the history
  • Loading branch information
rjha18 committed Jan 3, 2024
1 parent 557233a commit 4661e36
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion experiments/example_downstream/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ source_label = 9
target_label = 4
poisoner = "1xs"
output_path = "experiments/example_downstream/"
logits = false
soft = false
alpha = 0.0
2 changes: 1 addition & 1 deletion experiments/example_downstream_soft/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ source_label = 9
target_label = 4
poisoner = "1xs"
output_path = "experiments/example_downstream/"
logits = true
soft = true
alpha = 0.2
2 changes: 1 addition & 1 deletion experiments/example_precomputed/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ source_label = 9
target_label = 4
poisoner = "1xs"
output_path = "experiments/example_precomputed/"
logits = false
soft = false
alpha = 0.0

[downstream.optim_kwargs]
Expand Down
2 changes: 1 addition & 1 deletion experiments/example_precomputed_mix/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ source_label = 9
target_label = 4
poisoner = "1xs"
output_path = "experiments/example_precomputed/"
logits = false
soft = false
alpha = 0.0

[downstream.optim_kwargs]
Expand Down
4 changes: 2 additions & 2 deletions modules/downstream/run_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def run(experiment_name, module_name, **kwargs):
poisoner_flag = args["poisoner"]
clean_label = args["source_label"]
target_label = args["target_label"]
logits = args.get("logits", True)
soft = args.get("soft", True)
batch_size = args.get("batch_size", None)
epochs = args.get("epochs", None)
optim_kwargs = args.get("optim_kwargs", {})
Expand Down Expand Up @@ -74,7 +74,7 @@ def run(experiment_name, module_name, **kwargs):
else:
labels_d = softmax(labels_syn)

if not logits:
if not soft:
labels_d = labels_d.argmax(dim=1)

downstream_dataset = construct_downstream_dataset(distillation, labels_d)
Expand Down
2 changes: 1 addition & 1 deletion schemas/downstream.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ target_label = "int: {0,1,...,9}. Specifies label to attack"
poisoner = "string: Form: {{1,2,3,9}xp, {1,2}xs, {1,4}xl}. Integer resembles number of attacks and string represents type"

[OPTIONAL]
logits = "TODO"
soft = "TODO"
alpha = "TODO"
true = "TODO"
batch_size = "int: {0,1,...,infty}. Specifies batch size. Set to default for trainer if omitted."
Expand Down

0 comments on commit 4661e36

Please sign in to comment.