Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove Inferencer multiprocessing #3283

Merged
merged 4 commits into from
Oct 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 38 additions & 138 deletions haystack/modeling/infer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import List, Optional, Dict, Union, Generator, Set, Any
from typing import List, Optional, Dict, Union, Set, Any

import os
import logging
import multiprocessing as mp
from functools import partial
from tqdm import tqdm
import torch
from torch.utils.data.sampler import SequentialSampler
Expand All @@ -12,13 +10,7 @@
from haystack.modeling.data_handler.dataloader import NamedDataLoader
from haystack.modeling.data_handler.processor import Processor, InferenceProcessor
from haystack.modeling.data_handler.samples import SampleBasket
from haystack.modeling.utils import (
grouper,
initialize_device_settings,
set_all_seeds,
calc_chunksize,
log_ascii_workers,
)
from haystack.modeling.utils import initialize_device_settings, set_all_seeds
from haystack.modeling.data_handler.inputs import QAInput
from haystack.modeling.model.adaptive_model import AdaptiveModel, BaseAdaptiveModel
from haystack.modeling.model.predictions import QAPred
Expand Down Expand Up @@ -70,6 +62,9 @@ def __init__(
`multiprocessing.Pool` again! To do so call
:func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
done using this class. The garbage collector will not do this for you!
.. deprecated:: 1.10
This parameter has no effect; it will be removed as Inferencer multiprocessing
has been deprecated.
:param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing)
:param devices: List of torch devices (e.g. cuda, cpu, mps) to limit inference to specific devices.
A list containing torch device objects and/or strings is supported (For example
Expand Down Expand Up @@ -113,8 +108,6 @@ def __init__(
model.connect_heads_with_processor(processor.tasks, require_labels=False)
set_all_seeds(42)

self._set_multiprocessing_pool(num_processes)

@classmethod
def load(
cls,
Expand Down Expand Up @@ -166,6 +159,9 @@ def load(
`multiprocessing.Pool` again! To do so call
:func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
done using this class. The garbage collector will not do this for you!
.. deprecated:: 1.10
This parameter has no effect; it will be removed as Inferencer multiprocessing
has been deprecated.
:param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing)
:param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`)
:param use_fast: (Optional, True by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or
Expand Down Expand Up @@ -259,48 +255,6 @@ def load(
devices=devices,
)

def _set_multiprocessing_pool(self, num_processes: Optional[int]) -> None:
"""
Initialize a multiprocessing.Pool for instances of Inferencer.

:param num_processes: the number of processes for `multiprocessing.Pool`.
Set to value of 1 (or 0) to disable multiprocessing.
Set to None to let Inferencer use all CPU cores minus one.
If you want to debug the Language Model, you might need to disable multiprocessing!
**Warning!** If you use multiprocessing you have to close the
`multiprocessing.Pool` again! To do so call
:func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are
done using this class. The garbage collector will not do this for you!
:return: None
"""
self.process_pool = None
if num_processes == 0 or num_processes == 1: # disable multiprocessing
self.process_pool = None
else:
if num_processes is None: # use all CPU cores
if mp.cpu_count() > 3:
num_processes = mp.cpu_count() - 1
else:
num_processes = mp.cpu_count()
self.process_pool = mp.Pool(processes=num_processes)
logger.info("Got ya %s parallel workers to do inference ...", num_processes)
log_ascii_workers(n=num_processes, logger=logger)

def close_multiprocessing_pool(self, join: bool = False):
"""Close the `multiprocessing.Pool` again.

If you use multiprocessing you have to close the `multiprocessing.Pool` again!
To do so call this function after you are done using this class.
The garbage collector will not do this for you!

:param join: wait for the worker processes to exit
"""
if self.process_pool is not None:
self.process_pool.close()
if join:
self.process_pool.join()
self.process_pool = None

def save(self, path: str):
self.model.save(path)
self.processor.save(path)
Expand All @@ -313,6 +267,9 @@ def inference_from_file(self, file: str, multiprocessing_chunksize: int = None,

:param file: path of the input file for Inference
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
.. deprecated:: 1.10
This parameter has no effect; it will be removed as Inferencer multiprocessing
has been deprecated.
:return: list of predictions
"""
dicts = self.processor.file_to_dicts(file)
Expand All @@ -333,8 +290,11 @@ def inference_from_dicts(
One dict per sample.
:param return_json: Whether the output should be in a json appropriate format. If False, it returns the prediction
object where applicable, else it returns PredObj.to_json()
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
(only relevant if you do multiprocessing)
.. deprecated:: 1.10
This parameter has no effect; it will be removed as Inferencer multiprocessing
has been deprecated.
:return: list of predictions
"""
# whether to aggregate predictions across different samples (e.g. for QA on long texts)
Expand All @@ -346,26 +306,8 @@ def inference_from_dicts(
if len(self.model.prediction_heads) > 0:
aggregate_preds = hasattr(self.model.prediction_heads[0], "aggregate_preds")

if self.process_pool is None: # multiprocessing disabled (helpful for debugging or using in web frameworks)
predictions: Any = self._inference_without_multiprocessing(dicts, return_json, aggregate_preds)
return predictions
else: # use multiprocessing for inference
# Calculate values of multiprocessing_chunksize and num_processes if not supplied in the parameters.

if multiprocessing_chunksize is None:
_chunk_size, _ = calc_chunksize(len(dicts))
multiprocessing_chunksize = _chunk_size

predictions = self._inference_with_multiprocessing(
dicts, return_json, aggregate_preds, multiprocessing_chunksize
)

self.processor.log_problematic(self.problematic_sample_ids)
# cast the generator to a list if it isnt already a list.
if type(predictions) != list:
return list(predictions)
else:
return predictions
predictions: Any = self._inference_without_multiprocessing(dicts, return_json, aggregate_preds)
return predictions

def _inference_without_multiprocessing(self, dicts: List[Dict], return_json: bool, aggregate_preds: bool) -> List:
"""
Expand Down Expand Up @@ -399,69 +341,6 @@ def _inference_without_multiprocessing(self, dicts: List[Dict], return_json: boo

return preds_all

def _inference_with_multiprocessing(
self,
dicts: Union[List[Dict], Generator[Dict, None, None]],
return_json: bool,
aggregate_preds: bool,
multiprocessing_chunksize: int,
) -> Generator[Dict, None, None]:
"""
Implementation of inference. This method is a generator that yields the results.

:param dicts: Samples to run inference on provided as a list of dicts or a generator object that yield dicts.
:param return_json: Whether the output should be in a json appropriate format. If False, it returns the prediction
object where applicable, else it returns PredObj.to_json()
:param aggregate_preds: whether to aggregate predictions across different samples (e.g. for QA on long texts)
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
:return: generator object that yield predictions
"""

# We group the input dicts into chunks and feed each chunk to a different process
# in the pool, where it gets converted to a pytorch dataset
if self.process_pool is not None:
results = self.process_pool.imap(
partial(self._create_datasets_chunkwise, processor=self.processor),
grouper(iterable=dicts, n=multiprocessing_chunksize),
1,
)

# Once a process spits out a preprocessed chunk. we feed this dataset directly to the model.
# So we don't need to wait until all preprocessing has finished before getting first predictions.
for dataset, tensor_names, problematic_sample_ids, baskets in results:
self.problematic_sample_ids.update(problematic_sample_ids)
if dataset is None:
logger.error(
f"Part of the dataset could not be converted! \n"
f"BE AWARE: The order of predictions will not conform with the input order!"
)
else:
# TODO change format of formatted_preds in QA (list of dicts)
if aggregate_preds:
predictions = self._get_predictions_and_aggregate(dataset, tensor_names, baskets)
else:
predictions = self._get_predictions(dataset, tensor_names, baskets)

if return_json:
# TODO this try catch should be removed when all tasks return prediction objects
try:
predictions = [x.to_json() for x in predictions]
except AttributeError:
pass
yield from predictions

@classmethod
def _create_datasets_chunkwise(cls, chunk, processor: Processor):
"""Convert ONE chunk of data (i.e. dictionaries) into ONE pytorch dataset.
This is usually executed in one of many parallel processes.
The resulting datasets of the processes are merged together afterwards"""
dicts = [d[1] for d in chunk]
indices = [d[0] for d in chunk]
dataset, tensor_names, problematic_sample_ids, baskets = processor.dataset_from_dicts(
dicts, indices, return_baskets=True
)
return dataset, tensor_names, problematic_sample_ids, baskets

def _get_predictions(self, dataset: Dataset, tensor_names: List, baskets):
"""
Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting).
Expand Down Expand Up @@ -592,20 +471,41 @@ def __init__(self, *args, **kwargs):
def inference_from_dicts(
self, dicts: List[dict], return_json: bool = True, multiprocessing_chunksize: Optional[int] = None
) -> List[QAPred]:
"""
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
(only relevant if you do multiprocessing)
.. deprecated:: 1.10
This parameter has no effect; it will be removed as Inferencer multiprocessing
has been deprecated.
"""
return Inferencer.inference_from_dicts(
self, dicts, return_json=return_json, multiprocessing_chunksize=multiprocessing_chunksize
)

def inference_from_file(
self, file: str, multiprocessing_chunksize: Optional[int] = None, return_json=True
) -> List[QAPred]:
"""
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
(only relevant if you do multiprocessing)
.. deprecated:: 1.10
This parameter has no effect; it will be removed as Inferencer multiprocessing
has been deprecated.
"""
return Inferencer.inference_from_file(
self, file, return_json=return_json, multiprocessing_chunksize=multiprocessing_chunksize
)

def inference_from_objects(
self, objects: List[QAInput], return_json: bool = True, multiprocessing_chunksize: Optional[int] = None
) -> List[QAPred]:
"""
:param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process
(only relevant if you do multiprocessing)
.. deprecated:: 1.10
This parameter has no effect; it will be removed as Inferencer multiprocessing
has been deprecated.
"""
dicts = [o.to_dict() for o in objects]
# TODO investigate this deprecation warning. Timo: I thought we were about to implement Input Objects,
# then we can and should use inference from (input) objects!
Expand Down
25 changes: 9 additions & 16 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,22 +1115,15 @@ def adaptive_model_qa(num_processes):
"""
PyTest Fixture for a Question Answering Inferencer based on PyTorch.
"""
try:
model = Inferencer.load(
"deepset/bert-base-cased-squad2",
task_type="question_answering",
batch_size=16,
num_processes=num_processes,
gpu=False,
)
yield model
finally:
if num_processes != 0:
# close the pool
# we pass join=True to wait for all sub processes to close
# this is because below we want to test if all sub-processes
# have exited
model.close_multiprocessing_pool(join=True)

model = Inferencer.load(
"deepset/bert-base-cased-squad2",
task_type="question_answering",
batch_size=16,
num_processes=num_processes,
gpu=False,
)
yield model

# check if all workers (sub processes) are closed
current_process = psutil.Process()
Expand Down