Skip to content

Commit

Permalink
Update llamaindex.py
Browse files Browse the repository at this point in the history
ruff fix
  • Loading branch information
arnavsinghvi11 committed Jun 19, 2024
1 parent 0f85bc3 commit 07d8e1d
Showing 1 changed file with 17 additions and 21 deletions.
38 changes: 17 additions & 21 deletions dspy/predict/llamaindex.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,22 @@
from llama_index.core.prompts import BasePromptTemplate
from dspy import Predict
import dspy
from abc import abstractmethod
from typing import Any, Optional, List, Dict, Callable
import re
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional

from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.llms.generic_utils import (
prompt_to_messages,
)
from llama_index.core.base.query_pipeline.query import QueryComponent, InputKeys, OutputKeys
from llama_index.core.query_pipeline import QueryPipeline
from dspy.signatures.signature import ensure_signature, signature_to_template, infer_prefix, make_signature
from dspy.signatures.field import InputField, OutputField
from dspy.primitives import ProgramMeta
import dsp
from copy import deepcopy
import re
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.query_pipeline.query import InputKeys, OutputKeys, QueryComponent
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.bridge.pydantic import BaseModel, create_model
from llama_index.core.prompts import PromptTemplate

from llama_index.core.prompts import BasePromptTemplate, PromptTemplate
from llama_index.core.query_pipeline import QueryPipeline

import dsp
import dspy
from dspy import Predict
from dspy.signatures.field import InputField, OutputField
from dspy.signatures.signature import ensure_signature, make_signature, signature_to_template


def get_formatted_template(predict_module: Predict, kwargs: Dict[str, Any]) -> str:
Expand Down Expand Up @@ -78,7 +74,7 @@ def __init__(
metadata: Optional[Dict[str, Any]] = None,
template_var_mappings: Optional[Dict[str, Any]] = None,
function_mappings: Optional[Dict[str, Callable]] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
template = signature_to_template(predict_module.signature)
template_vars = _input_keys_from_template(template)
Expand Down Expand Up @@ -116,7 +112,7 @@ def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str:
return get_formatted_template(self.predict_module, mapped_kwargs)

def format_messages(
self, llm: Optional[BaseLLM] = None, **kwargs: Any
self, llm: Optional[BaseLLM] = None, **kwargs: Any,
) -> List[ChatMessage]:
"""Formats the prompt template into chat messages."""
del llm # unused
Expand All @@ -126,7 +122,7 @@ def format_messages(
def get_template(self, llm: Optional[BaseLLM] = None) -> str:
"""Get template."""
# get kwarg templates
kwarg_tmpl_map = {k: f"{{k}}" for k in self.template_vars}
kwarg_tmpl_map = {k: "{k}" for k in self.template_vars}

# get "raw" template with all the values filled in with {var_name}
template0 = get_formatted_template(self.predict_module, kwarg_tmpl_map)
Expand Down Expand Up @@ -264,4 +260,4 @@ def forward(self, **kwargs: Any) -> Dict[str, Any]:
"""Forward."""
output_dict = self.query_pipeline.run(**kwargs, return_values_direct=False)
return dspy.Prediction(**output_dict)


0 comments on commit 07d8e1d

Please sign in to comment.