Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed May 24, 2024
1 parent f423fa0 commit 6250719
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 6 deletions.
8 changes: 7 additions & 1 deletion w2s/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ def confidence_window_loss(

target = torch.stack([1.0 - labels, labels], dim=1)

return torch.nn.functional.cross_entropy(logits[uncertain], target[uncertain])
loss = torch.nn.functional.cross_entropy(
logits[uncertain],
target[uncertain],
reduction="sum"
)

return loss / logits.shape[0]


def log_confidence_loss(
Expand Down
9 changes: 5 additions & 4 deletions w2s/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
loss = confidence_window_loss(
outputs.logits,
labels,
radius=self.loss_cfg.radius,
radius=(self.loss_cfg.radius if self.transfer else 0.51),
)
else:
raise ValueError(f"Unknown loss function: {self.loss_name}")
Expand Down Expand Up @@ -113,8 +113,8 @@ def compute_metrics(eval_pred):
)

trainer = CustomLossTrainer(
loss_name=cfg.loss_name,
loss_cfg=cfg.loss,
loss_name=cfg["loss_name"],
loss_cfg=cfg["loss"],
transfer=transfer,
args=train_args,
compute_metrics=compute_metrics,
Expand All @@ -140,7 +140,8 @@ def compute_metrics(eval_pred):
with open(save_dir / "config.json", "w") as f:
cfg["model"] = model_cfg.to_dict()
cfg["train_args"] = train_args.to_dict()
cfg["logconf_weight"] = logconf_weight
cfg["transfer"] = transfer
cfg["loss"] = cfg["loss"].to_dict()
json.dump(cfg, f, indent=2)
wandb.config.update(cfg)

Expand Down
4 changes: 3 additions & 1 deletion w2s/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

@dataclass
class LossConfig(Serializable):
pass
def to_dict(self):
irrelevant_fields = []
return {k: v for k, v in vars(self).items() if k not in irrelevant_fields}

@dataclass
class LogConfidenceLossConfig(LossConfig):
Expand Down
11 changes: 11 additions & 0 deletions w2s/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Type, TypeVar, cast
from w2s.sft_config import LossConfig

T = TypeVar("T")

Expand All @@ -10,6 +11,12 @@ def assert_type(typ: Type[T], obj: Any) -> T:

return cast(typ, obj)

NICKNAMES = {
"Qwen/Qwen1.5-0.5B": "Qw0.5",
"meta-llama/Meta-Llama-3-8B": "Ll8",
"./results": "rs",
"cosine": "c",
}

def get_config_foldername(config: dict) -> str:
def shorten_key(key: str) -> str:
Expand All @@ -19,6 +26,8 @@ def shorten_value(value) -> str:
if isinstance(value, bool):
return "1" if value else "0"
elif isinstance(value, str):
if value in NICKNAMES:
return NICKNAMES[value]
value = value.split("/")[-1]
if "_" in value:
return "_".join(word[:4] for word in value.split("_"))
Expand All @@ -39,6 +48,8 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "_") -> dict:
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
elif isinstance(v, LossConfig):
items.extend(flatten_dict(v.to_dict(), new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)

0 comments on commit 6250719

Please sign in to comment.