Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add llamaindex integration #1170

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update llamaindex.py
ruff fix
  • Loading branch information
arnavsinghvi11 committed Jun 19, 2024
commit 07d8e1d1ee5dc43be09e9ff009251f196fb57cc0
38 changes: 17 additions & 21 deletions dspy/predict/llamaindex.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
from llama_index.core.prompts import BasePromptTemplate
from dspy import Predict
import dspy
from abc import abstractmethod
from typing import Any, Optional, List, Dict, Callable
import re
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional

from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.llms.generic_utils import (
prompt_to_messages,
)
from llama_index.core.base.query_pipeline.query import QueryComponent, InputKeys, OutputKeys
from llama_index.core.query_pipeline import QueryPipeline
from dspy.signatures.signature import ensure_signature, signature_to_template, infer_prefix, make_signature
from dspy.signatures.field import InputField, OutputField
from dspy.primitives import ProgramMeta
import dsp
from copy import deepcopy
import re
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.query_pipeline.query import InputKeys, OutputKeys, QueryComponent
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.bridge.pydantic import BaseModel, create_model
from llama_index.core.prompts import PromptTemplate

from llama_index.core.prompts import BasePromptTemplate, PromptTemplate
from llama_index.core.query_pipeline import QueryPipeline

import dsp
import dspy
from dspy import Predict
from dspy.signatures.field import InputField, OutputField
from dspy.signatures.signature import ensure_signature, make_signature, signature_to_template


def get_formatted_template(predict_module: Predict, kwargs: Dict[str, Any]) -> str:
Expand Down Expand Up @@ -78,7 +74,7 @@ def __init__(
metadata: Optional[Dict[str, Any]] = None,
template_var_mappings: Optional[Dict[str, Any]] = None,
function_mappings: Optional[Dict[str, Callable]] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
template = signature_to_template(predict_module.signature)
template_vars = _input_keys_from_template(template)
Expand Down Expand Up @@ -116,7 +112,7 @@ def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str:
return get_formatted_template(self.predict_module, mapped_kwargs)

def format_messages(
self, llm: Optional[BaseLLM] = None, **kwargs: Any
self, llm: Optional[BaseLLM] = None, **kwargs: Any,
) -> List[ChatMessage]:
"""Formats the prompt template into chat messages."""
del llm # unused
Expand All @@ -126,7 +122,7 @@ def format_messages(
def get_template(self, llm: Optional[BaseLLM] = None) -> str:
"""Get template."""
# get kwarg templates
kwarg_tmpl_map = {k: f"{{k}}" for k in self.template_vars}
kwarg_tmpl_map = {k: "{k}" for k in self.template_vars}

# get "raw" template with all the values filled in with {var_name}
template0 = get_formatted_template(self.predict_module, kwarg_tmpl_map)
Expand Down Expand Up @@ -264,4 +260,4 @@ def forward(self, **kwargs: Any) -> Dict[str, Any]:
"""Forward."""
output_dict = self.query_pipeline.run(**kwargs, return_values_direct=False)
return dspy.Prediction(**output_dict)


Loading