Skip to content

Commit

Permalink
feat: Add embedding manager (mem0ai#570)
Browse files Browse the repository at this point in the history
  • Loading branch information
taranjeet committed Sep 12, 2023
1 parent ba208f5 commit 2bd6881
Show file tree
Hide file tree
Showing 16 changed files with 311 additions and 73 deletions.
8 changes: 6 additions & 2 deletions embedchain/chunkers/base_chunker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@ def create_chunks(self, loader, src):
documents = []
ids = []
idMap = {}
datas = loader.load_data(src)
data_result = loader.load_data(src)
data_records = data_result["data"]
doc_id = data_result["doc_id"]
metadatas = []
for data in datas:
for data in data_records:
content = data["content"]

meta_data = data["meta_data"]
# add data type to meta data to allow query using data type
meta_data["data_type"] = self.data_type.value
meta_data["doc_id"] = doc_id
url = meta_data["url"]

chunks = self.get_chunks(content)
Expand All @@ -45,6 +48,7 @@ def create_chunks(self, loader, src):
"documents": documents,
"ids": ids,
"metadatas": metadatas,
"doc_id": doc_id,
}

def get_chunks(self, content):
Expand Down
112 changes: 110 additions & 2 deletions embedchain/embedchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import requests
from dotenv import load_dotenv
from langchain.docstore.document import Document
from tenacity import retry, stop_after_attempt, wait_fixed

