Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added support for elasticsearch as a datasource #402

Merged
merged 4 commits into from
Aug 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix: Using ElasticsearchDBConfig as es db config, updated documentation
  • Loading branch information
pc9 committed Aug 10, 2023
commit 45e071d0ce2a26b8f254b50320476bca5484228d
42 changes: 0 additions & 42 deletions docs/advanced/datasource.mdx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, I think we have to talk about the name of this section. Why not call it what it is, Vector Database? And then the sections are not clear to me. it should probably be

<h2>Vector Database</h2>
<h3>ChromaDb</h3>
<h3>Elasticsearch</h3>

but saying "Chromadb" is used as default. and then jumping to an Elasticsearch example might be confusing.

Copy link
Contributor Author

@pc9 pc9 Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree, have added Vector Database and Elasticsearch heading, I was unsure what to add under ChromaDb so I have skipped it.

This file was deleted.

34 changes: 34 additions & 0 deletions docs/advanced/vector_database.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
---
title: '💾 Vector Database'
---

We support `Chroma` and `Elasticsearch` as two vector database.
`Chroma` is used as a default database.

### Elasticsearch
In order to use `Elasticsearch` as vector database we need to use App type `CustomApp`.
```python
import os
from embedchain import CustomApp
from embedchain.config import CustomAppConfig, ElasticsearchDBConfig
from embedchain.models import Providers, EmbeddingFunctions, VectorDatabases

os.environ["OPENAI_API_KEY"] = 'OPENAI_API_KEY'

es_config = ElasticsearchDBConfig(
# elasticsearch url or list of nodes url with different hosts and ports.
es_url='https://localhost:9200',
# pass named parameters supported by Python Elasticsearch client
ca_certs="/path/to/http_ca.crt",
basic_auth=("username", "password")
)
config = CustomAppConfig(
embedding_fn=EmbeddingFunctions.OPENAI,
provider=Providers.OPENAI,
db_type=VectorDatabases.ELASTICSEARCH,
es_config=es_config,
)
es_app = CustomApp(config)
```
- Set `db_type=VectorDatabases.ELASTICSEARCH` and `es_config=ElasticsearchDBConfig(es_url='')` in `CustomAppConfig`.
- `ElasticsearchDBConfig` accepts `es_url` as elasticsearch url or as list of nodes url with different hosts and ports. Additionally we can pass named paramaters supported by Python Elasticsearch client.
2 changes: 1 addition & 1 deletion docs/mint.json
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
},
{
"group": "Advanced",
"pages": ["advanced/app_types", "advanced/interface_types", "advanced/adding_data","advanced/data_types", "advanced/query_configuration", "advanced/configuration", "advanced/testing", "advanced/datasource", "advanced/showcase"]
"pages": ["advanced/app_types", "advanced/interface_types", "advanced/adding_data","advanced/data_types", "advanced/query_configuration", "advanced/configuration", "advanced/testing", "advanced/vector_database", "advanced/showcase"]
},
{
"group": "Examples",
Expand Down
1 change: 1 addition & 0 deletions embedchain/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .BaseConfig import BaseConfig # noqa: F401
from .ChatConfig import ChatConfig # noqa: F401
from .QueryConfig import QueryConfig # noqa: F401
from .vectordbs.ElasticsearchDBConfig import ElasticsearchDBConfig # noqa: F401
11 changes: 9 additions & 2 deletions embedchain/config/apps/BaseAppConfig.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

from embedchain.config.BaseConfig import BaseConfig
from embedchain.config.vectordbs import ElasticsearchDBConfig
from embedchain.models import VectorDatabases, VectorDimensions


Expand All @@ -20,6 +21,7 @@ def __init__(
collection_name=None,
db_type: VectorDatabases = None,
vector_dim: VectorDimensions = None,
es_config: ElasticsearchDBConfig = None,
):
"""
:param log_level: Optional. (String) Debug level
Expand All @@ -32,6 +34,7 @@ def __init__(
:param collection_name: Optional. Collection name for the database.
:param db_type: Optional. type of Vector database to use
:param vector_dim: Vector dimension generated by embedding fn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should raise an error if vector_dim is used with Chroma, in order not to confuse anyone, because it does nothing for Chroma (at this point).

Similar handling is done here: https://github.com/embedchain/embedchain/blob/5e94980aaa801843661ccd18a16f46ed8c28a871/embedchain/apps/CustomApp.py#L84

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we can skip this part. right now vector_dim is not used with chroma as it computes the dimension itself. but in future we may need it.

:param es_config: Optional. elasticsearch database config to be used for connection
"""
self._setup_logging(log_level)
self.collection_name = collection_name if collection_name else "embedchain_store"
Expand All @@ -43,12 +46,13 @@ def __init__(
db_type=db_type,
vector_dim=vector_dim,
collection_name=self.collection_name,
es_config=es_config,
)
self.id = id
return

@staticmethod
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name):
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config):
"""
Get db based on db_type, db with default database (`ChromaDb`)
:param Optional. (Vector) database to use for embeddings.
Expand All @@ -58,6 +62,7 @@ def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name):
:param db_type: Optional. db type to use. Supported values (`es`, `chroma`)
:param vector_dim: Vector dimension generated by embedding fn
:param collection_name: Optional. Collection name for the database.
:param es_config: Optional. elasticsearch database config to be used for connection
:raises ValueError: BaseAppConfig knows no default embedding function.
:returns: database instance
"""
Expand All @@ -70,7 +75,9 @@ def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name):
if db_type == VectorDatabases.ELASTICSEARCH:
from embedchain.vectordb.elasticsearch_db import ElasticsearchDB

