diff --git a/elk/evaluation/evaluate.py b/elk/evaluation/evaluate.py index b954e661..9cfdd6eb 100644 --- a/elk/evaluation/evaluate.py +++ b/elk/evaluation/evaluate.py @@ -41,16 +41,22 @@ class Eval(Serializable): num_gpus: int = -1 skip_baseline: bool = False concatenated_layer_offset: int = 0 + combine_evals: bool = False def execute(self): datasets = self.data.prompts.datasets transfer_dir = elk_reporter_dir() / self.source / "transfer_eval" - for dataset in datasets: - self.data.prompts.datasets = [dataset] - run = Evaluate(cfg=self, out_dir=transfer_dir / dataset) + if self.combine_evals: + run = Evaluate(cfg=self, out_dir=transfer_dir / ", ".join(datasets)) run.evaluate() + else: + # eval on each dataset separately + for dataset in datasets: + self.data.prompts.datasets = [dataset] + run = Evaluate(cfg=self, out_dir=transfer_dir / dataset) + run.evaluate() @dataclass