-
Notifications
You must be signed in to change notification settings - Fork 4
/
run_chatbot.py
49 lines (39 loc) · 1.76 KB
/
run_chatbot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from chatbot.prompting import build_prompt_for
from chatbot.model import run_raw_inference
import typing as t
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def inference_fn(model,tokenizer,history: str, user_input: str,
generation_settings: t.Dict[str, t.Any],
char_settings: t.Dict[str,t.Any],
history_length = 8,
count = 0,
) -> str:
char_name = char_settings["char_name"]
char_persona = char_settings["char_persona"]
char_greeting = char_settings["char_greeting"]
world_scenario = char_settings["world_scenario"]
example_dialogue = char_settings["example_dialogue"]
#print(char_persona)
if count == 0 and char_greeting is not None:
return f"{char_greeting}"
prompt = build_prompt_for(history=history,
user_message=user_input,
char_name=char_name,
char_persona=char_persona,
example_dialogue=example_dialogue,
world_scenario=world_scenario,
history_lenght=history_length)
model_output = run_raw_inference(model, tokenizer, prompt,
user_input, **generation_settings)
#remove last line and keep the last line before
last_line = model_output.splitlines()[-1]
list_lines = model_output.splitlines()
if last_line.startswith("You:"):
bot_message = list_lines[-2]
else:
bot_message = last_line
#remove the char name at the beginning of the line
bot_message = bot_message.replace(f"{char_name}: ","")
return bot_message