Skip to content

Commit

Permalink
refactor(context/memory): remove ctx variable queries from schema
Browse files Browse the repository at this point in the history
  • Loading branch information
idiotWu committed Jul 30, 2024
1 parent d7f0981 commit 65ac002
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 3 deletions.
1 change: 1 addition & 0 deletions npiai/constant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CTX_QUERY_POSTFIX = "__ctx_query"
3 changes: 2 additions & 1 deletion npiai/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from npiai.context import Context
from npiai.core.hitl import HITL
from npiai.types import FunctionRegistration
from npiai.constant import CTX_QUERY_POSTFIX


class BaseTool(ABC):
Expand Down Expand Up @@ -73,7 +74,7 @@ async def exec(self, ctx: Context, fn_name: str, args: Dict[str, Any] = None):

# add context variables
for ctx_var in fn.ctx_variables:
query = args.pop(f"{ctx_var.name}__query", ctx_var.query)
query = args.pop(f"{ctx_var.name}{CTX_QUERY_POSTFIX}", ctx_var.query)

args[ctx_var.name] = await ctx.ask(
query=query,
Expand Down
3 changes: 2 additions & 1 deletion npiai/core/tool/_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_template_str,
)
from npiai.context import Context
from npiai.constant import CTX_QUERY_POSTFIX

__NPI_TOOL_ATTR__ = "__NPI_TOOL_ATTR__"

Expand Down Expand Up @@ -258,7 +259,7 @@ def _register_tools(self):
)

if is_template_str(anno.query):
param_fields[f"{p.name}__query"] = (
param_fields[f"{p.name}{CTX_QUERY_POSTFIX}"] = (
str,
Field(
default=anno.query,
Expand Down
10 changes: 9 additions & 1 deletion npiai/types/function_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from npiai.types.shot import Shot
from npiai.types.from_context import FromContext
from npiai.constant import CTX_QUERY_POSTFIX

ToolFunction = Callable[..., Awaitable[str]]

Expand All @@ -22,10 +23,17 @@ class FunctionRegistration:
few_shots: Optional[List[Shot]] = None

def get_meta(self):
params = {}

for name in self.schema:
# remove context variable queries
if not name.endswith(CTX_QUERY_POSTFIX):
params[name] = self.schema[name]

return {
"description": self.description,
"name": self.name,
"parameters": self.schema,
"parameters": params,
"fewShots": (
[asdict(ex) for ex in self.few_shots] if self.few_shots else None
),
Expand Down

0 comments on commit 65ac002

Please sign in to comment.