forked from EleutherAI/elk
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Multiple datasets refactor (EleutherAI#189)
* Fix bug where cached hidden states aren’t used when num_gpus is different * Actually works now * Refactor handling of multiple datasets * Various fixes * Fix math tests * Fix smoke tests * All tests working ostensibly * Make CCS normalization customizable * log each dataset individually * Move pseudo AUROC stuff to CcsReporter * Make 'datasets' and 'label_columns' config options more opinionated * tiny spacing change * Allow for toggling CV * add typing to logging; rename logging * Fix eval logging bug --------- Co-authored-by: Alex Mallen <[email protected]>
- Loading branch information
1 parent
361fb9b
commit 16dc1ca
Showing
28 changed files
with
687 additions
and
627 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import logging | ||
from pathlib import Path | ||
|
||
from datasets import DatasetDict | ||
|
||
from .utils import get_dataset_name, select_train_val_splits | ||
|
||
|
||
def save_debug_log(datasets: list[DatasetDict], out_dir: Path) -> None: | ||
""" | ||
Save a debug log to the output directory. This is useful for debugging | ||
training issues. | ||
""" | ||
|
||
logging.basicConfig( | ||
level=logging.DEBUG, | ||
format="%(asctime)s %(levelname)s:\n%(message)s", | ||
filename=out_dir / "debug.log", | ||
filemode="w", | ||
) | ||
|
||
for ds in datasets: | ||
logging.info( | ||
"=========================================\n" | ||
f"Dataset: {get_dataset_name(ds)}\n" | ||
"=========================================" | ||
) | ||
|
||
train_split, val_split = select_train_val_splits(ds) | ||
text_inputs = ds[val_split][0]["text_inputs"] | ||
template_ids = ds[val_split][0]["variant_ids"] | ||
label = ds[val_split][0]["label"] | ||
|
||
# log the train size and val size | ||
logging.info(f"Train size: {len(ds[train_split])}") | ||
logging.info(f"Val size: {len(ds[val_split])}") | ||
|
||
templates_text = f"{len(text_inputs)} templates used:\n" | ||
trailing_whitespace = False | ||
for (text0, text1), id in zip(text_inputs, template_ids): | ||
templates_text += ( | ||
f'***---TEMPLATE "{id}"---***\n' | ||
f"{'false' if label else 'true'}:\n" | ||
f'"""{text0}"""\n' | ||
f"{'true' if label else 'false'}:\n" | ||
f'"""{text1}"""\n\n\n' | ||
) | ||
if text0[-1].isspace() or text1[-1].isspace(): | ||
trailing_whitespace = True | ||
if trailing_whitespace: | ||
logging.warning( | ||
"Some inputs to the model have trailing whitespace! " | ||
"Check that the jinja templates are not adding " | ||
"trailing whitespace. If `token_loc` is 'last', this " | ||
"will extract hidden states from the whitespace token." | ||
) | ||
logging.info(templates_text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.