Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: collection name everywhere #310

Merged
merged 19 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
2a40d31
feat: allow switching between collections and specifying initial coll…
jonasiwnl Jul 18, 2023
09ec899
Merge branch 'feat/collection-name-everywhere' of https://github.com/…
jonasiwnl Jul 18, 2023
abe9f60
refactor: make collection name everywhere feature compatible with Bas…
jonasiwnl Jul 18, 2023
644c5a1
Merge pull request #1 from jonasiwnl/main
jonasiwnl Jul 18, 2023
201ac5d
chore: remove sentence-transformers dependency
jonasiwnl Jul 18, 2023
5bbcfbc
chore: revert changes to pyproject.toml
jonasiwnl Jul 18, 2023
52f1891
Merge branch 'main' into feat/collection-name-everywhere
cachho Jul 18, 2023
8257bf5
chore: fix merge errors, reformat, remove unused imports
jonasiwnl Jul 18, 2023
1f6b22a
Merge branch 'main' into feat/collection-name-everywhere
jonasiwnl Jul 20, 2023
8d6ec4d
chore: add default collection name to BaseAppConfig
jonasiwnl Jul 28, 2023
ce62702
Merge branch 'main' into feat/collection-name-everywhere
jonasiwnl Jul 28, 2023
ded8ab6
chore: remove default collection name redundancy
jonasiwnl Jul 28, 2023
f69966f
Merge branch 'feat/collection-name-everywhere' of https://github.com/…
jonasiwnl Jul 28, 2023
c8edfed
chore: change collection name to old default
cachho Aug 2, 2023
fee6f94
chore: old default name for compatibility
cachho Aug 2, 2023
cb04be0
refactor: unified argument order
cachho Aug 2, 2023
4dcb811
docs: fix param name
cachho Aug 2, 2023
7d2dfa0
chore: use old default collection name
cachho Aug 2, 2023
1b7b461
test: added unit tests
cachho Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: allow switching between collections and specifying initial coll…
…ection name
  • Loading branch information
jonasiwnl committed Jul 18, 2023
commit 2a40d31610645aa7333fbe6031415af70ac48a3c
4 changes: 4 additions & 0 deletions docs/advanced/configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ config = InitConfig(ef=embedding_functions.OpenAIEmbeddingFunction(
))
naval_chat_bot = App(config)

# Example: specify a custom collection name
config = InitConfig(collection_name="naval_chat_bot")
naval_chat_bot = App(config)

# Example: define your own chunker config for `youtube_video`
youtube_add_config = {
"chunker": {
Expand Down
1 change: 1 addition & 0 deletions docs/advanced/query_configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ title: '🔍 Query configurations'
| log_level | log level | string | WARNING |
| ef | embedding function | chromadb.utils.embedding_functions | \{text-embedding-ada-002\} |
| db | vector database (experimental) | BaseVectorDB | ChromaDB |
| collection_name | initial collection name for the database | string | embedchain_store |


## AddConfig
Expand Down
4 changes: 3 additions & 1 deletion embedchain/config/InitConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,21 @@ class InitConfig(BaseConfig):
Config to initialize an embedchain `App` instance.
"""

def __init__(self, log_level=None, ef=None, db=None, host=None, port=None, id=None):
def __init__(self, log_level=None, ef=None, db=None, collection_name=None, host=None, port=None, id=None):
"""
:param log_level: Optional. (String) Debug level
['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'].
:param ef: Optional. Embedding function to use.
:param db: Optional. (Vector) database to use for embeddings.
:param collection_name: Optional. Collection name for the database.
:param id: Optional. ID of the app. Document metadata will have this id.
:param host: Optional. Hostname for the database server.
:param port: Optional. Port for the database server.
"""
self._setup_logging(log_level)
self.ef = ef
self.db = db
self.collection_name = collection_name
self.host = host
self.port = port
self.id = id
Expand Down
11 changes: 10 additions & 1 deletion embedchain/embedchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, config: InitConfig):

self.config = config
self.db_client = self.config.db.client
self.collection = self.config.db.collection
self.collection = self.config.db._get_or_create_collection(self.config.collection_name)
self.user_asks = []
self.is_code_docs_instance = False
self.online = False
Expand Down Expand Up @@ -201,6 +201,7 @@ def get_answer_from_llm(self, prompt, config: ChatConfig):

def access_search_and_get_results(self, input_query):
from langchain.tools import DuckDuckGoSearchRun

search = DuckDuckGoSearchRun()
logging.info(f"Access search to get answers for {input_query}")
return search.run(input_query)
Expand Down Expand Up @@ -314,6 +315,14 @@ def dry_run(self, input_query, config: QueryConfig = None):
logging.info(f"Prompt: {prompt}")
return prompt

def set_collection(self, collection_name):
"""
Set the collection to use.

:param name: The name of the collection to use.
"""
self.collection = self.config.db._get_or_create_collection(collection_name)

def count(self):
"""
Count the number of embeddings.
Expand Down
6 changes: 4 additions & 2 deletions embedchain/vectordb/chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ def _get_or_create_db(self):
"""Get or create the database."""
return chromadb.Client(self.client_settings)

def _get_or_create_collection(self):
def _get_or_create_collection(self, name=None):
cachho marked this conversation as resolved.
Show resolved Hide resolved
"""Get or create the collection."""
if name is None:
name = "embedchain_store"
return self.client.get_or_create_collection(
"embedchain_store",
name=name,
embedding_function=self.ef,
)
30 changes: 30 additions & 0 deletions tests/vectordb/test_chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,33 @@ def test_init_with_host_and_port(self, mock_client):

self.assertEqual(mock_client.call_args[0][0].chroma_server_host, None)
self.assertEqual(mock_client.call_args[0][0].chroma_server_http_port, None)


class TestChromaDbDefaultCollection(unittest.TestCase):
def test_init_with_default_collection(self):
"""
Test if the `App` instance is initialized with the correct default collection name.
"""
app = App()

self.assertEqual(app.collection.name, "embedchain_store")

def test_init_with_custom_collection(self):
"""
Test if the `App` instance is initialized with the correct custom collection name.
"""
config = InitConfig(collection_name="test_collection")
app = App(config)

self.assertEqual(app.collection.name, "test_collection")


class TestChromaDbSetCollection(unittest.TestCase):
def test_init_with_host_and_port(self):
"""
Test if the `App` collection is correctly switched using the `set_collection` method.
"""
app = App()
app.set_collection("test_collection")

self.assertEqual(app.collection.name, "test_collection")