Skip to content

Commit

Permalink
Force spawn start method
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Apr 8, 2023
1 parent d83c5cb commit fd8af71
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 2 deletions.
3 changes: 3 additions & 0 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ def get_splits() -> SplitDict:
)
for (split_name, split_info) in get_splits().items()
}
import multiprocess as mp

mp.set_start_method("spawn", force=True) # type: ignore[attr-defined]

ds = dict()
for split, builder in builders.items():
Expand Down
3 changes: 2 additions & 1 deletion elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def apply_to_layers(
layers = self.concatenate(layers)

# Should we write to different CSV files for elicit vs eval?
with mp.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f:
ctx = mp.get_context("spawn")
with ctx.Pool(num_devices) as pool, open(self.out_dir / "eval.csv", "w") as f:
mapper = pool.imap_unordered if num_devices > 1 else map
iterator: Iterator[Log] = tqdm( # type: ignore
mapper(func, layers), total=len(layers)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_write_iterator_to_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def to_csv_line(x):
return x.to_csv_line(skip_baseline=True)

try:
with mp.Pool(processes) as pool, open(tmp_path / "eval.csv", "w") as f:
ctx = mp.get_context("spawn")
with ctx.Pool(processes) as pool, open(tmp_path / "eval.csv", "w") as f:
layers = [1, 2, 3]
iterator = pool.imap_unordered(log_function, layers)
write_iterator_to_file(
Expand Down

0 comments on commit fd8af71

Please sign in to comment.