diff --git a/dspy/predict/llamaindex.py b/dspy/predict/llamaindex.py index e0fe0677c..d5c000250 100644 --- a/dspy/predict/llamaindex.py +++ b/dspy/predict/llamaindex.py @@ -1,26 +1,22 @@ -from llama_index.core.prompts import BasePromptTemplate -from dspy import Predict -import dspy -from abc import abstractmethod -from typing import Any, Optional, List, Dict, Callable +import re +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional + from llama_index.core.base.llms.base import BaseLLM -from llama_index.core.base.llms.types import ChatMessage from llama_index.core.base.llms.generic_utils import ( prompt_to_messages, ) -from llama_index.core.base.query_pipeline.query import QueryComponent, InputKeys, OutputKeys -from llama_index.core.query_pipeline import QueryPipeline -from dspy.signatures.signature import ensure_signature, signature_to_template, infer_prefix, make_signature -from dspy.signatures.field import InputField, OutputField -from dspy.primitives import ProgramMeta -import dsp -from copy import deepcopy -import re +from llama_index.core.base.llms.types import ChatMessage +from llama_index.core.base.query_pipeline.query import InputKeys, OutputKeys, QueryComponent from llama_index.core.callbacks.base import CallbackManager -from llama_index.core.bridge.pydantic import BaseModel, create_model -from llama_index.core.prompts import PromptTemplate - +from llama_index.core.prompts import BasePromptTemplate, PromptTemplate +from llama_index.core.query_pipeline import QueryPipeline +import dsp +import dspy +from dspy import Predict +from dspy.signatures.field import InputField, OutputField +from dspy.signatures.signature import ensure_signature, make_signature, signature_to_template def get_formatted_template(predict_module: Predict, kwargs: Dict[str, Any]) -> str: @@ -78,7 +74,7 @@ def __init__( metadata: Optional[Dict[str, Any]] = None, template_var_mappings: Optional[Dict[str, Any]] = None, function_mappings: Optional[Dict[str, Callable]] = None, - **kwargs: Any + **kwargs: Any, ) -> None: template = signature_to_template(predict_module.signature) template_vars = _input_keys_from_template(template) @@ -116,7 +112,7 @@ def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str: return get_formatted_template(self.predict_module, mapped_kwargs) def format_messages( - self, llm: Optional[BaseLLM] = None, **kwargs: Any + self, llm: Optional[BaseLLM] = None, **kwargs: Any, ) -> List[ChatMessage]: """Formats the prompt template into chat messages.""" del llm # unused @@ -126,7 +122,7 @@ def format_messages( def get_template(self, llm: Optional[BaseLLM] = None) -> str: """Get template.""" # get kwarg templates - kwarg_tmpl_map = {k: f"{{k}}" for k in self.template_vars} + kwarg_tmpl_map = {k: "{k}" for k in self.template_vars} # get "raw" template with all the values filled in with {var_name} template0 = get_formatted_template(self.predict_module, kwarg_tmpl_map) @@ -264,4 +260,4 @@ def forward(self, **kwargs: Any) -> Dict[str, Any]: """Forward.""" output_dict = self.query_pipeline.run(**kwargs, return_values_direct=False) return dspy.Prediction(**output_dict) - \ No newline at end of file +