diff --git a/e2e/pipelines/test_standard_pipelines.py b/e2e/pipelines/test_standard_pipelines.py index fc3e961277..8e49ec9410 100644 --- a/e2e/pipelines/test_standard_pipelines.py +++ b/e2e/pipelines/test_standard_pipelines.py @@ -306,7 +306,7 @@ def test_summarization_pipeline(): output = pipeline.run(query=query, params={"Retriever": {"top_k": 1}}) answers = output["answers"] assert len(answers) == 1 - assert "The Eiffel Tower is one of the world's tallest structures." == answers[0]["answer"].strip() + assert answers[0]["answer"].strip() == "The Eiffel Tower is one of the world's tallest structures." def test_summarization_pipeline_one_summary(): diff --git a/e2e/preview/components/test_gpt35_generator.py b/e2e/preview/components/test_gpt35_generator.py index 044b0cf5af..7c24440469 100644 --- a/e2e/preview/components/test_gpt35_generator.py +++ b/e2e/preview/components/test_gpt35_generator.py @@ -17,7 +17,7 @@ def test_gpt35_generator_run(generator_class, model_name): assert "Paris" in results["replies"][0] assert len(results["metadata"]) == 1 assert model_name in results["metadata"][0]["model"] - assert "stop" == results["metadata"][0]["finish_reason"] + assert results["metadata"][0]["finish_reason"] == "stop" @pytest.mark.skipif( @@ -54,6 +54,6 @@ def __call__(self, chunk): assert len(results["metadata"]) == 1 assert model_name in results["metadata"][0]["model"] - assert "stop" == results["metadata"][0]["finish_reason"] + assert results["metadata"][0]["finish_reason"] == "stop" assert callback.responses == results["replies"][0] diff --git a/e2e/preview/components/test_whisper_local.py b/e2e/preview/components/test_whisper_local.py index 6c7e80613d..4290e63304 100644 --- a/e2e/preview/components/test_whisper_local.py +++ b/e2e/preview/components/test_whisper_local.py @@ -14,14 +14,14 @@ def test_whisper_local_transcriber(preview_samples_path): docs = output["documents"] assert len(docs) == 3 - assert "this is the content of the document." == docs[0].text.strip().lower() + assert docs[0].text.strip().lower() == "this is the content of the document." assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"] - assert "the context for this answer is here." == docs[1].text.strip().lower() + assert docs[1].text.strip().lower() == "the context for this answer is here." assert ( str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute()) == docs[1].metadata["audio_file"] ) - assert "answer." == docs[2].text.strip().lower() - assert "<>" == docs[2].metadata["audio_file"] + assert docs[2].text.strip().lower() == "answer." + assert docs[2].metadata["audio_file"] == "<>" diff --git a/e2e/preview/components/test_whisper_remote.py b/e2e/preview/components/test_whisper_remote.py index 673f80b17a..61f975dbb5 100644 --- a/e2e/preview/components/test_whisper_remote.py +++ b/e2e/preview/components/test_whisper_remote.py @@ -22,14 +22,14 @@ def test_whisper_remote_transcriber(preview_samples_path): docs = output["documents"] assert len(docs) == 3 - assert "this is the content of the document." == docs[0].text.strip().lower() + assert docs[0].text.strip().lower() == "this is the content of the document." assert preview_samples_path / "audio" / "this is the content of the document.wav" == docs[0].metadata["audio_file"] - assert "the context for this answer is here." == docs[1].text.strip().lower() + assert docs[1].text.strip().lower() == "the context for this answer is here." assert ( str((preview_samples_path / "audio" / "the context for this answer is here.wav").absolute()) == docs[1].metadata["audio_file"] ) - assert "answer." == docs[2].text.strip().lower() - assert "<>" == docs[2].metadata["audio_file"] + assert docs[2].text.strip().lower() == "answer." + assert docs[2].metadata["audio_file"] == "<>" diff --git a/haystack-linter/haystack_linter/linting.py b/haystack-linter/haystack_linter/linting.py index a55bfab84e..52286324ba 100644 --- a/haystack-linter/haystack_linter/linting.py +++ b/haystack-linter/haystack_linter/linting.py @@ -37,16 +37,13 @@ def leave_functiondef(self, node: nodes.FunctionDef) -> None: self._function_stack.pop() def visit_call(self, node: nodes.Call) -> None: - if isinstance(node.func, nodes.Attribute) and isinstance(node.func.expr, nodes.Name): - if node.func.expr.name == "logging" and node.func.attrname in [ - "debug", - "info", - "warning", - "error", - "critical", - "exception", - ]: - self.add_message("no-direct-logging", args=node.func.attrname, node=node) + if ( + isinstance(node.func, nodes.Attribute) + and isinstance(node.func.expr, nodes.Name) + and node.func.expr.name == "logging" + and node.func.attrname in ["debug", "info", "warning", "error", "critical", "exception"] + ): + self.add_message("no-direct-logging", args=node.func.attrname, node=node) class NoLoggingConfigurationChecker(BaseChecker): @@ -71,9 +68,13 @@ def leave_functiondef(self, node: nodes.FunctionDef) -> None: self._function_stack.pop() def visit_call(self, node: nodes.Call) -> None: - if isinstance(node.func, nodes.Attribute) and isinstance(node.func.expr, nodes.Name): - if node.func.expr.name == "logging" and node.func.attrname in ["basicConfig"]: - self.add_message("no-logging-basicconfig", node=node) + if ( + isinstance(node.func, nodes.Attribute) + and isinstance(node.func.expr, nodes.Name) + and node.func.expr.name == "logging" + and node.func.attrname in ["basicConfig"] + ): + self.add_message("no-logging-basicconfig", node=node) def register(linter: "PyLinter") -> None: diff --git a/haystack/agents/base.py b/haystack/agents/base.py index 0c0a146f47..53585a9267 100644 --- a/haystack/agents/base.py +++ b/haystack/agents/base.py @@ -346,7 +346,7 @@ def run( You can only pass parameters to tools that are pipelines, but not nodes. """ try: - if not self.hash == self.last_hash: + if self.hash != self.last_hash: self.last_hash = self.hash send_event(event_name="Agent", event_properties={"llm.agent_hash": self.hash}) except Exception as exc: diff --git a/haystack/document_stores/elasticsearch/es8.py b/haystack/document_stores/elasticsearch/es8.py index 4454072137..b9215e553f 100644 --- a/haystack/document_stores/elasticsearch/es8.py +++ b/haystack/document_stores/elasticsearch/es8.py @@ -299,9 +299,10 @@ def _init_elastic_client( return client def _index_exists(self, index_name: str, headers: Optional[Dict[str, str]] = None) -> bool: - if logger.isEnabledFor(logging.DEBUG): - if self.client.options(headers=headers).indices.exists_alias(name=index_name): - logger.debug("Index name %s is an alias.", index_name) + if logger.isEnabledFor(logging.DEBUG) and self.client.options(headers=headers).indices.exists_alias( + name=index_name + ): + logger.debug("Index name %s is an alias.", index_name) return self.client.options(headers=headers).indices.exists(index=index_name) diff --git a/haystack/document_stores/es_converter.py b/haystack/document_stores/es_converter.py index 1d242c92fa..1dade37388 100644 --- a/haystack/document_stores/es_converter.py +++ b/haystack/document_stores/es_converter.py @@ -228,9 +228,8 @@ def elasticsearch_index_to_document_store( content = record["_source"].pop(original_content_field, "") if content: meta = {} - if original_name_field is not None: - if original_name_field in record["_source"]: - meta["name"] = record["_source"].pop(original_name_field) + if original_name_field is not None and original_name_field in record["_source"]: + meta["name"] = record["_source"].pop(original_name_field) # Only add selected metadata fields if included_metadata_fields is not None: for metadata_field in included_metadata_fields: diff --git a/haystack/document_stores/faiss.py b/haystack/document_stores/faiss.py index d327c89ae3..fe8464be82 100644 --- a/haystack/document_stores/faiss.py +++ b/haystack/document_stores/faiss.py @@ -447,9 +447,8 @@ def get_all_documents_generator( return_embedding = self.return_embedding for doc in documents: - if return_embedding: - if doc.meta and doc.meta.get("vector_id") is not None: - doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"])) + if return_embedding and doc.meta and doc.meta.get("vector_id") is not None: + doc.embedding = self.faiss_indexes[index].reconstruct(int(doc.meta["vector_id"])) yield doc def get_documents_by_id( diff --git a/haystack/document_stores/opensearch.py b/haystack/document_stores/opensearch.py index 4b41a9d741..0d6e776660 100644 --- a/haystack/document_stores/opensearch.py +++ b/haystack/document_stores/opensearch.py @@ -382,10 +382,9 @@ def write_documents( self.index_type in ["ivf", "ivf_pq"] and not index.startswith(".") and not self._ivf_model_exists(index=index) - ): - if self.get_embedding_count(index=index, headers=headers) >= self.ivf_train_size: - train_docs = self.get_all_documents(index=index, return_embedding=True, headers=headers) - self._train_ivf_index(index=index, documents=train_docs, headers=headers) + ) and self.get_embedding_count(index=index, headers=headers) >= self.ivf_train_size: + train_docs = self.get_all_documents(index=index, return_embedding=True, headers=headers) + self._train_ivf_index(index=index, documents=train_docs, headers=headers) def _embed_documents(self, documents: List[Document], retriever: DenseRetriever) -> np.ndarray: """ diff --git a/haystack/document_stores/pinecone.py b/haystack/document_stores/pinecone.py index 37c13b78e2..fde08fea7d 100644 --- a/haystack/document_stores/pinecone.py +++ b/haystack/document_stores/pinecone.py @@ -487,7 +487,7 @@ def write_documents( documents=document_objects, index=index, duplicate_documents=duplicate_documents ) if document_objects: - add_vectors = False if document_objects[0].embedding is None else True + add_vectors = document_objects[0].embedding is not None # If these are not labels, we need to find the correct value for `doc_type` metadata field if not labels: type_metadata = DOCUMENT_WITH_EMBEDDING if add_vectors else DOCUMENT_WITHOUT_EMBEDDING diff --git a/haystack/document_stores/search_engine.py b/haystack/document_stores/search_engine.py index 9b212d82ee..eac2a6ec7a 100644 --- a/haystack/document_stores/search_engine.py +++ b/haystack/document_stores/search_engine.py @@ -1620,9 +1620,8 @@ def delete_index(self, index: str): self._index_delete(index) def _index_exists(self, index_name: str, headers: Optional[Dict[str, str]] = None) -> bool: - if logger.isEnabledFor(logging.DEBUG): - if self.client.indices.exists_alias(name=index_name): - logger.debug("Index name %s is an alias.", index_name) + if logger.isEnabledFor(logging.DEBUG) and self.client.indices.exists_alias(name=index_name): + logger.debug("Index name %s is an alias.", index_name) return self.client.indices.exists(index=index_name, headers=headers) diff --git a/haystack/document_stores/utils.py b/haystack/document_stores/utils.py index 462a07e783..ab84882510 100644 --- a/haystack/document_stores/utils.py +++ b/haystack/document_stores/utils.py @@ -40,9 +40,8 @@ def eval_data_from_json( logger.warning("No title information found for documents in QA file: %s", filename) for squad_document in data["data"]: - if max_docs: - if len(docs) > max_docs: - break + if max_docs and len(docs) > max_docs: + break # Extracting paragraphs and their labels from a SQuAD document dict cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict( squad_document, preprocessor, open_domain @@ -84,9 +83,8 @@ def eval_data_from_jsonl( with open(filename, "r", encoding="utf-8") as file: for document in file: - if max_docs: - if len(docs) > max_docs: - break + if max_docs and len(docs) > max_docs: + break # Extracting paragraphs and their labels from a SQuAD document dict squad_document = json.loads(document) cur_docs, cur_labels, cur_problematic_ids = _extract_docs_and_labels_from_dict( @@ -96,19 +94,18 @@ def eval_data_from_jsonl( labels.extend(cur_labels) problematic_ids.extend(cur_problematic_ids) - if batch_size is not None: - if len(docs) >= batch_size: - if len(problematic_ids) > 0: - logger.warning( - "Could not convert an answer for %s questions.\n" - "There were conversion errors for question ids: %s", - len(problematic_ids), - problematic_ids, - ) - yield docs, labels - docs = [] - labels = [] - problematic_ids = [] + if batch_size is not None and len(docs) >= batch_size: + if len(problematic_ids) > 0: + logger.warning( + "Could not convert an answer for %s questions.\n" + "There were conversion errors for question ids: %s", + len(problematic_ids), + problematic_ids, + ) + yield docs, labels + docs = [] + labels = [] + problematic_ids = [] yield docs, labels diff --git a/haystack/document_stores/weaviate.py b/haystack/document_stores/weaviate.py index be07875b07..5484ca16c6 100644 --- a/haystack/document_stores/weaviate.py +++ b/haystack/document_stores/weaviate.py @@ -661,10 +661,9 @@ def write_documents( if isinstance(v, dict): json_fields.append(k) v = json.dumps(v) - elif isinstance(v, list): - if len(v) > 0 and isinstance(v[0], dict): - json_fields.append(k) - v = [json.dumps(item) for item in v] + elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict): + json_fields.append(k) + v = [json.dumps(item) for item in v] _doc[k] = v _doc.pop("meta") @@ -734,9 +733,8 @@ def update_document_meta( # Weaviate requires dates to be in RFC3339 format date_fields = self._get_date_properties(index) for date_field in date_fields: - if date_field in meta: - if isinstance(meta[date_field], str): - meta[date_field] = convert_date_to_rfc3339(str(meta[date_field])) + if date_field in meta and isinstance(meta[date_field], str): + meta[date_field] = convert_date_to_rfc3339(str(meta[date_field])) self.weaviate_client.data_object.update(meta, class_name=index, uuid=id) @@ -771,10 +769,8 @@ def get_document_count( else: result = self.weaviate_client.query.aggregate(index).with_meta_count().do() - if "data" in result: - if "Aggregate" in result.get("data"): - if result.get("data").get("Aggregate").get(index): - doc_count = result.get("data").get("Aggregate").get(index)[0]["meta"]["count"] + if "data" in result and "Aggregate" in result.get("data") and result.get("data").get("Aggregate").get(index): + doc_count = result.get("data").get("Aggregate").get(index)[0]["meta"]["count"] return doc_count @@ -1153,9 +1149,13 @@ def query( query_output = self.weaviate_client.query.raw(gql_query) results = [] - if query_output and "data" in query_output and "Get" in query_output.get("data"): - if query_output.get("data").get("Get").get(index): - results = query_output.get("data").get("Get").get(index) + if ( + query_output + and "data" in query_output + and "Get" in query_output.get("data") + and query_output.get("data").get("Get").get(index) + ): + results = query_output.get("data").get("Get").get(index) # We retrieve the JSON properties from the schema and convert them back to the Python dicts json_properties = self._get_json_properties(index=index) @@ -1421,9 +1421,13 @@ def query_by_embedding( ) results = [] - if query_output and "data" in query_output and "Get" in query_output.get("data"): - if query_output.get("data").get("Get").get(index): - results = query_output.get("data").get("Get").get(index) + if ( + query_output + and "data" in query_output + and "Get" in query_output.get("data") + and query_output.get("data").get("Get").get(index) + ): + results = query_output.get("data").get("Get").get(index) # We retrieve the JSON properties from the schema and convert them back to the Python dicts json_properties = self._get_json_properties(index=index) diff --git a/haystack/modeling/data_handler/data_silo.py b/haystack/modeling/data_handler/data_silo.py index bb9c273512..f49cbcb80b 100644 --- a/haystack/modeling/data_handler/data_silo.py +++ b/haystack/modeling/data_handler/data_silo.py @@ -111,10 +111,12 @@ def _get_dataset(self, filename: Optional[Union[str, Path]], dicts: Optional[Lis if dicts is None: dicts = list(self.processor.file_to_dicts(filename)) # type: ignore # shuffle list of dicts here if we later want to have a random dev set split from train set - if str(self.processor.train_filename) in str(filename): - if not self.processor.dev_filename: - if self.processor.dev_split > 0.0: - random.shuffle(dicts) + if ( + str(self.processor.train_filename) in str(filename) + and not self.processor.dev_filename + and self.processor.dev_split > 0.0 + ): + random.shuffle(dicts) num_dicts = len(dicts) datasets = [] diff --git a/haystack/modeling/data_handler/processor.py b/haystack/modeling/data_handler/processor.py index 3556419496..f4c1707c72 100644 --- a/haystack/modeling/data_handler/processor.py +++ b/haystack/modeling/data_handler/processor.py @@ -488,9 +488,8 @@ def dataset_from_dicts( dataset, tensor_names, baskets = self._create_dataset(baskets) # Logging - if indices: - if 0 in indices: - self._log_samples(n_samples=1, baskets=baskets) + if indices and 0 in indices: + self._log_samples(n_samples=1, baskets=baskets) # During inference we need to keep the information contained in baskets. if return_baskets: diff --git a/haystack/modeling/evaluation/eval.py b/haystack/modeling/evaluation/eval.py index 7106291a0e..a349403861 100644 --- a/haystack/modeling/evaluation/eval.py +++ b/haystack/modeling/evaluation/eval.py @@ -194,12 +194,15 @@ def log_results( logger.info("\n _________ %s _________", head["task_name"]) for metric_name, metric_val in head.items(): # log with experiment tracking framework (e.g. Mlflow) - if logging: - if not metric_name in ["preds", "labels"] and not metric_name.startswith("_"): - if isinstance(metric_val, numbers.Number): - tracker.track_metrics( - metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val}, step=steps - ) + if ( + logging + and not metric_name in ["preds", "labels"] + and not metric_name.startswith("_") + and isinstance(metric_val, numbers.Number) + ): + tracker.track_metrics( + metrics={f"{dataset_name}_{metric_name}_{head['task_name']}": metric_val}, step=steps + ) # print via standard python logger if print: if metric_name == "report": diff --git a/haystack/modeling/infer.py b/haystack/modeling/infer.py index 81b159c80e..8d2d256829 100644 --- a/haystack/modeling/infer.py +++ b/haystack/modeling/infer.py @@ -1,20 +1,20 @@ -from typing import List, Optional, Dict, Union, Set, Any - -import os +import contextlib import logging -from tqdm import tqdm +import os +from typing import Any, Dict, List, Optional, Set, Union + import torch -from torch.utils.data.sampler import SequentialSampler from torch.utils.data import Dataset +from torch.utils.data.sampler import SequentialSampler +from tqdm import tqdm from haystack.modeling.data_handler.dataloader import NamedDataLoader -from haystack.modeling.data_handler.processor import Processor, InferenceProcessor -from haystack.modeling.data_handler.samples import SampleBasket -from haystack.modeling.utils import initialize_device_settings, set_all_seeds from haystack.modeling.data_handler.inputs import QAInput +from haystack.modeling.data_handler.processor import InferenceProcessor, Processor +from haystack.modeling.data_handler.samples import SampleBasket from haystack.modeling.model.adaptive_model import AdaptiveModel, BaseAdaptiveModel from haystack.modeling.model.predictions import QAPred - +from haystack.modeling.utils import initialize_device_settings, set_all_seeds logger = logging.getLogger(__name__) @@ -340,10 +340,8 @@ def _inference_without_multiprocessing(self, dicts: List[Dict], return_json: boo if return_json: # TODO this try catch should be removed when all tasks return prediction objects - try: + with contextlib.suppress(AttributeError): preds_all = [x.to_json() for x in preds_all] - except AttributeError: - pass return preds_all diff --git a/haystack/modeling/model/adaptive_model.py b/haystack/modeling/model/adaptive_model.py index eb17d80e4c..87d84a1cda 100644 --- a/haystack/modeling/model/adaptive_model.py +++ b/haystack/modeling/model/adaptive_model.py @@ -644,7 +644,7 @@ def convert_to_onnx( model=model_name, output=output_path / "model.onnx", opset=opset_version, - use_external_format=True if model_type == "XLMRoberta" else False, + use_external_format=model_type == "XLMRoberta", use_auth_token=use_auth_token, ) diff --git a/haystack/modeling/model/language_model.py b/haystack/modeling/model/language_model.py index 63582419b5..e1f4028eba 100644 --- a/haystack/modeling/model/language_model.py +++ b/haystack/modeling/model/language_model.py @@ -189,11 +189,7 @@ def formatted_preds( elif self.extraction_strategy == "per_token": vecs = sequence_output.cpu().numpy() - elif self.extraction_strategy == "reduce_mean": - vecs = self._pool_tokens( - sequence_output, padding_mask, self.extraction_strategy, ignore_first_token=ignore_first_token # type: ignore [arg-type] # type: ignore [arg-type] - ) - elif self.extraction_strategy == "reduce_max": + elif self.extraction_strategy in ("reduce_mean", "reduce_max"): vecs = self._pool_tokens( sequence_output, padding_mask, self.extraction_strategy, ignore_first_token=ignore_first_token # type: ignore [arg-type] # type: ignore [arg-type] ) diff --git a/haystack/modeling/model/multimodal/__init__.py b/haystack/modeling/model/multimodal/__init__.py index df11f619e5..d66c19e0c3 100644 --- a/haystack/modeling/model/multimodal/__init__.py +++ b/haystack/modeling/model/multimodal/__init__.py @@ -153,9 +153,11 @@ def get_model( def _is_sentence_transformers_model(pretrained_model_name_or_path: Union[Path, str], use_auth_token: Union[bool, str]): # Check if sentence transformers config file is in local path - if Path(pretrained_model_name_or_path).exists(): - if (Path(pretrained_model_name_or_path) / "config_sentence_transformers.json").exists(): - return True + if ( + Path(pretrained_model_name_or_path).exists() + and (Path(pretrained_model_name_or_path) / "config_sentence_transformers.json").exists() + ): + return True # Check if sentence transformers config file is in model hub try: diff --git a/haystack/modeling/model/prediction_head.py b/haystack/modeling/model/prediction_head.py index aefb3eade2..45e2d72690 100644 --- a/haystack/modeling/model/prediction_head.py +++ b/haystack/modeling/model/prediction_head.py @@ -676,9 +676,8 @@ def get_question(question_names: List[str], raw_dict: Dict): qa_name = "qas" elif "question" in raw_dict: qa_name = "question" - if qa_name: - if type(raw_dict[qa_name][0]) == dict: - return raw_dict[qa_name][0]["question"] + if qa_name and type(raw_dict[qa_name][0]) == dict: + return raw_dict[qa_name][0]["question"] return try_get(question_names, raw_dict) def aggregate_preds(self, preds, passage_start_t, ids, seq_2_start_t=None, labels=None): diff --git a/haystack/modeling/training/base.py b/haystack/modeling/training/base.py index f1baa30f43..700f288cbf 100644 --- a/haystack/modeling/training/base.py +++ b/haystack/modeling/training/base.py @@ -208,10 +208,9 @@ def train(self): progress_bar.set_description(f"Train epoch {epoch}/{self.epochs-1} (Cur. train loss: {loss:.4f})") # Only for distributed training: we need to ensure that all ranks still have a batch left for training - if self.local_rank != -1: - if not self._all_ranks_have_data(has_data=True, step=step): - early_break = True - break + if self.local_rank != -1 and not self._all_ranks_have_data(has_data=True, step=step): + early_break = True + break # Move batch of samples to device batch = {key: batch[key].to(self.device) for key in batch} @@ -324,11 +323,10 @@ def compute_loss(self, batch: dict, step: int) -> torch.Tensor: return self.backward_propagate(loss, step) def backward_propagate(self, loss: torch.Tensor, step: int): - if self.global_step % self.log_loss_every == 0 and self.local_rank in [-1, 0]: - if self.local_rank in [-1, 0]: - tracker.track_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step) - if self.log_learning_rate: - tracker.track_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step) + if self.global_step % self.log_loss_every == 0 and self.local_rank in [-1, 0] and self.local_rank in [-1, 0]: + tracker.track_metrics({"Train_loss_total": float(loss.detach().cpu().numpy())}, step=self.global_step) + if self.log_learning_rate: + tracker.track_metrics({"learning_rate": self.lr_schedule.get_last_lr()[0]}, step=self.global_step) self.scaler.scale(loss).backward() @@ -374,16 +372,15 @@ def create_or_load_checkpoint( defaults to "latest", using the checkpoint with the highest train steps. """ checkpoint_to_load = None - if checkpoint_root_dir: - if checkpoint_root_dir.exists(): - if resume_from_checkpoint == "latest": - saved_checkpoints = cls._get_checkpoints(checkpoint_root_dir) - if saved_checkpoints: - checkpoint_to_load = saved_checkpoints[0] # latest checkpoint - else: - checkpoint_to_load = None + if checkpoint_root_dir and checkpoint_root_dir.exists(): + if resume_from_checkpoint == "latest": + saved_checkpoints = cls._get_checkpoints(checkpoint_root_dir) + if saved_checkpoints: + checkpoint_to_load = saved_checkpoints[0] # latest checkpoint else: - checkpoint_to_load = checkpoint_root_dir / resume_from_checkpoint + checkpoint_to_load = None + else: + checkpoint_to_load = checkpoint_root_dir / resume_from_checkpoint if checkpoint_to_load: # TODO load empty model class from config instead of passing here? diff --git a/haystack/nodes/connector/crawler.py b/haystack/nodes/connector/crawler.py index 73c303bfb3..415f3b6cf7 100644 --- a/haystack/nodes/connector/crawler.py +++ b/haystack/nodes/connector/crawler.py @@ -485,14 +485,16 @@ def _extract_sublinks_from_url( ) continue - if sub_link and not (already_found_links and sub_link in already_found_links): - if self._is_internal_url(base_url=base_url, sub_link=sub_link) and ( - not self._is_inpage_navigation(base_url=base_url, sub_link=sub_link) - ): - if filter_pattern is not None: - if filter_pattern.search(sub_link): - sub_links.add(sub_link) - else: + if ( + sub_link + and not (already_found_links and sub_link in already_found_links) + and self._is_internal_url(base_url=base_url, sub_link=sub_link) + and (not self._is_inpage_navigation(base_url=base_url, sub_link=sub_link)) + ): + if filter_pattern is not None: + if filter_pattern.search(sub_link): sub_links.add(sub_link) + else: + sub_links.add(sub_link) return sub_links diff --git a/haystack/nodes/file_converter/image.py b/haystack/nodes/file_converter/image.py index 4837371593..1e4db2d392 100644 --- a/haystack/nodes/file_converter/image.py +++ b/haystack/nodes/file_converter/image.py @@ -134,10 +134,14 @@ def convert( digits = [word for word in words if any(i.isdigit() for i in word)] # remove lines having > 40% of words as digits AND not ending with a period(.) - if remove_numeric_tables: - if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."): - logger.debug("Removing line '%s' from file", line) - continue + if ( + remove_numeric_tables + and words + and len(digits) / len(words) > 0.4 + and not line.strip().endswith(".") + ): + logger.debug("Removing line '%s' from file", line) + continue cleaned_lines.append(line) page = "\n".join(cleaned_lines) diff --git a/haystack/nodes/file_converter/pdf.py b/haystack/nodes/file_converter/pdf.py index 7ce6e6ccdd..e5348671dc 100644 --- a/haystack/nodes/file_converter/pdf.py +++ b/haystack/nodes/file_converter/pdf.py @@ -182,10 +182,14 @@ def convert( digits = [word for word in words if any(i.isdigit() for i in word)] # remove lines having > 40% of words as digits AND not ending with a period(.) - if remove_numeric_tables: - if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."): - logger.debug("Removing line '%s' from %s", line, file_path) - continue + if ( + remove_numeric_tables + and words + and len(digits) / len(words) > 0.4 + and not line.strip().endswith(".") + ): + logger.debug("Removing line '%s' from %s", line, file_path) + continue cleaned_lines.append(line) page = "\n".join(cleaned_lines) diff --git a/haystack/nodes/file_converter/pdf_xpdf.py b/haystack/nodes/file_converter/pdf_xpdf.py index 5d8a32f148..8674e33931 100644 --- a/haystack/nodes/file_converter/pdf_xpdf.py +++ b/haystack/nodes/file_converter/pdf_xpdf.py @@ -132,10 +132,14 @@ def convert( digits = [word for word in words if any(i.isdigit() for i in word)] # remove lines having > 40% of words as digits AND not ending with a period(.) - if remove_numeric_tables: - if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."): - logger.debug("Removing line '%s' from %s", line, file_path) - continue + if ( + remove_numeric_tables + and words + and len(digits) / len(words) > 0.4 + and not line.strip().endswith(".") + ): + logger.debug("Removing line '%s' from %s", line, file_path) + continue cleaned_lines.append(line) page = "\n".join(cleaned_lines) diff --git a/haystack/nodes/file_converter/tika.py b/haystack/nodes/file_converter/tika.py index 510f96dd15..46ff729c85 100644 --- a/haystack/nodes/file_converter/tika.py +++ b/haystack/nodes/file_converter/tika.py @@ -169,10 +169,14 @@ def convert( digits = [word for word in words if any(i.isdigit() for i in word)] # remove lines having > 40% of words as digits AND not ending with a period(.) - if remove_numeric_tables: - if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."): - logger.debug("Removing line '%s' from %s", line, file_path) - continue + if ( + remove_numeric_tables + and words + and len(digits) / len(words) > 0.4 + and not line.strip().endswith(".") + ): + logger.debug("Removing line '%s' from %s", line, file_path) + continue cleaned_lines.append(line) diff --git a/haystack/nodes/file_converter/txt.py b/haystack/nodes/file_converter/txt.py index b9e323ec5f..0a2cd8d87e 100644 --- a/haystack/nodes/file_converter/txt.py +++ b/haystack/nodes/file_converter/txt.py @@ -61,10 +61,14 @@ def convert( digits = [word for word in words if any(i.isdigit() for i in word)] # remove lines having > 40% of words as digits AND not ending with a period(.) - if remove_numeric_tables: - if words and len(digits) / len(words) > 0.4 and not line.strip().endswith("."): - logger.debug("Removing line '%s' from %s", line, file_path) - continue + if ( + remove_numeric_tables + and words + and len(digits) / len(words) > 0.4 + and not line.strip().endswith(".") + ): + logger.debug("Removing line '%s' from %s", line, file_path) + continue cleaned_lines.append(line) diff --git a/haystack/nodes/other/route_documents.py b/haystack/nodes/other/route_documents.py index 87bd3d8499..e504d05e29 100644 --- a/haystack/nodes/other/route_documents.py +++ b/haystack/nodes/other/route_documents.py @@ -50,12 +50,11 @@ def __init__( self.metadata_values = metadata_values self.return_remaining = return_remaining - if self.split_by != "content_type": - if self.metadata_values is None or len(self.metadata_values) == 0: - raise ValueError( - "If split_by is set to the name of a metadata field, provide metadata_values if you want to split " - "a list of Documents by a metadata field." - ) + if self.split_by != "content_type" and (self.metadata_values is None or len(self.metadata_values) == 0): + raise ValueError( + "If split_by is set to the name of a metadata field, provide metadata_values if you want to split " + "a list of Documents by a metadata field." + ) @classmethod def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int: diff --git a/haystack/nodes/prompt/invocation_layer/hugging_face.py b/haystack/nodes/prompt/invocation_layer/hugging_face.py index a7666dca15..83a1cfc9f2 100644 --- a/haystack/nodes/prompt/invocation_layer/hugging_face.py +++ b/haystack/nodes/prompt/invocation_layer/hugging_face.py @@ -259,7 +259,7 @@ def invoke(self, *args, **kwargs): gen_dict.pop("transformers_version", None) model_input_kwargs.update(gen_dict) - is_text_generation = "text-generation" == self.task_name + is_text_generation = self.task_name == "text-generation" # Prefer return_full_text is False for text-generation (unless explicitly set) # Thus only generated text is returned (excluding prompt) if is_text_generation and "return_full_text" not in model_input_kwargs: diff --git a/haystack/nodes/prompt/invocation_layer/sagemaker_hf_infer.py b/haystack/nodes/prompt/invocation_layer/sagemaker_hf_infer.py index d4e0d9eab5..4ad97c5407 100644 --- a/haystack/nodes/prompt/invocation_layer/sagemaker_hf_infer.py +++ b/haystack/nodes/prompt/invocation_layer/sagemaker_hf_infer.py @@ -249,9 +249,8 @@ def _unwrap_response(self, response: Any): if isinstance(response, list): for sublist in response: yield from self._unwrap_response(sublist) - elif isinstance(response, dict): - if "generated_text" in response or "generated_texts" in response: - yield response + elif isinstance(response, dict) and ("generated_text" in response or "generated_texts" in response): + yield response @classmethod def get_test_payload(cls) -> Dict[str, str]: diff --git a/haystack/nodes/reader/base.py b/haystack/nodes/reader/base.py index 98f249981d..4bea95b2e4 100644 --- a/haystack/nodes/reader/base.py +++ b/haystack/nodes/reader/base.py @@ -198,11 +198,10 @@ def run_batch( # type: ignore # Add corresponding document_name and more meta data, if an answer contains the document_id answer_iterator = itertools.chain.from_iterable(results_label_input["answers"]) - if isinstance(documents[0], Document): - if isinstance(queries, list): - answer_iterator = itertools.chain.from_iterable( - itertools.chain.from_iterable(results_label_input["answers"]) - ) + if isinstance(documents[0], Document) and isinstance(queries, list): + answer_iterator = itertools.chain.from_iterable( + itertools.chain.from_iterable(results_label_input["answers"]) + ) flattened_documents = [] for doc_list in documents: if isinstance(doc_list, list): diff --git a/haystack/nodes/reader/farm.py b/haystack/nodes/reader/farm.py index 4e39d3d65e..94847339d1 100644 --- a/haystack/nodes/reader/farm.py +++ b/haystack/nodes/reader/farm.py @@ -1398,11 +1398,8 @@ def calibrate_confidence_scores( @staticmethod def _check_no_answer(c: "QACandidate"): # check for correct value in "answer" - if c.offset_answer_start == 0 and c.offset_answer_end == 0: - if c.answer != "no_answer": - logger.error( - "Invalid 'no_answer': Got a prediction for position 0, but answer string is not 'no_answer'" - ) + if c.offset_answer_start == 0 and c.offset_answer_end == 0 and c.answer != "no_answer": + logger.error("Invalid 'no_answer': Got a prediction for position 0, but answer string is not 'no_answer'") return c.answer == "no_answer" def predict_on_texts(self, question: str, texts: List[str], top_k: Optional[int] = None): diff --git a/haystack/nodes/retriever/sparse.py b/haystack/nodes/retriever/sparse.py index c32dddf0a6..708dcdb834 100644 --- a/haystack/nodes/retriever/sparse.py +++ b/haystack/nodes/retriever/sparse.py @@ -504,16 +504,15 @@ def retrieve( "Both the `index` parameter passed to the `retrieve` method and the default `index` of the Document store are null. Pass a non-null `index` value." ) - if self.auto_fit: - if ( - index not in self.document_counts - or document_store.get_document_count(headers=headers, index=index) != self.document_counts[index] - ): - # run fit() to update self.dataframes, self.tfidf_matrices and self.document_counts - logger.warning( - "Indexed documents have been updated and fit() method needs to be run before retrieval. Running it now." - ) - self.fit(document_store=document_store, index=index) + if self.auto_fit and ( + index not in self.document_counts + or document_store.get_document_count(headers=headers, index=index) != self.document_counts[index] + ): + # run fit() to update self.dataframes, self.tfidf_matrices and self.document_counts + logger.warning( + "Indexed documents have been updated and fit() method needs to be run before retrieval. Running it now." + ) + self.fit(document_store=document_store, index=index) if self.dataframes[index] is None: raise DocumentStoreError( "Retrieval requires dataframe and tf-idf matrix but fit() did not calculate them probably due to an empty document store." @@ -592,16 +591,15 @@ def retrieve_batch( "Both the `index` parameter passed to the `retrieve_batch` method and the default `index` of the Document store are null. Pass a non-null `index` value." ) - if self.auto_fit: - if ( - index not in self.document_counts - or document_store.get_document_count(headers=headers, index=index) != self.document_counts[index] - ): - # run fit() to update self.dataframes, self.tfidf_matrices and self.document_counts - logger.warning( - "Indexed documents have been updated and fit() method needs to be run before retrieval. Running it now." - ) - self.fit(document_store=document_store, index=index) + if self.auto_fit and ( + index not in self.document_counts + or document_store.get_document_count(headers=headers, index=index) != self.document_counts[index] + ): + # run fit() to update self.dataframes, self.tfidf_matrices and self.document_counts + logger.warning( + "Indexed documents have been updated and fit() method needs to be run before retrieval. Running it now." + ) + self.fit(document_store=document_store, index=index) if self.dataframes[index] is None: raise DocumentStoreError( "Retrieval requires dataframe and tf-idf matrix but fit() did not calculate them probably because of an empty document store." diff --git a/haystack/pipelines/base.py b/haystack/pipelines/base.py index 5d51528c8c..ed329c75ad 100644 --- a/haystack/pipelines/base.py +++ b/haystack/pipelines/base.py @@ -538,9 +538,8 @@ def run( # type: ignore # Apply debug attributes to the node input params # NOTE: global debug attributes will override the value specified # in each node's params dictionary. - if debug is None and node_input: - if node_input.get("params", {}): - debug = params.get("debug", None) # type: ignore + if debug is None and node_input and node_input.get("params", {}): + debug = params.get("debug", None) # type: ignore if debug is not None: if not node_input.get("params", None): node_input["params"] = {} @@ -709,9 +708,8 @@ def run_batch( # noqa: C901,PLR0912 type: ignore # Apply debug attributes to the node input params # NOTE: global debug attributes will override the value specified in each node's params dictionary. - if debug is None and node_input: - if node_input.get("params", {}): - debug = params.get("debug", None) # type: ignore + if debug is None and node_input and node_input.get("params", {}): + debug = params.get("debug", None) # type: ignore if debug is not None: if not node_input.get("params", None): node_input["params"] = {} @@ -2285,22 +2283,21 @@ def _validate_node_names_in_params(self, params: Optional[Dict]): """ Validates the node names provided in the 'params' arg of run/run_batch method. """ - if params: - if not all(node_id in self.graph.nodes for node_id in params.keys()): - # Might be a non-targeted param. Verify that too - not_a_node = set(params.keys()) - set(self.graph.nodes) - # "debug" will be picked up by _dispatch_run, see its code - # "add_isolated_node_eval" is set by pipeline.eval / pipeline.eval_batch - valid_global_params = {"debug", "add_isolated_node_eval"} - for node_id in self.graph.nodes: - run_signature_args = self._get_run_node_signature(node_id) - valid_global_params |= set(run_signature_args) - invalid_keys = [key for key in not_a_node if key not in valid_global_params] - - if invalid_keys: - raise ValueError( - f"No node(s) or global parameter(s) named {', '.join(invalid_keys)} found in pipeline." - ) + if params and not all(node_id in self.graph.nodes for node_id in params.keys()): + # Might be a non-targeted param. Verify that too + not_a_node = set(params.keys()) - set(self.graph.nodes) + # "debug" will be picked up by _dispatch_run, see its code + # "add_isolated_node_eval" is set by pipeline.eval / pipeline.eval_batch + valid_global_params = {"debug", "add_isolated_node_eval"} + for node_id in self.graph.nodes: + run_signature_args = self._get_run_node_signature(node_id) + valid_global_params |= set(run_signature_args) + invalid_keys = [key for key in not_a_node if key not in valid_global_params] + + if invalid_keys: + raise ValueError( + f"No node(s) or global parameter(s) named {', '.join(invalid_keys)} found in pipeline." + ) def _get_run_node_signature(self, node_id: str): return inspect.signature(self.graph.nodes[node_id]["component"].run).parameters.keys() diff --git a/haystack/pipelines/config.py b/haystack/pipelines/config.py index 10fef0f255..d654679c38 100644 --- a/haystack/pipelines/config.py +++ b/haystack/pipelines/config.py @@ -131,13 +131,12 @@ def build_component_dependency_graph( node_name = node["name"] graph.add_node(node_name) for input in node["inputs"]: - if input in component_definitions: + if input in component_definitions and not graph.has_edge(node_name, input): # Special case for (actually permitted) cyclic dependencies between two components: # e.g. DensePassageRetriever depends on ElasticsearchDocumentStore. # In indexing pipelines ElasticsearchDocumentStore depends on DensePassageRetriever's output. # But this second dependency is looser, so we neglect it. - if not graph.has_edge(node_name, input): - graph.add_edge(input, node_name) + graph.add_edge(input, node_name) return graph diff --git a/haystack/pipelines/ray.py b/haystack/pipelines/ray.py index 891e276094..fc0a44a53e 100644 --- a/haystack/pipelines/ray.py +++ b/haystack/pipelines/ray.py @@ -356,9 +356,8 @@ async def run_async( # type: ignore # Apply debug attributes to the node input params # NOTE: global debug attributes will override the value specified # in each node's params dictionary. - if debug is None and node_input: - if node_input.get("params", {}): - debug = params.get("debug", None) # type: ignore + if debug is None and node_input and node_input.get("params", {}): + debug = params.get("debug", None) # type: ignore if debug is not None: if not node_input.get("params", None): node_input["params"] = {} diff --git a/haystack/schema.py b/haystack/schema.py index c4f7cbd1c2..3dd3df4e6e 100644 --- a/haystack/schema.py +++ b/haystack/schema.py @@ -106,15 +106,16 @@ def __init__( allowed_hash_key_attributes = ["content", "content_type", "score", "meta", "embedding"] - if id_hash_keys is not None: - if not all(key in allowed_hash_key_attributes or key.startswith("meta.") for key in id_hash_keys): - raise ValueError( - f"You passed custom strings {id_hash_keys} to id_hash_keys which is deprecated. Supply instead a " - f"list of Document's attribute names (like {', '.join(allowed_hash_key_attributes)}) or " - f"a key of meta with a maximum depth of 1 (like meta.url). " - "See [Custom id hashing on documentstore level](https://github.com/deepset-ai/haystack/pull/1910) and " - "[Allow more flexible Document id hashing](https://github.com/deepset-ai/haystack/issues/4317) for details" - ) + if id_hash_keys is not None and not all( + key in allowed_hash_key_attributes or key.startswith("meta.") for key in id_hash_keys + ): + raise ValueError( + f"You passed custom strings {id_hash_keys} to id_hash_keys which is deprecated. Supply instead a " + f"list of Document's attribute names (like {', '.join(allowed_hash_key_attributes)}) or " + f"a key of meta with a maximum depth of 1 (like meta.url). " + "See [Custom id hashing on documentstore level](https://github.com/deepset-ai/haystack/pull/1910) and " + "[Allow more flexible Document id hashing](https://github.com/deepset-ai/haystack/issues/4317) for details" + ) # We store id_hash_keys to be able to clone documents, for example when splitting them during pre-processing self.id_hash_keys = id_hash_keys or ["content"] @@ -181,10 +182,9 @@ def to_dict(self, field_map: Optional[Dict[str, Any]] = None) -> Dict: # Exclude internal fields (Pydantic, ...) fields from the conversion process if k.startswith("__"): continue - if k == "content": # Convert pd.DataFrame to list of rows for serialization - if self.content_type == "table" and isinstance(self.content, DataFrame): - v = dataframe_to_list(self.content) + if k == "content" and self.content_type == "table" and isinstance(self.content, DataFrame): + v = dataframe_to_list(self.content) k = k if k not in inv_field_map else inv_field_map[k] _doc[k] = v return _doc diff --git a/haystack/utils/cleaning.py b/haystack/utils/cleaning.py index 460df47c0a..43d1ca29fa 100644 --- a/haystack/utils/cleaning.py +++ b/haystack/utils/cleaning.py @@ -14,9 +14,7 @@ def clean_wiki_text(text: str) -> str: lines = text.split("\n") cleaned = [] for l in lines: - if len(l) > 30: - cleaned.append(l) - elif l[:2] == "==" and l[-2:] == "==": + if len(l) > 30 or (l[:2] == "==" and l[-2:] == "=="): cleaned.append(l) text = "\n".join(cleaned) diff --git a/pyproject.toml b/pyproject.toml index 144473722b..f950a42919 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -336,11 +336,11 @@ disable = [ ] [tool.pylint.'DESIGN'] max-args = 38 # Default is 5 -max-attributes = 27 # Default is 7 +max-attributes = 28 # Default is 7 max-branches = 34 # Default is 12 max-locals = 45 # Default is 15 max-module-lines = 2468 # Default is 1000 -max-nested-blocks = 7 # Default is 5 +max-nested-blocks = 9 # Default is 5 max-statements = 206 # Default is 50 [tool.pylint.'SIMILARITIES'] min-similarity-lines=6 @@ -393,6 +393,7 @@ select = [ "PERF", # Perflint "PL", # Pylint "Q", # flake8-quotes + "SIM", # flake8-simplify "SLOT", # flake8-slots "T10", # flake8-debugger "W", # pycodestyle @@ -407,13 +408,16 @@ line-length = 1486 target-version = "py38" ignore = [ "F401", # unused-import - "PERF401", # Use a list comprehension to create a transformed list "PERF203", # `try`-`except` within a loop incurs performance overhead + "PERF401", # Use a list comprehension to create a transformed list "PLR1714", # repeated-equality-comparison "PLR5501", # collapsible-else-if "PLW0603", # global-statement "PLW1510", # subprocess-run-without-check "PLW2901", # redefined-loop-name + "SIM108", # if-else-block-instead-of-if-exp + "SIM115", # open-file-with-context-handler + "SIM118", # in-dict-keys ] [tool.ruff.mccabe] @@ -428,6 +432,7 @@ max-complexity = 28 allow-magic-value-types = ["float", "int", "str"] max-args = 38 # Default is 5 max-branches = 32 # Default is 12 +max-public-methods = 90 # Default is 20 max-returns = 9 # Default is 6 max-statements = 105 # Default is 50 diff --git a/rest_api/rest_api/controller/feedback.py b/rest_api/rest_api/controller/feedback.py index 59dc357bf4..8e891c6e9b 100644 --- a/rest_api/rest_api/controller/feedback.py +++ b/rest_api/rest_api/controller/feedback.py @@ -177,7 +177,7 @@ def export_feedback( start = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["answer_start"] answer = squad_label["paragraphs"][0]["qas"][0]["answers"][0]["text"] context = squad_label["paragraphs"][0]["context"] - if not context[start : start + len(answer)] == answer: + if context[start : start + len(answer)] != answer: logger.error( "Skipping invalid squad label as string via offsets ('%s') does not match answer string ('%s') ", context[start : start + len(answer)], diff --git a/rest_api/test/test_rest_api.py b/rest_api/test/test_rest_api.py index fd0153b05d..757b9dfb8f 100644 --- a/rest_api/test/test_rest_api.py +++ b/rest_api/test/test_rest_api.py @@ -258,7 +258,7 @@ def client(tmp_path): def test_get_all_documents(client): response = client.post(url="/documents/get_by_filters", data='{"filters": {}}') - assert 200 == response.status_code + assert response.status_code == 200 # Ensure `get_all_documents` was called with the expected `filters` param MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={}, index=None) # Ensure results are part of the response body @@ -268,21 +268,21 @@ def test_get_all_documents(client): def test_get_documents_with_filters(client): response = client.post(url="/documents/get_by_filters", data='{"filters": {"test_index": ["2"]}}') - assert 200 == response.status_code + assert response.status_code == 200 # Ensure `get_all_documents` was called with the expected `filters` param MockDocumentStore.mocker.get_all_documents.assert_called_with(filters={"test_index": ["2"]}, index=None) def test_delete_all_documents(client): response = client.post(url="/documents/delete_by_filters", data='{"filters": {}}') - assert 200 == response.status_code + assert response.status_code == 200 # Ensure `delete_documents` was called on the Document Store instance MockDocumentStore.mocker.delete_documents.assert_called_with(filters={}, index=None) def test_delete_documents_with_filters(client): response = client.post(url="/documents/delete_by_filters", data='{"filters": {"test_index": ["1"]}}') - assert 200 == response.status_code + assert response.status_code == 200 # Ensure `delete_documents` was called on the Document Store instance with the same params MockDocumentStore.mocker.delete_documents.assert_called_with(filters={"test_index": ["1"]}, index=None) @@ -290,7 +290,7 @@ def test_delete_documents_with_filters(client): def test_file_upload(client): file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")} response = client.post(url="/file-upload", files=file_to_upload, data={"meta": '{"test_key": "test_value"}'}) - assert 200 == response.status_code + assert response.status_code == 200 # Ensure the `convert` method was called with the right keyword params _, kwargs = MockPDFToTextConverter.mocker.convert.call_args # Files are renamed with random prefix like 83f4c1f5b2bd43f2af35923b9408076b_sample_pdf_1.pdf @@ -302,7 +302,7 @@ def test_file_upload(client): def test_file_upload_with_no_meta(client): file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")} response = client.post(url="/file-upload", files=file_to_upload, data={}) - assert 200 == response.status_code + assert response.status_code == 200 # Ensure the `convert` method was called with the right keyword params _, kwargs = MockPDFToTextConverter.mocker.convert.call_args assert kwargs["meta"] == {"name": "sample_pdf_1.pdf"} @@ -311,7 +311,7 @@ def test_file_upload_with_no_meta(client): def test_file_upload_with_empty_meta(client): file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")} response = client.post(url="/file-upload", files=file_to_upload, data={"meta": ""}) - assert 200 == response.status_code + assert response.status_code == 200 # Ensure the `convert` method was called with the right keyword params _, kwargs = MockPDFToTextConverter.mocker.convert.call_args assert kwargs["meta"] == {"name": "sample_pdf_1.pdf"} @@ -320,7 +320,7 @@ def test_file_upload_with_empty_meta(client): def test_file_upload_with_wrong_meta(client): file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")} response = client.post(url="/file-upload", files=file_to_upload, data={"meta": "1"}) - assert 500 == response.status_code + assert response.status_code == 500 # Ensure the `convert` method was never called MockPDFToTextConverter.mocker.convert.assert_not_called() @@ -330,7 +330,7 @@ def test_file_upload_cleanup_after_indexing(client): with mock.patch("rest_api.controller.file_upload.FILE_UPLOAD_PATH", os.environ.get("FILE_UPLOAD_PATH")): file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")} response = client.post(url="/file-upload", files=file_to_upload, data={}) - assert 200 == response.status_code + assert response.status_code == 200 # ensure upload folder is empty uploaded_files = os.listdir(os.environ.get("FILE_UPLOAD_PATH")) assert len(uploaded_files) == 0 @@ -341,7 +341,7 @@ def test_file_upload_keep_files_after_indexing(client): with mock.patch("rest_api.controller.file_upload.FILE_UPLOAD_PATH", os.environ.get("FILE_UPLOAD_PATH")): file_to_upload = {"files": (Path(__file__).parent / "samples" / "pdf" / "sample_pdf_1.pdf").open("rb")} response = client.post(url="/file-upload", files=file_to_upload, params={"keep_files": "true"}) - assert 200 == response.status_code + assert response.status_code == 200 # ensure original file was kept uploaded_files = os.listdir(os.environ.get("FILE_UPLOAD_PATH")) assert len(uploaded_files) == 1 @@ -352,7 +352,7 @@ def test_query_with_no_filter(client): # `run` must return a dictionary containing a `query` key mocked_pipeline.run.return_value = {"query": TEST_QUERY} response = client.post(url="/query", json={"query": TEST_QUERY}) - assert 200 == response.status_code + assert response.status_code == 200 # Ensure `run` was called with the expected parameters mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params={}, debug=False) @@ -363,7 +363,7 @@ def test_query_with_one_filter(client): # `run` must return a dictionary containing a `query` key mocked_pipeline.run.return_value = {"query": TEST_QUERY} response = client.post(url="/query", json={"query": TEST_QUERY, "params": params}) - assert 200 == response.status_code + assert response.status_code == 200 # Ensure `run` was called with the expected parameters mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params=params, debug=False) @@ -374,7 +374,7 @@ def test_query_with_one_global_filter(client): # `run` must return a dictionary containing a `query` key mocked_pipeline.run.return_value = {"query": TEST_QUERY} response = client.post(url="/query", json={"query": TEST_QUERY, "params": params}) - assert 200 == response.status_code + assert response.status_code == 200 # Ensure `run` was called with the expected parameters mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params=params, debug=False) @@ -385,7 +385,7 @@ def test_query_with_filter_list(client): # `run` must return a dictionary containing a `query` key mocked_pipeline.run.return_value = {"query": TEST_QUERY} response = client.post(url="/query", json={"query": TEST_QUERY, "params": params}) - assert 200 == response.status_code + assert response.status_code == 200 # Ensure `run` was called with the expected parameters mocked_pipeline.run.assert_called_with(query=TEST_QUERY, params=params, debug=False) @@ -395,7 +395,7 @@ def test_query_with_no_documents_and_no_answers(client): # `run` must return a dictionary containing a `query` key mocked_pipeline.run.return_value = {"query": TEST_QUERY} response = client.post(url="/query", json={"query": TEST_QUERY}) - assert 200 == response.status_code + assert response.status_code == 200 response_json = response.json() assert response_json["documents"] == [] assert response_json["answers"] == [] @@ -414,7 +414,7 @@ def test_query_with_bool_in_params(client): "params": {"debug": True, "Retriever": {"top_k": 5}, "Reader": {"top_k": 3}}, } response = client.post(url="/query", json=request_body) - assert 200 == response.status_code + assert response.status_code == 200 response_json = response.json() assert response_json["documents"] == [] assert response_json["answers"] == [] @@ -436,7 +436,7 @@ def test_query_with_embeddings(client): ], } response = client.post(url="/query", json={"query": TEST_QUERY}) - assert 200 == response.status_code + assert response.status_code == 200 assert len(response.json()["documents"]) == 1 assert response.json()["documents"][0]["content"] == "test" assert response.json()["documents"][0]["content_type"] == "text" @@ -471,7 +471,7 @@ def test_query_with_dataframe(client): ], } response = client.post(url="/query", json={"query": TEST_QUERY}) - assert 200 == response.status_code + assert response.status_code == 200 assert len(response.json()["documents"]) == 1 assert response.json()["documents"][0]["content"] == [["col1", "col2"], ["text_1", 1], ["text_2", 2]] assert response.json()["documents"][0]["content_type"] == "table" @@ -500,7 +500,7 @@ def test_query_with_prompt_node(client): "results": ["test"], } response = client.post(url="/query", json={"query": TEST_QUERY}) - assert 200 == response.status_code + assert response.status_code == 200 assert len(response.json()["documents"]) == 1 assert response.json()["documents"][0]["content"] == "test" assert response.json()["documents"][0]["content_type"] == "text" @@ -513,7 +513,7 @@ def test_query_with_prompt_node(client): def test_write_feedback(client, feedback): response = client.post(url="/feedback", json=feedback) - assert 200 == response.status_code + assert response.status_code == 200 # Ensure `write_labels` was called on the Document Store instance passing a list # containing only one label args, _ = MockDocumentStore.mocker.write_labels.call_args @@ -528,7 +528,7 @@ def test_write_feedback(client, feedback): def test_write_feedback_without_id(client, feedback): del feedback["id"] response = client.post(url="/feedback", json=feedback) - assert 200 == response.status_code + assert response.status_code == 200 # Ensure `write_labels` was called on the Document Store instance passing a list # containing only one label args, _ = MockDocumentStore.mocker.write_labels.call_args @@ -563,7 +563,7 @@ def get_all_labels(*args, **kwargs): # Call the API and ensure `delete_labels` was called only on the label with id=123 response = client.delete(url="/feedback") - assert 200 == response.status_code + assert response.status_code == 200 MockDocumentStore.mocker.delete_labels.assert_called_with(ids=["123"], index=None) diff --git a/test/agents/test_tools_manager.py b/test/agents/test_tools_manager.py index 7d89410c73..1697890c2c 100644 --- a/test/agents/test_tools_manager.py +++ b/test/agents/test_tools_manager.py @@ -82,9 +82,10 @@ def test_tool_invocation(): assert tool.run("input") == "mock" # now fail if results key is not present - with unittest.mock.patch("haystack.pipelines.Pipeline.run", return_value={"no_results": "mock"}): - with pytest.raises(ValueError, match="Tool ToolA returned result"): - assert tool.run("input") + with unittest.mock.patch("haystack.pipelines.Pipeline.run", return_value={"no_results": "mock"}), pytest.raises( + ValueError, match="Tool ToolA returned result" + ): + assert tool.run("input") # now try tool with a correct output variable tool = Tool(name="ToolA", pipeline_or_node=p, description="Tool A Description", output_variable="no_results") diff --git a/test/document_stores/test_pinecone.py b/test/document_stores/test_pinecone.py index 94737de931..e41a568651 100644 --- a/test/document_stores/test_pinecone.py +++ b/test/document_stores/test_pinecone.py @@ -284,7 +284,7 @@ def test_get_all_documents_extended_filter_in(self, doc_store_with_docs: Pinecon @pytest.mark.integration def test_get_all_documents_extended_filter_ne(self, doc_store_with_docs: PineconeDocumentStore): retrieved_docs = doc_store_with_docs.get_all_documents(filters={"meta_field": {"$ne": "test-1"}}) - assert all("test-1" != d.meta.get("meta_field", None) for d in retrieved_docs) + assert all(d.meta.get("meta_field", None) != "test-1" for d in retrieved_docs) @pytest.mark.integration def test_get_all_documents_extended_filter_nin(self, doc_store_with_docs: PineconeDocumentStore): diff --git a/test/document_stores/test_sql_based.py b/test/document_stores/test_sql_based.py index 8f6a5a9a24..9c433fece8 100644 --- a/test/document_stores/test_sql_based.py +++ b/test/document_stores/test_sql_based.py @@ -148,7 +148,7 @@ def test_delete_docs_with_filters(document_store, retriever): documents = document_store.get_all_documents() assert len(documents) == 3 assert document_store.get_embedding_count() == 3 - assert all("2021" == doc.meta["year"] for doc in documents) + assert all(doc.meta["year"] == "2021" for doc in documents) @pytest.mark.integration @@ -224,7 +224,7 @@ def test_get_docs_with_filters_one_value(document_store, retriever): documents = document_store.get_all_documents(filters={"year": ["2020"]}) assert len(documents) == 3 - assert all("2020" == doc.meta["year"] for doc in documents) + assert all(doc.meta["year"] == "2020" for doc in documents) @pytest.mark.integration @@ -252,9 +252,9 @@ def test_get_docs_with_many_filters(document_store, retriever): documents = document_store.get_all_documents(filters={"month": ["01"], "year": ["2020"]}) assert len(documents) == 1 - assert "name_1" == documents[0].meta["name"] - assert "01" == documents[0].meta["month"] - assert "2020" == documents[0].meta["year"] + assert documents[0].meta["name"] == "name_1" + assert documents[0].meta["month"] == "01" + assert documents[0].meta["year"] == "2020" @pytest.mark.integration diff --git a/test/modeling/test_processor.py b/test/modeling/test_processor.py index 9a45de953c..db937821e3 100644 --- a/test/modeling/test_processor.py +++ b/test/modeling/test_processor.py @@ -6,6 +6,7 @@ from transformers import AutoTokenizer from haystack.modeling.data_handler.processor import SquadProcessor, _is_json +import contextlib # during inference (parameter return_baskets = False) we do not convert labels @@ -230,10 +231,8 @@ def test_batch_encoding_flatten_rename(): flatten_rename(None, [], []) # keys and renamed_keys have different sizes - try: + with contextlib.suppress(AssertionError): flatten_rename(encoded_inputs, [], ["blah"]) - except AssertionError: - pass def test_dataset_from_dicts_qa_label_conversion(samples_path, caplog=None): diff --git a/test/nodes/test_filetype_classifier.py b/test/nodes/test_filetype_classifier.py index 378d67f916..55fb34455f 100644 --- a/test/nodes/test_filetype_classifier.py +++ b/test/nodes/test_filetype_classifier.py @@ -8,6 +8,7 @@ import haystack from haystack.nodes.file_classifier.file_type import FileTypeClassifier, DEFAULT_TYPES, DEFAULT_MEDIA_TYPES +import contextlib @pytest.mark.unit @@ -93,11 +94,8 @@ def test_filetype_classifier_other_files_without_extension(samples_path): @pytest.mark.unit def test_filetype_classifier_text_files_without_extension_no_magic(monkeypatch, caplog, samples_path): - try: + with contextlib.suppress(AttributeError): # only monkeypatch if magic is installed monkeypatch.delattr(haystack.nodes.file_classifier.file_type, "magic") - except AttributeError: - # magic not installed, even better - pass node = FileTypeClassifier(supported_types=[""]) diff --git a/test/nodes/test_link_content_fetcher.py b/test/nodes/test_link_content_fetcher.py index a01aa04253..54c0220e68 100644 --- a/test/nodes/test_link_content_fetcher.py +++ b/test/nodes/test_link_content_fetcher.py @@ -226,9 +226,10 @@ def test_fetch_exception_during_content_extraction_raise_on_failure(caplog, mock url = "https://www.example.com" r = LinkContentFetcher(raise_on_failure=True) - with patch("boilerpy3.extractors.ArticleExtractor.get_content", side_effect=Exception("Could not extract content")): - with pytest.raises(Exception, match="Could not extract content"): - r.fetch(url=url) + with patch( + "boilerpy3.extractors.ArticleExtractor.get_content", side_effect=Exception("Could not extract content") + ), pytest.raises(Exception, match="Could not extract content"): + r.fetch(url=url) @pytest.mark.unit @@ -254,9 +255,10 @@ def test_fetch_exception_during_request_get_raise_on_failure(caplog): url = "https://www.example.com" r = LinkContentFetcher(raise_on_failure=True) - with patch("haystack.nodes.retriever.link_content.requests.get", side_effect=requests.RequestException()): - with pytest.raises(requests.RequestException): - r.fetch(url=url) + with patch( + "haystack.nodes.retriever.link_content.requests.get", side_effect=requests.RequestException() + ), pytest.raises(requests.RequestException): + r.fetch(url=url) @pytest.mark.unit diff --git a/test/nodes/test_web_retriever.py b/test/nodes/test_web_retriever.py index c7a68b5d76..16cbe3cbf7 100644 --- a/test/nodes/test_web_retriever.py +++ b/test/nodes/test_web_retriever.py @@ -206,10 +206,10 @@ def test_retrieve_uses_cache(mock_web_search): SearchResult("https://www.yahoo.com/", "Some text", 0.43, "2"), ] cached_docs = [Document("doc1"), Document("doc2")] - with patch.object(wr, "_check_cache", return_value=(cached_links, cached_docs)) as mock_check_cache: - with patch.object(wr, "_save_to_cache") as mock_save_cache: - with patch.object(wr, "_scrape_links", return_value=[]): - result = wr.retrieve("query") + with patch.object(wr, "_check_cache", return_value=(cached_links, cached_docs)) as mock_check_cache, patch.object( + wr, "_save_to_cache" + ) as mock_save_cache, patch.object(wr, "_scrape_links", return_value=[]): + result = wr.retrieve("query") # checking cache is always called mock_check_cache.assert_called() @@ -228,9 +228,10 @@ def test_retrieve_saves_to_cache(mock_web_search): wr = WebRetriever(api_key="fake_key", cache_document_store=MockDocumentStore(), mode="preprocessed_documents") web_docs = [Document("doc1"), Document("doc2"), Document("doc3")] - with patch.object(wr, "_save_to_cache") as mock_save_cache: - with patch.object(wr, "_scrape_links", return_value=web_docs): - wr.retrieve("query") + with patch.object(wr, "_save_to_cache") as mock_save_cache, patch.object( + wr, "_scrape_links", return_value=web_docs + ): + wr.retrieve("query") mock_save_cache.assert_called() diff --git a/test/others/test_utils.py b/test/others/test_utils.py index c3382b8dd1..ec4a64773a 100644 --- a/test/others/test_utils.py +++ b/test/others/test_utils.py @@ -54,77 +54,72 @@ def noop(): @pytest.mark.unit def test_deprecation_previous_major_and_minor(): - with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"): - with pytest.warns(match="This feature is marked for removal in v1.1"): - fail_at_version(1, 1)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns( + match="This feature is marked for removal in v1.1" + ): + fail_at_version(1, 1)(noop)() - with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"): - with pytest.raises(_pytest.outcomes.Failed): - fail_at_version(1, 1)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed): + fail_at_version(1, 1)(noop)() - with mock.patch.object(conftest, "haystack_version", "2.2.2"): - with pytest.raises(_pytest.outcomes.Failed): - fail_at_version(1, 1)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed): + fail_at_version(1, 1)(noop)() @pytest.mark.unit def test_deprecation_previous_major_same_minor(): - with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"): - with pytest.warns(match="This feature is marked for removal in v1.2"): - fail_at_version(1, 2)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns( + match="This feature is marked for removal in v1.2" + ): + fail_at_version(1, 2)(noop)() - with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"): - with pytest.raises(_pytest.outcomes.Failed): - fail_at_version(1, 2)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed): + fail_at_version(1, 2)(noop)() - with mock.patch.object(conftest, "haystack_version", "2.2.2"): - with pytest.raises(_pytest.outcomes.Failed): - fail_at_version(1, 2)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed): + fail_at_version(1, 2)(noop)() @pytest.mark.unit def test_deprecation_previous_major_later_minor(): - with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"): - with pytest.warns(match="This feature is marked for removal in v1.3"): - fail_at_version(1, 3)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns( + match="This feature is marked for removal in v1.3" + ): + fail_at_version(1, 3)(noop)() - with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"): - with pytest.raises(_pytest.outcomes.Failed): - fail_at_version(1, 3)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed): + fail_at_version(1, 3)(noop)() - with mock.patch.object(conftest, "haystack_version", "2.2.2"): - with pytest.raises(_pytest.outcomes.Failed): - fail_at_version(1, 3)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed): + fail_at_version(1, 3)(noop)() @pytest.mark.unit def test_deprecation_same_major_previous_minor(): - with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"): - with pytest.warns(match="This feature is marked for removal in v2.1"): - fail_at_version(2, 1)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns( + match="This feature is marked for removal in v2.1" + ): + fail_at_version(2, 1)(noop)() - with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"): - with pytest.raises(_pytest.outcomes.Failed): - fail_at_version(2, 1)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed): + fail_at_version(2, 1)(noop)() - with mock.patch.object(conftest, "haystack_version", "2.2.2"): - with pytest.raises(_pytest.outcomes.Failed): - fail_at_version(2, 1)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed): + fail_at_version(2, 1)(noop)() @pytest.mark.unit def test_deprecation_same_major_same_minor(): - with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"): - with pytest.warns(match="This feature is marked for removal in v2.2"): - fail_at_version(2, 2)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2-rc0"), pytest.warns( + match="This feature is marked for removal in v2.2" + ): + fail_at_version(2, 2)(noop)() - with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"): - with pytest.raises(_pytest.outcomes.Failed): - fail_at_version(2, 2)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2rc1"), pytest.raises(_pytest.outcomes.Failed): + fail_at_version(2, 2)(noop)() - with mock.patch.object(conftest, "haystack_version", "2.2.2"): - with pytest.raises(_pytest.outcomes.Failed): - fail_at_version(2, 2)(noop)() + with mock.patch.object(conftest, "haystack_version", "2.2.2"), pytest.raises(_pytest.outcomes.Failed): + fail_at_version(2, 2)(noop)() @pytest.mark.unit diff --git a/test/preview/components/file_converters/test_textfile_to_document.py b/test/preview/components/file_converters/test_textfile_to_document.py index 959fa528c3..d85a5785ef 100644 --- a/test/preview/components/file_converters/test_textfile_to_document.py +++ b/test/preview/components/file_converters/test_textfile_to_document.py @@ -90,13 +90,14 @@ def test_run(self, preview_samples_path): def test_run_warning_for_invalid_language(self, preview_samples_path, caplog): file_path = preview_samples_path / "txt" / "doc_1.txt" converter = TextFileToDocument() - with patch("haystack.preview.components.file_converters.txt.langdetect.detect", return_value="en"): - with caplog.at_level(logging.WARNING): - output = converter.run(paths=[file_path], valid_languages=["de"]) - assert ( - f"Text from file {file_path} is not in one of the valid languages: ['de']. " - f"The file may have been decoded incorrectly." in caplog.text - ) + with patch( + "haystack.preview.components.file_converters.txt.langdetect.detect", return_value="en" + ), caplog.at_level(logging.WARNING): + output = converter.run(paths=[file_path], valid_languages=["de"]) + assert ( + f"Text from file {file_path} is not in one of the valid languages: ['de']. " + f"The file may have been decoded incorrectly." in caplog.text + ) docs = output["documents"] assert len(docs) == 1 diff --git a/test/prompt/test_prompt_template.py b/test/prompt/test_prompt_template.py index 78c458a17a..71180b08aa 100644 --- a/test/prompt/test_prompt_template.py +++ b/test/prompt/test_prompt_template.py @@ -118,12 +118,13 @@ def test_prompt_templates_from_file(tmp_path): @pytest.mark.unit def test_prompt_templates_on_the_fly(): - with patch("haystack.nodes.prompt.prompt_template.yaml") as mocked_yaml: - with patch("haystack.nodes.prompt.prompt_template.prompthub") as mocked_ph: - p = PromptTemplate("This is a test prompt. Use your knowledge to answer this question: {question}") - assert p.name == "custom-at-query-time" - mocked_ph.fetch.assert_not_called() - mocked_yaml.safe_load.assert_not_called() + with patch("haystack.nodes.prompt.prompt_template.yaml") as mocked_yaml, patch( + "haystack.nodes.prompt.prompt_template.prompthub" + ) as mocked_ph: + p = PromptTemplate("This is a test prompt. Use your knowledge to answer this question: {question}") + assert p.name == "custom-at-query-time" + mocked_ph.fetch.assert_not_called() + mocked_yaml.safe_load.assert_not_called() @pytest.mark.unit