Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch: 0.1.1 后端代码提交 #9

Merged
merged 2 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
patch: 0.1.1 后端代码提交
  • Loading branch information
yaojin3616 committed Sep 8, 2023
commit 02f84c6c62ba87895d4dd34e982de62327c42894
2 changes: 1 addition & 1 deletion docker/bisheng/config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# 数据库配置
database_url:
"mysql+pymysql:https://username:[email protected]:3306/langflow"
"mysql+pymysql:https://username:[email protected]:3306/bisheng"
redis_url:
"192.168.106.116:6379"

Expand Down
9 changes: 9 additions & 0 deletions src/backend/bisheng/api/JWT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from pydantic import BaseModel


class Settings(BaseModel):
authjwt_secret_key: str = 'xI$xO.oN$sC}tC^oQ(fF^nK~dB&uT('
# Configure application to store and get JWT from cookies
authjwt_token_location: set = {'cookies'}
# Disable CSRF Protection for this example. default is True
authjwt_cookie_csrf_protect: bool = False
14 changes: 8 additions & 6 deletions src/backend/bisheng/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def build_input_keys_response(langchain_object, artifacts):

return input_keys_response

def build_flow(graph_data:dict, artifacts, process_file=False, flow_id=None, chat_id=None):

def build_flow(graph_data: dict, artifacts, process_file=False, flow_id=None, chat_id=None):
try:
# Some error could happen when building the graph
graph = Graph.from_payload(graph_data)
Expand Down Expand Up @@ -92,12 +93,12 @@ def build_flow(graph_data:dict, artifacts, process_file=False, flow_id=None, cha
# 过滤掉文件
vertex.params[key] = ''

#vectore store 引入自动建库逻辑
# vectore store 引入自动建库逻辑
# 聊天窗口等flow 主动生成的vector 需要新建临时collection
# tmp_{chat_id}
if vertex.base_type == 'vectorstores':
if 'collection_name' in vertex.params and not vertex.params.get('collection_name'):
vertex.params['collection_name']=f'tmp_{flow_id}_{chat_id}'
vertex.params['collection_name'] = f'tmp_{flow_id}_{chat_id}'

vertex.build()
params = vertex._built_object_repr()
Expand Down Expand Up @@ -131,7 +132,8 @@ def build_flow(graph_data:dict, artifacts, process_file=False, flow_id=None, cha
yield str(StreamData(event='message', data=response))
return graph

def build_flow_no_yield(graph_data:dict, artifacts, process_file=False, flow_id=None, chat_id=None):

def build_flow_no_yield(graph_data: dict, artifacts, process_file=False, flow_id=None, chat_id=None):
try:
# Some error could happen when building the graph
graph = Graph.from_payload(graph_data)
Expand All @@ -153,12 +155,12 @@ def build_flow_no_yield(graph_data:dict, artifacts, process_file=False, flow_id=
# 过滤掉文件
vertex.params[key] = ''

#vectore store 引入自动建库逻辑
# vectore store 引入自动建库逻辑
# 聊天窗口等flow 主动生成的vector 需要新建临时collection
# tmp_{chat_id}
if vertex.base_type == 'vectorstores':
if 'collection_name' in vertex.params and not vertex.params.get('collection_name'):
vertex.params['collection_name']=f'tmp_{flow_id}_{chat_id}'
vertex.params['collection_name'] = f'tmp_{flow_id}_{chat_id}'

vertex.build()
params = vertex._built_object_repr()
Expand Down
9 changes: 2 additions & 7 deletions src/backend/bisheng/api/v1/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def on_tool_start(
async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""
observation_prefix = kwargs.get('observation_prefix', 'Tool output: ')
from langchain.docstore.document import Document
from langchain.docstore.document import Document # noqa
result = eval(output).get('result')

# Create a formatted message.
Expand All @@ -73,7 +73,6 @@ async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
intermediate_steps=intermediate_steps,
)


try:
# This is to emulate the stream of tokens
await self.websocket.send_json(resp.dict())
Expand Down Expand Up @@ -138,8 +137,6 @@ def on_agent_action(self, action: AgentAction, **kwargs: Any) -> Any:
loop = asyncio.get_event_loop()
coroutine = self.websocket.send_json(resp.dict())
asyncio.run_coroutine_threadsafe(coroutine, loop)


else:
resp = ChatResponse(message='', type='stream', intermediate_steps=log)
loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -174,10 +171,8 @@ def on_tool_end(self, output: str, **kwargs: Any) -> Any:
"""Run when tool ends running."""
observation_prefix = kwargs.get('observation_prefix', 'Tool output: ')

from langchain.docstore.document import Document
from langchain.docstore.document import Document # noqa
result = eval(output).get('result')


# Create a formatted message.
intermediate_steps = f'{observation_prefix}{result}'

Expand Down
41 changes: 31 additions & 10 deletions src/backend/bisheng/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from fastapi.encoders import jsonable_encoder
from fastapi.params import Depends
from fastapi.responses import StreamingResponse
from fastapi_jwt_auth import AuthJWT
from sqlalchemy import func
from sqlmodel import Session, select

router = APIRouter(tags=['Chat'])
Expand All @@ -32,41 +34,60 @@ def get_chatmessage(*,
chat_id: str,
flow_id: str,
id: Optional[str] = None,
page_size: Optional[int] = 20):
page_size: Optional[int] = 20,
Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())
if not chat_id or not flow_id:
return {'code': 500, 'message': 'chat_id 和 flow_id 必传参数'}
where = select(ChatMessage).where(ChatMessage.flow_id == flow_id, ChatMessage.chat_id == chat_id)
where = select(ChatMessage).where(ChatMessage.flow_id == flow_id, ChatMessage.chat_id == chat_id,
ChatMessage.user_id == payload.get('user_id'))
if id:
where = where.where(ChatMessage.id < id)
db_message = session.exec(where.order_by(ChatMessage.id.desc()).limit(page_size)).all()
return [jsonable_encoder(message) for message in db_message]


@router.get('/chat/list', response_model=List[ChatList], status_code=200)
def get_chatmessage_list(
*,
session: Session = Depends(get_session),
):
db_message = session.exec(select(ChatMessage).group_by(ChatMessage.flow_id)).all()
def get_chatlist_list(*, session: Session = Depends(get_session), Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())

smt = (select(
ChatMessage.flow_id, ChatMessage.chat_id, ChatMessage.chat_id,
func.max(ChatMessage.create_time).label('create_time'),
func.max(ChatMessage.update_time).label('update_time')).where(ChatMessage.user_id == payload.get('user_id')).group_by(
ChatMessage.flow_id).order_by(func.max(ChatMessage.create_time).desc()))
db_message = session.exec(smt).all()
flow_ids = [message.flow_id for message in db_message]
db_flow = session.exec(select(Flow).where(Flow.id.in_(flow_ids))).all()
# set object
chat_list = []
flow_dict = {flow.id: flow for flow in db_flow}
for i, message in enumerate(db_message):
if message.flow_id not in flow_dict:
# flow 被删除
continue
chat_list.append(
ChatList(flow_name=flow_dict[message.flow_id].name,
flow_description=flow_dict[message.flow_id].description,
flow_id=message.flow_id,
chat_id=message.chat_id,
create_time=message.create_time,
update_time=message.update_time))
return [jsonable_encoder(chat) for chat in chat_list]


@router.websocket('/chat/{client_id}')
async def chat(client_id: str, websocket: WebSocket, chat_id: Optional[str] = None, type: Optional[str] = None):
async def chat(client_id: str,
websocket: WebSocket,
chat_id: Optional[str] = None,
type: Optional[str] = None,
Authorize: AuthJWT = Depends()):
Authorize.jwt_required(auth_from='websocket', websocket=websocket)
payload = json.loads(Authorize.get_jwt_subject())
user_id = payload.get('user_id')
"""Websocket endpoint for chat."""

if type and type == 'L1':
with next(get_session()) as session:
db_flow = session.get(Flow, client_id)
Expand Down Expand Up @@ -98,7 +119,7 @@ async def chat(client_id: str, websocket: WebSocket, chat_id: Optional[str] = No
key_node = get_cache_key(client_id, chat_id, node.id)
chat_manager.set_cache(key_node, node._built_object)
chat_manager.set_cache(get_cache_key(client_id, chat_id), node._built_object)
await chat_manager.handle_websocket(client_id, chat_id, websocket)
await chat_manager.handle_websocket(client_id, chat_id, websocket, user_id)
except WebSocketException as exc:
logger.error(exc)
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
Expand Down
87 changes: 57 additions & 30 deletions src/backend/bisheng/api/v1/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,29 @@
from bisheng.database.base import get_session
from bisheng.database.models.flow import (Flow, FlowCreate, FlowRead,
FlowReadWithStyle, FlowUpdate)
from bisheng.database.models.template import Template
from bisheng.database.models.user import User
from bisheng.settings import settings
from fastapi import APIRouter, Depends, File, HTTPException, Query, UploadFile
from fastapi.encoders import jsonable_encoder
from fastapi_jwt_auth import AuthJWT
from sqlmodel import Session, select

# build router
router = APIRouter(prefix='/flows', tags=['Flows'])


@router.post('/', response_model=FlowRead, status_code=201)
def create_flow(*, session: Session = Depends(get_session), flow: FlowCreate):
def create_flow(*, session: Session = Depends(get_session), flow: FlowCreate, Authorize: AuthJWT = Depends()):
"""Create a new flow."""
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())

if flow.flow_id:
# copy from template
temp_flow = session.get(Flow, flow.flow_id)
flow.data = temp_flow.data
flow.user_id = payload.get('user_id')
db_flow = Flow.from_orm(flow)
session.add(db_flow)
session.commit()
Expand All @@ -31,28 +38,41 @@ def create_flow(*, session: Session = Depends(get_session), flow: FlowCreate):


@router.get('/', response_model=list[FlowReadWithStyle], status_code=200)
def read_flows(
*,
session: Session = Depends(get_session),
name: str = Query(default=None, description='根据name查找数据库'),
page_size: int = Query(default=None, description='根据pagesize查找数据库'),
page_num: int = Query(default=None, description='根据pagenum查找数据库'),
status: int = None
):
def read_flows(*,
session: Session = Depends(get_session),
name: str = Query(default=None, description='根据name查找数据库'),
page_size: int = Query(default=None, description='根据pagesize查找数据库'),
page_num: int = Query(default=None, description='根据pagenum查找数据库'),
status: int = None,
Authorize: AuthJWT = Depends()):
"""Read all flows."""
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())

try:
sql = select(Flow)
if 'admin' != payload.get('role'):
sql = sql.where(Flow.user_id == payload.get('user_id'))
if name:
sql = sql.where(Flow.name.like(f'%{name}%'))
if status:
sql = sql.where(Flow.status == status)

sql = sql.order_by(Flow.update_time.desc())
if page_num and page_size:
sql = sql.offset((page_num-1) * page_size).limit(page_size)
sql = sql.offset((page_num - 1) * page_size).limit(page_size)

flows = session.exec(sql).all()
return [jsonable_encoder(flow) for flow in flows]

res = [jsonable_encoder(flow) for flow in flows]
if flows:
db_user_ids = {flow.user_id for flow in flows}
db_user = session.exec(select(User).where(User.user_id.in_(db_user_ids))).all()
userMap = {user.user_id: user.user_name for user in db_user}
for r in res:
r['user_name'] = userMap[r['user_id']]

return res

except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) from e
Expand All @@ -68,26 +88,24 @@ def read_flow(*, session: Session = Depends(get_session), flow_id: UUID):


