174 lines
9.8 KiB
Python
174 lines
9.8 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from fastapi import Depends, Query
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
from fastapi_common.fastapi_common_security.security import verify_access_token
|
|
from fastapi_common.fastapi_common_web.controller import BaseController
|
|
from fastapi_common.fastapi_common_web.domain.responses import Result
|
|
|
|
from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import (
|
|
RagConversationRenameDTO,
|
|
RagChatSendMessageDTO,
|
|
RagMessageFeedbackDTO,
|
|
)
|
|
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
|
|
RagAppParametersVO,
|
|
RagChatAppListVO,
|
|
RagChatAppVO,
|
|
RagConversationPageVO,
|
|
RagConversationRenameVO,
|
|
RagMessagePageVO,
|
|
RagOperationResultVO,
|
|
)
|
|
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import RagDatasetPageVO
|
|
from fastapi_modules.fastapi_leaudit.services.impl.permissionServiceImpl import PermissionServiceImpl
|
|
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
|
|
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
|
|
from fastapi_modules.fastapi_leaudit.services.permissionService import IPermissionService
|
|
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
|
|
from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService
|
|
|
|
|
|
class RagChatController(BaseController):
|
|
_PERMISSIONS = {
|
|
"chat_use": "rag:chat:use",
|
|
"conversation_read": "rag:conversation:read",
|
|
"conversation_update": "rag:conversation:update",
|
|
"conversation_delete": "rag:conversation:delete",
|
|
"message_feedback": "rag:message:feedback",
|
|
"app_read": "rag:app:read",
|
|
"dataset_read": "rag:dataset:read",
|
|
}
|
|
|
|
def __init__(self):
|
|
super().__init__(prefix="/v3/rag", tags=["RAG 聊天"])
|
|
self.RagChatService: IRagChatService = RagChatServiceImpl()
|
|
self.RagDatasetService: IRagDatasetService = RagDatasetServiceImpl()
|
|
self.PermissionService: IPermissionService = PermissionServiceImpl()
|
|
|
|
@self.router.get("/apps", response_model=Result[RagChatAppListVO])
|
|
async def GetApps(payload: dict[str, Any] = Depends(verify_access_token)):
|
|
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["app_read"]]):
|
|
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看聊天应用权限", "data": None})
|
|
data = await self.RagChatService.GetApps(
|
|
CurrentUserId=int(payload["user_id"]),
|
|
UserArea=payload.get("area"),
|
|
UserRole=payload.get("user_role"),
|
|
)
|
|
return Result.success(data=data)
|
|
|
|
@self.router.get("/apps/default", response_model=Result[RagChatAppVO | None])
|
|
async def GetDefaultApp(payload: dict[str, Any] = Depends(verify_access_token)):
|
|
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["app_read"]]):
|
|
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看默认聊天应用权限", "data": None})
|
|
data = await self.RagChatService.GetDefaultApp(
|
|
CurrentUserId=int(payload["user_id"]),
|
|
UserArea=payload.get("area"),
|
|
UserRole=payload.get("user_role"),
|
|
)
|
|
return Result.success(data=data)
|
|
|
|
@self.router.get("/datasets/my", response_model=Result[RagDatasetPageVO])
|
|
async def GetMyDatasets(payload: dict[str, Any] = Depends(verify_access_token)):
|
|
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]):
|
|
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看知识库权限", "data": None})
|
|
data = await self.RagDatasetService.GetMyDatasets(
|
|
CurrentUserId=int(payload["user_id"]),
|
|
UserArea=payload.get("area"),
|
|
UserRole=payload.get("user_role"),
|
|
)
|
|
return Result.success(data=data)
|
|
|
|
@self.router.get("/chat/parameters", response_model=Result[RagAppParametersVO])
|
|
async def GetAppParameters(
|
|
appId: int | None = Query(None, description="聊天应用ID"),
|
|
payload: dict[str, Any] = Depends(verify_access_token),
|
|
):
|
|
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"], self._PERMISSIONS["app_read"]]):
|
|
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看聊天配置权限", "data": None})
|
|
data = await self.RagChatService.GetAppParameters(
|
|
CurrentUserId=int(payload["user_id"]),
|
|
UserArea=payload.get("area"),
|
|
UserRole=payload.get("user_role"),
|
|
AppId=appId,
|
|
)
|
|
return Result.success(data=data)
|
|
|
|
@self.router.post("/chat/messages")
|
|
async def SendMessage(Body: RagChatSendMessageDTO, payload: dict[str, Any] = Depends(verify_access_token)):
|
|
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"]]):
|
|
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有使用 RAG 对话权限", "data": None})
|
|
stream = self.RagChatService.SendMessage(
|
|
CurrentUserId=int(payload["user_id"]),
|
|
UserName=payload.get("username") or str(payload.get("user_id")),
|
|
UserArea=payload.get("area"),
|
|
UserRole=payload.get("user_role"),
|
|
Query=Body.query,
|
|
ConversationId=Body.conversationId,
|
|
AppId=Body.appId,
|
|
)
|
|
return StreamingResponse(
|
|
stream,
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive"},
|
|
)
|
|
|
|
@self.router.get("/chat/conversations", response_model=Result[RagConversationPageVO])
|
|
async def GetConversations(
|
|
appId: int | None = Query(None, description="聊天应用ID"),
|
|
page: int = Query(1, ge=1),
|
|
pageSize: int = Query(20, ge=1, le=100),
|
|
payload: dict[str, Any] = Depends(verify_access_token),
|
|
):
|
|
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["conversation_read"]]):
|
|
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看聊天会话权限", "data": None})
|
|
data = await self.RagChatService.GetConversations(int(payload["user_id"]), appId, page, pageSize)
|
|
return Result.success(data=data)
|
|
|
|
@self.router.get("/chat/conversations/{ConversationId}/messages", response_model=Result[RagMessagePageVO])
|
|
async def GetConversationMessages(
|
|
ConversationId: str,
|
|
page: int = Query(1, ge=1),
|
|
pageSize: int = Query(20, ge=1, le=100),
|
|
payload: dict[str, Any] = Depends(verify_access_token),
|
|
):
|
|
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["conversation_read"]]):
|
|
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看聊天消息权限", "data": None})
|
|
data = await self.RagChatService.GetConversationMessages(int(payload["user_id"]), ConversationId, page, pageSize)
|
|
return Result.success(data=data)
|
|
|
|
@self.router.patch("/chat/conversations/{ConversationId}", response_model=Result[RagConversationRenameVO])
|
|
async def RenameConversation(
|
|
ConversationId: str,
|
|
Body: RagConversationRenameDTO,
|
|
payload: dict[str, Any] = Depends(verify_access_token),
|
|
):
|
|
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["conversation_update"]]):
|
|
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有修改聊天会话权限", "data": None})
|
|
data = await self.RagChatService.RenameConversation(int(payload["user_id"]), ConversationId, Body)
|
|
return Result.success(data=data)
|
|
|
|
@self.router.delete("/chat/conversations/{ConversationId}", response_model=Result[RagOperationResultVO])
|
|
async def DeleteConversation(ConversationId: str, payload: dict[str, Any] = Depends(verify_access_token)):
|
|
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["conversation_delete"]]):
|
|
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有删除聊天会话权限", "data": None})
|
|
data = await self.RagChatService.DeleteConversation(int(payload["user_id"]), ConversationId)
|
|
return Result.success(data=data)
|
|
|
|
@self.router.post("/chat/messages/{MessageId}/feedback", response_model=Result[RagOperationResultVO])
|
|
async def UpdateFeedback(
|
|
MessageId: str,
|
|
Body: RagMessageFeedbackDTO,
|
|
payload: dict[str, Any] = Depends(verify_access_token),
|
|
):
|
|
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["message_feedback"]]):
|
|
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有反馈聊天消息权限", "data": None})
|
|
data = await self.RagChatService.UpdateFeedback(int(payload["user_id"]), MessageId, Body)
|
|
return Result.success(data=data)
|
|
|
|
async def _check_permission(self, user_id: int, permission_keys: list[str]) -> bool:
|
|
return await self.PermissionService.HasAnyPermission(UserId=user_id, PermissionKeys=permission_keys)
|