From dcc0f3c30d955a8812cd27c79436653347cab888 Mon Sep 17 00:00:00 2001 From: wren <“porlong@qq.com”> Date: Mon, 11 May 2026 17:21:33 +0800 Subject: [PATCH] feat: restore rag dataset management and linkage --- .../controllers/ragChatController.py | 342 +++- .../domian/Dto/ragDatasetDto.py | 6 + .../fastapi_leaudit/domian/vo/ragDatasetVo.py | 106 ++ .../services/impl/ragChatServiceImpl.py | 71 +- .../services/impl/ragDatasetServiceImpl.py | 1533 ++++++++++++++++- .../services/ragDatasetService.py | 196 ++- 6 files changed, 2208 insertions(+), 46 deletions(-) create mode 100644 fastapi_modules/fastapi_leaudit/domian/Dto/ragDatasetDto.py diff --git a/fastapi_modules/fastapi_leaudit/controllers/ragChatController.py b/fastapi_modules/fastapi_leaudit/controllers/ragChatController.py index 7c79d23..9954828 100644 --- a/fastapi_modules/fastapi_leaudit/controllers/ragChatController.py +++ b/fastapi_modules/fastapi_leaudit/controllers/ragChatController.py @@ -2,7 +2,9 @@ from __future__ import annotations from typing import Any -from fastapi import Depends, Query +import json + +from fastapi import Depends, File, Form, Query, UploadFile from fastapi.responses import JSONResponse, StreamingResponse from fastapi_common.fastapi_common_security.security import verify_access_token @@ -14,6 +16,7 @@ from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import ( RagChatSendMessageDTO, RagMessageFeedbackDTO, ) +from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import ( RagAppParametersVO, RagChatAppListVO, @@ -23,7 +26,16 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import ( RagMessagePageVO, RagOperationResultVO, ) -from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import RagDatasetPageVO +from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import ( + RagDatasetDetailVO, + RagDatasetDocumentItemVO, + RagDatasetDocumentPageVO, + RagDatasetPageVO, + RagDatasetRetrieveResponseVO, + RagDatasetSegmentItemVO, + RagDatasetSegmentPageVO, + RagDatasetUploadDocumentVO, +) from fastapi_modules.fastapi_leaudit.services.impl.permissionServiceImpl import PermissionServiceImpl from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl @@ -82,6 +94,332 @@ class RagChatController(BaseController): ) return Result.success(data=data) + @self.router.get("/datasets/admin", response_model=Result[RagDatasetPageVO]) + async def GetAdminDatasets( + area: str | None = Query(None), + onlyEnabled: bool | None = Query(None), + page: int = Query(1, ge=1), + pageSize: int = Query(20, ge=1, le=200), + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有管理知识库权限", "data": None}) + data = await self.RagDatasetService.GetAdminDatasets( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + Area=area, + OnlyEnabled=onlyEnabled, + Page=page, + PageSize=pageSize, + ) + return Result.success(data=data) + + @self.router.post("/datasets/admin", response_model=Result[RagDatasetDetailVO]) + async def CreateAdminDataset(Body: dict[str, Any], payload: dict[str, Any] = Depends(verify_access_token)): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有创建知识库权限", "data": None}) + data = await self.RagDatasetService.CreateAdminDataset( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + Body=Body, + ) + return Result.success(data=data) + + @self.router.put("/datasets/admin/{DatasetId}", response_model=Result[RagDatasetDetailVO | None]) + async def UpdateAdminDataset(DatasetId: int, Body: dict[str, Any], payload: dict[str, Any] = Depends(verify_access_token)): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有更新知识库权限", "data": None}) + data = await self.RagDatasetService.UpdateAdminDataset( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + Body=Body, + ) + return Result.success(data=data) + + @self.router.delete("/datasets/admin/{DatasetId}", response_model=Result[RagOperationResultVO]) + async def DeleteAdminDataset(DatasetId: int, payload: dict[str, Any] = Depends(verify_access_token)): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有删除知识库权限", "data": None}) + data = await self.RagDatasetService.DeleteAdminDataset( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + ) + return Result.success(data=data) + + @self.router.get("/datasets/{DatasetId}", response_model=Result[RagDatasetDetailVO | None]) + async def GetDatasetDetail(DatasetId: int, payload: dict[str, Any] = Depends(verify_access_token)): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看知识库权限", "data": None}) + data = await self.RagDatasetService.GetDatasetDetail( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + ) + return Result.success(data=data) + + @self.router.patch("/datasets/{DatasetId}", response_model=Result[RagDatasetDetailVO | None]) + async def UpdateDataset(DatasetId: int, Body: RagDatasetUpdateDTO, payload: dict[str, Any] = Depends(verify_access_token)): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有修改知识库权限", "data": None}) + data = await self.RagDatasetService.UpdateDataset( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + Body=Body, + ) + return Result.success(data=data) + + @self.router.get("/datasets/{DatasetId}/documents", response_model=Result[RagDatasetDocumentPageVO]) + async def GetDatasetDocuments( + DatasetId: int, + page: int = Query(1, ge=1), + limit: int = Query(20, ge=1, le=100), + keyword: str | None = Query(None), + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看知识库文档权限", "data": None}) + data = await self.RagDatasetService.GetDatasetDocuments( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + Page=page, + Limit=limit, + Keyword=keyword, + ) + return Result.success(data=data) + + @self.router.get("/datasets/{DatasetId}/documents/{DocumentId}", response_model=Result[RagDatasetDocumentItemVO | None]) + async def GetDatasetDocumentDetail( + DatasetId: int, + DocumentId: int, + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看知识库文档权限", "data": None}) + data = await self.RagDatasetService.GetDatasetDocumentDetail( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + DocumentId=DocumentId, + ) + return Result.success(data=data) + + @self.router.post("/datasets/{DatasetId}/documents", response_model=Result[RagDatasetUploadDocumentVO]) + async def UploadDatasetDocument( + DatasetId: int, + file: UploadFile = File(...), + data: str | None = Form(None), + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有上传知识库文档权限", "data": None}) + process_config = json.loads(data) if data else None + file_bytes = await file.read() + result = await self.RagDatasetService.UploadDatasetDocument( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + FileName=file.filename or "document", + ContentType=file.content_type, + Content=file_bytes, + ProcessConfig=process_config, + ) + return Result.success(data=result) + + @self.router.post("/datasets/{DatasetId}/documents/{DocumentId}/update-by-file", response_model=Result[RagDatasetUploadDocumentVO]) + async def UpdateDatasetDocumentByFile( + DatasetId: int, + DocumentId: int, + file: UploadFile = File(...), + data: str | None = Form(None), + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有重处理知识库文档权限", "data": None}) + process_config = json.loads(data) if data else None + file_bytes = await file.read() + result = await self.RagDatasetService.UpdateDatasetDocumentByFile( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + DocumentId=DocumentId, + FileName=file.filename or "document", + ContentType=file.content_type, + Content=file_bytes, + ProcessConfig=process_config, + ) + return Result.success(data=result) + + @self.router.get("/datasets/{DatasetId}/documents/{DocumentId}/indexing-status") + async def GetDatasetDocumentIndexingStatus( + DatasetId: int, + DocumentId: int, + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看知识库处理进度权限", "data": None}) + result = await self.RagDatasetService.GetDatasetDocumentIndexingStatus( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + DocumentId=DocumentId, + ) + return Result.success(data=result["data"]) + + @self.router.patch("/datasets/{DatasetId}/documents/status/{Action}", response_model=Result[RagOperationResultVO]) + async def BatchUpdateDatasetDocumentStatus( + DatasetId: int, + Action: str, + Body: dict[str, Any], + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有修改知识库文档状态权限", "data": None}) + enabled = Action == "enable" + if Action not in {"enable", "disable"}: + return JSONResponse(status_code=400, content={"code": 400, "msg": "当前仅支持启用和禁用", "data": None}) + document_ids = Body.get("document_ids") or [] + result = await self.RagDatasetService.BatchUpdateDatasetDocumentStatus( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + DocumentIds=[int(item) for item in document_ids], + Enabled=enabled, + ) + return Result.success(data=result) + + @self.router.get("/datasets/{DatasetId}/documents/{DocumentId}/segments", response_model=Result[RagDatasetSegmentPageVO]) + async def GetDatasetDocumentSegments( + DatasetId: int, + DocumentId: int, + page: int = Query(1, ge=1), + limit: int = Query(20, ge=1, le=200), + keyword: str | None = Query(None), + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看知识库分段权限", "data": None}) + result = await self.RagDatasetService.GetDatasetDocumentSegments( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + DocumentId=DocumentId, + Page=page, + Limit=limit, + Keyword=keyword, + ) + return Result.success(data=result) + + @self.router.delete("/datasets/{DatasetId}/documents/{DocumentId}", response_model=Result[RagOperationResultVO]) + async def DeleteDatasetDocument( + DatasetId: int, + DocumentId: int, + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有删除知识库文档权限", "data": None}) + result = await self.RagDatasetService.DeleteDatasetDocument( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + DocumentId=DocumentId, + ) + return Result.success(data=result) + + @self.router.post("/datasets/{DatasetId}/retrieve", response_model=Result[RagDatasetRetrieveResponseVO]) + async def RetrieveDataset( + DatasetId: int, + Body: dict[str, Any], + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有检索知识库权限", "data": None}) + result = await self.RagDatasetService.RetrieveDataset( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + Query=str(Body.get("query") or ""), + RetrievalModel=Body.get("retrieval_model") if isinstance(Body.get("retrieval_model"), dict) else None, + ) + return Result.success(data=result) + + @self.router.get("/datasets/{DatasetId}/documents/{DocumentId}/segments/{SegmentId}", response_model=Result[RagDatasetSegmentItemVO | None]) + async def GetDatasetDocumentSegmentDetail( + DatasetId: int, + DocumentId: int, + SegmentId: str, + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看知识库分段权限", "data": None}) + result = await self.RagDatasetService.GetDatasetDocumentSegmentDetail( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + DocumentId=DocumentId, + SegmentId=SegmentId, + ) + return Result.success(data=result) + + @self.router.post("/datasets/{DatasetId}/documents/{DocumentId}/segments/{SegmentId}", response_model=Result[RagDatasetSegmentItemVO | None]) + async def UpdateDatasetDocumentSegment( + DatasetId: int, + DocumentId: int, + SegmentId: str, + Body: dict[str, Any], + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有修改知识库分段权限", "data": None}) + result = await self.RagDatasetService.UpdateDatasetDocumentSegment( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + DocumentId=DocumentId, + SegmentId=SegmentId, + Body=Body, + ) + return Result.success(data=result) + + @self.router.delete("/datasets/{DatasetId}/documents/{DocumentId}/segments/{SegmentId}", response_model=Result[RagOperationResultVO]) + async def DeleteDatasetDocumentSegment( + DatasetId: int, + DocumentId: int, + SegmentId: str, + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有删除知识库分段权限", "data": None}) + result = await self.RagDatasetService.DeleteDatasetDocumentSegment( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + DocumentId=DocumentId, + SegmentId=SegmentId, + ) + return Result.success(data=result) + @self.router.get("/chat/parameters", response_model=Result[RagAppParametersVO]) async def GetAppParameters( appId: int | None = Query(None, description="聊天应用ID"), diff --git a/fastapi_modules/fastapi_leaudit/domian/Dto/ragDatasetDto.py b/fastapi_modules/fastapi_leaudit/domian/Dto/ragDatasetDto.py new file mode 100644 index 0000000..12301ee --- /dev/null +++ b/fastapi_modules/fastapi_leaudit/domian/Dto/ragDatasetDto.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel, Field + + +class RagDatasetUpdateDTO(BaseModel): + name: str | None = Field(None, min_length=1, max_length=255) + retrieval_model: dict | None = Field(None) diff --git a/fastapi_modules/fastapi_leaudit/domian/vo/ragDatasetVo.py b/fastapi_modules/fastapi_leaudit/domian/vo/ragDatasetVo.py index a69d007..92dc2a1 100644 --- a/fastapi_modules/fastapi_leaudit/domian/vo/ragDatasetVo.py +++ b/fastapi_modules/fastapi_leaudit/domian/vo/ragDatasetVo.py @@ -11,8 +11,114 @@ class RagDatasetItemVO(BaseModel): documentCount: int = Field(0) totalChunks: int = Field(0) status: int = Field(1) + sortOrder: int = Field(0) + createdAt: int = Field(0) + updatedAt: int = Field(0) + appId: int | None = Field(default=None) + appName: str = Field("") + appIsDefault: bool = Field(False) class RagDatasetPageVO(BaseModel): data: list[RagDatasetItemVO] = Field(default_factory=list) total: int = Field(0) + + +class RagDatasetDetailVO(BaseModel): + id: int = Field(...) + name: str = Field(...) + description: str = Field("") + area: str = Field("") + isPublic: bool = Field(False) + isDefault: bool = Field(False) + status: int = Field(1) + documentCount: int = Field(0) + totalChunks: int = Field(0) + chunkMaxSize: int = Field(800) + chunkMinSize: int = Field(20) + sortOrder: int = Field(0) + retrievalModel: dict = Field(default_factory=dict) + createdAt: int = Field(0) + updatedAt: int = Field(0) + appId: int | None = Field(default=None) + appName: str = Field("") + appIsDefault: bool = Field(False) + + +class RagDatasetDocumentItemVO(BaseModel): + id: int = Field(...) + datasetId: int = Field(...) + name: str = Field(...) + fileType: str = Field("") + fileSize: int = Field(0) + chunkCount: int = Field(0) + indexingStatus: str = Field("waiting") + error: str = Field("") + enabled: bool = Field(True) + hitCount: int = Field(0) + createdBy: int | None = Field(None) + createdAt: int = Field(0) + updatedAt: int = Field(0) + + +class RagDatasetDocumentPageVO(BaseModel): + data: list[RagDatasetDocumentItemVO] = Field(default_factory=list) + total: int = Field(0) + page: int = Field(1) + limit: int = Field(20) + hasMore: bool = Field(False) + + +class RagDatasetUploadDocumentVO(BaseModel): + document: dict = Field(default_factory=dict) + batch: str = Field("") + + +class RagDatasetSegmentItemVO(BaseModel): + id: str = Field(...) + position: int = Field(0) + documentId: str = Field("") + content: str = Field("") + wordCount: int = Field(0) + hitCount: int = Field(0) + enabled: bool = Field(True) + status: str = Field("completed") + createdAt: int = Field(0) + + +class RagDatasetSegmentPageVO(BaseModel): + data: list[RagDatasetSegmentItemVO] = Field(default_factory=list) + total: int = Field(0) + limit: int = Field(20) + hasMore: bool = Field(False) + + +class RagDatasetRetrieveDocumentVO(BaseModel): + id: str = Field("") + dataSourceType: str = Field("upload_file") + name: str = Field("") + docType: str | None = Field(default=None) + + +class RagDatasetRetrieveSegmentVO(BaseModel): + id: str = Field(...) + position: int = Field(0) + documentId: str = Field("") + content: str = Field("") + answer: str = Field("") + wordCount: int = Field(0) + hitCount: int = Field(0) + enabled: bool = Field(True) + status: str = Field("completed") + createdAt: int = Field(0) + document: RagDatasetRetrieveDocumentVO | None = Field(default=None) + + +class RagDatasetRetrieveRecordVO(BaseModel): + segment: RagDatasetRetrieveSegmentVO = Field(...) + score: float = Field(0.0) + + +class RagDatasetRetrieveResponseVO(BaseModel): + query: dict = Field(default_factory=dict) + records: list[RagDatasetRetrieveRecordVO] = Field(default_factory=list) diff --git a/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py b/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py index 0e3c417..f97faf1 100644 --- a/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py +++ b/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py @@ -80,7 +80,7 @@ class RagChatServiceImpl(IRagChatService): context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), Query) collected_answer = "" - held_message_end: bytes | None = None + held_message_end: dict | None = None async for chunk in generate_stream( query=Query, @@ -94,17 +94,21 @@ class RagChatServiceImpl(IRagChatService): dataset_name=dataset_name, ): chunk_bytes = chunk.encode("utf-8") - for line in chunk.strip().split("\n"): - if not line.startswith("data: "): - continue - data = json.loads(line[6:]) - if data.get("event") == "message": - collected_answer += data.get("answer", "") - elif data.get("event") == "message_end": - held_message_end = chunk_bytes - continue - if held_message_end is None: + data = self._parse_sse_event(chunk) + if not data: yield chunk_bytes + continue + + if data.get("event") == "message": + collected_answer += data.get("answer", "") + yield chunk_bytes + continue + + if data.get("event") == "message_end": + held_message_end = data + continue + + yield chunk_bytes followups: list[str] = [] try: @@ -114,15 +118,10 @@ class RagChatServiceImpl(IRagChatService): if held_message_end: try: - for line in held_message_end.decode("utf-8").strip().split("\n"): - if not line.startswith("data: "): - continue - end_data = json.loads(line[6:]) - if end_data.get("event") == "message_end": - end_data.setdefault("metadata", {})["suggested_questions"] = followups - yield f"data: {json.dumps(end_data, ensure_ascii=False)}\\n\\n".encode("utf-8") + held_message_end.setdefault("metadata", {})["suggested_questions"] = followups + yield f"data: {json.dumps(held_message_end, ensure_ascii=False)}\n\n".encode("utf-8") except Exception: - yield held_message_end + yield f"data: {json.dumps(held_message_end, ensure_ascii=False)}\n\n".encode("utf-8") async with GetAsyncSession() as session: async with session.begin(): @@ -158,7 +157,7 @@ class RagChatServiceImpl(IRagChatService): FROM rag_conversation WHERE user_id = :user_id AND deleted_at IS NULL - AND (:app_id IS NULL OR app_id = :app_id) + AND (CAST(:app_id AS BIGINT) IS NULL OR app_id = CAST(:app_id AS BIGINT)) ORDER BY updated_at DESC OFFSET :offset LIMIT :limit """ @@ -277,11 +276,10 @@ class RagChatServiceImpl(IRagChatService): raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在") if int(owner) != CurrentUserId: raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权修改该消息反馈") - async with session.begin(): - await session.execute( - text("UPDATE rag_message SET feedback = :feedback WHERE message_id = :message_id"), - {"feedback": Body.rating, "message_id": MessageId}, - ) + await session.execute( + text("UPDATE rag_message SET feedback = :feedback WHERE message_id = :message_id"), + {"feedback": Body.rating, "message_id": MessageId}, + ) return RagOperationResultVO(result="success") async def GetAppParameters( @@ -587,3 +585,26 @@ class RagChatServiceImpl(IRagChatService): ) return visible_chunks + + def _parse_sse_event(self, chunk: str) -> dict | None: + data_lines: list[str] = [] + for line in chunk.splitlines(): + if line.startswith("data: "): + data_lines.append(line[6:]) + elif line.startswith("data:"): + data_lines.append(line[5:].lstrip()) + + if not data_lines: + return None + + payload = "\n".join(part for part in data_lines if part.strip()).strip() + if not payload or payload == "[DONE]": + return None + payload = payload.removesuffix("\\n\\n").removesuffix("\\n").strip() + + try: + data = json.loads(payload) + except json.JSONDecodeError: + return None + + return data if isinstance(data, dict) else None diff --git a/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py b/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py index a7b40f2..0cb00ec 100644 --- a/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py +++ b/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py @@ -1,30 +1,305 @@ from __future__ import annotations +import json +import mimetypes +import os +import re +import tempfile +import uuid +from datetime import datetime +from pathlib import Path + +import httpx from sqlalchemy import text from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession +from fastapi_common.fastapi_common_storage.oss_client import OssClient +from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum +from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException -from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import RagDatasetItemVO, RagDatasetPageVO +from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO +from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import ( + RagDatasetDetailVO, + RagDatasetDocumentItemVO, + RagDatasetDocumentPageVO, + RagDatasetItemVO, + RagDatasetPageVO, + RagDatasetRetrieveDocumentVO, + RagDatasetRetrieveRecordVO, + RagDatasetRetrieveResponseVO, + RagDatasetRetrieveSegmentVO, + RagDatasetSegmentItemVO, + RagDatasetSegmentPageVO, + RagDatasetUploadDocumentVO, +) +from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import RagOperationResultVO +from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma +from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService class RagDatasetServiceImpl(IRagDatasetService): + _APP_LINK_SQL = """ + LEFT JOIN ( + SELECT DISTINCT ON (dataset_id) + dataset_id, + id, + name, + is_default + FROM rag_chat_app + WHERE deleted_at IS NULL + AND status = 1 + ORDER BY dataset_id, is_default DESC, sort_order ASC, id ASC + ) a ON a.dataset_id = d.id + """ + + async def GetAdminDatasets( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + Area: str | None, + OnlyEnabled: bool | None, + Page: int, + PageSize: int, + ) -> RagDatasetPageVO: + if UserRole not in ("provincial_admin", "admin", "super_admin"): + raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有管理知识库权限") + + filters = ["d.deleted_at IS NULL"] + params: dict = { + "offset": max(Page - 1, 0) * PageSize, + "limit": PageSize, + } + areas = [item.strip() for item in str(Area or "").split(",") if item.strip()] + if len(areas) == 1: + filters.append("d.area = :area") + params["area"] = areas[0] + elif len(areas) > 1: + filters.append("d.area = ANY(:areas)") + params["areas"] = areas + if OnlyEnabled is not None: + filters.append("d.status = :status") + params["status"] = 1 if OnlyEnabled else 0 + + where_sql = " AND ".join(filters) + async with GetAsyncSession() as session: + total = ( + await session.execute( + text(f"SELECT COUNT(1) FROM rag_dataset d WHERE {where_sql}"), + {k: v for k, v in params.items() if k not in ("offset", "limit")}, + ) + ).scalar_one() + rows = ( + await session.execute( + text( + f""" + SELECT d.id, d.name, d.description, d.area, d.is_public, d.is_default, d.document_count, d.total_chunks, d.status, + d.sort_order, d.created_at, d.updated_at, + a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default + FROM rag_dataset d + {self._APP_LINK_SQL} + WHERE {where_sql} + ORDER BY d.sort_order ASC, d.id ASC + OFFSET :offset LIMIT :limit + """ + ), + params, + ) + ).mappings().all() + return RagDatasetPageVO( + data=[self._to_item_vo(dict(row)) for row in rows], + total=int(total or 0), + ) + + async def CreateAdminDataset( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + Body: dict, + ) -> RagDatasetDetailVO: + if UserRole not in ("provincial_admin", "admin", "super_admin"): + raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有创建知识库权限") + + area = str(Body.get("area") or "").strip() + name = str(Body.get("dataset_name") or Body.get("name") or "").strip() + description = str(Body.get("dataset_description") or Body.get("description") or "").strip() + if not area or not name: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "地区和知识库名称不能为空") + + collection_name = self._slugify_collection_name(area, name) + retrieval_model = {} + async with GetAsyncSession() as session: + base = ( + await session.execute( + text( + """ + SELECT embedding_model, embedding_dim, chunk_max_size, chunk_min_size, retrieval_model + FROM rag_dataset + WHERE deleted_at IS NULL + ORDER BY is_default DESC, id ASC + LIMIT 1 + """ + ) + ) + ).mappings().first() + if base: + retrieval_model = base.get("retrieval_model") or {} + if bool(Body.get("is_default")): + await self._clear_default_flags(session) + row = ( + await session.execute( + text( + """ + INSERT INTO rag_dataset ( + name, description, area, is_public, is_default, collection_name, + embedding_model, embedding_dim, chunk_max_size, chunk_min_size, + retrieval_model, sort_order, status, created_by, updated_by + ) VALUES ( + :name, :description, :area, :is_public, :is_default, :collection_name, + :embedding_model, :embedding_dim, :chunk_max_size, :chunk_min_size, + CAST(:retrieval_model AS jsonb), :sort_order, :status, :created_by, :updated_by + ) + RETURNING id + """ + ), + { + "name": name, + "description": description, + "area": area, + "is_public": bool(Body.get("is_public")), + "is_default": bool(Body.get("is_default")), + "collection_name": collection_name, + "embedding_model": (base.get("embedding_model") if base else "text-embedding-v4"), + "embedding_dim": (base.get("embedding_dim") if base else 1024), + "chunk_max_size": (base.get("chunk_max_size") if base else 800), + "chunk_min_size": (base.get("chunk_min_size") if base else 20), + "retrieval_model": json.dumps(retrieval_model, ensure_ascii=False), + "sort_order": int(Body.get("sort_order") or 0), + "status": int(Body.get("status") or 1), + "created_by": CurrentUserId, + "updated_by": CurrentUserId, + }, + ) + ).mappings().first() + dataset_id = int(row["id"]) + await self._ensure_linked_app( + session=session, + dataset_id=dataset_id, + dataset_name=name, + dataset_area=area, + current_user_id=CurrentUserId, + is_default=bool(Body.get("is_default")), + ) + refreshed = await self._get_dataset_row(dataset_id) + return self._to_detail_vo(refreshed) if refreshed else None + + async def UpdateAdminDataset( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + Body: dict, + ) -> RagDatasetDetailVO | None: + if UserRole not in ("provincial_admin", "admin", "super_admin"): + raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有更新知识库权限") + existing = await self._get_dataset_row(DatasetId) + if not existing: + return None + + area = str(Body.get("area") or existing.get("area") or "").strip() + + async with GetAsyncSession() as session: + target_is_default = bool(Body.get("is_default", existing.get("is_default"))) + if target_is_default: + await self._clear_default_flags(session) + elif existing.get("is_default") and Body.get("is_default") is False: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不能直接取消,请先将其他知识库设为默认") + await session.execute( + text( + """ + UPDATE rag_dataset + SET name = :name, + description = :description, + area = :area, + is_public = :is_public, + is_default = :is_default, + sort_order = :sort_order, + status = :status, + updated_by = :updated_by, + updated_at = NOW() + WHERE id = :dataset_id + """ + ), + { + "dataset_id": DatasetId, + "name": str(Body.get("dataset_name") or Body.get("name") or existing.get("name") or "").strip(), + "description": str(Body.get("dataset_description") or Body.get("description") or existing.get("description") or "").strip(), + "area": area, + "is_public": bool(Body.get("is_public", existing.get("is_public"))), + "is_default": target_is_default, + "sort_order": int(Body.get("sort_order") if Body.get("sort_order") is not None else (existing.get("sort_order") or 0)), + "status": int(Body.get("status") if Body.get("status") is not None else (existing.get("status") or 1)), + "updated_by": CurrentUserId, + }, + ) + await self._ensure_linked_app( + session=session, + dataset_id=DatasetId, + dataset_name=str(Body.get("dataset_name") or Body.get("name") or existing.get("name") or "").strip(), + dataset_area=area, + current_user_id=CurrentUserId, + is_default=target_is_default, + ) + refreshed = await self._get_dataset_row(DatasetId) + return self._to_detail_vo(refreshed) if refreshed else None + + async def DeleteAdminDataset( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + ) -> RagOperationResultVO: + if UserRole not in ("provincial_admin", "admin", "super_admin"): + raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库权限") + existing = await self._get_dataset_row(DatasetId) + if not existing: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") + if bool(existing.get("is_default")): + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不允许删除,请先切换默认知识库") + async with GetAsyncSession() as session: + await session.execute( + text("UPDATE rag_dataset SET deleted_at = NOW(), updated_by = :updated_by, updated_at = NOW() WHERE id = :dataset_id"), + {"dataset_id": DatasetId, "updated_by": CurrentUserId}, + ) + await session.execute( + text("UPDATE rag_chat_app SET deleted_at = NOW(), updated_by = :updated_by, updated_at = NOW() WHERE dataset_id = :dataset_id AND deleted_at IS NULL"), + {"dataset_id": DatasetId, "updated_by": CurrentUserId}, + ) + return RagOperationResultVO(result="success") + async def GetMyDatasets(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagDatasetPageVO: async with GetAsyncSession() as session: rows = ( await session.execute( text( - """ - SELECT id, name, description, area, is_public, is_default, document_count, total_chunks, status - FROM rag_dataset - WHERE deleted_at IS NULL - AND status = 1 + f""" + SELECT d.id, d.name, d.description, d.area, d.is_public, d.is_default, d.document_count, d.total_chunks, d.status, + d.sort_order, d.created_at, d.updated_at, + a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default + FROM rag_dataset d + {self._APP_LINK_SQL} + WHERE d.deleted_at IS NULL + AND d.status = 1 AND ( :is_provincial = TRUE - OR area IN (:user_area, '省级', '') - OR is_public = TRUE + OR d.area IN (:user_area, '省级', '') + OR d.is_public = TRUE ) - ORDER BY sort_order ASC, created_at DESC + ORDER BY d.sort_order ASC, d.created_at DESC """ ), { @@ -36,17 +311,1239 @@ class RagDatasetServiceImpl(IRagDatasetService): return RagDatasetPageVO( data=[ RagDatasetItemVO( - id=row["id"], - name=row["name"], - description=row.get("description") or "", - area=row.get("area") or "", - isPublic=bool(row.get("is_public")), - isDefault=bool(row.get("is_default")), - documentCount=row.get("document_count") or 0, - totalChunks=row.get("total_chunks") or 0, - status=row.get("status") or 1, + **self._to_item_vo(dict(row)).model_dump() ) for row in rows ], total=len(rows), ) + + async def GetDatasetDetail(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int) -> RagDatasetDetailVO | None: + row = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + if not row: + return None + return self._to_detail_vo(row) + + async def UpdateDataset( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + Body: RagDatasetUpdateDTO, + ) -> RagDatasetDetailVO | None: + row = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + if not row: + return None + + if UserRole not in ("provincial_admin", "admin", "super_admin"): + raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有修改知识库配置权限") + + update_fields: list[str] = [] + params: dict = {"dataset_id": DatasetId, "updated_by": CurrentUserId} + + if Body.name is not None: + update_fields.append("name = :name") + params["name"] = Body.name.strip() + + if Body.retrieval_model is not None: + update_fields.append("retrieval_model = CAST(:retrieval_model AS jsonb)") + params["retrieval_model"] = json.dumps(Body.retrieval_model, ensure_ascii=False) + + if not update_fields: + return self._to_detail_vo(row) + + update_fields.append("updated_by = :updated_by") + update_fields.append("updated_at = NOW()") + + async with GetAsyncSession() as session: + await session.execute( + text( + f""" + UPDATE rag_dataset + SET {", ".join(update_fields)} + WHERE id = :dataset_id + """ + ), + params, + ) + + refreshed = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + return self._to_detail_vo(refreshed) if refreshed else None + + async def GetDatasetDocuments( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + Page: int, + Limit: int, + Keyword: str | None, + ) -> RagDatasetDocumentPageVO: + dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + if not dataset: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") + + where_sql = [ + "dataset_id = :dataset_id", + "deleted_at IS NULL", + ] + params: dict = { + "dataset_id": DatasetId, + "offset": max(Page - 1, 0) * Limit, + "limit": Limit + 1, + } + if Keyword and Keyword.strip(): + where_sql.append("(original_name ILIKE :keyword OR filename ILIKE :keyword)") + params["keyword"] = f"%{Keyword.strip()}%" + + async with GetAsyncSession() as session: + total = ( + await session.execute( + text( + f""" + SELECT COUNT(1) + FROM rag_document + WHERE {" AND ".join(where_sql)} + """ + ), + {key: value for key, value in params.items() if key not in ("offset", "limit")}, + ) + ).scalar_one() + rows = ( + await session.execute( + text( + f""" + SELECT id, dataset_id, original_name, file_type, file_size, chunk_count, + indexing_status, COALESCE(indexing_error, '') AS indexing_error, + enabled, hit_count, created_by, created_at, updated_at + FROM rag_document + WHERE {" AND ".join(where_sql)} + ORDER BY created_at DESC + OFFSET :offset LIMIT :limit + """ + ), + params, + ) + ).mappings().all() + + has_more = len(rows) > Limit + items = rows[:Limit] + return RagDatasetDocumentPageVO( + data=[ + RagDatasetDocumentItemVO( + id=row["id"], + datasetId=row["dataset_id"], + name=row.get("original_name") or "", + fileType=row.get("file_type") or "", + fileSize=row.get("file_size") or 0, + chunkCount=row.get("chunk_count") or 0, + indexingStatus=row.get("indexing_status") or "waiting", + error=row.get("indexing_error") or "", + enabled=bool(row.get("enabled")), + hitCount=row.get("hit_count") or 0, + createdBy=row.get("created_by"), + createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, + updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, + ) + for row in items + ], + total=int(total or 0), + page=Page, + limit=Limit, + hasMore=has_more, + ) + + async def GetDatasetDocumentDetail( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + ) -> RagDatasetDocumentItemVO | None: + dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + if not dataset: + return None + + async with GetAsyncSession() as session: + row = ( + await session.execute( + text( + """ + SELECT id, dataset_id, original_name, file_type, file_size, chunk_count, + indexing_status, COALESCE(indexing_error, '') AS indexing_error, + enabled, hit_count, created_by, created_at, updated_at + FROM rag_document + WHERE id = :document_id + AND dataset_id = :dataset_id + AND deleted_at IS NULL + LIMIT 1 + """ + ), + {"document_id": DocumentId, "dataset_id": DatasetId}, + ) + ).mappings().first() + + if not row: + return None + + return RagDatasetDocumentItemVO( + id=row["id"], + datasetId=row["dataset_id"], + name=row.get("original_name") or "", + fileType=row.get("file_type") or "", + fileSize=row.get("file_size") or 0, + chunkCount=row.get("chunk_count") or 0, + indexingStatus=row.get("indexing_status") or "waiting", + error=row.get("indexing_error") or "", + enabled=bool(row.get("enabled")), + hitCount=row.get("hit_count") or 0, + createdBy=row.get("created_by"), + createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, + updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, + ) + + async def _get_visible_dataset(self, user_area: str | None, user_role: str | None, dataset_id: int) -> dict | None: + async with GetAsyncSession() as session: + row = ( + await session.execute( + text( + f""" + SELECT d.id, d.name, d.description, d.area, d.is_public, d.is_default, d.status, + d.document_count, d.total_chunks, d.chunk_max_size, d.chunk_min_size, d.sort_order, + d.retrieval_model, d.collection_name, d.embedding_model, d.created_at, d.updated_at, + a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default + FROM rag_dataset d + {self._APP_LINK_SQL} + WHERE d.id = :dataset_id + AND d.deleted_at IS NULL + AND d.status = 1 + LIMIT 1 + """ + ), + {"dataset_id": dataset_id}, + ) + ).mappings().first() + if not row: + return None + if user_role == "provincial_admin": + return dict(row) + area = row.get("area") or "" + if area in ("", "省级", user_area or "") or bool(row.get("is_public")): + return dict(row) + return None + + def _to_detail_vo(self, row: dict) -> RagDatasetDetailVO: + return RagDatasetDetailVO( + id=row["id"], + name=row["name"], + description=row.get("description") or "", + area=row.get("area") or "", + isPublic=bool(row.get("is_public")), + isDefault=bool(row.get("is_default")), + status=row.get("status") or 1, + documentCount=row.get("document_count") or 0, + sortOrder=row.get("sort_order") or 0, + totalChunks=row.get("total_chunks") or 0, + chunkMaxSize=row.get("chunk_max_size") or 800, + chunkMinSize=row.get("chunk_min_size") or 20, + retrievalModel=row.get("retrieval_model") or {}, + createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, + updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, + appId=row.get("app_id"), + appName=row.get("app_name") or "", + appIsDefault=bool(row.get("app_is_default")), + ) + + def _to_item_vo(self, row: dict) -> RagDatasetItemVO: + return RagDatasetItemVO( + id=row["id"], + name=row.get("name") or "", + description=row.get("description") or "", + area=row.get("area") or "", + isPublic=bool(row.get("is_public")), + isDefault=bool(row.get("is_default")), + documentCount=row.get("document_count") or 0, + totalChunks=row.get("total_chunks") or 0, + status=row.get("status") or 1, + sortOrder=row.get("sort_order") or 0, + createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, + updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, + appId=row.get("app_id"), + appName=row.get("app_name") or "", + appIsDefault=bool(row.get("app_is_default")), + ) + + async def _get_dataset_row(self, dataset_id: int) -> dict | None: + async with GetAsyncSession() as session: + row = ( + await session.execute( + text( + f""" + SELECT d.id, d.name, d.description, d.area, d.is_public, d.is_default, d.status, + d.document_count, d.total_chunks, d.chunk_max_size, d.chunk_min_size, d.sort_order, + d.retrieval_model, d.collection_name, d.embedding_model, d.created_at, d.updated_at, + a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default + FROM rag_dataset d + {self._APP_LINK_SQL} + WHERE d.id = :dataset_id AND d.deleted_at IS NULL + LIMIT 1 + """ + ), + {"dataset_id": dataset_id}, + ) + ).mappings().first() + return dict(row) if row else None + + async def _clear_default_flags(self, session) -> None: + await session.execute(text("UPDATE rag_dataset SET is_default = FALSE WHERE deleted_at IS NULL")) + await session.execute(text("UPDATE rag_chat_app SET is_default = FALSE WHERE deleted_at IS NULL")) + + async def _ensure_linked_app( + self, + session, + dataset_id: int, + dataset_name: str, + dataset_area: str, + current_user_id: int, + is_default: bool, + ) -> None: + app_row = ( + await session.execute( + text( + """ + SELECT id + FROM rag_chat_app + WHERE dataset_id = :dataset_id + AND deleted_at IS NULL + ORDER BY is_default DESC, sort_order ASC, id ASC + LIMIT 1 + """ + ), + {"dataset_id": dataset_id}, + ) + ).mappings().first() + + app_name = self._build_app_name(dataset_area=dataset_area, dataset_name=dataset_name) + if app_row: + await session.execute( + text( + """ + UPDATE rag_chat_app + SET name = :name, + is_default = :is_default, + status = 1, + updated_by = :updated_by, + updated_at = NOW() + WHERE id = :app_id + """ + ), + { + "app_id": int(app_row["id"]), + "name": app_name, + "is_default": is_default, + "updated_by": current_user_id, + }, + ) + return + + await session.execute( + text( + """ + INSERT INTO rag_chat_app ( + name, description, area, dataset_id, suggested_questions, + opening_statement, sort_order, status, is_default, created_by, updated_by + ) VALUES ( + :name, :description, :area, :dataset_id, CAST(:suggested_questions AS jsonb), + :opening_statement, 0, 1, :is_default, :created_by, :updated_by + ) + """ + ), + { + "name": app_name, + "description": f"{dataset_area or '默认地区'}知识库问答助手", + "area": dataset_area or "", + "dataset_id": dataset_id, + "suggested_questions": json.dumps([], ensure_ascii=False), + "opening_statement": f"您好,我是{app_name}。", + "is_default": is_default, + "created_by": current_user_id, + "updated_by": current_user_id, + }, + ) + + def _build_app_name(self, dataset_area: str, dataset_name: str) -> str: + cleaned = (dataset_name or "").strip() + if cleaned.endswith("知识库"): + cleaned = cleaned[:-3] + if cleaned.endswith("助手"): + return cleaned + if cleaned: + return f"{cleaned}助手" + return f"{dataset_area or '默认地区'}法务助手" + + def _slugify_collection_name(self, area: str, name: str) -> str: + source = f"{area}_{name}".lower() + normalized = re.sub(r"[^a-z0-9]+", "_", source).strip("_") + if normalized: + return f"legal_kb_{normalized}"[:96] + return f"legal_kb_{uuid.uuid4().hex[:12]}" + + async def UploadDatasetDocument( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + FileName: str, + ContentType: str | None, + Content: bytes, + ProcessConfig: dict | None, + ) -> RagDatasetUploadDocumentVO: + dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + if not dataset: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") + + suffix = os.path.splitext(FileName)[1].lower() + if suffix not in {".pdf", ".docx", ".txt", ".md"}: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD") + + object_key = f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}" + content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream" + OssClient().EnsureBucket() + stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type) + + async with GetAsyncSession() as session: + inserted = ( + await session.execute( + text( + """ + INSERT INTO rag_document ( + dataset_id, filename, original_name, minio_path, file_type, file_size, + chunk_count, indexing_status, enabled, hit_count, created_by + ) VALUES ( + :dataset_id, :filename, :original_name, :minio_path, :file_type, :file_size, + 0, 'indexing', TRUE, 0, :created_by + ) + RETURNING id, created_at, updated_at + """ + ), + { + "dataset_id": DatasetId, + "filename": uuid.uuid4().hex + suffix, + "original_name": FileName, + "minio_path": stored_key, + "file_type": suffix.lstrip("."), + "file_size": len(Content), + "created_by": CurrentUserId, + }, + ) + ).mappings().first() + document_id = int(inserted["id"]) + + try: + page_texts = self._extract_page_texts(FileName=FileName, Content=Content) + processed = self._build_chunks( + file_name=FileName, + page_texts=page_texts, + dataset=dataset, + process_config=ProcessConfig or {}, + document_id=document_id, + ) + if not processed: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本") + + embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "") + collection = get_chroma().get_or_create_collection(dataset["collection_name"]) + collection.add( + ids=[item["id"] for item in processed], + documents=[item["text"] for item in processed], + embeddings=embeddings, + metadatas=[item["metadata"] for item in processed], + ) + + async with GetAsyncSession() as session: + await session.execute( + text( + """ + UPDATE rag_document + SET chunk_count = :chunk_count, + indexing_status = 'completed', + indexing_error = NULL, + indexing_started_at = COALESCE(indexing_started_at, NOW()), + indexing_completed_at = NOW() + WHERE id = :document_id + """ + ), + {"chunk_count": len(processed), "document_id": document_id}, + ) + await session.execute( + text( + """ + UPDATE rag_dataset + SET document_count = ( + SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL + ), + total_chunks = COALESCE(( + SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL + ), 0) + WHERE id = :dataset_id + """ + ), + {"dataset_id": DatasetId}, + ) + except Exception as exc: + async with GetAsyncSession() as session: + await session.execute( + text( + """ + UPDATE rag_document + SET indexing_status = 'error', + indexing_error = :error + WHERE id = :document_id + """ + ), + {"document_id": document_id, "error": str(exc)[:2000]}, + ) + raise + + return RagDatasetUploadDocumentVO( + document={ + "id": str(document_id), + "name": FileName, + "indexing_status": "completed", + "word_count": len(processed), + "hit_count": 0, + "enabled": True, + }, + batch=str(document_id), + ) + + async def GetDatasetDocumentSegments( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + Page: int, + Limit: int, + Keyword: str | None, + ) -> RagDatasetSegmentPageVO: + dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + if not dataset: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") + + collection = get_chroma().get_or_create_collection(dataset["collection_name"]) + raw = collection.get(where={"document_id": DocumentId}, include=["documents", "metadatas"]) + ids = raw.get("ids") or [] + docs = raw.get("documents") or [] + metas = raw.get("metadatas") or [] + + items: list[dict] = [] + for index, segment_id in enumerate(ids): + content = docs[index] if index < len(docs) else "" + meta = metas[index] if index < len(metas) and isinstance(metas[index], dict) else {} + if Keyword and Keyword.strip() and Keyword.strip() not in content: + continue + items.append( + { + "id": str(segment_id), + "position": index + 1, + "document_id": str(DocumentId), + "content": content or "", + "word_count": len(content or ""), + "hit_count": 0, + "enabled": True, + "status": "completed", + "created_at": 0, + } + ) + + offset = max(Page - 1, 0) * Limit + page_items = items[offset: offset + Limit] + has_more = offset + Limit < len(items) + return RagDatasetSegmentPageVO( + data=[ + RagDatasetSegmentItemVO( + id=item["id"], + position=item["position"], + documentId=item["document_id"], + content=item["content"], + wordCount=item["word_count"], + hitCount=item["hit_count"], + enabled=item["enabled"], + status=item["status"], + createdAt=item["created_at"], + ) + for item in page_items + ], + total=len(items), + limit=Limit, + hasMore=has_more, + ) + + async def GetDatasetDocumentSegmentDetail( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + SegmentId: str, + ) -> RagDatasetSegmentItemVO | None: + dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + if not dataset: + return None + + collection = get_chroma().get_or_create_collection(dataset["collection_name"]) + raw = collection.get(ids=[SegmentId], include=["documents", "metadatas"]) + ids = raw.get("ids") or [] + if not ids: + return None + + content = (raw.get("documents") or [""])[0] or "" + metadata = (raw.get("metadatas") or [{}])[0] or {} + if str(metadata.get("document_id") or "") != str(DocumentId): + return None + + return RagDatasetSegmentItemVO( + id=str(SegmentId), + position=int(metadata.get("chunk_index") or 0) + 1, + documentId=str(DocumentId), + content=content, + wordCount=len(content), + hitCount=0, + enabled=True, + status="completed", + createdAt=0, + ) + + async def UpdateDatasetDocumentSegment( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + SegmentId: str, + Body: dict, + ) -> RagDatasetSegmentItemVO | None: + dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + if not dataset: + return None + + if UserRole not in ("provincial_admin", "admin", "super_admin"): + raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有修改知识库分段权限") + + current = await self.GetDatasetDocumentSegmentDetail( + CurrentUserId=CurrentUserId, + UserArea=UserArea, + UserRole=UserRole, + DatasetId=DatasetId, + DocumentId=DocumentId, + SegmentId=SegmentId, + ) + if not current: + return None + + segment_body = Body.get("segment") if isinstance(Body.get("segment"), dict) else Body + content = str(segment_body.get("content") or current.content).strip() + if not content: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "分段内容不能为空") + + collection = get_chroma().get_or_create_collection(dataset["collection_name"]) + raw = collection.get(ids=[SegmentId], include=["metadatas"]) + metadata = (raw.get("metadatas") or [{}])[0] or {} + embeddings = await self._embed_texts([content], dataset.get("embedding_model") or "") + collection.update( + ids=[SegmentId], + documents=[content], + embeddings=embeddings, + metadatas=[metadata], + ) + + return RagDatasetSegmentItemVO( + id=str(SegmentId), + position=int(metadata.get("chunk_index") or 0) + 1, + documentId=str(DocumentId), + content=content, + wordCount=len(content), + hitCount=0, + enabled=True, + status="completed", + createdAt=0, + ) + + async def DeleteDatasetDocumentSegment( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + SegmentId: str, + ) -> RagOperationResultVO: + dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + if not dataset: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") + + if UserRole not in ("provincial_admin", "admin", "super_admin"): + raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库分段权限") + + current = await self.GetDatasetDocumentSegmentDetail( + CurrentUserId=CurrentUserId, + UserArea=UserArea, + UserRole=UserRole, + DatasetId=DatasetId, + DocumentId=DocumentId, + SegmentId=SegmentId, + ) + if not current: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "分段不存在") + + collection = get_chroma().get_or_create_collection(dataset["collection_name"]) + collection.delete(ids=[SegmentId]) + + async with GetAsyncSession() as session: + await session.execute( + text( + """ + UPDATE rag_document + SET chunk_count = GREATEST(chunk_count - 1, 0), + updated_at = NOW() + WHERE id = :document_id + """ + ), + {"document_id": DocumentId}, + ) + await session.execute( + text( + """ + UPDATE rag_dataset + SET total_chunks = GREATEST(total_chunks - 1, 0), + updated_at = NOW() + WHERE id = :dataset_id + """ + ), + {"dataset_id": DatasetId}, + ) + + return RagOperationResultVO(result="success") + + async def DeleteDatasetDocument( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + ) -> RagOperationResultVO: + dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + if not dataset: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") + + if UserRole not in ("provincial_admin", "admin", "super_admin"): + raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库文档权限") + + async with GetAsyncSession() as session: + document_row = ( + await session.execute( + text( + """ + SELECT id, dataset_id, minio_path + FROM rag_document + WHERE id = :document_id + AND dataset_id = :dataset_id + AND deleted_at IS NULL + LIMIT 1 + """ + ), + {"document_id": DocumentId, "dataset_id": DatasetId}, + ) + ).mappings().first() + + if not document_row: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在") + + collection = get_chroma().get_or_create_collection(dataset["collection_name"]) + raw = collection.get(where={"document_id": DocumentId}, include=[]) + ids = raw.get("ids") or [] + if ids: + collection.delete(ids=ids) + + self._delete_oss_object(document_row.get("minio_path")) + + async with GetAsyncSession() as session: + await session.execute( + text( + """ + UPDATE rag_document + SET deleted_at = NOW(), + updated_at = NOW() + WHERE id = :document_id + """ + ), + {"document_id": DocumentId}, + ) + await session.execute( + text( + """ + UPDATE rag_dataset + SET document_count = ( + SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL + ), + total_chunks = COALESCE(( + SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL + ), 0), + updated_at = NOW() + WHERE id = :dataset_id + """ + ), + {"dataset_id": DatasetId}, + ) + + return RagOperationResultVO(result="success") + + async def RetrieveDataset( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + Query: str, + RetrievalModel: dict | None, + ) -> RagDatasetRetrieveResponseVO: + dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + if not dataset: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") + + query_text = (Query or "").strip() + if not query_text: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "检索内容不能为空") + + retrieval_model = RetrievalModel or {} + top_k = int(retrieval_model.get("top_k") or 5) + top_k = max(1, min(top_k, 20)) + score_threshold_enabled = bool(retrieval_model.get("score_threshold_enabled")) + score_threshold = float(retrieval_model.get("score_threshold") or 0) if score_threshold_enabled else None + + query_embedding = await self._embed_texts([query_text], dataset.get("embedding_model") or "") + collection = get_chroma().get_or_create_collection(dataset["collection_name"]) + raw = collection.query( + query_embeddings=query_embedding, + n_results=top_k, + include=["documents", "metadatas", "distances"], + ) + + ids = (raw.get("ids") or [[]])[0] if raw.get("ids") else [] + documents = (raw.get("documents") or [[]])[0] if raw.get("documents") else [] + metadatas = (raw.get("metadatas") or [[]])[0] if raw.get("metadatas") else [] + distances = (raw.get("distances") or [[]])[0] if raw.get("distances") else [] + + records: list[RagDatasetRetrieveRecordVO] = [] + for index, segment_id in enumerate(ids): + content = documents[index] if index < len(documents) else "" + metadata = metadatas[index] if index < len(metadatas) and isinstance(metadatas[index], dict) else {} + distance = float(distances[index]) if index < len(distances) and distances[index] is not None else 1.0 + score = max(0.0, min(1.0, 1.0 - distance)) + if score_threshold_enabled and score_threshold is not None and score < score_threshold: + continue + + document_name = metadata.get("document_name") or metadata.get("source") or "" + document_id = str(metadata.get("document_id") or "") + chunk_index = int(metadata.get("chunk_index") or index) + records.append( + RagDatasetRetrieveRecordVO( + score=round(score, 6), + segment=RagDatasetRetrieveSegmentVO( + id=str(segment_id), + position=chunk_index + 1, + documentId=document_id, + content=content or "", + answer="", + wordCount=len(content or ""), + hitCount=0, + enabled=True, + status="completed", + createdAt=0, + document=RagDatasetRetrieveDocumentVO( + id=document_id, + dataSourceType="upload_file", + name=document_name, + docType=None, + ), + ), + ) + ) + + return RagDatasetRetrieveResponseVO( + query={"content": query_text}, + records=records, + ) + + async def GetDatasetDocumentIndexingStatus( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + ) -> dict: + document = await self.GetDatasetDocumentDetail( + CurrentUserId=CurrentUserId, + UserArea=UserArea, + UserRole=UserRole, + DatasetId=DatasetId, + DocumentId=DocumentId, + ) + if not document: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在") + + completed_segments = document.chunkCount if document.indexingStatus == "completed" else 0 + total_segments = document.chunkCount if document.chunkCount > 0 else 0 + return { + "data": [ + { + "id": str(document.id), + "indexing_status": document.indexingStatus, + "processing_started_at": document.updatedAt or document.createdAt or None, + "parsing_completed_at": document.updatedAt if document.indexingStatus in ("cleaning", "splitting", "indexing", "completed") else None, + "cleaning_completed_at": document.updatedAt if document.indexingStatus in ("splitting", "indexing", "completed") else None, + "splitting_completed_at": document.updatedAt if document.indexingStatus in ("indexing", "completed") else None, + "completed_at": document.updatedAt if document.indexingStatus == "completed" else None, + "paused_at": None, + "error": document.error or None, + "stopped_at": None, + "completed_segments": completed_segments, + "total_segments": total_segments, + } + ] + } + + async def UpdateDatasetDocumentByFile( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + FileName: str, + ContentType: str | None, + Content: bytes, + ProcessConfig: dict | None, + ) -> RagDatasetUploadDocumentVO: + dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + if not dataset: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") + + if UserRole not in ("provincial_admin", "admin", "super_admin"): + raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有重处理知识库文档权限") + + async with GetAsyncSession() as session: + current = ( + await session.execute( + text( + """ + SELECT id, minio_path + FROM rag_document + WHERE id = :document_id + AND dataset_id = :dataset_id + AND deleted_at IS NULL + LIMIT 1 + """ + ), + {"document_id": DocumentId, "dataset_id": DatasetId}, + ) + ).mappings().first() + if not current: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在") + + suffix = os.path.splitext(FileName)[1].lower() + if suffix not in {".pdf", ".docx", ".txt", ".md"}: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD") + + content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream" + object_key = current.get("minio_path") or f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}" + OssClient().EnsureBucket() + stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type) + + collection = get_chroma().get_or_create_collection(dataset["collection_name"]) + raw = collection.get(where={"document_id": DocumentId}, include=[]) + ids = raw.get("ids") or [] + if ids: + collection.delete(ids=ids) + + async with GetAsyncSession() as session: + await session.execute( + text( + """ + UPDATE rag_document + SET original_name = :original_name, + minio_path = :minio_path, + file_type = :file_type, + file_size = :file_size, + chunk_count = 0, + indexing_status = 'indexing', + indexing_error = NULL, + updated_at = NOW() + WHERE id = :document_id + """ + ), + { + "document_id": DocumentId, + "original_name": FileName, + "minio_path": stored_key, + "file_type": suffix.lstrip("."), + "file_size": len(Content), + }, + ) + + try: + page_texts = self._extract_page_texts(FileName=FileName, Content=Content) + processed = self._build_chunks( + file_name=FileName, + page_texts=page_texts, + dataset=dataset, + process_config=ProcessConfig or {}, + document_id=DocumentId, + ) + if not processed: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本") + + embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "") + collection.add( + ids=[item["id"] for item in processed], + documents=[item["text"] for item in processed], + embeddings=embeddings, + metadatas=[item["metadata"] for item in processed], + ) + + async with GetAsyncSession() as session: + await session.execute( + text( + """ + UPDATE rag_document + SET chunk_count = :chunk_count, + indexing_status = 'completed', + indexing_error = NULL, + indexing_started_at = COALESCE(indexing_started_at, NOW()), + indexing_completed_at = NOW(), + updated_at = NOW() + WHERE id = :document_id + """ + ), + {"chunk_count": len(processed), "document_id": DocumentId}, + ) + await session.execute( + text( + """ + UPDATE rag_dataset + SET document_count = ( + SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL + ), + total_chunks = COALESCE(( + SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL + ), 0), + updated_at = NOW() + WHERE id = :dataset_id + """ + ), + {"dataset_id": DatasetId}, + ) + except Exception as exc: + async with GetAsyncSession() as session: + await session.execute( + text( + """ + UPDATE rag_document + SET indexing_status = 'error', + indexing_error = :error, + updated_at = NOW() + WHERE id = :document_id + """ + ), + {"document_id": DocumentId, "error": str(exc)[:2000]}, + ) + raise + + return RagDatasetUploadDocumentVO( + document={ + "id": str(DocumentId), + "name": FileName, + "indexing_status": "completed", + "word_count": len(processed), + "hit_count": 0, + "enabled": True, + }, + batch=str(DocumentId), + ) + + async def BatchUpdateDatasetDocumentStatus( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentIds: list[int], + Enabled: bool, + ) -> RagOperationResultVO: + dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) + if not dataset: + raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") + + if UserRole not in ("provincial_admin", "admin", "super_admin"): + raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有修改知识库文档状态权限") + + ids = [int(doc_id) for doc_id in DocumentIds] + if not ids: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "未传入文档ID") + + async with GetAsyncSession() as session: + await session.execute( + text( + """ + UPDATE rag_document + SET enabled = :enabled, + updated_at = NOW() + WHERE dataset_id = :dataset_id + AND id = ANY(:document_ids) + AND deleted_at IS NULL + """ + ), + {"dataset_id": DatasetId, "document_ids": ids, "enabled": Enabled}, + ) + return RagOperationResultVO(result="success") + + def _extract_page_texts(self, *, FileName: str, Content: bytes) -> list[tuple[int, str]]: + suffix = os.path.splitext(FileName)[1].lower() + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file: + temp_file.write(Content) + temp_path = temp_file.name + try: + from fastapi_modules.fastapi_leaudit.services.impl.documentServiceImpl import ( + _extract_page_texts_from_docx, + _extract_page_texts_from_pdf, + ) + + if suffix == ".pdf": + return _extract_page_texts_from_pdf(Path(temp_path)) + if suffix == ".docx": + return _extract_page_texts_from_docx(Path(temp_path)) + + text_value = Content.decode("utf-8", errors="ignore").strip() + return [(1, text_value)] if text_value else [] + finally: + try: + os.unlink(temp_path) + except OSError: + pass + + def _build_chunks(self, *, file_name: str, page_texts: list[tuple[int, str]], dataset: dict, process_config: dict, document_id: int) -> list[dict]: + rules = ((process_config or {}).get("process_rule") or {}).get("rules") or {} + segmentation = rules.get("segmentation") or {} + pre_rules = rules.get("pre_processing_rules") or [] + remove_spaces = any(rule.get("id") == "remove_extra_spaces" and rule.get("enabled") for rule in pre_rules) + remove_urls = any(rule.get("id") == "remove_urls_emails" and rule.get("enabled") for rule in pre_rules) + + separator = segmentation.get("separator") or "\n\n" + max_tokens = int(segmentation.get("max_tokens") or dataset.get("chunk_max_size") or 800) + chunk_overlap = int(segmentation.get("chunk_overlap") or 50) + chunk_overlap = max(0, min(chunk_overlap, max_tokens // 2 if max_tokens > 1 else 0)) + + chunks: list[dict] = [] + for page_no, raw_text in page_texts: + text_value = self._preprocess_text(raw_text, remove_spaces=remove_spaces, remove_urls=remove_urls) + if not text_value: + continue + for index, chunk_text in enumerate(self._split_text(text_value, separator=separator, max_chars=max_tokens, overlap=chunk_overlap)): + if not chunk_text.strip(): + continue + chunk_id = f"{document_id}:{page_no}:{index}" + chunks.append( + { + "id": chunk_id, + "text": chunk_text, + "metadata": { + "id": chunk_id, + "source": file_name, + "document_name": file_name, + "document_id": document_id, + "page": page_no, + "chunk_index": index, + }, + } + ) + return chunks + + async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]: + embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings" + embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"] + embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4" + if not embed_url or not embed_key: + raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务") + + async with httpx.AsyncClient(timeout=120.0) as client: + response = await client.post( + embed_url, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {embed_key}", + }, + json={"model": embed_model, "input": texts}, + ) + response.raise_for_status() + payload = response.json() + rows = payload.get("data") or [] + embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")] + if len(embeddings) != len(texts): + raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") + return embeddings + + def _preprocess_text(self, text_value: str, *, remove_spaces: bool, remove_urls: bool) -> str: + result = text_value or "" + if remove_urls: + result = re.sub(r"https?://\\S+|www\\.\\S+|[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}", " ", result) + if remove_spaces: + result = re.sub(r"[ \\t]+", " ", result) + result = re.sub(r"\n{3,}", "\n\n", result) + return result.strip() + + def _split_text(self, text_value: str, *, separator: str, max_chars: int, overlap: int) -> list[str]: + parts = [part.strip() for part in text_value.split(separator) if part.strip()] if separator else [text_value] + chunks: list[str] = [] + current = "" + for part in parts: + candidate = f"{current}{separator if current else ''}{part}" if separator else f"{current}{part}" + if len(candidate) <= max_chars: + current = candidate + continue + if current: + chunks.append(current) + if len(part) <= max_chars: + current = part + continue + start = 0 + step = max(max_chars - overlap, 1) + while start < len(part): + chunks.append(part[start:start + max_chars]) + start += step + current = "" + if current: + chunks.append(current) + return chunks + + def _delete_oss_object(self, source: str | None) -> None: + if not source: + return + try: + oss = OssClient() + ref = oss.ResolveObjectRef(Source=source) + if ref.isDirectUrl or not ref.objectKey: + return + oss._GetMinioClient().remove_object(ref.bucket, ref.objectKey) + except Exception: + # 对象存储删除失败不阻塞业务主流程,避免历史脏数据导致文档无法删除。 + return diff --git a/fastapi_modules/fastapi_leaudit/services/ragDatasetService.py b/fastapi_modules/fastapi_leaudit/services/ragDatasetService.py index f71048d..09053e5 100644 --- a/fastapi_modules/fastapi_leaudit/services/ragDatasetService.py +++ b/fastapi_modules/fastapi_leaudit/services/ragDatasetService.py @@ -2,9 +2,203 @@ from __future__ import annotations from abc import ABC, abstractmethod -from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import RagDatasetPageVO +from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO +from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import ( + RagDatasetDetailVO, + RagDatasetDocumentItemVO, + RagDatasetDocumentPageVO, + RagDatasetPageVO, + RagDatasetRetrieveResponseVO, + RagDatasetSegmentPageVO, + RagDatasetUploadDocumentVO, +) +from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import RagOperationResultVO class IRagDatasetService(ABC): + @abstractmethod + async def GetAdminDatasets( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + Area: str | None, + OnlyEnabled: bool | None, + Page: int, + PageSize: int, + ) -> RagDatasetPageVO: ... + + @abstractmethod + async def CreateAdminDataset( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + Body: dict, + ) -> RagDatasetDetailVO: ... + + @abstractmethod + async def UpdateAdminDataset( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + Body: dict, + ) -> RagDatasetDetailVO | None: ... + + @abstractmethod + async def DeleteAdminDataset( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + ) -> RagOperationResultVO: ... + @abstractmethod async def GetMyDatasets(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagDatasetPageVO: ... + + @abstractmethod + async def GetDatasetDetail(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int) -> RagDatasetDetailVO | None: ... + + @abstractmethod + async def UpdateDataset(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, Body: RagDatasetUpdateDTO) -> RagDatasetDetailVO | None: ... + + @abstractmethod + async def GetDatasetDocuments( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + Page: int, + Limit: int, + Keyword: str | None, + ) -> RagDatasetDocumentPageVO: ... + + @abstractmethod + async def GetDatasetDocumentDetail( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + ) -> RagDatasetDocumentItemVO | None: ... + + @abstractmethod + async def UploadDatasetDocument( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + FileName: str, + ContentType: str | None, + Content: bytes, + ProcessConfig: dict | None, + ) -> RagDatasetUploadDocumentVO: ... + + @abstractmethod + async def GetDatasetDocumentSegments( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + Page: int, + Limit: int, + Keyword: str | None, + ) -> RagDatasetSegmentPageVO: ... + + @abstractmethod + async def DeleteDatasetDocument( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + ) -> RagOperationResultVO: ... + + @abstractmethod + async def RetrieveDataset( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + Query: str, + RetrievalModel: dict | None, + ) -> RagDatasetRetrieveResponseVO: ... + + @abstractmethod + async def GetDatasetDocumentIndexingStatus( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + ) -> dict: ... + + @abstractmethod + async def UpdateDatasetDocumentByFile( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + FileName: str, + ContentType: str | None, + Content: bytes, + ProcessConfig: dict | None, + ) -> RagDatasetUploadDocumentVO: ... + + @abstractmethod + async def BatchUpdateDatasetDocumentStatus( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentIds: list[int], + Enabled: bool, + ) -> RagOperationResultVO: ... + + @abstractmethod + async def GetDatasetDocumentSegmentDetail( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + SegmentId: str, + ) -> RagDatasetSegmentItemVO | None: ... + + @abstractmethod + async def UpdateDatasetDocumentSegment( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + SegmentId: str, + Body: dict, + ) -> RagDatasetSegmentItemVO | None: ... + + @abstractmethod + async def DeleteDatasetDocumentSegment( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentId: int, + SegmentId: str, + ) -> RagOperationResultVO: ...