Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Mar 23, 2023
1 parent 1a25f90 commit ac5559d
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 7 deletions.
6 changes: 4 additions & 2 deletions elk/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def save_meta(dataset, out_dir: Path):
"""Save the meta data to a file"""

meta = {
"dataset_fingerprints": {split: dataset[split]._fingerprint for split in dataset.keys()}
"dataset_fingerprints": {
split: dataset[split]._fingerprint for split in dataset.keys()
}
}
with open(out_dir / "metadata.yaml", "w") as meta_f:
yaml.dump(meta, meta_f)
yaml.dump(meta, meta_f)
2 changes: 1 addition & 1 deletion elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ def apply_to_layers(self, func):
for row in sorted(row_buf):
writer.writerow(row)
if self.cfg.debug:
save_debug_log(self.dataset, self.out_dir)
save_debug_log(self.dataset, self.out_dir)
2 changes: 1 addition & 1 deletion elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,4 @@ def get_pseudo_auroc(
return pseudo_auroc

def train(self):
self.apply_to_layers(func=self.train_reporter)
self.apply_to_layers(func=self.train_reporter)
2 changes: 1 addition & 1 deletion elk/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@

from .gpu_utils import select_usable_devices
from .tree_utils import pytree_map
from .typing import assert_type, float32_to_int16, int16_to_float32
from .typing import assert_type, float32_to_int16, int16_to_float32
6 changes: 4 additions & 2 deletions elk/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def undersample(

return dataset


def get_layers(ds: DatasetDict) -> List[int]:
"""Get a list of indices of hidden layers given a `DatasetDict`."""
layers = [
Expand All @@ -141,6 +142,7 @@ def get_layers(ds: DatasetDict) -> List[int]:
]
return layers


def apply_template(template: Template, example: dict) -> str:
"""Concatenate question and answer if answer is not empty or whitespace."""
q, a = template.apply(example)
Expand All @@ -149,6 +151,7 @@ def apply_template(template: Template, example: dict) -> str:
sep = "" if not q or q[-1].isspace() or not a or a[0].isspace() else " "
return f"{q}{sep}{a}" if a and not a.isspace() else q


def binarize(
template: Template, label: int, new_label: int, rng: Random
) -> tuple[Template, int]:
Expand All @@ -162,7 +165,6 @@ def binarize(
the index of the true answer into `new_template.answer_choices`
"""


# TODO: it would be nice in the future to binarize exhaustively so we're not
# cheating here (since this step requires a label). e.g. this function would
# also take a candidate answer and the template would ask whether the candidate
Expand All @@ -179,4 +181,4 @@ def binarize(
new_template.answer_choices = (
f"{false} ||| {true}" if new_label else f"{true} ||| {false}"
)
return new_template, new_label
return new_template, new_label

0 comments on commit ac5559d

Please sign in to comment.