diff --git a/elk/extraction/inference_server.py b/elk/extraction/inference_server.py index 64f14cf4..65de6876 100644 --- a/elk/extraction/inference_server.py +++ b/elk/extraction/inference_server.py @@ -216,13 +216,18 @@ def imap( q.put((closure_pkl, model_kwargs_pkl, shard)) generator = round_robin(self._result_queues) # type: ignore[arg-type] - seen_dummy = False + seen_ids = set() for out in tqdm(generator, total=len(dataset), disable=not use_tqdm): if out[0] == dummy_id: - if seen_dummy: + if dummy_id in seen_ids: continue # ignore any extra dummy rows - else: - seen_dummy = True + elif out[0] in seen_ids: + raise RuntimeError( + "Round robin yielded duplicate items. " + "This may be due to multiprocessing queues returning " + "items repeatedly." + ) + seen_ids.add(out[0]) yield out