Skip to content

Commit

Permalink
Refactor imports to allow using without Annoy/SentenceTransformers wh…
Browse files Browse the repository at this point in the history
…en using custom embedding search.
  • Loading branch information
drazvan committed Sep 4, 2023
1 parent 57f2b6c commit ec07145
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Moved to using `nest_asyncio` for [implementing the blocking API](./docs/user_guide/advanced/nested-async-loop.md). Fixes [#3](https://github.com/NVIDIA/NeMo-Guardrails/issues/3) and [#32](https://github.com/NVIDIA/NeMo-Guardrails/issues/32).
- Improved event property validation in `new_event_dict`.
- Refactored imports to allow installing from source without Annoy/SentenceTransformers (would need a custom embedding search provider to work).

### Fixed

Expand Down
3 changes: 2 additions & 1 deletion nemoguardrails/embeddings/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import List

from annoy import AnnoyIndex
from sentence_transformers import SentenceTransformer
from torch import cuda

from nemoguardrails.embeddings.index import EmbeddingModel, EmbeddingsIndex, IndexItem
Expand Down Expand Up @@ -115,6 +114,8 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):
"""Embedding model using sentence-transformers."""

def __init__(self, embedding_model: str):
from sentence_transformers import SentenceTransformer

device = "cuda" if cuda.is_available() else "cpu"
self.model = SentenceTransformer(embedding_model, device=device)
# Get the embedding dimension of the model
Expand Down
5 changes: 4 additions & 1 deletion nemoguardrails/kb/kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from annoy import AnnoyIndex

from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex
from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
from nemoguardrails.kb.utils import split_markdown_in_topic_chunks
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, KnowledgeBaseConfig
Expand Down Expand Up @@ -89,6 +88,8 @@ async def build(self):
and os.path.exists(cache_file)
and os.path.exists(embedding_size_file)
):
from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex

log.info(cache_file)
self.index = cast(
BasicEmbeddingsIndex,
Expand Down Expand Up @@ -116,6 +117,8 @@ async def build(self):
# For the default Embedding Search provider, which uses annoy, we also
# persist the index after it's computed.
if self.config.embedding_search_provider.name == "default":
from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex

# We also save the file for future use
os.makedirs(CACHE_FOLDER, exist_ok=True)
basic_index = cast(BasicEmbeddingsIndex, self.index)
Expand Down
3 changes: 2 additions & 1 deletion nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from nemoguardrails.actions.math import wolfram_alpha_request
from nemoguardrails.actions.output_moderation import output_moderation
from nemoguardrails.actions.retrieve_relevant_chunks import retrieve_relevant_chunks
from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex
from nemoguardrails.embeddings.index import EmbeddingsIndex
from nemoguardrails.flows.runtime import Runtime
from nemoguardrails.kb.kb import KnowledgeBase
Expand Down Expand Up @@ -227,6 +226,8 @@ def _get_embeddings_search_provider_instance(
esp_config = EmbeddingSearchProvider()

if esp_config.name == "default":
from nemoguardrails.embeddings.basic import BasicEmbeddingsIndex

return BasicEmbeddingsIndex(
embedding_model=esp_config.parameters.get(
"embedding_model", self.default_embedding_model
Expand Down

0 comments on commit ec07145

Please sign in to comment.