Skip to content

Commit

Permalink
final fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed Jun 14, 2024
1 parent c769e15 commit 08cb994
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
6 changes: 3 additions & 3 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def run_train(cfg: SFTConfig):
cfg.dataset, cfg.n_train, cfg.n_val, cfg.n_test, cfg.n_predict
)

train_halves = splits["train"].train_test_split(test_size=0.5, seed=seed)
train_halves = splits["train"].train_test_split(test_size=0.5, seed=42)
splits["weak_train"] = train_halves["train"]
splits["strong_train"] = train_halves["test"]

Expand Down Expand Up @@ -184,15 +184,15 @@ def get_model_and_run_name(model_name, current_name):
)
# assert (prev_train_preds_ds["id"] == s2s_ds_dict["train"]["id"])
# assert (prev_val_preds_ds["id"] == s2s_ds_dict["val"]["id"])
s2s_predict_dict = {"strong_train": splits["train"], "val": splits["val"]}
s2s_predict_dict = {"train": splits["strong_train"], "val": splits["val"]}
train(
s2s_ds_dict,
model_cfg,
TrainingArguments(**train_args),
cfg.to_dict(),
transfer=True,
predict_dict=s2s_predict_dict,
acts_dir=root / cfg_name / f"s2s-{s2s_iter}" / "activations",
acts_dir=shared_root / cfg_name / f"s2s-{s2s_iter}" / "activations",
)

prev = f"s2s-{s2s_iter}"
Expand Down
3 changes: 3 additions & 0 deletions w2s/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ class LogConfidenceLossConfig(LossConfig):
class ConfidenceWindowLossConfig(LossConfig):
radius: Union[float, literal("midweak")] = 0.15

def to_dict(self):
return {"radius": self.radius if isinstance(self.radius, float) else "midweak"}

@dataclass
class LogEntropyLossConfig(LogConfidenceLossConfig):
pass
Expand Down
12 changes: 8 additions & 4 deletions w2s/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@ def __init__(
loss_name: str,
loss_cfg: LossConfig,
transfer: bool,
buffer_size: int,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.loss_name = loss_name
self.loss_cfg = loss_cfg
self.transfer = transfer
if loss_name == "logconf":
if loss_name in ["logconf", "entropy"]:
self.buffer = []
# self.buffer_size = kwargs["buffer_size"]
self.buffer_size = buffer_size


def compute_loss(self, model, inputs, return_outputs=False):
Expand All @@ -66,7 +67,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
balance_batch=self.loss_cfg.balance_batch,
harden=True,
buffer=self.buffer,
# buffer_size=self.buffer_size,
buffer_size=self.buffer_size,
)
elif self.loss_name == 'entropy':
loss = log_confidence_loss(
Expand All @@ -77,6 +78,8 @@ def compute_loss(self, model, inputs, return_outputs=False):
warmup_steps=self.loss_cfg.logconf_warmup_steps,
balance_batch=self.loss_cfg.balance_batch,
harden=False,
buffer=self.buffer,
buffer_size=self.buffer_size,
)
elif self.loss_name == 'xent':
loss = cross_entropy_loss(
Expand Down Expand Up @@ -217,7 +220,7 @@ def compute_metrics(eval_pred):
ds_dict[name] = ds
if cfg['probe_relabel']:
# print(lshape(ds["labels"]))
ds = ds.remove_columns("labels").add_column("labels", preds[:, 1].cpu().numpy())
ds = ds.remove_columns("labels").add_column("labels", preds[:, 1].detach().cpu().numpy())
ds_dict[name] = ds

if results_path.exists():
Expand All @@ -241,6 +244,7 @@ def compute_metrics(eval_pred):
trainer = CustomLossTrainer(
loss_name=cfg["loss_name"],
loss_cfg=cfg["loss"],
buffer_size=cfg["batch_size"],
transfer=transfer,
args=train_args,
compute_metrics=compute_metrics,
Expand Down

0 comments on commit 08cb994

Please sign in to comment.