Skip to content

Commit

Permalink
Upgrade the chromadb version to 0.4.8 and open its settings configura…
Browse files Browse the repository at this point in the history
…tion. (mem0ai#517)
  • Loading branch information
mggger committed Sep 4, 2023
1 parent 433c415 commit eecdbc5
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 21 deletions.
7 changes: 5 additions & 2 deletions embedchain/config/apps/BaseAppConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
db_type: VectorDatabases = None,
vector_dim: VectorDimensions = None,
es_config: ElasticsearchDBConfig = None,
chroma_settings: dict = {},
):
"""
:param log_level: Optional. (String) Debug level
Expand All @@ -38,6 +39,7 @@ def __init__(
:param db_type: Optional. type of Vector database to use
:param vector_dim: Vector dimension generated by embedding fn
:param es_config: Optional. elasticsearch database config to be used for connection
:param chroma_settings: Optional. Chroma settings for connection.
"""
self._setup_logging(log_level)
self.collection_name = collection_name if collection_name else "embedchain_store"
Expand All @@ -50,13 +52,14 @@ def __init__(
vector_dim=vector_dim,
collection_name=self.collection_name,
es_config=es_config,
chroma_settings=chroma_settings,
)
self.id = id
self.collect_metrics = True if (collect_metrics is True or collect_metrics is None) else False
return

