feat(rag): add temporary chat attachments
This commit is contained in:
@@ -0,0 +1,89 @@
|
||||
# Chat Temporary RAG Attachments Implementation Plan
|
||||
|
||||
> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking.
|
||||
|
||||
**Goal:** Build conversation-scoped temporary RAG attachments for chat with 7-day TTL, parser/OCR indexing, strict tenant/user/conversation isolation, and dual retrieval with the existing formal knowledge base.
|
||||
|
||||
**Architecture:** Add a focused `RagChatAttachmentServiceImpl` responsible for attachment lifecycle, parsing, chunking, indexing, retrieval, validation, and cleanup. Extend `RagChatServiceImpl` so a chat message can include an `AttachmentId`, retrieve attachment facts first, then formal KB legal context, and generate one answer with source-aware chunks. Add frontend upload/poll/status plumbing inside the existing chat input path without touching unrelated dirty frontend files.
|
||||
|
||||
**Tech Stack:** FastAPI, SQLAlchemy text queries, Chroma, existing `RagRetriever`, existing OSS client, existing LeAudit OCR bridge, React/Next.js, Ant Design upload controls.
|
||||
|
||||
---
|
||||
|
||||
### Task 1: Backend Attachment Contract And Unit Tests
|
||||
|
||||
**Files:**
|
||||
- Create: `tests/test_rag_chat_attachment_service.py`
|
||||
- Create: `fastapi_modules/fastapi_leaudit/domian/vo/ragChatAttachmentVo.py`
|
||||
- Modify: `fastapi_modules/fastapi_leaudit/domian/Dto/ragChatDto.py`
|
||||
|
||||
- [ ] Write tests for TTL, collection naming, scope matching, and chunk metadata.
|
||||
- [ ] Run: `pytest tests/test_rag_chat_attachment_service.py -q`; expected failures mention missing `RagChatAttachmentServiceImpl`.
|
||||
- [ ] Add attachment DTO and VO classes.
|
||||
- [ ] Re-run the same test and keep remaining failures focused on missing service implementation.
|
||||
|
||||
### Task 2: Backend Attachment Service
|
||||
|
||||
**Files:**
|
||||
- Create: `fastapi_modules/fastapi_leaudit/services/ragChatAttachmentService.py`
|
||||
- Create: `fastapi_modules/fastapi_leaudit/services/impl/ragChatAttachmentServiceImpl.py`
|
||||
- Create: `scripts/创建sql/schema_add_rag_chat_attachments.sql`
|
||||
|
||||
- [ ] Implement schema creation for `rag_chat_attachment`.
|
||||
- [ ] Implement `BuildCollectionName`, `BuildChunks`, `CreateAttachment`, `GetAttachment`, `ValidateAttachmentForChat`, `RetrieveAttachmentContext`, `DeleteAttachment`, and `CleanupExpiredAttachments`.
|
||||
- [ ] Implement parsers for txt/md/json/csv/docx/pdf/xlsx/images.
|
||||
- [ ] Implement async indexing with status transitions and best-effort cleanup.
|
||||
- [ ] Run: `pytest tests/test_rag_chat_attachment_service.py -q`; expected pass.
|
||||
|
||||
### Task 3: Chat Message Dual Retrieval
|
||||
|
||||
**Files:**
|
||||
- Modify: `fastapi_modules/fastapi_leaudit/services/ragChatService.py`
|
||||
- Modify: `fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py`
|
||||
- Modify: `fastapi_modules/fastapi_leaudit/rag_engine/generator.py`
|
||||
- Modify: `tests/test_rag_chat_streaming_sources.py`
|
||||
|
||||
- [ ] Add `AttachmentId` to chat service and DTO call path.
|
||||
- [ ] Add tests proving `_run_message_task` merges attachment chunks and formal KB chunks with distinct source scopes.
|
||||
- [ ] Run targeted test and confirm red.
|
||||
- [ ] Implement retrieval merge and grounded legal query construction.
|
||||
- [ ] Update generator prompt to group uploaded file facts separately from formal KB basis.
|
||||
- [ ] Run targeted tests and confirm green.
|
||||
|
||||
### Task 4: Backend Controller Routes
|
||||
|
||||
**Files:**
|
||||
- Modify: `fastapi_modules/fastapi_leaudit/controllers/ragChatController.py`
|
||||
|
||||
- [ ] Add `POST /chat/attachments`, `GET /chat/attachments/{AttachmentId}`, and `DELETE /chat/attachments/{AttachmentId}`.
|
||||
- [ ] Extend message send route to pass `Body.attachmentId`.
|
||||
- [ ] Reuse `rag:chat:use` permission and existing tenant context.
|
||||
- [ ] Run backend attachment and chat streaming tests.
|
||||
|
||||
### Task 5: Frontend Upload And Send Plumbing
|
||||
|
||||
**Files:**
|
||||
- Create: `legal-platform-frontend/app/api/chat-attachments/route.ts`
|
||||
- Create: `legal-platform-frontend/app/api/chat-attachments/[attachmentId]/route.ts`
|
||||
- Create: `legal-platform-frontend/lib/api/legacy/dify-chat/attachment.ts`
|
||||
- Modify: `legal-platform-frontend/lib/api/legacy/dify-chat/types.ts`
|
||||
- Modify: `legal-platform-frontend/lib/api/legacy/dify-chat/client.ts`
|
||||
- Modify: `legal-platform-frontend/app/api/chat-messages/route.ts`
|
||||
- Modify: `legal-platform-frontend/hooks/use-chat-message.ts`
|
||||
- Modify: `legal-platform-frontend/components/dify-chat/index.tsx`
|
||||
- Modify: `legal-platform-frontend/components/dify-chat/chat-input.tsx`
|
||||
|
||||
- [ ] Add frontend attachment API client and Next route proxies.
|
||||
- [ ] Update chat input to allow one file, upload immediately, poll status, show completed/error/removable state, and block send while indexing.
|
||||
- [ ] Pass `attachmentId` through `Chat` -> `useChatMessage` -> `/api/chat-messages` -> FastAPI.
|
||||
- [ ] Keep upload scoped to current conversation id; for new chats, create/send with conversation after backend returns id or require existing conversation before attachment.
|
||||
|
||||
### Task 6: Verification
|
||||
|
||||
**Files:**
|
||||
- All touched files
|
||||
|
||||
- [ ] Run: `pytest tests/test_rag_chat_attachment_service.py tests/test_rag_chat_streaming_sources.py -q`.
|
||||
- [ ] Run frontend type/test command if available and scoped enough.
|
||||
- [ ] Check `git status --short` at root and frontend subrepo.
|
||||
- [ ] Report changed files, verification output, and any known gaps.
|
||||
@@ -0,0 +1,150 @@
|
||||
# Chat Temporary RAG Attachments Design
|
||||
|
||||
## Goal
|
||||
|
||||
Implement in-chat file upload for `/chat-with-llm/chat` so a user can upload one document or image, have it parsed and indexed into a temporary conversation-scoped RAG collection, and then ask questions that combine facts from the uploaded file with legal knowledge from the existing app knowledge base. Temporary attachment indexes expire after 7 days by default.
|
||||
|
||||
## Core Behavior
|
||||
|
||||
The upload belongs to the chat experience, not the permanent knowledge-base management UI. A file uploaded inside a conversation creates a temporary attachment record and a temporary vector collection. Chat messages may reference that attachment by `attachmentId`.
|
||||
|
||||
The answer flow is dual-source RAG:
|
||||
|
||||
1. Retrieve from the temporary attachment collection to establish facts from the uploaded file.
|
||||
2. Build a grounded legal search query from the user question plus attachment hits.
|
||||
3. Retrieve from the existing formal app knowledge base to find laws, penalties, cases, or policy rules.
|
||||
4. Generate a single answer that clearly uses uploaded-file content as facts and formal knowledge-base content as legal basis.
|
||||
|
||||
If no attachment is selected, chat continues using the existing formal knowledge-base retrieval path.
|
||||
|
||||
## Scope And Isolation
|
||||
|
||||
Temporary knowledge must never leak between conversations or users. Backend validation must never trust frontend `attachmentId` alone. Every attachment operation validates:
|
||||
|
||||
- `tenant_code` or resolved tenant context
|
||||
- `user_id`
|
||||
- `conversation_id`
|
||||
- `attachment_id`
|
||||
- `deleted_at IS NULL`
|
||||
- `expires_at > NOW()`
|
||||
- `indexing_status = 'completed'` for chat retrieval
|
||||
|
||||
The temporary collection name must include all isolation dimensions:
|
||||
|
||||
```text
|
||||
chat_attachment_{tenantCode}_{userId}_{conversationHash}_{attachmentId}
|
||||
```
|
||||
|
||||
Chunk metadata also carries:
|
||||
|
||||
- `tenant_code`
|
||||
- `user_id`
|
||||
- `conversation_id`
|
||||
- `attachment_id`
|
||||
- `source_scope = "chat_attachment"`
|
||||
- `document_name`
|
||||
- `page`
|
||||
- `chunk_index`
|
||||
|
||||
This creates defense in depth: database filtering protects attachment ownership, and vector metadata protects retrieval integrity if collection access is ever broadened.
|
||||
|
||||
## Supported Files
|
||||
|
||||
Initial backend support:
|
||||
|
||||
- Text: `.txt`, `.md`, `.json`, `.csv`
|
||||
- Word: `.docx`
|
||||
- PDF: `.pdf`
|
||||
- Excel: `.xlsx` using `openpyxl` when available, with a clear 400 error if the dependency is missing
|
||||
- Images: `.png`, `.jpg`, `.jpeg`, `.webp`, `.bmp`, `.tif`, `.tiff`
|
||||
|
||||
Images are parsed through the existing LeAudit OCR bridge. If OCR returns structured pages, the service converts pages to text. If only a raw dictionary is returned, it extracts common fields such as `full_text`, `text`, `content`, `pages`, and OCR line text.
|
||||
|
||||
## Lifecycle
|
||||
|
||||
Upload flow:
|
||||
|
||||
1. Frontend posts multipart file with `conversation_id` and optional `app_id`.
|
||||
2. Backend validates chat permission and conversation ownership.
|
||||
3. Backend creates `rag_chat_attachment` with `indexing_status = 'waiting'` and `expires_at = now + 7 days`.
|
||||
4. Backend stores the original file in OSS under a temporary chat path.
|
||||
5. Backend starts async indexing:
|
||||
- `parsing`
|
||||
- `splitting`
|
||||
- `indexing`
|
||||
- `completed`
|
||||
- or `error`
|
||||
6. Frontend polls attachment status and enables send only after `completed`.
|
||||
|
||||
Deletion flow:
|
||||
|
||||
- User can remove the attachment from the chat UI.
|
||||
- Backend soft-deletes the row.
|
||||
- Backend attempts to delete the temporary Chroma collection and OSS object.
|
||||
- Deletion failures in Chroma/OSS do not block the user-facing delete result.
|
||||
|
||||
TTL cleanup:
|
||||
|
||||
- `expires_at` defaults to 7 days.
|
||||
- Retrieval rejects expired attachments immediately.
|
||||
- A cleanup method soft-deletes expired records and best-effort deletes Chroma collections and OSS objects.
|
||||
- The cleanup method can later be wired to an existing scheduler if one exists.
|
||||
|
||||
## API Shape
|
||||
|
||||
Backend routes under `/api/v3/rag`:
|
||||
|
||||
- `POST /chat/attachments`
|
||||
- multipart: `file`, `conversation_id`, optional `app_id`
|
||||
- returns attachment id, filename, status, expires timestamp, and collection name
|
||||
- `GET /chat/attachments/{AttachmentId}`
|
||||
- returns status and metadata after ownership validation
|
||||
- `DELETE /chat/attachments/{AttachmentId}`
|
||||
- soft deletes and best-effort removes temporary artifacts
|
||||
- `POST /chat/messages`
|
||||
- extends existing body with optional `attachmentId`
|
||||
|
||||
Frontend routes:
|
||||
|
||||
- `POST /api/chat-attachments`
|
||||
- `GET /api/chat-attachments/[attachmentId]`
|
||||
- `DELETE /api/chat-attachments/[attachmentId]`
|
||||
- Existing `/api/chat-messages` forwards `attachment_id` / `attachmentId`.
|
||||
|
||||
## Chat Generation Contract
|
||||
|
||||
The generator must receive context chunks with an explicit `source_scope`:
|
||||
|
||||
- `chat_attachment`: facts extracted from uploaded file
|
||||
- `formal_kb`: laws, penalties, and authoritative references from the configured app knowledge base
|
||||
|
||||
Prompt construction must tell the model:
|
||||
|
||||
- Uploaded attachment chunks are factual input from the user's file.
|
||||
- Formal knowledge-base chunks are the source for legal rules and penalties.
|
||||
- Do not invent laws, penalties, or file facts that are absent from the provided contexts.
|
||||
|
||||
Sources returned to the UI must preserve source type so users can tell whether a cited segment came from the uploaded file or the formal knowledge base.
|
||||
|
||||
## Error Handling
|
||||
|
||||
- Empty file: 400
|
||||
- Unsupported type: 400
|
||||
- Attachment from another user, tenant, or conversation: 403 or 404
|
||||
- Expired attachment: 410-like business error using the existing exception mechanism
|
||||
- Attachment not completed when sending: 400
|
||||
- Parser found no text: status `error`, send disabled in UI
|
||||
- OCR/index failures: status `error`, error message capped for display
|
||||
|
||||
## Testing Requirements
|
||||
|
||||
Backend tests must cover:
|
||||
|
||||
- Default `expires_at` is 7 days.
|
||||
- Collection names include sanitized tenant/user/conversation/attachment isolation fields.
|
||||
- Scope validation rejects mismatched tenant, user, or conversation.
|
||||
- Expired or non-completed attachments cannot be used for chat retrieval.
|
||||
- Built chunk metadata contains tenant/user/conversation/attachment isolation fields.
|
||||
- Chat task merges temporary attachment chunks and formal KB chunks, preserving source metadata.
|
||||
|
||||
Frontend tests are optional in the first slice if the project does not already have focused chat component tests, but the implementation must keep the UI constrained to one selected attachment.
|
||||
@@ -30,6 +30,10 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
|
||||
RagMessagePageVO,
|
||||
RagOperationResultVO,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatAttachmentVo import (
|
||||
RagChatAttachmentDeleteVO,
|
||||
RagChatAttachmentVO,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
|
||||
RagDatasetBatchDeleteResultVO,
|
||||
RagDatasetDetailVO,
|
||||
@@ -43,8 +47,10 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.permissionServiceImpl import PermissionServiceImpl
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
|
||||
from fastapi_modules.fastapi_leaudit.services.permissionService import IPermissionService
|
||||
from fastapi_modules.fastapi_leaudit.services.ragChatAttachmentService import IRagChatAttachmentService
|
||||
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
|
||||
from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService
|
||||
|
||||
@@ -76,6 +82,7 @@ class RagChatController(BaseController):
|
||||
def __init__(self):
|
||||
super().__init__(prefix="/v3/rag", tags=["RAG 聊天"])
|
||||
self.RagChatService: IRagChatService = RagChatServiceImpl()
|
||||
self.RagChatAttachmentService: IRagChatAttachmentService = RagChatAttachmentServiceImpl()
|
||||
self.RagDatasetService: IRagDatasetService = RagDatasetServiceImpl()
|
||||
self.PermissionService: IPermissionService = PermissionServiceImpl()
|
||||
|
||||
@@ -484,6 +491,8 @@ class RagChatController(BaseController):
|
||||
Query=Body.query,
|
||||
ConversationId=Body.conversationId,
|
||||
AppId=Body.appId,
|
||||
AttachmentId=Body.attachmentId,
|
||||
AttachmentIds=Body.attachmentIds,
|
||||
)
|
||||
return StreamingResponse(
|
||||
stream,
|
||||
@@ -491,6 +500,62 @@ class RagChatController(BaseController):
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
@self.router.post("/chat/attachments", response_model=Result[RagChatAttachmentVO])
|
||||
async def UploadChatAttachment(
|
||||
file: UploadFile = File(...),
|
||||
conversation_id: str | None = Form(None),
|
||||
app_id: int | None = Form(None),
|
||||
payload: dict[str, Any] = Depends(verify_access_token),
|
||||
):
|
||||
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"]]):
|
||||
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有上传聊天附件权限", "data": None})
|
||||
tenant_context = self._tenant_context(payload)
|
||||
file_bytes = await file.read()
|
||||
data = await self.RagChatAttachmentService.CreateAttachment(
|
||||
CurrentUserId=int(payload["user_id"]),
|
||||
**tenant_context,
|
||||
ConversationId=conversation_id,
|
||||
AppId=app_id,
|
||||
FileName=file.filename or "attachment",
|
||||
ContentType=file.content_type,
|
||||
Content=file_bytes,
|
||||
)
|
||||
return Result.success(data=data)
|
||||
|
||||
@self.router.get("/chat/attachments/{AttachmentId}", response_model=Result[RagChatAttachmentVO])
|
||||
async def GetChatAttachment(
|
||||
AttachmentId: str,
|
||||
conversation_id: str = Query(..., description="附件所属会话ID"),
|
||||
payload: dict[str, Any] = Depends(verify_access_token),
|
||||
):
|
||||
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"]]):
|
||||
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看聊天附件权限", "data": None})
|
||||
tenant_context = self._tenant_context(payload)
|
||||
data = await self.RagChatAttachmentService.GetAttachment(
|
||||
CurrentUserId=int(payload["user_id"]),
|
||||
**tenant_context,
|
||||
ConversationId=conversation_id,
|
||||
AttachmentId=AttachmentId,
|
||||
)
|
||||
return Result.success(data=data)
|
||||
|
||||
@self.router.delete("/chat/attachments/{AttachmentId}", response_model=Result[RagChatAttachmentDeleteVO])
|
||||
async def DeleteChatAttachment(
|
||||
AttachmentId: str,
|
||||
conversation_id: str = Query(..., description="附件所属会话ID"),
|
||||
payload: dict[str, Any] = Depends(verify_access_token),
|
||||
):
|
||||
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"]]):
|
||||
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有删除聊天附件权限", "data": None})
|
||||
tenant_context = self._tenant_context(payload)
|
||||
data = await self.RagChatAttachmentService.DeleteAttachment(
|
||||
CurrentUserId=int(payload["user_id"]),
|
||||
**tenant_context,
|
||||
ConversationId=conversation_id,
|
||||
AttachmentId=AttachmentId,
|
||||
)
|
||||
return Result.success(data=data)
|
||||
|
||||
@self.router.post("/chat/messages/{MessageId}/stop", response_model=Result[RagOperationResultVO])
|
||||
async def StopMessage(
|
||||
MessageId: str,
|
||||
|
||||
@@ -5,6 +5,8 @@ class RagChatSendMessageDTO(BaseModel):
|
||||
query: str = Field(..., min_length=1, description="用户问题")
|
||||
conversationId: str | None = Field(None, description="会话ID")
|
||||
appId: int | None = Field(None, description="聊天应用ID")
|
||||
attachmentId: str | None = Field(None, description="临时聊天附件ID")
|
||||
attachmentIds: list[str] = Field(default_factory=list, description="临时聊天附件ID列表")
|
||||
|
||||
|
||||
class RagConversationRenameDTO(BaseModel):
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RagChatAttachmentVO(BaseModel):
|
||||
attachmentId: str = Field(..., description="临时附件ID")
|
||||
conversationId: str = Field(..., description="会话ID")
|
||||
fileName: str = Field(..., description="原始文件名")
|
||||
contentType: str = Field("", description="文件 MIME 类型")
|
||||
fileSize: int = Field(0, description="文件大小")
|
||||
indexingStatus: str = Field("waiting", description="索引状态")
|
||||
indexingError: str | None = Field(None, description="索引错误")
|
||||
chunkCount: int = Field(0, description="分段数量")
|
||||
collectionName: str = Field("", description="临时向量集合名")
|
||||
expiresAt: int = Field(0, description="过期时间戳")
|
||||
createdAt: int = Field(0, description="创建时间戳")
|
||||
|
||||
|
||||
class RagChatAttachmentDeleteVO(BaseModel):
|
||||
result: str = Field("success")
|
||||
@@ -36,6 +36,7 @@ class RagMessageItemVO(BaseModel):
|
||||
conversationId: str = Field(...)
|
||||
query: str = Field(...)
|
||||
answer: str = Field(...)
|
||||
messageFiles: list[dict] = Field(default_factory=list)
|
||||
feedback: dict | None = Field(None)
|
||||
retrieverResources: list[dict] | None = Field(None)
|
||||
suggestedQuestions: list[str] = Field(default_factory=list)
|
||||
|
||||
@@ -33,16 +33,34 @@ async def generate_stream(
|
||||
|
||||
max_context_chars = 8000
|
||||
if context_chunks:
|
||||
parts: list[str] = []
|
||||
attachment_parts: list[str] = []
|
||||
formal_parts: list[str] = []
|
||||
total_len = 0
|
||||
for chunk in context_chunks:
|
||||
part = f"[来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
|
||||
scope = str(chunk.get("source_scope") or chunk.get("data_source_type") or "formal_kb")
|
||||
if scope == "chat_attachment":
|
||||
label = "上传文件事实"
|
||||
else:
|
||||
label = "正式知识库依据"
|
||||
part = f"[{label} | 来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
|
||||
if total_len + len(part) > max_context_chars:
|
||||
break
|
||||
parts.append(part)
|
||||
if scope == "chat_attachment":
|
||||
attachment_parts.append(part)
|
||||
else:
|
||||
formal_parts.append(part)
|
||||
total_len += len(part)
|
||||
context_text = "\\n\\n---\\n\\n".join(parts)
|
||||
user_content = f"知识库内容:\\n{context_text}\\n\\n用户问题: {query}"
|
||||
context_sections: list[str] = []
|
||||
if attachment_parts:
|
||||
context_sections.append("【上传文件事实】\n" + "\\n\\n---\\n\\n".join(attachment_parts))
|
||||
if formal_parts:
|
||||
context_sections.append("【正式知识库依据】\n" + "\\n\\n---\\n\\n".join(formal_parts))
|
||||
context_text = "\\n\\n======\\n\\n".join(context_sections)
|
||||
user_content = (
|
||||
"请严格区分上下文来源:上传文件事实只能用于判断用户材料中的事实;正式知识库依据只能用于法律、处罚、裁量或案例依据。"
|
||||
"如果依据不足,请明确说明,不要编造。\n\n"
|
||||
f"{context_text}\\n\\n用户问题: {query}"
|
||||
)
|
||||
else:
|
||||
user_content = query
|
||||
|
||||
@@ -115,7 +133,7 @@ async def generate_stream(
|
||||
"dataset_name": dataset_name,
|
||||
"document_id": "",
|
||||
"document_name": chunk.get("source", ""),
|
||||
"data_source_type": "upload_file",
|
||||
"data_source_type": chunk.get("data_source_type") or chunk.get("source_scope") or "upload_file",
|
||||
"segment_id": chunk.get("id", ""),
|
||||
"retriever_from": "rag",
|
||||
"score": round(chunk.get("score", 0.0), 4),
|
||||
@@ -125,6 +143,8 @@ async def generate_stream(
|
||||
"index_node_hash": "",
|
||||
"content": chunk.get("text", "")[:500],
|
||||
"page": None,
|
||||
"source_scope": chunk.get("source_scope") or "",
|
||||
"attachment_id": chunk.get("attachment_id") or "",
|
||||
}
|
||||
for i, chunk in enumerate(context_chunks)
|
||||
]
|
||||
|
||||
@@ -183,10 +183,15 @@ class RagRetriever:
|
||||
"source": meta.get("source") or meta.get("document_name") or dataset_name,
|
||||
"score": score,
|
||||
"chunk_index": int(meta.get("chunk_index") or idx),
|
||||
"document_name": document_name,
|
||||
"document_id": meta.get("document_id"),
|
||||
"page": meta.get("page"),
|
||||
}
|
||||
"document_name": document_name,
|
||||
"document_id": meta.get("document_id"),
|
||||
"page": meta.get("page"),
|
||||
"source_scope": meta.get("source_scope"),
|
||||
"attachment_id": meta.get("attachment_id"),
|
||||
"conversation_id": meta.get("conversation_id"),
|
||||
"tenant_code": meta.get("tenant_code"),
|
||||
"user_id": meta.get("user_id"),
|
||||
}
|
||||
)
|
||||
return chunks
|
||||
|
||||
@@ -285,6 +290,11 @@ class RagRetriever:
|
||||
"document_name": document_name,
|
||||
"document_id": meta.get("document_id"),
|
||||
"page": meta.get("page"),
|
||||
"source_scope": meta.get("source_scope"),
|
||||
"attachment_id": meta.get("attachment_id"),
|
||||
"conversation_id": meta.get("conversation_id"),
|
||||
"tenant_code": meta.get("tenant_code"),
|
||||
"user_id": meta.get("user_id"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -392,10 +402,10 @@ class RagRetriever:
|
||||
{
|
||||
"position": index + 1,
|
||||
"dataset_id": str(chunk.get("dataset_id") or ""),
|
||||
"dataset_name": dataset_name,
|
||||
"dataset_name": chunk.get("dataset_name") or dataset_name,
|
||||
"document_id": str(chunk.get("document_id") or ""),
|
||||
"document_name": chunk.get("document_name") or chunk.get("source", ""),
|
||||
"data_source_type": "upload_file",
|
||||
"data_source_type": chunk.get("data_source_type") or chunk.get("source_scope") or "upload_file",
|
||||
"segment_id": chunk.get("id", ""),
|
||||
"retriever_from": "rag",
|
||||
"score": round(float(chunk.get("score") or 0.0), 4),
|
||||
@@ -405,6 +415,8 @@ class RagRetriever:
|
||||
"index_node_hash": "",
|
||||
"content": chunk.get("text", "")[:500],
|
||||
"page": None,
|
||||
"source_scope": chunk.get("source_scope") or "",
|
||||
"attachment_id": chunk.get("attachment_id") or "",
|
||||
}
|
||||
for index, chunk in enumerate(context_chunks)
|
||||
]
|
||||
|
||||
@@ -0,0 +1,821 @@
|
||||
"""Temporary RAG attachments for chat conversations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import csv
|
||||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from fastapi_common.fastapi_common_storage.oss_client import OssClient
|
||||
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
||||
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatAttachmentVo import (
|
||||
RagChatAttachmentDeleteVO,
|
||||
RagChatAttachmentVO,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import EmbedTexts, RagRetriever
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
|
||||
from fastapi_modules.fastapi_leaudit.services.ragChatAttachmentService import IRagChatAttachmentService
|
||||
|
||||
|
||||
DEFAULT_ATTACHMENT_TTL_DAYS = 7
|
||||
SUPPORTED_TEXT_SUFFIXES = {".txt", ".md", ".json", ".csv"}
|
||||
SUPPORTED_DOCUMENT_SUFFIXES = {".docx", ".pdf", ".xlsx"}
|
||||
SUPPORTED_IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tif", ".tiff"}
|
||||
SUPPORTED_SUFFIXES = SUPPORTED_TEXT_SUFFIXES | SUPPORTED_DOCUMENT_SUFFIXES | SUPPORTED_IMAGE_SUFFIXES
|
||||
|
||||
|
||||
class RagChatAttachmentServiceImpl(IRagChatAttachmentService):
|
||||
"""Manage temporary, conversation-scoped RAG attachment indexes."""
|
||||
|
||||
_attachment_schema_checked = False
|
||||
_attachment_schema_lock = asyncio.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
chroma_client: Any | None = None,
|
||||
embed_texts: EmbedTexts | None = None,
|
||||
chat_service: RagChatServiceImpl | None = None,
|
||||
ocr_client_factory: Callable[[], Any] | None = None,
|
||||
) -> None:
|
||||
self._chroma_client = chroma_client
|
||||
self.retriever = RagRetriever(
|
||||
chroma_client=chroma_client,
|
||||
embed_texts=embed_texts,
|
||||
hydrate_documents=False,
|
||||
)
|
||||
self.chat_service = chat_service or RagChatServiceImpl()
|
||||
self.dataset_helpers = RagDatasetServiceImpl()
|
||||
self._ocr_client_factory = ocr_client_factory
|
||||
|
||||
def _default_expires_at(self, now: datetime | None = None) -> datetime:
|
||||
base = self._ensure_aware_datetime(now or datetime.now(timezone.utc))
|
||||
return base + timedelta(days=DEFAULT_ATTACHMENT_TTL_DAYS)
|
||||
|
||||
def BuildCollectionName(self, *, TenantCode: str | None, UserId: int, ConversationId: str, AttachmentId: str) -> str:
|
||||
tenant = self._slugify(TenantCode or "default")
|
||||
user = self._slugify(str(UserId))
|
||||
conversation_hash = hashlib.sha1(ConversationId.encode("utf-8")).hexdigest()[:16]
|
||||
attachment = self._slugify(AttachmentId)
|
||||
collection_name = f"chat_attachment_{tenant}_{user}_{conversation_hash}_{attachment}"
|
||||
return collection_name[:120]
|
||||
|
||||
def BuildChunks(
|
||||
self,
|
||||
*,
|
||||
AttachmentId: str,
|
||||
TenantCode: str | None,
|
||||
UserId: int,
|
||||
ConversationId: str,
|
||||
FileName: str,
|
||||
PageTexts: list[tuple[int, str]],
|
||||
ChunkMaxSize: int = 800,
|
||||
ChunkOverlap: int = 80,
|
||||
) -> list[dict[str, Any]]:
|
||||
chunk_max_size = max(50, int(ChunkMaxSize or 800))
|
||||
chunk_overlap = max(0, min(int(ChunkOverlap or 0), chunk_max_size // 2))
|
||||
chunks: list[dict[str, Any]] = []
|
||||
for page_no, raw_text in PageTexts:
|
||||
text_value = self._preprocess_text(raw_text)
|
||||
if not text_value:
|
||||
continue
|
||||
split_chunks = self._split_text(text_value, separator="\n\n", max_chars=chunk_max_size, overlap=chunk_overlap)
|
||||
for index, chunk_text in enumerate(split_chunks):
|
||||
if not chunk_text.strip():
|
||||
continue
|
||||
chunk_id = f"{AttachmentId}:{page_no}:{index}"
|
||||
chunks.append(
|
||||
{
|
||||
"id": chunk_id,
|
||||
"text": chunk_text,
|
||||
"metadata": {
|
||||
"id": chunk_id,
|
||||
"source": FileName,
|
||||
"document_name": FileName,
|
||||
"attachment_id": AttachmentId,
|
||||
"tenant_code": str(TenantCode or ""),
|
||||
"user_id": str(UserId),
|
||||
"conversation_id": ConversationId,
|
||||
"source_scope": "chat_attachment",
|
||||
"page": int(page_no),
|
||||
"chunk_index": index,
|
||||
},
|
||||
}
|
||||
)
|
||||
return chunks
|
||||
|
||||
async def CreateAttachment(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None,
|
||||
TenantName: str | None,
|
||||
ConversationId: str | None,
|
||||
AppId: int | None,
|
||||
FileName: str,
|
||||
ContentType: str | None,
|
||||
Content: bytes,
|
||||
) -> RagChatAttachmentVO:
|
||||
if not FileName:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "上传文件名不能为空")
|
||||
if not Content:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "上传文件内容不能为空")
|
||||
|
||||
suffix = Path(FileName).suffix.lower()
|
||||
if suffix not in SUPPORTED_SUFFIXES:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 TXT、MD、JSON、CSV、DOCX、PDF、XLSX 和图片文件")
|
||||
|
||||
app = await self.chat_service._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName)
|
||||
if not app:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用")
|
||||
|
||||
conversation_id = await self.chat_service._ensure_conversation(
|
||||
user_id=CurrentUserId,
|
||||
conversation_id=ConversationId,
|
||||
app_id=app["id"],
|
||||
user_area=UserArea,
|
||||
user_role=UserRole,
|
||||
tenant_code=TenantCode,
|
||||
tenant_name=TenantName,
|
||||
)
|
||||
|
||||
attachment_id = str(uuid.uuid4())
|
||||
collection_name = self.BuildCollectionName(
|
||||
TenantCode=TenantCode,
|
||||
UserId=CurrentUserId,
|
||||
ConversationId=conversation_id,
|
||||
AttachmentId=attachment_id,
|
||||
)
|
||||
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
|
||||
expires_at = self._default_expires_at()
|
||||
object_key = f"rag-chat-attachments/{TenantCode or 'default'}/{CurrentUserId}/{conversation_id}/{attachment_id}_{FileName}"
|
||||
OssClient().EnsureBucket()
|
||||
stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type)
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_attachment_schema(session)
|
||||
row = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_chat_attachment (
|
||||
attachment_id, conversation_id, user_id, tenant_code, area,
|
||||
filename, original_name, content_type, file_size, minio_path,
|
||||
collection_name, indexing_status, chunk_count, expires_at
|
||||
) VALUES (
|
||||
:attachment_id, :conversation_id, :user_id, :tenant_code, :area,
|
||||
:filename, :original_name, :content_type, :file_size, :minio_path,
|
||||
:collection_name, 'waiting', 0, :expires_at
|
||||
)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"attachment_id": attachment_id,
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": CurrentUserId,
|
||||
"tenant_code": TenantCode,
|
||||
"area": UserArea,
|
||||
"filename": f"{attachment_id}{suffix}",
|
||||
"original_name": FileName,
|
||||
"content_type": content_type,
|
||||
"file_size": len(Content),
|
||||
"minio_path": stored_key,
|
||||
"collection_name": collection_name,
|
||||
"expires_at": expires_at,
|
||||
},
|
||||
)
|
||||
).mappings().first()
|
||||
|
||||
asyncio.create_task(
|
||||
self._run_attachment_indexing_task(
|
||||
attachment_id=attachment_id,
|
||||
tenant_code=TenantCode,
|
||||
user_id=CurrentUserId,
|
||||
conversation_id=conversation_id,
|
||||
file_name=FileName,
|
||||
content=Content,
|
||||
collection_name=collection_name,
|
||||
)
|
||||
)
|
||||
return self._to_vo(dict(row))
|
||||
|
||||
async def GetAttachment(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None,
|
||||
TenantName: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
) -> RagChatAttachmentVO:
|
||||
self._validate_conversation_id(ConversationId)
|
||||
record = await self._get_attachment_record(AttachmentId)
|
||||
self._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
RequireCompleted=False,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
return self._to_vo(record)
|
||||
|
||||
async def DeleteAttachment(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None,
|
||||
TenantName: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
) -> RagChatAttachmentDeleteVO:
|
||||
self._validate_conversation_id(ConversationId)
|
||||
record = await self._get_attachment_record(AttachmentId)
|
||||
self._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
RequireCompleted=False,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_attachment_schema(session)
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_chat_attachment
|
||||
SET deleted_at = NOW(), updated_at = NOW()
|
||||
WHERE attachment_id = :attachment_id
|
||||
"""
|
||||
),
|
||||
{"attachment_id": AttachmentId},
|
||||
)
|
||||
self._delete_collection(str(record.get("collection_name") or ""))
|
||||
self._delete_oss_object(str(record.get("minio_path") or ""))
|
||||
return RagChatAttachmentDeleteVO(result="success")
|
||||
|
||||
async def ValidateAttachmentForChat(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
UserArea: str | None = None,
|
||||
) -> dict:
|
||||
record = await self._get_attachment_record(AttachmentId)
|
||||
self._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
RequireCompleted=True,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
return record
|
||||
|
||||
async def RetrieveAttachmentContext(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
Query: str,
|
||||
TopK: int = 5,
|
||||
UserArea: str | None = None,
|
||||
) -> tuple[list[dict], str]:
|
||||
record = await self.ValidateAttachmentForChat(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
AttachmentId=AttachmentId,
|
||||
)
|
||||
result = await self.retriever.retrieve(
|
||||
query=Query,
|
||||
collection_name=str(record["collection_name"]),
|
||||
top_k=TopK,
|
||||
)
|
||||
chunks: list[dict] = []
|
||||
for chunk in result.chunks:
|
||||
chunks.append(
|
||||
{
|
||||
**chunk,
|
||||
"dataset_name": "临时上传文件",
|
||||
"attachment_id": AttachmentId,
|
||||
"conversation_id": ConversationId,
|
||||
"source_scope": "chat_attachment",
|
||||
"data_source_type": "chat_attachment",
|
||||
"document_name": chunk.get("document_name") or record.get("original_name") or chunk.get("source"),
|
||||
"source": chunk.get("source") or record.get("original_name") or "上传文件",
|
||||
}
|
||||
)
|
||||
return chunks, str(record.get("original_name") or "上传文件")
|
||||
|
||||
async def ResolveActiveAttachmentIdForConversation(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
UserArea: str | None = None,
|
||||
) -> str | None:
|
||||
attachment_ids = await self.ResolveActiveAttachmentIdsForConversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
)
|
||||
return attachment_ids[0] if attachment_ids else None
|
||||
|
||||
async def ResolveActiveAttachmentIdsForConversation(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
UserArea: str | None = None,
|
||||
) -> list[str]:
|
||||
self._validate_conversation_id(ConversationId)
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_attachment_schema(session)
|
||||
rows = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT attachment_id, tenant_code, area
|
||||
FROM rag_chat_attachment
|
||||
WHERE user_id = :user_id
|
||||
AND conversation_id = :conversation_id
|
||||
AND indexing_status = 'completed'
|
||||
AND expires_at > NOW()
|
||||
AND deleted_at IS NULL
|
||||
AND (
|
||||
NULLIF(:tenant_code, '') IS NULL
|
||||
OR NULLIF(tenant_code, '') IS NULL
|
||||
OR tenant_code = :tenant_code
|
||||
)
|
||||
AND (
|
||||
NULLIF(:tenant_code, '') IS NOT NULL
|
||||
OR NULLIF(:user_area, '') IS NULL
|
||||
OR NULLIF(area, '') IS NULL
|
||||
OR area = :user_area
|
||||
)
|
||||
ORDER BY indexing_completed_at DESC NULLS LAST, created_at DESC
|
||||
"""
|
||||
),
|
||||
{
|
||||
"user_id": CurrentUserId,
|
||||
"conversation_id": ConversationId,
|
||||
"tenant_code": TenantCode,
|
||||
"user_area": UserArea,
|
||||
},
|
||||
)
|
||||
).mappings().all()
|
||||
return [str(row["attachment_id"]) for row in rows]
|
||||
|
||||
async def CleanupExpiredAttachments(self, *, Limit: int = 100) -> int:
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_attachment_schema(session)
|
||||
rows = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT attachment_id, collection_name, minio_path
|
||||
FROM rag_chat_attachment
|
||||
WHERE deleted_at IS NULL
|
||||
AND expires_at <= NOW()
|
||||
ORDER BY expires_at ASC
|
||||
LIMIT :limit
|
||||
"""
|
||||
),
|
||||
{"limit": max(1, int(Limit or 100))},
|
||||
)
|
||||
).mappings().all()
|
||||
attachment_ids = [str(row["attachment_id"]) for row in rows]
|
||||
if attachment_ids:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_chat_attachment
|
||||
SET deleted_at = NOW(), updated_at = NOW()
|
||||
WHERE attachment_id = ANY(:attachment_ids)
|
||||
"""
|
||||
),
|
||||
{"attachment_ids": attachment_ids},
|
||||
)
|
||||
for row in rows:
|
||||
self._delete_collection(str(row.get("collection_name") or ""))
|
||||
self._delete_oss_object(str(row.get("minio_path") or ""))
|
||||
return len(rows)
|
||||
|
||||
async def _get_attachment_record(self, attachment_id: str) -> dict:
|
||||
if not str(attachment_id or "").strip():
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "附件ID不能为空")
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_attachment_schema(session)
|
||||
row = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT *
|
||||
FROM rag_chat_attachment
|
||||
WHERE attachment_id = :attachment_id
|
||||
AND deleted_at IS NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"attachment_id": attachment_id},
|
||||
)
|
||||
).mappings().first()
|
||||
if not row:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "临时附件不存在")
|
||||
return dict(row)
|
||||
|
||||
def _assert_attachment_scope(
|
||||
self,
|
||||
record: dict,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
RequireCompleted: bool,
|
||||
Now: datetime | None = None,
|
||||
UserArea: str | None = None,
|
||||
) -> None:
|
||||
if record.get("deleted_at") is not None:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "临时附件不存在")
|
||||
if int(record.get("user_id") or 0) != int(CurrentUserId):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该临时附件")
|
||||
if str(record.get("conversation_id") or "") != str(ConversationId or ""):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前会话")
|
||||
expected_tenant = str(TenantCode or "").strip()
|
||||
record_tenant = str(record.get("tenant_code") or "").strip()
|
||||
if expected_tenant and record_tenant and expected_tenant != record_tenant:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前租户")
|
||||
expected_area = str(UserArea or "").strip()
|
||||
record_area = str(record.get("area") or "").strip()
|
||||
if not expected_tenant and expected_area and record_area and expected_area != record_area:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前地区")
|
||||
|
||||
expires_at = self._ensure_aware_datetime(record.get("expires_at"))
|
||||
now = self._ensure_aware_datetime(Now or datetime.now(timezone.utc))
|
||||
if expires_at <= now:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件已过期,请重新上传")
|
||||
|
||||
if RequireCompleted and str(record.get("indexing_status") or "") != "completed":
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件尚未完成解析,请稍后再试")
|
||||
|
||||
async def _run_attachment_indexing_task(
|
||||
self,
|
||||
*,
|
||||
attachment_id: str,
|
||||
tenant_code: str | None,
|
||||
user_id: int,
|
||||
conversation_id: str,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
collection_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
await self._update_attachment_state(attachment_id=attachment_id, status="parsing")
|
||||
page_texts = await self._extract_page_texts(FileName=file_name, Content=content)
|
||||
if not page_texts:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文件未提取到可检索文本")
|
||||
|
||||
await self._update_attachment_state(attachment_id=attachment_id, status="splitting")
|
||||
chunks = self.BuildChunks(
|
||||
AttachmentId=attachment_id,
|
||||
TenantCode=tenant_code,
|
||||
UserId=user_id,
|
||||
ConversationId=conversation_id,
|
||||
FileName=file_name,
|
||||
PageTexts=page_texts,
|
||||
)
|
||||
if not chunks:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文件未切分出可检索文本")
|
||||
|
||||
await self._update_attachment_state(attachment_id=attachment_id, status="indexing", chunk_count=len(chunks))
|
||||
embeddings = await self.retriever._embed_texts([item["text"] for item in chunks], "")
|
||||
collection = self._get_chroma().get_or_create_collection(collection_name)
|
||||
collection.add(
|
||||
ids=[item["id"] for item in chunks],
|
||||
documents=[item["text"] for item in chunks],
|
||||
embeddings=embeddings,
|
||||
metadatas=[item["metadata"] for item in chunks],
|
||||
)
|
||||
|
||||
await self._update_attachment_state(
|
||||
attachment_id=attachment_id,
|
||||
status="completed",
|
||||
chunk_count=len(chunks),
|
||||
completed=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
await self._update_attachment_state(
|
||||
attachment_id=attachment_id,
|
||||
status="error",
|
||||
error=str(exc)[:2000],
|
||||
)
|
||||
|
||||
async def _extract_page_texts(self, *, FileName: str, Content: bytes) -> list[tuple[int, str]]:
|
||||
suffix = Path(FileName).suffix.lower()
|
||||
if suffix in {".txt", ".md"}:
|
||||
text_value = Content.decode("utf-8", errors="ignore").strip()
|
||||
return [(1, text_value)] if text_value else []
|
||||
if suffix == ".json":
|
||||
return self.dataset_helpers._extract_page_texts_from_json(Content)
|
||||
if suffix == ".csv":
|
||||
return [(1, self._csv_to_text(Content))] if self._csv_to_text(Content).strip() else []
|
||||
if suffix == ".xlsx":
|
||||
return self._extract_page_texts_from_xlsx(Content)
|
||||
if suffix in SUPPORTED_IMAGE_SUFFIXES:
|
||||
return await self._extract_page_texts_from_image(FileName, Content)
|
||||
return self.dataset_helpers._extract_page_texts(FileName=FileName, Content=Content)
|
||||
|
||||
def _csv_to_text(self, content: bytes) -> str:
|
||||
text_value = content.decode("utf-8-sig", errors="ignore")
|
||||
rows: list[str] = []
|
||||
for row_index, row in enumerate(csv.reader(text_value.splitlines()), start=1):
|
||||
values = [str(cell).strip() for cell in row if str(cell).strip()]
|
||||
if values:
|
||||
rows.append(f"row_{row_index}: " + " | ".join(values))
|
||||
return "\n".join(rows)
|
||||
|
||||
def _extract_page_texts_from_xlsx(self, content: bytes) -> list[tuple[int, str]]:
|
||||
try:
|
||||
from openpyxl import load_workbook
|
||||
except Exception as exc: # pragma: no cover - depends on optional env package
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前环境未安装 openpyxl,暂无法解析 Excel 文件") from exc
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as temp_file:
|
||||
temp_file.write(content)
|
||||
temp_path = temp_file.name
|
||||
try:
|
||||
workbook = load_workbook(temp_path, data_only=True, read_only=True)
|
||||
page_texts: list[tuple[int, str]] = []
|
||||
for sheet_index, sheet in enumerate(workbook.worksheets, start=1):
|
||||
lines: list[str] = [f"Sheet: {sheet.title}"]
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
values = [str(cell).strip() for cell in row if cell is not None and str(cell).strip()]
|
||||
if values:
|
||||
lines.append(" | ".join(values))
|
||||
rendered = "\n".join(lines).strip()
|
||||
if rendered:
|
||||
page_texts.append((sheet_index, rendered))
|
||||
return page_texts
|
||||
finally:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def _extract_page_texts_from_image(self, file_name: str, content: bytes) -> list[tuple[int, str]]:
|
||||
suffix = Path(file_name).suffix.lower() or ".png"
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file:
|
||||
temp_file.write(content)
|
||||
temp_path = Path(temp_file.name)
|
||||
try:
|
||||
ocr_client = self._create_ocr_client()
|
||||
result = await ocr_client.ocr(temp_path)
|
||||
text_value = self._ocr_result_to_text(result)
|
||||
return [(1, text_value)] if text_value else []
|
||||
finally:
|
||||
try:
|
||||
temp_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _ocr_result_to_text(self, value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value.strip()
|
||||
if isinstance(value, dict):
|
||||
candidates: list[str] = []
|
||||
for key in ("full_text", "text", "content", "markdown"):
|
||||
if isinstance(value.get(key), str) and value.get(key).strip():
|
||||
candidates.append(value[key].strip())
|
||||
pages = value.get("pages")
|
||||
if isinstance(pages, list):
|
||||
for page in pages:
|
||||
page_text = self._ocr_result_to_text(page)
|
||||
if page_text:
|
||||
candidates.append(page_text)
|
||||
lines = value.get("lines") or value.get("blocks")
|
||||
if isinstance(lines, list):
|
||||
for line in lines:
|
||||
line_text = self._ocr_result_to_text(line)
|
||||
if line_text:
|
||||
candidates.append(line_text)
|
||||
return "\n".join(dict.fromkeys(candidates)).strip()
|
||||
if isinstance(value, list):
|
||||
return "\n".join(item for item in (self._ocr_result_to_text(item) for item in value) if item).strip()
|
||||
|
||||
text_attrs = []
|
||||
for attr in ("full_text", "text", "content", "markdown"):
|
||||
raw = getattr(value, attr, None)
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
text_attrs.append(raw.strip())
|
||||
pages = getattr(value, "pages", None)
|
||||
if isinstance(pages, list):
|
||||
for page in pages:
|
||||
page_text = self._ocr_result_to_text(page)
|
||||
if page_text:
|
||||
text_attrs.append(page_text)
|
||||
return "\n".join(dict.fromkeys(text_attrs)).strip()
|
||||
|
||||
def _create_ocr_client(self):
|
||||
if self._ocr_client_factory is not None:
|
||||
return self._ocr_client_factory()
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.client_factory import create_ocr_client
|
||||
|
||||
return create_ocr_client()
|
||||
|
||||
async def _update_attachment_state(
|
||||
self,
|
||||
*,
|
||||
attachment_id: str,
|
||||
status: str,
|
||||
chunk_count: int | None = None,
|
||||
error: str | None = None,
|
||||
completed: bool = False,
|
||||
) -> None:
|
||||
fields = [
|
||||
"indexing_status = :status",
|
||||
"indexing_error = :error",
|
||||
"updated_at = NOW()",
|
||||
"indexing_started_at = COALESCE(indexing_started_at, NOW())",
|
||||
]
|
||||
if chunk_count is not None:
|
||||
fields.append("chunk_count = :chunk_count")
|
||||
if completed:
|
||||
fields.append("indexing_completed_at = NOW()")
|
||||
params: dict[str, Any] = {
|
||||
"attachment_id": attachment_id,
|
||||
"status": status,
|
||||
"error": error,
|
||||
"chunk_count": chunk_count,
|
||||
}
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_attachment_schema(session)
|
||||
await session.execute(
|
||||
text(
|
||||
f"""
|
||||
UPDATE rag_chat_attachment
|
||||
SET {", ".join(fields)}
|
||||
WHERE attachment_id = :attachment_id
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
async def _ensure_attachment_schema(self, session) -> None:
|
||||
if self.__class__._attachment_schema_checked:
|
||||
return
|
||||
async with self.__class__._attachment_schema_lock:
|
||||
if self.__class__._attachment_schema_checked:
|
||||
return
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_chat_attachment (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
attachment_id VARCHAR(64) NOT NULL UNIQUE,
|
||||
conversation_id VARCHAR(64) NOT NULL,
|
||||
user_id BIGINT NOT NULL,
|
||||
tenant_code VARCHAR(64) NULL,
|
||||
area VARCHAR(255) NULL,
|
||||
filename VARCHAR(512) NOT NULL,
|
||||
original_name VARCHAR(512) NOT NULL,
|
||||
content_type VARCHAR(255) NULL,
|
||||
file_size BIGINT NOT NULL DEFAULT 0,
|
||||
minio_path TEXT NULL,
|
||||
collection_name VARCHAR(160) NOT NULL,
|
||||
indexing_status VARCHAR(32) NOT NULL DEFAULT 'waiting',
|
||||
indexing_error TEXT NULL,
|
||||
chunk_count INTEGER NOT NULL DEFAULT 0,
|
||||
indexing_started_at TIMESTAMPTZ NULL,
|
||||
indexing_completed_at TIMESTAMPTZ NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
deleted_at TIMESTAMPTZ NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_scope
|
||||
ON rag_chat_attachment(tenant_code, user_id, conversation_id, attachment_id)
|
||||
WHERE deleted_at IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_expires
|
||||
ON rag_chat_attachment(expires_at)
|
||||
WHERE deleted_at IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
self.__class__._attachment_schema_checked = True
|
||||
|
||||
def _get_chroma(self) -> Any:
|
||||
return self._chroma_client or get_chroma()
|
||||
|
||||
def _delete_collection(self, collection_name: str) -> None:
|
||||
if not collection_name:
|
||||
return
|
||||
try:
|
||||
self._get_chroma().delete_collection(collection_name)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def _delete_oss_object(self, source: str | None) -> None:
|
||||
if not source:
|
||||
return
|
||||
try:
|
||||
oss = OssClient()
|
||||
ref = oss.ResolveObjectRef(Source=source)
|
||||
if ref.isDirectUrl or not ref.objectKey:
|
||||
return
|
||||
oss._GetMinioClient().remove_object(ref.bucket, ref.objectKey)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def _to_vo(self, row: dict) -> RagChatAttachmentVO:
|
||||
expires_at = row.get("expires_at")
|
||||
created_at = row.get("created_at")
|
||||
return RagChatAttachmentVO(
|
||||
attachmentId=str(row.get("attachment_id") or ""),
|
||||
conversationId=str(row.get("conversation_id") or ""),
|
||||
fileName=str(row.get("original_name") or row.get("filename") or ""),
|
||||
contentType=str(row.get("content_type") or ""),
|
||||
fileSize=int(row.get("file_size") or 0),
|
||||
indexingStatus=str(row.get("indexing_status") or "waiting"),
|
||||
indexingError=row.get("indexing_error"),
|
||||
chunkCount=int(row.get("chunk_count") or 0),
|
||||
collectionName=str(row.get("collection_name") or ""),
|
||||
expiresAt=self._timestamp(expires_at),
|
||||
createdAt=self._timestamp(created_at),
|
||||
)
|
||||
|
||||
def _preprocess_text(self, text_value: str) -> str:
|
||||
result = text_value or ""
|
||||
result = re.sub(r"[ \t]+", " ", result)
|
||||
result = re.sub(r"\n{3,}", "\n\n", result)
|
||||
return result.strip()
|
||||
|
||||
def _split_text(self, text_value: str, *, separator: str, max_chars: int, overlap: int) -> list[str]:
|
||||
return self.dataset_helpers._split_text(text_value, separator=separator, max_chars=max_chars, overlap=overlap)
|
||||
|
||||
def _slugify(self, value: str) -> str:
|
||||
normalized = re.sub(r"[^A-Za-z0-9_]+", "_", str(value or "").strip())
|
||||
normalized = re.sub(r"_+", "_", normalized).strip("_").lower()
|
||||
return normalized or "default"
|
||||
|
||||
def _validate_conversation_id(self, conversation_id: str | None) -> None:
|
||||
if not str(conversation_id or "").strip() or str(conversation_id or "").strip() == "-1":
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "会话ID不能为空")
|
||||
|
||||
def _ensure_aware_datetime(self, value: Any) -> datetime:
|
||||
if isinstance(value, datetime):
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is None:
|
||||
return parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件时间字段异常")
|
||||
|
||||
def _timestamp(self, value: Any) -> int:
|
||||
if not value:
|
||||
return 0
|
||||
return int(self._ensure_aware_datetime(value).timestamp())
|
||||
@@ -106,6 +106,8 @@ class RagChatServiceImpl(IRagChatService):
|
||||
Query: str,
|
||||
ConversationId: str | None,
|
||||
AppId: int | None,
|
||||
AttachmentId: str | None = None,
|
||||
AttachmentIds: list[str] | None = None,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
@@ -128,6 +130,25 @@ class RagChatServiceImpl(IRagChatService):
|
||||
messageId = str(uuid.uuid4())
|
||||
taskId = str(uuid.uuid4())
|
||||
is_new_conversation = not ConversationId or ConversationId == "-1"
|
||||
active_attachment_ids = self._normalize_attachment_ids(
|
||||
attachment_id=AttachmentId,
|
||||
attachment_ids=AttachmentIds,
|
||||
)
|
||||
if not active_attachment_ids:
|
||||
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=conversationId,
|
||||
)
|
||||
attachment_records = await self._load_message_attachment_records(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=conversationId,
|
||||
AttachmentIds=active_attachment_ids,
|
||||
)
|
||||
user_message_files = self._build_user_message_files(attachment_records)
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
async with session.begin():
|
||||
@@ -135,13 +156,14 @@ class RagChatServiceImpl(IRagChatService):
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
|
||||
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, '{}'::jsonb)
|
||||
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, CAST(:metadata AS jsonb))
|
||||
"""
|
||||
),
|
||||
{
|
||||
"message_id": str(uuid.uuid4()),
|
||||
"conversation_id": conversationId,
|
||||
"content": Query,
|
||||
"metadata": json.dumps({"message_files": user_message_files}, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
await session.execute(
|
||||
@@ -170,6 +192,11 @@ class RagChatServiceImpl(IRagChatService):
|
||||
message_id=messageId,
|
||||
query=Query,
|
||||
app=app,
|
||||
current_user_id=CurrentUserId,
|
||||
tenant_code=TenantCode,
|
||||
user_area=UserArea,
|
||||
attachment_id=AttachmentId,
|
||||
attachment_ids=active_attachment_ids,
|
||||
)
|
||||
|
||||
event_index = 0
|
||||
@@ -322,6 +349,7 @@ class RagChatServiceImpl(IRagChatService):
|
||||
row = items[idx]
|
||||
if row["role"] == "user":
|
||||
answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None
|
||||
user_metadata = dict(row.get("metadata") or {})
|
||||
answer_metadata = dict((answer.get("metadata") if answer else None) or {})
|
||||
answer_status = str(answer_metadata.get("status") or ("completed" if answer else "running"))
|
||||
answer_content = (answer.get("content") if answer else None) or ""
|
||||
@@ -355,6 +383,7 @@ class RagChatServiceImpl(IRagChatService):
|
||||
conversationId=ConversationId,
|
||||
query=row["content"],
|
||||
answer=answer_content if answer else "",
|
||||
messageFiles=[item for item in (user_metadata.get("message_files") or []) if isinstance(item, dict)],
|
||||
feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None),
|
||||
retrieverResources=(answer.get("sources") if answer else None),
|
||||
suggestedQuestions=[str(item) for item in (answer_metadata.get("suggested_questions") or []) if str(item).strip()],
|
||||
@@ -940,7 +969,181 @@ class RagChatServiceImpl(IRagChatService):
|
||||
|
||||
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
|
||||
result = await self.retriever.retrieve(query=query, dataset_id=dataset_id)
|
||||
return result.chunks, result.dataset_name
|
||||
chunks = [
|
||||
{
|
||||
**chunk,
|
||||
"source_scope": chunk.get("source_scope") or "formal_kb",
|
||||
"data_source_type": chunk.get("data_source_type") or "formal_kb",
|
||||
}
|
||||
for chunk in result.chunks
|
||||
]
|
||||
return chunks, result.dataset_name
|
||||
|
||||
async def _resolve_attachment_id_for_conversation(
|
||||
self,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
UserArea: str | None,
|
||||
ConversationId: str,
|
||||
) -> str | None:
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
service = RagChatAttachmentServiceImpl()
|
||||
return await service.ResolveActiveAttachmentIdForConversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
)
|
||||
|
||||
async def _resolve_attachment_ids_for_conversation(
|
||||
self,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
UserArea: str | None,
|
||||
ConversationId: str,
|
||||
) -> list[str]:
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
service = RagChatAttachmentServiceImpl()
|
||||
if hasattr(service, "ResolveActiveAttachmentIdsForConversation"):
|
||||
return await service.ResolveActiveAttachmentIdsForConversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
)
|
||||
attachment_id = await service.ResolveActiveAttachmentIdForConversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
)
|
||||
return [attachment_id] if attachment_id else []
|
||||
|
||||
async def _retrieve_attachment_context(
|
||||
self,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
UserArea: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
Query: str,
|
||||
) -> tuple[list[dict], str]:
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
service = RagChatAttachmentServiceImpl()
|
||||
return await service.RetrieveAttachmentContext(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
AttachmentId=AttachmentId,
|
||||
Query=Query,
|
||||
TopK=5,
|
||||
)
|
||||
|
||||
async def _load_message_attachment_records(
|
||||
self,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
UserArea: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentIds: list[str],
|
||||
) -> list[dict]:
|
||||
if not AttachmentIds:
|
||||
return []
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
service = RagChatAttachmentServiceImpl()
|
||||
records: list[dict] = []
|
||||
for attachment_id in AttachmentIds:
|
||||
record = await service.ValidateAttachmentForChat(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
AttachmentId=attachment_id,
|
||||
)
|
||||
records.append(record)
|
||||
return records
|
||||
|
||||
def _build_user_message_files(self, attachment_records: list[dict]) -> list[dict]:
|
||||
files: list[dict] = []
|
||||
for record in attachment_records:
|
||||
attachment_id = str(record.get("attachment_id") or "").strip()
|
||||
if not attachment_id:
|
||||
continue
|
||||
file_name = str(record.get("original_name") or record.get("filename") or "上传文件")
|
||||
content_type = str(record.get("content_type") or "")
|
||||
file_type = "image" if content_type.startswith("image/") or re.search(r"\.(png|jpe?g|webp|bmp|tiff?)$", file_name, re.I) else "file"
|
||||
files.append(
|
||||
{
|
||||
"id": attachment_id,
|
||||
"upload_file_id": attachment_id,
|
||||
"name": file_name,
|
||||
"fileName": file_name,
|
||||
"type": file_type,
|
||||
"transfer_method": "local_file",
|
||||
"contentType": content_type or None,
|
||||
"fileSize": int(record.get("file_size") or 0),
|
||||
"belongs_to": "user",
|
||||
"usage": "temporary_attachment",
|
||||
}
|
||||
)
|
||||
return files
|
||||
|
||||
def _build_formal_kb_query(self, *, query: str, attachment_chunks: list[dict]) -> str:
|
||||
if not attachment_chunks:
|
||||
return query
|
||||
|
||||
facts: list[str] = []
|
||||
for chunk in attachment_chunks[:3]:
|
||||
text_value = str(chunk.get("text") or "").strip()
|
||||
if text_value:
|
||||
facts.append(text_value[:500])
|
||||
if not facts:
|
||||
return query
|
||||
|
||||
return (
|
||||
f"{query}\n\n"
|
||||
"用户上传文档中检索到的相关事实:\n"
|
||||
+ "\n".join(f"- {item}" for item in facts)
|
||||
+ "\n\n请检索这些事实对应的法律责任、处罚依据、裁量规则或案例。"
|
||||
)
|
||||
|
||||
def _merge_context_chunks(self, *, attachment_chunks: list[dict], formal_chunks: list[dict]) -> list[dict]:
|
||||
merged: list[dict] = []
|
||||
for chunk in attachment_chunks:
|
||||
merged.append(
|
||||
{
|
||||
**chunk,
|
||||
"source_scope": "chat_attachment",
|
||||
"data_source_type": "chat_attachment",
|
||||
}
|
||||
)
|
||||
for chunk in formal_chunks:
|
||||
merged.append(
|
||||
{
|
||||
**chunk,
|
||||
"source_scope": chunk.get("source_scope") or "formal_kb",
|
||||
"data_source_type": chunk.get("data_source_type") or "formal_kb",
|
||||
}
|
||||
)
|
||||
return merged
|
||||
|
||||
def _normalize_attachment_ids(self, *, attachment_id: str | None, attachment_ids: list[str] | None) -> list[str]:
|
||||
normalized: list[str] = []
|
||||
for raw in [*(attachment_ids or []), attachment_id]:
|
||||
value = str(raw or "").strip()
|
||||
if value and value not in normalized:
|
||||
normalized.append(value)
|
||||
return normalized
|
||||
|
||||
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
|
||||
return await self.retriever._embed_texts(texts, model_name)
|
||||
@@ -953,6 +1156,11 @@ class RagChatServiceImpl(IRagChatService):
|
||||
message_id: str,
|
||||
query: str,
|
||||
app: dict,
|
||||
current_user_id: int | None = None,
|
||||
tenant_code: str | None = None,
|
||||
user_area: str | None = None,
|
||||
attachment_id: str | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
self._task_events[task_id] = []
|
||||
self._task_done[task_id] = False
|
||||
@@ -964,6 +1172,11 @@ class RagChatServiceImpl(IRagChatService):
|
||||
message_id=message_id,
|
||||
query=query,
|
||||
app=app,
|
||||
current_user_id=current_user_id,
|
||||
tenant_code=tenant_code,
|
||||
user_area=user_area,
|
||||
attachment_id=attachment_id,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
)
|
||||
self._message_tasks[task_id] = task
|
||||
@@ -976,6 +1189,11 @@ class RagChatServiceImpl(IRagChatService):
|
||||
message_id: str,
|
||||
query: str,
|
||||
app: dict,
|
||||
current_user_id: int | None = None,
|
||||
tenant_code: str | None = None,
|
||||
user_area: str | None = None,
|
||||
attachment_id: str | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
context_chunks: list[dict] = []
|
||||
dataset_name = ""
|
||||
@@ -987,7 +1205,46 @@ class RagChatServiceImpl(IRagChatService):
|
||||
last_persisted_at = time.monotonic()
|
||||
|
||||
try:
|
||||
context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), query)
|
||||
attachment_chunks: list[dict] = []
|
||||
attachment_names: list[str] = []
|
||||
active_attachment_ids = self._normalize_attachment_ids(
|
||||
attachment_id=attachment_id,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
if not active_attachment_ids and current_user_id is not None:
|
||||
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
|
||||
CurrentUserId=current_user_id,
|
||||
TenantCode=tenant_code,
|
||||
UserArea=user_area,
|
||||
ConversationId=conversation_id,
|
||||
)
|
||||
if active_attachment_ids:
|
||||
if current_user_id is None:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件缺少用户上下文")
|
||||
for active_attachment_id in active_attachment_ids:
|
||||
next_chunks, next_attachment_name = await self._retrieve_attachment_context(
|
||||
CurrentUserId=current_user_id,
|
||||
TenantCode=tenant_code,
|
||||
UserArea=user_area,
|
||||
ConversationId=conversation_id,
|
||||
AttachmentId=active_attachment_id,
|
||||
Query=query,
|
||||
)
|
||||
attachment_chunks.extend(next_chunks)
|
||||
if next_attachment_name:
|
||||
attachment_names.append(next_attachment_name)
|
||||
legal_query = self._build_formal_kb_query(query=query, attachment_chunks=attachment_chunks)
|
||||
formal_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), legal_query)
|
||||
context_chunks = self._merge_context_chunks(
|
||||
attachment_chunks=attachment_chunks,
|
||||
formal_chunks=formal_chunks,
|
||||
)
|
||||
generation_dataset_name = dataset_name
|
||||
attachment_name = "、".join(dict.fromkeys(attachment_names))
|
||||
if attachment_name and dataset_name:
|
||||
generation_dataset_name = f"{attachment_name} + {dataset_name}"
|
||||
elif attachment_name:
|
||||
generation_dataset_name = attachment_name
|
||||
async for chunk in generate_stream(
|
||||
query=query,
|
||||
context_chunks=context_chunks,
|
||||
@@ -997,7 +1254,7 @@ class RagChatServiceImpl(IRagChatService):
|
||||
model=app.get("llm_model") or "",
|
||||
temperature=app.get("temperature"),
|
||||
max_tokens=app.get("max_tokens"),
|
||||
dataset_name=dataset_name,
|
||||
dataset_name=generation_dataset_name,
|
||||
task_id=task_id,
|
||||
):
|
||||
data = self._parse_sse_event(chunk)
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatAttachmentVo import (
|
||||
RagChatAttachmentDeleteVO,
|
||||
RagChatAttachmentVO,
|
||||
)
|
||||
|
||||
|
||||
class IRagChatAttachmentService(ABC):
|
||||
@abstractmethod
|
||||
async def CreateAttachment(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None,
|
||||
TenantName: str | None,
|
||||
ConversationId: str | None,
|
||||
AppId: int | None,
|
||||
FileName: str,
|
||||
ContentType: str | None,
|
||||
Content: bytes,
|
||||
) -> RagChatAttachmentVO: ...
|
||||
|
||||
@abstractmethod
|
||||
async def GetAttachment(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None,
|
||||
TenantName: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
) -> RagChatAttachmentVO: ...
|
||||
|
||||
@abstractmethod
|
||||
async def DeleteAttachment(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None,
|
||||
TenantName: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
) -> RagChatAttachmentDeleteVO: ...
|
||||
|
||||
@abstractmethod
|
||||
async def ValidateAttachmentForChat(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
UserArea: str | None = None,
|
||||
) -> dict: ...
|
||||
|
||||
@abstractmethod
|
||||
async def RetrieveAttachmentContext(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
Query: str,
|
||||
TopK: int = 5,
|
||||
UserArea: str | None = None,
|
||||
) -> tuple[list[dict], str]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def ResolveActiveAttachmentIdForConversation(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
UserArea: str | None = None,
|
||||
) -> str | None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def ResolveActiveAttachmentIdsForConversation(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
UserArea: str | None = None,
|
||||
) -> list[str]: ...
|
||||
@@ -50,6 +50,8 @@ class IRagChatService(ABC):
|
||||
Query: str,
|
||||
ConversationId: str | None,
|
||||
AppId: int | None,
|
||||
AttachmentId: str | None = None,
|
||||
AttachmentIds: list[str] | None = None,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> AsyncGenerator[bytes, None]: ...
|
||||
|
||||
@@ -22,6 +22,7 @@ dependencies = [
|
||||
"pyjwt>=2.10.0",
|
||||
"openai>=1.30.0",
|
||||
"pillow>=11.0.0",
|
||||
"openpyxl>=3.1.0",
|
||||
"pyyaml>=6.0",
|
||||
"minio>=7.2.8",
|
||||
"leaudit",
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
CREATE TABLE IF NOT EXISTS rag_chat_attachment (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
attachment_id VARCHAR(64) NOT NULL UNIQUE,
|
||||
conversation_id VARCHAR(64) NOT NULL,
|
||||
user_id BIGINT NOT NULL,
|
||||
tenant_code VARCHAR(64) NULL,
|
||||
area VARCHAR(255) NULL,
|
||||
filename VARCHAR(512) NOT NULL,
|
||||
original_name VARCHAR(512) NOT NULL,
|
||||
content_type VARCHAR(255) NULL,
|
||||
file_size BIGINT NOT NULL DEFAULT 0,
|
||||
minio_path TEXT NULL,
|
||||
collection_name VARCHAR(160) NOT NULL,
|
||||
indexing_status VARCHAR(32) NOT NULL DEFAULT 'waiting',
|
||||
indexing_error TEXT NULL,
|
||||
chunk_count INTEGER NOT NULL DEFAULT 0,
|
||||
indexing_started_at TIMESTAMPTZ NULL,
|
||||
indexing_completed_at TIMESTAMPTZ NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
deleted_at TIMESTAMPTZ NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_scope
|
||||
ON rag_chat_attachment(tenant_code, user_id, conversation_id, attachment_id)
|
||||
WHERE deleted_at IS NULL;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_expires
|
||||
ON rag_chat_attachment(expires_at)
|
||||
WHERE deleted_at IS NULL;
|
||||
@@ -0,0 +1,372 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
||||
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
|
||||
def _service() -> RagChatAttachmentServiceImpl:
|
||||
return RagChatAttachmentServiceImpl(chroma_client=None, embed_texts=lambda texts, model_name="": [[0.1] for _ in texts])
|
||||
|
||||
|
||||
def test_default_expiry_is_seven_days_from_now():
|
||||
service = _service()
|
||||
now = datetime(2026, 5, 25, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
expires_at = service._default_expires_at(now)
|
||||
|
||||
assert expires_at == now + timedelta(days=7)
|
||||
|
||||
|
||||
def test_collection_name_contains_isolation_components():
|
||||
service = _service()
|
||||
|
||||
collection_name = service.BuildCollectionName(
|
||||
TenantCode="gd-tobacco",
|
||||
UserId=42,
|
||||
ConversationId="conversation-abc-123",
|
||||
AttachmentId="attach-xyz",
|
||||
)
|
||||
|
||||
assert collection_name.startswith("chat_attachment_gd_tobacco_42_")
|
||||
assert collection_name.endswith("_attach_xyz")
|
||||
assert "conversation-abc-123" not in collection_name
|
||||
assert len(collection_name) <= 120
|
||||
|
||||
|
||||
def test_validate_attachment_scope_rejects_other_user_and_conversation():
|
||||
service = _service()
|
||||
record = {
|
||||
"attachment_id": "att-1",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"indexing_status": "completed",
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
|
||||
"deleted_at": None,
|
||||
}
|
||||
|
||||
with pytest.raises(LeauditException) as user_exc:
|
||||
service._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=101,
|
||||
TenantCode="tenant-a",
|
||||
ConversationId="conv-a",
|
||||
RequireCompleted=True,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
assert user_exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
|
||||
|
||||
with pytest.raises(LeauditException) as conversation_exc:
|
||||
service._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-a",
|
||||
ConversationId="conv-b",
|
||||
RequireCompleted=True,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
assert conversation_exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_get_attachment_requires_request_conversation_when_provided():
|
||||
service = _service()
|
||||
record = {
|
||||
"attachment_id": "att-1",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"indexing_status": "completed",
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
|
||||
"deleted_at": None,
|
||||
}
|
||||
|
||||
with pytest.raises(LeauditException) as exc:
|
||||
service._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-a",
|
||||
ConversationId="conv-b",
|
||||
RequireCompleted=False,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
assert exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
|
||||
assert "当前会话" in exc.value.message
|
||||
|
||||
|
||||
def test_get_attachment_rejects_same_user_attachment_from_another_conversation(monkeypatch):
|
||||
service = _service()
|
||||
record = {
|
||||
"attachment_id": "att-1",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"indexing_status": "completed",
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
|
||||
"deleted_at": None,
|
||||
}
|
||||
|
||||
async def fake_get_attachment_record(_attachment_id):
|
||||
return record
|
||||
|
||||
monkeypatch.setattr(service, "_get_attachment_record", fake_get_attachment_record)
|
||||
|
||||
with pytest.raises(LeauditException) as exc:
|
||||
asyncio.run(
|
||||
service.GetAttachment(
|
||||
CurrentUserId=100,
|
||||
UserArea=None,
|
||||
UserRole=None,
|
||||
TenantCode="tenant-a",
|
||||
TenantName=None,
|
||||
ConversationId="conv-b",
|
||||
AttachmentId="att-1",
|
||||
)
|
||||
)
|
||||
|
||||
assert exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_delete_attachment_rejects_same_user_attachment_from_another_conversation(monkeypatch):
|
||||
service = _service()
|
||||
record = {
|
||||
"attachment_id": "att-1",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"indexing_status": "completed",
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
|
||||
"deleted_at": None,
|
||||
}
|
||||
|
||||
async def fake_get_attachment_record(_attachment_id):
|
||||
return record
|
||||
|
||||
monkeypatch.setattr(service, "_get_attachment_record", fake_get_attachment_record)
|
||||
|
||||
with pytest.raises(LeauditException) as exc:
|
||||
asyncio.run(
|
||||
service.DeleteAttachment(
|
||||
CurrentUserId=100,
|
||||
UserArea=None,
|
||||
UserRole=None,
|
||||
TenantCode="tenant-a",
|
||||
TenantName=None,
|
||||
ConversationId="conv-b",
|
||||
AttachmentId="att-1",
|
||||
)
|
||||
)
|
||||
|
||||
assert exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_validate_attachment_scope_rejects_other_tenant():
|
||||
service = _service()
|
||||
record = {
|
||||
"attachment_id": "att-1",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"indexing_status": "completed",
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
|
||||
"deleted_at": None,
|
||||
}
|
||||
|
||||
with pytest.raises(LeauditException) as exc:
|
||||
service._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-b",
|
||||
ConversationId="conv-a",
|
||||
RequireCompleted=True,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
assert exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_validate_attachment_scope_rejects_expired_or_incomplete_attachment():
|
||||
service = _service()
|
||||
expired = {
|
||||
"attachment_id": "att-1",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"indexing_status": "completed",
|
||||
"expires_at": datetime.now(timezone.utc) - timedelta(seconds=1),
|
||||
"deleted_at": None,
|
||||
}
|
||||
waiting = {
|
||||
**expired,
|
||||
"indexing_status": "indexing",
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
|
||||
}
|
||||
|
||||
with pytest.raises(LeauditException) as expired_exc:
|
||||
service._assert_attachment_scope(
|
||||
expired,
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-a",
|
||||
ConversationId="conv-a",
|
||||
RequireCompleted=True,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
assert expired_exc.value.status == StatusCodeEnum.HTTP_400_BAD_REQUEST
|
||||
assert "已过期" in expired_exc.value.message
|
||||
|
||||
with pytest.raises(LeauditException) as waiting_exc:
|
||||
service._assert_attachment_scope(
|
||||
waiting,
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-a",
|
||||
ConversationId="conv-a",
|
||||
RequireCompleted=True,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
assert waiting_exc.value.status == StatusCodeEnum.HTTP_400_BAD_REQUEST
|
||||
assert "尚未完成" in waiting_exc.value.message
|
||||
|
||||
|
||||
def test_build_chunks_includes_isolation_metadata():
|
||||
service = _service()
|
||||
|
||||
chunks = service.BuildChunks(
|
||||
AttachmentId="att-1",
|
||||
TenantCode="tenant-a",
|
||||
UserId=100,
|
||||
ConversationId="conv-a",
|
||||
FileName="处罚材料.docx",
|
||||
PageTexts=[(1, "第一段违法事实。\n\n第二段处罚线索。")],
|
||||
ChunkMaxSize=20,
|
||||
ChunkOverlap=0,
|
||||
)
|
||||
|
||||
assert chunks
|
||||
metadata = chunks[0]["metadata"]
|
||||
assert metadata["tenant_code"] == "tenant-a"
|
||||
assert metadata["user_id"] == "100"
|
||||
assert metadata["conversation_id"] == "conv-a"
|
||||
assert metadata["attachment_id"] == "att-1"
|
||||
assert metadata["source_scope"] == "chat_attachment"
|
||||
assert metadata["document_name"] == "处罚材料.docx"
|
||||
assert metadata["page"] == 1
|
||||
|
||||
|
||||
def test_resolve_active_attachment_id_uses_user_conversation_tenant_and_completed_state(monkeypatch):
|
||||
service = _service()
|
||||
captured_sql = {}
|
||||
captured_params = {}
|
||||
|
||||
class FakeResult:
|
||||
def mappings(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return {"attachment_id": "att-active"}
|
||||
|
||||
def all(self):
|
||||
return [{"attachment_id": "att-active"}]
|
||||
|
||||
class FakeSession:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def execute(self, statement, params=None):
|
||||
captured_sql["value"] = str(statement)
|
||||
captured_params.update(params or {})
|
||||
return FakeResult()
|
||||
|
||||
class FakeSessionFactory:
|
||||
def __call__(self):
|
||||
return FakeSession()
|
||||
|
||||
service.__class__._attachment_schema_checked = True
|
||||
monkeypatch.setattr(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl.GetAsyncSession",
|
||||
FakeSessionFactory(),
|
||||
)
|
||||
|
||||
attachment_id = asyncio.run(
|
||||
service.ResolveActiveAttachmentIdForConversation(
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-a",
|
||||
UserArea="云浮",
|
||||
ConversationId="conv-a",
|
||||
)
|
||||
)
|
||||
|
||||
assert attachment_id == "att-active"
|
||||
assert "user_id = :user_id" in captured_sql["value"]
|
||||
assert "conversation_id = :conversation_id" in captured_sql["value"]
|
||||
assert "indexing_status = 'completed'" in captured_sql["value"]
|
||||
assert "expires_at > NOW()" in captured_sql["value"]
|
||||
assert "deleted_at IS NULL" in captured_sql["value"]
|
||||
assert captured_params == {
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_area": "云浮",
|
||||
}
|
||||
|
||||
|
||||
def test_resolve_active_attachment_ids_returns_all_completed_conversation_attachments(monkeypatch):
|
||||
service = _service()
|
||||
captured_sql = {}
|
||||
captured_params = {}
|
||||
|
||||
class FakeResult:
|
||||
def mappings(self):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
return [{"attachment_id": "att-1"}, {"attachment_id": "att-2"}]
|
||||
|
||||
class FakeSession:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def execute(self, statement, params=None):
|
||||
captured_sql["value"] = str(statement)
|
||||
captured_params.update(params or {})
|
||||
return FakeResult()
|
||||
|
||||
class FakeSessionFactory:
|
||||
def __call__(self):
|
||||
return FakeSession()
|
||||
|
||||
service.__class__._attachment_schema_checked = True
|
||||
monkeypatch.setattr(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl.GetAsyncSession",
|
||||
FakeSessionFactory(),
|
||||
)
|
||||
|
||||
attachment_ids = asyncio.run(
|
||||
service.ResolveActiveAttachmentIdsForConversation(
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-a",
|
||||
UserArea="云浮",
|
||||
ConversationId="conv-a",
|
||||
)
|
||||
)
|
||||
|
||||
assert attachment_ids == ["att-1", "att-2"]
|
||||
assert "ORDER BY indexing_completed_at DESC NULLS LAST, created_at DESC" in captured_sql["value"]
|
||||
assert "LIMIT 1" not in captured_sql["value"]
|
||||
assert captured_params == {
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_area": "云浮",
|
||||
}
|
||||
@@ -79,6 +79,232 @@ async def _run_streaming_task() -> list[dict]:
|
||||
return service._task_events[task_id]
|
||||
|
||||
|
||||
async def _run_streaming_task_with_attachment() -> tuple[list[dict], list[dict]]:
|
||||
service = RagChatServiceImpl()
|
||||
task_id = "task-attachment-test"
|
||||
service._task_events[task_id] = []
|
||||
service._task_done[task_id] = False
|
||||
service._task_locks[task_id] = asyncio.Lock()
|
||||
captured_context_chunks: list[dict] = []
|
||||
|
||||
async def fake_retrieve_context(dataset_id, query):
|
||||
return [
|
||||
{
|
||||
"dataset_id": dataset_id,
|
||||
"document_id": "law-doc-1",
|
||||
"document_name": "处罚依据.md",
|
||||
"source": "处罚依据.md",
|
||||
"id": "law-segment-1",
|
||||
"score": 0.88,
|
||||
"hit_count": 2,
|
||||
"text": "正式知识库中的处罚依据",
|
||||
}
|
||||
], "正式法规知识库"
|
||||
|
||||
async def fake_retrieve_attachment_context(**kwargs):
|
||||
return [
|
||||
{
|
||||
"attachment_id": kwargs["AttachmentId"],
|
||||
"document_name": "用户上传.docx",
|
||||
"source": "用户上传.docx",
|
||||
"id": "attachment-segment-1",
|
||||
"score": 0.96,
|
||||
"text": "上传文档中的违法事实",
|
||||
"source_scope": "chat_attachment",
|
||||
"data_source_type": "chat_attachment",
|
||||
}
|
||||
], "用户上传.docx"
|
||||
|
||||
async def fake_generate_stream(**kwargs):
|
||||
captured_context_chunks.extend(kwargs["context_chunks"])
|
||||
async for event in _fake_generate_stream(**kwargs):
|
||||
yield event
|
||||
|
||||
async def noop_async(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service._retrieve_context = fake_retrieve_context
|
||||
service._retrieve_attachment_context = fake_retrieve_attachment_context
|
||||
service._finalize_message_record = noop_async
|
||||
service._maybe_schedule_auto_title = noop_async
|
||||
|
||||
with (
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
|
||||
fake_generate_stream,
|
||||
),
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
await service._run_message_task(
|
||||
task_id=task_id,
|
||||
conversation_id="conversation-test",
|
||||
message_id="message-test",
|
||||
query="这份材料的违法内容会受到什么处罚",
|
||||
app={"dataset_id": 7},
|
||||
current_user_id=100,
|
||||
tenant_code="tenant-a",
|
||||
attachment_id="attachment-1",
|
||||
)
|
||||
|
||||
return service._task_events[task_id], captured_context_chunks
|
||||
|
||||
|
||||
async def _run_streaming_task_with_conversation_attachment_fallback() -> list[dict]:
|
||||
service = RagChatServiceImpl()
|
||||
task_id = "task-attachment-fallback-test"
|
||||
service._task_events[task_id] = []
|
||||
service._task_done[task_id] = False
|
||||
service._task_locks[task_id] = asyncio.Lock()
|
||||
captured_context_chunks: list[dict] = []
|
||||
|
||||
async def fake_resolve_attachment_ids_for_conversation(**kwargs):
|
||||
assert kwargs["ConversationId"] == "conversation-test"
|
||||
assert kwargs["CurrentUserId"] == 100
|
||||
return ["attachment-1"]
|
||||
|
||||
async def fake_retrieve_context(dataset_id, query):
|
||||
return [
|
||||
{
|
||||
"dataset_id": dataset_id,
|
||||
"document_id": "law-doc-1",
|
||||
"document_name": "处罚依据.md",
|
||||
"source": "处罚依据.md",
|
||||
"id": "law-segment-1",
|
||||
"score": 0.88,
|
||||
"hit_count": 2,
|
||||
"text": "正式知识库中的处罚依据",
|
||||
}
|
||||
], "正式法规知识库"
|
||||
|
||||
async def fake_retrieve_attachment_context(**kwargs):
|
||||
return [
|
||||
{
|
||||
"attachment_id": kwargs["AttachmentId"],
|
||||
"document_name": "用户上传.docx",
|
||||
"source": "用户上传.docx",
|
||||
"id": "attachment-segment-1",
|
||||
"score": 0.96,
|
||||
"text": "上传文档中的违法事实",
|
||||
"source_scope": "chat_attachment",
|
||||
"data_source_type": "chat_attachment",
|
||||
}
|
||||
], "用户上传.docx"
|
||||
|
||||
async def fake_generate_stream(**kwargs):
|
||||
captured_context_chunks.extend(kwargs["context_chunks"])
|
||||
async for event in _fake_generate_stream(**kwargs):
|
||||
yield event
|
||||
|
||||
async def noop_async(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service._resolve_attachment_ids_for_conversation = fake_resolve_attachment_ids_for_conversation
|
||||
service._retrieve_context = fake_retrieve_context
|
||||
service._retrieve_attachment_context = fake_retrieve_attachment_context
|
||||
service._finalize_message_record = noop_async
|
||||
service._maybe_schedule_auto_title = noop_async
|
||||
|
||||
with (
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
|
||||
fake_generate_stream,
|
||||
),
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
await service._run_message_task(
|
||||
task_id=task_id,
|
||||
conversation_id="conversation-test",
|
||||
message_id="message-test",
|
||||
query="江小妹违法了什么",
|
||||
app={"dataset_id": 7},
|
||||
current_user_id=100,
|
||||
tenant_code="tenant-a",
|
||||
attachment_id=None,
|
||||
)
|
||||
|
||||
return captured_context_chunks
|
||||
|
||||
|
||||
async def _run_streaming_task_with_multiple_attachments() -> tuple[list[dict], list[dict]]:
|
||||
service = RagChatServiceImpl()
|
||||
task_id = "task-multiple-attachments-test"
|
||||
service._task_events[task_id] = []
|
||||
service._task_done[task_id] = False
|
||||
service._task_locks[task_id] = asyncio.Lock()
|
||||
captured_context_chunks: list[dict] = []
|
||||
|
||||
async def fake_retrieve_context(dataset_id, query):
|
||||
return [
|
||||
{
|
||||
"dataset_id": dataset_id,
|
||||
"document_id": "law-doc-1",
|
||||
"document_name": "处罚依据.md",
|
||||
"source": "处罚依据.md",
|
||||
"id": "law-segment-1",
|
||||
"score": 0.88,
|
||||
"hit_count": 2,
|
||||
"text": "正式知识库中的处罚依据",
|
||||
}
|
||||
], "正式法规知识库"
|
||||
|
||||
async def fake_retrieve_attachment_context(**kwargs):
|
||||
attachment_id = kwargs["AttachmentId"]
|
||||
return [
|
||||
{
|
||||
"attachment_id": attachment_id,
|
||||
"document_name": f"{attachment_id}.docx",
|
||||
"source": f"{attachment_id}.docx",
|
||||
"id": f"{attachment_id}-segment-1",
|
||||
"score": 0.96,
|
||||
"text": f"{attachment_id} 中的违法事实",
|
||||
"source_scope": "chat_attachment",
|
||||
"data_source_type": "chat_attachment",
|
||||
}
|
||||
], f"{attachment_id}.docx"
|
||||
|
||||
async def fake_generate_stream(**kwargs):
|
||||
captured_context_chunks.extend(kwargs["context_chunks"])
|
||||
async for event in _fake_generate_stream(**kwargs):
|
||||
yield event
|
||||
|
||||
async def noop_async(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service._retrieve_context = fake_retrieve_context
|
||||
service._retrieve_attachment_context = fake_retrieve_attachment_context
|
||||
service._finalize_message_record = noop_async
|
||||
service._maybe_schedule_auto_title = noop_async
|
||||
|
||||
with (
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
|
||||
fake_generate_stream,
|
||||
),
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
await service._run_message_task(
|
||||
task_id=task_id,
|
||||
conversation_id="conversation-test",
|
||||
message_id="message-test",
|
||||
query="这些材料的违法内容会受到什么处罚",
|
||||
app={"dataset_id": 7},
|
||||
current_user_id=100,
|
||||
tenant_code="tenant-a",
|
||||
attachment_ids=["attachment-1", "attachment-2"],
|
||||
)
|
||||
|
||||
return service._task_events[task_id], captured_context_chunks
|
||||
|
||||
|
||||
def test_streaming_message_end_includes_retriever_resources():
|
||||
events = asyncio.run(_run_streaming_task())
|
||||
|
||||
@@ -89,3 +315,87 @@ def test_streaming_message_end_includes_retriever_resources():
|
||||
assert resources[0]["dataset_id"] == "7"
|
||||
assert resources[0]["dataset_name"] == "测试知识库"
|
||||
assert resources[0]["document_name"] == "引用文档.pdf"
|
||||
|
||||
|
||||
def test_message_task_merges_attachment_facts_and_formal_kb_context():
|
||||
events, context_chunks = asyncio.run(_run_streaming_task_with_attachment())
|
||||
|
||||
message_end = next(event for event in events if event.get("event") == "message_end")
|
||||
resources = message_end["metadata"].get("retriever_resources")
|
||||
|
||||
scopes = [chunk.get("source_scope") for chunk in context_chunks]
|
||||
assert scopes == ["chat_attachment", "formal_kb"]
|
||||
assert context_chunks[0]["document_name"] == "用户上传.docx"
|
||||
assert context_chunks[1]["document_name"] == "处罚依据.md"
|
||||
assert resources[0]["data_source_type"] == "chat_attachment"
|
||||
assert resources[1]["data_source_type"] == "formal_kb"
|
||||
|
||||
|
||||
def test_message_task_uses_active_conversation_attachment_when_request_omits_attachment_id():
|
||||
context_chunks = asyncio.run(_run_streaming_task_with_conversation_attachment_fallback())
|
||||
|
||||
scopes = [chunk.get("source_scope") for chunk in context_chunks]
|
||||
assert scopes == ["chat_attachment", "formal_kb"]
|
||||
assert context_chunks[0]["attachment_id"] == "attachment-1"
|
||||
|
||||
|
||||
def test_message_task_merges_multiple_attachment_contexts_before_formal_kb():
|
||||
events, context_chunks = asyncio.run(_run_streaming_task_with_multiple_attachments())
|
||||
|
||||
message_end = next(event for event in events if event.get("event") == "message_end")
|
||||
resources = message_end["metadata"].get("retriever_resources")
|
||||
|
||||
scopes = [chunk.get("source_scope") for chunk in context_chunks]
|
||||
assert scopes == ["chat_attachment", "chat_attachment", "formal_kb"]
|
||||
assert [chunk.get("attachment_id") for chunk in context_chunks[:2]] == ["attachment-1", "attachment-2"]
|
||||
assert resources[0]["data_source_type"] == "chat_attachment"
|
||||
assert resources[1]["data_source_type"] == "chat_attachment"
|
||||
assert resources[2]["data_source_type"] == "formal_kb"
|
||||
|
||||
|
||||
def test_message_attachment_metadata_is_stored_for_history_display():
|
||||
service = RagChatServiceImpl()
|
||||
|
||||
files = service._build_user_message_files(
|
||||
[
|
||||
{
|
||||
"attachment_id": "att-doc",
|
||||
"original_name": "处罚材料.docx",
|
||||
"content_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"file_size": 1234,
|
||||
},
|
||||
{
|
||||
"attachment_id": "att-img",
|
||||
"original_name": "现场照片.png",
|
||||
"content_type": "image/png",
|
||||
"file_size": 4567,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
assert files == [
|
||||
{
|
||||
"id": "att-doc",
|
||||
"upload_file_id": "att-doc",
|
||||
"name": "处罚材料.docx",
|
||||
"fileName": "处罚材料.docx",
|
||||
"type": "file",
|
||||
"transfer_method": "local_file",
|
||||
"contentType": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"fileSize": 1234,
|
||||
"belongs_to": "user",
|
||||
"usage": "temporary_attachment",
|
||||
},
|
||||
{
|
||||
"id": "att-img",
|
||||
"upload_file_id": "att-img",
|
||||
"name": "现场照片.png",
|
||||
"fileName": "现场照片.png",
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"contentType": "image/png",
|
||||
"fileSize": 4567,
|
||||
"belongs_to": "user",
|
||||
"usage": "temporary_attachment",
|
||||
},
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user