Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add node postprocessor abstraction #684

Merged
merged 7 commits into from
Mar 11, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
cr
  • Loading branch information
jerryjliu committed Mar 11, 2023
commit 1a5015cb98cbd91ec3d9c2718833e4408cb4c0cc
16 changes: 7 additions & 9 deletions gpt_index/indices/postprocessor/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,16 @@ def postprocess_nodes(
) -> List[Node]:
"""Postprocess nodes."""
extra_info = extra_info or {}
similarity_tracker = extra_info.get("similarity_tracker", None)
if similarity_tracker is None:
return nodes
sim_cutoff_exists = (
similarity_tracker is not None and self.similarity_cutoff is not None
)

new_nodes = []
for node in nodes:
should_use_node = True
similarity_tracker = extra_info.get("similarity_tracker")
if similarity_tracker is None:
raise ValueError(
"Similarity tracker is required for similarity postprocessor."
)
sim_cutoff_exists = (
similarity_tracker is not None and self.similarity_cutoff is not None
)

if sim_cutoff_exists:
similarity = cast(SimilarityTracker, similarity_tracker).find(node)
if similarity is None:
Expand Down
5 changes: 4 additions & 1 deletion gpt_index/indices/query/knowledge_graph/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,11 @@ def _get_nodes_for_response(
self.index_struct.text_chunks[idx] for idx in sorted_chunk_indices
]
# filter sorted nodes
postprocess_info = {"similarity_tracker": similarity_tracker}
for node_processor in self.node_preprocessors:
sorted_nodes = node_processor.postprocess_nodes(sorted_nodes)
sorted_nodes = node_processor.postprocess_nodes(
sorted_nodes, postprocess_info
)

# TMP/TODO: also filter rel_texts as nodes until we figure out better
# abstraction
Expand Down
4 changes: 2 additions & 2 deletions tests/indices/knowledge_graph/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ def test_query_similarity(
assert isinstance(response.extra_info, dict)
assert len(response.extra_info["kg_rel_texts"]) == 2

# Filters out all embeddings
# Filters embeddings
try:
response = index.query("foo", similarity_cutoff=-1.0)
response = index.query("foo", similarity_cutoff=10000000)
except ValueError as e:
assert str(e) == "kg_rel_map must be found in at least one Node."