forked from EleutherAI/elk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
debug_logging.py
57 lines (48 loc) · 1.98 KB
/
debug_logging.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import logging
from pathlib import Path
from datasets import DatasetDict
from .utils import get_dataset_name, select_train_val_splits
def save_debug_log(datasets: list[DatasetDict], out_dir: Path) -> None:
"""
Save a debug log to the output directory. This is useful for debugging
training issues.
"""
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s %(levelname)s:\n%(message)s",
filename=out_dir / "debug.log",
filemode="w",
)
for ds in datasets:
logging.info(
"=========================================\n"
f"Dataset: {get_dataset_name(ds)}\n"
"========================================="
)
train_split, val_split = select_train_val_splits(ds)
text_inputs = ds[val_split][0]["text_inputs"]
template_ids = ds[val_split][0]["variant_ids"]
label = ds[val_split][0]["label"]
# log the train size and val size
logging.info(f"Train size: {len(ds[train_split])}")
logging.info(f"Val size: {len(ds[val_split])}")
templates_text = f"{len(text_inputs)} templates used:\n"
trailing_whitespace = False
for (text0, text1), id in zip(text_inputs, template_ids):
templates_text += (
f'***---TEMPLATE "{id}"---***\n'
f"{'false' if label else 'true'}:\n"
f'"""{text0}"""\n'
f"{'true' if label else 'false'}:\n"
f'"""{text1}"""\n\n\n'
)
if text0[-1].isspace() or text1[-1].isspace():
trailing_whitespace = True
if trailing_whitespace:
logging.warning(
"Some inputs to the model have trailing whitespace! "
"Check that the jinja templates are not adding "
"trailing whitespace. If `token_loc` is 'last', this "
"will extract hidden states from the whitespace token."
)
logging.info(templates_text)