Skip to content

Commit

Permalink
midweak and cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed May 25, 2024
1 parent c91df69 commit 3bcdd61
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 8 deletions.
9 changes: 4 additions & 5 deletions run_eight.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sys import argv

# Define the datasets and respective GPU ids
# list of tuples with dataset name and minibatch size
configs = [
("boolq", 2),
("anli-r2", 8),
Expand Down Expand Up @@ -31,14 +32,12 @@
"--eval_every 100 "
"--save_every 100 "
"--save_total_limit 1 "
"--loss logconf "
"--logconf_warmup_steps 80 "
"--balance_batch "
"--logconf_weight 0.5 "
"--loss window "
"--radius midweak "
"--minibatch_size {minibatch_size} "
"--weak_lr 5e-4 "
"--strong_lr 8e-5 "
'--run_name "logconf4" '
'--run_name "mw_window" '
)


Expand Down
5 changes: 5 additions & 0 deletions w2s/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def compute_metrics(eval_pred):
auroc=roc_auc(hard_labels, predictions[:, 1]),
)

if transfer and cfg["loss_name"] == "window" and cfg["loss"].radius == "midweak":
confs = torch.abs(torch.tensor(ds_dict["train"]["labels"]) - 0.5)
cfg["loss"].radius = confs.median().item()
print(f"Setting radius to {cfg['loss'].radius:.2f} based on median confidence in train set")

trainer = CustomLossTrainer(
loss_name=cfg["loss_name"],
loss_cfg=cfg["loss"],
Expand Down
11 changes: 8 additions & 3 deletions w2s/sft_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from dataclasses import dataclass
from typing import Literal, Optional, Union
from typing import Optional, Union
from enum import StrEnum

from simple_parsing import Serializable, field, subgroups


# simple_parsing doesn't like typing.Literal so I rolled my own
# note: parens, not brackets
literal = lambda *args: StrEnum("option", args)

@dataclass
class LossConfig(Serializable):
def to_dict(self):
Expand All @@ -18,7 +23,7 @@ class LogConfidenceLossConfig(LossConfig):

@dataclass
class ConfidenceWindowLossConfig(LossConfig):
radius: float = 0.15
radius: Union[float, literal("midweak")] = 0.15

LOSS_CONFIGS = {
"logconf": LogConfidenceLossConfig,
Expand All @@ -38,7 +43,7 @@ class SFTConfig(Serializable): # TODO: what is this for??
n_test: int = 1_000
# when "train", it uses the training set to generate predictions
# otherwise it uses n_predict held out examples
n_predict: Union[Literal["train"], int] = 0
n_predict: Union[literal("train"), int] = 0
minibatch_size: int = 8
# examples per update
batch_size: int = 32
Expand Down

0 comments on commit 3bcdd61

Please sign in to comment.