Skip to content

Commit

Permalink
Refactor to support customization of embedding string(s) (run-llama#436)
Browse files Browse the repository at this point in the history
* save

* refactor to generalize query_str to query_bundle

* remove extraneous file

* remove extraneous file

* fix circular dependency

* fix tests

* notebook change

* wip

* wip

* fix types

* wip

* wip

* update rst

* add test

* f

* wip

* comments

* fix knowledge graph

* make interface simpler
  • Loading branch information
Disiok authored Feb 16, 2023
1 parent 157d59c commit 0204c12
Show file tree
Hide file tree
Showing 26 changed files with 333 additions and 74 deletions.
2 changes: 1 addition & 1 deletion docs/reference/indices/composability_query.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Composable Queries
.. automodule:: gpt_index.indices.query.schema
:members:
:inherited-members:
:exclude-members:
:exclude-members: QueryBundle


.. automodule:: gpt_index.data_structs.struct_type
Expand Down
12 changes: 10 additions & 2 deletions docs/reference/query.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ We then show how to define a query config in order to recursively query
multiple indices that are `composed </how_to/composability.html>`_ together.
We then show the base query class, which contains parameters that are shared
among all queries.

Lastly, we show how to customize the string(s) used for an embedding-based query.
.. toctree::
:maxdepth: 1
:caption: Index-specific Query Subclasses
Expand All @@ -32,11 +32,19 @@ multiple indices that are `composed </how_to/composability.html>`_ together.

indices/composability_query.rst


Base Query Class
^^^^^^^^^^^^^^^^

.. automodule:: gpt_index.indices.query.base
:members:
:inherited-members:
:exclude-members: BaseQueryRunner

Query Bundle
^^^^^^^^^^^^^^^^
Query bundle enables user to customize the string(s) used for embedding-based query.

.. automodule:: gpt_index.indices.query.schema
:members: QueryBundle
:inherited-members:
:exclude-members:
48 changes: 46 additions & 2 deletions examples/vector_indices/SimpleIndexDemo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@
"cell_type": "code",
"execution_count": null,
"id": "ad144ee7-96da-4dd6-be00-fd6cf0c78e58",
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"index = GPTSimpleVectorIndex(documents)"
Expand Down Expand Up @@ -107,6 +109,48 @@
"display(Markdown(f\"<b>{response}</b>\"))"
]
},
{
"cell_type": "markdown",
"id": "0da9092e",
"metadata": {},
"source": [
"**Query Index with custom embedding string**"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d57f2c87",
"metadata": {},
"outputs": [],
"source": [
"from gpt_index.indices.query.schema import QueryBundle"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bbecbdb5",
"metadata": {},
"outputs": [],
"source": [
"query_bundle = QueryBundle(\n",
" query_str=\"What did the author do growing up?\", \n",
" custom_embedding_strs=['The author grew up painting.']\n",
")\n",
"response = index.query(query_bundle)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d4d1e028",
"metadata": {},
"outputs": [],
"source": [
"display(Markdown(f\"<b>{response}</b>\"))"
]
},
{
"cell_type": "markdown",
"id": "5636a15c-8938-4809-958b-03b8c445ecbd",
Expand Down Expand Up @@ -150,7 +194,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.1"
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions gpt_index/composability/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from gpt_index.indices.list.base import GPTListIndex
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.query.query_runner import QueryRunner
from gpt_index.indices.query.schema import QueryConfig
from gpt_index.indices.query.schema import QueryBundle, QueryConfig
from gpt_index.indices.registry import IndexRegistry
from gpt_index.indices.struct_store.sql import GPTSQLStructStoreIndex
from gpt_index.indices.tree.base import GPTTreeIndex
Expand Down Expand Up @@ -106,7 +106,7 @@ def build_from_index(self, index: BaseGPTIndex) -> "ComposableGraph":

def query(
self,
query_str: str,
query_str: Union[str, QueryBundle],
query_configs: Optional[List[QUERY_CONFIG_TYPE]] = None,
) -> Response:
"""Query the index."""
Expand Down
15 changes: 15 additions & 0 deletions gpt_index/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
EMB_TYPE = List


def mean_agg(embeddings: List[List[float]]) -> List[float]:
"""Mean aggregation for embeddings."""
return list(np.array(embeddings).mean(axis=0))


class SimilarityMode(str, Enum):
"""Modes for similarity/distance."""

Expand Down Expand Up @@ -40,6 +45,16 @@ def get_query_embedding(self, query: str) -> List[float]:
self._total_tokens_used += query_tokens_count
return query_embedding

def get_agg_embedding_from_queries(
self,
queries: List[str],
agg_fn: Optional[Callable[..., List[float]]] = None,
) -> List[float]:
"""Get aggregated embedding from multiple queries."""
query_embeddings = [self.get_query_embedding(query) for query in queries]
agg_fn = agg_fn or mean_agg
return agg_fn(query_embeddings)

@abstractmethod
def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
Expand Down
4 changes: 2 additions & 2 deletions gpt_index/indices/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.query.base import BaseGPTIndexQuery
from gpt_index.indices.query.query_runner import QueryRunner
from gpt_index.indices.query.schema import QueryConfig, QueryMode
from gpt_index.indices.query.schema import QueryBundle, QueryConfig, QueryMode
from gpt_index.indices.registry import IndexRegistry
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
from gpt_index.langchain_helpers.text_splitter import TokenTextSplitter
Expand Down Expand Up @@ -334,7 +334,7 @@ def _preprocess_query(self, mode: QueryMode, query_kwargs: Dict) -> None:

def query(
self,
query_str: str,
query_str: Union[str, QueryBundle],
mode: str = QueryMode.DEFAULT,
**query_kwargs: Any,
) -> Response:
Expand Down
29 changes: 16 additions & 13 deletions gpt_index/indices/query/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from gpt_index.embeddings.openai import OpenAIEmbedding
from gpt_index.indices.prompt_helper import PromptHelper
from gpt_index.indices.query.embedding_utils import SimilarityTracker
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.indices.response.builder import ResponseBuilder, ResponseMode, TextChunk
from gpt_index.indices.utils import truncate_text
from gpt_index.langchain_helpers.chain_wrapper import LLMPredictor
Expand All @@ -31,7 +32,7 @@ class BaseQueryRunner:
"""Base query runner."""

@abstractmethod
def query(self, query: str, index_struct: IndexStruct) -> Response:
def query(self, query_bundle: QueryBundle, index_struct: IndexStruct) -> Response:
"""Schedule a query."""
raise NotImplementedError("Not implemented yet.")

Expand Down Expand Up @@ -147,7 +148,7 @@ def _should_use_node(

def _get_text_from_node(
self,
query_str: str,
query_bundle: QueryBundle,
node: Node,
level: Optional[int] = None,
) -> Tuple[TextChunk, Optional[Response]]:
Expand Down Expand Up @@ -176,7 +177,7 @@ def _get_text_from_node(

if is_index_struct:
query_runner = cast(BaseQueryRunner, self._query_runner)
response = query_runner.query(query_str, cast(IndexStruct, doc))
response = query_runner.query(query_bundle, cast(IndexStruct, doc))
return TextChunk(str(response), is_answer=True), response
else:
text = node.get_text()
Expand Down Expand Up @@ -207,7 +208,7 @@ def _give_response_for_nodes(
return response or ""

def get_nodes_and_similarities_for_response(
self, query_str: str
self, query_bundle: QueryBundle
) -> List[Tuple[Node, Optional[float]]]:
"""Get list of tuples of node and similarity for response.
Expand All @@ -217,7 +218,7 @@ def get_nodes_and_similarities_for_response(
"""
similarity_tracker = SimilarityTracker()
nodes = self._get_nodes_for_response(
query_str, similarity_tracker=similarity_tracker
query_bundle, similarity_tracker=similarity_tracker
)
nodes = [
node for node in nodes if self._should_use_node(node, similarity_tracker)
Expand All @@ -229,18 +230,18 @@ def get_nodes_and_similarities_for_response(
@abstractmethod
def _get_nodes_for_response(
self,
query_str: str,
query_bundle: QueryBundle,
similarity_tracker: Optional[SimilarityTracker] = None,
) -> List[Node]:
"""Get nodes for response."""

def _query(self, query_str: str) -> Response:
def _query(self, query_bundle: QueryBundle) -> Response:
"""Answer a query."""
# TODO: remove _query and just use query
tuples = self.get_nodes_and_similarities_for_response(query_str)
tuples = self.get_nodes_and_similarities_for_response(query_bundle)
node_texts = []
for node, similarity in tuples:
text, response = self._get_text_from_node(query_str, node)
text, response = self._get_text_from_node(query_bundle, node)
self.response_builder.add_node(node, similarity=similarity)
if response is not None:
# these are source nodes from within this node (when it's an index)
Expand All @@ -249,16 +250,18 @@ def _query(self, query_str: str) -> Response:
node_texts.append(text)

if self._response_mode != ResponseMode.NO_TEXT:
response_str = self._give_response_for_nodes(query_str, node_texts)
response_str = self._give_response_for_nodes(
query_bundle.query_str, node_texts
)
else:
response_str = None

return Response(response_str, source_nodes=self.response_builder.get_sources())

@llm_token_counter("query")
def query(self, query_str: str) -> Response:
def query(self, query_bundle: QueryBundle) -> Response:
"""Answer a query."""
response = self._query(query_str)
response = self._query(query_bundle)
# if include_summary is True, then include summary text in answer
# summary text is set through `set_text` on the underlying index.
# TODO: refactor response builder to be in the __init__
Expand All @@ -272,7 +275,7 @@ def query(self, query_str: str) -> Response:
)
# NOTE: use create and refine for now (default response mode)
response.response = response_builder.get_response(
query_str,
query_bundle.query_str,
mode=self._response_mode,
prev_response=response.response,
)
Expand Down
9 changes: 5 additions & 4 deletions gpt_index/indices/query/keyword_table/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
)
from gpt_index.indices.query.base import BaseGPTIndexQuery
from gpt_index.indices.query.embedding_utils import SimilarityTracker
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.indices.utils import truncate_text
from gpt_index.prompts.default_prompts import (
DEFAULT_KEYWORD_EXTRACT_TEMPLATE,
Expand Down Expand Up @@ -67,13 +68,13 @@ def _get_keywords(self, query_str: str) -> List[str]:

def _get_nodes_for_response(
self,
query_str: str,
query_bundle: QueryBundle,
similarity_tracker: Optional[SimilarityTracker] = None,
) -> List[Node]:
"""Get nodes for response."""
logging.info(f"> Starting query: {query_str}")
keywords = self._get_keywords(query_str)
logging.info(f"> query keywords: {keywords}")
logging.info(f"> Starting query: {query_bundle.query_str}")
keywords = self._get_keywords(query_bundle.query_str)
logging.info(f"query keywords: {keywords}")

# go through text chunks in order of most matching keywords
chunk_indices_count: Dict[int, int] = defaultdict(int)
Expand Down
7 changes: 4 additions & 3 deletions gpt_index/indices/query/knowledge_graph/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from gpt_index.indices.keyword_table.utils import extract_keywords_given_response
from gpt_index.indices.query.base import BaseGPTIndexQuery
from gpt_index.indices.query.embedding_utils import SimilarityTracker
from gpt_index.indices.query.schema import QueryBundle
from gpt_index.indices.utils import truncate_text
from gpt_index.prompts.default_prompts import DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE
from gpt_index.prompts.prompts import QueryKeywordExtractPrompt
Expand Down Expand Up @@ -62,12 +63,12 @@ def _get_keywords(self, query_str: str) -> List[str]:

def _get_nodes_for_response(
self,
query_str: str,
query_bundle: QueryBundle,
similarity_tracker: Optional[SimilarityTracker] = None,
) -> List[Node]:
"""Get nodes for response."""
logging.info(f"> Starting query: {query_str}")
keywords = self._get_keywords(query_str)
logging.info(f"> Starting query: {query_bundle.query_str}")
keywords = self._get_keywords(query_bundle.query_str)
logging.info(f"> Query keywords: {keywords}")
rel_texts = []
chunk_indices_count: Dict[str, int] = defaultdict(int)
Expand Down
11 changes: 7 additions & 4 deletions gpt_index/indices/query/list/embedding_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_top_k_embeddings,
)
from gpt_index.indices.query.list.query import BaseGPTListIndexQuery
from gpt_index.indices.query.schema import QueryBundle


class GPTListIndexEmbeddingQuery(BaseGPTListIndexQuery):
Expand Down Expand Up @@ -44,13 +45,13 @@ def __init__(

def _get_nodes_for_response(
self,
query_str: str,
query_bundle: QueryBundle,
similarity_tracker: Optional[SimilarityTracker] = None,
) -> List[Node]:
"""Get nodes for response."""
nodes = self.index_struct.nodes
# top k nodes
query_embedding, node_embeddings = self._get_embeddings(query_str, nodes)
query_embedding, node_embeddings = self._get_embeddings(query_bundle, nodes)

top_similarities, top_idxs = get_top_k_embeddings(
self._embed_model,
Expand All @@ -72,10 +73,12 @@ def _get_nodes_for_response(
return top_k_nodes

def _get_embeddings(
self, query_str: str, nodes: List[Node]
self, query_bundle: QueryBundle, nodes: List[Node]
) -> Tuple[List[float], List[List[float]]]:
"""Get top nodes by similarity to the query."""
query_embedding = self._embed_model.get_query_embedding(query_str)
query_embedding = self._embed_model.get_agg_embedding_from_queries(
query_bundle.embedding_strs
)
node_embeddings: List[List[float]] = []
for node in self.index_struct.nodes:
if node.embedding is not None:
Expand Down
3 changes: 2 additions & 1 deletion gpt_index/indices/query/list/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from gpt_index.data_structs.data_structs import IndexList, Node
from gpt_index.indices.query.base import BaseGPTIndexQuery
from gpt_index.indices.query.embedding_utils import SimilarityTracker
from gpt_index.indices.query.schema import QueryBundle


class BaseGPTListIndexQuery(BaseGPTIndexQuery[IndexList]):
Expand Down Expand Up @@ -38,7 +39,7 @@ class GPTListIndexQuery(BaseGPTListIndexQuery):

def _get_nodes_for_response(
self,
query_str: str,
query_bundle: QueryBundle,
similarity_tracker: Optional[SimilarityTracker] = None,
) -> List[Node]:
"""Get nodes for response."""
Expand Down
Loading

0 comments on commit 0204c12

Please sign in to comment.