Skip to content

Commit

Permalink
refactor: get existing doc id method (mem0ai#616)
Browse files Browse the repository at this point in the history
  • Loading branch information
cachho committed Sep 17, 2023
1 parent 01fb216 commit 3d0e414
Showing 1 changed file with 34 additions and 28 deletions.
62 changes: 34 additions & 28 deletions embedchain/embedchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,26 +322,10 @@ def load_and_embed(
count_new_chunks = self.db.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks

def load_and_embed_v2(
self,
loader: BaseLoader,
chunker: BaseChunker,
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None,
dry_run=False,
):

def _get_existing_doc_id(self, chunker: BaseChunker, src: Any):
"""
Loads the data from the given URL, chunks it, and adds it to database.
:param loader: The loader to use to load the data.
:param chunker: The chunker to use to chunk the data.
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
:param metadata: Optional. Metadata associated with the data source.
:param source_id: Hexadecimal hash of the source.
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
Get id of existing document for a given source, based on the data type
"""
# Find existing embeddings for the source
# Depending on the data type, existing embeddings are checked for.
Expand All @@ -350,7 +334,7 @@ def load_and_embed_v2(
# Think of a text:
# Either it's the same, then it won't change, so it's not an update.
# Or it's different, then it will be added as a new text.
existing_doc_id = None
return None
elif chunker.data_type.value in [item.value for item in IndirectDataType]:
# These types have a indirect source reference
# As long as the reference is the same, they can be updated.
Expand All @@ -360,10 +344,10 @@ def load_and_embed_v2(
},
limit=1,
)
try:
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
except Exception:
existing_doc_id = None
if len(existing_embeddings_data.get("metadatas", [])) > 0:
return existing_embeddings_data["metadatas"][0]["doc_id"]
else:
return None
elif chunker.data_type.value in [item.value for item in SpecialDataType]:
# These types don't contain indirect references.
# Through custom logic, they can be attributed to a source and be updated.
Expand All @@ -375,10 +359,10 @@ def load_and_embed_v2(
},
limit=1,
)
try:
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
except Exception:
existing_doc_id = None
if len(existing_embeddings_data.get("metadatas", [])) > 0:
return existing_embeddings_data["metadatas"][0]["doc_id"]
else:
return None
else:
raise NotImplementedError(
f"SpecialDataType {chunker.data_type} must have a custom logic to check for existing data"
Expand All @@ -389,6 +373,28 @@ def load_and_embed_v2(
"When it should be DirectDataType, IndirectDataType or SpecialDataType."
)

def load_and_embed_v2(
self,
loader: BaseLoader,
chunker: BaseChunker,
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None,
dry_run=False,
):
"""
Loads the data from the given URL, chunks it, and adds it to database.
:param loader: The loader to use to load the data.
:param chunker: The chunker to use to chunk the data.
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
:param metadata: Optional. Metadata associated with the data source.
:param source_id: Hexadecimal hash of the source.
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
"""
existing_doc_id = self._get_existing_doc_id(chunker=chunker, src=src)

# Create chunks
embeddings_data = chunker.create_chunks(loader, src)

Expand Down

0 comments on commit 3d0e414

Please sign in to comment.