diff --git a/embedchain/data_formatter/__init__.py b/embedchain/data_formatter/__init__.py new file mode 100644 index 0000000000..ebd5943ac6 --- /dev/null +++ b/embedchain/data_formatter/__init__.py @@ -0,0 +1 @@ +from .data_formatter import DataFormatter \ No newline at end of file diff --git a/embedchain/data_formatter/data_formatter.py b/embedchain/data_formatter/data_formatter.py new file mode 100644 index 0000000000..0cbf34f4c9 --- /dev/null +++ b/embedchain/data_formatter/data_formatter.py @@ -0,0 +1,66 @@ +from embedchain.loaders.youtube_video import YoutubeVideoLoader +from embedchain.loaders.pdf_file import PdfFileLoader +from embedchain.loaders.web_page import WebPageLoader +from embedchain.loaders.local_qna_pair import LocalQnaPairLoader +from embedchain.loaders.local_text import LocalTextLoader +from embedchain.loaders.docx_file import DocxFileLoader +from embedchain.chunkers.youtube_video import YoutubeVideoChunker +from embedchain.chunkers.pdf_file import PdfFileChunker +from embedchain.chunkers.web_page import WebPageChunker +from embedchain.chunkers.qna_pair import QnaPairChunker +from embedchain.chunkers.text import TextChunker +from embedchain.chunkers.docx_file import DocxFileChunker + + +class DataFormatter: + """ + DataFormatter is an internal utility class which abstracts the mapping for + loaders and chunkers to the data_type entered by the user in their + .add or .add_local method call + """ + def __init__(self, data_type): + self.loader = self._get_loader(data_type) + self.chunker = self._get_chunker(data_type) + + def _get_loader(self, data_type): + """ + Returns the appropriate data loader for the given data type. + + :param data_type: The type of the data to load. + :return: The loader for the given data type. + :raises ValueError: If an unsupported data type is provided. + """ + loaders = { + 'youtube_video': YoutubeVideoLoader(), + 'pdf_file': PdfFileLoader(), + 'web_page': WebPageLoader(), + 'qna_pair': LocalQnaPairLoader(), + 'text': LocalTextLoader(), + 'docx': DocxFileLoader(), + } + if data_type in loaders: + return loaders[data_type] + else: + raise ValueError(f"Unsupported data type: {data_type}") + + def _get_chunker(self, data_type): + """ + Returns the appropriate chunker for the given data type. + + :param data_type: The type of the data to chunk. + :return: The chunker for the given data type. + :raises ValueError: If an unsupported data type is provided. + """ + chunkers = { + 'youtube_video': YoutubeVideoChunker(), + 'pdf_file': PdfFileChunker(), + 'web_page': WebPageChunker(), + 'qna_pair': QnaPairChunker(), + 'text': TextChunker(), + 'docx': DocxFileChunker(), + } + if data_type in chunkers: + return chunkers[data_type] + else: + raise ValueError(f"Unsupported data type: {data_type}") + diff --git a/embedchain/embedchain.py b/embedchain/embedchain.py index aa84f7471b..14d24de4cb 100644 --- a/embedchain/embedchain.py +++ b/embedchain/embedchain.py @@ -9,21 +9,7 @@ from langchain.memory import ConversationBufferMemory from embedchain.config import InitConfig, AddConfig, QueryConfig, ChatConfig from embedchain.config.QueryConfig import DEFAULT_PROMPT - -from embedchain.loaders.youtube_video import YoutubeVideoLoader -from embedchain.loaders.pdf_file import PdfFileLoader -from embedchain.loaders.web_page import WebPageLoader -from embedchain.loaders.local_qna_pair import LocalQnaPairLoader -from embedchain.loaders.local_text import LocalTextLoader -from embedchain.loaders.docx_file import DocxFileLoader -from embedchain.chunkers.youtube_video import YoutubeVideoChunker -from embedchain.chunkers.pdf_file import PdfFileChunker -from embedchain.chunkers.web_page import WebPageChunker -from embedchain.chunkers.qna_pair import QnaPairChunker -from embedchain.chunkers.text import TextChunker -from embedchain.chunkers.docx_file import DocxFileChunker -from embedchain.vectordb.chroma_db import ChromaDB - +from embedchain.data_formatter import DataFormatter gpt4all_model = None @@ -49,48 +35,9 @@ def __init__(self, config: InitConfig): self.collection = self.config.db.collection self.user_asks = [] - def _get_loader(self, data_type): - """ - Returns the appropriate data loader for the given data type. - - :param data_type: The type of the data to load. - :return: The loader for the given data type. - :raises ValueError: If an unsupported data type is provided. - """ - loaders = { - 'youtube_video': YoutubeVideoLoader(), - 'pdf_file': PdfFileLoader(), - 'web_page': WebPageLoader(), - 'qna_pair': LocalQnaPairLoader(), - 'text': LocalTextLoader(), - 'docx': DocxFileLoader(), - } - if data_type in loaders: - return loaders[data_type] - else: - raise ValueError(f"Unsupported data type: {data_type}") - - def _get_chunker(self, data_type): - """ - Returns the appropriate chunker for the given data type. - - :param data_type: The type of the data to chunk. - :return: The chunker for the given data type. - :raises ValueError: If an unsupported data type is provided. - """ - chunkers = { - 'youtube_video': YoutubeVideoChunker(), - 'pdf_file': PdfFileChunker(), - 'web_page': WebPageChunker(), - 'qna_pair': QnaPairChunker(), - 'text': TextChunker(), - 'docx': DocxFileChunker(), - } - if data_type in chunkers: - return chunkers[data_type] - else: - raise ValueError(f"Unsupported data type: {data_type}") + + def add(self, data_type, url, config: AddConfig = None): """ Adds the data from the given URL to the vector db. @@ -103,10 +50,10 @@ def add(self, data_type, url, config: AddConfig = None): """ if config is None: config = AddConfig() - loader = self._get_loader(data_type) - chunker = self._get_chunker(data_type) + + data_formatter = DataFormatter(data_type) self.user_asks.append([data_type, url]) - self.load_and_embed(loader, chunker, url) + self.load_and_embed(data_formatter.loader, data_formatter.chunker, url) def add_local(self, data_type, content, config: AddConfig = None): """ @@ -120,10 +67,10 @@ def add_local(self, data_type, content, config: AddConfig = None): """ if config is None: config = AddConfig() - loader = self._get_loader(data_type) - chunker = self._get_chunker(data_type) + + data_formatter = DataFormatter(data_type) self.user_asks.append([data_type, content]) - self.load_and_embed(loader, chunker, content) + self.load_and_embed(data_formatter.loader, data_formatter.chunker, content) def load_and_embed(self, loader, chunker, src): """