Skip to content

Commit

Permalink
added cli arg
Browse files Browse the repository at this point in the history
  • Loading branch information
derpyplops committed Jul 13, 2023
1 parent db1f897 commit 47bcfb2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 3 deletions.
2 changes: 1 addition & 1 deletion elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def execute(self, highlight_color: Color = "cyan"):

@torch.inference_mode()
def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
self, layer: int, devices: list[str], world_size: int, probe_per_prompt: bool
) -> dict[str, pd.DataFrame]:
"""Evaluate a single reporter on a single layer."""
device = self.get_device(devices, world_size)
Expand Down
11 changes: 9 additions & 2 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class Run(ABC, Serializable):
prompt_indices: tuple[int, ...] = ()
"""The indices of the prompt templates to use. If empty, all prompts are used."""

probe_per_prompt: bool = False
"""If true, a probe is trained per prompt template. Otherwise, a single probe is
trained for all prompt templates."""

concatenated_layer_offset: int = 0
debug: bool = False
min_gpu_mem: int | None = None # in bytes
Expand Down Expand Up @@ -99,13 +103,16 @@ def execute(
devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem)
num_devices = len(devices)
func: Callable[[int], dict[str, pd.DataFrame]] = partial(
self.apply_to_layer, devices=devices, world_size=num_devices
self.apply_to_layer,
devices=devices,
world_size=num_devices,
probe_per_prompt=self.probe_per_prompt,
)
self.apply_to_layers(func=func, num_devices=num_devices)

@abstractmethod
def apply_to_layer(
self, layer: int, devices: list[str], world_size: int
self, layer: int, devices: list[str], world_size: int, probe_per_prompt: bool
) -> dict[str, pd.DataFrame]:
"""Train or eval a reporter on a single layer."""

Expand Down
1 change: 1 addition & 0 deletions elk/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def apply_to_layer(
layer: int,
devices: list[str],
world_size: int,
probe_per_prompt: bool,
) -> dict[str, pd.DataFrame]:
"""Train a single reporter on a single layer."""

Expand Down

0 comments on commit 47bcfb2

Please sign in to comment.