From 01683012b96b50d62e5113e0e64e5324849510fa Mon Sep 17 00:00:00 2001 From: Prashant Date: Fri, 4 Aug 2023 19:05:27 +0530 Subject: [PATCH 1/4] feat: added support for elasticsearch as a datasource --- docs/advanced/datasource.mdx | 41 +++++ docs/mint.json | 2 +- embedchain/config/apps/AppConfig.py | 4 +- embedchain/config/apps/BaseAppConfig.py | 43 ++++- embedchain/config/apps/CustomAppConfig.py | 3 +- embedchain/config/apps/OpenSourceAppConfig.py | 4 +- embedchain/embedchain.py | 45 ++--- embedchain/utils.py | 6 +- embedchain/vectordb/elasticsearch_db.py | 53 ++++++ embedchain/vectordb/vector_db.py | 165 ++++++++++++++++++ pyproject.toml | 1 + setup.py | 1 + tests/vectordb/test_chroma_db.py | 9 +- tests/vectordb/test_elasticsearch_db.py | 28 +++ 14 files changed, 353 insertions(+), 52 deletions(-) create mode 100644 docs/advanced/datasource.mdx create mode 100644 embedchain/vectordb/elasticsearch_db.py create mode 100644 embedchain/vectordb/vector_db.py create mode 100644 tests/vectordb/test_elasticsearch_db.py diff --git a/docs/advanced/datasource.mdx b/docs/advanced/datasource.mdx new file mode 100644 index 0000000000..374d08484f --- /dev/null +++ b/docs/advanced/datasource.mdx @@ -0,0 +1,41 @@ +--- +title: '💾 Datasource' +--- + +## Vector Database + +We support `Chromadb` and `Elasticsearch` as two type of vector database. +`Chromadb` is used as default. + +### App +```python +import os + +from embedchain import App +from embedchain.config import AppConfig + +os.environ["ES_ENDPOINT"] = "elasticsearch_endpoint" +# Adds HTTP header 'Authorization: ApiKey ' +os.environ["ES_API_KEY_ID"] = "api_key_id" # Optional +os.environ["ES_API_KEY"] = "api_key" # Optional + + +es_app_config = AppConfig(db_type='es') +es_app = App(es_app_config) +``` +_To use elasticsearch as vector db we need an external running instance and connection config_ + +- `Elasticsearch` as vector database can be used by setting `db_type='es'` in `AppConfig`. +- `ES_ENDPOINT` is mandatory to connect to `Elasticsearch`. +- `ES_API_KEY_ID` and `ES_API_KEY` can be configured for authentication and connecting to `Elasticsearch`. +- An index with name `embedchain_store` is created if not present. + +### OpenSourceApp +Similarly for Open source app set `db_type='es'` +```python +from embedchain import OpenSourceApp +from embedchain.config import OpenSourceAppConfig + +opensource_es_app_config = OpenSourceAppConfig(db_type='es') +opensource_es_app = OpenSourceApp(opensource_es_app_config) +``` \ No newline at end of file diff --git a/docs/mint.json b/docs/mint.json index cf73769f27..89164b7326 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -32,7 +32,7 @@ }, { "group": "Advanced", - "pages": ["advanced/app_types", "advanced/interface_types", "advanced/adding_data","advanced/data_types", "advanced/query_configuration", "advanced/configuration", "advanced/testing", "advanced/showcase"] + "pages": ["advanced/app_types", "advanced/interface_types", "advanced/adding_data","advanced/data_types", "advanced/query_configuration", "advanced/configuration", "advanced/testing", "advanced/datasource", "advanced/showcase"] }, { "group": "Examples", diff --git a/embedchain/config/apps/AppConfig.py b/embedchain/config/apps/AppConfig.py index 1e08040e0d..df61ad0ef4 100644 --- a/embedchain/config/apps/AppConfig.py +++ b/embedchain/config/apps/AppConfig.py @@ -16,7 +16,7 @@ class AppConfig(BaseAppConfig): Config to initialize an embedchain custom `App` instance, with extra config options. """ - def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None): + def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None, db_type=None): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. @@ -24,6 +24,7 @@ def __init__(self, log_level=None, host=None, port=None, id=None, collection_nam :param port: Optional. Port for the database server. :param id: Optional. ID of the app. Document metadata will have this id. :param collection_name: Optional. Collection name for the database. + :param db_type: Optional. db type to use. Currently [chroma, es] are supported. """ super().__init__( log_level=log_level, @@ -32,6 +33,7 @@ def __init__(self, log_level=None, host=None, port=None, id=None, collection_nam port=port, id=id, collection_name=collection_name, + db_type=db_type, ) @staticmethod diff --git a/embedchain/config/apps/BaseAppConfig.py b/embedchain/config/apps/BaseAppConfig.py index c1f0daa669..9c8f167a0d 100644 --- a/embedchain/config/apps/BaseAppConfig.py +++ b/embedchain/config/apps/BaseAppConfig.py @@ -1,6 +1,7 @@ import logging from embedchain.config.BaseConfig import BaseConfig +from embedchain.vectordb.vector_db import VectorDb class BaseAppConfig(BaseConfig): @@ -8,7 +9,17 @@ class BaseAppConfig(BaseConfig): Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`. """ - def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None, collection_name=None): + def __init__( + self, + log_level=None, + embedding_fn=None, + db=None, + host=None, + port=None, + id=None, + collection_name=None, + db_type=None, + ): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. @@ -18,30 +29,46 @@ def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=N :param port: Optional. Port for the database server. :param id: Optional. ID of the app. Document metadata will have this id. :param collection_name: Optional. Collection name for the database. + :param db_type: Optional. db type to use. Currently [chroma, es] are supported. """ self._setup_logging(log_level) - - self.db = db if db else BaseAppConfig.default_db(embedding_fn=embedding_fn, host=host, port=port) + self.db = BaseAppConfig.get_db( + db=db, + embedding_fn=embedding_fn, + host=host, + port=port, + db_type=db_type, + ) self.collection_name = collection_name if collection_name else "embedchain_store" self.id = id return @staticmethod - def default_db(embedding_fn, host, port): + def get_db(db, embedding_fn, host, port, db_type): """ - Sets database to default (`ChromaDb`). - + Get db based on db_type, db with default database (`ChromaDb`) + :param Optional. (Vector) database to use for embeddings. :param embedding_fn: Embedding function to use in database. :param host: Optional. Hostname for the database server. :param port: Optional. Port for the database server. - :returns: Default database + :param db_type: Optional. db type to use. Supported values (`es`, `chroma`) + :returns: database instance :raises ValueError: BaseAppConfig knows no default embedding function. """ + if db: + return VectorDb(db, db_type) + if embedding_fn is None: raise ValueError("ChromaDb cannot be instantiated without an embedding function") + + if db_type == "es": + from embedchain.vectordb.elasticsearch_db import EsDB + + return VectorDb(EsDB(embedding_fn=embedding_fn), db_type) + from embedchain.vectordb.chroma_db import ChromaDB - return ChromaDB(embedding_fn=embedding_fn, host=host, port=port) + return VectorDb(ChromaDB(embedding_fn=embedding_fn, host=host, port=port)) def _setup_logging(self, debug_level): level = logging.WARNING # Default level diff --git a/embedchain/config/apps/CustomAppConfig.py b/embedchain/config/apps/CustomAppConfig.py index edacf4b5bc..cc42f85a56 100644 --- a/embedchain/config/apps/CustomAppConfig.py +++ b/embedchain/config/apps/CustomAppConfig.py @@ -52,8 +52,7 @@ def __init__( super().__init__( log_level=log_level, embedding_fn=CustomAppConfig.embedding_function( - embedding_function=embedding_fn, model=embedding_fn_model, - deployment_name=deployment_name + embedding_function=embedding_fn, model=embedding_fn_model, deployment_name=deployment_name ), db=db, host=host, diff --git a/embedchain/config/apps/OpenSourceAppConfig.py b/embedchain/config/apps/OpenSourceAppConfig.py index 8666f1257d..c06247d6eb 100644 --- a/embedchain/config/apps/OpenSourceAppConfig.py +++ b/embedchain/config/apps/OpenSourceAppConfig.py @@ -8,7 +8,7 @@ class OpenSourceAppConfig(BaseAppConfig): Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options. """ - def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None, model=None): + def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None, model=None, db_type=None): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. @@ -17,6 +17,7 @@ def __init__(self, log_level=None, host=None, port=None, id=None, collection_nam :param host: Optional. Hostname for the database server. :param port: Optional. Port for the database server. :param model: Optional. GPT4ALL uses the model to instantiate the class. + :param db_type: Optional. db type to use. Currently [chroma, es] are supported. So unlike `App`, it has to be provided before querying. """ self.model = model or "orca-mini-3b.ggmlv3.q4_0.bin" @@ -28,6 +29,7 @@ def __init__(self, log_level=None, host=None, port=None, id=None, collection_nam port=port, id=id, collection_name=collection_name, + db_type=db_type, ) @staticmethod diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 8033eff0af..1967c4e8ec 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -3,7 +3,6 @@ from chromadb.errors import InvalidDimensionException from dotenv import load_dotenv -from langchain.docstore.document import Document from langchain.memory import ConversationBufferMemory from embedchain.chunkers.base_chunker import BaseChunker @@ -31,8 +30,9 @@ def __init__(self, config: BaseAppConfig): """ self.config = config - self.db_client = self.config.db.client - self.collection = self.config.db._get_or_create_collection(self.config.collection_name) + # self.db_client = self.config.db.client + # self.collection = self.config.db._get_or_create_collection(self.config.collection_name) + self.db = self.config.db self.user_asks = [] self.is_docs_site_instance = False self.online = False @@ -97,13 +97,10 @@ def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata metadatas = embeddings_data["metadatas"] ids = embeddings_data["ids"] # get existing ids, and discard doc if any common id exist. - where = {"app_id": self.config.id} if self.config.id is not None else {} - # where={"url": src} - existing_docs = self.collection.get( - ids=ids, - where=where, # optional filter + existing_ids = self.db.get( + ids, + self.config.id, ) - existing_ids = set(existing_docs["ids"]) if len(existing_ids): data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)} @@ -128,19 +125,9 @@ def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata # Add metadata to each document metadatas_with_metadata = [{**meta, **metadata} for meta in metadatas] - self.collection.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids) + self.db.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids) print((f"Successfully saved {src}. New chunks count: " f"{self.count() - chunks_before_addition}")) - def _format_result(self, results): - return [ - (Document(page_content=result[0], metadata=result[1] or {}), result[2]) - for result in zip( - results["documents"][0], - results["metadatas"][0], - results["distances"][0], - ) - ] - def get_llm_model_answer(self): """ Usually implemented by child class @@ -157,22 +144,16 @@ def retrieve_from_database(self, input_query, config: QueryConfig): :return: The content of the document that matched your query. """ try: - where = {"app_id": self.config.id} if self.config.id is not None else {} # optional filter - result = self.collection.query( - query_texts=[ - input_query, - ], - n_results=config.number_documents, - where=where, + contents = self.db.query( + input_query, + config.number_documents, + self.config.id, ) except InvalidDimensionException as e: raise InvalidDimensionException( e.message() + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501 ) from None - - results_formatted = self._format_result(result) - contents = [result[0].page_content for result in results_formatted] return contents def _append_search_and_context(self, context, web_search_result): @@ -339,11 +320,11 @@ def count(self): :return: The number of embeddings. """ - return self.collection.count() + return self.db.count() def reset(self): """ Resets the database. Deletes all embeddings irreversibly. `App` has to be reinitialized after using this method. """ - self.db_client.reset() + self.db.reset() diff --git a/embedchain/utils.py b/embedchain/utils.py index 5b7cfa60bd..2afae19123 100644 --- a/embedchain/utils.py +++ b/embedchain/utils.py @@ -65,7 +65,7 @@ def use_pysqlite3(): import datetime import subprocess import sys - + subprocess.check_call( [sys.executable, "-m", "pip", "install", "pysqlite3-binary", "--quiet", "--disable-pip-version-check"] ) @@ -86,6 +86,6 @@ def use_pysqlite3(): print( f"{current_time} [embedchain] [ERROR]", "Failed to swap std-lib sqlite3 with pysqlite3 for ChromaDb compatibility.", - f"Error:", - e + "Error:", + e, ) diff --git a/embedchain/vectordb/elasticsearch_db.py b/embedchain/vectordb/elasticsearch_db.py new file mode 100644 index 0000000000..903604a33b --- /dev/null +++ b/embedchain/vectordb/elasticsearch_db.py @@ -0,0 +1,53 @@ +import os +from typing import Callable, Optional + +from elasticsearch import Elasticsearch +from elasticsearch.helpers import bulk + +from embedchain.vectordb.base_vector_db import BaseVectorDB + + +class EsDB(BaseVectorDB): + """ + Elasticsearch as vector database + :param embedding_fn: Function to generate embedding vectors. + :param config: Optional. elastic search client + """ + + def __init__( + self, embedding_fn: Callable[[list[str]], list[str]] = None, es_client: Optional[Elasticsearch] = None + ): + if not hasattr(embedding_fn, "__call__"): + raise ValueError("Embedding function is not a function") + self.embedding_fn = embedding_fn + endpoint = os.getenv("ES_ENDPOINT") + api_key_id = os.getenv("ES_API_KEY_ID") + api_key = os.getenv("ES_API_KEY") + api_key_id = api_key_id if api_key_id is not None else "" + api_key = api_key if api_key is not None else "" + if not endpoint and not es_client: + raise ValueError("Elasticsearch endpoint is required to connect") + self.client = es_client if es_client is not None else Elasticsearch(endpoint, api_key=(api_key_id, api_key)) + self.es_index = "embedchain_store" + # Check if its configurable, currently setting it to max, also check for performance issues + self.vector_dim = 2048 + self.bulk = bulk + index_settings = { + "mappings": { + "properties": { + "text": {"type": "text"}, + "text_vector": {"type": "dense_vector", "index": False, "dims": self.vector_dim}, + } + } + } + if not self.client.indices.exists(index=self.es_index): + # create index if not exist + print("Creating index", self.es_index, index_settings) + self.client.indices.create(index=self.es_index, body=index_settings) + super().__init__() + + def _get_or_create_db(self): + return self.client + + def _get_or_create_collection(self): + """Note: nothing to return here. Discuss later""" diff --git a/embedchain/vectordb/vector_db.py b/embedchain/vectordb/vector_db.py new file mode 100644 index 0000000000..bdcf42b157 --- /dev/null +++ b/embedchain/vectordb/vector_db.py @@ -0,0 +1,165 @@ +from typing import Any, List, Optional, Union + +import numpy as np +from langchain.docstore.document import Document + +from embedchain.vectordb.chroma_db import ChromaDB +from embedchain.vectordb.elasticsearch_db import EsDB + + +class VectorDb: + """ + Database abstraction class, abstracting common functionality + :param db: (Vector) database instance to use for embeddings. Can be es/chroma + :param db_type: which type of database is used. [es, chroma] + """ + + def __init__(self, db: Union[ChromaDB, EsDB], db_type: Optional[str] = None): + self.db = db + self.db_type = db_type + + """ + Get existing doc ids present in vector database + :param ids: list of doc ids to check for existance + :param app_id: Optional application to filter data + """ + + def get(self, ids: List[str], app_id: Optional[str]) -> List[str]: + if self.db_type == "es": + query = {"bool": {"must": [{"ids": {"values": ids}}]}} + if app_id: + query["bool"]["must"].append({"term": {"metadata.app_id": app_id}}) + response = self.db.client.search(index=self.db.es_index, query=query, _source=False) + docs = response["hits"]["hits"] + ids = [doc["_id"] for doc in docs] + return set(ids) + + where = {"app_id": app_id} if app_id is not None else {} + existing_docs = self.db.collection.get( + ids=ids, + where=where, # optional filter + ) + + return set(existing_docs["ids"]) + + """ + add data in vector database + :param documents: list of texts to add + :param metadatas: list of metadata associated with docs + :param ids: ids of docs + """ + + def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any: + if self.db_type == "es": + docs = [] + embeddings = self.db.embedding_fn(documents) + for id, text, metadata, text_vector in zip(ids, documents, metadatas, embeddings): + # need to do this to create fixed dimension vector, padding with zeros + # NOTE: look for better solutions + vector = np.zeros(self.db.vector_dim) + vector[: len(text_vector)] = text_vector + docs.append( + { + "_index": self.db.es_index, + "_id": id, + "_source": {"text": text, "metadata": metadata, "text_vector": vector.tolist()}, + } + ) + self.db.bulk(self.db.client, docs) + self.db.client.indices.refresh(index=self.db.es_index) + return + + self.db.collection.add(documents=documents, metadatas=metadatas, ids=ids) + + def _format_result(self, results): + # discuss why there was a need to create lagchain Document + return [ + (Document(page_content=result[0], metadata=result[1] or {}), result[2]) + for result in zip( + results["documents"][0], + results["metadatas"][0], + results["distances"][0], + ) + ] + + """ + query contents from vector data base based on vector similarity + :param input_query: list of query string + :param number_documents: no of similar documents to fetch from database + :param app_id: Optional app id for filtering data + """ + + def query( + self, input_query: List[str], number_documents: int, app_id: Optional[Union[int, str]] = None + ) -> List[str]: + if self.db_type == "es": + """ + Currently have taken max 2048 as vector dim, there is a need to re check the + accuracy of cosineSimilarity used to retrive similar documents + Not using Approximate kNN because cannot index dense vector due to dims > 1024 + https://www.elastic.co/guide/en/elasticsearch/reference/master/knn-search.html + Using Exact KNN + https://www.elastic.co/guide/en/elasticsearch/reference/master/knn-search.html#exact-knn + """ + input_query_vector = self.db.embedding_fn(input_query) + query_vector = np.zeros(self.db.vector_dim) + query_vector[: len(input_query_vector[0])] = input_query_vector[0] + query = { + "script_score": { + "query": {"bool": {"must": [{"exists": {"field": "text"}}]}}, + "script": { + "source": "cosineSimilarity(params.input_query_vector, 'text_vector') + 1.0", + "params": {"input_query_vector": query_vector.tolist()}, + }, + } + } + if app_id: + query["script_score"]["query"]["bool"]["must"] = [{"term": {"metadata.app_id": app_id}}] + _source = ["text"] + size = number_documents + response = self.db.client.search(index=self.db.es_index, query=query, _source=_source, size=size) + docs = response["hits"]["hits"] + contents = [doc["_source"]["text"] for doc in docs] + return contents + + where = {"app_id": app_id} if app_id is not None else {} # optional filter + result = self.db.collection.query( + query_texts=[ + input_query, + ], + n_results=number_documents, + where=where, + ) + + results_formatted = self._format_result(result) + contents = [result[0].page_content for result in results_formatted] + return contents + + """ + get count of docs in the database + :param app_id: Optional app id to filter data + """ + + def count(self, app_id: Optional[Union[int, str]] = None) -> int: + if self.db_type == "es": + query = {"match_all": {}} + if app_id: + query = {"bool": {"must": [{"term": {"metadata.app_id": app_id}}]}} + response = self.db.client.count(index=self.db.es_index, query=query) + doc_count = response["count"] + return doc_count + + return self.db.collection.count() + + # Delete all data from the database + def reset(self): + if self.db_type == "es" and self.db.client.indices.exists(index=self.db.es_index): + # delete index in Es + self.db.client.indices.delete(index=self.db.es_index) + return + + self.db.collection.delete() + + # get Vector Db instance + def get_db(self): + return self.db diff --git a/pyproject.toml b/pyproject.toml index 8521180437..9095b6c07a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ beautifulsoup4 = "^4.12.2" pypdf = "^3.11.0" pytube = "^15.0.0" llama-index = { version = "^0.7.21", optional = true } +elasticsearch = "^8.9.0" diff --git a/setup.py b/setup.py index 137888b628..21ae9ea46c 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ "pydantic==1.10.8", "replicate==0.9.0", "duckduckgo-search==3.8.4", + "elasticsearch>=8.0.0", ], extras_require={"dev": ["black", "ruff", "isort", "pytest"], "community": ["llama-index==0.7.21"]}, ) diff --git a/tests/vectordb/test_chroma_db.py b/tests/vectordb/test_chroma_db.py index 9ddff085e1..37a62d5789 100644 --- a/tests/vectordb/test_chroma_db.py +++ b/tests/vectordb/test_chroma_db.py @@ -73,6 +73,7 @@ def test_init_with_host_and_port(self, mock_client): self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None) self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None) + class TestChromaDbDuplicateHandling: def test_duplicates_throw_warning(self, caplog): """ @@ -101,8 +102,8 @@ def test_duplicates_collections_no_warning(self, caplog): app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) app.set_collection("test_collection_2") app.collection.add(embeddings=[[0, 0, 0]], ids=["0"]) - assert "Insert of existing embedding ID: 0" not in caplog.text # not - assert "Add of existing embedding ID: 0" not in caplog.text # not + assert "Insert of existing embedding ID: 0" not in caplog.text # not + assert "Add of existing embedding ID: 0" not in caplog.text # not class TestChromaDbCollection(unittest.TestCase): @@ -197,9 +198,9 @@ def test_parallel_collections(self): app2.collection.add(embeddings=[0, 0, 0], ids=["0"]) # Swap names and test - app1.set_collection('test_collection_2') + app1.set_collection("test_collection_2") self.assertEqual(app1.count(), 1) - app2.set_collection('test_collection_1') + app2.set_collection("test_collection_1") self.assertEqual(app2.count(), 3) def test_ids_share_collections(self): diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py new file mode 100644 index 0000000000..bb4964905a --- /dev/null +++ b/tests/vectordb/test_elasticsearch_db.py @@ -0,0 +1,28 @@ +import unittest +from unittest.mock import MagicMock, Mock, patch + +from embedchain.vectordb.elasticsearch_db import EsDB + + +class TestEsDB(unittest.TestCase): + def setUp(self): + # set mock es client + self.mock_client = MagicMock() + self.mock_client.indices.exists.return_value = True + + def test_init_with_invalid_embedding_fn(self): + # Test if an exception is raised when an invalid embedding_fn is provided + with self.assertRaises(ValueError): + EsDB(embedding_fn=None) + + def test_init_with_valid_embedding_and_client(self): + # check for successful creation of EsDB instance + esdb = EsDB(embedding_fn=Mock(), es_client=self.mock_client) + self.assertIsInstance(esdb, EsDB) + + @patch("os.getenv") # Mock the os.getenv function to return None for ES_ENDPOINT + def test_init_with_missing_endpoint(self, mock_os_getenv): + # Test if an exception is raised when ES_ENDPOINT is missing + mock_os_getenv.return_value = None + with self.assertRaises(ValueError): + EsDB(embedding_fn=Mock()) From ab67c9d225e9a7bb7289aa842a7742a5e12dc992 Mon Sep 17 00:00:00 2001 From: Prashant Date: Fri, 4 Aug 2023 21:47:49 +0530 Subject: [PATCH 2/4] fix(es-datasource): created different index with fixed vector dim for different app --- docs/advanced/datasource.mdx | 5 +++-- embedchain/config/apps/AppConfig.py | 1 + embedchain/config/apps/BaseAppConfig.py | 13 ++++++------- embedchain/config/apps/OpenSourceAppConfig.py | 1 + embedchain/vectordb/elasticsearch_db.py | 12 ++++++++---- embedchain/vectordb/vector_db.py | 12 +++--------- tests/vectordb/test_elasticsearch_db.py | 7 ++++++- 7 files changed, 28 insertions(+), 23 deletions(-) diff --git a/docs/advanced/datasource.mdx b/docs/advanced/datasource.mdx index 374d08484f..62a026e993 100644 --- a/docs/advanced/datasource.mdx +++ b/docs/advanced/datasource.mdx @@ -28,7 +28,7 @@ _To use elasticsearch as vector db we need an external running instance and conn - `Elasticsearch` as vector database can be used by setting `db_type='es'` in `AppConfig`. - `ES_ENDPOINT` is mandatory to connect to `Elasticsearch`. - `ES_API_KEY_ID` and `ES_API_KEY` can be configured for authentication and connecting to `Elasticsearch`. -- An index with name `embedchain_store` is created if not present. +- An index with name `embedchain_store_1536` is created if not present. ### OpenSourceApp Similarly for Open source app set `db_type='es'` @@ -38,4 +38,5 @@ from embedchain.config import OpenSourceAppConfig opensource_es_app_config = OpenSourceAppConfig(db_type='es') opensource_es_app = OpenSourceApp(opensource_es_app_config) -``` \ No newline at end of file +``` +- An index with name `embedchain_store_384` is created if not present. \ No newline at end of file diff --git a/embedchain/config/apps/AppConfig.py b/embedchain/config/apps/AppConfig.py index df61ad0ef4..e4ce55f334 100644 --- a/embedchain/config/apps/AppConfig.py +++ b/embedchain/config/apps/AppConfig.py @@ -34,6 +34,7 @@ def __init__(self, log_level=None, host=None, port=None, id=None, collection_nam id=id, collection_name=collection_name, db_type=db_type, + vector_dim=1536, # vector length created by embedding fn ) @staticmethod diff --git a/embedchain/config/apps/BaseAppConfig.py b/embedchain/config/apps/BaseAppConfig.py index 9c8f167a0d..246eba18ed 100644 --- a/embedchain/config/apps/BaseAppConfig.py +++ b/embedchain/config/apps/BaseAppConfig.py @@ -19,6 +19,7 @@ def __init__( id=None, collection_name=None, db_type=None, + vector_dim: int = None, ): """ :param log_level: Optional. (String) Debug level @@ -30,21 +31,18 @@ def __init__( :param id: Optional. ID of the app. Document metadata will have this id. :param collection_name: Optional. Collection name for the database. :param db_type: Optional. db type to use. Currently [chroma, es] are supported. + :param vector_dim: Vector dimension generated by embedding fn """ self._setup_logging(log_level) self.db = BaseAppConfig.get_db( - db=db, - embedding_fn=embedding_fn, - host=host, - port=port, - db_type=db_type, + db=db, embedding_fn=embedding_fn, host=host, port=port, db_type=db_type, vector_dim=vector_dim ) self.collection_name = collection_name if collection_name else "embedchain_store" self.id = id return @staticmethod - def get_db(db, embedding_fn, host, port, db_type): + def get_db(db, embedding_fn, host, port, db_type, vector_dim): """ Get db based on db_type, db with default database (`ChromaDb`) :param Optional. (Vector) database to use for embeddings. @@ -52,6 +50,7 @@ def get_db(db, embedding_fn, host, port, db_type): :param host: Optional. Hostname for the database server. :param port: Optional. Port for the database server. :param db_type: Optional. db type to use. Supported values (`es`, `chroma`) + :param vector_dim: Vector dimension generated by embedding fn :returns: database instance :raises ValueError: BaseAppConfig knows no default embedding function. """ @@ -64,7 +63,7 @@ def get_db(db, embedding_fn, host, port, db_type): if db_type == "es": from embedchain.vectordb.elasticsearch_db import EsDB - return VectorDb(EsDB(embedding_fn=embedding_fn), db_type) + return VectorDb(EsDB(embedding_fn=embedding_fn, vector_dim=vector_dim), db_type) from embedchain.vectordb.chroma_db import ChromaDB diff --git a/embedchain/config/apps/OpenSourceAppConfig.py b/embedchain/config/apps/OpenSourceAppConfig.py index c06247d6eb..174a238d21 100644 --- a/embedchain/config/apps/OpenSourceAppConfig.py +++ b/embedchain/config/apps/OpenSourceAppConfig.py @@ -30,6 +30,7 @@ def __init__(self, log_level=None, host=None, port=None, id=None, collection_nam id=id, collection_name=collection_name, db_type=db_type, + vector_dim=384, # vector length created by embedding fn ) @staticmethod diff --git a/embedchain/vectordb/elasticsearch_db.py b/embedchain/vectordb/elasticsearch_db.py index 903604a33b..4433f87062 100644 --- a/embedchain/vectordb/elasticsearch_db.py +++ b/embedchain/vectordb/elasticsearch_db.py @@ -15,7 +15,10 @@ class EsDB(BaseVectorDB): """ def __init__( - self, embedding_fn: Callable[[list[str]], list[str]] = None, es_client: Optional[Elasticsearch] = None + self, + embedding_fn: Callable[[list[str]], list[str]] = None, + es_client: Optional[Elasticsearch] = None, + vector_dim: int = None, ): if not hasattr(embedding_fn, "__call__"): raise ValueError("Embedding function is not a function") @@ -27,10 +30,11 @@ def __init__( api_key = api_key if api_key is not None else "" if not endpoint and not es_client: raise ValueError("Elasticsearch endpoint is required to connect") + if vector_dim is None: + raise ValueError("Vector Dimension is required to refer correct index and mapping") self.client = es_client if es_client is not None else Elasticsearch(endpoint, api_key=(api_key_id, api_key)) - self.es_index = "embedchain_store" - # Check if its configurable, currently setting it to max, also check for performance issues - self.vector_dim = 2048 + self.vector_dim = vector_dim + self.es_index = f"embedchain_store_{self.vector_dim}" self.bulk = bulk index_settings = { "mappings": { diff --git a/embedchain/vectordb/vector_db.py b/embedchain/vectordb/vector_db.py index bdcf42b157..8d8bfeca3b 100644 --- a/embedchain/vectordb/vector_db.py +++ b/embedchain/vectordb/vector_db.py @@ -1,6 +1,5 @@ from typing import Any, List, Optional, Union -import numpy as np from langchain.docstore.document import Document from embedchain.vectordb.chroma_db import ChromaDB @@ -54,15 +53,11 @@ def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> docs = [] embeddings = self.db.embedding_fn(documents) for id, text, metadata, text_vector in zip(ids, documents, metadatas, embeddings): - # need to do this to create fixed dimension vector, padding with zeros - # NOTE: look for better solutions - vector = np.zeros(self.db.vector_dim) - vector[: len(text_vector)] = text_vector docs.append( { "_index": self.db.es_index, "_id": id, - "_source": {"text": text, "metadata": metadata, "text_vector": vector.tolist()}, + "_source": {"text": text, "metadata": metadata, "text_vector": text_vector}, } ) self.db.bulk(self.db.client, docs) @@ -102,14 +97,13 @@ def query( https://www.elastic.co/guide/en/elasticsearch/reference/master/knn-search.html#exact-knn """ input_query_vector = self.db.embedding_fn(input_query) - query_vector = np.zeros(self.db.vector_dim) - query_vector[: len(input_query_vector[0])] = input_query_vector[0] + query_vector = input_query_vector[0] query = { "script_score": { "query": {"bool": {"must": [{"exists": {"field": "text"}}]}}, "script": { "source": "cosineSimilarity(params.input_query_vector, 'text_vector') + 1.0", - "params": {"input_query_vector": query_vector.tolist()}, + "params": {"input_query_vector": query_vector}, }, } } diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py index bb4964905a..af5f1fa8ea 100644 --- a/tests/vectordb/test_elasticsearch_db.py +++ b/tests/vectordb/test_elasticsearch_db.py @@ -15,9 +15,14 @@ def test_init_with_invalid_embedding_fn(self): with self.assertRaises(ValueError): EsDB(embedding_fn=None) + def test_init_with_invalid_vector_dim(self): + # Test if an exception is raised when an invalid vector_dim is provided + with self.assertRaises(ValueError): + EsDB(embedding_fn=Mock(), es_client=self.mock_client, vector_dim=None) + def test_init_with_valid_embedding_and_client(self): # check for successful creation of EsDB instance - esdb = EsDB(embedding_fn=Mock(), es_client=self.mock_client) + esdb = EsDB(embedding_fn=Mock(), es_client=self.mock_client, vector_dim=1024) self.assertIsInstance(esdb, EsDB) @patch("os.getenv") # Mock the os.getenv function to return None for ES_ENDPOINT From ef0219e8cf8d3d19eb67a4f915923907beddbefb Mon Sep 17 00:00:00 2001 From: Prashant Date: Thu, 10 Aug 2023 08:24:50 +0530 Subject: [PATCH 3/4] fixes: - Remove additional vector db class and add functions in base and inherited db classes - Move vector dimensions and db type in enum classes under models - Support elasticsearch as db type in CustomApp and do not alter App and OpenSourceApp - Add elasticsearch as an optional dependency --- embedchain/config/apps/AppConfig.py | 5 +- embedchain/config/apps/BaseAppConfig.py | 33 ++-- embedchain/config/apps/CustomAppConfig.py | 23 ++- embedchain/config/apps/OpenSourceAppConfig.py | 5 +- embedchain/embedchain.py | 39 +++-- embedchain/models/VectorDatabases.py | 6 + embedchain/models/VectorDimensions.py | 9 + embedchain/models/__init__.py | 2 + embedchain/vectordb/base_vector_db.py | 15 ++ embedchain/vectordb/chroma_db.py | 73 +++++++- embedchain/vectordb/elasticsearch_db.py | 110 ++++++++++-- embedchain/vectordb/vector_db.py | 159 ------------------ pyproject.toml | 3 +- setup.py | 7 +- tests/vectordb/test_elasticsearch_db.py | 14 +- 15 files changed, 281 insertions(+), 222 deletions(-) create mode 100644 embedchain/models/VectorDatabases.py create mode 100644 embedchain/models/VectorDimensions.py delete mode 100644 embedchain/vectordb/vector_db.py diff --git a/embedchain/config/apps/AppConfig.py b/embedchain/config/apps/AppConfig.py index e4ce55f334..1e08040e0d 100644 --- a/embedchain/config/apps/AppConfig.py +++ b/embedchain/config/apps/AppConfig.py @@ -16,7 +16,7 @@ class AppConfig(BaseAppConfig): Config to initialize an embedchain custom `App` instance, with extra config options. """ - def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None, db_type=None): + def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. @@ -24,7 +24,6 @@ def __init__(self, log_level=None, host=None, port=None, id=None, collection_nam :param port: Optional. Port for the database server. :param id: Optional. ID of the app. Document metadata will have this id. :param collection_name: Optional. Collection name for the database. - :param db_type: Optional. db type to use. Currently [chroma, es] are supported. """ super().__init__( log_level=log_level, @@ -33,8 +32,6 @@ def __init__(self, log_level=None, host=None, port=None, id=None, collection_nam port=port, id=id, collection_name=collection_name, - db_type=db_type, - vector_dim=1536, # vector length created by embedding fn ) @staticmethod diff --git a/embedchain/config/apps/BaseAppConfig.py b/embedchain/config/apps/BaseAppConfig.py index 246eba18ed..f558ca915b 100644 --- a/embedchain/config/apps/BaseAppConfig.py +++ b/embedchain/config/apps/BaseAppConfig.py @@ -1,7 +1,7 @@ import logging from embedchain.config.BaseConfig import BaseConfig -from embedchain.vectordb.vector_db import VectorDb +from embedchain.models import VectorDatabases, VectorDimensions class BaseAppConfig(BaseConfig): @@ -18,8 +18,8 @@ def __init__( port=None, id=None, collection_name=None, - db_type=None, - vector_dim: int = None, + db_type: VectorDatabases = None, + vector_dim: VectorDimensions = None, ): """ :param log_level: Optional. (String) Debug level @@ -30,19 +30,25 @@ def __init__( :param port: Optional. Port for the database server. :param id: Optional. ID of the app. Document metadata will have this id. :param collection_name: Optional. Collection name for the database. - :param db_type: Optional. db type to use. Currently [chroma, es] are supported. + :param db_type: Optional. type of Vector database to use :param vector_dim: Vector dimension generated by embedding fn """ self._setup_logging(log_level) + self.collection_name = collection_name if collection_name else "embedchain_store" self.db = BaseAppConfig.get_db( - db=db, embedding_fn=embedding_fn, host=host, port=port, db_type=db_type, vector_dim=vector_dim + db=db, + embedding_fn=embedding_fn, + host=host, + port=port, + db_type=db_type, + vector_dim=vector_dim, + collection_name=self.collection_name, ) - self.collection_name = collection_name if collection_name else "embedchain_store" self.id = id return @staticmethod - def get_db(db, embedding_fn, host, port, db_type, vector_dim): + def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name): """ Get db based on db_type, db with default database (`ChromaDb`) :param Optional. (Vector) database to use for embeddings. @@ -51,23 +57,24 @@ def get_db(db, embedding_fn, host, port, db_type, vector_dim): :param port: Optional. Port for the database server. :param db_type: Optional. db type to use. Supported values (`es`, `chroma`) :param vector_dim: Vector dimension generated by embedding fn - :returns: database instance + :param collection_name: Optional. Collection name for the database. :raises ValueError: BaseAppConfig knows no default embedding function. + :returns: database instance """ if db: - return VectorDb(db, db_type) + return db if embedding_fn is None: raise ValueError("ChromaDb cannot be instantiated without an embedding function") - if db_type == "es": - from embedchain.vectordb.elasticsearch_db import EsDB + if db_type == VectorDatabases.ELASTICSEARCH: + from embedchain.vectordb.elasticsearch_db import ElasticsearchDB - return VectorDb(EsDB(embedding_fn=embedding_fn, vector_dim=vector_dim), db_type) + return ElasticsearchDB(embedding_fn=embedding_fn, vector_dim=vector_dim, collection_name=collection_name) from embedchain.vectordb.chroma_db import ChromaDB - return VectorDb(ChromaDB(embedding_fn=embedding_fn, host=host, port=port)) + return ChromaDB(embedding_fn=embedding_fn, host=host, port=port) def _setup_logging(self, debug_level): level = logging.WARNING # Default level diff --git a/embedchain/config/apps/CustomAppConfig.py b/embedchain/config/apps/CustomAppConfig.py index cc42f85a56..abb2e13448 100644 --- a/embedchain/config/apps/CustomAppConfig.py +++ b/embedchain/config/apps/CustomAppConfig.py @@ -3,7 +3,7 @@ from chromadb.api.types import Documents, Embeddings from dotenv import load_dotenv -from embedchain.models import EmbeddingFunctions, Providers +from embedchain.models import EmbeddingFunctions, Providers, VectorDatabases, VectorDimensions from .BaseAppConfig import BaseAppConfig @@ -28,6 +28,7 @@ def __init__( provider: Providers = None, open_source_app_config=None, deployment_name=None, + db_type: VectorDatabases = None, ): """ :param log_level: Optional. (String) Debug level @@ -41,6 +42,7 @@ def __init__( :param collection_name: Optional. Collection name for the database. :param provider: Optional. (Providers): LLM Provider to use. :param open_source_app_config: Optional. Config instance needed for open source apps. + :param db_type: Optional. type of Vector database to use. """ if provider: self.provider = provider @@ -59,6 +61,8 @@ def __init__( port=port, id=id, collection_name=collection_name, + db_type=db_type, + vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn), ) @staticmethod @@ -108,3 +112,20 @@ def embedding_function(embedding_function: EmbeddingFunctions, model: str = None from chromadb.utils import embedding_functions return embedding_functions.SentenceTransformerEmbeddingFunction(model_name=model) + + @staticmethod + def get_vector_dimension(embedding_function: EmbeddingFunctions): + if not isinstance(embedding_function, EmbeddingFunctions): + raise ValueError(f"Invalid option: '{embedding_function}'.") + + if embedding_function == EmbeddingFunctions.OPENAI: + return VectorDimensions.OPENAI.value + + elif embedding_function == EmbeddingFunctions.HUGGING_FACE: + return VectorDimensions.HUGGING_FACE.value + + elif embedding_function == EmbeddingFunctions.VERTEX_AI: + return VectorDimensions.VERTEX_AI.value + + elif embedding_function == EmbeddingFunctions.GPT4ALL: + return VectorDimensions.GPT4ALL.value diff --git a/embedchain/config/apps/OpenSourceAppConfig.py b/embedchain/config/apps/OpenSourceAppConfig.py index 174a238d21..8666f1257d 100644 --- a/embedchain/config/apps/OpenSourceAppConfig.py +++ b/embedchain/config/apps/OpenSourceAppConfig.py @@ -8,7 +8,7 @@ class OpenSourceAppConfig(BaseAppConfig): Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options. """ - def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None, model=None, db_type=None): + def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None, model=None): """ :param log_level: Optional. (String) Debug level ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']. @@ -17,7 +17,6 @@ def __init__(self, log_level=None, host=None, port=None, id=None, collection_nam :param host: Optional. Hostname for the database server. :param port: Optional. Port for the database server. :param model: Optional. GPT4ALL uses the model to instantiate the class. - :param db_type: Optional. db type to use. Currently [chroma, es] are supported. So unlike `App`, it has to be provided before querying. """ self.model = model or "orca-mini-3b.ggmlv3.q4_0.bin" @@ -29,8 +28,6 @@ def __init__(self, log_level=None, host=None, port=None, id=None, collection_nam port=port, id=id, collection_name=collection_name, - db_type=db_type, - vector_dim=384, # vector length created by embedding fn ) @staticmethod diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index 1967c4e8ec..8cfca45546 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -1,8 +1,8 @@ import logging import os -from chromadb.errors import InvalidDimensionException from dotenv import load_dotenv +from langchain.docstore.document import Document from langchain.memory import ConversationBufferMemory from embedchain.chunkers.base_chunker import BaseChunker @@ -30,8 +30,7 @@ def __init__(self, config: BaseAppConfig): """ self.config = config - # self.db_client = self.config.db.client - # self.collection = self.config.db._get_or_create_collection(self.config.collection_name) + self.collection = self.config.db._get_or_create_collection(self.config.collection_name) self.db = self.config.db self.user_asks = [] self.is_docs_site_instance = False @@ -97,9 +96,11 @@ def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata metadatas = embeddings_data["metadatas"] ids = embeddings_data["ids"] # get existing ids, and discard doc if any common id exist. + where = {"app_id": self.config.id} if self.config.id is not None else {} + # where={"url": src} existing_ids = self.db.get( - ids, - self.config.id, + ids=ids, + where=where, # optional filter ) if len(existing_ids): @@ -128,6 +129,16 @@ def load_and_embed(self, loader: BaseLoader, chunker: BaseChunker, src, metadata self.db.add(documents=documents, metadatas=list(metadatas_with_metadata), ids=ids) print((f"Successfully saved {src}. New chunks count: " f"{self.count() - chunks_before_addition}")) + def _format_result(self, results): + return [ + (Document(page_content=result[0], metadata=result[1] or {}), result[2]) + for result in zip( + results["documents"][0], + results["metadatas"][0], + results["distances"][0], + ) + ] + def get_llm_model_answer(self): """ Usually implemented by child class @@ -143,17 +154,13 @@ def retrieve_from_database(self, input_query, config: QueryConfig): :param config: The query configuration. :return: The content of the document that matched your query. """ - try: - contents = self.db.query( - input_query, - config.number_documents, - self.config.id, - ) - except InvalidDimensionException as e: - raise InvalidDimensionException( - e.message() - + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501 - ) from None + where = {"app_id": self.config.id} if self.config.id is not None else {} # optional filter + contents = self.db.query( + input_query=input_query, + n_results=config.number_documents, + where=where, + ) + return contents def _append_search_and_context(self, context, web_search_result): diff --git a/embedchain/models/VectorDatabases.py b/embedchain/models/VectorDatabases.py new file mode 100644 index 0000000000..5abf38443d --- /dev/null +++ b/embedchain/models/VectorDatabases.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class VectorDatabases(Enum): + CHROMADB = "CHROMADB" + ELASTICSEARCH = "ELASTICSEARCH" diff --git a/embedchain/models/VectorDimensions.py b/embedchain/models/VectorDimensions.py new file mode 100644 index 0000000000..9be1f304df --- /dev/null +++ b/embedchain/models/VectorDimensions.py @@ -0,0 +1,9 @@ +from enum import Enum + + +# vector length created by embedding fn +class VectorDimensions(Enum): + GPT4ALL = 384 + OPENAI = 1536 + VERTEX_AI = 768 + HUGGING_FACE = 384 diff --git a/embedchain/models/__init__.py b/embedchain/models/__init__.py index 7c459977eb..c5daa449a2 100644 --- a/embedchain/models/__init__.py +++ b/embedchain/models/__init__.py @@ -1,2 +1,4 @@ from .EmbeddingFunctions import EmbeddingFunctions # noqa: F401 from .Providers import Providers # noqa: F401 +from .VectorDatabases import VectorDatabases # noqa: F401 +from .VectorDimensions import VectorDimensions # noqa: F401 diff --git a/embedchain/vectordb/base_vector_db.py b/embedchain/vectordb/base_vector_db.py index f38e3d3110..0ed1e3c0aa 100644 --- a/embedchain/vectordb/base_vector_db.py +++ b/embedchain/vectordb/base_vector_db.py @@ -10,3 +10,18 @@ def _get_or_create_db(self): def _get_or_create_collection(self): raise NotImplementedError + + def get(self): + raise NotImplementedError + + def add(self): + raise NotImplementedError + + def query(self): + raise NotImplementedError + + def count(self): + raise NotImplementedError + + def reset(self): + raise NotImplementedError diff --git a/embedchain/vectordb/chroma_db.py b/embedchain/vectordb/chroma_db.py index 168c622145..b50c8b52f5 100644 --- a/embedchain/vectordb/chroma_db.py +++ b/embedchain/vectordb/chroma_db.py @@ -1,4 +1,8 @@ import logging +from typing import Any, Dict, List + +from chromadb.errors import InvalidDimensionException +from langchain.docstore.document import Document try: import chromadb @@ -7,6 +11,7 @@ use_pysqlite3() import chromadb + from chromadb.config import Settings from embedchain.vectordb.base_vector_db import BaseVectorDB @@ -41,7 +46,73 @@ def _get_or_create_db(self): def _get_or_create_collection(self, name): """Get or create the collection.""" - return self.client.get_or_create_collection( + self.collection = self.client.get_or_create_collection( name=name, embedding_function=self.embedding_fn, ) + return self.collection + + def get(self, ids: List[str], where: Dict[str, any]) -> List[str]: + """ + Get existing doc ids present in vector database + :param ids: list of doc ids to check for existance + :param where: Optional. to filter data + """ + existing_docs = self.collection.get( + ids=ids, + where=where, # optional filter + ) + + return set(existing_docs["ids"]) + + def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any: + """ + add data in vector database + :param documents: list of texts to add + :param metadatas: list of metadata associated with docs + :param ids: ids of docs + """ + self.collection.add(documents=documents, metadatas=metadatas, ids=ids) + + def _format_result(self, results): + return [ + (Document(page_content=result[0], metadata=result[1] or {}), result[2]) + for result in zip( + results["documents"][0], + results["metadatas"][0], + results["distances"][0], + ) + ] + + def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]: + """ + query contents from vector data base based on vector similarity + :param input_query: list of query string + :param n_results: no of similar documents to fetch from database + :param where: Optional. to filter data + :return: The content of the document that matched your query. + """ + try: + result = self.collection.query( + query_texts=[ + input_query, + ], + n_results=n_results, + where=where, + ) + except InvalidDimensionException as e: + raise InvalidDimensionException( + e.message() + + ". This is commonly a side-effect when an embedding function, different from the one used to add the embeddings, is used to retrieve an embedding from the database." # noqa E501 + ) from None + + results_formatted = self._format_result(result) + contents = [result[0].page_content for result in results_formatted] + return contents + + def count(self) -> int: + return self.collection.count() + + def reset(self): + # Delete all data from the database + self.client.reset() diff --git a/embedchain/vectordb/elasticsearch_db.py b/embedchain/vectordb/elasticsearch_db.py index 4433f87062..145745d305 100644 --- a/embedchain/vectordb/elasticsearch_db.py +++ b/embedchain/vectordb/elasticsearch_db.py @@ -1,25 +1,32 @@ import os -from typing import Callable, Optional +from typing import Any, Callable, Dict, List, Optional -from elasticsearch import Elasticsearch -from elasticsearch.helpers import bulk +try: + from elasticsearch import Elasticsearch + from elasticsearch.helpers import bulk +except ImportError: + raise ImportError( + "Elasticsearch requires extra dependencies. Install with `pip install embedchain[elasticsearch]`" + ) from None +from embedchain.models.VectorDimensions import VectorDimensions from embedchain.vectordb.base_vector_db import BaseVectorDB -class EsDB(BaseVectorDB): - """ - Elasticsearch as vector database - :param embedding_fn: Function to generate embedding vectors. - :param config: Optional. elastic search client - """ - +class ElasticsearchDB(BaseVectorDB): def __init__( self, embedding_fn: Callable[[list[str]], list[str]] = None, es_client: Optional[Elasticsearch] = None, - vector_dim: int = None, + vector_dim: VectorDimensions = None, + collection_name: str = None, ): + """ + Elasticsearch as vector database + :param embedding_fn: Function to generate embedding vectors. + :param vector_dim: Vector dimension generated by embedding fn + :param collection_name: Optional. Collection name for the database. + """ if not hasattr(embedding_fn, "__call__"): raise ValueError("Embedding function is not a function") self.embedding_fn = embedding_fn @@ -34,8 +41,8 @@ def __init__( raise ValueError("Vector Dimension is required to refer correct index and mapping") self.client = es_client if es_client is not None else Elasticsearch(endpoint, api_key=(api_key_id, api_key)) self.vector_dim = vector_dim - self.es_index = f"embedchain_store_{self.vector_dim}" - self.bulk = bulk + # self.collection_name = collection_name if collection_name else "embedchain_store" + self.es_index = f"{collection_name}_{self.vector_dim}" index_settings = { "mappings": { "properties": { @@ -53,5 +60,80 @@ def __init__( def _get_or_create_db(self): return self.client - def _get_or_create_collection(self): + def _get_or_create_collection(self, name): """Note: nothing to return here. Discuss later""" + + def get(self, ids: List[str], where: Dict[str, any]) -> List[str]: + """ + Get existing doc ids present in vector database + :param ids: list of doc ids to check for existance + :param where: Optional. to filter data + """ + query = {"bool": {"must": [{"ids": {"values": ids}}]}} + if "app_id" in where: + app_id = where["app_id"] + query["bool"]["must"].append({"term": {"metadata.app_id": app_id}}) + response = self.client.search(index=self.es_index, query=query, _source=False) + docs = response["hits"]["hits"] + ids = [doc["_id"] for doc in docs] + return set(ids) + + def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any: + """ + add data in vector database + :param documents: list of texts to add + :param metadatas: list of metadata associated with docs + :param ids: ids of docs + """ + docs = [] + embeddings = self.embedding_fn(documents) + for id, text, metadata, text_vector in zip(ids, documents, metadatas, embeddings): + docs.append( + { + "_index": self.es_index, + "_id": id, + "_source": {"text": text, "metadata": metadata, "text_vector": text_vector}, + } + ) + bulk(self.client, docs) + self.client.indices.refresh(index=self.es_index) + return + + def query(self, input_query: List[str], n_results: int, where: Dict[str, any]) -> List[str]: + """ + query contents from vector data base based on vector similarity + :param input_query: list of query string + :param n_results: no of similar documents to fetch from database + :param where: Optional. to filter data + """ + input_query_vector = self.embedding_fn(input_query) + query_vector = input_query_vector[0] + query = { + "script_score": { + "query": {"bool": {"must": [{"exists": {"field": "text"}}]}}, + "script": { + "source": "cosineSimilarity(params.input_query_vector, 'text_vector') + 1.0", + "params": {"input_query_vector": query_vector}, + }, + } + } + if "app_id" in where: + app_id = where["app_id"] + query["script_score"]["query"]["bool"]["must"] = [{"term": {"metadata.app_id": app_id}}] + _source = ["text"] + response = self.client.search(index=self.es_index, query=query, _source=_source, size=n_results) + docs = response["hits"]["hits"] + contents = [doc["_source"]["text"] for doc in docs] + return contents + + def count(self) -> int: + query = {"match_all": {}} + response = self.client.count(index=self.es_index, query=query) + doc_count = response["count"] + return doc_count + + def reset(self): + # Delete all data from the database + if self.client.indices.exists(index=self.es_index): + # delete index in Es + self.client.indices.delete(index=self.es_index) diff --git a/embedchain/vectordb/vector_db.py b/embedchain/vectordb/vector_db.py deleted file mode 100644 index 8d8bfeca3b..0000000000 --- a/embedchain/vectordb/vector_db.py +++ /dev/null @@ -1,159 +0,0 @@ -from typing import Any, List, Optional, Union - -from langchain.docstore.document import Document - -from embedchain.vectordb.chroma_db import ChromaDB -from embedchain.vectordb.elasticsearch_db import EsDB - - -class VectorDb: - """ - Database abstraction class, abstracting common functionality - :param db: (Vector) database instance to use for embeddings. Can be es/chroma - :param db_type: which type of database is used. [es, chroma] - """ - - def __init__(self, db: Union[ChromaDB, EsDB], db_type: Optional[str] = None): - self.db = db - self.db_type = db_type - - """ - Get existing doc ids present in vector database - :param ids: list of doc ids to check for existance - :param app_id: Optional application to filter data - """ - - def get(self, ids: List[str], app_id: Optional[str]) -> List[str]: - if self.db_type == "es": - query = {"bool": {"must": [{"ids": {"values": ids}}]}} - if app_id: - query["bool"]["must"].append({"term": {"metadata.app_id": app_id}}) - response = self.db.client.search(index=self.db.es_index, query=query, _source=False) - docs = response["hits"]["hits"] - ids = [doc["_id"] for doc in docs] - return set(ids) - - where = {"app_id": app_id} if app_id is not None else {} - existing_docs = self.db.collection.get( - ids=ids, - where=where, # optional filter - ) - - return set(existing_docs["ids"]) - - """ - add data in vector database - :param documents: list of texts to add - :param metadatas: list of metadata associated with docs - :param ids: ids of docs - """ - - def add(self, documents: List[str], metadatas: List[object], ids: List[str]) -> Any: - if self.db_type == "es": - docs = [] - embeddings = self.db.embedding_fn(documents) - for id, text, metadata, text_vector in zip(ids, documents, metadatas, embeddings): - docs.append( - { - "_index": self.db.es_index, - "_id": id, - "_source": {"text": text, "metadata": metadata, "text_vector": text_vector}, - } - ) - self.db.bulk(self.db.client, docs) - self.db.client.indices.refresh(index=self.db.es_index) - return - - self.db.collection.add(documents=documents, metadatas=metadatas, ids=ids) - - def _format_result(self, results): - # discuss why there was a need to create lagchain Document - return [ - (Document(page_content=result[0], metadata=result[1] or {}), result[2]) - for result in zip( - results["documents"][0], - results["metadatas"][0], - results["distances"][0], - ) - ] - - """ - query contents from vector data base based on vector similarity - :param input_query: list of query string - :param number_documents: no of similar documents to fetch from database - :param app_id: Optional app id for filtering data - """ - - def query( - self, input_query: List[str], number_documents: int, app_id: Optional[Union[int, str]] = None - ) -> List[str]: - if self.db_type == "es": - """ - Currently have taken max 2048 as vector dim, there is a need to re check the - accuracy of cosineSimilarity used to retrive similar documents - Not using Approximate kNN because cannot index dense vector due to dims > 1024 - https://www.elastic.co/guide/en/elasticsearch/reference/master/knn-search.html - Using Exact KNN - https://www.elastic.co/guide/en/elasticsearch/reference/master/knn-search.html#exact-knn - """ - input_query_vector = self.db.embedding_fn(input_query) - query_vector = input_query_vector[0] - query = { - "script_score": { - "query": {"bool": {"must": [{"exists": {"field": "text"}}]}}, - "script": { - "source": "cosineSimilarity(params.input_query_vector, 'text_vector') + 1.0", - "params": {"input_query_vector": query_vector}, - }, - } - } - if app_id: - query["script_score"]["query"]["bool"]["must"] = [{"term": {"metadata.app_id": app_id}}] - _source = ["text"] - size = number_documents - response = self.db.client.search(index=self.db.es_index, query=query, _source=_source, size=size) - docs = response["hits"]["hits"] - contents = [doc["_source"]["text"] for doc in docs] - return contents - - where = {"app_id": app_id} if app_id is not None else {} # optional filter - result = self.db.collection.query( - query_texts=[ - input_query, - ], - n_results=number_documents, - where=where, - ) - - results_formatted = self._format_result(result) - contents = [result[0].page_content for result in results_formatted] - return contents - - """ - get count of docs in the database - :param app_id: Optional app id to filter data - """ - - def count(self, app_id: Optional[Union[int, str]] = None) -> int: - if self.db_type == "es": - query = {"match_all": {}} - if app_id: - query = {"bool": {"must": [{"term": {"metadata.app_id": app_id}}]}} - response = self.db.client.count(index=self.db.es_index, query=query) - doc_count = response["count"] - return doc_count - - return self.db.collection.count() - - # Delete all data from the database - def reset(self): - if self.db_type == "es" and self.db.client.indices.exists(index=self.db.es_index): - # delete index in Es - self.db.client.indices.delete(index=self.db.es_index) - return - - self.db.collection.delete() - - # get Vector Db instance - def get_db(self): - return self.db diff --git a/pyproject.toml b/pyproject.toml index 9095b6c07a..d08bc234ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,7 @@ beautifulsoup4 = "^4.12.2" pypdf = "^3.11.0" pytube = "^15.0.0" llama-index = { version = "^0.7.21", optional = true } -elasticsearch = "^8.9.0" +elasticsearch = { version = "^8.9.0", optional = true } @@ -108,6 +108,7 @@ isort = "^5.12.0" [tool.poetry.extras] streamlit = ["streamlit"] community = ["llama-index"] +elasticsearch = ["elasticsearch"] [tool.poetry.group.docs.dependencies] diff --git a/setup.py b/setup.py index 21ae9ea46c..6f460683c6 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,10 @@ "pydantic==1.10.8", "replicate==0.9.0", "duckduckgo-search==3.8.4", - "elasticsearch>=8.0.0", ], - extras_require={"dev": ["black", "ruff", "isort", "pytest"], "community": ["llama-index==0.7.21"]}, + extras_require={ + "dev": ["black", "ruff", "isort", "pytest"], + "community": ["llama-index==0.7.21"], + "elasticsearch": ["elasticsearch>=8.9.0"], + }, ) diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py index af5f1fa8ea..47796107ce 100644 --- a/tests/vectordb/test_elasticsearch_db.py +++ b/tests/vectordb/test_elasticsearch_db.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import MagicMock, Mock, patch -from embedchain.vectordb.elasticsearch_db import EsDB +from embedchain.vectordb.elasticsearch_db import ElasticsearchDB class TestEsDB(unittest.TestCase): @@ -13,21 +13,21 @@ def setUp(self): def test_init_with_invalid_embedding_fn(self): # Test if an exception is raised when an invalid embedding_fn is provided with self.assertRaises(ValueError): - EsDB(embedding_fn=None) + ElasticsearchDB(embedding_fn=None) def test_init_with_invalid_vector_dim(self): # Test if an exception is raised when an invalid vector_dim is provided with self.assertRaises(ValueError): - EsDB(embedding_fn=Mock(), es_client=self.mock_client, vector_dim=None) + ElasticsearchDB(embedding_fn=Mock(), es_client=self.mock_client, vector_dim=None) def test_init_with_valid_embedding_and_client(self): - # check for successful creation of EsDB instance - esdb = EsDB(embedding_fn=Mock(), es_client=self.mock_client, vector_dim=1024) - self.assertIsInstance(esdb, EsDB) + # check for successful creation of ElasticsearchDB instance + esdb = ElasticsearchDB(embedding_fn=Mock(), es_client=self.mock_client, vector_dim=1024) + self.assertIsInstance(esdb, ElasticsearchDB) @patch("os.getenv") # Mock the os.getenv function to return None for ES_ENDPOINT def test_init_with_missing_endpoint(self, mock_os_getenv): # Test if an exception is raised when ES_ENDPOINT is missing mock_os_getenv.return_value = None with self.assertRaises(ValueError): - EsDB(embedding_fn=Mock()) + ElasticsearchDB(embedding_fn=Mock()) From 45e071d0ce2a26b8f254b50320476bca5484228d Mon Sep 17 00:00:00 2001 From: Prashant Date: Thu, 10 Aug 2023 22:41:19 +0530 Subject: [PATCH 4/4] fix: Using ElasticsearchDBConfig as es db config, updated documentation --- docs/advanced/datasource.mdx | 42 ------------------- docs/advanced/vector_database.mdx | 34 +++++++++++++++ docs/mint.json | 2 +- embedchain/config/__init__.py | 1 + embedchain/config/apps/BaseAppConfig.py | 11 ++++- embedchain/config/apps/CustomAppConfig.py | 4 ++ .../config/vectordbs/ElasticsearchDBConfig.py | 15 +++++++ embedchain/config/vectordbs/__init__.py | 0 embedchain/vectordb/elasticsearch_db.py | 23 +++++----- tests/vectordb/test_elasticsearch_db.py | 30 ++++++------- 10 files changed, 89 insertions(+), 73 deletions(-) delete mode 100644 docs/advanced/datasource.mdx create mode 100644 docs/advanced/vector_database.mdx create mode 100644 embedchain/config/vectordbs/ElasticsearchDBConfig.py create mode 100644 embedchain/config/vectordbs/__init__.py diff --git a/docs/advanced/datasource.mdx b/docs/advanced/datasource.mdx deleted file mode 100644 index 62a026e993..0000000000 --- a/docs/advanced/datasource.mdx +++ /dev/null @@ -1,42 +0,0 @@ ---- -title: '💾 Datasource' ---- - -## Vector Database - -We support `Chromadb` and `Elasticsearch` as two type of vector database. -`Chromadb` is used as default. - -### App -```python -import os - -from embedchain import App -from embedchain.config import AppConfig - -os.environ["ES_ENDPOINT"] = "elasticsearch_endpoint" -# Adds HTTP header 'Authorization: ApiKey ' -os.environ["ES_API_KEY_ID"] = "api_key_id" # Optional -os.environ["ES_API_KEY"] = "api_key" # Optional - - -es_app_config = AppConfig(db_type='es') -es_app = App(es_app_config) -``` -_To use elasticsearch as vector db we need an external running instance and connection config_ - -- `Elasticsearch` as vector database can be used by setting `db_type='es'` in `AppConfig`. -- `ES_ENDPOINT` is mandatory to connect to `Elasticsearch`. -- `ES_API_KEY_ID` and `ES_API_KEY` can be configured for authentication and connecting to `Elasticsearch`. -- An index with name `embedchain_store_1536` is created if not present. - -### OpenSourceApp -Similarly for Open source app set `db_type='es'` -```python -from embedchain import OpenSourceApp -from embedchain.config import OpenSourceAppConfig - -opensource_es_app_config = OpenSourceAppConfig(db_type='es') -opensource_es_app = OpenSourceApp(opensource_es_app_config) -``` -- An index with name `embedchain_store_384` is created if not present. \ No newline at end of file diff --git a/docs/advanced/vector_database.mdx b/docs/advanced/vector_database.mdx new file mode 100644 index 0000000000..b3cdda86b1 --- /dev/null +++ b/docs/advanced/vector_database.mdx @@ -0,0 +1,34 @@ +--- +title: '💾 Vector Database' +--- + +We support `Chroma` and `Elasticsearch` as two vector database. +`Chroma` is used as a default database. + +### Elasticsearch +In order to use `Elasticsearch` as vector database we need to use App type `CustomApp`. +```python +import os +from embedchain import CustomApp +from embedchain.config import CustomAppConfig, ElasticsearchDBConfig +from embedchain.models import Providers, EmbeddingFunctions, VectorDatabases + +os.environ["OPENAI_API_KEY"] = 'OPENAI_API_KEY' + +es_config = ElasticsearchDBConfig( + # elasticsearch url or list of nodes url with different hosts and ports. + es_url='http://localhost:9200', + # pass named parameters supported by Python Elasticsearch client + ca_certs="/path/to/http_ca.crt", + basic_auth=("username", "password") +) +config = CustomAppConfig( + embedding_fn=EmbeddingFunctions.OPENAI, + provider=Providers.OPENAI, + db_type=VectorDatabases.ELASTICSEARCH, + es_config=es_config, +) +es_app = CustomApp(config) +``` +- Set `db_type=VectorDatabases.ELASTICSEARCH` and `es_config=ElasticsearchDBConfig(es_url='')` in `CustomAppConfig`. +- `ElasticsearchDBConfig` accepts `es_url` as elasticsearch url or as list of nodes url with different hosts and ports. Additionally we can pass named paramaters supported by Python Elasticsearch client. \ No newline at end of file diff --git a/docs/mint.json b/docs/mint.json index 89164b7326..750afc0b88 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -32,7 +32,7 @@ }, { "group": "Advanced", - "pages": ["advanced/app_types", "advanced/interface_types", "advanced/adding_data","advanced/data_types", "advanced/query_configuration", "advanced/configuration", "advanced/testing", "advanced/datasource", "advanced/showcase"] + "pages": ["advanced/app_types", "advanced/interface_types", "advanced/adding_data","advanced/data_types", "advanced/query_configuration", "advanced/configuration", "advanced/testing", "advanced/vector_database", "advanced/showcase"] }, { "group": "Examples", diff --git a/embedchain/config/__init__.py b/embedchain/config/__init__.py index 9bd12b2644..684b63fa7e 100644 --- a/embedchain/config/__init__.py +++ b/embedchain/config/__init__.py @@ -5,3 +5,4 @@ from .BaseConfig import BaseConfig # noqa: F401 from .ChatConfig import ChatConfig # noqa: F401 from .QueryConfig import QueryConfig # noqa: F401 +from .vectordbs.ElasticsearchDBConfig import ElasticsearchDBConfig # noqa: F401 diff --git a/embedchain/config/apps/BaseAppConfig.py b/embedchain/config/apps/BaseAppConfig.py index f558ca915b..4b85c1a400 100644 --- a/embedchain/config/apps/BaseAppConfig.py +++ b/embedchain/config/apps/BaseAppConfig.py @@ -1,6 +1,7 @@ import logging from embedchain.config.BaseConfig import BaseConfig +from embedchain.config.vectordbs import ElasticsearchDBConfig from embedchain.models import VectorDatabases, VectorDimensions @@ -20,6 +21,7 @@ def __init__( collection_name=None, db_type: VectorDatabases = None, vector_dim: VectorDimensions = None, + es_config: ElasticsearchDBConfig = None, ): """ :param log_level: Optional. (String) Debug level @@ -32,6 +34,7 @@ def __init__( :param collection_name: Optional. Collection name for the database. :param db_type: Optional. type of Vector database to use :param vector_dim: Vector dimension generated by embedding fn + :param es_config: Optional. elasticsearch database config to be used for connection """ self._setup_logging(log_level) self.collection_name = collection_name if collection_name else "embedchain_store" @@ -43,12 +46,13 @@ def __init__( db_type=db_type, vector_dim=vector_dim, collection_name=self.collection_name, + es_config=es_config, ) self.id = id return @staticmethod - def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name): + def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config): """ Get db based on db_type, db with default database (`ChromaDb`) :param Optional. (Vector) database to use for embeddings. @@ -58,6 +62,7 @@ def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name): :param db_type: Optional. db type to use. Supported values (`es`, `chroma`) :param vector_dim: Vector dimension generated by embedding fn :param collection_name: Optional. Collection name for the database. + :param es_config: Optional. elasticsearch database config to be used for connection :raises ValueError: BaseAppConfig knows no default embedding function. :returns: database instance """ @@ -70,7 +75,9 @@ def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name): if db_type == VectorDatabases.ELASTICSEARCH: from embedchain.vectordb.elasticsearch_db import ElasticsearchDB - return ElasticsearchDB(embedding_fn=embedding_fn, vector_dim=vector_dim, collection_name=collection_name) + return ElasticsearchDB( + embedding_fn=embedding_fn, vector_dim=vector_dim, collection_name=collection_name, es_config=es_config + ) from embedchain.vectordb.chroma_db import ChromaDB diff --git a/embedchain/config/apps/CustomAppConfig.py b/embedchain/config/apps/CustomAppConfig.py index abb2e13448..1d2dd916d5 100644 --- a/embedchain/config/apps/CustomAppConfig.py +++ b/embedchain/config/apps/CustomAppConfig.py @@ -3,6 +3,7 @@ from chromadb.api.types import Documents, Embeddings from dotenv import load_dotenv +from embedchain.config.vectordbs import ElasticsearchDBConfig from embedchain.models import EmbeddingFunctions, Providers, VectorDatabases, VectorDimensions from .BaseAppConfig import BaseAppConfig @@ -29,6 +30,7 @@ def __init__( open_source_app_config=None, deployment_name=None, db_type: VectorDatabases = None, + es_config: ElasticsearchDBConfig = None, ): """ :param log_level: Optional. (String) Debug level @@ -43,6 +45,7 @@ def __init__( :param provider: Optional. (Providers): LLM Provider to use. :param open_source_app_config: Optional. Config instance needed for open source apps. :param db_type: Optional. type of Vector database to use. + :param es_config: Optional. elasticsearch database config to be used for connection """ if provider: self.provider = provider @@ -63,6 +66,7 @@ def __init__( collection_name=collection_name, db_type=db_type, vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn), + es_config=es_config, ) @staticmethod diff --git a/embedchain/config/vectordbs/ElasticsearchDBConfig.py b/embedchain/config/vectordbs/ElasticsearchDBConfig.py new file mode 100644 index 0000000000..6e7dd0f9c1 --- /dev/null +++ b/embedchain/config/vectordbs/ElasticsearchDBConfig.py @@ -0,0 +1,15 @@ +from typing import Dict, List, Union + +from embedchain.config.BaseConfig import BaseConfig + + +class ElasticsearchDBConfig(BaseConfig): + """ + Config to initialize an elasticsearch client. + :param es_url. elasticsearch url or list of nodes url to be used for connection + :param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch. + """ + + def __init__(self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]): + self.ES_URL = es_url + self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS diff --git a/embedchain/config/vectordbs/__init__.py b/embedchain/config/vectordbs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/embedchain/vectordb/elasticsearch_db.py b/embedchain/vectordb/elasticsearch_db.py index 145745d305..4371237eb4 100644 --- a/embedchain/vectordb/elasticsearch_db.py +++ b/embedchain/vectordb/elasticsearch_db.py @@ -1,5 +1,4 @@ -import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List try: from elasticsearch import Elasticsearch @@ -9,6 +8,7 @@ "Elasticsearch requires extra dependencies. Install with `pip install embedchain[elasticsearch]`" ) from None +from embedchain.config import ElasticsearchDBConfig from embedchain.models.VectorDimensions import VectorDimensions from embedchain.vectordb.base_vector_db import BaseVectorDB @@ -16,32 +16,29 @@ class ElasticsearchDB(BaseVectorDB): def __init__( self, + es_config: ElasticsearchDBConfig = None, embedding_fn: Callable[[list[str]], list[str]] = None, - es_client: Optional[Elasticsearch] = None, vector_dim: VectorDimensions = None, collection_name: str = None, ): """ Elasticsearch as vector database + :param es_config. elasticsearch database config to be used for connection :param embedding_fn: Function to generate embedding vectors. :param vector_dim: Vector dimension generated by embedding fn :param collection_name: Optional. Collection name for the database. """ if not hasattr(embedding_fn, "__call__"): raise ValueError("Embedding function is not a function") - self.embedding_fn = embedding_fn - endpoint = os.getenv("ES_ENDPOINT") - api_key_id = os.getenv("ES_API_KEY_ID") - api_key = os.getenv("ES_API_KEY") - api_key_id = api_key_id if api_key_id is not None else "" - api_key = api_key if api_key is not None else "" - if not endpoint and not es_client: - raise ValueError("Elasticsearch endpoint is required to connect") + if es_config is None: + raise ValueError("ElasticsearchDBConfig is required") if vector_dim is None: raise ValueError("Vector Dimension is required to refer correct index and mapping") - self.client = es_client if es_client is not None else Elasticsearch(endpoint, api_key=(api_key_id, api_key)) + if collection_name is None: + raise ValueError("collection name is required. It cannot be empty") + self.embedding_fn = embedding_fn + self.client = Elasticsearch(es_config.ES_URL, **es_config.ES_EXTRA_PARAMS) self.vector_dim = vector_dim - # self.collection_name = collection_name if collection_name else "embedchain_store" self.es_index = f"{collection_name}_{self.vector_dim}" index_settings = { "mappings": { diff --git a/tests/vectordb/test_elasticsearch_db.py b/tests/vectordb/test_elasticsearch_db.py index 47796107ce..4f316eae88 100644 --- a/tests/vectordb/test_elasticsearch_db.py +++ b/tests/vectordb/test_elasticsearch_db.py @@ -1,33 +1,33 @@ import unittest -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock +from embedchain.config import ElasticsearchDBConfig from embedchain.vectordb.elasticsearch_db import ElasticsearchDB class TestEsDB(unittest.TestCase): def setUp(self): - # set mock es client - self.mock_client = MagicMock() - self.mock_client.indices.exists.return_value = True + self.es_config = ElasticsearchDBConfig() + self.vector_dim = 384 def test_init_with_invalid_embedding_fn(self): # Test if an exception is raised when an invalid embedding_fn is provided with self.assertRaises(ValueError): ElasticsearchDB(embedding_fn=None) + def test_init_with_invalid_es_config(self): + # Test if an exception is raised when an invalid es_config is provided + with self.assertRaises(ValueError): + ElasticsearchDB(embedding_fn=Mock(), es_config=None) + def test_init_with_invalid_vector_dim(self): # Test if an exception is raised when an invalid vector_dim is provided with self.assertRaises(ValueError): - ElasticsearchDB(embedding_fn=Mock(), es_client=self.mock_client, vector_dim=None) - - def test_init_with_valid_embedding_and_client(self): - # check for successful creation of ElasticsearchDB instance - esdb = ElasticsearchDB(embedding_fn=Mock(), es_client=self.mock_client, vector_dim=1024) - self.assertIsInstance(esdb, ElasticsearchDB) + ElasticsearchDB(embedding_fn=Mock(), es_config=self.es_config, vector_dim=None) - @patch("os.getenv") # Mock the os.getenv function to return None for ES_ENDPOINT - def test_init_with_missing_endpoint(self, mock_os_getenv): - # Test if an exception is raised when ES_ENDPOINT is missing - mock_os_getenv.return_value = None + def test_init_with_invalid_collection_name(self): + # Test if an exception is raised when an invalid collection_name is provided with self.assertRaises(ValueError): - ElasticsearchDB(embedding_fn=Mock()) + ElasticsearchDB( + embedding_fn=Mock(), es_config=self.es_config, vector_dim=self.vector_dim, collection_name=None + )