Skip to content

Commit

Permalink
fix lints and also the bug in the get_current_user
Browse files Browse the repository at this point in the history
  • Loading branch information
Maryam Abdoli committed Nov 3, 2023
1 parent 83425b0 commit ef94983
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
7 changes: 3 additions & 4 deletions src/backend/langflow/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
APIRouter,
Depends,
HTTPException,
Query,
WebSocket,
WebSocketException,
status,
Expand All @@ -11,8 +10,9 @@
from langflow.api.utils import build_input_keys_response
from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData

from langflow.services.database.models.user.user import User
from langflow.graph.graph.base import Graph
from langflow.services.auth.utils import get_current_active_user, get_current_user
from langflow.services.auth.utils import get_current_active_user
from langflow.services.cache.utils import update_build_status
from loguru import logger
from langflow.services.getters import get_chat_service, get_session, get_cache_service
Expand All @@ -28,14 +28,13 @@
async def chat(
client_id: str,
websocket: WebSocket,
token: str = Query(...),
db: Session = Depends(get_session),
chat_service: "ChatService" = Depends(get_chat_service),
user: User = Depends(get_current_active_user),
):
"""Websocket endpoint for chat."""
try:
await websocket.accept()
user = await get_current_user(token, db)
if not user:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
Expand Down
17 changes: 8 additions & 9 deletions src/backend/langflow/services/auth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from langflow.services.getters import get_session, get_settings_service
from sqlmodel import Session

oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login")
oauth2_login = OAuth2PasswordBearer(tokenUrl="api/v1/login", auto_error=False)

API_KEY_NAME = "x-api-key"

Expand Down Expand Up @@ -74,23 +74,22 @@ async def get_current_user(
header_param: str = Security(api_key_header),
db: Session = Depends(get_session),
) -> User:
try:
if token:
return await get_current_user_by_jwt(token, db)
except HTTPException as exc:
else:
if not query_param and not header_param:
raise exc
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="An API key must be passed as query or header",
)
user = await api_key_security(query_param, header_param, db)
if user:
return user

raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid or missing API key",
)
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Internal server error: {exc}",
)


async def get_current_user_by_jwt(
Expand Down

0 comments on commit ef94983

Please sign in to comment.