Skip to content

Commit

Permalink
Patch: 0.1.1 后端代码提交 (#9)
Browse files Browse the repository at this point in the history
1. 增加用户权限设计
2. 修复一些bug
  • Loading branch information
yaojin3616 committed Sep 8, 2023
2 parents 1afccbf + 02f84c6 commit 9e0539c
Show file tree
Hide file tree
Showing 33 changed files with 552 additions and 418 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ jobs:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}

- name: Bash echo
shell: bash
- name: Get version
id: get_version
run: |
echo "${{ github.ref }}"
echo ""
echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//}
# 构建 backend 并推送到 Docker hub
- name: Build backend and push
id: docker_build_backend
Expand All @@ -48,11 +48,11 @@ jobs:
# docker build arg, 注入 APP_NAME/APP_VERSION
build-args: |
APP_NAME="bisheng-backend"
APP_VERSION=0.1.1
APP_VERSION=${{ steps.get_version.outputs.VERSION }}
# 生成两个 docker tag: ${APP_VERSION} 和 latest
tags: |
${{ env.DOCKERHUB_REPO }}bisheng-backend:latest
${{ env.DOCKERHUB_REPO }}bisheng-backend:0.1.1
${{ env.DOCKERHUB_REPO }}bisheng-backend:${{ steps.get_version.outputs.VERSION }}
# 构建 Docker frontend 并推送到 Docker hub
- name: Build frontend and push
id: docker_build_frontend
Expand All @@ -65,8 +65,8 @@ jobs:
# docker build arg, 注入 APP_NAME/APP_VERSION
build-args: |
APP_NAME="bisheng-frontend"
APP_VERSION=0.1.1
APP_VERSION=${{ steps.get_version.outputs.VERSION }}
# 生成两个 docker tag: ${APP_VERSION} 和 latest
tags: |
${{ env.DOCKERHUB_REPO }}bisheng-frontend:latest
${{ env.DOCKERHUB_REPO }}bisheng-frontend:0.1.1
${{ env.DOCKERHUB_REPO }}bisheng-frontend:${{ steps.get_version.outputs.VERSION }}
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Bisheng是一款领先的开源<b>大模型应用开发平台</b>,赋能和加

“毕昇”是活字印刷术的发明人,活字印刷术为人类知识的传递起到了巨大的推动作用。我们希望“毕昇”同样能够为智能应用的广泛落地提供有力的支撑。欢迎大家一道参与。

