diff --git a/elk/utils/fsdp.py b/elk/utils/fsdp.py index d429ce90..99b883f5 100644 --- a/elk/utils/fsdp.py +++ b/elk/utils/fsdp.py @@ -165,6 +165,8 @@ def imap( closure_pkl = dill.dumps(closure) result_queues = [] for q, result_queue in zip(self._task_queues, self._result_queues): + # Put the same dataset on each queue, so that each worker gets the same + # inputs q.put((closure_pkl, dataset)) result_queues.append(result_queue) @@ -213,7 +215,7 @@ def _worker( logging.disable(logging.CRITICAL) warnings.filterwarnings("ignore") - closure: Callable[[ModelOutput], Any] | None = None + closure: Callable[[ModelOutput], Any] dataset: Dataset | None = None device = devices[rank]