Skip to content

Commit

Permalink
fix multi-round dialogue.
Browse files Browse the repository at this point in the history
  • Loading branch information
fengyh3 committed Apr 21, 2023
1 parent 2621909 commit 422a9c5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
11 changes: 10 additions & 1 deletion generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer

def generate(self, args, prompts):
def generate(self, args, prompts, cut_off=None, cut_off_times=1):
if cut_off is not None:
cut_off_times = [cut_off_times for i in range(len(prompts))]
batch = len(prompts)
assert batch <= args.batch_size

Expand Down Expand Up @@ -118,6 +120,12 @@ def generate(self, args, prompts):
try:
t.index(self.tokenizer.eos_id)
except ValueError:
if cut_off is not None:
if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]:
if cut_off_times[i] == 1:
continue
else:
cut_off_times[i] -= 1
continue_exsample.append(i)
if len(continue_exsample) == 0:
break
Expand All @@ -126,6 +134,7 @@ def generate(self, args, prompts):
for i, t in enumerate(tokens.tolist()):
t = t[: args.seq_length]
try:
t = t[: t.index(self.tokenizer.pad_id)]
t = t[: t.index(self.tokenizer.eos_id)]
except ValueError:
pass
Expand Down
12 changes: 6 additions & 6 deletions llama_dialogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ def multi_round_chat(args, lm_generation, keep_length_ratio=0.5):

input_str = ''
for user, ans in zip(users, answers):
input_str += user + '\n' + ans + '\n'
input_str += user_input + '\n'
input_str += 'User: ' + user + '\nBot: ' + ans + '\n'
input_str += 'User: ' + user_input + '\nBot: '
if len(input_str) >= int(keep_length_ratio * args.seq_length):
input_str = input_str[len(input_str) - int(keep_length_ratio * args.seq_length):]
answer = lm_generation.generate(args, [input_str])[0]
answer = lm_generation.generate(args, [input_str], cut_off='User:', cut_off_times=1)[0]
answer = answer[len(input_str):]
print("ChatLLaMa: " + answer + '\n')
users.append(user_input)
answers.append(answer)
print("ChatLLaMa: " + answer.replace('User:', ''))
users.append(user_input.rstrip(' ').rstrip('\n'))
answers.append(answer.replace('User:', '').rstrip(' ').rstrip('\n'))


if __name__ == '__main__':
Expand Down

0 comments on commit 422a9c5

Please sign in to comment.