4 Commits

25 changed files with 2682 additions and 36 deletions
@@ -0,0 +1,89 @@
# Chat Temporary RAG Attachments Implementation Plan
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
**Goal:** Build conversation-scoped temporary RAG attachments for chat with 7-day TTL, parser/OCR indexing, strict tenant/user/conversation isolation, and dual retrieval with the existing formal knowledge base.
**Architecture:** Add a focused `RagChatAttachmentServiceImpl` responsible for attachment lifecycle, parsing, chunking, indexing, retrieval, validation, and cleanup. Extend `RagChatServiceImpl` so a chat message can include an `AttachmentId`, retrieve attachment facts first, then formal KB legal context, and generate one answer with source-aware chunks. Add frontend upload/poll/status plumbing inside the existing chat input path without touching unrelated dirty frontend files.
**Tech Stack:** FastAPI, SQLAlchemy text queries, Chroma, existing `RagRetriever`, existing OSS client, existing LeAudit OCR bridge, React/Next.js, Ant Design upload controls.
---
### Task 1: Backend Attachment Contract And Unit Tests
**Files:**
- Create: `tests/test_rag_chat_attachment_service.py`
- Create: `fastapi_modules/fastapi_leaudit/domian/vo/ragChatAttachmentVo.py`
- Modify: `fastapi_modules/fastapi_leaudit/domian/Dto/ragChatDto.py`
- [ ] Write tests for TTL, collection naming, scope matching, and chunk metadata.
- [ ] Run: `pytest tests/test_rag_chat_attachment_service.py -q`; expected failures mention missing `RagChatAttachmentServiceImpl`.
- [ ] Add attachment DTO and VO classes.
- [ ] Re-run the same test and keep remaining failures focused on missing service implementation.
### Task 2: Backend Attachment Service
**Files:**
- Create: `fastapi_modules/fastapi_leaudit/services/ragChatAttachmentService.py`
- Create: `fastapi_modules/fastapi_leaudit/services/impl/ragChatAttachmentServiceImpl.py`
- Create: `scripts/创建sql/schema_add_rag_chat_attachments.sql`
- [ ] Implement schema creation for `rag_chat_attachment`.
- [ ] Implement `BuildCollectionName`, `BuildChunks`, `CreateAttachment`, `GetAttachment`, `ValidateAttachmentForChat`, `RetrieveAttachmentContext`, `DeleteAttachment`, and `CleanupExpiredAttachments`.
- [ ] Implement parsers for txt/md/json/csv/docx/pdf/xlsx/images.
- [ ] Implement async indexing with status transitions and best-effort cleanup.
- [ ] Run: `pytest tests/test_rag_chat_attachment_service.py -q`; expected pass.
### Task 3: Chat Message Dual Retrieval
**Files:**
- Modify: `fastapi_modules/fastapi_leaudit/services/ragChatService.py`
- Modify: `fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py`
- Modify: `fastapi_modules/fastapi_leaudit/rag_engine/generator.py`
- Modify: `tests/test_rag_chat_streaming_sources.py`
- [ ] Add `AttachmentId` to chat service and DTO call path.
- [ ] Add tests proving `_run_message_task` merges attachment chunks and formal KB chunks with distinct source scopes.
- [ ] Run targeted test and confirm red.
- [ ] Implement retrieval merge and grounded legal query construction.
- [ ] Update generator prompt to group uploaded file facts separately from formal KB basis.
- [ ] Run targeted tests and confirm green.
### Task 4: Backend Controller Routes
**Files:**
- Modify: `fastapi_modules/fastapi_leaudit/controllers/ragChatController.py`
- [ ] Add `POST /chat/attachments`, `GET /chat/attachments/{AttachmentId}`, and `DELETE /chat/attachments/{AttachmentId}`.
- [ ] Extend message send route to pass `Body.attachmentId`.
- [ ] Reuse `rag:chat:use` permission and existing tenant context.
- [ ] Run backend attachment and chat streaming tests.
### Task 5: Frontend Upload And Send Plumbing
**Files:**
- Create: `legal-platform-frontend/app/api/chat-attachments/route.ts`
- Create: `legal-platform-frontend/app/api/chat-attachments/[attachmentId]/route.ts`
- Create: `legal-platform-frontend/lib/api/legacy/dify-chat/attachment.ts`
- Modify: `legal-platform-frontend/lib/api/legacy/dify-chat/types.ts`
- Modify: `legal-platform-frontend/lib/api/legacy/dify-chat/client.ts`
- Modify: `legal-platform-frontend/app/api/chat-messages/route.ts`
- Modify: `legal-platform-frontend/hooks/use-chat-message.ts`
- Modify: `legal-platform-frontend/components/dify-chat/index.tsx`
- Modify: `legal-platform-frontend/components/dify-chat/chat-input.tsx`
- [ ] Add frontend attachment API client and Next route proxies.
- [ ] Update chat input to allow one file, upload immediately, poll status, show completed/error/removable state, and block send while indexing.
- [ ] Pass `attachmentId` through `Chat` -> `useChatMessage` -> `/api/chat-messages` -> FastAPI.
- [ ] Keep upload scoped to current conversation id; for new chats, create/send with conversation after backend returns id or require existing conversation before attachment.
### Task 6: Verification
**Files:**
- All touched files
- [ ] Run: `pytest tests/test_rag_chat_attachment_service.py tests/test_rag_chat_streaming_sources.py -q`.
- [ ] Run frontend type/test command if available and scoped enough.
- [ ] Check `git status --short` at root and frontend subrepo.
- [ ] Report changed files, verification output, and any known gaps.
@@ -0,0 +1,150 @@
# Chat Temporary RAG Attachments Design
## Goal
Implement in-chat file upload for `/chat-with-llm/chat` so a user can upload one document or image, have it parsed and indexed into a temporary conversation-scoped RAG collection, and then ask questions that combine facts from the uploaded file with legal knowledge from the existing app knowledge base. Temporary attachment indexes expire after 7 days by default.
## Core Behavior
The upload belongs to the chat experience, not the permanent knowledge-base management UI. A file uploaded inside a conversation creates a temporary attachment record and a temporary vector collection. Chat messages may reference that attachment by `attachmentId`.
The answer flow is dual-source RAG:
1. Retrieve from the temporary attachment collection to establish facts from the uploaded file.
2. Build a grounded legal search query from the user question plus attachment hits.
3. Retrieve from the existing formal app knowledge base to find laws, penalties, cases, or policy rules.
4. Generate a single answer that clearly uses uploaded-file content as facts and formal knowledge-base content as legal basis.
If no attachment is selected, chat continues using the existing formal knowledge-base retrieval path.
## Scope And Isolation
Temporary knowledge must never leak between conversations or users. Backend validation must never trust frontend `attachmentId` alone. Every attachment operation validates:
- `tenant_code` or resolved tenant context
- `user_id`
- `conversation_id`
- `attachment_id`
- `deleted_at IS NULL`
- `expires_at > NOW()`
- `indexing_status = 'completed'` for chat retrieval
The temporary collection name must include all isolation dimensions:
```text
chat_attachment_{tenantCode}_{userId}_{conversationHash}_{attachmentId}
```
Chunk metadata also carries:
- `tenant_code`
- `user_id`
- `conversation_id`
- `attachment_id`
- `source_scope = "chat_attachment"`
- `document_name`
- `page`
- `chunk_index`
This creates defense in depth: database filtering protects attachment ownership, and vector metadata protects retrieval integrity if collection access is ever broadened.
## Supported Files
Initial backend support:
- Text: `.txt`, `.md`, `.json`, `.csv`
- Word: `.docx`
- PDF: `.pdf`
- Excel: `.xlsx` using `openpyxl` when available, with a clear 400 error if the dependency is missing
- Images: `.png`, `.jpg`, `.jpeg`, `.webp`, `.bmp`, `.tif`, `.tiff`
Images are parsed through the existing LeAudit OCR bridge. If OCR returns structured pages, the service converts pages to text. If only a raw dictionary is returned, it extracts common fields such as `full_text`, `text`, `content`, `pages`, and OCR line text.
## Lifecycle
Upload flow:
1. Frontend posts multipart file with `conversation_id` and optional `app_id`.
2. Backend validates chat permission and conversation ownership.
3. Backend creates `rag_chat_attachment` with `indexing_status = 'waiting'` and `expires_at = now + 7 days`.
4. Backend stores the original file in OSS under a temporary chat path.
5. Backend starts async indexing:
- `parsing`
- `splitting`
- `indexing`
- `completed`
- or `error`
6. Frontend polls attachment status and enables send only after `completed`.
Deletion flow:
- User can remove the attachment from the chat UI.
- Backend soft-deletes the row.
- Backend attempts to delete the temporary Chroma collection and OSS object.
- Deletion failures in Chroma/OSS do not block the user-facing delete result.
TTL cleanup:
- `expires_at` defaults to 7 days.
- Retrieval rejects expired attachments immediately.
- A cleanup method soft-deletes expired records and best-effort deletes Chroma collections and OSS objects.
- The cleanup method can later be wired to an existing scheduler if one exists.
## API Shape
Backend routes under `/api/v3/rag`:
- `POST /chat/attachments`
- multipart: `file`, `conversation_id`, optional `app_id`
- returns attachment id, filename, status, expires timestamp, and collection name
- `GET /chat/attachments/{AttachmentId}`
- returns status and metadata after ownership validation
- `DELETE /chat/attachments/{AttachmentId}`
- soft deletes and best-effort removes temporary artifacts
- `POST /chat/messages`
- extends existing body with optional `attachmentId`
Frontend routes:
- `POST /api/chat-attachments`
- `GET /api/chat-attachments/[attachmentId]`
- `DELETE /api/chat-attachments/[attachmentId]`
- Existing `/api/chat-messages` forwards `attachment_id` / `attachmentId`.
## Chat Generation Contract
The generator must receive context chunks with an explicit `source_scope`:
- `chat_attachment`: facts extracted from uploaded file
- `formal_kb`: laws, penalties, and authoritative references from the configured app knowledge base
Prompt construction must tell the model:
- Uploaded attachment chunks are factual input from the user's file.
- Formal knowledge-base chunks are the source for legal rules and penalties.
- Do not invent laws, penalties, or file facts that are absent from the provided contexts.
Sources returned to the UI must preserve source type so users can tell whether a cited segment came from the uploaded file or the formal knowledge base.
## Error Handling
- Empty file: 400
- Unsupported type: 400
- Attachment from another user, tenant, or conversation: 403 or 404
- Expired attachment: 410-like business error using the existing exception mechanism
- Attachment not completed when sending: 400
- Parser found no text: status `error`, send disabled in UI
- OCR/index failures: status `error`, error message capped for display
## Testing Requirements
Backend tests must cover:
- Default `expires_at` is 7 days.
- Collection names include sanitized tenant/user/conversation/attachment isolation fields.
- Scope validation rejects mismatched tenant, user, or conversation.
- Expired or non-completed attachments cannot be used for chat retrieval.
- Built chunk metadata contains tenant/user/conversation/attachment isolation fields.
- Chat task merges temporary attachment chunks and formal KB chunks, preserving source metadata.
Frontend tests are optional in the first slice if the project does not already have focused chat component tests, but the implementation must keep the UI constrained to one selected attachment.
@@ -55,6 +55,13 @@ class DocumentController(BaseController):
"""文档控制器。"""
_CROSS_REVIEW_DOCUMENT_READ_PERMISSION = "cross_review:document:read"
_DOCUMENT_TYPE_PERMISSIONS = {
"list": "doc_type:list:read",
"detail": "doc_type:detail:read",
"create": "doc_type:create:write",
"update": "doc_type:update:write",
"delete": "doc_type:delete:delete",
}
@staticmethod
def _tenant_context(payload: dict[str, Any]) -> dict[str, str | None]:
@@ -296,8 +303,16 @@ class DocumentController(BaseController):
async def ListDocumentTypes(
ids: str | None = Query(None, description="逗号分隔的ID列表,不传则返回全部"),
entry_module_id: int | None = Query(None, description="按入口模块ID过滤文档类型"),
payload: dict[str, Any] = Depends(verify_access_token),
):
"""获取文档类型列表。"""
deniedResponse = await self._deny_document_type_without_permission(
int(payload["user_id"]),
self._DOCUMENT_TYPE_PERMISSIONS["list"],
"当前用户没有文档类型列表权限",
)
if deniedResponse:
return deniedResponse
idList: list[int] | None = None
if ids:
idList = [int(x.strip()) for x in ids.split(",") if x.strip().isdigit()]
@@ -305,52 +320,109 @@ class DocumentController(BaseController):
return Result.success(data=Data)
@self.router.get("/document-types/{TypeId}", response_model=Result[DocumentTypeItemVO])
async def GetDocumentType(TypeId: int):
async def GetDocumentType(TypeId: int, payload: dict[str, Any] = Depends(verify_access_token)):
"""获取文档类型详情。"""
deniedResponse = await self._deny_document_type_without_permission(
int(payload["user_id"]),
self._DOCUMENT_TYPE_PERMISSIONS["detail"],
"当前用户没有文档类型详情权限",
)
if deniedResponse:
return deniedResponse
Data = await self.DocumentService.GetDocumentType(Id=TypeId)
return Result.success(data=Data)
@self.router.post("/document-types", response_model=Result[DocumentTypeItemVO])
async def CreateDocumentType(Body: DocumentTypeCreateDTO):
async def CreateDocumentType(Body: DocumentTypeCreateDTO, payload: dict[str, Any] = Depends(verify_access_token)):
"""创建文档类型。"""
deniedResponse = await self._deny_document_type_without_permission(
int(payload["user_id"]),
self._DOCUMENT_TYPE_PERMISSIONS["create"],
"当前用户没有创建文档类型权限",
)
if deniedResponse:
return deniedResponse
Data = await self.DocumentService.CreateDocumentType(Body=Body)
return Result.success(data=Data, message="文档类型创建成功")
@self.router.put("/document-types/{TypeId}", response_model=Result[DocumentTypeItemVO])
async def UpdateDocumentType(TypeId: int, Body: DocumentTypeUpdateDTO):
async def UpdateDocumentType(TypeId: int, Body: DocumentTypeUpdateDTO, payload: dict[str, Any] = Depends(verify_access_token)):
"""更新文档类型。"""
deniedResponse = await self._deny_document_type_without_permission(
int(payload["user_id"]),
self._DOCUMENT_TYPE_PERMISSIONS["update"],
"当前用户没有更新文档类型权限",
)
if deniedResponse:
return deniedResponse
Data = await self.DocumentService.UpdateDocumentType(Id=TypeId, Body=Body)
return Result.success(data=Data, message="文档类型更新成功")
@self.router.delete("/document-types/{TypeId}", response_model=Result[None])
async def DeleteDocumentType(TypeId: int):
async def DeleteDocumentType(TypeId: int, payload: dict[str, Any] = Depends(verify_access_token)):
"""删除文档类型(软删除)。"""
deniedResponse = await self._deny_document_type_without_permission(
int(payload["user_id"]),
self._DOCUMENT_TYPE_PERMISSIONS["delete"],
"当前用户没有删除文档类型权限",
)
if deniedResponse:
return deniedResponse
await self.DocumentService.DeleteDocumentType(Id=TypeId)
return Result.success(message="文档类型已删除")
@self.router.get("/v3/document-type-roots", response_model=Result[list[DocumentTypeRootItemVO]])
async def ListDocumentTypeRoots(
entry_module_id: int | None = Query(None, description="按入口模块过滤一级大类"),
payload: dict[str, Any] = Depends(verify_access_token),
):
"""获取一级文档类型(业务大类)列表。"""
deniedResponse = await self._deny_document_type_without_permission(
int(payload["user_id"]),
self._DOCUMENT_TYPE_PERMISSIONS["list"],
"当前用户没有业务大类列表权限",
)
if deniedResponse:
return deniedResponse
Data = await self.DocumentService.ListDocumentTypeRoots(EntryModuleId=entry_module_id)
return Result.success(data=Data)
@self.router.get("/v3/document-type-roots/{RootId}", response_model=Result[DocumentTypeRootItemVO])
async def GetDocumentTypeRoot(RootId: int):
async def GetDocumentTypeRoot(RootId: int, payload: dict[str, Any] = Depends(verify_access_token)):
"""获取一级文档类型(业务大类)详情。"""
deniedResponse = await self._deny_document_type_without_permission(
int(payload["user_id"]),
self._DOCUMENT_TYPE_PERMISSIONS["detail"],
"当前用户没有业务大类详情权限",
)
if deniedResponse:
return deniedResponse
Data = await self.DocumentService.GetDocumentTypeRoot(Id=RootId)
return Result.success(data=Data)
@self.router.post("/v3/document-type-roots", response_model=Result[DocumentTypeRootItemVO])
async def CreateDocumentTypeRoot(Body: DocumentTypeRootCreateDTO):
async def CreateDocumentTypeRoot(Body: DocumentTypeRootCreateDTO, payload: dict[str, Any] = Depends(verify_access_token)):
"""创建一级文档类型(业务大类)。"""
deniedResponse = await self._deny_document_type_without_permission(
int(payload["user_id"]),
self._DOCUMENT_TYPE_PERMISSIONS["create"],
"当前用户没有创建业务大类权限",
)
if deniedResponse:
return deniedResponse
Data = await self.DocumentService.CreateDocumentTypeRoot(Body=Body)
return Result.success(data=Data, message="一级文档类型创建成功")
@self.router.put("/v3/document-type-roots/{RootId}", response_model=Result[DocumentTypeRootItemVO])
async def UpdateDocumentTypeRoot(RootId: int, Body: DocumentTypeRootUpdateDTO):
async def UpdateDocumentTypeRoot(RootId: int, Body: DocumentTypeRootUpdateDTO, payload: dict[str, Any] = Depends(verify_access_token)):
"""更新一级文档类型(业务大类)。"""
deniedResponse = await self._deny_document_type_without_permission(
int(payload["user_id"]),
self._DOCUMENT_TYPE_PERMISSIONS["update"],
"当前用户没有更新业务大类权限",
)
if deniedResponse:
return deniedResponse
Data = await self.DocumentService.UpdateDocumentTypeRoot(Id=RootId, Body=Body)
return Result.success(data=Data, message="一级文档类型更新成功")
@@ -431,3 +503,11 @@ class DocumentController(BaseController):
status_code=403,
content={"code": 403, "msg": "当前用户没有查看交叉评查结果权限", "data": None},
)
async def _deny_document_type_without_permission(self, UserId: int, PermissionKey: str, Message: str) -> JSONResponse | None:
if await self.PermissionService.CheckPermission(UserId, PermissionKey):
return None
return JSONResponse(
status_code=403,
content={"code": 403, "msg": Message, "data": None},
)
@@ -30,6 +30,10 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
RagMessagePageVO,
RagOperationResultVO,
)
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatAttachmentVo import (
RagChatAttachmentDeleteVO,
RagChatAttachmentVO,
)
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
RagDatasetBatchDeleteResultVO,
RagDatasetDetailVO,
@@ -43,8 +47,10 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
)
from fastapi_modules.fastapi_leaudit.services.impl.permissionServiceImpl import PermissionServiceImpl
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
from fastapi_modules.fastapi_leaudit.services.permissionService import IPermissionService
from fastapi_modules.fastapi_leaudit.services.ragChatAttachmentService import IRagChatAttachmentService
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService
@@ -76,6 +82,7 @@ class RagChatController(BaseController):
def __init__(self):
super().__init__(prefix="/v3/rag", tags=["RAG 聊天"])
self.RagChatService: IRagChatService = RagChatServiceImpl()
self.RagChatAttachmentService: IRagChatAttachmentService = RagChatAttachmentServiceImpl()
self.RagDatasetService: IRagDatasetService = RagDatasetServiceImpl()
self.PermissionService: IPermissionService = PermissionServiceImpl()
@@ -484,6 +491,8 @@ class RagChatController(BaseController):
Query=Body.query,
ConversationId=Body.conversationId,
AppId=Body.appId,
AttachmentId=Body.attachmentId,
AttachmentIds=Body.attachmentIds,
)
return StreamingResponse(
stream,
@@ -491,6 +500,62 @@ class RagChatController(BaseController):
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive"},
)
@self.router.post("/chat/attachments", response_model=Result[RagChatAttachmentVO])
async def UploadChatAttachment(
file: UploadFile = File(...),
conversation_id: str | None = Form(None),
app_id: int | None = Form(None),
payload: dict[str, Any] = Depends(verify_access_token),
):
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"]]):
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有上传聊天附件权限", "data": None})
tenant_context = self._tenant_context(payload)
file_bytes = await file.read()
data = await self.RagChatAttachmentService.CreateAttachment(
CurrentUserId=int(payload["user_id"]),
**tenant_context,
ConversationId=conversation_id,
AppId=app_id,
FileName=file.filename or "attachment",
ContentType=file.content_type,
Content=file_bytes,
)
return Result.success(data=data)
@self.router.get("/chat/attachments/{AttachmentId}", response_model=Result[RagChatAttachmentVO])
async def GetChatAttachment(
AttachmentId: str,
conversation_id: str = Query(..., description="附件所属会话ID"),
payload: dict[str, Any] = Depends(verify_access_token),
):
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"]]):
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看聊天附件权限", "data": None})
tenant_context = self._tenant_context(payload)
data = await self.RagChatAttachmentService.GetAttachment(
CurrentUserId=int(payload["user_id"]),
**tenant_context,
ConversationId=conversation_id,
AttachmentId=AttachmentId,
)
return Result.success(data=data)
@self.router.delete("/chat/attachments/{AttachmentId}", response_model=Result[RagChatAttachmentDeleteVO])
async def DeleteChatAttachment(
AttachmentId: str,
conversation_id: str = Query(..., description="附件所属会话ID"),
payload: dict[str, Any] = Depends(verify_access_token),
):
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"]]):
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有删除聊天附件权限", "data": None})
tenant_context = self._tenant_context(payload)
data = await self.RagChatAttachmentService.DeleteAttachment(
CurrentUserId=int(payload["user_id"]),
**tenant_context,
ConversationId=conversation_id,
AttachmentId=AttachmentId,
)
return Result.success(data=data)
@self.router.post("/chat/messages/{MessageId}/stop", response_model=Result[RagOperationResultVO])
async def StopMessage(
MessageId: str,
@@ -5,6 +5,8 @@ class RagChatSendMessageDTO(BaseModel):
query: str = Field(..., min_length=1, description="用户问题")
conversationId: str | None = Field(None, description="会话ID")
appId: int | None = Field(None, description="聊天应用ID")
attachmentId: str | None = Field(None, description="临时聊天附件ID")
attachmentIds: list[str] = Field(default_factory=list, description="临时聊天附件ID列表")
class RagConversationRenameDTO(BaseModel):
@@ -0,0 +1,19 @@
from pydantic import BaseModel, Field
class RagChatAttachmentVO(BaseModel):
attachmentId: str = Field(..., description="临时附件ID")
conversationId: str = Field(..., description="会话ID")
fileName: str = Field(..., description="原始文件名")
contentType: str = Field("", description="文件 MIME 类型")
fileSize: int = Field(0, description="文件大小")
indexingStatus: str = Field("waiting", description="索引状态")
indexingError: str | None = Field(None, description="索引错误")
chunkCount: int = Field(0, description="分段数量")
collectionName: str = Field("", description="临时向量集合名")
expiresAt: int = Field(0, description="过期时间戳")
createdAt: int = Field(0, description="创建时间戳")
class RagChatAttachmentDeleteVO(BaseModel):
result: str = Field("success")
@@ -36,6 +36,7 @@ class RagMessageItemVO(BaseModel):
conversationId: str = Field(...)
query: str = Field(...)
answer: str = Field(...)
messageFiles: list[dict] = Field(default_factory=list)
feedback: dict | None = Field(None)
retrieverResources: list[dict] | None = Field(None)
suggestedQuestions: list[str] = Field(default_factory=list)
@@ -33,16 +33,34 @@ async def generate_stream(
max_context_chars = 8000
if context_chunks:
parts: list[str] = []
attachment_parts: list[str] = []
formal_parts: list[str] = []
total_len = 0
for chunk in context_chunks:
part = f"[来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
scope = str(chunk.get("source_scope") or chunk.get("data_source_type") or "formal_kb")
if scope == "chat_attachment":
label = "上传文件事实"
else:
label = "正式知识库依据"
part = f"[{label} | 来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
if total_len + len(part) > max_context_chars:
break
parts.append(part)
if scope == "chat_attachment":
attachment_parts.append(part)
else:
formal_parts.append(part)
total_len += len(part)
context_text = "\\n\\n---\\n\\n".join(parts)
user_content = f"知识库内容:\\n{context_text}\\n\\n用户问题: {query}"
context_sections: list[str] = []
if attachment_parts:
context_sections.append("【上传文件事实】\n" + "\\n\\n---\\n\\n".join(attachment_parts))
if formal_parts:
context_sections.append("【正式知识库依据】\n" + "\\n\\n---\\n\\n".join(formal_parts))
context_text = "\\n\\n======\\n\\n".join(context_sections)
user_content = (
"请严格区分上下文来源:上传文件事实只能用于判断用户材料中的事实;正式知识库依据只能用于法律、处罚、裁量或案例依据。"
"如果依据不足,请明确说明,不要编造。\n\n"
f"{context_text}\\n\\n用户问题: {query}"
)
else:
user_content = query
@@ -115,7 +133,7 @@ async def generate_stream(
"dataset_name": dataset_name,
"document_id": "",
"document_name": chunk.get("source", ""),
"data_source_type": "upload_file",
"data_source_type": chunk.get("data_source_type") or chunk.get("source_scope") or "upload_file",
"segment_id": chunk.get("id", ""),
"retriever_from": "rag",
"score": round(chunk.get("score", 0.0), 4),
@@ -125,6 +143,8 @@ async def generate_stream(
"index_node_hash": "",
"content": chunk.get("text", "")[:500],
"page": None,
"source_scope": chunk.get("source_scope") or "",
"attachment_id": chunk.get("attachment_id") or "",
}
for i, chunk in enumerate(context_chunks)
]
@@ -183,10 +183,15 @@ class RagRetriever:
"source": meta.get("source") or meta.get("document_name") or dataset_name,
"score": score,
"chunk_index": int(meta.get("chunk_index") or idx),
"document_name": document_name,
"document_id": meta.get("document_id"),
"page": meta.get("page"),
}
"document_name": document_name,
"document_id": meta.get("document_id"),
"page": meta.get("page"),
"source_scope": meta.get("source_scope"),
"attachment_id": meta.get("attachment_id"),
"conversation_id": meta.get("conversation_id"),
"tenant_code": meta.get("tenant_code"),
"user_id": meta.get("user_id"),
}
)
return chunks
@@ -285,6 +290,11 @@ class RagRetriever:
"document_name": document_name,
"document_id": meta.get("document_id"),
"page": meta.get("page"),
"source_scope": meta.get("source_scope"),
"attachment_id": meta.get("attachment_id"),
"conversation_id": meta.get("conversation_id"),
"tenant_code": meta.get("tenant_code"),
"user_id": meta.get("user_id"),
}
)
@@ -392,10 +402,10 @@ class RagRetriever:
{
"position": index + 1,
"dataset_id": str(chunk.get("dataset_id") or ""),
"dataset_name": dataset_name,
"dataset_name": chunk.get("dataset_name") or dataset_name,
"document_id": str(chunk.get("document_id") or ""),
"document_name": chunk.get("document_name") or chunk.get("source", ""),
"data_source_type": "upload_file",
"data_source_type": chunk.get("data_source_type") or chunk.get("source_scope") or "upload_file",
"segment_id": chunk.get("id", ""),
"retriever_from": "rag",
"score": round(float(chunk.get("score") or 0.0), 4),
@@ -405,6 +415,8 @@ class RagRetriever:
"index_node_hash": "",
"content": chunk.get("text", "")[:500],
"page": None,
"source_scope": chunk.get("source_scope") or "",
"attachment_id": chunk.get("attachment_id") or "",
}
for index, chunk in enumerate(context_chunks)
]
@@ -51,7 +51,7 @@ _ALLOWED_FEATURES = {
_DEFAULT_FEATURES_BY_PROFILE = {
"document_review": ["home", "documents", "upload", "rules", "rule_groups"],
"contract": ["home", "documents", "upload", "rules", "contract_template_search", "contract_template_list"],
"govdoc": ["home", "govdoc_audits", "govdoc_upload", "rule_groups"],
"govdoc": ["home", "govdoc_audits", "govdoc_upload", "rules"],
"cross_checking": ["cross_checking", "cross_checking_upload", "cross_checking_list"],
"custom": ["home", "documents"],
}
@@ -884,6 +884,8 @@ class EntryModuleAdminServiceImpl(IEntryModuleAdminService):
feature = str(item or "").strip()
if not feature:
continue
if MenuProfile == "govdoc" and feature == "rule_groups":
feature = "rules"
if feature not in _ALLOWED_FEATURES:
invalid.append(feature)
continue
@@ -915,6 +917,8 @@ class EntryModuleAdminServiceImpl(IEntryModuleAdminService):
normalized: list[str] = []
for item in Features:
feature = str(item or "").strip()
if MenuProfile == "govdoc" and feature == "rule_groups":
feature = "rules"
if feature in _ALLOWED_FEATURES and feature not in normalized:
normalized.append(feature)
return normalized or list(_DEFAULT_FEATURES_BY_PROFILE.get(MenuProfile, _DEFAULT_FEATURES_BY_PROFILE["document_review"]))
@@ -54,7 +54,7 @@ class HomeServiceImpl(IHomeService):
_DEFAULT_FEATURES_BY_PROFILE: dict[str, list[str]] = {
"document_review": ["home", "documents", "upload", "rules", "rule_groups"],
"contract": ["home", "documents", "upload", "rules", "contract_template_search", "contract_template_list"],
"govdoc": ["home", "govdoc_audits", "govdoc_upload", "rule_groups"],
"govdoc": ["home", "govdoc_audits", "govdoc_upload", "rules"],
"cross_checking": ["cross_checking", "cross_checking_upload", "cross_checking_list"],
"custom": ["home", "documents"],
}
@@ -553,6 +553,8 @@ class HomeServiceImpl(IHomeService):
normalized: list[str] = []
for item in parsed:
feature = str(item or "").strip()
if menu_profile == "govdoc" and feature == "rule_groups":
feature = "rules"
if feature in allowed_features and feature not in normalized:
normalized.append(feature)
return normalized or list(cls._DEFAULT_FEATURES_BY_PROFILE[menu_profile])
@@ -0,0 +1,821 @@
"""Temporary RAG attachments for chat conversations."""
from __future__ import annotations
import asyncio
import csv
import hashlib
import json
import mimetypes
import os
import re
import tempfile
import uuid
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Awaitable, Callable
from sqlalchemy import text
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_storage.oss_client import OssClient
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatAttachmentVo import (
RagChatAttachmentDeleteVO,
RagChatAttachmentVO,
)
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import EmbedTexts, RagRetriever
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
from fastapi_modules.fastapi_leaudit.services.ragChatAttachmentService import IRagChatAttachmentService
DEFAULT_ATTACHMENT_TTL_DAYS = 7
SUPPORTED_TEXT_SUFFIXES = {".txt", ".md", ".json", ".csv"}
SUPPORTED_DOCUMENT_SUFFIXES = {".docx", ".pdf", ".xlsx"}
SUPPORTED_IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tif", ".tiff"}
SUPPORTED_SUFFIXES = SUPPORTED_TEXT_SUFFIXES | SUPPORTED_DOCUMENT_SUFFIXES | SUPPORTED_IMAGE_SUFFIXES
class RagChatAttachmentServiceImpl(IRagChatAttachmentService):
"""Manage temporary, conversation-scoped RAG attachment indexes."""
_attachment_schema_checked = False
_attachment_schema_lock = asyncio.Lock()
def __init__(
self,
*,
chroma_client: Any | None = None,
embed_texts: EmbedTexts | None = None,
chat_service: RagChatServiceImpl | None = None,
ocr_client_factory: Callable[[], Any] | None = None,
) -> None:
self._chroma_client = chroma_client
self.retriever = RagRetriever(
chroma_client=chroma_client,
embed_texts=embed_texts,
hydrate_documents=False,
)
self.chat_service = chat_service or RagChatServiceImpl()
self.dataset_helpers = RagDatasetServiceImpl()
self._ocr_client_factory = ocr_client_factory
def _default_expires_at(self, now: datetime | None = None) -> datetime:
base = self._ensure_aware_datetime(now or datetime.now(timezone.utc))
return base + timedelta(days=DEFAULT_ATTACHMENT_TTL_DAYS)
def BuildCollectionName(self, *, TenantCode: str | None, UserId: int, ConversationId: str, AttachmentId: str) -> str:
tenant = self._slugify(TenantCode or "default")
user = self._slugify(str(UserId))
conversation_hash = hashlib.sha1(ConversationId.encode("utf-8")).hexdigest()[:16]
attachment = self._slugify(AttachmentId)
collection_name = f"chat_attachment_{tenant}_{user}_{conversation_hash}_{attachment}"
return collection_name[:120]
def BuildChunks(
self,
*,
AttachmentId: str,
TenantCode: str | None,
UserId: int,
ConversationId: str,
FileName: str,
PageTexts: list[tuple[int, str]],
ChunkMaxSize: int = 800,
ChunkOverlap: int = 80,
) -> list[dict[str, Any]]:
chunk_max_size = max(50, int(ChunkMaxSize or 800))
chunk_overlap = max(0, min(int(ChunkOverlap or 0), chunk_max_size // 2))
chunks: list[dict[str, Any]] = []
for page_no, raw_text in PageTexts:
text_value = self._preprocess_text(raw_text)
if not text_value:
continue
split_chunks = self._split_text(text_value, separator="\n\n", max_chars=chunk_max_size, overlap=chunk_overlap)
for index, chunk_text in enumerate(split_chunks):
if not chunk_text.strip():
continue
chunk_id = f"{AttachmentId}:{page_no}:{index}"
chunks.append(
{
"id": chunk_id,
"text": chunk_text,
"metadata": {
"id": chunk_id,
"source": FileName,
"document_name": FileName,
"attachment_id": AttachmentId,
"tenant_code": str(TenantCode or ""),
"user_id": str(UserId),
"conversation_id": ConversationId,
"source_scope": "chat_attachment",
"page": int(page_no),
"chunk_index": index,
},
}
)
return chunks
async def CreateAttachment(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
ConversationId: str | None,
AppId: int | None,
FileName: str,
ContentType: str | None,
Content: bytes,
) -> RagChatAttachmentVO:
if not FileName:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "上传文件名不能为空")
if not Content:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "上传文件内容不能为空")
suffix = Path(FileName).suffix.lower()
if suffix not in SUPPORTED_SUFFIXES:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 TXT、MD、JSON、CSV、DOCX、PDF、XLSX 和图片文件")
app = await self.chat_service._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName)
if not app:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用")
conversation_id = await self.chat_service._ensure_conversation(
user_id=CurrentUserId,
conversation_id=ConversationId,
app_id=app["id"],
user_area=UserArea,
user_role=UserRole,
tenant_code=TenantCode,
tenant_name=TenantName,
)
attachment_id = str(uuid.uuid4())
collection_name = self.BuildCollectionName(
TenantCode=TenantCode,
UserId=CurrentUserId,
ConversationId=conversation_id,
AttachmentId=attachment_id,
)
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
expires_at = self._default_expires_at()
object_key = f"rag-chat-attachments/{TenantCode or 'default'}/{CurrentUserId}/{conversation_id}/{attachment_id}_{FileName}"
OssClient().EnsureBucket()
stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type)
async with GetAsyncSession() as session:
await self._ensure_attachment_schema(session)
row = (
await session.execute(
text(
"""
INSERT INTO rag_chat_attachment (
attachment_id, conversation_id, user_id, tenant_code, area,
filename, original_name, content_type, file_size, minio_path,
collection_name, indexing_status, chunk_count, expires_at
) VALUES (
:attachment_id, :conversation_id, :user_id, :tenant_code, :area,
:filename, :original_name, :content_type, :file_size, :minio_path,
:collection_name, 'waiting', 0, :expires_at
)
RETURNING *
"""
),
{
"attachment_id": attachment_id,
"conversation_id": conversation_id,
"user_id": CurrentUserId,
"tenant_code": TenantCode,
"area": UserArea,
"filename": f"{attachment_id}{suffix}",
"original_name": FileName,
"content_type": content_type,
"file_size": len(Content),
"minio_path": stored_key,
"collection_name": collection_name,
"expires_at": expires_at,
},
)
).mappings().first()
asyncio.create_task(
self._run_attachment_indexing_task(
attachment_id=attachment_id,
tenant_code=TenantCode,
user_id=CurrentUserId,
conversation_id=conversation_id,
file_name=FileName,
content=Content,
collection_name=collection_name,
)
)
return self._to_vo(dict(row))
async def GetAttachment(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
ConversationId: str,
AttachmentId: str,
) -> RagChatAttachmentVO:
self._validate_conversation_id(ConversationId)
record = await self._get_attachment_record(AttachmentId)
self._assert_attachment_scope(
record,
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
RequireCompleted=False,
Now=datetime.now(timezone.utc),
)
return self._to_vo(record)
async def DeleteAttachment(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
ConversationId: str,
AttachmentId: str,
) -> RagChatAttachmentDeleteVO:
self._validate_conversation_id(ConversationId)
record = await self._get_attachment_record(AttachmentId)
self._assert_attachment_scope(
record,
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
RequireCompleted=False,
Now=datetime.now(timezone.utc),
)
async with GetAsyncSession() as session:
await self._ensure_attachment_schema(session)
await session.execute(
text(
"""
UPDATE rag_chat_attachment
SET deleted_at = NOW(), updated_at = NOW()
WHERE attachment_id = :attachment_id
"""
),
{"attachment_id": AttachmentId},
)
self._delete_collection(str(record.get("collection_name") or ""))
self._delete_oss_object(str(record.get("minio_path") or ""))
return RagChatAttachmentDeleteVO(result="success")
async def ValidateAttachmentForChat(
self,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
AttachmentId: str,
UserArea: str | None = None,
) -> dict:
record = await self._get_attachment_record(AttachmentId)
self._assert_attachment_scope(
record,
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
RequireCompleted=True,
Now=datetime.now(timezone.utc),
)
return record
async def RetrieveAttachmentContext(
self,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
AttachmentId: str,
Query: str,
TopK: int = 5,
UserArea: str | None = None,
) -> tuple[list[dict], str]:
record = await self.ValidateAttachmentForChat(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
AttachmentId=AttachmentId,
)
result = await self.retriever.retrieve(
query=Query,
collection_name=str(record["collection_name"]),
top_k=TopK,
)
chunks: list[dict] = []
for chunk in result.chunks:
chunks.append(
{
**chunk,
"dataset_name": "临时上传文件",
"attachment_id": AttachmentId,
"conversation_id": ConversationId,
"source_scope": "chat_attachment",
"data_source_type": "chat_attachment",
"document_name": chunk.get("document_name") or record.get("original_name") or chunk.get("source"),
"source": chunk.get("source") or record.get("original_name") or "上传文件",
}
)
return chunks, str(record.get("original_name") or "上传文件")
async def ResolveActiveAttachmentIdForConversation(
self,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
UserArea: str | None = None,
) -> str | None:
attachment_ids = await self.ResolveActiveAttachmentIdsForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
return attachment_ids[0] if attachment_ids else None
async def ResolveActiveAttachmentIdsForConversation(
self,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
UserArea: str | None = None,
) -> list[str]:
self._validate_conversation_id(ConversationId)
async with GetAsyncSession() as session:
await self._ensure_attachment_schema(session)
rows = (
await session.execute(
text(
"""
SELECT attachment_id, tenant_code, area
FROM rag_chat_attachment
WHERE user_id = :user_id
AND conversation_id = :conversation_id
AND indexing_status = 'completed'
AND expires_at > NOW()
AND deleted_at IS NULL
AND (
NULLIF(:tenant_code, '') IS NULL
OR NULLIF(tenant_code, '') IS NULL
OR tenant_code = :tenant_code
)
AND (
NULLIF(:tenant_code, '') IS NOT NULL
OR NULLIF(:user_area, '') IS NULL
OR NULLIF(area, '') IS NULL
OR area = :user_area
)
ORDER BY indexing_completed_at DESC NULLS LAST, created_at DESC
"""
),
{
"user_id": CurrentUserId,
"conversation_id": ConversationId,
"tenant_code": TenantCode,
"user_area": UserArea,
},
)
).mappings().all()
return [str(row["attachment_id"]) for row in rows]
async def CleanupExpiredAttachments(self, *, Limit: int = 100) -> int:
async with GetAsyncSession() as session:
await self._ensure_attachment_schema(session)
rows = (
await session.execute(
text(
"""
SELECT attachment_id, collection_name, minio_path
FROM rag_chat_attachment
WHERE deleted_at IS NULL
AND expires_at <= NOW()
ORDER BY expires_at ASC
LIMIT :limit
"""
),
{"limit": max(1, int(Limit or 100))},
)
).mappings().all()
attachment_ids = [str(row["attachment_id"]) for row in rows]
if attachment_ids:
await session.execute(
text(
"""
UPDATE rag_chat_attachment
SET deleted_at = NOW(), updated_at = NOW()
WHERE attachment_id = ANY(:attachment_ids)
"""
),
{"attachment_ids": attachment_ids},
)
for row in rows:
self._delete_collection(str(row.get("collection_name") or ""))
self._delete_oss_object(str(row.get("minio_path") or ""))
return len(rows)
async def _get_attachment_record(self, attachment_id: str) -> dict:
if not str(attachment_id or "").strip():
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "附件ID不能为空")
async with GetAsyncSession() as session:
await self._ensure_attachment_schema(session)
row = (
await session.execute(
text(
"""
SELECT *
FROM rag_chat_attachment
WHERE attachment_id = :attachment_id
AND deleted_at IS NULL
LIMIT 1
"""
),
{"attachment_id": attachment_id},
)
).mappings().first()
if not row:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "临时附件不存在")
return dict(row)
def _assert_attachment_scope(
self,
record: dict,
*,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
RequireCompleted: bool,
Now: datetime | None = None,
UserArea: str | None = None,
) -> None:
if record.get("deleted_at") is not None:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "临时附件不存在")
if int(record.get("user_id") or 0) != int(CurrentUserId):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该临时附件")
if str(record.get("conversation_id") or "") != str(ConversationId or ""):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前会话")
expected_tenant = str(TenantCode or "").strip()
record_tenant = str(record.get("tenant_code") or "").strip()
if expected_tenant and record_tenant and expected_tenant != record_tenant:
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前租户")
expected_area = str(UserArea or "").strip()
record_area = str(record.get("area") or "").strip()
if not expected_tenant and expected_area and record_area and expected_area != record_area:
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前地区")
expires_at = self._ensure_aware_datetime(record.get("expires_at"))
now = self._ensure_aware_datetime(Now or datetime.now(timezone.utc))
if expires_at <= now:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件已过期,请重新上传")
if RequireCompleted and str(record.get("indexing_status") or "") != "completed":
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件尚未完成解析,请稍后再试")
async def _run_attachment_indexing_task(
self,
*,
attachment_id: str,
tenant_code: str | None,
user_id: int,
conversation_id: str,
file_name: str,
content: bytes,
collection_name: str,
) -> None:
try:
await self._update_attachment_state(attachment_id=attachment_id, status="parsing")
page_texts = await self._extract_page_texts(FileName=file_name, Content=content)
if not page_texts:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文件未提取到可检索文本")
await self._update_attachment_state(attachment_id=attachment_id, status="splitting")
chunks = self.BuildChunks(
AttachmentId=attachment_id,
TenantCode=tenant_code,
UserId=user_id,
ConversationId=conversation_id,
FileName=file_name,
PageTexts=page_texts,
)
if not chunks:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文件未切分出可检索文本")
await self._update_attachment_state(attachment_id=attachment_id, status="indexing", chunk_count=len(chunks))
embeddings = await self.retriever._embed_texts([item["text"] for item in chunks], "")
collection = self._get_chroma().get_or_create_collection(collection_name)
collection.add(
ids=[item["id"] for item in chunks],
documents=[item["text"] for item in chunks],
embeddings=embeddings,
metadatas=[item["metadata"] for item in chunks],
)
await self._update_attachment_state(
attachment_id=attachment_id,
status="completed",
chunk_count=len(chunks),
completed=True,
)
except Exception as exc:
await self._update_attachment_state(
attachment_id=attachment_id,
status="error",
error=str(exc)[:2000],
)
async def _extract_page_texts(self, *, FileName: str, Content: bytes) -> list[tuple[int, str]]:
suffix = Path(FileName).suffix.lower()
if suffix in {".txt", ".md"}:
text_value = Content.decode("utf-8", errors="ignore").strip()
return [(1, text_value)] if text_value else []
if suffix == ".json":
return self.dataset_helpers._extract_page_texts_from_json(Content)
if suffix == ".csv":
return [(1, self._csv_to_text(Content))] if self._csv_to_text(Content).strip() else []
if suffix == ".xlsx":
return self._extract_page_texts_from_xlsx(Content)
if suffix in SUPPORTED_IMAGE_SUFFIXES:
return await self._extract_page_texts_from_image(FileName, Content)
return self.dataset_helpers._extract_page_texts(FileName=FileName, Content=Content)
def _csv_to_text(self, content: bytes) -> str:
text_value = content.decode("utf-8-sig", errors="ignore")
rows: list[str] = []
for row_index, row in enumerate(csv.reader(text_value.splitlines()), start=1):
values = [str(cell).strip() for cell in row if str(cell).strip()]
if values:
rows.append(f"row_{row_index}: " + " | ".join(values))
return "\n".join(rows)
def _extract_page_texts_from_xlsx(self, content: bytes) -> list[tuple[int, str]]:
try:
from openpyxl import load_workbook
except Exception as exc: # pragma: no cover - depends on optional env package
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前环境未安装 openpyxl,暂无法解析 Excel 文件") from exc
with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as temp_file:
temp_file.write(content)
temp_path = temp_file.name
try:
workbook = load_workbook(temp_path, data_only=True, read_only=True)
page_texts: list[tuple[int, str]] = []
for sheet_index, sheet in enumerate(workbook.worksheets, start=1):
lines: list[str] = [f"Sheet: {sheet.title}"]
for row in sheet.iter_rows(values_only=True):
values = [str(cell).strip() for cell in row if cell is not None and str(cell).strip()]
if values:
lines.append(" | ".join(values))
rendered = "\n".join(lines).strip()
if rendered:
page_texts.append((sheet_index, rendered))
return page_texts
finally:
try:
os.unlink(temp_path)
except OSError:
pass
async def _extract_page_texts_from_image(self, file_name: str, content: bytes) -> list[tuple[int, str]]:
suffix = Path(file_name).suffix.lower() or ".png"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file:
temp_file.write(content)
temp_path = Path(temp_file.name)
try:
ocr_client = self._create_ocr_client()
result = await ocr_client.ocr(temp_path)
text_value = self._ocr_result_to_text(result)
return [(1, text_value)] if text_value else []
finally:
try:
temp_path.unlink()
except OSError:
pass
def _ocr_result_to_text(self, value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value.strip()
if isinstance(value, dict):
candidates: list[str] = []
for key in ("full_text", "text", "content", "markdown"):
if isinstance(value.get(key), str) and value.get(key).strip():
candidates.append(value[key].strip())
pages = value.get("pages")
if isinstance(pages, list):
for page in pages:
page_text = self._ocr_result_to_text(page)
if page_text:
candidates.append(page_text)
lines = value.get("lines") or value.get("blocks")
if isinstance(lines, list):
for line in lines:
line_text = self._ocr_result_to_text(line)
if line_text:
candidates.append(line_text)
return "\n".join(dict.fromkeys(candidates)).strip()
if isinstance(value, list):
return "\n".join(item for item in (self._ocr_result_to_text(item) for item in value) if item).strip()
text_attrs = []
for attr in ("full_text", "text", "content", "markdown"):
raw = getattr(value, attr, None)
if isinstance(raw, str) and raw.strip():
text_attrs.append(raw.strip())
pages = getattr(value, "pages", None)
if isinstance(pages, list):
for page in pages:
page_text = self._ocr_result_to_text(page)
if page_text:
text_attrs.append(page_text)
return "\n".join(dict.fromkeys(text_attrs)).strip()
def _create_ocr_client(self):
if self._ocr_client_factory is not None:
return self._ocr_client_factory()
from fastapi_modules.fastapi_leaudit.leaudit_bridge.client_factory import create_ocr_client
return create_ocr_client()
async def _update_attachment_state(
self,
*,
attachment_id: str,
status: str,
chunk_count: int | None = None,
error: str | None = None,
completed: bool = False,
) -> None:
fields = [
"indexing_status = :status",
"indexing_error = :error",
"updated_at = NOW()",
"indexing_started_at = COALESCE(indexing_started_at, NOW())",
]
if chunk_count is not None:
fields.append("chunk_count = :chunk_count")
if completed:
fields.append("indexing_completed_at = NOW()")
params: dict[str, Any] = {
"attachment_id": attachment_id,
"status": status,
"error": error,
"chunk_count": chunk_count,
}
async with GetAsyncSession() as session:
await self._ensure_attachment_schema(session)
await session.execute(
text(
f"""
UPDATE rag_chat_attachment
SET {", ".join(fields)}
WHERE attachment_id = :attachment_id
"""
),
params,
)
async def _ensure_attachment_schema(self, session) -> None:
if self.__class__._attachment_schema_checked:
return
async with self.__class__._attachment_schema_lock:
if self.__class__._attachment_schema_checked:
return
await session.execute(
text(
"""
CREATE TABLE IF NOT EXISTS rag_chat_attachment (
id BIGSERIAL PRIMARY KEY,
attachment_id VARCHAR(64) NOT NULL UNIQUE,
conversation_id VARCHAR(64) NOT NULL,
user_id BIGINT NOT NULL,
tenant_code VARCHAR(64) NULL,
area VARCHAR(255) NULL,
filename VARCHAR(512) NOT NULL,
original_name VARCHAR(512) NOT NULL,
content_type VARCHAR(255) NULL,
file_size BIGINT NOT NULL DEFAULT 0,
minio_path TEXT NULL,
collection_name VARCHAR(160) NOT NULL,
indexing_status VARCHAR(32) NOT NULL DEFAULT 'waiting',
indexing_error TEXT NULL,
chunk_count INTEGER NOT NULL DEFAULT 0,
indexing_started_at TIMESTAMPTZ NULL,
indexing_completed_at TIMESTAMPTZ NULL,
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ NULL
)
"""
)
)
await session.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_scope
ON rag_chat_attachment(tenant_code, user_id, conversation_id, attachment_id)
WHERE deleted_at IS NULL
"""
)
)
await session.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_expires
ON rag_chat_attachment(expires_at)
WHERE deleted_at IS NULL
"""
)
)
self.__class__._attachment_schema_checked = True
def _get_chroma(self) -> Any:
return self._chroma_client or get_chroma()
def _delete_collection(self, collection_name: str) -> None:
if not collection_name:
return
try:
self._get_chroma().delete_collection(collection_name)
except Exception:
return
def _delete_oss_object(self, source: str | None) -> None:
if not source:
return
try:
oss = OssClient()
ref = oss.ResolveObjectRef(Source=source)
if ref.isDirectUrl or not ref.objectKey:
return
oss._GetMinioClient().remove_object(ref.bucket, ref.objectKey)
except Exception:
return
def _to_vo(self, row: dict) -> RagChatAttachmentVO:
expires_at = row.get("expires_at")
created_at = row.get("created_at")
return RagChatAttachmentVO(
attachmentId=str(row.get("attachment_id") or ""),
conversationId=str(row.get("conversation_id") or ""),
fileName=str(row.get("original_name") or row.get("filename") or ""),
contentType=str(row.get("content_type") or ""),
fileSize=int(row.get("file_size") or 0),
indexingStatus=str(row.get("indexing_status") or "waiting"),
indexingError=row.get("indexing_error"),
chunkCount=int(row.get("chunk_count") or 0),
collectionName=str(row.get("collection_name") or ""),
expiresAt=self._timestamp(expires_at),
createdAt=self._timestamp(created_at),
)
def _preprocess_text(self, text_value: str) -> str:
result = text_value or ""
result = re.sub(r"[ \t]+", " ", result)
result = re.sub(r"\n{3,}", "\n\n", result)
return result.strip()
def _split_text(self, text_value: str, *, separator: str, max_chars: int, overlap: int) -> list[str]:
return self.dataset_helpers._split_text(text_value, separator=separator, max_chars=max_chars, overlap=overlap)
def _slugify(self, value: str) -> str:
normalized = re.sub(r"[^A-Za-z0-9_]+", "_", str(value or "").strip())
normalized = re.sub(r"_+", "_", normalized).strip("_").lower()
return normalized or "default"
def _validate_conversation_id(self, conversation_id: str | None) -> None:
if not str(conversation_id or "").strip() or str(conversation_id or "").strip() == "-1":
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "会话ID不能为空")
def _ensure_aware_datetime(self, value: Any) -> datetime:
if isinstance(value, datetime):
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value
if isinstance(value, str):
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
if parsed.tzinfo is None:
return parsed.replace(tzinfo=timezone.utc)
return parsed
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件时间字段异常")
def _timestamp(self, value: Any) -> int:
if not value:
return 0
return int(self._ensure_aware_datetime(value).timestamp())
@@ -106,6 +106,8 @@ class RagChatServiceImpl(IRagChatService):
Query: str,
ConversationId: str | None,
AppId: int | None,
AttachmentId: str | None = None,
AttachmentIds: list[str] | None = None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> AsyncGenerator[bytes, None]:
@@ -128,6 +130,25 @@ class RagChatServiceImpl(IRagChatService):
messageId = str(uuid.uuid4())
taskId = str(uuid.uuid4())
is_new_conversation = not ConversationId or ConversationId == "-1"
active_attachment_ids = self._normalize_attachment_ids(
attachment_id=AttachmentId,
attachment_ids=AttachmentIds,
)
if not active_attachment_ids:
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=conversationId,
)
attachment_records = await self._load_message_attachment_records(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=conversationId,
AttachmentIds=active_attachment_ids,
)
user_message_files = self._build_user_message_files(attachment_records)
async with GetAsyncSession() as session:
async with session.begin():
@@ -135,13 +156,14 @@ class RagChatServiceImpl(IRagChatService):
text(
"""
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, '{}'::jsonb)
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, CAST(:metadata AS jsonb))
"""
),
{
"message_id": str(uuid.uuid4()),
"conversation_id": conversationId,
"content": Query,
"metadata": json.dumps({"message_files": user_message_files}, ensure_ascii=False),
},
)
await session.execute(
@@ -170,6 +192,11 @@ class RagChatServiceImpl(IRagChatService):
message_id=messageId,
query=Query,
app=app,
current_user_id=CurrentUserId,
tenant_code=TenantCode,
user_area=UserArea,
attachment_id=AttachmentId,
attachment_ids=active_attachment_ids,
)
event_index = 0
@@ -322,6 +349,7 @@ class RagChatServiceImpl(IRagChatService):
row = items[idx]
if row["role"] == "user":
answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None
user_metadata = dict(row.get("metadata") or {})
answer_metadata = dict((answer.get("metadata") if answer else None) or {})
answer_status = str(answer_metadata.get("status") or ("completed" if answer else "running"))
answer_content = (answer.get("content") if answer else None) or ""
@@ -355,6 +383,7 @@ class RagChatServiceImpl(IRagChatService):
conversationId=ConversationId,
query=row["content"],
answer=answer_content if answer else "",
messageFiles=[item for item in (user_metadata.get("message_files") or []) if isinstance(item, dict)],
feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None),
retrieverResources=(answer.get("sources") if answer else None),
suggestedQuestions=[str(item) for item in (answer_metadata.get("suggested_questions") or []) if str(item).strip()],
@@ -940,7 +969,181 @@ class RagChatServiceImpl(IRagChatService):
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
result = await self.retriever.retrieve(query=query, dataset_id=dataset_id)
return result.chunks, result.dataset_name
chunks = [
{
**chunk,
"source_scope": chunk.get("source_scope") or "formal_kb",
"data_source_type": chunk.get("data_source_type") or "formal_kb",
}
for chunk in result.chunks
]
return chunks, result.dataset_name
async def _resolve_attachment_id_for_conversation(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
) -> str | None:
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
return await service.ResolveActiveAttachmentIdForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
async def _resolve_attachment_ids_for_conversation(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
) -> list[str]:
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
if hasattr(service, "ResolveActiveAttachmentIdsForConversation"):
return await service.ResolveActiveAttachmentIdsForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
attachment_id = await service.ResolveActiveAttachmentIdForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
return [attachment_id] if attachment_id else []
async def _retrieve_attachment_context(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
AttachmentId: str,
Query: str,
) -> tuple[list[dict], str]:
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
return await service.RetrieveAttachmentContext(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
AttachmentId=AttachmentId,
Query=Query,
TopK=5,
)
async def _load_message_attachment_records(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
AttachmentIds: list[str],
) -> list[dict]:
if not AttachmentIds:
return []
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
records: list[dict] = []
for attachment_id in AttachmentIds:
record = await service.ValidateAttachmentForChat(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
AttachmentId=attachment_id,
)
records.append(record)
return records
def _build_user_message_files(self, attachment_records: list[dict]) -> list[dict]:
files: list[dict] = []
for record in attachment_records:
attachment_id = str(record.get("attachment_id") or "").strip()
if not attachment_id:
continue
file_name = str(record.get("original_name") or record.get("filename") or "上传文件")
content_type = str(record.get("content_type") or "")
file_type = "image" if content_type.startswith("image/") or re.search(r"\.(png|jpe?g|webp|bmp|tiff?)$", file_name, re.I) else "file"
files.append(
{
"id": attachment_id,
"upload_file_id": attachment_id,
"name": file_name,
"fileName": file_name,
"type": file_type,
"transfer_method": "local_file",
"contentType": content_type or None,
"fileSize": int(record.get("file_size") or 0),
"belongs_to": "user",
"usage": "temporary_attachment",
}
)
return files
def _build_formal_kb_query(self, *, query: str, attachment_chunks: list[dict]) -> str:
if not attachment_chunks:
return query
facts: list[str] = []
for chunk in attachment_chunks[:3]:
text_value = str(chunk.get("text") or "").strip()
if text_value:
facts.append(text_value[:500])
if not facts:
return query
return (
f"{query}\n\n"
"用户上传文档中检索到的相关事实:\n"
+ "\n".join(f"- {item}" for item in facts)
+ "\n\n请检索这些事实对应的法律责任、处罚依据、裁量规则或案例。"
)
def _merge_context_chunks(self, *, attachment_chunks: list[dict], formal_chunks: list[dict]) -> list[dict]:
merged: list[dict] = []
for chunk in attachment_chunks:
merged.append(
{
**chunk,
"source_scope": "chat_attachment",
"data_source_type": "chat_attachment",
}
)
for chunk in formal_chunks:
merged.append(
{
**chunk,
"source_scope": chunk.get("source_scope") or "formal_kb",
"data_source_type": chunk.get("data_source_type") or "formal_kb",
}
)
return merged
def _normalize_attachment_ids(self, *, attachment_id: str | None, attachment_ids: list[str] | None) -> list[str]:
normalized: list[str] = []
for raw in [*(attachment_ids or []), attachment_id]:
value = str(raw or "").strip()
if value and value not in normalized:
normalized.append(value)
return normalized
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
return await self.retriever._embed_texts(texts, model_name)
@@ -953,6 +1156,11 @@ class RagChatServiceImpl(IRagChatService):
message_id: str,
query: str,
app: dict,
current_user_id: int | None = None,
tenant_code: str | None = None,
user_area: str | None = None,
attachment_id: str | None = None,
attachment_ids: list[str] | None = None,
) -> None:
self._task_events[task_id] = []
self._task_done[task_id] = False
@@ -964,6 +1172,11 @@ class RagChatServiceImpl(IRagChatService):
message_id=message_id,
query=query,
app=app,
current_user_id=current_user_id,
tenant_code=tenant_code,
user_area=user_area,
attachment_id=attachment_id,
attachment_ids=attachment_ids,
)
)
self._message_tasks[task_id] = task
@@ -976,6 +1189,11 @@ class RagChatServiceImpl(IRagChatService):
message_id: str,
query: str,
app: dict,
current_user_id: int | None = None,
tenant_code: str | None = None,
user_area: str | None = None,
attachment_id: str | None = None,
attachment_ids: list[str] | None = None,
) -> None:
context_chunks: list[dict] = []
dataset_name = ""
@@ -987,7 +1205,46 @@ class RagChatServiceImpl(IRagChatService):
last_persisted_at = time.monotonic()
try:
context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), query)
attachment_chunks: list[dict] = []
attachment_names: list[str] = []
active_attachment_ids = self._normalize_attachment_ids(
attachment_id=attachment_id,
attachment_ids=attachment_ids,
)
if not active_attachment_ids and current_user_id is not None:
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
CurrentUserId=current_user_id,
TenantCode=tenant_code,
UserArea=user_area,
ConversationId=conversation_id,
)
if active_attachment_ids:
if current_user_id is None:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件缺少用户上下文")
for active_attachment_id in active_attachment_ids:
next_chunks, next_attachment_name = await self._retrieve_attachment_context(
CurrentUserId=current_user_id,
TenantCode=tenant_code,
UserArea=user_area,
ConversationId=conversation_id,
AttachmentId=active_attachment_id,
Query=query,
)
attachment_chunks.extend(next_chunks)
if next_attachment_name:
attachment_names.append(next_attachment_name)
legal_query = self._build_formal_kb_query(query=query, attachment_chunks=attachment_chunks)
formal_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), legal_query)
context_chunks = self._merge_context_chunks(
attachment_chunks=attachment_chunks,
formal_chunks=formal_chunks,
)
generation_dataset_name = dataset_name
attachment_name = "".join(dict.fromkeys(attachment_names))
if attachment_name and dataset_name:
generation_dataset_name = f"{attachment_name} + {dataset_name}"
elif attachment_name:
generation_dataset_name = attachment_name
async for chunk in generate_stream(
query=query,
context_chunks=context_chunks,
@@ -997,7 +1254,7 @@ class RagChatServiceImpl(IRagChatService):
model=app.get("llm_model") or "",
temperature=app.get("temperature"),
max_tokens=app.get("max_tokens"),
dataset_name=dataset_name,
dataset_name=generation_dataset_name,
task_id=task_id,
):
data = self._parse_sse_event(chunk)
@@ -323,6 +323,18 @@ class RbacAdminServiceImpl(IRbacAdminService):
"is_cache": True,
"meta": {"group": "settings"},
},
{
"route_path": "/rule-groups",
"route_name": "rule-groups",
"component": "rule-groups",
"route_title": "评查点分组",
"icon": "ri-node-tree",
"sort_order": 6,
"parent_path": "/settings",
"is_hidden": False,
"is_cache": True,
"meta": {"group": "settings"},
},
]
_MANAGEABLE_PERMISSION_BLUEPRINTS: list[dict[str, Any]] = [
@@ -332,10 +344,10 @@ class RbacAdminServiceImpl(IRbacAdminService):
{"permission_key": "entry_module:update:write", "display_name": "更新入口模块", "module": "entry_module", "resource": "update", "action": "write", "api_method": "PUT", "api_path": "/api/v3/entry-modules/{id}", "route_path": "/entry-modules"},
{"permission_key": "entry_module:delete:delete", "display_name": "删除入口模块", "module": "entry_module", "resource": "delete", "action": "delete", "api_method": "DELETE", "api_path": "/api/v3/entry-modules/{id}", "route_path": "/entry-modules"},
{"permission_key": "entry_module:image:write", "display_name": "上传入口模块图标", "module": "entry_module", "resource": "image", "action": "write", "api_method": "POST", "api_path": "/api/v3/entry-modules/{id}/image", "route_path": "/entry-modules"},
{"permission_key": "doc_type:list:read", "display_name": "文档类型列表", "module": "doc_type", "resource": "list", "action": "read", "api_method": "GET", "api_path": "/api/document-types", "route_path": "/document-types"},
{"permission_key": "doc_type:detail:read", "display_name": "文档类型详情", "module": "doc_type", "resource": "detail", "action": "read", "api_method": "GET", "api_path": "/api/document-types/{id}", "route_path": "/document-types"},
{"permission_key": "doc_type:create:write", "display_name": "创建文档类型", "module": "doc_type", "resource": "create", "action": "write", "api_method": "POST", "api_path": "/api/document-types", "route_path": "/document-types"},
{"permission_key": "doc_type:update:write", "display_name": "更新文档类型", "module": "doc_type", "resource": "update", "action": "write", "api_method": "PUT", "api_path": "/api/document-types/{id}", "route_path": "/document-types"},
{"permission_key": "doc_type:list:read", "display_name": "业务大类列表", "module": "doc_type", "resource": "list", "action": "read", "api_method": "GET", "api_path": "/api/v3/document-type-roots", "route_path": "/document-types"},
{"permission_key": "doc_type:detail:read", "display_name": "业务大类详情", "module": "doc_type", "resource": "detail", "action": "read", "api_method": "GET", "api_path": "/api/v3/document-type-roots/{id}", "route_path": "/document-types"},
{"permission_key": "doc_type:create:write", "display_name": "创建业务大类", "module": "doc_type", "resource": "create", "action": "write", "api_method": "POST", "api_path": "/api/v3/document-type-roots", "route_path": "/document-types"},
{"permission_key": "doc_type:update:write", "display_name": "更新业务大类", "module": "doc_type", "resource": "update", "action": "write", "api_method": "PUT", "api_path": "/api/v3/document-type-roots/{id}", "route_path": "/document-types"},
{"permission_key": "doc_type:delete:delete", "display_name": "删除文档类型", "module": "doc_type", "resource": "delete", "action": "delete", "api_method": "DELETE", "api_path": "/api/document-types/{id}", "route_path": "/document-types"},
{"permission_key": "rbac:tenants:read", "display_name": "查看租户列表", "module": "rbac", "resource": "tenants", "action": "read", "api_method": "GET", "api_path": "/api/v3/tenants", "route_path": "/tenants"},
{"permission_key": "rbac:tenants:create", "display_name": "创建租户", "module": "rbac", "resource": "tenants", "action": "create", "api_method": "POST", "api_path": "/api/v3/tenants", "route_path": "/tenants"},
@@ -26,6 +26,7 @@ class RbacServiceImpl(IRbacService):
"/files",
"/documents",
"/rules",
"/rule-groups",
"/rules-files",
"/settings",
"/entry-modules",
@@ -322,6 +323,20 @@ class RbacServiceImpl(IRbacService):
"meta": {"group": "settings"},
"children": None,
},
{
"id": 1019,
"route_path": "/rule-groups",
"route_name": "rule-groups",
"component": "rule-groups",
"parent_id": 1013,
"route_title": "评查点分组",
"icon": "ri-node-tree",
"sort_order": 6,
"is_hidden": False,
"is_cache": True,
"meta": {"group": "settings"},
"children": None,
},
],
},
{
@@ -738,7 +753,8 @@ class RbacServiceImpl(IRbacService):
databaseRoutes = await self._loadDatabaseRoutes(Session, roleIds, grantedPermissions)
if self._isFrontendRouteSetReady(databaseRoutes):
routes = self._filterRoutesByMinimalScope(databaseRoutes)
grantedRoutePaths = self._collectCurrentFrontendRoutePaths(databaseRoutes)
routes = self._filterRoutesByRouteAndPermissionScope(databaseRoutes, grantedRoutePaths, grantedPermissions)
else:
routes = self._buildCompatibilityRoutes(roleKeys, grantedPermissions)
@@ -872,6 +888,30 @@ class RbacServiceImpl(IRbacService):
filtered.append(routeCopy)
return filtered
def _filterRoutesByRouteAndPermissionScope(
self,
Routes: list[RbacRouteVO],
GrantedRoutePaths: set[str],
GrantedPermissions: set[str],
) -> list[RbacRouteVO]:
"""按角色已勾选路由裁剪,接口权限不能替代子路由勾选。"""
filtered: list[RbacRouteVO] = []
for route in Routes:
if not self._isRoutePathEnabled(route.route_path):
continue
if route.route_path not in GrantedRoutePaths:
continue
routeCopy = route.model_copy(deep=True)
routeCopy.permissions = self._resolvePermissionsForPath(route.route_path, GrantedPermissions)
routeCopy.children = self._filterRoutesByRouteAndPermissionScope(
route.children or [],
GrantedRoutePaths,
GrantedPermissions,
) or None
filtered.append(routeCopy)
return filtered
def _filterBlueprintsByMinimalScope(self, Blueprints: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""按当前最小可用范围裁剪兼容蓝图。"""
filtered: list[dict[str, Any]] = []
@@ -953,6 +993,14 @@ class RbacServiceImpl(IRbacService):
paths.update(self._collectRoutePaths(route.children))
return paths
def _collectCurrentFrontendRoutePaths(self, Routes: list[RbacRouteVO]) -> set[str]:
"""收集当前前端真实路由,旧 govdoc-audit 残留授权不映射成新版子路由。"""
return {
path
for path in self._collectRoutePaths(Routes)
if not path.startswith("/govdoc-audit/")
}
@staticmethod
def _normalizeMeta(Meta: Any) -> dict | None:
"""兼容 meta 为 JSON 字符串、字典或空值的情况。"""
@@ -0,0 +1,89 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatAttachmentVo import (
RagChatAttachmentDeleteVO,
RagChatAttachmentVO,
)
class IRagChatAttachmentService(ABC):
@abstractmethod
async def CreateAttachment(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
ConversationId: str | None,
AppId: int | None,
FileName: str,
ContentType: str | None,
Content: bytes,
) -> RagChatAttachmentVO: ...
@abstractmethod
async def GetAttachment(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
ConversationId: str,
AttachmentId: str,
) -> RagChatAttachmentVO: ...
@abstractmethod
async def DeleteAttachment(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
ConversationId: str,
AttachmentId: str,
) -> RagChatAttachmentDeleteVO: ...
@abstractmethod
async def ValidateAttachmentForChat(
self,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
AttachmentId: str,
UserArea: str | None = None,
) -> dict: ...
@abstractmethod
async def RetrieveAttachmentContext(
self,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
AttachmentId: str,
Query: str,
TopK: int = 5,
UserArea: str | None = None,
) -> tuple[list[dict], str]: ...
@abstractmethod
async def ResolveActiveAttachmentIdForConversation(
self,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
UserArea: str | None = None,
) -> str | None: ...
@abstractmethod
async def ResolveActiveAttachmentIdsForConversation(
self,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
UserArea: str | None = None,
) -> list[str]: ...
@@ -50,6 +50,8 @@ class IRagChatService(ABC):
Query: str,
ConversationId: str | None,
AppId: int | None,
AttachmentId: str | None = None,
AttachmentIds: list[str] | None = None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> AsyncGenerator[bytes, None]: ...
Submodule legal-platform-frontend updated: 469de25dc8...88476a11bc
+1
View File
@@ -22,6 +22,7 @@ dependencies = [
"pyjwt>=2.10.0",
"openai>=1.30.0",
"pillow>=11.0.0",
"openpyxl>=3.1.0",
"pyyaml>=6.0",
"minio>=7.2.8",
"leaudit",
@@ -0,0 +1,31 @@
CREATE TABLE IF NOT EXISTS rag_chat_attachment (
id BIGSERIAL PRIMARY KEY,
attachment_id VARCHAR(64) NOT NULL UNIQUE,
conversation_id VARCHAR(64) NOT NULL,
user_id BIGINT NOT NULL,
tenant_code VARCHAR(64) NULL,
area VARCHAR(255) NULL,
filename VARCHAR(512) NOT NULL,
original_name VARCHAR(512) NOT NULL,
content_type VARCHAR(255) NULL,
file_size BIGINT NOT NULL DEFAULT 0,
minio_path TEXT NULL,
collection_name VARCHAR(160) NOT NULL,
indexing_status VARCHAR(32) NOT NULL DEFAULT 'waiting',
indexing_error TEXT NULL,
chunk_count INTEGER NOT NULL DEFAULT 0,
indexing_started_at TIMESTAMPTZ NULL,
indexing_completed_at TIMESTAMPTZ NULL,
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ NULL
);
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_scope
ON rag_chat_attachment(tenant_code, user_id, conversation_id, attachment_id)
WHERE deleted_at IS NULL;
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_expires
ON rag_chat_attachment(expires_at)
WHERE deleted_at IS NULL;
+169
View File
@@ -0,0 +1,169 @@
"""文档类型权限控制测试。"""
import pytest
from starlette.responses import JSONResponse
from fastapi_modules.fastapi_leaudit.controllers.documentController import DocumentController
from fastapi_modules.fastapi_leaudit.domian.vo.documentVo import (
DocumentTypeCreateDTO,
DocumentTypeItemVO,
DocumentTypeRootCreateDTO,
DocumentTypeRootItemVO,
DocumentTypeRootUpdateDTO,
DocumentTypeUpdateDTO,
)
class _DenyPermissionService:
"""拒绝所有权限的测试权限服务。"""
async def CheckPermission(self, UserId: int, PermissionKey: str) -> bool:
return False
class _AllowOnlyPermissionService:
"""只允许指定权限的测试权限服务。"""
def __init__(self, allowed: set[str]) -> None:
self.allowed = allowed
async def CheckPermission(self, UserId: int, PermissionKey: str) -> bool:
return PermissionKey in self.allowed
class _FakeDocumentService:
"""记录调用的测试文档服务。"""
def __init__(self) -> None:
self.calls: list[str] = []
async def ListDocumentTypes(self, **kwargs):
self.calls.append("ListDocumentTypes")
return [
DocumentTypeItemVO(
id=1,
name="合同",
code="contract",
description=None,
entryModuleId=None,
isEnabled=True,
ruleSetIds=[],
)
]
async def GetDocumentType(self, **kwargs):
self.calls.append("GetDocumentType")
return DocumentTypeItemVO(
id=1,
name="合同",
code="contract",
description=None,
entryModuleId=None,
isEnabled=True,
ruleSetIds=[],
)
async def CreateDocumentType(self, **kwargs):
self.calls.append("CreateDocumentType")
return await self.GetDocumentType()
async def UpdateDocumentType(self, **kwargs):
self.calls.append("UpdateDocumentType")
return await self.GetDocumentType()
async def DeleteDocumentType(self, **kwargs):
self.calls.append("DeleteDocumentType")
async def ListDocumentTypeRoots(self, **kwargs):
self.calls.append("ListDocumentTypeRoots")
return [
DocumentTypeRootItemVO(
id=11,
name="合同",
code="root.contract",
description=None,
entryModuleId=None,
entryModuleName=None,
isEnabled=True,
childGroupCount=0,
ruleSetCount=0,
ruleSetIds=[],
)
]
async def GetDocumentTypeRoot(self, **kwargs):
self.calls.append("GetDocumentTypeRoot")
return DocumentTypeRootItemVO(
id=11,
name="合同",
code="root.contract",
description=None,
entryModuleId=None,
entryModuleName=None,
isEnabled=True,
childGroupCount=0,
ruleSetCount=0,
ruleSetIds=[],
)
async def CreateDocumentTypeRoot(self, **kwargs):
self.calls.append("CreateDocumentTypeRoot")
return await self.GetDocumentTypeRoot()
async def UpdateDocumentTypeRoot(self, **kwargs):
self.calls.append("UpdateDocumentTypeRoot")
return await self.GetDocumentTypeRoot()
def _find_endpoint(controller: DocumentController, path: str, method: str):
"""根据路径和方法查找路由 endpoint。"""
full_path = f"{controller.router.prefix}{path}"
for route in controller.router.routes:
if getattr(route, "path", "") == full_path and method in getattr(route, "methods", set()):
return route.endpoint
raise AssertionError(f"未找到路由 {method} {full_path}")
@pytest.mark.asyncio
@pytest.mark.parametrize(
("path", "method", "kwargs", "expected_call"),
[
("/document-types", "GET", {"ids": None, "entry_module_id": None}, "ListDocumentTypes"),
("/document-types/{TypeId}", "GET", {"TypeId": 1}, "GetDocumentType"),
("/document-types", "POST", {"Body": DocumentTypeCreateDTO(code="contract", name="合同")}, "CreateDocumentType"),
("/document-types/{TypeId}", "PUT", {"TypeId": 1, "Body": DocumentTypeUpdateDTO(name="合同")}, "UpdateDocumentType"),
("/document-types/{TypeId}", "DELETE", {"TypeId": 1}, "DeleteDocumentType"),
("/v3/document-type-roots", "GET", {"entry_module_id": None}, "ListDocumentTypeRoots"),
("/v3/document-type-roots/{RootId}", "GET", {"RootId": 11}, "GetDocumentTypeRoot"),
("/v3/document-type-roots", "POST", {"Body": DocumentTypeRootCreateDTO(code="root.contract", name="合同")}, "CreateDocumentTypeRoot"),
("/v3/document-type-roots/{RootId}", "PUT", {"RootId": 11, "Body": DocumentTypeRootUpdateDTO(name="合同")}, "UpdateDocumentTypeRoot"),
],
)
async def test_document_type_endpoints_require_permission(path, method, kwargs, expected_call):
"""文档类型和业务大类接口无权限时返回 403,且不调用业务服务。"""
controller = DocumentController()
service = _FakeDocumentService()
controller.DocumentService = service
controller.PermissionService = _DenyPermissionService()
endpoint = _find_endpoint(controller, path, method)
response = await endpoint(**kwargs, payload={"user_id": 7})
assert isinstance(response, JSONResponse)
assert response.status_code == 403
assert expected_call not in service.calls
@pytest.mark.asyncio
async def test_document_type_root_list_calls_service_when_permission_granted():
"""业务大类列表有查看权限时正常调用业务服务。"""
controller = DocumentController()
service = _FakeDocumentService()
controller.DocumentService = service
controller.PermissionService = _AllowOnlyPermissionService({"doc_type:list:read"})
endpoint = _find_endpoint(controller, "/v3/document-type-roots", "GET")
response = await endpoint(entry_module_id=None, payload={"user_id": 7})
assert response.data[0].id == 11
assert service.calls == ["ListDocumentTypeRoots"]
+88
View File
@@ -145,3 +145,91 @@ def test_govdoc_root_route_marks_frontend_route_set_ready():
]
assert service._isFrontendRouteSetReady(routes) is True
def test_govdoc_parent_route_does_not_expose_ungranted_child_routes():
"""只有内部公文父路由和接口权限时,不应补出未勾选的列表/上传子路由。"""
service = RbacServiceImpl()
routes = [
RbacRouteVO(
id=1,
route_path="/govdoc",
route_name="govdoc",
component="govdoc",
parent_id=None,
route_title="内部公文处理",
children=[
RbacRouteVO(
id=2,
route_path="/govdoc/audits",
route_name="govdoc-audits",
component="govdoc.audits",
parent_id=1,
route_title="公文列表",
),
RbacRouteVO(
id=3,
route_path="/govdoc/upload",
route_name="govdoc-upload",
component="govdoc.upload",
parent_id=1,
route_title="公文上传",
),
],
)
]
filtered = service._filterRoutesByRouteAndPermissionScope(
routes,
{"/govdoc"},
{"govdoc:document:read", "govdoc:document:create"},
)
paths = service._collectRoutePaths(filtered)
assert "/govdoc" in paths
assert "/govdoc/audits" not in paths
assert "/govdoc/upload" not in paths
def test_legacy_govdoc_audit_route_does_not_grant_current_govdoc_child_route():
"""旧 /govdoc-audit 残留授权不应继续放行当前 /govdoc 子路由。"""
service = RbacServiceImpl()
routes = [
RbacRouteVO(
id=1,
route_path="/govdoc",
route_name="govdoc",
component="govdoc",
parent_id=None,
route_title="内部公文处理",
),
RbacRouteVO(
id=2,
route_path="/govdoc-audit/audits",
route_name="legacy-govdoc-audits",
component="govdoc-audit.audits",
parent_id=None,
route_title="旧公文列表",
),
RbacRouteVO(
id=3,
route_path="/govdoc-audit/upload",
route_name="legacy-govdoc-upload",
component="govdoc-audit.upload",
parent_id=None,
route_title="旧公文上传",
),
]
filtered = service._filterRoutesByRouteAndPermissionScope(
routes,
service._collectCurrentFrontendRoutePaths(routes),
{"govdoc:document:read", "govdoc:document:create"},
)
paths = service._collectRoutePaths(filtered)
assert "/govdoc" in paths
assert "/govdoc-audit/audits" not in paths
assert "/govdoc-audit/upload" not in paths
assert "/govdoc/audits" not in paths
assert "/govdoc/upload" not in paths
+372
View File
@@ -0,0 +1,372 @@
from __future__ import annotations
import asyncio
from datetime import datetime, timedelta, timezone
import pytest
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
def _service() -> RagChatAttachmentServiceImpl:
return RagChatAttachmentServiceImpl(chroma_client=None, embed_texts=lambda texts, model_name="": [[0.1] for _ in texts])
def test_default_expiry_is_seven_days_from_now():
service = _service()
now = datetime(2026, 5, 25, 12, 0, 0, tzinfo=timezone.utc)
expires_at = service._default_expires_at(now)
assert expires_at == now + timedelta(days=7)
def test_collection_name_contains_isolation_components():
service = _service()
collection_name = service.BuildCollectionName(
TenantCode="gd-tobacco",
UserId=42,
ConversationId="conversation-abc-123",
AttachmentId="attach-xyz",
)
assert collection_name.startswith("chat_attachment_gd_tobacco_42_")
assert collection_name.endswith("_attach_xyz")
assert "conversation-abc-123" not in collection_name
assert len(collection_name) <= 120
def test_validate_attachment_scope_rejects_other_user_and_conversation():
service = _service()
record = {
"attachment_id": "att-1",
"tenant_code": "tenant-a",
"user_id": 100,
"conversation_id": "conv-a",
"indexing_status": "completed",
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
"deleted_at": None,
}
with pytest.raises(LeauditException) as user_exc:
service._assert_attachment_scope(
record,
CurrentUserId=101,
TenantCode="tenant-a",
ConversationId="conv-a",
RequireCompleted=True,
Now=datetime.now(timezone.utc),
)
assert user_exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
with pytest.raises(LeauditException) as conversation_exc:
service._assert_attachment_scope(
record,
CurrentUserId=100,
TenantCode="tenant-a",
ConversationId="conv-b",
RequireCompleted=True,
Now=datetime.now(timezone.utc),
)
assert conversation_exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
def test_get_attachment_requires_request_conversation_when_provided():
service = _service()
record = {
"attachment_id": "att-1",
"tenant_code": "tenant-a",
"user_id": 100,
"conversation_id": "conv-a",
"indexing_status": "completed",
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
"deleted_at": None,
}
with pytest.raises(LeauditException) as exc:
service._assert_attachment_scope(
record,
CurrentUserId=100,
TenantCode="tenant-a",
ConversationId="conv-b",
RequireCompleted=False,
Now=datetime.now(timezone.utc),
)
assert exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
assert "当前会话" in exc.value.message
def test_get_attachment_rejects_same_user_attachment_from_another_conversation(monkeypatch):
service = _service()
record = {
"attachment_id": "att-1",
"tenant_code": "tenant-a",
"user_id": 100,
"conversation_id": "conv-a",
"indexing_status": "completed",
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
"deleted_at": None,
}
async def fake_get_attachment_record(_attachment_id):
return record
monkeypatch.setattr(service, "_get_attachment_record", fake_get_attachment_record)
with pytest.raises(LeauditException) as exc:
asyncio.run(
service.GetAttachment(
CurrentUserId=100,
UserArea=None,
UserRole=None,
TenantCode="tenant-a",
TenantName=None,
ConversationId="conv-b",
AttachmentId="att-1",
)
)
assert exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
def test_delete_attachment_rejects_same_user_attachment_from_another_conversation(monkeypatch):
service = _service()
record = {
"attachment_id": "att-1",
"tenant_code": "tenant-a",
"user_id": 100,
"conversation_id": "conv-a",
"indexing_status": "completed",
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
"deleted_at": None,
}
async def fake_get_attachment_record(_attachment_id):
return record
monkeypatch.setattr(service, "_get_attachment_record", fake_get_attachment_record)
with pytest.raises(LeauditException) as exc:
asyncio.run(
service.DeleteAttachment(
CurrentUserId=100,
UserArea=None,
UserRole=None,
TenantCode="tenant-a",
TenantName=None,
ConversationId="conv-b",
AttachmentId="att-1",
)
)
assert exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
def test_validate_attachment_scope_rejects_other_tenant():
service = _service()
record = {
"attachment_id": "att-1",
"tenant_code": "tenant-a",
"user_id": 100,
"conversation_id": "conv-a",
"indexing_status": "completed",
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
"deleted_at": None,
}
with pytest.raises(LeauditException) as exc:
service._assert_attachment_scope(
record,
CurrentUserId=100,
TenantCode="tenant-b",
ConversationId="conv-a",
RequireCompleted=True,
Now=datetime.now(timezone.utc),
)
assert exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
def test_validate_attachment_scope_rejects_expired_or_incomplete_attachment():
service = _service()
expired = {
"attachment_id": "att-1",
"tenant_code": "tenant-a",
"user_id": 100,
"conversation_id": "conv-a",
"indexing_status": "completed",
"expires_at": datetime.now(timezone.utc) - timedelta(seconds=1),
"deleted_at": None,
}
waiting = {
**expired,
"indexing_status": "indexing",
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
}
with pytest.raises(LeauditException) as expired_exc:
service._assert_attachment_scope(
expired,
CurrentUserId=100,
TenantCode="tenant-a",
ConversationId="conv-a",
RequireCompleted=True,
Now=datetime.now(timezone.utc),
)
assert expired_exc.value.status == StatusCodeEnum.HTTP_400_BAD_REQUEST
assert "已过期" in expired_exc.value.message
with pytest.raises(LeauditException) as waiting_exc:
service._assert_attachment_scope(
waiting,
CurrentUserId=100,
TenantCode="tenant-a",
ConversationId="conv-a",
RequireCompleted=True,
Now=datetime.now(timezone.utc),
)
assert waiting_exc.value.status == StatusCodeEnum.HTTP_400_BAD_REQUEST
assert "尚未完成" in waiting_exc.value.message
def test_build_chunks_includes_isolation_metadata():
service = _service()
chunks = service.BuildChunks(
AttachmentId="att-1",
TenantCode="tenant-a",
UserId=100,
ConversationId="conv-a",
FileName="处罚材料.docx",
PageTexts=[(1, "第一段违法事实。\n\n第二段处罚线索。")],
ChunkMaxSize=20,
ChunkOverlap=0,
)
assert chunks
metadata = chunks[0]["metadata"]
assert metadata["tenant_code"] == "tenant-a"
assert metadata["user_id"] == "100"
assert metadata["conversation_id"] == "conv-a"
assert metadata["attachment_id"] == "att-1"
assert metadata["source_scope"] == "chat_attachment"
assert metadata["document_name"] == "处罚材料.docx"
assert metadata["page"] == 1
def test_resolve_active_attachment_id_uses_user_conversation_tenant_and_completed_state(monkeypatch):
service = _service()
captured_sql = {}
captured_params = {}
class FakeResult:
def mappings(self):
return self
def first(self):
return {"attachment_id": "att-active"}
def all(self):
return [{"attachment_id": "att-active"}]
class FakeSession:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def execute(self, statement, params=None):
captured_sql["value"] = str(statement)
captured_params.update(params or {})
return FakeResult()
class FakeSessionFactory:
def __call__(self):
return FakeSession()
service.__class__._attachment_schema_checked = True
monkeypatch.setattr(
"fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl.GetAsyncSession",
FakeSessionFactory(),
)
attachment_id = asyncio.run(
service.ResolveActiveAttachmentIdForConversation(
CurrentUserId=100,
TenantCode="tenant-a",
UserArea="云浮",
ConversationId="conv-a",
)
)
assert attachment_id == "att-active"
assert "user_id = :user_id" in captured_sql["value"]
assert "conversation_id = :conversation_id" in captured_sql["value"]
assert "indexing_status = 'completed'" in captured_sql["value"]
assert "expires_at > NOW()" in captured_sql["value"]
assert "deleted_at IS NULL" in captured_sql["value"]
assert captured_params == {
"user_id": 100,
"conversation_id": "conv-a",
"tenant_code": "tenant-a",
"user_area": "云浮",
}
def test_resolve_active_attachment_ids_returns_all_completed_conversation_attachments(monkeypatch):
service = _service()
captured_sql = {}
captured_params = {}
class FakeResult:
def mappings(self):
return self
def all(self):
return [{"attachment_id": "att-1"}, {"attachment_id": "att-2"}]
class FakeSession:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
async def execute(self, statement, params=None):
captured_sql["value"] = str(statement)
captured_params.update(params or {})
return FakeResult()
class FakeSessionFactory:
def __call__(self):
return FakeSession()
service.__class__._attachment_schema_checked = True
monkeypatch.setattr(
"fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl.GetAsyncSession",
FakeSessionFactory(),
)
attachment_ids = asyncio.run(
service.ResolveActiveAttachmentIdsForConversation(
CurrentUserId=100,
TenantCode="tenant-a",
UserArea="云浮",
ConversationId="conv-a",
)
)
assert attachment_ids == ["att-1", "att-2"]
assert "ORDER BY indexing_completed_at DESC NULLS LAST, created_at DESC" in captured_sql["value"]
assert "LIMIT 1" not in captured_sql["value"]
assert captured_params == {
"user_id": 100,
"conversation_id": "conv-a",
"tenant_code": "tenant-a",
"user_area": "云浮",
}
+310
View File
@@ -79,6 +79,232 @@ async def _run_streaming_task() -> list[dict]:
return service._task_events[task_id]
async def _run_streaming_task_with_attachment() -> tuple[list[dict], list[dict]]:
service = RagChatServiceImpl()
task_id = "task-attachment-test"
service._task_events[task_id] = []
service._task_done[task_id] = False
service._task_locks[task_id] = asyncio.Lock()
captured_context_chunks: list[dict] = []
async def fake_retrieve_context(dataset_id, query):
return [
{
"dataset_id": dataset_id,
"document_id": "law-doc-1",
"document_name": "处罚依据.md",
"source": "处罚依据.md",
"id": "law-segment-1",
"score": 0.88,
"hit_count": 2,
"text": "正式知识库中的处罚依据",
}
], "正式法规知识库"
async def fake_retrieve_attachment_context(**kwargs):
return [
{
"attachment_id": kwargs["AttachmentId"],
"document_name": "用户上传.docx",
"source": "用户上传.docx",
"id": "attachment-segment-1",
"score": 0.96,
"text": "上传文档中的违法事实",
"source_scope": "chat_attachment",
"data_source_type": "chat_attachment",
}
], "用户上传.docx"
async def fake_generate_stream(**kwargs):
captured_context_chunks.extend(kwargs["context_chunks"])
async for event in _fake_generate_stream(**kwargs):
yield event
async def noop_async(*args, **kwargs):
return None
service._retrieve_context = fake_retrieve_context
service._retrieve_attachment_context = fake_retrieve_attachment_context
service._finalize_message_record = noop_async
service._maybe_schedule_auto_title = noop_async
with (
patch(
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
fake_generate_stream,
),
patch(
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
return_value=[],
),
):
await service._run_message_task(
task_id=task_id,
conversation_id="conversation-test",
message_id="message-test",
query="这份材料的违法内容会受到什么处罚",
app={"dataset_id": 7},
current_user_id=100,
tenant_code="tenant-a",
attachment_id="attachment-1",
)
return service._task_events[task_id], captured_context_chunks
async def _run_streaming_task_with_conversation_attachment_fallback() -> list[dict]:
service = RagChatServiceImpl()
task_id = "task-attachment-fallback-test"
service._task_events[task_id] = []
service._task_done[task_id] = False
service._task_locks[task_id] = asyncio.Lock()
captured_context_chunks: list[dict] = []
async def fake_resolve_attachment_ids_for_conversation(**kwargs):
assert kwargs["ConversationId"] == "conversation-test"
assert kwargs["CurrentUserId"] == 100
return ["attachment-1"]
async def fake_retrieve_context(dataset_id, query):
return [
{
"dataset_id": dataset_id,
"document_id": "law-doc-1",
"document_name": "处罚依据.md",
"source": "处罚依据.md",
"id": "law-segment-1",
"score": 0.88,
"hit_count": 2,
"text": "正式知识库中的处罚依据",
}
], "正式法规知识库"
async def fake_retrieve_attachment_context(**kwargs):
return [
{
"attachment_id": kwargs["AttachmentId"],
"document_name": "用户上传.docx",
"source": "用户上传.docx",
"id": "attachment-segment-1",
"score": 0.96,
"text": "上传文档中的违法事实",
"source_scope": "chat_attachment",
"data_source_type": "chat_attachment",
}
], "用户上传.docx"
async def fake_generate_stream(**kwargs):
captured_context_chunks.extend(kwargs["context_chunks"])
async for event in _fake_generate_stream(**kwargs):
yield event
async def noop_async(*args, **kwargs):
return None
service._resolve_attachment_ids_for_conversation = fake_resolve_attachment_ids_for_conversation
service._retrieve_context = fake_retrieve_context
service._retrieve_attachment_context = fake_retrieve_attachment_context
service._finalize_message_record = noop_async
service._maybe_schedule_auto_title = noop_async
with (
patch(
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
fake_generate_stream,
),
patch(
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
return_value=[],
),
):
await service._run_message_task(
task_id=task_id,
conversation_id="conversation-test",
message_id="message-test",
query="江小妹违法了什么",
app={"dataset_id": 7},
current_user_id=100,
tenant_code="tenant-a",
attachment_id=None,
)
return captured_context_chunks
async def _run_streaming_task_with_multiple_attachments() -> tuple[list[dict], list[dict]]:
service = RagChatServiceImpl()
task_id = "task-multiple-attachments-test"
service._task_events[task_id] = []
service._task_done[task_id] = False
service._task_locks[task_id] = asyncio.Lock()
captured_context_chunks: list[dict] = []
async def fake_retrieve_context(dataset_id, query):
return [
{
"dataset_id": dataset_id,
"document_id": "law-doc-1",
"document_name": "处罚依据.md",
"source": "处罚依据.md",
"id": "law-segment-1",
"score": 0.88,
"hit_count": 2,
"text": "正式知识库中的处罚依据",
}
], "正式法规知识库"
async def fake_retrieve_attachment_context(**kwargs):
attachment_id = kwargs["AttachmentId"]
return [
{
"attachment_id": attachment_id,
"document_name": f"{attachment_id}.docx",
"source": f"{attachment_id}.docx",
"id": f"{attachment_id}-segment-1",
"score": 0.96,
"text": f"{attachment_id} 中的违法事实",
"source_scope": "chat_attachment",
"data_source_type": "chat_attachment",
}
], f"{attachment_id}.docx"
async def fake_generate_stream(**kwargs):
captured_context_chunks.extend(kwargs["context_chunks"])
async for event in _fake_generate_stream(**kwargs):
yield event
async def noop_async(*args, **kwargs):
return None
service._retrieve_context = fake_retrieve_context
service._retrieve_attachment_context = fake_retrieve_attachment_context
service._finalize_message_record = noop_async
service._maybe_schedule_auto_title = noop_async
with (
patch(
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
fake_generate_stream,
),
patch(
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
return_value=[],
),
):
await service._run_message_task(
task_id=task_id,
conversation_id="conversation-test",
message_id="message-test",
query="这些材料的违法内容会受到什么处罚",
app={"dataset_id": 7},
current_user_id=100,
tenant_code="tenant-a",
attachment_ids=["attachment-1", "attachment-2"],
)
return service._task_events[task_id], captured_context_chunks
def test_streaming_message_end_includes_retriever_resources():
events = asyncio.run(_run_streaming_task())
@@ -89,3 +315,87 @@ def test_streaming_message_end_includes_retriever_resources():
assert resources[0]["dataset_id"] == "7"
assert resources[0]["dataset_name"] == "测试知识库"
assert resources[0]["document_name"] == "引用文档.pdf"
def test_message_task_merges_attachment_facts_and_formal_kb_context():
events, context_chunks = asyncio.run(_run_streaming_task_with_attachment())
message_end = next(event for event in events if event.get("event") == "message_end")
resources = message_end["metadata"].get("retriever_resources")
scopes = [chunk.get("source_scope") for chunk in context_chunks]
assert scopes == ["chat_attachment", "formal_kb"]
assert context_chunks[0]["document_name"] == "用户上传.docx"
assert context_chunks[1]["document_name"] == "处罚依据.md"
assert resources[0]["data_source_type"] == "chat_attachment"
assert resources[1]["data_source_type"] == "formal_kb"
def test_message_task_uses_active_conversation_attachment_when_request_omits_attachment_id():
context_chunks = asyncio.run(_run_streaming_task_with_conversation_attachment_fallback())
scopes = [chunk.get("source_scope") for chunk in context_chunks]
assert scopes == ["chat_attachment", "formal_kb"]
assert context_chunks[0]["attachment_id"] == "attachment-1"
def test_message_task_merges_multiple_attachment_contexts_before_formal_kb():
events, context_chunks = asyncio.run(_run_streaming_task_with_multiple_attachments())
message_end = next(event for event in events if event.get("event") == "message_end")
resources = message_end["metadata"].get("retriever_resources")
scopes = [chunk.get("source_scope") for chunk in context_chunks]
assert scopes == ["chat_attachment", "chat_attachment", "formal_kb"]
assert [chunk.get("attachment_id") for chunk in context_chunks[:2]] == ["attachment-1", "attachment-2"]
assert resources[0]["data_source_type"] == "chat_attachment"
assert resources[1]["data_source_type"] == "chat_attachment"
assert resources[2]["data_source_type"] == "formal_kb"
def test_message_attachment_metadata_is_stored_for_history_display():
service = RagChatServiceImpl()
files = service._build_user_message_files(
[
{
"attachment_id": "att-doc",
"original_name": "处罚材料.docx",
"content_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"file_size": 1234,
},
{
"attachment_id": "att-img",
"original_name": "现场照片.png",
"content_type": "image/png",
"file_size": 4567,
},
]
)
assert files == [
{
"id": "att-doc",
"upload_file_id": "att-doc",
"name": "处罚材料.docx",
"fileName": "处罚材料.docx",
"type": "file",
"transfer_method": "local_file",
"contentType": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"fileSize": 1234,
"belongs_to": "user",
"usage": "temporary_attachment",
},
{
"id": "att-img",
"upload_file_id": "att-img",
"name": "现场照片.png",
"fileName": "现场照片.png",
"type": "image",
"transfer_method": "local_file",
"contentType": "image/png",
"fileSize": 4567,
"belongs_to": "user",
"usage": "temporary_attachment",
},
]
+7 -5
View File
@@ -183,7 +183,7 @@ def test_rbac_manageable_permissions_include_rule_version_lifecycle():
assert "rules:binding_delete:delete" in permission_keys
def test_rbac_rule_group_permissions_are_folded_into_rules_menu():
def test_rbac_rule_groups_route_is_exposed_under_settings():
route_paths = {item["route_path"] for item in RbacAdminServiceImpl._MANAGEABLE_ROUTE_BLUEPRINTS}
group_permission_paths = {
item["route_path"]
@@ -191,17 +191,19 @@ def test_rbac_rule_group_permissions_are_folded_into_rules_menu():
if item["permission_key"].startswith("evaluation_group:")
}
assert "/rule-groups" not in route_paths
assert "/rule-groups" in route_paths
assert group_permission_paths == {"/rules"}
def test_user_route_compat_menu_does_not_expose_rule_groups():
def test_user_route_compat_menu_exposes_rule_groups_under_settings():
service = RbacServiceImpl()
routes = service._buildCompatibilityRoutes(["admin"], {"evaluation_group:list:read", "rules:list:read"})
paths = service._collectRoutePaths(routes)
rules_route = next(route for route in routes if route.route_path == "/rules")
settings_route = next(route for route in routes if route.route_path == "/settings")
rule_groups_route = next(route for route in (settings_route.children or []) if route.route_path == "/rule-groups")
assert "/rule-groups" not in paths
assert "/rule-groups" in paths
assert rule_groups_route.parent_id == settings_route.id
assert "evaluation_group:list:read" in rules_route.permissions