return ElasticsearchDB(embedding_fn=embedding_fn, vector_dim=vector_dim, collection_name=collection_name)
return ElasticsearchDB(
embedding_fn=embedding_fn, vector_dim=vector_dim, collection_name=collection_name, es_config=es_config
)

from embedchain.vectordb.chroma_db import ChromaDB

Expand Down
4 changes: 4 additions & 0 deletions embedchain/config/apps/CustomAppConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from chromadb.api.types import Documents, Embeddings
from dotenv import load_dotenv

from embedchain.config.vectordbs import ElasticsearchDBConfig
from embedchain.models import EmbeddingFunctions, Providers, VectorDatabases, VectorDimensions

from .BaseAppConfig import BaseAppConfig
Expand All @@ -29,6 +30,7 @@ def __init__(
open_source_app_config=None,
deployment_name=None,
db_type: VectorDatabases = None,
es_config: ElasticsearchDBConfig = None,
):
"""
:param log_level: Optional. (String) Debug level
Expand All @@ -43,6 +45,7 @@ def __init__(
:param provider: Optional. (Providers): LLM Provider to use.
:param open_source_app_config: Optional. Config instance needed for open source apps.
:param db_type: Optional. type of Vector database to use.
:param es_config: Optional. elasticsearch database config to be used for connection
"""
if provider:
self.provider = provider
Expand All @@ -63,6 +66,7 @@ def __init__(
collection_name=collection_name,
db_type=db_type,
vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn),
es_config=es_config,
)

@staticmethod
Expand Down
15 changes: 15 additions & 0 deletions embedchain/config/vectordbs/ElasticsearchDBConfig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Dict, List, Union

from embedchain.config.BaseConfig import BaseConfig


class ElasticsearchDBConfig(BaseConfig):
"""
Config to initialize an elasticsearch client.
:param es_url. elasticsearch url or list of nodes url to be used for connection
:param ES_EXTRA_PARAMS: extra params dict that can be passed to elasticsearch.
"""

def __init__(self, es_url: Union[str, List[str]] = None, **ES_EXTRA_PARAMS: Dict[str, any]):
self.ES_URL = es_url
self.ES_EXTRA_PARAMS = ES_EXTRA_PARAMS
Empty file.
23 changes: 10 additions & 13 deletions embedchain/vectordb/elasticsearch_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List

try:
from elasticsearch import Elasticsearch
Expand All @@ -9,39 +8,37 @@
"Elasticsearch requires extra dependencies. Install with `pip install embedchain[elasticsearch]`"
) from None

from embedchain.config import ElasticsearchDBConfig
from embedchain.models.VectorDimensions import VectorDimensions
from embedchain.vectordb.base_vector_db import BaseVectorDB


