Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: collection name everywhere #310

Merged
merged 19 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2a40d31
feat: allow switching between collections and specifying initial coll…
jonasiwnl Jul 18, 2023
09ec899
Merge branch 'feat/collection-name-everywhere' of https://github.com/…
jonasiwnl Jul 18, 2023
abe9f60
refactor: make collection name everywhere feature compatible with Bas…
jonasiwnl Jul 18, 2023
644c5a1
Merge pull request #1 from jonasiwnl/main
jonasiwnl Jul 18, 2023
201ac5d
chore: remove sentence-transformers dependency
jonasiwnl Jul 18, 2023
5bbcfbc
chore: revert changes to pyproject.toml
jonasiwnl Jul 18, 2023
52f1891
Merge branch 'main' into feat/collection-name-everywhere
cachho Jul 18, 2023
8257bf5
chore: fix merge errors, reformat, remove unused imports
jonasiwnl Jul 18, 2023
1f6b22a
Merge branch 'main' into feat/collection-name-everywhere
jonasiwnl Jul 20, 2023
8d6ec4d
chore: add default collection name to BaseAppConfig
jonasiwnl Jul 28, 2023
ce62702
Merge branch 'main' into feat/collection-name-everywhere
jonasiwnl Jul 28, 2023
ded8ab6
chore: remove default collection name redundancy
jonasiwnl Jul 28, 2023
f69966f
Merge branch 'feat/collection-name-everywhere' of https://github.com/…
jonasiwnl Jul 28, 2023
c8edfed
chore: change collection name to old default
cachho Aug 2, 2023
fee6f94
chore: old default name for compatibility
cachho Aug 2, 2023
cb04be0
refactor: unified argument order
cachho Aug 2, 2023
4dcb811
docs: fix param name
cachho Aug 2, 2023
7d2dfa0
chore: use old default collection name
cachho Aug 2, 2023
1b7b461
test: added unit tests
cachho Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/advanced/configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ from chromadb.utils import embedding_functions
config = AppConfig(log_level="DEBUG")
naval_chat_bot = App(config)

# Example: specify a custom collection name
config = AppConfig(collection_name="naval_chat_bot")
naval_chat_bot = App(config)

# Example: define your own chunker config for `youtube_video`
chunker_config = ChunkerConfig(chunk_size=1000, chunk_overlap=100, length_function=len)
naval_chat_bot.add("youtube_video", "https://www.youtube.com/watch?v=3qHkcs3kG44", AddConfig(chunker=chunker_config))
Expand Down
9 changes: 5 additions & 4 deletions docs/advanced/query_configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ title: '🔍 Query configurations'

## AppConfig

| option | description | type | default |
|-------------|-----------------------|---------------------------------|------------------------|
| log_level | log level | string | WARNING |
| option | description | type | default |
|-----------|-----------------------|---------------------------------|------------------------|
| log_level | log level | string | WARNING |
| embedding_fn| embedding function | chromadb.utils.embedding_functions | \{text-embedding-ada-002\} |
| db | vector database (experimental) | BaseVectorDB | ChromaDB |
| db | vector database (experimental) | BaseVectorDB | ChromaDB |
| collection_name | initial collection name for the database | string | embedchain_store |


## AddConfig
Expand Down
10 changes: 8 additions & 2 deletions embedchain/config/apps/AppConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,22 @@ class AppConfig(BaseAppConfig):
Config to initialize an embedchain custom `App` instance, with extra config options.
"""

def __init__(self, log_level=None, host=None, port=None, id=None):
def __init__(self, log_level=None, collection_name=None, host=None, port=None, id=None):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
"""
super().__init__(
log_level=log_level, embedding_fn=AppConfig.default_embedding_function(), host=host, port=port, id=id
log_level=log_level,
collection_name=collection_name,
embedding_fn=AppConfig.default_embedding_function(),
host=host,
port=port,
id=id,
)

@staticmethod
Expand Down
4 changes: 3 additions & 1 deletion embedchain/config/apps/BaseAppConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,21 @@ class BaseAppConfig(BaseConfig):
Parent config to initialize an instance of `App`, `OpenSourceApp` or `CustomApp`.
"""

def __init__(self, log_level=None, embedding_fn=None, db=None, host=None, port=None, id=None):
def __init__(self, log_level=None, embedding_fn=None, db=None, collection_name=None, host=None, port=None, id=None):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param embedding_fn: Embedding function to use.
:param db: Optional. (Vector) database instance to use for embeddings.
:param collection_name: Optional. Collection name for the database.
:param id: Optional. ID of the app. Document metadata will have this id.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
"""
self._setup_logging(log_level)

