forked from EleutherAI/elk
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'upstream/main'
- Loading branch information
Showing
44 changed files
with
1,764 additions
and
1,071 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import logging | ||
from pathlib import Path | ||
|
||
from datasets import DatasetDict | ||
|
||
from .utils import get_dataset_name, select_train_val_splits | ||
|
||
|
||
def save_debug_log(datasets: list[DatasetDict], out_dir: Path) -> None: | ||
""" | ||
Save a debug log to the output directory. This is useful for debugging | ||
training issues. | ||
""" | ||
|
||
logging.basicConfig( | ||
level=logging.DEBUG, | ||
format="%(asctime)s %(levelname)s:\n%(message)s", | ||
filename=out_dir / "debug.log", | ||
filemode="w", | ||
) | ||
|
||
for ds in datasets: | ||
logging.info( | ||
"=========================================\n" | ||
f"Dataset: {get_dataset_name(ds)}\n" | ||
"=========================================" | ||
) | ||
|
||
train_split, val_split = select_train_val_splits(ds) | ||
text_questions = ds[val_split][0]["text_questions"] | ||
template_ids = ds[val_split][0]["variant_ids"] | ||
label = ds[val_split][0]["label"] | ||
|
||
# log the train size and val size | ||
logging.info(f"Train size: {len(ds[train_split])}") | ||
logging.info(f"Val size: {len(ds[val_split])}") | ||
|
||
templates_text = f"{len(text_questions)} templates used:\n" | ||
trailing_whitespace = False | ||
for (text0, text1), id in zip(text_questions, template_ids): | ||
templates_text += ( | ||
f'***---TEMPLATE "{id}"---***\n' | ||
f"{'false' if label else 'true'}:\n" | ||
f'"""{text0}"""\n' | ||
f"{'true' if label else 'false'}:\n" | ||
f'"""{text1}"""\n\n\n' | ||
) | ||
if text0[-1].isspace() or text1[-1].isspace(): | ||
trailing_whitespace = True | ||
if trailing_whitespace: | ||
logging.warning( | ||
"Some inputs to the model have trailing whitespace! " | ||
"Check that the jinja templates are not adding " | ||
"trailing whitespace. If `token_loc` is 'last', this " | ||
"will extract hidden states from the whitespace token." | ||
) | ||
logging.info(templates_text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.