Skip to content

Commit

Permalink
Fix the parameters type conversion when invoking actions from colang …
Browse files Browse the repository at this point in the history
…(previously everything was string).
  • Loading branch information
drazvan committed Jul 27, 2023
1 parent 39ae7a1 commit 0e941ac
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 13 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed the Cohere prompt templates.
- [#55](https://github.com/NVIDIA/NeMo-Guardrails/issues/83): Fix bug related to LangChain callbacks initialization.
- Fixed generation of "..." on value generation.

- Fixed the parameters type conversion when invoking actions from colang (previously everything was string).

## [0.3.0] - 2023-06-30

Expand Down
24 changes: 12 additions & 12 deletions nemoguardrails/language/coyml_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
This also transpiles correctly to JS to be used on the client side.
"""
import json
import re
from ast import literal_eval
from typing import List

from .utils import get_stripped_tokens, split_max, word_split
from .utils import get_stripped_tokens, split_args, split_max, word_split


def _to_value(s, remove_quotes: bool = False):
Expand All @@ -34,16 +36,14 @@ def _to_value(s, remove_quotes: bool = False):
TODO: other useful value shorthands
"""
if s == "None":
return None

if remove_quotes and len(s) > 0 and s[0] == '"' and s[-1] == '"':
return s[1:-1]

if isinstance(s, str) and s.isnumeric():
return int(s)

return s
if isinstance(s, str):
# If it's a reference to a variable, we leave as is.
if re.match(r"\$([a-zA-Z_][a-zA-Z0-9_]*)", s):
return s
else:
return literal_eval(s)
else:
return s


def _extract_inline_params(d_value, d_params):
Expand All @@ -54,7 +54,7 @@ def _extract_inline_params(d_value, d_params):
assert params_str[-1] == ")", f"Incorrect params str: {params_str}"

params_str = params_str[0:-1]
param_pairs = get_stripped_tokens(word_split(params_str, ","))
param_pairs = get_stripped_tokens(split_args(params_str))

for pair in param_pairs:
# Skip empty pairs
Expand Down
45 changes: 45 additions & 0 deletions nemoguardrails/language/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,51 @@ def split_max(text, separator, max_instances):
return parts


def split_args(args_str: str) -> List[str]:
"""Split a string that represents arguments for a function.
It supports keyword arguments and also correctly handles strings and lists/dicts.
Args:
args_str: The string with the arguments e.g. 'name="John", colors=["blue", "red"]'
Returns:
The string that correspond to each individual argument value.
"""

parts = []
stack = []

current = []

closing_char = {"[": "]", "(": ")", "{": "}", "'": "'", '"': '"'}

for char in args_str:
if char in "([{":
stack.append(char)
current.append(char)
elif char in "\"'" and (len(stack) == 0 or stack[-1] != char):
stack.append(char)
current.append(char)
elif char in ")]}\"'":
if char != closing_char[stack[-1]]:
raise ValueError(
f"Invalid syntax for string: {args_str}; "
f"expecting {closing_char[stack[-1]]} and got {char}"
)
stack.pop()
current.append(char)
elif char == "," and len(stack) == 0:
parts.append("".join(current))
current = []
else:
current.append(char)

parts.append("".join(current))

return [part.strip() for part in parts]


def get_numbered_lines(content: str):
"""Helper to returned numbered lines.
Expand Down
54 changes: 54 additions & 0 deletions tests/test_action_params_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

from nemoguardrails import RailsConfig
from tests.utils import TestChat

config = RailsConfig.from_content(
"""
define user express greeting
"hello"
define flow
user express greeting
execute custom_action(name="John", age=20, height=5.8, colors=["blue", "green"], data={'a': 1})
bot express greeting
"""
)


def test_1():
chat = TestChat(
config,
llm_completions=[
" express greeting",
' "Hello there!"',
],
)

async def custom_action(
name: str, age: int, height: float, colors: List[str], data: dict
):
assert name == "John"
assert age == 20
assert height == 5.8
assert colors == ["blue", "green"]
assert data == {"a": 1}

chat.app.register_action(custom_action)

chat >> "Hello!"
chat << "Hello there!"
32 changes: 32 additions & 0 deletions tests/test_parser_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemoguardrails.language.utils import split_args


def test_1():
assert split_args("1") == ["1"]
assert split_args('1, "a"') == ["1", '"a"']
assert split_args("1, [1,2,3]") == ["1", "[1,2,3]"]
assert split_args("1, numbers = [1,2,3]") == ["1", "numbers = [1,2,3]"]
assert split_args("1, data = {'name': 'John'}") == ["1", "data = {'name': 'John'}"]
assert split_args("'a,b, c'") == ["'a,b, c'"]

assert split_args("1, 'a,b, c', x=[1,2,3], data = {'name': 'John'}") == [
"1",
"'a,b, c'",
"x=[1,2,3]",
"data = {'name': 'John'}",
]

0 comments on commit 0e941ac

Please sign in to comment.