diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 9c082c52..1cf3bf44 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -1,4 +1,5 @@ """Functions for extracting the hidden states of a model.""" + import os from collections import defaultdict from dataclasses import InitVar, dataclass, replace diff --git a/elk/plotting/visualize.py b/elk/plotting/visualize.py index 85eedd43..93fbd650 100644 --- a/elk/plotting/visualize.py +++ b/elk/plotting/visualize.py @@ -78,9 +78,9 @@ def render( y=dataset_data["auroc_estimate"], mode="lines", name=ensemble, - showlegend=False - if dataset_name != unique_datasets[0] - else True, + showlegend=( + False if dataset_name != unique_datasets[0] else True + ), line=dict(color=color_map[ensemble]), ), row=row, diff --git a/elk/promptsource/templates.py b/elk/promptsource/templates.py index 7d4c0b84..8d93a40c 100644 --- a/elk/promptsource/templates.py +++ b/elk/promptsource/templates.py @@ -215,9 +215,11 @@ def _escape_pipe(cls, example): # Replaces any occurrences of the "|||" separator in the example, which # which will be replaced back after splitting protected_example = { - key: value.replace("|||", cls.pipe_protector) - if isinstance(value, str) - else value + key: ( + value.replace("|||", cls.pipe_protector) + if isinstance(value, str) + else value + ) for key, value in example.items() } return protected_example diff --git a/elk/training/platt_scaling.py b/elk/training/platt_scaling.py index 278d8d95..70dd87c3 100644 --- a/elk/training/platt_scaling.py +++ b/elk/training/platt_scaling.py @@ -12,8 +12,7 @@ class PlattMixin(ABC): scale: nn.Parameter @abstractmethod - def __call__(self, *args: Any, **kwds: Any) -> Any: - ... + def __call__(self, *args: Any, **kwds: Any) -> Any: ... def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100): """Fit the scale and bias terms to data with LBFGS.