Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor SQL index to not require certain arguments #642

Merged
merged 3 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
cr
  • Loading branch information
Jerry Liu authored and Jerry Liu committed Mar 7, 2023
commit a277f8f28a3a88b0112179dda7aabe8fa4fb47db
21 changes: 13 additions & 8 deletions gpt_index/indices/common/struct_store/base.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
"""Common classes for structured operations."""

from typing import Dict, List, Optional, Sequence, cast, Any

from abc import abstractmethod
import logging
from abc import abstractmethod
from typing import Any, Callable, Dict, List, Optional, Sequence, cast

from gpt_index.data_structs.table import StructDatapoint
from gpt_index.prompts.default_prompts import DEFAULT_SCHEMA_EXTRACT_PROMPT
from gpt_index.prompts.prompts import SchemaExtractPrompt
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.response.builder import ResponseBuilder, TextChunk
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase
from gpt_index.langchain_helpers.text_splitter import TextSplitter
from gpt_index.prompts.default_prompts import (
DEFAULT_REFINE_TABLE_CONTEXT_PROMPT,
DEFAULT_SCHEMA_EXTRACT_PROMPT,
DEFAULT_TABLE_CONTEXT_PROMPT,
DEFAULT_TABLE_CONTEXT_QUERY,
)
from gpt_index.prompts.prompts import (
QuestionAnswerPrompt,
RefinePrompt,
RefineTableContextPrompt,
SchemaExtractPrompt,
TableContextPrompt,
)
from gpt_index.schema import BaseDocument
Expand Down Expand Up @@ -118,6 +118,9 @@ def build_table_context_from_documents(
return cast(str, table_context)


OUTPUT_PARSER_TYPE = Callable[[str], Optional[Dict[str, Any]]]


class BaseStructDatapointExtractor:
"""Extracts datapoints from a structured document."""

Expand All @@ -126,11 +129,13 @@ def __init__(
llm_predictor: LLMPredictor,
text_splitter: TextSplitter,
schema_extract_prompt: SchemaExtractPrompt,
output_parser: OUTPUT_PARSER_TYPE,
) -> None:
"""Initialize params."""
self._llm_predictor = llm_predictor
self._text_splitter = text_splitter
self._schema_extract_prompt = schema_extract_prompt
self._output_parser = output_parser

def _clean_and_validate_fields(self, fields: Dict[str, Any]) -> Dict[str, Any]:
"""Validate fields with col_types_map."""
Expand Down Expand Up @@ -172,8 +177,8 @@ def _get_col_types_map(self) -> Dict[str, type]:
def _get_schema_text(self) -> str:
"""Get schema text for extracting relevant info from unstructured text."""

def extract_datapoint_from_document(self, document: BaseDocument) -> Dict[str, str]:
"""Extract datapoint from a document."""
def insert_datapoint_from_document(self, document: BaseDocument) -> None:
"""Extract datapoint from a document and insert it."""
text_chunks = self._text_splitter.split_text(document.get_text())
fields = {}
for i, text_chunk in enumerate(text_chunks):
Expand All @@ -186,7 +191,7 @@ def extract_datapoint_from_document(self, document: BaseDocument) -> Dict[str, s
text=text_chunk,
schema=schema_text,
)
cur_fields = self.output_parser(response_str)
cur_fields = self._output_parser(response_str)
if cur_fields is None:
continue
# validate fields with col_types_map
Expand Down
42 changes: 33 additions & 9 deletions gpt_index/indices/common/struct_store/sql.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
"""SQL StructDatapointExtractor."""

from gpt_index.indices.common.struct_store.base import BaseStructDatapointExtractor
from typing import Any, Dict, Optional, cast

from sqlalchemy import Table

from gpt_index.data_structs.table import BaseStructTable, StructDatapoint
from gpt_index.indices.common.struct_store.base import (
OUTPUT_PARSER_TYPE,
BaseStructDatapointExtractor,
)
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase
from gpt_index.langchain_helpers.text_splitter import TextSplitter
from typing import Dict, Any, cast, Optional
from gpt_index.prompts.prompts import SchemaExtractPrompt
from gpt_index.langchain_helpers.sql_wrapper import SQLDatabase
from gpt_index.schema import BaseDocument
from gpt_index.utils import truncate_text

Expand All @@ -19,25 +25,43 @@ def __init__(
llm_predictor: LLMPredictor,
text_splitter: TextSplitter,
schema_extract_prompt: SchemaExtractPrompt,
output_parser: OUTPUT_PARSER_TYPE,
sql_database: SQLDatabase,
table_name: Optional[str] = None,
table: Optional[Table] = None,
ref_doc_id_column: Optional[str] = None,
) -> None:
"""Initialize params."""
super().__init__(llm_predictor, text_splitter, schema_extract_prompt)
self._llm_predictor = llm_predictor
self._text_splitter = text_splitter
self._schema_extract_prompt = schema_extract_prompt
super().__init__(
llm_predictor, text_splitter, schema_extract_prompt, output_parser
)
self._sql_database = sql_database
self._table_name = table_name
# currently the user must specify a table info
if table_name is None and table is None:
raise ValueError("table_name must be specified")
self._table_name = table_name or cast(Table, table).name
if table is None:
table = self._sql_database.metadata_obj.tables[table_name]
# if ref_doc_id_column is specified, then we need to check that
# it is a valid column in the table
col_names = [c.name for c in table.c]
if ref_doc_id_column is not None and ref_doc_id_column not in col_names:
raise ValueError(
f"ref_doc_id_column {ref_doc_id_column} not in table {table_name}"
)
self.ref_doc_id_column = ref_doc_id_column
# then store python types of each column
self._col_types_map: Dict[str, type] = {
c.name: table.c[c.name].type.python_type for c in table.c
}

def _get_col_types_map(self) -> Dict[str, type]:
"""Get col types map for schema."""
return self._col_types_map

def _get_schema_text(self) -> str:
"""Insert datapoint into index."""
return self._sql_database.get_single_table_info(self.table_name)
return self._sql_database.get_single_table_info(self._table_name)

def _insert_datapoint(self, datapoint: StructDatapoint) -> None:
"""Insert datapoint into index."""
Expand Down
81 changes: 0 additions & 81 deletions gpt_index/indices/struct_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,93 +65,12 @@ def __init__(
**kwargs,
)

@abstractmethod
def _insert_datapoint(self, datapoint: StructDatapoint) -> None:
"""Insert datapoint into index."""

@abstractmethod
def _get_col_types_map(self) -> Dict[str, type]:
"""Get col types map for schema."""

def _clean_and_validate_fields(self, fields: Dict[str, Any]) -> Dict[str, Any]:
"""Validate fields with col_types_map."""
new_fields = {}
col_types_map = self._get_col_types_map()
for field, value in fields.items():
clean_value = value
if field not in col_types_map:
continue
# if expected type is int or float, try to convert value to int or float
expected_type = col_types_map[field]
if expected_type == int:
try:
clean_value = int(value)
except ValueError:
continue
elif expected_type == float:
try:
clean_value = float(value)
except ValueError:
continue
else:
if len(value) == 0:
continue
if not isinstance(value, col_types_map[field]):
continue
new_fields[field] = clean_value
return new_fields

@abstractmethod
def _get_schema_text(self) -> str:
"""Get schema text for extracting relevant info from unstructured text."""

def _build_fallback_text_splitter(self) -> TextSplitter:
# if not specified, use "smart" text splitter to ensure chunks fit in prompt
return self._prompt_helper.get_text_splitter_given_prompt(
self.schema_extract_prompt, 1
)

# @abstractmethod
# def _add_document_to_index(
# self,
# document: BaseDocument,
# ) -> None:
# """Add document to index."""
# text_chunks = self._text_splitter.split_text(document.get_text())
# fields = {}
# for i, text_chunk in enumerate(text_chunks):
# fmt_text_chunk = truncate_text(text_chunk, 50)
# logging.info(f"> Adding chunk {i}: {fmt_text_chunk}")
# # if embedding specified in document, pass it to the Node
# schema_text = self._get_schema_text()
# response_str, _ = self._llm_predictor.predict(
# self.schema_extract_prompt,
# text=text_chunk,
# schema=schema_text,
# )
# cur_fields = self.output_parser(response_str)
# if cur_fields is None:
# continue
# # validate fields with col_types_map
# new_cur_fields = self._clean_and_validate_fields(cur_fields)
# fields.update(new_cur_fields)

# struct_datapoint = StructDatapoint(fields)
# if struct_datapoint is not None:
# self._insert_datapoint(struct_datapoint)
# logging.debug(f"> Added datapoint: {fields}")

# def _build_index_from_documents(self, documents: Sequence[BaseDocument]) -> BST:
# """Build index from documents."""
# index_struct = self.index_struct_cls()
# for d in documents:
# self._add_document_to_index(d)
# return index_struct

def _insert(self, document: BaseDocument, **insert_kwargs: Any) -> None:
"""Insert a document."""
self._add_document_to_index(document)

def _delete(self, doc_id: str, **delete_kwargs: Any) -> None:
"""Delete a document."""
raise NotImplementedError("Delete not implemented for Struct Store Index.")
Loading