from embedchain.chunkers.base_chunker import BaseChunker
Expand Down Expand Up @@ -179,7 +180,7 @@ def add(

data_formatter = DataFormatter(data_type, config)
self.user_asks.append([source, data_type.value, metadata])
documents, metadatas, _ids, new_chunks = self.load_and_embed(
documents, metadatas, _ids, new_chunks = self.load_and_embed_v2(
data_formatter.loader, data_formatter.chunker, source, metadata, source_id, dry_run
)
if data_type in {DataType.DOCS_SITE}:
Expand Down Expand Up @@ -271,10 +272,11 @@ def load_and_embed(
# get existing ids, and discard doc if any common id exist.
where = {"app_id": self.config.id} if self.config.id is not None else {}
# where={"url": src}
existing_ids = self.db.get(
db_result = self.db.get(
ids=ids,
where=where, # optional filter
)
existing_ids = set(db_result["ids"])

if len(existing_ids):
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
Expand Down Expand Up @@ -317,6 +319,112 @@ def load_and_embed(
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks

def load_and_embed_v2(
self,
loader: BaseLoader,
chunker: BaseChunker,
src: Any,
metadata: Optional[Dict[str, Any]] = None,
source_id: Optional[str] = None,
dry_run = False
):
"""
Loads the data from the given URL, chunks it, and adds it to database.
:param loader: The loader to use to load the data.
:param chunker: The chunker to use to chunk the data.
:param src: The data to be handled by the loader. Can be a URL for
remote sources or local content for local loaders.
:param metadata: Optional. Metadata associated with the data source.
:param source_id: Hexadecimal hash of the source.
:return: (List) documents (embedded text), (List) metadata, (list) ids, (int) number of chunks
"""
existing_embeddings_data = self.db.get(
where={
"url": src,
},
limit=1,
)
try:
existing_doc_id = existing_embeddings_data.get("metadatas", [])[0]["doc_id"]
except Exception:
existing_doc_id = None
embeddings_data = chunker.create_chunks(loader, src)

# spread chunking results
documents = embeddings_data["documents"]
metadatas = embeddings_data["metadatas"]
ids = embeddings_data["ids"]
new_doc_id = embeddings_data["doc_id"]

if existing_doc_id and existing_doc_id == new_doc_id:
print("Doc content has not changed. Skipping creating chunks and embeddings")
return [], [], [], 0

# this means that doc content has changed.
if existing_doc_id and existing_doc_id != new_doc_id:
print("Doc content has changed. Recomputing chunks and embeddings intelligently.")
self.db.delete({
"doc_id": existing_doc_id
})

# get existing ids, and discard doc if any common id exist.
where = {"app_id": self.config.id} if self.config.id is not None else {}
# where={"url": src}
db_result = self.db.get(
ids=ids,
where=where, # optional filter
)
existing_ids = set(db_result["ids"])

if len(existing_ids):
data_dict = {id: (doc, meta) for id, doc, meta in zip(ids, documents, metadatas)}
data_dict = {id: value for id, value in data_dict.items() if id not in existing_ids}

if not data_dict:
print(f"All data from {src} already exists in the database.")
# Make sure to return a matching return type
return [], [], [], 0

ids = list(data_dict.keys())
documents, metadatas = zip(*data_dict.values())

# Loop though all metadatas and add extras.
new_metadatas = []
for m in metadatas:
# Add app id in metadatas so that they can be queried on later
if self.config.id:
m["app_id"] = self.config.id

# Add hashed source
m["hash"] = source_id

# Note: Metadata is the function argument
if metadata:
# Spread whatever is in metadata into the new object.
m.update(metadata)

new_metadatas.append(m)
metadatas = new_metadatas

# Count before, to calculate a delta in the end.
chunks_before_addition = self.count()

self.db.add(documents=documents, metadatas=metadatas, ids=ids)
count_new_chunks = self.count() - chunks_before_addition
print((f"Successfully saved {src} ({chunker.data_type}). New chunks count: {count_new_chunks}"))
return list(documents), metadatas, ids, count_new_chunks

def _format_result(self, results):
return [
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]

def retrieve_from_database(self, input_query: str, config: Optional[BaseLlmConfig] = None, where=None) -> List[str]:
"""
Queries the vector database based on the given input query.
Expand Down
10 changes: 8 additions & 2 deletions embedchain/loaders/csv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import csv
import hashlib
from io import StringIO
from urllib.parse import urlparse

Expand Down Expand Up @@ -34,13 +35,18 @@ def _get_file_content(content):
def load_data(content):
"""Load a csv file with headers. Each line is a document"""
result = []

lines = []
with CsvLoader._get_file_content(content) as file:
first_line = file.readline()
delimiter = CsvLoader._detect_delimiter(first_line)
file.seek(0) # Reset the file pointer to the start
reader = csv.DictReader(file, delimiter=delimiter)
for i, row in enumerate(reader):
line = ", ".join([f"{field}: {value}" for field, value in row.items()])
lines.append(line)
result.append({"content": line, "meta_data": {"url": content, "row": i + 1}})
return result
doc_id = hashlib.sha256((content + " ".join(lines)).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": result
}
7 changes: 6 additions & 1 deletion embedchain/loaders/docs_site_loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import logging
from urllib.parse import urljoin, urlparse

Expand Down Expand Up @@ -99,4 +100,8 @@ def load_data(self, url):
output = []
for u in all_urls:
output.extend(self._load_data_from_url(u))
return output
doc_id = hashlib.sha256((" ".join(all_urls) + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}
8 changes: 7 additions & 1 deletion embedchain/loaders/docx_file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib

from langchain.document_loaders import Docx2txtLoader

from embedchain.helper.json_serializable import register_deserializable
Expand All @@ -15,4 +17,8 @@ def load_data(self, url):
meta_data = data[0].metadata
meta_data["url"] = "local"
output.append({"content": content, "meta_data": meta_data})
return output
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": output,
}
21 changes: 14 additions & 7 deletions embedchain/loaders/local_qna_pair.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib

from embedchain.helper.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader

Expand All @@ -8,12 +10,17 @@ def load_data(self, content):
"""Load data from a local QnA pair."""
question, answer = content
content = f"Q: {question}\nA: {answer}"
url = "local"
meta_data = {
"url": "local",
"url": url,
}
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": meta_data,
}
]
}
return [
{
"content": content,
"meta_data": meta_data,
}
]
21 changes: 14 additions & 7 deletions embedchain/loaders/local_text.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib

from embedchain.helper.json_serializable import register_deserializable
from embedchain.loaders.base_loader import BaseLoader

Expand All @@ -6,12 +8,17 @@
class LocalTextLoader(BaseLoader):
def load_data(self, content):
"""Load data from a local text file."""
url = "local"
meta_data = {
"url": "local",
"url": url,
}
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": meta_data,
}
]
}
return [
{
"content": content,
"meta_data": meta_data,
}
]
10 changes: 7 additions & 3 deletions embedchain/loaders/notion.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import logging
import os

Expand Down Expand Up @@ -34,10 +35,13 @@ def load_data(self, source):

# Clean text
text = clean_string(raw_text)

return [
doc_id = hashlib.sha256((text + source).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": text,
"meta_data": {"url": f"notion-{formatted_id}"},
}
]
],
}
14 changes: 11 additions & 3 deletions embedchain/loaders/pdf_file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib

from langchain.document_loaders import PyPDFLoader

from embedchain.helper.json_serializable import register_deserializable
Expand All @@ -10,7 +12,8 @@ class PdfFileLoader(BaseLoader):
def load_data(self, url):
"""Load data from a PDF file."""
loader = PyPDFLoader(url)
output = []
data = []
all_content = []
pages = loader.load_and_split()
if not len(pages):
raise ValueError("No data found")
Expand All @@ -19,10 +22,15 @@ def load_data(self, url):
content = clean_string(content)
meta_data = page.metadata
meta_data["url"] = url
output.append(
data.append(
{
"content": content,
"meta_data": meta_data,
}
)
return output
all_content.append(content)
doc_id = hashlib.sha256((" ".join(all_content) + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": data,
}
8 changes: 7 additions & 1 deletion embedchain/loaders/sitemap.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import logging

import requests
Expand Down Expand Up @@ -30,6 +31,8 @@ def load_data(self, sitemap_url):
# Get all <loc> tags as a fallback. This might include images.
links = [link.text for link in soup.find_all("loc")]

doc_id = hashlib.sha256((" ".join(links) + sitemap_url).encode()).hexdigest()

for link in links:
try:
each_load_data = web_page_loader.load_data(link)
Expand All @@ -40,4 +43,7 @@ def load_data(self, sitemap_url):
logging.warning(f"Page is not readable (too many invalid characters): {link}")
except ParserRejectedMarkup as e:
logging.error(f"Failed to parse {link}: {e}")
return [data[0] for data in output]
return {
"doc_id": doc_id,
"data": [data[0] for data in output]
}
19 changes: 12 additions & 7 deletions embedchain/loaders/web_page.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import logging

import requests
Expand Down Expand Up @@ -63,10 +64,14 @@ def load_data(self, url):
meta_data = {
"url": url,
}

return [
{
"content": content,
"meta_data": meta_data,
}
]
content = content
doc_id = hashlib.sha256((content + url).encode()).hexdigest()
return {
"doc_id": doc_id,
"data": [
{
"content": content,
"meta_data": meta_data,
}
],
}
Loading

0 comments on commit 2bd6881

Please sign in to comment.