Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: collection name everywhere #310

Merged
merged 19 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2a40d31
feat: allow switching between collections and specifying initial coll…
jonasiwnl Jul 18, 2023
09ec899
Merge branch 'feat/collection-name-everywhere' of https://github.com/…
jonasiwnl Jul 18, 2023
abe9f60
refactor: make collection name everywhere feature compatible with Bas…
jonasiwnl Jul 18, 2023
644c5a1
Merge pull request #1 from jonasiwnl/main
jonasiwnl Jul 18, 2023
201ac5d
chore: remove sentence-transformers dependency
jonasiwnl Jul 18, 2023
5bbcfbc
chore: revert changes to pyproject.toml
jonasiwnl Jul 18, 2023
52f1891
Merge branch 'main' into feat/collection-name-everywhere
cachho Jul 18, 2023
8257bf5
chore: fix merge errors, reformat, remove unused imports
jonasiwnl Jul 18, 2023
1f6b22a
Merge branch 'main' into feat/collection-name-everywhere
jonasiwnl Jul 20, 2023
8d6ec4d
chore: add default collection name to BaseAppConfig
jonasiwnl Jul 28, 2023
ce62702
Merge branch 'main' into feat/collection-name-everywhere
jonasiwnl Jul 28, 2023
ded8ab6
chore: remove default collection name redundancy
jonasiwnl Jul 28, 2023
f69966f
Merge branch 'feat/collection-name-everywhere' of https://github.com/…
jonasiwnl Jul 28, 2023
c8edfed
chore: change collection name to old default
cachho Aug 2, 2023
fee6f94
chore: old default name for compatibility
cachho Aug 2, 2023
cb04be0
refactor: unified argument order
cachho Aug 2, 2023
4dcb811
docs: fix param name
cachho Aug 2, 2023
7d2dfa0
chore: use old default collection name
cachho Aug 2, 2023
1b7b461
test: added unit tests
cachho Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/advanced/configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ from chromadb.utils import embedding_functions
config = AppConfig(log_level="DEBUG")
naval_chat_bot = App(config)

# Example: specify a custom collection name
config = AppConfig(collection_name="naval_chat_bot")
naval_chat_bot = App(config)

# Example: define your own chunker config for `youtube_video`
chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=100, length_function=len)
naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(chunker=chunker_config))
Expand Down
9 changes: 5 additions & 4 deletions docs/advanced/query_configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ title: '🔍 Query configurations'

## AppConfig

| option | description | type | default |
|-------------|-----------------------|---------------------------------|------------------------|
| log_level | log level | string | WARNING |
| option | description | type | default |
|-----------|-----------------------|---------------------------------|------------------------|
| log_level | log level | string | WARNING |
| embedding_fn| embedding function | chromadb.utils.embedding_functions | \{text-embedding-ada-002\} |
| db | vector database (experimental) | BaseVectorDB | ChromaDB |
| db | vector database (experimental) | BaseVectorDB | ChromaDB |
| collection_name | initial collection name for the database | string | embedchain_store |


## AddConfig
Expand Down
10 changes: 8 additions & 2 deletions embedchain/config/apps/AppConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,22 @@ class AppConfig(BaseAppConfig):
Config to initialize an embedchain custom `App` instance, with extra config options.
"""

def __init__(self, log_level=None, host=None, port=None, id=None):
def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
"""
super().__init__(
log_level=log_level, embedding_fn=AppConfig.default_embedding_function(), host=host, port=port, id=id
log_level=log_level,
embedding_fn=AppConfig.default_embedding_function(),
host=host,
port=port,
id=id,
collection_name=collection_name,
)

@staticmethod
Expand Down
6 changes: 4 additions & 2 deletions embedchain/config/apps/BaseAppConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,21 @@ class BaseAppConfig(BaseConfig):
Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
"""

def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None):
def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None, collection_name=None):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param embedding_fn: Embedding function to use.
:param db: Optional. (Vector) database instance to use for embeddings.
:param id: Optional. ID of the app. Document metadata will have this id.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
"""
self._setup_logging(log_level)

self.db = db if db else BaseAppConfig.default_db(embedding_fn=embedding_fn, host=host, port=port)
self.collection_name = collection_name if collection_name else "embedchain_store"
self.id = id
return

