diff --git a/haystack/modeling/infer.py b/haystack/modeling/infer.py index c7c289b4c5..d8ab26a27e 100644 --- a/haystack/modeling/infer.py +++ b/haystack/modeling/infer.py @@ -1,9 +1,7 @@ -from typing import List, Optional, Dict, Union, Generator, Set, Any +from typing import List, Optional, Dict, Union, Set, Any import os import logging -import multiprocessing as mp -from functools import partial from tqdm import tqdm import torch from torch.utils.data.sampler import SequentialSampler @@ -12,13 +10,7 @@ from haystack.modeling.data_handler.dataloader import NamedDataLoader from haystack.modeling.data_handler.processor import Processor, InferenceProcessor from haystack.modeling.data_handler.samples import SampleBasket -from haystack.modeling.utils import ( - grouper, - initialize_device_settings, - set_all_seeds, - calc_chunksize, - log_ascii_workers, -) +from haystack.modeling.utils import initialize_device_settings, set_all_seeds from haystack.modeling.data_handler.inputs import QAInput from haystack.modeling.model.adaptive_model import AdaptiveModel, BaseAdaptiveModel from haystack.modeling.model.predictions import QAPred @@ -70,6 +62,9 @@ def __init__( `multiprocessing.Pool` again! To do so call :func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are done using this class. The garbage collector will not do this for you! + .. deprecated:: 1.10 + This parameter has no effect; it will be removed as Inferencer multiprocessing + has been deprecated. :param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing) :param devices: List of torch devices (e.g. cuda, cpu, mps) to limit inference to specific devices. A list containing torch device objects and/or strings is supported (For example @@ -113,8 +108,6 @@ def __init__( model.connect_heads_with_processor(processor.tasks, require_labels=False) set_all_seeds(42) - self._set_multiprocessing_pool(num_processes) - @classmethod def load( cls, @@ -166,6 +159,9 @@ def load( `multiprocessing.Pool` again! To do so call :func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are done using this class. The garbage collector will not do this for you! + .. deprecated:: 1.10 + This parameter has no effect; it will be removed as Inferencer multiprocessing + has been deprecated. :param disable_tqdm: Whether to disable tqdm logging (can get very verbose in multiprocessing) :param tokenizer_class: (Optional) Name of the tokenizer class to load (e.g. `BertTokenizer`) :param use_fast: (Optional, True by default) Indicate if FARM should try to load the fast version of the tokenizer (True) or @@ -259,48 +255,6 @@ def load( devices=devices, ) - def _set_multiprocessing_pool(self, num_processes: Optional[int]) -> None: - """ - Initialize a multiprocessing.Pool for instances of Inferencer. - - :param num_processes: the number of processes for `multiprocessing.Pool`. - Set to value of 1 (or 0) to disable multiprocessing. - Set to None to let Inferencer use all CPU cores minus one. - If you want to debug the Language Model, you might need to disable multiprocessing! - **Warning!** If you use multiprocessing you have to close the - `multiprocessing.Pool` again! To do so call - :func:`~farm.infer.Inferencer.close_multiprocessing_pool` after you are - done using this class. The garbage collector will not do this for you! - :return: None - """ - self.process_pool = None - if num_processes == 0 or num_processes == 1: # disable multiprocessing - self.process_pool = None - else: - if num_processes is None: # use all CPU cores - if mp.cpu_count() > 3: - num_processes = mp.cpu_count() - 1 - else: - num_processes = mp.cpu_count() - self.process_pool = mp.Pool(processes=num_processes) - logger.info("Got ya %s parallel workers to do inference ...", num_processes) - log_ascii_workers(n=num_processes, logger=logger) - - def close_multiprocessing_pool(self, join: bool = False): - """Close the `multiprocessing.Pool` again. - - If you use multiprocessing you have to close the `multiprocessing.Pool` again! - To do so call this function after you are done using this class. - The garbage collector will not do this for you! - - :param join: wait for the worker processes to exit - """ - if self.process_pool is not None: - self.process_pool.close() - if join: - self.process_pool.join() - self.process_pool = None - def save(self, path: str): self.model.save(path) self.processor.save(path) @@ -313,6 +267,9 @@ def inference_from_file(self, file: str, multiprocessing_chunksize: int = None, :param file: path of the input file for Inference :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process + .. deprecated:: 1.10 + This parameter has no effect; it will be removed as Inferencer multiprocessing + has been deprecated. :return: list of predictions """ dicts = self.processor.file_to_dicts(file) @@ -333,8 +290,11 @@ def inference_from_dicts( One dict per sample. :param return_json: Whether the output should be in a json appropriate format. If False, it returns the prediction object where applicable, else it returns PredObj.to_json() - :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process + :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process (only relevant if you do multiprocessing) + .. deprecated:: 1.10 + This parameter has no effect; it will be removed as Inferencer multiprocessing + has been deprecated. :return: list of predictions """ # whether to aggregate predictions across different samples (e.g. for QA on long texts) @@ -346,26 +306,8 @@ def inference_from_dicts( if len(self.model.prediction_heads) > 0: aggregate_preds = hasattr(self.model.prediction_heads[0], "aggregate_preds") - if self.process_pool is None: # multiprocessing disabled (helpful for debugging or using in web frameworks) - predictions: Any = self._inference_without_multiprocessing(dicts, return_json, aggregate_preds) - return predictions - else: # use multiprocessing for inference - # Calculate values of multiprocessing_chunksize and num_processes if not supplied in the parameters. - - if multiprocessing_chunksize is None: - _chunk_size, _ = calc_chunksize(len(dicts)) - multiprocessing_chunksize = _chunk_size - - predictions = self._inference_with_multiprocessing( - dicts, return_json, aggregate_preds, multiprocessing_chunksize - ) - - self.processor.log_problematic(self.problematic_sample_ids) - # cast the generator to a list if it isnt already a list. - if type(predictions) != list: - return list(predictions) - else: - return predictions + predictions: Any = self._inference_without_multiprocessing(dicts, return_json, aggregate_preds) + return predictions def _inference_without_multiprocessing(self, dicts: List[Dict], return_json: bool, aggregate_preds: bool) -> List: """ @@ -399,69 +341,6 @@ def _inference_without_multiprocessing(self, dicts: List[Dict], return_json: boo return preds_all - def _inference_with_multiprocessing( - self, - dicts: Union[List[Dict], Generator[Dict, None, None]], - return_json: bool, - aggregate_preds: bool, - multiprocessing_chunksize: int, - ) -> Generator[Dict, None, None]: - """ - Implementation of inference. This method is a generator that yields the results. - - :param dicts: Samples to run inference on provided as a list of dicts or a generator object that yield dicts. - :param return_json: Whether the output should be in a json appropriate format. If False, it returns the prediction - object where applicable, else it returns PredObj.to_json() - :param aggregate_preds: whether to aggregate predictions across different samples (e.g. for QA on long texts) - :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process - :return: generator object that yield predictions - """ - - # We group the input dicts into chunks and feed each chunk to a different process - # in the pool, where it gets converted to a pytorch dataset - if self.process_pool is not None: - results = self.process_pool.imap( - partial(self._create_datasets_chunkwise, processor=self.processor), - grouper(iterable=dicts, n=multiprocessing_chunksize), - 1, - ) - - # Once a process spits out a preprocessed chunk. we feed this dataset directly to the model. - # So we don't need to wait until all preprocessing has finished before getting first predictions. - for dataset, tensor_names, problematic_sample_ids, baskets in results: - self.problematic_sample_ids.update(problematic_sample_ids) - if dataset is None: - logger.error( - f"Part of the dataset could not be converted! \n" - f"BE AWARE: The order of predictions will not conform with the input order!" - ) - else: - # TODO change format of formatted_preds in QA (list of dicts) - if aggregate_preds: - predictions = self._get_predictions_and_aggregate(dataset, tensor_names, baskets) - else: - predictions = self._get_predictions(dataset, tensor_names, baskets) - - if return_json: - # TODO this try catch should be removed when all tasks return prediction objects - try: - predictions = [x.to_json() for x in predictions] - except AttributeError: - pass - yield from predictions - - @classmethod - def _create_datasets_chunkwise(cls, chunk, processor: Processor): - """Convert ONE chunk of data (i.e. dictionaries) into ONE pytorch dataset. - This is usually executed in one of many parallel processes. - The resulting datasets of the processes are merged together afterwards""" - dicts = [d[1] for d in chunk] - indices = [d[0] for d in chunk] - dataset, tensor_names, problematic_sample_ids, baskets = processor.dataset_from_dicts( - dicts, indices, return_baskets=True - ) - return dataset, tensor_names, problematic_sample_ids, baskets - def _get_predictions(self, dataset: Dataset, tensor_names: List, baskets): """ Feed a preprocessed dataset to the model and get the actual predictions (forward pass + formatting). @@ -592,6 +471,13 @@ def __init__(self, *args, **kwargs): def inference_from_dicts( self, dicts: List[dict], return_json: bool = True, multiprocessing_chunksize: Optional[int] = None ) -> List[QAPred]: + """ + :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process + (only relevant if you do multiprocessing) + .. deprecated:: 1.10 + This parameter has no effect; it will be removed as Inferencer multiprocessing + has been deprecated. + """ return Inferencer.inference_from_dicts( self, dicts, return_json=return_json, multiprocessing_chunksize=multiprocessing_chunksize ) @@ -599,6 +485,13 @@ def inference_from_dicts( def inference_from_file( self, file: str, multiprocessing_chunksize: Optional[int] = None, return_json=True ) -> List[QAPred]: + """ + :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process + (only relevant if you do multiprocessing) + .. deprecated:: 1.10 + This parameter has no effect; it will be removed as Inferencer multiprocessing + has been deprecated. + """ return Inferencer.inference_from_file( self, file, return_json=return_json, multiprocessing_chunksize=multiprocessing_chunksize ) @@ -606,6 +499,13 @@ def inference_from_file( def inference_from_objects( self, objects: List[QAInput], return_json: bool = True, multiprocessing_chunksize: Optional[int] = None ) -> List[QAPred]: + """ + :param multiprocessing_chunksize: number of dicts to put together in one chunk and feed to one process + (only relevant if you do multiprocessing) + .. deprecated:: 1.10 + This parameter has no effect; it will be removed as Inferencer multiprocessing + has been deprecated. + """ dicts = [o.to_dict() for o in objects] # TODO investigate this deprecation warning. Timo: I thought we were about to implement Input Objects, # then we can and should use inference from (input) objects! diff --git a/test/conftest.py b/test/conftest.py index 1006b71219..303c00a5a6 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1115,22 +1115,15 @@ def adaptive_model_qa(num_processes): """ PyTest Fixture for a Question Answering Inferencer based on PyTorch. """ - try: - model = Inferencer.load( - "deepset/bert-base-cased-squad2", - task_type="question_answering", - batch_size=16, - num_processes=num_processes, - gpu=False, - ) - yield model - finally: - if num_processes != 0: - # close the pool - # we pass join=True to wait for all sub processes to close - # this is because below we want to test if all sub-processes - # have exited - model.close_multiprocessing_pool(join=True) + + model = Inferencer.load( + "deepset/bert-base-cased-squad2", + task_type="question_answering", + batch_size=16, + num_processes=num_processes, + gpu=False, + ) + yield model # check if all workers (sub processes) are closed current_process = psutil.Process()