Skip to content

Commit

Permalink
Fix dream prompt; move train.py outside package
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed May 14, 2024
1 parent dc4d8ef commit 1f5477a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ dependencies = [
"datasets",
"torch",
"peft",
"scipy",
"simple-parsing",
# 4.0 introduced the breaking change of using return_dict=True by default
"transformers>=4.0.0",
"wandb",
]
version = "0.0.1"

Expand Down
9 changes: 4 additions & 5 deletions w2s/train.py → train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
)

import wandb

from .ds_registry import load_and_process_dataset
from .knn import gather_hiddens, zeta_filter
from .loss import log_confidence_loss
from .roc_auc import roc_auc
from w2s.ds_registry import load_and_process_dataset
from w2s.knn import gather_hiddens, zeta_filter
from w2s.loss import log_confidence_loss
from w2s.roc_auc import roc_auc


@dataclass
Expand Down
7 changes: 5 additions & 2 deletions w2s/ds_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def format_dream(ex, rng):

ans = rng.choice(distractors)

txt = f"Q: {ex['question']} A: {ans}"
txt = f"{ex['dialogue']}\n\nQ: {ex['question']} A: {ans}"
return dict(txt=txt, hard_label=hard_label)


Expand Down Expand Up @@ -410,7 +410,10 @@ def format_openbookqa(ex, rng):
del distractors[letters.index(ex["answerKey"])]
ans = rng.choice(distractors)

txt = f"Q: {ex['question_stem']}\n\nA: {ans}"
choices = [
f"{a}) {t}" for a, t in zip(ex["choices"]["label"], ex["choices"]["text"])
]
txt = f"Q: {ex['question_stem']}\n\nChoices:\n{'\n'.join(choices)}\n\nAnswer: {ans}"
return dict(txt=txt, hard_label=hard_label)


Expand Down

0 comments on commit 1f5477a

Please sign in to comment.