Skip to content

Commit

Permalink
local llama
Browse files Browse the repository at this point in the history
  • Loading branch information
streetycat committed Sep 28, 2023
1 parent e0c4eb5 commit 7b5c010
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 89 deletions.
2 changes: 1 addition & 1 deletion rootfs/agents/math_teacher/agent.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
instance_id = "math_teacher"
fullname = "the one"
llm_model_name = "gpt-4-0613"
llm_model_name = "LLaMA2-70B"
[[prompt]]
role = "system"
content = "你是精通数学的老师"
123 changes: 60 additions & 63 deletions src/aios_kernel/local_llama_compute_node.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@

import json
import logging
import requests
from typing import Optional, List
from pydantic import BaseModel
from llama_cpp import Llama

from .compute_task import ComputeTask, ComputeTaskState, ComputeTaskType
from .compute_task import ComputeTask, ComputeTaskResult, ComputeTaskState, ComputeTaskType
from .queue_compute_node import Queue_ComputeNode

logger = logging.getLogger(__name__)
Expand All @@ -14,69 +16,64 @@
"""

class LocalLlama_ComputeNode(Queue_ComputeNode):
async def execute_task(self, task: ComputeTask) -> {
"content": str,
"message": str,
"state": ComputeTaskState,
"error": {
"code": int,
"message": str,
}
}:
class GenerateResponse(BaseModel):
error: Optional[int]
msg: Optional[str]
results: Optional[List[str]]

try:
prompt_msgs = []
for prompt in task.params["prompts"]:
prompt_msgs.append(prompt["content"])
def __init__(self, model_path: str, model_name: str):
super().__init__()
self.model_path = model_path
self.model_name = model_name
self.llm = Llama(model_path=model_path)

async def execute_task(self, task: ComputeTask) -> ComputeTaskResult:
match task.task_type:
case ComputeTaskType.TEXT_EMBEDDING:
model_name = task.params["model_name"]
input = task.params["input"]
logger.info(f"call openai {model_name} input: {input}")

embedding = self.llm.embed(input=input)

body = {
"prompts": prompt_msgs
}

response = requests.post("http:https://aigc:7880/generate", json = body, verify=False, headers={"Content-Type": "application/json"})
response.close()

logger.info(f"LocalLlama_ComputeNode task responsed, request: {body}, status-code: {response.status_code}, headers: {response.headers}, content: {response.content}")

if response.status_code != 200:
return {
"state": ComputeTaskState.ERROR,
"error": {
"code": response.status_code,
"message": "http request failed: " + str(response.status_code)
}
}
else:
resp = response.json()
if "error" in resp:
return {
"state": ComputeTaskState.ERROR,
"error": {
"code": resp["error"],
"message": "local llama failed:" + resp["msg"]
}
}
else:
return {
"state": ComputeTaskState.DONE,
"content": str(resp["results"]),
"message": str(resp["results"])
}
except Exception as err:
import traceback
logger.error(f"{traceback.format_exc()}, error: {err}")
logger.info(f"local-llama({self.model_path}) response: {resp}")

result = ComputeTaskResult()
result.set_from_task(task)
result.result = embedding

return result
case ComputeTaskType.LLM_COMPLETION:
mode_name = task.params["model_name"]
prompts = task.params["prompts"]
max_token_size = task.params.get("max_token_size")
llm_inner_functions = task.params.get("inner_functions")
if max_token_size is None:
max_token_size = 4000

logger.info(f"local-llama({self.model_path}) prompts: {prompts}")

resp = self.llm.create_chat_completion(model=mode_name,
messages=prompts,
functions=llm_inner_functions, # function has not support?
max_tokens=max_token_size,
temperature=0.7) # TODO: add temperature to task params?


return {
"state": ComputeTaskState.ERROR,
"error": {
"code": -1,
"message": "unknown exception: " + str(err)
}
}
logger.info(f"local-llama({self.model_path}) response: {json.dumps(resp, indent=4)}")

result = ComputeTaskResult()
result.set_from_task(task)

status_code = resp["choices"][0]["finish_reason"]
match status_code:
case "function_call":
task.state = ComputeTaskState.DONE
case "stop":
task.state = ComputeTaskState.DONE
case _:
task.state = ComputeTaskState.ERROR
task.error_str = f"The status code was {status_code}."
return None

result.result_str = resp["choices"][0]["message"]["content"]
result.result_message = resp["choices"][0]["message"]
return result

async def initial(self) -> bool:
return True
Expand All @@ -88,7 +85,7 @@ def get_capacity(self):
pass

def is_support(self, task: ComputeTask) -> bool:
return task.task_type == ComputeTaskType.LLM_COMPLETION and (not task.params["model_name"] or task.params["model_name"] == "llama")
return (task.task_type == ComputeTaskType.TEXT_EMBEDDING or task.task_type == ComputeTaskType.LLM_COMPLETION) and (not task.params["model_name"] or task.params["model_name"] == self.model_name)

def is_local(self) -> bool:
return True
32 changes: 7 additions & 25 deletions src/aios_kernel/queue_compute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,7 @@ def __init__(self):
self.is_start = False

@abstractmethod
async def execute_task(self, task: ComputeTask) -> {
"content": str,
"message": str,
"state": ComputeTaskState,
"error": {
"code": int,
"message": str,
}
}:
async def execute_task(self, task: ComputeTask) -> ComputeTaskResult:
pass

async def push_task(self, task: ComputeTask, proiority: int = 0):
Expand All @@ -37,23 +29,13 @@ async def remove_task(self, task_id: str):
async def _run_task(self, task: ComputeTask):
task.state = ComputeTaskState.RUNNING

resp = await self.execute_task(task)

result = ComputeTaskResult()

result.worker_id = self.node_id
task.state = resp["state"]

if task.state == ComputeTaskState.ERROR:
result.result_code = ComputeTaskResultCode.ERROR
task.error_str = resp["error"]["message"]
result = await self.execute_task(task)
if result is not None:
result.set_from_task(task)
result.worker_id = self.node_id
else:
result.result_code = ComputeTaskResultCode.OK
result.result_str = resp["content"]
result.result_message = resp["message"]

result.set_from_task(task)

task.state = ComputeTaskState.ERROR

return result

def start(self):
Expand Down

0 comments on commit 7b5c010

Please sign in to comment.