@router.patch('/{flow_id}', response_model=FlowRead, status_code=200)
def update_flow(
*,
session: Session = Depends(get_session),
flow_id: UUID,
flow: FlowUpdate
):
def update_flow(*, session: Session = Depends(get_session), flow_id: UUID, flow: FlowUpdate, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())
"""Update a flow."""
db_flow = session.get(Flow, flow_id)
if not db_flow:
raise HTTPException(status_code=404, detail='Flow not found')

if 'admin' != payload.get('role') and db_flow.user_id != payload.get('user_id'):
raise HTTPException(status_code=500, detail='没有权限编辑此技能')

flow_data = flow.dict(exclude_unset=True)

if 'status' in flow_data and flow_data['status'
] == 2 and db_flow.status == 1:
if 'status' in flow_data and flow_data['status'] == 2 and db_flow.status == 1:
# 上线校验
try:
art = {}
build_flow_no_yield(
graph_data=db_flow.data, artifacts=art, process_file=False
)
build_flow_no_yield(graph_data=db_flow.data, artifacts=art, process_file=False)
except Exception as exc:
raise HTTPException(status_code=500, detail='Flow 编译不通过') from exc

Expand All @@ -102,25 +120,36 @@ def update_flow(


@router.delete('/{flow_id}', status_code=200)
def delete_flow(*, session: Session = Depends(get_session), flow_id: UUID):
def delete_flow(*, session: Session = Depends(get_session), flow_id: UUID, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())
"""Delete a flow."""
flow = session.get(Flow, flow_id)
if not flow:
raise HTTPException(status_code=404, detail='Flow not found')
if 'admin' != payload.get('role') and flow.user_id != payload.get('user_id'):
raise HTTPException(status_code=500, detail='没有权限删除此技能')

# 判断是否属于模板
db_template = session.exec(select(Template).where(Template.flow_id == flow_id)).first()
if db_template:
session.delete(db_template)

session.delete(flow)
session.commit()
return {'message': 'Flow deleted successfully'}


# Define a new model to handle multiple flows
@router.post('/batch/', response_model=List[FlowRead], status_code=201)
def create_flows(
*, session: Session = Depends(get_session), flow_list: FlowListCreate
):
def create_flows(*, session: Session = Depends(get_session), flow_list: FlowListCreate, Authorize: AuthJWT = Depends()):
Authorize.jwt_required()
payload = json.loads(Authorize.get_jwt_subject())
"""Create multiple new flows."""
db_flows = []
for flow in flow_list.flows:
db_flow = Flow.from_orm(flow)
db_flow.user_id = payload.get('user_id')
session.add(db_flow)
db_flows.append(db_flow)
session.commit()
Expand All @@ -130,17 +159,15 @@ def create_flows(


@router.post('/upload/', response_model=List[FlowRead], status_code=201)
async def upload_file(
*, session: Session = Depends(get_session), file: UploadFile = File(...)
):
async def upload_file(*, session: Session = Depends(get_session), file: UploadFile = File(...), Authorize: AuthJWT = Depends()):
"""Upload flows from a file."""
contents = await file.read()
data = json.loads(contents)
if 'flows' in data:
flow_list = FlowListCreate(**data)
else:
flow_list = FlowListCreate(flows=[FlowCreate(**flow) for flow in data])
return create_flows(session=session, flow_list=flow_list)
return create_flows(session=session, flow_list=flow_list, Authorize=Authorize)


@router.get('/download/', response_model=FlowListRead, status_code=200)
Expand Down
Loading