self.db = db if db else BaseAppConfig.default_db(embedding_fn=embedding_fn, host=host, port=port)
self.collection_name = collection_name
cachho marked this conversation as resolved.
Show resolved Hide resolved
self.id = id
return

Expand Down
3 changes: 3 additions & 0 deletions embedchain/config/apps/CustomAppConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
embedding_fn: EmbeddingFunctions = None,
embedding_fn_model=None,
db=None,
collection_name=None,
host=None,
port=None,
id=None,
Expand All @@ -34,6 +35,7 @@ def __init__(
:param embedding_fn: Optional. Embedding function to use.
:param embedding_fn_model: Optional. Model name to use for embedding function.
:param db: Optional. (Vector) database to use for embeddings.
:param collection_name: Optional. Collection name for the database.
:param id: Optional. ID of the app. Document metadata will have this id.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
Expand All @@ -51,6 +53,7 @@ def __init__(
log_level=log_level,
embedding_fn=CustomAppConfig.embedding_function(embedding_function=embedding_fn, model=embedding_fn_model),
db=db,
collection_name=collection_name,
host=host,
port=port,
id=id,
Expand Down
4 changes: 3 additions & 1 deletion embedchain/config/apps/OpenSourceAppConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ class OpenSourceAppConfig(BaseAppConfig):
Config to initialize an embedchain custom `OpenSourceApp` instance, with extra config options.
"""

def __init__(self, log_level=None, host=None, port=None, id=None, model=None):
def __init__(self, log_level=None, collection_name=None, host=None, port=None, id=None, model=None):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param id: Optional. ID of the app. Document metadata will have this id.
:param collection_name: Optional. Collection name for the database.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
:param model: Optional. GPT4ALL uses the model to instantiate the class.
Expand All @@ -23,6 +24,7 @@ def __init__(self, log_level=None, host=None, port=None, id=None, model=None):
super().__init__(
log_level=log_level,
embedding_fn=OpenSourceAppConfig.default_embedding_function(),
collection_name=collection_name,
host=host,
port=port,
id=id,
Expand Down
10 changes: 9 additions & 1 deletion embedchain/embedchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, config: BaseAppConfig):

self.config = config
self.db_client = self.config.db.client
self.collection = self.config.db.collection
self.collection = self.config.db._get_or_create_collection(self.config.collection_name)
self.user_asks = []
self.is_docs_site_instance = False
self.online = False
Expand Down Expand Up @@ -322,6 +322,14 @@ def dry_run(self, input_query, config: QueryConfig = None):
logging.info(f"Prompt: {prompt}")
return prompt

def set_collection(self, collection_name):
"""
Set the collection to use.

:param name: The name of the collection to use.
"""
self.collection = self.config.db._get_or_create_collection(collection_name)

def count(self):
"""
Count the number of embeddings.
Expand Down
6 changes: 4 additions & 2 deletions embedchain/vectordb/chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def _get_or_create_db(self):
"""Get or create the database."""
return chromadb.Client(self.client_settings)

def _get_or_create_collection(self):
def _get_or_create_collection(self, name=None):
cachho marked this conversation as resolved.
Show resolved Hide resolved
"""Get or create the collection."""
if name is None:
name = "embedchain_store"
return self.client.get_or_create_collection(
"embedchain_store",
name=name,
embedding_function=self.embedding_fn,
)
30 changes: 29 additions & 1 deletion tests/vectordb/test_chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import unittest
from unittest.mock import patch

from embedchain import App
from embedchain.apps.App import App
from embedchain.config import AppConfig
from embedchain.vectordb.chroma_db import ChromaDB, chromadb

Expand Down Expand Up @@ -71,3 +71,31 @@ def test_init_with_host_and_port(self, mock_client):

self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)


class TestChromaDbCollection(unittest.TestCase):
def test_init_with_default_collection(self):
"""
Test if the `App` instance is initialized with the correct default collection name.
"""
app = App()

self.assertEqual(app.collection.name, "embedchain_store")

def test_init_with_custom_collection(self):
"""
Test if the `App` instance is initialized with the correct custom collection name.
"""
config = AppConfig(collection_name="test_collection")
app = App(config)

self.assertEqual(app.collection.name, "test_collection")

def test_set_collection(self):
"""
Test if the `App` collection is correctly switched using the `set_collection` method.
"""
app = App()
app.set_collection("test_collection")

self.assertEqual(app.collection.name, "test_collection")