Bisheng 基于 [Apache 2.0 License](https://github.com/dataelement/bisheng/main/LICENSE) 协议发布,于 2023 年 8 月底正式开源。
Bisheng 基于 [Apache 2.0 License](https://github.com/dataelement/bisheng/blob/main/LICENSE) 协议发布,于 2023 年 8 月底正式开源。


## 产品亮点
Expand Down
11 changes: 9 additions & 2 deletions docker/bisheng/config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# 数据库配置
database_url:
"mysql+pymysql:https://username:[email protected]:3306/langflow"
"mysql+pymysql:https://username:[email protected]:3306/bisheng"
redis_url:
"192.168.106.116:6379"


# 为知识库的embedding进行模型撇脂
embedding_config:
text-embedding-ada-002:
base_url:
""
multilingual-e5-large:
base_url:
""

agents:
ZeroShotAgent:
Expand Down
9 changes: 9 additions & 0 deletions src/backend/bisheng/api/JWT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pydantic import BaseModel


class Settings(BaseModel):
authjwt_secret_key: str = 'xI$xO.oN$sC}tC^oQ(fF^nK~dB&uT('
# Configure application to store and get JWT from cookies
authjwt_token_location: set = {'cookies'}
# Disable CSRF Protection for this example. default is True
authjwt_cookie_csrf_protect: bool = False
14 changes: 8 additions & 6 deletions src/backend/bisheng/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def build_input_keys_response(langchain_object, artifacts):

return input_keys_response

def build_flow(graph_data:dict, artifacts, process_file=False, flow_id=None, chat_id=None):

def build_flow(graph_data: dict, artifacts, process_file=False, flow_id=None, chat_id=None):
try:
# Some error could happen when building the graph
graph = Graph.from_payload(graph_data)
Expand Down Expand Up @@ -92,12 +93,12 @@ def build_flow(graph_data:dict, artifacts, process_file=False, flow_id=None, cha
# 过滤掉文件
vertex.params[key] = ''

#vectore store 引入自动建库逻辑
# vectore store 引入自动建库逻辑
# 聊天窗口等flow 主动生成的vector 需要新建临时collection
# tmp_{chat_id}
if vertex.base_type == 'vectorstores':
if 'collection_name' in vertex.params and not vertex.params.get('collection_name'):
vertex.params['collection_name']=f'tmp_{flow_id}_{chat_id}'
vertex.params['collection_name'] = f'tmp_{flow_id}_{chat_id}'

vertex.build()
params = vertex._built_object_repr()
Expand Down Expand Up @@ -131,7 +132,8 @@ def build_flow(graph_data:dict, artifacts, process_file=False, flow_id=None, cha
yield str(StreamData(event='message', data=response))
return graph

def build_flow_no_yield(graph_data:dict, artifacts, process_file=False, flow_id=None, chat_id=None):

def build_flow_no_yield(graph_data: dict, artifacts, process_file=False, flow_id=None, chat_id=None):
try:
# Some error could happen when building the graph
graph = Graph.from_payload(graph_data)
Expand All @@ -153,12 +155,12 @@ def build_flow_no_yield(graph_data:dict, artifacts, process_file=False, flow_id=
# 过滤掉文件
vertex.params[key] = ''

#vectore store 引入自动建库逻辑
# vectore store 引入自动建库逻辑
# 聊天窗口等flow 主动生成的vector 需要新建临时collection
# tmp_{chat_id}
if vertex.base_type == 'vectorstores':
if 'collection_name' in vertex.params and not vertex.params.get('collection_name'):
vertex.params['collection_name']=f'tmp_{flow_id}_{chat_id}'
vertex.params['collection_name'] = f'tmp_{flow_id}_{chat_id}'

vertex.build()
params = vertex._built_object_repr()
Expand Down
9 changes: 2 additions & 7 deletions src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def on_tool_start(
async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""
observation_prefix = kwargs.get('observation_prefix', 'Tool output: ')
from langchain.docstore.document import Document
from langchain.docstore.document import Document # noqa
result = eval(output).get('result')

# Create a formatted message.
Expand All @@ -73,7 +73,6 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
intermediate_steps=intermediate_steps,
)


try:
# This is to emulate the stream of tokens
await self.websocket.send_json(resp.dict())
Expand Down Expand Up @@ -138,8 +137,6 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
loop = asyncio.get_event_loop()
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)


else:
resp = ChatResponse(message='', type='stream', intermediate_steps=log)
loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -174,10 +171,8 @@ def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""
observation_prefix = kwargs.get('observation_prefix', 'Tool output: ')

from langchain.docstore.document import Document
from langchain.docstore.document import Document # noqa
result = eval(output).get('result')


# Create a formatted message.
intermediate_steps = f'{observation_prefix}{result}'

Expand Down
119 changes: 80 additions & 39 deletions src/backend/bisheng/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,78 +18,117 @@
from fastapi.encoders import jsonable_encoder
from fastapi.params import Depends
from fastapi.responses import StreamingResponse
from fastapi_jwt_auth import AuthJWT
from sqlalchemy import func
from sqlmodel import Session, select

router = APIRouter(tags=['Chat'])
chat_manager = ChatManager()
flow_data_store = redis_client
expire = 600 #reids 60s 过期
expire = 600 # reids 60s 过期


@router.get('/chat/history', response_model=List[ChatMessageRead], status_code=200)
def get_chatmessage(*, session: Session = Depends(get_session), chat_id:str, flow_id:str, id:Optional[str]=None, page_size:Optional[int] = 20):
def get_chatmessage(*,
session: Session = Depends(get_session),
chat_id: str,
flow_id: str,
id: Optional[str] = None,
page_size: Optional[int] = 20,
Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())
if not chat_id or not flow_id:
return {'code': 500, 'message': 'chat_id 和 flow_id 必传参数'}
where = select(ChatMessage).where(ChatMessage.flow_id == flow_id,
ChatMessage.chat_id == chat_id)
where = select(ChatMessage).where(ChatMessage.flow_id == flow_id, ChatMessage.chat_id == chat_id,
ChatMessage.user_id == payload.get('user_id'))
if id:
where = where.where(ChatMessage.id < id)
db_message = session.exec(where.order_by(ChatMessage.id.desc()).limit(page_size)).all()
return [jsonable_encoder(message) for message in db_message]


@router.get('/chat/list', response_model=List[ChatList], status_code=200)
def get_chatmessage(*, session: Session = Depends(get_session), ):
db_message = session.exec(select(ChatMessage).group_by(ChatMessage.flow_id)).all()
def get_chatlist_list(*, session: Session = Depends(get_session), Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())

smt = (select(
ChatMessage.flow_id, ChatMessage.chat_id, ChatMessage.chat_id,
func.max(ChatMessage.create_time).label('create_time'),
func.max(ChatMessage.update_time).label('update_time')).where(ChatMessage.user_id == payload.get('user_id')).group_by(
ChatMessage.flow_id).order_by(func.max(ChatMessage.create_time).desc()))
db_message = session.exec(smt).all()
flow_ids = [message.flow_id for message in db_message]
db_flow = session.exec(select(Flow).where(Flow.id.in_(flow_ids))).all()
# set object
chat_list = []
flow_dict={flow.id: flow for flow in db_flow}
flow_dict = {flow.id: flow for flow in db_flow}
for i, message in enumerate(db_message):
chat_list.append(ChatList(flow_name = flow_dict[message.flow_id].name, flow_description=flow_dict[message.flow_id].description,
flow_id= message.flow_id, create_time=message.create_time, update_time=message.update_time))
if message.flow_id not in flow_dict:
# flow 被删除
continue
chat_list.append(
ChatList(flow_name=flow_dict[message.flow_id].name,
flow_description=flow_dict[message.flow_id].description,
flow_id=message.flow_id,
chat_id=message.chat_id,
create_time=message.create_time,
update_time=message.update_time))
return [jsonable_encoder(chat) for chat in chat_list]


@router.websocket('/chat/{client_id}')
async def chat(client_id: str, websocket: WebSocket, chat_id: Optional[str]=None, type:Optional[str]=None):
async def chat(client_id: str,
websocket: WebSocket,
chat_id: Optional[str] = None,
type: Optional[str] = None,
Authorize: AuthJWT = Depends()):
Authorize.jwt_required(auth_from='websocket', websocket=websocket)
payload = json.loads(Authorize.get_jwt_subject())
user_id = payload.get('user_id')
"""Websocket endpoint for chat."""

if type and type == 'L1':
with next(get_session()) as session:
db_flow = session.get(Flow, client_id)
if not db_flow:
await websocket.accept()
message = '该技能已被删除'
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=message)
if db_flow.status !=2:
if db_flow.status != 2:
await websocket.accept()
message = '当前技能未上线,无法直接对话'
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=message)
graph_data = db_flow.data
else:
flow_data_key = 'flow_data_' + client_id
if str(flow_data_store.hget(flow_data_key, 'status'), 'utf-8')!= BuildStatus.SUCCESS.value:
if str(flow_data_store.hget(flow_data_key, 'status'), 'utf-8') != BuildStatus.SUCCESS.value:
await websocket.accept()
message = '当前编译没通过'
await websocket.close(code=status.WS_1013_TRY_AGAIN_LATER, reason=message)
graph_data = json.loads(flow_data_store.hget(flow_data_key, 'graph_data'))

try:
graph = build_flow_no_yield(graph_data=graph_data, artifacts={}, process_file=False, flow_id=UUID(client_id).hex, chat_id=chat_id)
graph = build_flow_no_yield(graph_data=graph_data,
artifacts={},
process_file=False,
flow_id=UUID(client_id).hex,
chat_id=chat_id)
langchain_object = graph.build()
for node in langchain_object:
key_node = get_cache_key(client_id, chat_id, node.id)
chat_manager.set_cache(key_node, node._built_object)
chat_manager.set_cache(get_cache_key(client_id, chat_id), node._built_object)
await chat_manager.handle_websocket(client_id, chat_id, websocket)
await chat_manager.handle_websocket(client_id, chat_id, websocket, user_id)
except WebSocketException as exc:
logger.error(exc)
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
except Exception as e:
logger.error(str(e))


@router.post('/build/init/{flow_id}', response_model=InitResponse, status_code=201)
async def init_build(*, graph_data: dict, session: Session = Depends(get_session),flow_id: str):
async def init_build(*, graph_data: dict, session: Session = Depends(get_session), flow_id: str):
"""Initialize the build by storing graph data and returning a unique session ID."""
chat_id = graph_data.get('chat_id')
if chat_id:
Expand All @@ -100,30 +139,30 @@ async def init_build(*, graph_data: dict, session: Session = Depends(get_session
raise ValueError('No ID provided')
# Check if already building
flow_data_key = 'flow_data_' + flow_id
if flow_data_store.hget(flow_data_key, 'status')== BuildStatus.IN_PROGRESS.value:
if flow_data_store.hget(flow_data_key, 'status') == BuildStatus.IN_PROGRESS.value:
return InitResponse(flowId=flow_id)

# Delete from cache if already exists
flow_data_store.hset(flow_data_key, map = {
'graph_data': json.dumps(graph_data),
'status': BuildStatus.STARTED.value}, expiration=expire)
flow_data_store.hset(flow_data_key,
map={
'graph_data': json.dumps(graph_data),
'status': BuildStatus.STARTED.value
},
expiration=expire)

return InitResponse(flowId=flow_id)
except Exception as exc:
logger.error(exc)
return HTTPException(status_code=500, detail=str(exc))


@router.get('/build/{flow_id}/status', response_model=BuiltResponse)
async def build_status(flow_id: str):
"""Check the flow_id is in the flow_data_store."""
try:
flow_data_key = 'flow_data_' + flow_id
built = (
flow_data_store.hget(flow_data_key, 'status') == BuildStatus.SUCCESS.value
)
return BuiltResponse(
built=built,
)
built = (flow_data_store.hget(flow_data_key, 'status') == BuildStatus.SUCCESS.value)
return BuiltResponse(built=built,)

except Exception as exc:
logger.error(exc)
Expand All @@ -134,7 +173,7 @@ async def build_status(flow_id: str):
async def stream_build(flow_id: str, chat_id: Optional[str] = None):
"""Stream the build process based on stored flow data."""

async def event_stream(flow_id, chat_id:str):
async def event_stream(flow_id, chat_id: str):
final_response = {'end_of_stream': True}
artifacts = {}
try:
Expand All @@ -159,10 +198,14 @@ async def event_stream(flow_id, chat_id:str):
logger.debug('Building langchain object')
flow_data_store.hsetkey(flow_data_key, 'status', BuildStatus.IN_PROGRESS.value, expire)

#L1 用户,采用build流程
# L1 用户,采用build流程
try:
process_file= False if chat_id else True
graph = build_flow(graph_data=graph_data, artifacts=artifacts, process_file=process_file, flow_id=UUID(flow_id).hex, chat_id=chat_id)
process_file = False if chat_id else True
graph = build_flow(graph_data=graph_data,
artifacts=artifacts,
process_file=process_file,
flow_id=UUID(flow_id).hex,
chat_id=chat_id)
while True:
value = next(graph)
yield value
Expand All @@ -176,21 +219,19 @@ async def event_stream(flow_id, chat_id:str):
langchain_object = graph.build()
# Now we need to check the input_keys to send them to the client
input_keys_response = {
'input_keys': [],
'memory_keys': [],
'handle_keys': [],
}
'input_keys': [],
'memory_keys': [],
'handle_keys': [],
}
for node in langchain_object:
if hasattr(node._built_object, 'input_keys'):
input_keys = build_input_keys_response(
node._built_object, artifacts
)
input_keys['input_keys'].update({'id':node.id})
input_keys = build_input_keys_response(node._built_object, artifacts)
input_keys['input_keys'].update({'id': node.id})
input_keys_response['input_keys'].append(input_keys.get('input_keys'))
input_keys_response['memory_keys'].extend(input_keys.get('memory_keys'))
input_keys_response['handle_keys'].extend(input_keys.get('handle_keys'))
elif ('fileNode' in node.output):
input_keys_response['input_keys'].append({'file_path':'', 'type': 'file', 'id': node.id})
input_keys_response['input_keys'].append({'file_path': '', 'type': 'file', 'id': node.id})

yield str(StreamData(event='message', data=input_keys_response))
# We need to reset the chat history
Expand Down
Loading

0 comments on commit 9e0539c

Please sign in to comment.