class ElasticsearchDB(BaseVectorDB):
def __init__(
self,
es_config: ElasticsearchDBConfig = None,
embedding_fn: Callable[[list[str]], list[str]] = None,
es_client: Optional[Elasticsearch] = None,
vector_dim: VectorDimensions = None,
collection_name: str = None,
):
"""
Elasticsearch as vector database
:param es_config. elasticsearch database config to be used for connection
:param embedding_fn: Function to generate embedding vectors.
:param vector_dim: Vector dimension generated by embedding fn
:param collection_name: Optional. Collection name for the database.
"""
if not hasattr(embedding_fn, "__call__"):
raise ValueError("Embedding function is not a function")
self.embedding_fn = embedding_fn
endpoint = os.getenv("ES_ENDPOINT")
api_key_id = os.getenv("ES_API_KEY_ID")
api_key = os.getenv("ES_API_KEY")
api_key_id = api_key_id if api_key_id is not None else ""
api_key = api_key if api_key is not None else ""
if not endpoint and not es_client:
raise ValueError("Elasticsearch endpoint is required to connect")
if es_config is None:
raise ValueError("ElasticsearchDBConfig is required")
if vector_dim is None:
raise ValueError("Vector Dimension is required to refer correct index and mapping")
self.client = es_client if es_client is not None else Elasticsearch(endpoint, api_key=(api_key_id, api_key))
if collection_name is None:
raise ValueError("collection name is required. It cannot be empty")
self.embedding_fn = embedding_fn
self.client = Elasticsearch(es_config.ES_URL, **es_config.ES_EXTRA_PARAMS)
self.vector_dim = vector_dim
# self.collection_name = collection_name if collection_name else "embedchain_store"
self.es_index = f"{collection_name}_{self.vector_dim}"
index_settings = {
"mappings": {
Expand Down
30 changes: 15 additions & 15 deletions tests/vectordb/test_elasticsearch_db.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
import unittest
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import Mock

from embedchain.config import ElasticsearchDBConfig
from embedchain.vectordb.elasticsearch_db import ElasticsearchDB


class TestEsDB(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this isn't a helpful comment, but maybe more positive tests wouldn't hurt. Like testing add, get, reset, and not just testing illegal methods.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we should have both positive and negative tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need help here, need to figure out how to mock elasticsearch client to successfully test both positive and negative test cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pc9 : can you open a new issue for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure will do that.

def setUp(self):
# set mock es client
self.mock_client = MagicMock()
self.mock_client.indices.exists.return_value = True
self.es_config = ElasticsearchDBConfig()
self.vector_dim = 384

def test_init_with_invalid_embedding_fn(self):
# Test if an exception is raised when an invalid embedding_fn is provided
with self.assertRaises(ValueError):
ElasticsearchDB(embedding_fn=None)

def test_init_with_invalid_es_config(self):
# Test if an exception is raised when an invalid es_config is provided
with self.assertRaises(ValueError):
ElasticsearchDB(embedding_fn=Mock(), es_config=None)

def test_init_with_invalid_vector_dim(self):
# Test if an exception is raised when an invalid vector_dim is provided
with self.assertRaises(ValueError):
ElasticsearchDB(embedding_fn=Mock(), es_client=self.mock_client, vector_dim=None)

def test_init_with_valid_embedding_and_client(self):
# check for successful creation of ElasticsearchDB instance
esdb = ElasticsearchDB(embedding_fn=Mock(), es_client=self.mock_client, vector_dim=1024)
self.assertIsInstance(esdb, ElasticsearchDB)
ElasticsearchDB(embedding_fn=Mock(), es_config=self.es_config, vector_dim=None)

@patch("os.getenv") # Mock the os.getenv function to return None for ES_ENDPOINT
def test_init_with_missing_endpoint(self, mock_os_getenv):
# Test if an exception is raised when ES_ENDPOINT is missing
mock_os_getenv.return_value = None
def test_init_with_invalid_collection_name(self):
# Test if an exception is raised when an invalid collection_name is provided
with self.assertRaises(ValueError):
ElasticsearchDB(embedding_fn=Mock())
ElasticsearchDB(
embedding_fn=Mock(), es_config=self.es_config, vector_dim=self.vector_dim, collection_name=None
)