Skip to content

Commit

Permalink
feat: improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
da-z committed Jan 23, 2024
1 parent 03233b4 commit f1761bc
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mlx_lm.utils import generate_step

title = "MLX Chat"
ver = "0.7.11"
ver = "0.7.12"
debug = False

with open('models.txt', 'r') as file:
Expand Down Expand Up @@ -128,34 +128,42 @@ def queue_chat(the_prompt, continuation=""):
if actions[0].button("😶‍🌫️ Forget", use_container_width=True,
help="Forget the previous conversations."):
st.session_state.messages = [{"role": "assistant", "content": assistant_greeting}]
if "prompt" in st.session_state and st.session_state["prompt"]:
st.session_state["prompt"] = None
st.session_state["continuation"] = None
st.rerun()

if actions[1].button("🔂 Continue", use_container_width=True,
help="Continue the generation."):

user_prompts = [msg["content"] for msg in st.session_state.messages if msg["role"] == "user"]
last_prompt = user_prompts[-1] or "Please continue your response."

assistant_responses = [msg["content"] for msg in st.session_state.messages if msg["role"] == "assistant"]
last_assistant_response = assistant_responses[-1] if assistant_responses else ""
if user_prompts:

# remove last line completely, so it is regenerated correctly (in case it stopped mid-word or mid-number)
last_assistant_response_lines = last_assistant_response.split('\n')
if len(last_assistant_response_lines) > 1:
last_assistant_response_lines.pop()
last_assistant_response = "\n".join(last_assistant_response_lines)
last_user_prompt = user_prompts[-1]

full_prompt = tokenizer.apply_chat_template([
{"role": "system", "content": system_prompt},
{"role": "user", "content": last_prompt},
{"role": "assistant", "content": last_assistant_response},
], tokenize=False, add_generation_prompt=False, chat_template=chatml_template)
full_prompt = full_prompt.rstrip("<|im_end|>\n")
assistant_responses = [msg["content"] for msg in st.session_state.messages
if msg["role"] == "assistant" and msg["content"] != assistant_greeting]
last_assistant_response = assistant_responses[-1] if assistant_responses else ""

# replace last assistant response from state, as it will be replaced with a continued one
remove_last_occurrence(st.session_state.messages, lambda msg: msg["role"] == "assistant")
# remove last line completely, so it is regenerated correctly (in case it stopped mid-word or mid-number)
last_assistant_response_lines = last_assistant_response.split('\n')
if len(last_assistant_response_lines) > 1:
last_assistant_response_lines.pop()
last_assistant_response = "\n".join(last_assistant_response_lines)

queue_chat(full_prompt, last_assistant_response)
full_prompt = tokenizer.apply_chat_template([
{"role": "system", "content": system_prompt},
{"role": "user", "content": last_user_prompt},
{"role": "assistant", "content": last_assistant_response},
], tokenize=False, add_generation_prompt=False, chat_template=chatml_template)
full_prompt = full_prompt.rstrip("<|im_end|>\n")

# replace last assistant response from state, as it will be replaced with a continued one
remove_last_occurrence(st.session_state.messages,
lambda msg: msg["role"] == "assistant" and msg["content"] != assistant_greeting)

queue_chat(full_prompt, last_assistant_response)

if prompt := st.chat_input():
st.session_state.messages.append({"role": "user", "content": prompt})
Expand Down

0 comments on commit f1761bc

Please sign in to comment.