diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index a2019b753c..d6364b5b49 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -242,7 +242,7 @@ def load_and_embed( src: Any, metadata: Optional[Dict[str, Any]] = None, source_id: Optional[str] = None, - dry_run = False + dry_run=False, ) -> Tuple[List[str], Dict[str, Any], List[str], int]: """The loader to use to load the data. @@ -320,14 +320,14 @@ def load_and_embed( return list(documents), metadatas, ids, count_new_chunks def load_and_embed_v2( - self, - loader: BaseLoader, - chunker: BaseChunker, - src: Any, - metadata: Optional[Dict[str, Any]] = None, - source_id: Optional[str] = None, - dry_run = False - ): + self, + loader: BaseLoader, + chunker: BaseChunker, + src: Any, + metadata: Optional[Dict[str, Any]] = None, + source_id: Optional[str] = None, + dry_run=False, + ): """ Loads the data from the given URL, chunks it, and adds it to database. @@ -364,9 +364,7 @@ def load_and_embed_v2( # this means that doc content has changed. if existing_doc_id and existing_doc_id != new_doc_id: print("Doc content has changed. Recomputing chunks and embeddings intelligently.") - self.db.delete({ - "doc_id": existing_doc_id - }) + self.db.delete({"doc_id": existing_doc_id}) # get existing ids, and discard doc if any common id exist. where = {"app_id": self.config.id} if self.config.id is not None else {} diff --git a/embedchain/loaders/csv.py b/embedchain/loaders/csv.py index 9de84de03b..1730ae9c2c 100644 --- a/embedchain/loaders/csv.py +++ b/embedchain/loaders/csv.py @@ -46,7 +46,4 @@ def load_data(content): lines.append(line) result.append({"content": line, "meta_data": {"url": content, "row": i + 1}}) doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest() - return { - "doc_id": doc_id, - "data": result - } + return {"doc_id": doc_id, "data": result} diff --git a/embedchain/loaders/local_qna_pair.py b/embedchain/loaders/local_qna_pair.py index 36da278e1e..61d9a57668 100644 --- a/embedchain/loaders/local_qna_pair.py +++ b/embedchain/loaders/local_qna_pair.py @@ -22,5 +22,5 @@ def load_data(self, content): "content": content, "meta_data": meta_data, } - ] + ], } diff --git a/embedchain/loaders/local_text.py b/embedchain/loaders/local_text.py index 7a5195787a..118cbd3afd 100644 --- a/embedchain/loaders/local_text.py +++ b/embedchain/loaders/local_text.py @@ -20,5 +20,5 @@ def load_data(self, content): "content": content, "meta_data": meta_data, } - ] + ], } diff --git a/embedchain/loaders/notion.py b/embedchain/loaders/notion.py index 065673e564..7ff84ed583 100644 --- a/embedchain/loaders/notion.py +++ b/embedchain/loaders/notion.py @@ -39,9 +39,9 @@ def load_data(self, source): return { "doc_id": doc_id, "data": [ - { - "content": text, - "meta_data": {"url": f"notion-{formatted_id}"}, - } - ], + { + "content": text, + "meta_data": {"url": f"notion-{formatted_id}"}, + } + ], } diff --git a/embedchain/loaders/sitemap.py b/embedchain/loaders/sitemap.py index c78542c41e..d85e8829d1 100644 --- a/embedchain/loaders/sitemap.py +++ b/embedchain/loaders/sitemap.py @@ -43,7 +43,4 @@ def load_data(self, sitemap_url): logging.warning(f"Page is not readable (too many invalid characters): {link}") except ParserRejectedMarkup as e: logging.error(f"Failed to parse {link}: {e}") - return { - "doc_id": doc_id, - "data": [data[0] for data in output] - } + return {"doc_id": doc_id, "data": [data[0] for data in output]} diff --git a/embedchain/loaders/web_page.py b/embedchain/loaders/web_page.py index 9e62df3882..53d41df013 100644 --- a/embedchain/loaders/web_page.py +++ b/embedchain/loaders/web_page.py @@ -66,7 +66,7 @@ def load_data(self, url): } content = content doc_id = hashlib.sha256((content + url).encode()).hexdigest() - return { + return { "doc_id": doc_id, "data": [ { diff --git a/embedchain/vectordb/base_vector_db.py b/embedchain/vectordb/base_vector_db.py index e7ef5c8ea4..aee18f1c00 100644 --- a/embedchain/vectordb/base_vector_db.py +++ b/embedchain/vectordb/base_vector_db.py @@ -47,4 +47,4 @@ def reset(self): raise NotImplementedError def set_collection_name(self, name: str): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/embedchain/vectordb/chroma.py b/embedchain/vectordb/chroma.py index 2717b378f5..3086c23fae 100644 --- a/embedchain/vectordb/chroma.py +++ b/embedchain/vectordb/chroma.py @@ -1,5 +1,5 @@ import logging -from typing import Dict, List, Optional, Any +from typing import Any, Dict, List, Optional from chromadb import Collection, QueryResult from langchain.docstore.document import Document @@ -105,9 +105,7 @@ def get(self, ids=None, where=None, limit=None): args["where"] = where if limit: args["limit"] = limit - return self.collection.get( - **args - ) + return self.collection.get(**args) def get_advanced(self, where): return self.collection.get(where=where, limit=1) diff --git a/tests/chunkers/test_text.py b/tests/chunkers/test_text.py index 6ea5662021..e5bc32ab7f 100644 --- a/tests/chunkers/test_text.py +++ b/tests/chunkers/test_text.py @@ -76,5 +76,5 @@ def load_data(self, src): "content": src, "meta_data": {"url": "none"}, } - ] + ], } diff --git a/tests/embedchain/test_add.py b/tests/embedchain/test_add.py index 446ab496c4..f8c2056406 100644 --- a/tests/embedchain/test_add.py +++ b/tests/embedchain/test_add.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch from embedchain import App -from embedchain.config import AppConfig, AddConfig, ChunkerConfig +from embedchain.config import AddConfig, AppConfig, ChunkerConfig from embedchain.models.data_type import DataType