Skip to content

Commit

Permalink
refactor: apply pep-484 (deepset-ai#3542)
Browse files Browse the repository at this point in the history
* apply pep-484

* another implicit optional

* apply pep-484 on rest_api and ui too
  • Loading branch information
ZanSara authored Nov 8, 2022
1 parent 43b24fd commit 9539a20
Show file tree
Hide file tree
Showing 47 changed files with 269 additions and 214 deletions.
6 changes: 3 additions & 3 deletions haystack/document_stores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def add_eval_data(
label_index: str = "label",
batch_size: Optional[int] = None,
preprocessor: Optional[PreProcessor] = None,
max_docs: Union[int, bool] = None,
max_docs: Optional[Union[int, bool]] = None,
open_domain: bool = False,
headers: Optional[Dict[str, str]] = None,
):
Expand Down Expand Up @@ -568,7 +568,7 @@ def get_documents_by_id(
pass

@abstractmethod
def update_document_meta(self, id: str, meta: Dict[str, Any], index: str = None):
def update_document_meta(self, id: str, meta: Dict[str, Any], index: Optional[str] = None):
pass

def _drop_duplicate_documents(self, documents: List[Document], index: Optional[str] = None) -> List[Document]:
Expand Down Expand Up @@ -633,7 +633,7 @@ def _handle_duplicate_documents(
return documents

def _get_duplicate_labels(
self, labels: list, index: str = None, headers: Optional[Dict[str, str]] = None
self, labels: list, index: Optional[str] = None, headers: Optional[Dict[str, str]] = None
) -> List[Label]:
"""
Return all duplicate labels
Expand Down
4 changes: 2 additions & 2 deletions haystack/document_stores/deepsetcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def wrapper(self, *args, **kwargs):
class DeepsetCloudDocumentStore(KeywordDocumentStore):
def __init__(
self,
api_key: str = None,
api_key: Optional[str] = None,
workspace: str = "default",
index: Optional[str] = None,
duplicate_documents: str = "overwrite",
Expand Down Expand Up @@ -603,7 +603,7 @@ def write_documents(
pass

@disable_and_log
def update_document_meta(self, id: str, meta: Dict[str, Any], index: str = None):
def update_document_meta(self, id: str, meta: Dict[str, Any], index: Optional[str] = None):
"""
Update the metadata dictionary of a document by specifying its string id.
Expand Down
8 changes: 4 additions & 4 deletions haystack/document_stores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class FAISSDocumentStore(SQLDocumentStore):
def __init__(
self,
sql_url: str = "sqlite:https:///faiss_document_store.db",
vector_dim: int = None,
vector_dim: Optional[int] = None,
embedding_dim: int = 768,
faiss_index_factory_str: str = "Flat",
faiss_index: Optional[faiss.swigfaiss.Index] = None,
Expand All @@ -52,9 +52,9 @@ def __init__(
embedding_field: str = "embedding",
progress_bar: bool = True,
duplicate_documents: str = "overwrite",
faiss_index_path: Union[str, Path] = None,
faiss_config_path: Union[str, Path] = None,
isolation_level: str = None,
faiss_index_path: Optional[Union[str, Path]] = None,
faiss_config_path: Optional[Union[str, Path]] = None,
isolation_level: Optional[str] = None,
n_links: int = 64,
ef_search: int = 20,
ef_construction: int = 80,
Expand Down
4 changes: 2 additions & 2 deletions haystack/document_stores/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def get_document_count(
)
return len(documents)

def update_document_meta(self, id: str, meta: Dict[str, Any], index: str = None):
def update_document_meta(self, id: str, meta: Dict[str, Any], index: Optional[str] = None):
"""
Update the metadata dictionary of a document by specifying its string id.
Expand Down Expand Up @@ -639,7 +639,7 @@ def get_all_documents_generator(

def get_all_labels(
self,
index: str = None,
index: Optional[str] = None,
filters: Optional[Dict[str, Any]] = None, # TODO: Adapt type once we allow extended filters in InMemoryDocStore
headers: Optional[Dict[str, str]] = None,
) -> List[Label]:
Expand Down
4 changes: 2 additions & 2 deletions haystack/document_stores/milvus1.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
milvus_url: str = "tcp:https://localhost:19530",
connection_pool: str = "SingletonThread",
index: str = "document",
vector_dim: int = None,
vector_dim: Optional[int] = None,
embedding_dim: int = 768,
index_file_size: int = 1024,
similarity: str = "dot_product",
Expand All @@ -57,7 +57,7 @@ def __init__(
embedding_field: str = "embedding",
progress_bar: bool = True,
duplicate_documents: str = "overwrite",
isolation_level: str = None,
isolation_level: Optional[str] = None,
):
"""
**WARNING:** Milvus1DocumentStore is deprecated and will be removed in a future version. Please switch to Milvus2
Expand Down
4 changes: 2 additions & 2 deletions haystack/document_stores/milvus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
port: str = "19530",
connection_pool: str = "SingletonThread",
index: str = "document",
vector_dim: int = None,
vector_dim: Optional[int] = None,
embedding_dim: int = 768,
index_file_size: int = 1024,
similarity: str = "dot_product",
Expand All @@ -74,7 +74,7 @@ def __init__(
custom_fields: Optional[List[Any]] = None,
progress_bar: bool = True,
duplicate_documents: str = "overwrite",
isolation_level: str = None,
isolation_level: Optional[str] = None,
consistency_level: int = 0,
recreate_index: bool = False,
):
Expand Down
6 changes: 3 additions & 3 deletions haystack/document_stores/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ def get_documents_by_id(
batch_size: int = 32,
headers: Optional[Dict[str, str]] = None,
return_embedding: Optional[bool] = None,
namespace: str = None,
namespace: Optional[str] = None,
) -> List[Document]:
"""
Retrieves all documents in the index using their IDs.
Expand Down Expand Up @@ -826,7 +826,7 @@ def get_document_by_id(
index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
return_embedding: Optional[bool] = None,
namespace: str = None,
namespace: Optional[str] = None,
) -> Document:
"""
Returns a single Document retrieved using an ID.
Expand Down Expand Up @@ -869,7 +869,7 @@ def get_embedding_count(
count = 0
return count

def update_document_meta(self, id: str, meta: Dict[str, str], namespace: str = None, index: str = None): # type: ignore
def update_document_meta(self, id: str, meta: Dict[str, str], namespace: Optional[str] = None, index: Optional[str] = None): # type: ignore
"""
Update the metadata dictionary of a document by specifying its string ID.
Expand Down
2 changes: 1 addition & 1 deletion haystack/document_stores/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def write_labels(
self._bulk(labels_to_index, request_timeout=300, refresh=self.refresh_type, headers=headers)

def update_document_meta(
self, id: str, meta: Dict[str, str], index: str = None, headers: Optional[Dict[str, str]] = None
self, id: str, meta: Dict[str, str], index: Optional[str] = None, headers: Optional[Dict[str, str]] = None
):
"""
Update the metadata dictionary of a document by specifying its string id
Expand Down
4 changes: 2 additions & 2 deletions haystack/document_stores/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
label_index: str = "label",
duplicate_documents: str = "overwrite",
check_same_thread: bool = False,
isolation_level: str = None,
isolation_level: Optional[str] = None,
):
"""
An SQL backed DocumentStore. Currently supports SQLite, PostgreSQL and MySQL backends.
Expand Down Expand Up @@ -524,7 +524,7 @@ def reset_vector_ids(self, index: Optional[str] = None):
self.session.query(DocumentORM).filter_by(index=index).update({DocumentORM.vector_id: null()})
self.session.commit()

def update_document_meta(self, id: str, meta: Dict[str, str], index: str = None):
def update_document_meta(self, id: str, meta: Dict[str, str], index: Optional[str] = None):
"""
Update the metadata dictionary of a document by specifying its string id
"""
Expand Down
11 changes: 7 additions & 4 deletions haystack/document_stores/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@


def eval_data_from_json(
filename: str, max_docs: Union[int, bool] = None, preprocessor: PreProcessor = None, open_domain: bool = False
filename: str,
max_docs: Optional[Union[int, bool]] = None,
preprocessor: Optional[PreProcessor] = None,
open_domain: bool = False,
) -> Tuple[List[Document], List[Label]]:
"""
Read Documents + Labels from a SQuAD-style file.
Expand Down Expand Up @@ -58,8 +61,8 @@ def eval_data_from_json(
def eval_data_from_jsonl(
filename: str,
batch_size: Optional[int] = None,
max_docs: Union[int, bool] = None,
preprocessor: PreProcessor = None,
max_docs: Optional[Union[int, bool]] = None,
preprocessor: Optional[PreProcessor] = None,
open_domain: bool = False,
) -> Generator[Tuple[List[Document], List[Label]], None, None]:
"""
Expand Down Expand Up @@ -123,7 +126,7 @@ def squad_json_to_jsonl(squad_file: str, output_file: str):


def _extract_docs_and_labels_from_dict(
document_dict: Dict, preprocessor: PreProcessor = None, open_domain: bool = False
document_dict: Dict, preprocessor: Optional[PreProcessor] = None, open_domain: bool = False
):
"""
Set open_domain to True if you are trying to load open_domain labels (i.e. labels without doc id or start idx)
Expand Down
8 changes: 5 additions & 3 deletions haystack/document_stores/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def __init__(
host: Union[str, List[str]] = "https://localhost",
port: Union[int, List[int]] = 8080,
timeout_config: tuple = (5, 15),
username: str = None,
password: str = None,
username: Optional[str] = None,
password: Optional[str] = None,
index: str = "Document",
embedding_dim: int = 768,
content_field: str = "content",
Expand Down Expand Up @@ -565,7 +565,9 @@ def write_documents(
progress_bar.update(batch_size)
progress_bar.close()

def update_document_meta(self, id: str, meta: Dict[str, Union[List, str, int, float, bool]], index: str = None):
def update_document_meta(
self, id: str, meta: Dict[str, Union[List, str, int, float, bool]], index: Optional[str] = None
):
"""
Update the metadata dictionary of a document by specifying its string id.
Overwrites only the specified fields, the unspecified ones remain unchanged.
Expand Down
6 changes: 3 additions & 3 deletions haystack/modeling/data_handler/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import Optional, List

from math import ceil

Expand All @@ -13,8 +13,8 @@ def __init__(
self,
dataset: Dataset,
batch_size: int,
sampler: Sampler = None,
tensor_names: List[str] = None,
sampler: Optional[Sampler] = None,
tensor_names: Optional[List[str]] = None,
num_workers: int = 0,
pin_memory: bool = False,
):
Expand Down
6 changes: 4 additions & 2 deletions haystack/modeling/data_handler/dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import numbers
from typing import List
from typing import Optional, List

import numpy as np
import torch
Expand All @@ -12,7 +12,9 @@
logger = logging.getLogger(__name__)


def flatten_rename(encoded_batch: BatchEncoding, keys: List[str] = None, renamed_keys: List[str] = None):
def flatten_rename(
encoded_batch: BatchEncoding, keys: Optional[List[str]] = None, renamed_keys: Optional[List[str]] = None
):
if encoded_batch is None:
return []
if not keys:
Expand Down
4 changes: 2 additions & 2 deletions haystack/modeling/data_handler/inputs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List, Union
from typing import Optional, List, Union


class Question:
def __init__(self, text: str, uid: str = None):
def __init__(self, text: str, uid: Optional[str] = None):
self.text = text
self.uid = uid

Expand Down
2 changes: 1 addition & 1 deletion haystack/modeling/data_handler/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2122,7 +2122,7 @@ def write_squad_predictions(predictions, out_filename, predictions_filename=None
def _read_dpr_json(
file: str,
max_samples: Optional[int] = None,
proxies: Any = None,
proxies: Optional[Any] = None,
num_hard_negatives: int = 1,
num_positives: int = 1,
shuffle_negatives: bool = True,
Expand Down
2 changes: 1 addition & 1 deletion haystack/modeling/data_handler/samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
self,
id_internal: Optional[Union[int, str]],
raw: dict,
id_external: str = None,
id_external: Optional[str] = None,
samples: Optional[List[Sample]] = None,
):
"""
Expand Down
4 changes: 2 additions & 2 deletions haystack/modeling/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def load(
disable_tqdm: bool = False,
tokenizer_class: Optional[str] = None,
use_fast: bool = True,
tokenizer_args: Dict = None,
tokenizer_args: Optional[Dict] = None,
multithreading_rust: bool = True,
use_auth_token: Optional[Union[bool, str]] = None,
devices: Optional[List[Union[str, torch.device]]] = None,
Expand Down Expand Up @@ -259,7 +259,7 @@ def save(self, path: str):
self.model.save(path)
self.processor.save(path)

def inference_from_file(self, file: str, multiprocessing_chunksize: int = None, return_json: bool = True):
def inference_from_file(self, file: str, multiprocessing_chunksize: Optional[int] = None, return_json: bool = True):
"""
Run down-stream inference on samples created from an input file.
The file should be in the same format as the ones used during training
Expand Down
2 changes: 1 addition & 1 deletion haystack/modeling/model/adaptive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def convert_from_transformers(
cls,
model_name_or_path,
device: Union[str, torch.device],
revision: str = None,
revision: Optional[str] = None,
task_type: str = "question_answering",
processor: Optional[Processor] = None,
use_auth_token: Optional[Union[bool, str]] = None,
Expand Down
2 changes: 1 addition & 1 deletion haystack/modeling/model/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class FeatureExtractor:
def __init__(
self,
pretrained_model_name_or_path: Union[str, Path],
revision: str = None,
revision: Optional[str] = None,
use_fast: bool = True,
use_auth_token: Optional[Union[str, bool]] = None,
**kwargs,
Expand Down
12 changes: 6 additions & 6 deletions haystack/modeling/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def save_config(self, save_dir: Union[Path, str]):
with open(save_filename, "w") as file:
file.write(string)

def save(self, save_dir: Union[str, Path], state_dict: Dict[Any, Any] = None):
def save(self, save_dir: Union[str, Path], state_dict: Optional[Dict[Any, Any]] = None):
"""
Save the model `state_dict` and its configuration file so that it can be loaded again.
Expand All @@ -148,7 +148,7 @@ def save(self, save_dir: Union[str, Path], state_dict: Dict[Any, Any] = None):
self.save_config(save_dir)

def formatted_preds(
self, logits, samples, ignore_first_token: bool = True, padding_mask: torch.Tensor = None
self, logits, samples, ignore_first_token: bool = True, padding_mask: Optional[torch.Tensor] = None
) -> List[Dict[str, Any]]:
"""
Extracting vectors from a language model (for example, for extracting sentence embeddings).
Expand Down Expand Up @@ -243,7 +243,7 @@ def __init__(
self,
pretrained_model_name_or_path: Union[Path, str],
model_type: str,
language: str = None,
language: Optional[str] = None,
n_added_tokens: int = 0,
use_auth_token: Optional[Union[str, bool]] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -358,7 +358,7 @@ def __init__(
self,
pretrained_model_name_or_path: Union[Path, str],
model_type: str,
language: str = None,
language: Optional[str] = None,
n_added_tokens: int = 0,
use_auth_token: Optional[Union[str, bool]] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -486,7 +486,7 @@ def __init__(
self,
pretrained_model_name_or_path: Union[Path, str],
model_type: str,
language: str = None,
language: Optional[str] = None,
n_added_tokens: int = 0,
use_auth_token: Optional[Union[str, bool]] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -822,7 +822,7 @@ def get_language_model_class(model_type: str) -> Optional[Type[Union[HFLanguageM

def get_language_model(
pretrained_model_name_or_path: Union[Path, str],
language: str = None,
language: Optional[str] = None,
n_added_tokens: int = 0,
use_auth_token: Optional[Union[str, bool]] = None,
revision: Optional[str] = None,
Expand Down
6 changes: 3 additions & 3 deletions haystack/modeling/model/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ def initialize_optimizer(
n_epochs: int,
device: torch.device,
learning_rate: float,
optimizer_opts: Dict[Any, Any] = None,
schedule_opts: Dict[Any, Any] = None,
optimizer_opts: Optional[Dict[Any, Any]] = None,
schedule_opts: Optional[Dict[Any, Any]] = None,
distributed: bool = False,
grad_acc_steps: int = 1,
local_rank: int = -1,
use_amp: str = None,
use_amp: Optional[str] = None,
):
"""
Initializes an optimizer, a learning rate scheduler and converts the model if needed (e.g for mixed precision).
Expand Down
Loading

0 comments on commit 9539a20

Please sign in to comment.