Expand Down
6 changes: 4 additions & 2 deletions embedchain/config/apps/CustomAppConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def __init__(
host=None,
port=None,
id=None,
collection_name=None,
provider: Providers = None,
model=None,
open_source_app_config=None,
):
cachho marked this conversation as resolved.
Show resolved Hide resolved
"""
Expand All @@ -34,9 +34,10 @@ def __init__(
:param embedding_fn: Optional. Embedding function to use.
:param embedding_fn_model: Optional. Model name to use for embedding function.
:param db: Optional. (Vector) database to use for embeddings.
:param id: Optional. ID of the app. Document metadata will have this id.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
:param provider: Optional. (Providers): LLM Provider to use.
:param open_source_app_config: Optional. Config instance needed for open source apps.
"""
Expand All @@ -54,6 +55,7 @@ def __init__(
host=host,
port=port,
id=id,
collection_name=collection_name,
)

@staticmethod
Expand Down
4 changes: 3 additions & 1 deletion embedchain/config/apps/OpenSourceAppConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ class OpenSourceAppConfig(BaseAppConfig):
Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.
"""

def __init__(self, log_level=None, host=None, port=None, id=None, model=None):
def __init__(self, log_level=None, host=None, port=None, id=None, collection_name=None, model=None):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param model: Optional. GPT4ALL uses the model to instantiate the class.
Expand All @@ -26,6 +27,7 @@ def __init__(self, log_level=None, host=None, port=None, id=None, model=None):
host=host,
port=port,
id=id,
collection_name=collection_name,
)

@staticmethod
Expand Down
10 changes: 9 additions & 1 deletion embedchain/embedchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, config: BaseAppConfig):

self.config = config
self.db_client = self.config.db.client
self.collection = self.config.db.collection
self.collection = self.config.db._get_or_create_collection(self.config.collection_name)
self.user_asks = []
self.is_docs_site_instance = False
self.online = False
Expand Down Expand Up @@ -325,6 +325,14 @@ def _stream_chat_response(self, answer):
memory.chat_memory.add_ai_message(streamed_answer)
logging.info(f"Answer: {streamed_answer}")

def set_collection(self, collection_name):
"""
Set the collection to use.

:param collection_name: The name of the collection to use.
"""
self.collection = self.config.db._get_or_create_collection(collection_name)

def count(self):
"""
Count the number of embeddings.
Expand Down
1 change: 0 additions & 1 deletion embedchain/vectordb/base_vector_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ class BaseVectorDB:

def __init__(self):
self.client = self._get_or_create_db()
self.collection = self._get_or_create_collection()

def _get_or_create_db(self):
"""Get or create the database."""
Expand Down
4 changes: 2 additions & 2 deletions embedchain/vectordb/chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def _get_or_create_db(self):
"""Get or create the database."""
return self.client

def _get_or_create_collection(self):
def _get_or_create_collection(self, name):
"""Get or create the collection."""
return self.client.get_or_create_collection(
"embedchain_store",
name=name,
embedding_function=self.embedding_fn,
)
183 changes: 183 additions & 0 deletions tests/vectordb/test_chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,186 @@ def test_init_with_host_and_port(self, mock_client):

self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)

class TestChromaDbDuplicateHandling:
def test_duplicates_throw_warning(self, caplog):
"""
Test that add duplicates throws an error.
"""
# Start with a clean app
App().reset()

app = App()
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" in caplog.text
assert "Add of existing embedding ID: 0" in caplog.text

def test_duplicates_collections_no_warning(self, caplog):
"""
Test that different collections can have duplicates.
"""
# NOTE: Not part of the TestChromaDbCollection because `unittest.TestCase` doesn't have caplog.

# Start with a clean app
App().reset()

app = App()
app.set_collection("test_collection_1")
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
app.set_collection("test_collection_2")
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
assert "Insert of existing embedding ID: 0" not in caplog.text # not
assert "Add of existing embedding ID: 0" not in caplog.text # not


class TestChromaDbCollection(unittest.TestCase):
def test_init_with_default_collection(self):
"""
Test if the `App` instance is initialized with the correct default collection name.
"""
app = App()