@staticmethod
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config):
def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, es_config, chroma_settings):
"""
Get db based on db_type, db with default database (`ChromaDb`)
:param Optional. (Vector) database to use for embeddings.
Expand Down Expand Up @@ -85,7 +88,7 @@ def get_db(db, embedding_fn, host, port, db_type, vector_dim, collection_name, e

from embedchain.vectordb.chroma_db import ChromaDB

return ChromaDB(embedding_fn=embedding_fn, host=host, port=port)
return ChromaDB(embedding_fn=embedding_fn, host=host, port=port, chroma_settings=chroma_settings)

def _setup_logging(self, debug_level):
level = logging.WARNING # Default level
Expand Down
3 changes: 3 additions & 0 deletions embedchain/config/apps/CustomAppConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
collect_metrics: Optional[bool] = None,
db_type: VectorDatabases = None,
es_config: ElasticsearchDBConfig = None,
chroma_settings: dict = {},
):
"""
:param log_level: Optional. (String) Debug level
Expand All @@ -51,6 +52,7 @@ def __init__(
:param collect_metrics: Defaults to True. Send anonymous telemetry to improve embedchain.
:param db_type: Optional. type of Vector database to use.
:param es_config: Optional. elasticsearch database config to be used for connection
:param chroma_settings: Optional. Chroma settings for connection.
"""
if provider:
self.provider = provider
Expand All @@ -73,6 +75,7 @@ def __init__(
db_type=db_type,
vector_dim=CustomAppConfig.get_vector_dimension(embedding_function=embedding_fn),
es_config=es_config,
chroma_settings=chroma_settings,
)

@staticmethod
Expand Down
22 changes: 15 additions & 7 deletions embedchain/vectordb/chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,31 @@
class ChromaDB(BaseVectorDB):
"""Vector database using ChromaDB."""

def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None):
def __init__(self, db_dir=None, embedding_fn=None, host=None, port=None, chroma_settings={}):
self.embedding_fn = embedding_fn

if not hasattr(embedding_fn, "__call__"):
raise ValueError("Embedding function is not a function")

self.settings = Settings()
for key, value in chroma_settings.items():
if hasattr(self.settings, key):
setattr(self.settings, key, value)

if host and port:
logging.info(f"Connecting to ChromaDB server: {host}:{port}")
self.client = chromadb.HttpClient(host=host, port=port)
self.settings.chroma_server_host = host
self.settings.chroma_server_http_port = port
self.settings.chroma_api_impl = "chromadb.api.fastapi.FastAPI"

else:
if db_dir is None:
db_dir = "db"
self.settings = Settings(anonymized_telemetry=False, allow_reset=True)
self.client = chromadb.PersistentClient(
path=db_dir,
settings=self.settings,
)

self.settings.persist_directory = db_dir
self.settings.is_persistent = True

self.client = chromadb.Client(self.settings)
super().__init__()

def _get_or_create_db(self):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ langchain = "^0.0.279"
requests = "^2.31.0"
openai = "^0.27.5"
tiktoken = "^0.4.0"
chromadb ="^0.4.2"
chromadb ="^0.4.8"
youtube-transcript-api = "^0.6.1"
beautifulsoup4 = "^4.12.2"
pypdf = "^3.11.0"
Expand Down
9 changes: 7 additions & 2 deletions tests/embedchain/test_embedchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from unittest.mock import patch

from embedchain import App
from embedchain.config import AppConfig
from embedchain.config import AppConfig, CustomAppConfig
from embedchain.models import EmbeddingFunctions, Providers


class TestChromaDbHostsLoglevel(unittest.TestCase):
Expand Down Expand Up @@ -42,7 +43,11 @@ def test_add_after_reset(self):
"""
Test if the `App` instance is correctly reconstructed after a reset.
"""
app = App()
app = App(
CustomAppConfig(
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
)
)
app.reset()

# Make sure the client is still healthy
Expand Down
58 changes: 49 additions & 9 deletions tests/vectordb/test_chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from unittest.mock import patch

from embedchain import App
from embedchain.config import AppConfig
from embedchain.config import AppConfig, CustomAppConfig
from embedchain.models import EmbeddingFunctions, Providers
from embedchain.vectordb.chroma_db import ChromaDB


Expand All @@ -21,6 +22,24 @@ def test_init_with_host_and_port(self):
self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port)

def test_init_with_basic_auth(self):
host = "test-host"
port = "1234"

chroma_auth_settings = {
"chroma_client_auth_provider": "chromadb.auth.basic.BasicAuthClientProvider",
"chroma_client_auth_credentials": "admin:admin",
}

db = ChromaDB(host=host, port=port, embedding_fn=len, chroma_settings=chroma_auth_settings)
settings = db.client.get_settings()
self.assertEqual(settings.chroma_server_host, host)
self.assertEqual(settings.chroma_server_http_port, port)
self.assertEqual(settings.chroma_client_auth_provider, chroma_auth_settings["chroma_client_auth_provider"])
self.assertEqual(
settings.chroma_client_auth_credentials, chroma_auth_settings["chroma_client_auth_credentials"]
)


# Review this test
class TestChromaDbHostsInit(unittest.TestCase):
Expand Down Expand Up @@ -68,12 +87,18 @@ def test_init_with_host_and_port(self, mock_client):


class TestChromaDbDuplicateHandling:
app_with_settings = App(
CustomAppConfig(
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
)
)

def test_duplicates_throw_warning(self, caplog):
"""
Test that add duplicates throws an error.
"""
# Start with a clean app
App().reset()
self.app_with_settings.reset()

app = App(config=AppConfig(collect_metrics=False))
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
Expand All @@ -88,7 +113,7 @@ def test_duplicates_collections_no_warning(self, caplog):
# NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.

# Start with a clean app
App().reset()
self.app_with_settings.reset()

app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1")
Expand All @@ -100,6 +125,12 @@ def test_duplicates_collections_no_warning(self, caplog):


class TestChromaDbCollection(unittest.TestCase):
app_with_settings = App(
CustomAppConfig(
provider=Providers.OPENAI, embedding_fn=EmbeddingFunctions.OPENAI, chroma_settings={"allow_reset": True}
)
)

def test_init_with_default_collection(self):
"""
Test if the `App` instance is initialized with the correct default collection name.
Expand Down Expand Up @@ -131,7 +162,7 @@ def test_changes_encapsulated(self):
Test that changes to one collection do not affect the other collection
"""
# Start with a clean app
App().reset()
self.app_with_settings.reset()

app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1")
Expand All @@ -157,7 +188,7 @@ def test_collections_are_persistent(self):
Test that a collection can be picked up later.
"""
# Start with a clean app
App().reset()
self.app_with_settings.reset()

app = App(config=AppConfig(collect_metrics=False))
app.set_collection("test_collection_1")
Expand All @@ -175,7 +206,7 @@ def test_parallel_collections(self):
the other app.
"""
# Start clean
App().reset()
self.app_with_settings.reset()

# Create two apps
app1 = App(AppConfig(collection_name="test_collection_1", collect_metrics=False))
Expand All @@ -201,7 +232,7 @@ def test_ids_share_collections(self):
Different ids should still share collections.
"""
# Start clean
App().reset()
self.app_with_settings.reset()

# Create two apps
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
Expand All @@ -220,11 +251,20 @@ def test_reset(self):
Resetting should hit all collections and ids.
"""
# Start clean
App().reset()
self.app_with_settings.reset()

# Create four apps.
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1", collect_metrics=False))
app1 = App(
CustomAppConfig(
collection_name="one_collection",
id="new_app_id_1",
collect_metrics=False,
provider=Providers.OPENAI,
embedding_fn=EmbeddingFunctions.OPENAI,
chroma_settings={"allow_reset": True},
)
)
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2", collect_metrics=False))
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1", collect_metrics=False))
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4", collect_metrics=False))
Expand Down

0 comments on commit eecdbc5

Please sign in to comment.