self.assertEqual(app.collection.name, "embedchain_store")

def test_init_with_custom_collection(self):
"""
Test if the `App` instance is initialized with the correct custom collection name.
"""
config = AppConfig(collection_name="test_collection")
app = App(config)

self.assertEqual(app.collection.name, "test_collection")

def test_set_collection(self):
"""
Test if the `App` collection is correctly switched using the `set_collection` method.
"""
app = App()
app.set_collection("test_collection")

self.assertEqual(app.collection.name, "test_collection")

def test_changes_encapsulated(self):
"""
Test that changes to one collection do not affect the other collection
"""
# Start with a clean app
App().reset()

app = App()
app.set_collection("test_collection_1")
# Collection should be empty when created
self.assertEqual(app.count(), 0)

app.collection.add(embeddings=[0, 0, 0], ids=["0"])
# After adding, should contain one item
self.assertEqual(app.count(), 1)

app.set_collection("test_collection_2")
# New collection is empty
self.assertEqual(app.count(), 0)

# Adding to new collection should not effect existing collection
app.collection.add(embeddings=[0, 0, 0], ids=["0"])
app.set_collection("test_collection_1")
# Should still be 1, not 2.
self.assertEqual(app.count(), 1)

def test_collections_are_persistent(self):
"""
Test that a collection can be picked up later.
"""
# Start with a clean app
App().reset()

app = App()
app.set_collection("test_collection_1")
app.collection.add(embeddings=[[0, 0, 0]], ids=["0"])
del app

app = App()
app.set_collection("test_collection_1")
self.assertEqual(app.count(), 1)

def test_parallel_collections(self):
"""
Test that two apps can have different collections open in parallel.
Switching the names will allow instant access to the collection of
the other app.
"""
# Start clean
App().reset()

# Create two apps
app1 = App(AppConfig(collection_name="test_collection_1"))
app2 = App(AppConfig(collection_name="test_collection_2"))

# app2 has been created last, but adding to app1 will still write to collection 1.
app1.collection.add(embeddings=[0, 0, 0], ids=["0"])
self.assertEqual(app1.count(), 1)
self.assertEqual(app2.count(), 0)

# Add data
app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["1", "2"])
app2.collection.add(embeddings=[0, 0, 0], ids=["0"])

# Swap names and test
app1.set_collection('test_collection_2')
self.assertEqual(app1.count(), 1)
app2.set_collection('test_collection_1')
self.assertEqual(app2.count(), 3)

def test_ids_share_collections(self):
"""
Different ids should still share collections.
"""
# Start clean
App().reset()

# Create two apps
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1"))
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2"))

# Add data
app1.collection.add(embeddings=[[0, 0, 0], [1, 1, 1]], ids=["0", "1"])
app2.collection.add(embeddings=[0, 0, 0], ids=["2"])

# Both should have the same collection
self.assertEqual(app1.count(), 3)
self.assertEqual(app2.count(), 3)

def test_reset(self):
"""
Resetting should hit all collections and ids.
"""
# Start clean
App().reset()

# Create four apps.
# app1, which we are about to reset, shares an app with one, and an id with the other, none with the last.
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1"))
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2"))
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_1"))
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_4"))

# Each one of them get data
app1.collection.add(embeddings=[0, 0, 0], ids=["1"])
app2.collection.add(embeddings=[0, 0, 0], ids=["2"])
app3.collection.add(embeddings=[0, 0, 0], ids=["3"])
app4.collection.add(embeddings=[0, 0, 0], ids=["4"])

# Resetting the first one should reset them all.
app1.reset()

# Reinstantiate them
app1 = App(AppConfig(collection_name="one_collection", id="new_app_id_1"))
app2 = App(AppConfig(collection_name="one_collection", id="new_app_id_2"))
app3 = App(AppConfig(collection_name="three_collection", id="new_app_id_3"))
app4 = App(AppConfig(collection_name="four_collection", id="new_app_id_3"))

# All should be empty
self.assertEqual(app1.count(), 0)
self.assertEqual(app2.count(), 0)
self.assertEqual(app3.count(), 0)
self.assertEqual(app4.count(), 0)