feat(rag): add temporary chat attachments
This commit is contained in:
@@ -30,6 +30,10 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
|
||||
RagMessagePageVO,
|
||||
RagOperationResultVO,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatAttachmentVo import (
|
||||
RagChatAttachmentDeleteVO,
|
||||
RagChatAttachmentVO,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
|
||||
RagDatasetBatchDeleteResultVO,
|
||||
RagDatasetDetailVO,
|
||||
@@ -43,8 +47,10 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.permissionServiceImpl import PermissionServiceImpl
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
|
||||
from fastapi_modules.fastapi_leaudit.services.permissionService import IPermissionService
|
||||
from fastapi_modules.fastapi_leaudit.services.ragChatAttachmentService import IRagChatAttachmentService
|
||||
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
|
||||
from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService
|
||||
|
||||
@@ -76,6 +82,7 @@ class RagChatController(BaseController):
|
||||
def __init__(self):
|
||||
super().__init__(prefix="/v3/rag", tags=["RAG 聊天"])
|
||||
self.RagChatService: IRagChatService = RagChatServiceImpl()
|
||||
self.RagChatAttachmentService: IRagChatAttachmentService = RagChatAttachmentServiceImpl()
|
||||
self.RagDatasetService: IRagDatasetService = RagDatasetServiceImpl()
|
||||
self.PermissionService: IPermissionService = PermissionServiceImpl()
|
||||
|
||||
@@ -484,6 +491,8 @@ class RagChatController(BaseController):
|
||||
Query=Body.query,
|
||||
ConversationId=Body.conversationId,
|
||||
AppId=Body.appId,
|
||||
AttachmentId=Body.attachmentId,
|
||||
AttachmentIds=Body.attachmentIds,
|
||||
)
|
||||
return StreamingResponse(
|
||||
stream,
|
||||
@@ -491,6 +500,62 @@ class RagChatController(BaseController):
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
@self.router.post("/chat/attachments", response_model=Result[RagChatAttachmentVO])
|
||||
async def UploadChatAttachment(
|
||||
file: UploadFile = File(...),
|
||||
conversation_id: str | None = Form(None),
|
||||
app_id: int | None = Form(None),
|
||||
payload: dict[str, Any] = Depends(verify_access_token),
|
||||
):
|
||||
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"]]):
|
||||
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有上传聊天附件权限", "data": None})
|
||||
tenant_context = self._tenant_context(payload)
|
||||
file_bytes = await file.read()
|
||||
data = await self.RagChatAttachmentService.CreateAttachment(
|
||||
CurrentUserId=int(payload["user_id"]),
|
||||
**tenant_context,
|
||||
ConversationId=conversation_id,
|
||||
AppId=app_id,
|
||||
FileName=file.filename or "attachment",
|
||||
ContentType=file.content_type,
|
||||
Content=file_bytes,
|
||||
)
|
||||
return Result.success(data=data)
|
||||
|
||||
@self.router.get("/chat/attachments/{AttachmentId}", response_model=Result[RagChatAttachmentVO])
|
||||
async def GetChatAttachment(
|
||||
AttachmentId: str,
|
||||
conversation_id: str = Query(..., description="附件所属会话ID"),
|
||||
payload: dict[str, Any] = Depends(verify_access_token),
|
||||
):
|
||||
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"]]):
|
||||
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看聊天附件权限", "data": None})
|
||||
tenant_context = self._tenant_context(payload)
|
||||
data = await self.RagChatAttachmentService.GetAttachment(
|
||||
CurrentUserId=int(payload["user_id"]),
|
||||
**tenant_context,
|
||||
ConversationId=conversation_id,
|
||||
AttachmentId=AttachmentId,
|
||||
)
|
||||
return Result.success(data=data)
|
||||
|
||||
@self.router.delete("/chat/attachments/{AttachmentId}", response_model=Result[RagChatAttachmentDeleteVO])
|
||||
async def DeleteChatAttachment(
|
||||
AttachmentId: str,
|
||||
conversation_id: str = Query(..., description="附件所属会话ID"),
|
||||
payload: dict[str, Any] = Depends(verify_access_token),
|
||||
):
|
||||
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"]]):
|
||||
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有删除聊天附件权限", "data": None})
|
||||
tenant_context = self._tenant_context(payload)
|
||||
data = await self.RagChatAttachmentService.DeleteAttachment(
|
||||
CurrentUserId=int(payload["user_id"]),
|
||||
**tenant_context,
|
||||
ConversationId=conversation_id,
|
||||
AttachmentId=AttachmentId,
|
||||
)
|
||||
return Result.success(data=data)
|
||||
|
||||
@self.router.post("/chat/messages/{MessageId}/stop", response_model=Result[RagOperationResultVO])
|
||||
async def StopMessage(
|
||||
MessageId: str,
|
||||
|
||||
@@ -5,6 +5,8 @@ class RagChatSendMessageDTO(BaseModel):
|
||||
query: str = Field(..., min_length=1, description="用户问题")
|
||||
conversationId: str | None = Field(None, description="会话ID")
|
||||
appId: int | None = Field(None, description="聊天应用ID")
|
||||
attachmentId: str | None = Field(None, description="临时聊天附件ID")
|
||||
attachmentIds: list[str] = Field(default_factory=list, description="临时聊天附件ID列表")
|
||||
|
||||
|
||||
class RagConversationRenameDTO(BaseModel):
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RagChatAttachmentVO(BaseModel):
|
||||
attachmentId: str = Field(..., description="临时附件ID")
|
||||
conversationId: str = Field(..., description="会话ID")
|
||||
fileName: str = Field(..., description="原始文件名")
|
||||
contentType: str = Field("", description="文件 MIME 类型")
|
||||
fileSize: int = Field(0, description="文件大小")
|
||||
indexingStatus: str = Field("waiting", description="索引状态")
|
||||
indexingError: str | None = Field(None, description="索引错误")
|
||||
chunkCount: int = Field(0, description="分段数量")
|
||||
collectionName: str = Field("", description="临时向量集合名")
|
||||
expiresAt: int = Field(0, description="过期时间戳")
|
||||
createdAt: int = Field(0, description="创建时间戳")
|
||||
|
||||
|
||||
class RagChatAttachmentDeleteVO(BaseModel):
|
||||
result: str = Field("success")
|
||||
@@ -36,6 +36,7 @@ class RagMessageItemVO(BaseModel):
|
||||
conversationId: str = Field(...)
|
||||
query: str = Field(...)
|
||||
answer: str = Field(...)
|
||||
messageFiles: list[dict] = Field(default_factory=list)
|
||||
feedback: dict | None = Field(None)
|
||||
retrieverResources: list[dict] | None = Field(None)
|
||||
suggestedQuestions: list[str] = Field(default_factory=list)
|
||||
|
||||
@@ -33,16 +33,34 @@ async def generate_stream(
|
||||
|
||||
max_context_chars = 8000
|
||||
if context_chunks:
|
||||
parts: list[str] = []
|
||||
attachment_parts: list[str] = []
|
||||
formal_parts: list[str] = []
|
||||
total_len = 0
|
||||
for chunk in context_chunks:
|
||||
part = f"[来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
|
||||
scope = str(chunk.get("source_scope") or chunk.get("data_source_type") or "formal_kb")
|
||||
if scope == "chat_attachment":
|
||||
label = "上传文件事实"
|
||||
else:
|
||||
label = "正式知识库依据"
|
||||
part = f"[{label} | 来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
|
||||
if total_len + len(part) > max_context_chars:
|
||||
break
|
||||
parts.append(part)
|
||||
if scope == "chat_attachment":
|
||||
attachment_parts.append(part)
|
||||
else:
|
||||
formal_parts.append(part)
|
||||
total_len += len(part)
|
||||
context_text = "\\n\\n---\\n\\n".join(parts)
|
||||
user_content = f"知识库内容:\\n{context_text}\\n\\n用户问题: {query}"
|
||||
context_sections: list[str] = []
|
||||
if attachment_parts:
|
||||
context_sections.append("【上传文件事实】\n" + "\\n\\n---\\n\\n".join(attachment_parts))
|
||||
if formal_parts:
|
||||
context_sections.append("【正式知识库依据】\n" + "\\n\\n---\\n\\n".join(formal_parts))
|
||||
context_text = "\\n\\n======\\n\\n".join(context_sections)
|
||||
user_content = (
|
||||
"请严格区分上下文来源:上传文件事实只能用于判断用户材料中的事实;正式知识库依据只能用于法律、处罚、裁量或案例依据。"
|
||||
"如果依据不足,请明确说明,不要编造。\n\n"
|
||||
f"{context_text}\\n\\n用户问题: {query}"
|
||||
)
|
||||
else:
|
||||
user_content = query
|
||||
|
||||
@@ -115,7 +133,7 @@ async def generate_stream(
|
||||
"dataset_name": dataset_name,
|
||||
"document_id": "",
|
||||
"document_name": chunk.get("source", ""),
|
||||
"data_source_type": "upload_file",
|
||||
"data_source_type": chunk.get("data_source_type") or chunk.get("source_scope") or "upload_file",
|
||||
"segment_id": chunk.get("id", ""),
|
||||
"retriever_from": "rag",
|
||||
"score": round(chunk.get("score", 0.0), 4),
|
||||
@@ -125,6 +143,8 @@ async def generate_stream(
|
||||
"index_node_hash": "",
|
||||
"content": chunk.get("text", "")[:500],
|
||||
"page": None,
|
||||
"source_scope": chunk.get("source_scope") or "",
|
||||
"attachment_id": chunk.get("attachment_id") or "",
|
||||
}
|
||||
for i, chunk in enumerate(context_chunks)
|
||||
]
|
||||
|
||||
@@ -183,10 +183,15 @@ class RagRetriever:
|
||||
"source": meta.get("source") or meta.get("document_name") or dataset_name,
|
||||
"score": score,
|
||||
"chunk_index": int(meta.get("chunk_index") or idx),
|
||||
"document_name": document_name,
|
||||
"document_id": meta.get("document_id"),
|
||||
"page": meta.get("page"),
|
||||
}
|
||||
"document_name": document_name,
|
||||
"document_id": meta.get("document_id"),
|
||||
"page": meta.get("page"),
|
||||
"source_scope": meta.get("source_scope"),
|
||||
"attachment_id": meta.get("attachment_id"),
|
||||
"conversation_id": meta.get("conversation_id"),
|
||||
"tenant_code": meta.get("tenant_code"),
|
||||
"user_id": meta.get("user_id"),
|
||||
}
|
||||
)
|
||||
return chunks
|
||||
|
||||
@@ -285,6 +290,11 @@ class RagRetriever:
|
||||
"document_name": document_name,
|
||||
"document_id": meta.get("document_id"),
|
||||
"page": meta.get("page"),
|
||||
"source_scope": meta.get("source_scope"),
|
||||
"attachment_id": meta.get("attachment_id"),
|
||||
"conversation_id": meta.get("conversation_id"),
|
||||
"tenant_code": meta.get("tenant_code"),
|
||||
"user_id": meta.get("user_id"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -392,10 +402,10 @@ class RagRetriever:
|
||||
{
|
||||
"position": index + 1,
|
||||
"dataset_id": str(chunk.get("dataset_id") or ""),
|
||||
"dataset_name": dataset_name,
|
||||
"dataset_name": chunk.get("dataset_name") or dataset_name,
|
||||
"document_id": str(chunk.get("document_id") or ""),
|
||||
"document_name": chunk.get("document_name") or chunk.get("source", ""),
|
||||
"data_source_type": "upload_file",
|
||||
"data_source_type": chunk.get("data_source_type") or chunk.get("source_scope") or "upload_file",
|
||||
"segment_id": chunk.get("id", ""),
|
||||
"retriever_from": "rag",
|
||||
"score": round(float(chunk.get("score") or 0.0), 4),
|
||||
@@ -405,6 +415,8 @@ class RagRetriever:
|
||||
"index_node_hash": "",
|
||||
"content": chunk.get("text", "")[:500],
|
||||
"page": None,
|
||||
"source_scope": chunk.get("source_scope") or "",
|
||||
"attachment_id": chunk.get("attachment_id") or "",
|
||||
}
|
||||
for index, chunk in enumerate(context_chunks)
|
||||
]
|
||||
|
||||
@@ -0,0 +1,821 @@
|
||||
"""Temporary RAG attachments for chat conversations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import csv
|
||||
import hashlib
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from fastapi_common.fastapi_common_storage.oss_client import OssClient
|
||||
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
||||
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatAttachmentVo import (
|
||||
RagChatAttachmentDeleteVO,
|
||||
RagChatAttachmentVO,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import EmbedTexts, RagRetriever
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
|
||||
from fastapi_modules.fastapi_leaudit.services.ragChatAttachmentService import IRagChatAttachmentService
|
||||
|
||||
|
||||
DEFAULT_ATTACHMENT_TTL_DAYS = 7
|
||||
SUPPORTED_TEXT_SUFFIXES = {".txt", ".md", ".json", ".csv"}
|
||||
SUPPORTED_DOCUMENT_SUFFIXES = {".docx", ".pdf", ".xlsx"}
|
||||
SUPPORTED_IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tif", ".tiff"}
|
||||
SUPPORTED_SUFFIXES = SUPPORTED_TEXT_SUFFIXES | SUPPORTED_DOCUMENT_SUFFIXES | SUPPORTED_IMAGE_SUFFIXES
|
||||
|
||||
|
||||
class RagChatAttachmentServiceImpl(IRagChatAttachmentService):
|
||||
"""Manage temporary, conversation-scoped RAG attachment indexes."""
|
||||
|
||||
_attachment_schema_checked = False
|
||||
_attachment_schema_lock = asyncio.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
chroma_client: Any | None = None,
|
||||
embed_texts: EmbedTexts | None = None,
|
||||
chat_service: RagChatServiceImpl | None = None,
|
||||
ocr_client_factory: Callable[[], Any] | None = None,
|
||||
) -> None:
|
||||
self._chroma_client = chroma_client
|
||||
self.retriever = RagRetriever(
|
||||
chroma_client=chroma_client,
|
||||
embed_texts=embed_texts,
|
||||
hydrate_documents=False,
|
||||
)
|
||||
self.chat_service = chat_service or RagChatServiceImpl()
|
||||
self.dataset_helpers = RagDatasetServiceImpl()
|
||||
self._ocr_client_factory = ocr_client_factory
|
||||
|
||||
def _default_expires_at(self, now: datetime | None = None) -> datetime:
|
||||
base = self._ensure_aware_datetime(now or datetime.now(timezone.utc))
|
||||
return base + timedelta(days=DEFAULT_ATTACHMENT_TTL_DAYS)
|
||||
|
||||
def BuildCollectionName(self, *, TenantCode: str | None, UserId: int, ConversationId: str, AttachmentId: str) -> str:
|
||||
tenant = self._slugify(TenantCode or "default")
|
||||
user = self._slugify(str(UserId))
|
||||
conversation_hash = hashlib.sha1(ConversationId.encode("utf-8")).hexdigest()[:16]
|
||||
attachment = self._slugify(AttachmentId)
|
||||
collection_name = f"chat_attachment_{tenant}_{user}_{conversation_hash}_{attachment}"
|
||||
return collection_name[:120]
|
||||
|
||||
def BuildChunks(
|
||||
self,
|
||||
*,
|
||||
AttachmentId: str,
|
||||
TenantCode: str | None,
|
||||
UserId: int,
|
||||
ConversationId: str,
|
||||
FileName: str,
|
||||
PageTexts: list[tuple[int, str]],
|
||||
ChunkMaxSize: int = 800,
|
||||
ChunkOverlap: int = 80,
|
||||
) -> list[dict[str, Any]]:
|
||||
chunk_max_size = max(50, int(ChunkMaxSize or 800))
|
||||
chunk_overlap = max(0, min(int(ChunkOverlap or 0), chunk_max_size // 2))
|
||||
chunks: list[dict[str, Any]] = []
|
||||
for page_no, raw_text in PageTexts:
|
||||
text_value = self._preprocess_text(raw_text)
|
||||
if not text_value:
|
||||
continue
|
||||
split_chunks = self._split_text(text_value, separator="\n\n", max_chars=chunk_max_size, overlap=chunk_overlap)
|
||||
for index, chunk_text in enumerate(split_chunks):
|
||||
if not chunk_text.strip():
|
||||
continue
|
||||
chunk_id = f"{AttachmentId}:{page_no}:{index}"
|
||||
chunks.append(
|
||||
{
|
||||
"id": chunk_id,
|
||||
"text": chunk_text,
|
||||
"metadata": {
|
||||
"id": chunk_id,
|
||||
"source": FileName,
|
||||
"document_name": FileName,
|
||||
"attachment_id": AttachmentId,
|
||||
"tenant_code": str(TenantCode or ""),
|
||||
"user_id": str(UserId),
|
||||
"conversation_id": ConversationId,
|
||||
"source_scope": "chat_attachment",
|
||||
"page": int(page_no),
|
||||
"chunk_index": index,
|
||||
},
|
||||
}
|
||||
)
|
||||
return chunks
|
||||
|
||||
async def CreateAttachment(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None,
|
||||
TenantName: str | None,
|
||||
ConversationId: str | None,
|
||||
AppId: int | None,
|
||||
FileName: str,
|
||||
ContentType: str | None,
|
||||
Content: bytes,
|
||||
) -> RagChatAttachmentVO:
|
||||
if not FileName:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "上传文件名不能为空")
|
||||
if not Content:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "上传文件内容不能为空")
|
||||
|
||||
suffix = Path(FileName).suffix.lower()
|
||||
if suffix not in SUPPORTED_SUFFIXES:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 TXT、MD、JSON、CSV、DOCX、PDF、XLSX 和图片文件")
|
||||
|
||||
app = await self.chat_service._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName)
|
||||
if not app:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用")
|
||||
|
||||
conversation_id = await self.chat_service._ensure_conversation(
|
||||
user_id=CurrentUserId,
|
||||
conversation_id=ConversationId,
|
||||
app_id=app["id"],
|
||||
user_area=UserArea,
|
||||
user_role=UserRole,
|
||||
tenant_code=TenantCode,
|
||||
tenant_name=TenantName,
|
||||
)
|
||||
|
||||
attachment_id = str(uuid.uuid4())
|
||||
collection_name = self.BuildCollectionName(
|
||||
TenantCode=TenantCode,
|
||||
UserId=CurrentUserId,
|
||||
ConversationId=conversation_id,
|
||||
AttachmentId=attachment_id,
|
||||
)
|
||||
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
|
||||
expires_at = self._default_expires_at()
|
||||
object_key = f"rag-chat-attachments/{TenantCode or 'default'}/{CurrentUserId}/{conversation_id}/{attachment_id}_{FileName}"
|
||||
OssClient().EnsureBucket()
|
||||
stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type)
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_attachment_schema(session)
|
||||
row = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_chat_attachment (
|
||||
attachment_id, conversation_id, user_id, tenant_code, area,
|
||||
filename, original_name, content_type, file_size, minio_path,
|
||||
collection_name, indexing_status, chunk_count, expires_at
|
||||
) VALUES (
|
||||
:attachment_id, :conversation_id, :user_id, :tenant_code, :area,
|
||||
:filename, :original_name, :content_type, :file_size, :minio_path,
|
||||
:collection_name, 'waiting', 0, :expires_at
|
||||
)
|
||||
RETURNING *
|
||||
"""
|
||||
),
|
||||
{
|
||||
"attachment_id": attachment_id,
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": CurrentUserId,
|
||||
"tenant_code": TenantCode,
|
||||
"area": UserArea,
|
||||
"filename": f"{attachment_id}{suffix}",
|
||||
"original_name": FileName,
|
||||
"content_type": content_type,
|
||||
"file_size": len(Content),
|
||||
"minio_path": stored_key,
|
||||
"collection_name": collection_name,
|
||||
"expires_at": expires_at,
|
||||
},
|
||||
)
|
||||
).mappings().first()
|
||||
|
||||
asyncio.create_task(
|
||||
self._run_attachment_indexing_task(
|
||||
attachment_id=attachment_id,
|
||||
tenant_code=TenantCode,
|
||||
user_id=CurrentUserId,
|
||||
conversation_id=conversation_id,
|
||||
file_name=FileName,
|
||||
content=Content,
|
||||
collection_name=collection_name,
|
||||
)
|
||||
)
|
||||
return self._to_vo(dict(row))
|
||||
|
||||
async def GetAttachment(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None,
|
||||
TenantName: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
) -> RagChatAttachmentVO:
|
||||
self._validate_conversation_id(ConversationId)
|
||||
record = await self._get_attachment_record(AttachmentId)
|
||||
self._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
RequireCompleted=False,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
return self._to_vo(record)
|
||||
|
||||
async def DeleteAttachment(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None,
|
||||
TenantName: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
) -> RagChatAttachmentDeleteVO:
|
||||
self._validate_conversation_id(ConversationId)
|
||||
record = await self._get_attachment_record(AttachmentId)
|
||||
self._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
RequireCompleted=False,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_attachment_schema(session)
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_chat_attachment
|
||||
SET deleted_at = NOW(), updated_at = NOW()
|
||||
WHERE attachment_id = :attachment_id
|
||||
"""
|
||||
),
|
||||
{"attachment_id": AttachmentId},
|
||||
)
|
||||
self._delete_collection(str(record.get("collection_name") or ""))
|
||||
self._delete_oss_object(str(record.get("minio_path") or ""))
|
||||
return RagChatAttachmentDeleteVO(result="success")
|
||||
|
||||
async def ValidateAttachmentForChat(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
UserArea: str | None = None,
|
||||
) -> dict:
|
||||
record = await self._get_attachment_record(AttachmentId)
|
||||
self._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
RequireCompleted=True,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
return record
|
||||
|
||||
async def RetrieveAttachmentContext(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
Query: str,
|
||||
TopK: int = 5,
|
||||
UserArea: str | None = None,
|
||||
) -> tuple[list[dict], str]:
|
||||
record = await self.ValidateAttachmentForChat(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
AttachmentId=AttachmentId,
|
||||
)
|
||||
result = await self.retriever.retrieve(
|
||||
query=Query,
|
||||
collection_name=str(record["collection_name"]),
|
||||
top_k=TopK,
|
||||
)
|
||||
chunks: list[dict] = []
|
||||
for chunk in result.chunks:
|
||||
chunks.append(
|
||||
{
|
||||
**chunk,
|
||||
"dataset_name": "临时上传文件",
|
||||
"attachment_id": AttachmentId,
|
||||
"conversation_id": ConversationId,
|
||||
"source_scope": "chat_attachment",
|
||||
"data_source_type": "chat_attachment",
|
||||
"document_name": chunk.get("document_name") or record.get("original_name") or chunk.get("source"),
|
||||
"source": chunk.get("source") or record.get("original_name") or "上传文件",
|
||||
}
|
||||
)
|
||||
return chunks, str(record.get("original_name") or "上传文件")
|
||||
|
||||
async def ResolveActiveAttachmentIdForConversation(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
UserArea: str | None = None,
|
||||
) -> str | None:
|
||||
attachment_ids = await self.ResolveActiveAttachmentIdsForConversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
)
|
||||
return attachment_ids[0] if attachment_ids else None
|
||||
|
||||
async def ResolveActiveAttachmentIdsForConversation(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
UserArea: str | None = None,
|
||||
) -> list[str]:
|
||||
self._validate_conversation_id(ConversationId)
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_attachment_schema(session)
|
||||
rows = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT attachment_id, tenant_code, area
|
||||
FROM rag_chat_attachment
|
||||
WHERE user_id = :user_id
|
||||
AND conversation_id = :conversation_id
|
||||
AND indexing_status = 'completed'
|
||||
AND expires_at > NOW()
|
||||
AND deleted_at IS NULL
|
||||
AND (
|
||||
NULLIF(:tenant_code, '') IS NULL
|
||||
OR NULLIF(tenant_code, '') IS NULL
|
||||
OR tenant_code = :tenant_code
|
||||
)
|
||||
AND (
|
||||
NULLIF(:tenant_code, '') IS NOT NULL
|
||||
OR NULLIF(:user_area, '') IS NULL
|
||||
OR NULLIF(area, '') IS NULL
|
||||
OR area = :user_area
|
||||
)
|
||||
ORDER BY indexing_completed_at DESC NULLS LAST, created_at DESC
|
||||
"""
|
||||
),
|
||||
{
|
||||
"user_id": CurrentUserId,
|
||||
"conversation_id": ConversationId,
|
||||
"tenant_code": TenantCode,
|
||||
"user_area": UserArea,
|
||||
},
|
||||
)
|
||||
).mappings().all()
|
||||
return [str(row["attachment_id"]) for row in rows]
|
||||
|
||||
async def CleanupExpiredAttachments(self, *, Limit: int = 100) -> int:
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_attachment_schema(session)
|
||||
rows = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT attachment_id, collection_name, minio_path
|
||||
FROM rag_chat_attachment
|
||||
WHERE deleted_at IS NULL
|
||||
AND expires_at <= NOW()
|
||||
ORDER BY expires_at ASC
|
||||
LIMIT :limit
|
||||
"""
|
||||
),
|
||||
{"limit": max(1, int(Limit or 100))},
|
||||
)
|
||||
).mappings().all()
|
||||
attachment_ids = [str(row["attachment_id"]) for row in rows]
|
||||
if attachment_ids:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_chat_attachment
|
||||
SET deleted_at = NOW(), updated_at = NOW()
|
||||
WHERE attachment_id = ANY(:attachment_ids)
|
||||
"""
|
||||
),
|
||||
{"attachment_ids": attachment_ids},
|
||||
)
|
||||
for row in rows:
|
||||
self._delete_collection(str(row.get("collection_name") or ""))
|
||||
self._delete_oss_object(str(row.get("minio_path") or ""))
|
||||
return len(rows)
|
||||
|
||||
async def _get_attachment_record(self, attachment_id: str) -> dict:
|
||||
if not str(attachment_id or "").strip():
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "附件ID不能为空")
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_attachment_schema(session)
|
||||
row = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT *
|
||||
FROM rag_chat_attachment
|
||||
WHERE attachment_id = :attachment_id
|
||||
AND deleted_at IS NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"attachment_id": attachment_id},
|
||||
)
|
||||
).mappings().first()
|
||||
if not row:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "临时附件不存在")
|
||||
return dict(row)
|
||||
|
||||
def _assert_attachment_scope(
|
||||
self,
|
||||
record: dict,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
RequireCompleted: bool,
|
||||
Now: datetime | None = None,
|
||||
UserArea: str | None = None,
|
||||
) -> None:
|
||||
if record.get("deleted_at") is not None:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "临时附件不存在")
|
||||
if int(record.get("user_id") or 0) != int(CurrentUserId):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该临时附件")
|
||||
if str(record.get("conversation_id") or "") != str(ConversationId or ""):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前会话")
|
||||
expected_tenant = str(TenantCode or "").strip()
|
||||
record_tenant = str(record.get("tenant_code") or "").strip()
|
||||
if expected_tenant and record_tenant and expected_tenant != record_tenant:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前租户")
|
||||
expected_area = str(UserArea or "").strip()
|
||||
record_area = str(record.get("area") or "").strip()
|
||||
if not expected_tenant and expected_area and record_area and expected_area != record_area:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前地区")
|
||||
|
||||
expires_at = self._ensure_aware_datetime(record.get("expires_at"))
|
||||
now = self._ensure_aware_datetime(Now or datetime.now(timezone.utc))
|
||||
if expires_at <= now:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件已过期,请重新上传")
|
||||
|
||||
if RequireCompleted and str(record.get("indexing_status") or "") != "completed":
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件尚未完成解析,请稍后再试")
|
||||
|
||||
async def _run_attachment_indexing_task(
|
||||
self,
|
||||
*,
|
||||
attachment_id: str,
|
||||
tenant_code: str | None,
|
||||
user_id: int,
|
||||
conversation_id: str,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
collection_name: str,
|
||||
) -> None:
|
||||
try:
|
||||
await self._update_attachment_state(attachment_id=attachment_id, status="parsing")
|
||||
page_texts = await self._extract_page_texts(FileName=file_name, Content=content)
|
||||
if not page_texts:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文件未提取到可检索文本")
|
||||
|
||||
await self._update_attachment_state(attachment_id=attachment_id, status="splitting")
|
||||
chunks = self.BuildChunks(
|
||||
AttachmentId=attachment_id,
|
||||
TenantCode=tenant_code,
|
||||
UserId=user_id,
|
||||
ConversationId=conversation_id,
|
||||
FileName=file_name,
|
||||
PageTexts=page_texts,
|
||||
)
|
||||
if not chunks:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文件未切分出可检索文本")
|
||||
|
||||
await self._update_attachment_state(attachment_id=attachment_id, status="indexing", chunk_count=len(chunks))
|
||||
embeddings = await self.retriever._embed_texts([item["text"] for item in chunks], "")
|
||||
collection = self._get_chroma().get_or_create_collection(collection_name)
|
||||
collection.add(
|
||||
ids=[item["id"] for item in chunks],
|
||||
documents=[item["text"] for item in chunks],
|
||||
embeddings=embeddings,
|
||||
metadatas=[item["metadata"] for item in chunks],
|
||||
)
|
||||
|
||||
await self._update_attachment_state(
|
||||
attachment_id=attachment_id,
|
||||
status="completed",
|
||||
chunk_count=len(chunks),
|
||||
completed=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
await self._update_attachment_state(
|
||||
attachment_id=attachment_id,
|
||||
status="error",
|
||||
error=str(exc)[:2000],
|
||||
)
|
||||
|
||||
async def _extract_page_texts(self, *, FileName: str, Content: bytes) -> list[tuple[int, str]]:
|
||||
suffix = Path(FileName).suffix.lower()
|
||||
if suffix in {".txt", ".md"}:
|
||||
text_value = Content.decode("utf-8", errors="ignore").strip()
|
||||
return [(1, text_value)] if text_value else []
|
||||
if suffix == ".json":
|
||||
return self.dataset_helpers._extract_page_texts_from_json(Content)
|
||||
if suffix == ".csv":
|
||||
return [(1, self._csv_to_text(Content))] if self._csv_to_text(Content).strip() else []
|
||||
if suffix == ".xlsx":
|
||||
return self._extract_page_texts_from_xlsx(Content)
|
||||
if suffix in SUPPORTED_IMAGE_SUFFIXES:
|
||||
return await self._extract_page_texts_from_image(FileName, Content)
|
||||
return self.dataset_helpers._extract_page_texts(FileName=FileName, Content=Content)
|
||||
|
||||
def _csv_to_text(self, content: bytes) -> str:
|
||||
text_value = content.decode("utf-8-sig", errors="ignore")
|
||||
rows: list[str] = []
|
||||
for row_index, row in enumerate(csv.reader(text_value.splitlines()), start=1):
|
||||
values = [str(cell).strip() for cell in row if str(cell).strip()]
|
||||
if values:
|
||||
rows.append(f"row_{row_index}: " + " | ".join(values))
|
||||
return "\n".join(rows)
|
||||
|
||||
def _extract_page_texts_from_xlsx(self, content: bytes) -> list[tuple[int, str]]:
|
||||
try:
|
||||
from openpyxl import load_workbook
|
||||
except Exception as exc: # pragma: no cover - depends on optional env package
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前环境未安装 openpyxl,暂无法解析 Excel 文件") from exc
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as temp_file:
|
||||
temp_file.write(content)
|
||||
temp_path = temp_file.name
|
||||
try:
|
||||
workbook = load_workbook(temp_path, data_only=True, read_only=True)
|
||||
page_texts: list[tuple[int, str]] = []
|
||||
for sheet_index, sheet in enumerate(workbook.worksheets, start=1):
|
||||
lines: list[str] = [f"Sheet: {sheet.title}"]
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
values = [str(cell).strip() for cell in row if cell is not None and str(cell).strip()]
|
||||
if values:
|
||||
lines.append(" | ".join(values))
|
||||
rendered = "\n".join(lines).strip()
|
||||
if rendered:
|
||||
page_texts.append((sheet_index, rendered))
|
||||
return page_texts
|
||||
finally:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
async def _extract_page_texts_from_image(self, file_name: str, content: bytes) -> list[tuple[int, str]]:
|
||||
suffix = Path(file_name).suffix.lower() or ".png"
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file:
|
||||
temp_file.write(content)
|
||||
temp_path = Path(temp_file.name)
|
||||
try:
|
||||
ocr_client = self._create_ocr_client()
|
||||
result = await ocr_client.ocr(temp_path)
|
||||
text_value = self._ocr_result_to_text(result)
|
||||
return [(1, text_value)] if text_value else []
|
||||
finally:
|
||||
try:
|
||||
temp_path.unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _ocr_result_to_text(self, value: Any) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return value.strip()
|
||||
if isinstance(value, dict):
|
||||
candidates: list[str] = []
|
||||
for key in ("full_text", "text", "content", "markdown"):
|
||||
if isinstance(value.get(key), str) and value.get(key).strip():
|
||||
candidates.append(value[key].strip())
|
||||
pages = value.get("pages")
|
||||
if isinstance(pages, list):
|
||||
for page in pages:
|
||||
page_text = self._ocr_result_to_text(page)
|
||||
if page_text:
|
||||
candidates.append(page_text)
|
||||
lines = value.get("lines") or value.get("blocks")
|
||||
if isinstance(lines, list):
|
||||
for line in lines:
|
||||
line_text = self._ocr_result_to_text(line)
|
||||
if line_text:
|
||||
candidates.append(line_text)
|
||||
return "\n".join(dict.fromkeys(candidates)).strip()
|
||||
if isinstance(value, list):
|
||||
return "\n".join(item for item in (self._ocr_result_to_text(item) for item in value) if item).strip()
|
||||
|
||||
text_attrs = []
|
||||
for attr in ("full_text", "text", "content", "markdown"):
|
||||
raw = getattr(value, attr, None)
|
||||
if isinstance(raw, str) and raw.strip():
|
||||
text_attrs.append(raw.strip())
|
||||
pages = getattr(value, "pages", None)
|
||||
if isinstance(pages, list):
|
||||
for page in pages:
|
||||
page_text = self._ocr_result_to_text(page)
|
||||
if page_text:
|
||||
text_attrs.append(page_text)
|
||||
return "\n".join(dict.fromkeys(text_attrs)).strip()
|
||||
|
||||
def _create_ocr_client(self):
|
||||
if self._ocr_client_factory is not None:
|
||||
return self._ocr_client_factory()
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.client_factory import create_ocr_client
|
||||
|
||||
return create_ocr_client()
|
||||
|
||||
async def _update_attachment_state(
|
||||
self,
|
||||
*,
|
||||
attachment_id: str,
|
||||
status: str,
|
||||
chunk_count: int | None = None,
|
||||
error: str | None = None,
|
||||
completed: bool = False,
|
||||
) -> None:
|
||||
fields = [
|
||||
"indexing_status = :status",
|
||||
"indexing_error = :error",
|
||||
"updated_at = NOW()",
|
||||
"indexing_started_at = COALESCE(indexing_started_at, NOW())",
|
||||
]
|
||||
if chunk_count is not None:
|
||||
fields.append("chunk_count = :chunk_count")
|
||||
if completed:
|
||||
fields.append("indexing_completed_at = NOW()")
|
||||
params: dict[str, Any] = {
|
||||
"attachment_id": attachment_id,
|
||||
"status": status,
|
||||
"error": error,
|
||||
"chunk_count": chunk_count,
|
||||
}
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_attachment_schema(session)
|
||||
await session.execute(
|
||||
text(
|
||||
f"""
|
||||
UPDATE rag_chat_attachment
|
||||
SET {", ".join(fields)}
|
||||
WHERE attachment_id = :attachment_id
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
async def _ensure_attachment_schema(self, session) -> None:
|
||||
if self.__class__._attachment_schema_checked:
|
||||
return
|
||||
async with self.__class__._attachment_schema_lock:
|
||||
if self.__class__._attachment_schema_checked:
|
||||
return
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS rag_chat_attachment (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
attachment_id VARCHAR(64) NOT NULL UNIQUE,
|
||||
conversation_id VARCHAR(64) NOT NULL,
|
||||
user_id BIGINT NOT NULL,
|
||||
tenant_code VARCHAR(64) NULL,
|
||||
area VARCHAR(255) NULL,
|
||||
filename VARCHAR(512) NOT NULL,
|
||||
original_name VARCHAR(512) NOT NULL,
|
||||
content_type VARCHAR(255) NULL,
|
||||
file_size BIGINT NOT NULL DEFAULT 0,
|
||||
minio_path TEXT NULL,
|
||||
collection_name VARCHAR(160) NOT NULL,
|
||||
indexing_status VARCHAR(32) NOT NULL DEFAULT 'waiting',
|
||||
indexing_error TEXT NULL,
|
||||
chunk_count INTEGER NOT NULL DEFAULT 0,
|
||||
indexing_started_at TIMESTAMPTZ NULL,
|
||||
indexing_completed_at TIMESTAMPTZ NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
deleted_at TIMESTAMPTZ NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_scope
|
||||
ON rag_chat_attachment(tenant_code, user_id, conversation_id, attachment_id)
|
||||
WHERE deleted_at IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_expires
|
||||
ON rag_chat_attachment(expires_at)
|
||||
WHERE deleted_at IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
self.__class__._attachment_schema_checked = True
|
||||
|
||||
def _get_chroma(self) -> Any:
|
||||
return self._chroma_client or get_chroma()
|
||||
|
||||
def _delete_collection(self, collection_name: str) -> None:
|
||||
if not collection_name:
|
||||
return
|
||||
try:
|
||||
self._get_chroma().delete_collection(collection_name)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def _delete_oss_object(self, source: str | None) -> None:
|
||||
if not source:
|
||||
return
|
||||
try:
|
||||
oss = OssClient()
|
||||
ref = oss.ResolveObjectRef(Source=source)
|
||||
if ref.isDirectUrl or not ref.objectKey:
|
||||
return
|
||||
oss._GetMinioClient().remove_object(ref.bucket, ref.objectKey)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
def _to_vo(self, row: dict) -> RagChatAttachmentVO:
|
||||
expires_at = row.get("expires_at")
|
||||
created_at = row.get("created_at")
|
||||
return RagChatAttachmentVO(
|
||||
attachmentId=str(row.get("attachment_id") or ""),
|
||||
conversationId=str(row.get("conversation_id") or ""),
|
||||
fileName=str(row.get("original_name") or row.get("filename") or ""),
|
||||
contentType=str(row.get("content_type") or ""),
|
||||
fileSize=int(row.get("file_size") or 0),
|
||||
indexingStatus=str(row.get("indexing_status") or "waiting"),
|
||||
indexingError=row.get("indexing_error"),
|
||||
chunkCount=int(row.get("chunk_count") or 0),
|
||||
collectionName=str(row.get("collection_name") or ""),
|
||||
expiresAt=self._timestamp(expires_at),
|
||||
createdAt=self._timestamp(created_at),
|
||||
)
|
||||
|
||||
def _preprocess_text(self, text_value: str) -> str:
|
||||
result = text_value or ""
|
||||
result = re.sub(r"[ \t]+", " ", result)
|
||||
result = re.sub(r"\n{3,}", "\n\n", result)
|
||||
return result.strip()
|
||||
|
||||
def _split_text(self, text_value: str, *, separator: str, max_chars: int, overlap: int) -> list[str]:
|
||||
return self.dataset_helpers._split_text(text_value, separator=separator, max_chars=max_chars, overlap=overlap)
|
||||
|
||||
def _slugify(self, value: str) -> str:
|
||||
normalized = re.sub(r"[^A-Za-z0-9_]+", "_", str(value or "").strip())
|
||||
normalized = re.sub(r"_+", "_", normalized).strip("_").lower()
|
||||
return normalized or "default"
|
||||
|
||||
def _validate_conversation_id(self, conversation_id: str | None) -> None:
|
||||
if not str(conversation_id or "").strip() or str(conversation_id or "").strip() == "-1":
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "会话ID不能为空")
|
||||
|
||||
def _ensure_aware_datetime(self, value: Any) -> datetime:
|
||||
if isinstance(value, datetime):
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
if parsed.tzinfo is None:
|
||||
return parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件时间字段异常")
|
||||
|
||||
def _timestamp(self, value: Any) -> int:
|
||||
if not value:
|
||||
return 0
|
||||
return int(self._ensure_aware_datetime(value).timestamp())
|
||||
@@ -106,6 +106,8 @@ class RagChatServiceImpl(IRagChatService):
|
||||
Query: str,
|
||||
ConversationId: str | None,
|
||||
AppId: int | None,
|
||||
AttachmentId: str | None = None,
|
||||
AttachmentIds: list[str] | None = None,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
@@ -128,6 +130,25 @@ class RagChatServiceImpl(IRagChatService):
|
||||
messageId = str(uuid.uuid4())
|
||||
taskId = str(uuid.uuid4())
|
||||
is_new_conversation = not ConversationId or ConversationId == "-1"
|
||||
active_attachment_ids = self._normalize_attachment_ids(
|
||||
attachment_id=AttachmentId,
|
||||
attachment_ids=AttachmentIds,
|
||||
)
|
||||
if not active_attachment_ids:
|
||||
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=conversationId,
|
||||
)
|
||||
attachment_records = await self._load_message_attachment_records(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=conversationId,
|
||||
AttachmentIds=active_attachment_ids,
|
||||
)
|
||||
user_message_files = self._build_user_message_files(attachment_records)
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
async with session.begin():
|
||||
@@ -135,13 +156,14 @@ class RagChatServiceImpl(IRagChatService):
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
|
||||
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, '{}'::jsonb)
|
||||
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, CAST(:metadata AS jsonb))
|
||||
"""
|
||||
),
|
||||
{
|
||||
"message_id": str(uuid.uuid4()),
|
||||
"conversation_id": conversationId,
|
||||
"content": Query,
|
||||
"metadata": json.dumps({"message_files": user_message_files}, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
await session.execute(
|
||||
@@ -170,6 +192,11 @@ class RagChatServiceImpl(IRagChatService):
|
||||
message_id=messageId,
|
||||
query=Query,
|
||||
app=app,
|
||||
current_user_id=CurrentUserId,
|
||||
tenant_code=TenantCode,
|
||||
user_area=UserArea,
|
||||
attachment_id=AttachmentId,
|
||||
attachment_ids=active_attachment_ids,
|
||||
)
|
||||
|
||||
event_index = 0
|
||||
@@ -322,6 +349,7 @@ class RagChatServiceImpl(IRagChatService):
|
||||
row = items[idx]
|
||||
if row["role"] == "user":
|
||||
answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None
|
||||
user_metadata = dict(row.get("metadata") or {})
|
||||
answer_metadata = dict((answer.get("metadata") if answer else None) or {})
|
||||
answer_status = str(answer_metadata.get("status") or ("completed" if answer else "running"))
|
||||
answer_content = (answer.get("content") if answer else None) or ""
|
||||
@@ -355,6 +383,7 @@ class RagChatServiceImpl(IRagChatService):
|
||||
conversationId=ConversationId,
|
||||
query=row["content"],
|
||||
answer=answer_content if answer else "",
|
||||
messageFiles=[item for item in (user_metadata.get("message_files") or []) if isinstance(item, dict)],
|
||||
feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None),
|
||||
retrieverResources=(answer.get("sources") if answer else None),
|
||||
suggestedQuestions=[str(item) for item in (answer_metadata.get("suggested_questions") or []) if str(item).strip()],
|
||||
@@ -940,7 +969,181 @@ class RagChatServiceImpl(IRagChatService):
|
||||
|
||||
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
|
||||
result = await self.retriever.retrieve(query=query, dataset_id=dataset_id)
|
||||
return result.chunks, result.dataset_name
|
||||
chunks = [
|
||||
{
|
||||
**chunk,
|
||||
"source_scope": chunk.get("source_scope") or "formal_kb",
|
||||
"data_source_type": chunk.get("data_source_type") or "formal_kb",
|
||||
}
|
||||
for chunk in result.chunks
|
||||
]
|
||||
return chunks, result.dataset_name
|
||||
|
||||
async def _resolve_attachment_id_for_conversation(
|
||||
self,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
UserArea: str | None,
|
||||
ConversationId: str,
|
||||
) -> str | None:
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
service = RagChatAttachmentServiceImpl()
|
||||
return await service.ResolveActiveAttachmentIdForConversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
)
|
||||
|
||||
async def _resolve_attachment_ids_for_conversation(
|
||||
self,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
UserArea: str | None,
|
||||
ConversationId: str,
|
||||
) -> list[str]:
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
service = RagChatAttachmentServiceImpl()
|
||||
if hasattr(service, "ResolveActiveAttachmentIdsForConversation"):
|
||||
return await service.ResolveActiveAttachmentIdsForConversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
)
|
||||
attachment_id = await service.ResolveActiveAttachmentIdForConversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
)
|
||||
return [attachment_id] if attachment_id else []
|
||||
|
||||
async def _retrieve_attachment_context(
|
||||
self,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
UserArea: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
Query: str,
|
||||
) -> tuple[list[dict], str]:
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
service = RagChatAttachmentServiceImpl()
|
||||
return await service.RetrieveAttachmentContext(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
AttachmentId=AttachmentId,
|
||||
Query=Query,
|
||||
TopK=5,
|
||||
)
|
||||
|
||||
async def _load_message_attachment_records(
|
||||
self,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
UserArea: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentIds: list[str],
|
||||
) -> list[dict]:
|
||||
if not AttachmentIds:
|
||||
return []
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
service = RagChatAttachmentServiceImpl()
|
||||
records: list[dict] = []
|
||||
for attachment_id in AttachmentIds:
|
||||
record = await service.ValidateAttachmentForChat(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
AttachmentId=attachment_id,
|
||||
)
|
||||
records.append(record)
|
||||
return records
|
||||
|
||||
def _build_user_message_files(self, attachment_records: list[dict]) -> list[dict]:
|
||||
files: list[dict] = []
|
||||
for record in attachment_records:
|
||||
attachment_id = str(record.get("attachment_id") or "").strip()
|
||||
if not attachment_id:
|
||||
continue
|
||||
file_name = str(record.get("original_name") or record.get("filename") or "上传文件")
|
||||
content_type = str(record.get("content_type") or "")
|
||||
file_type = "image" if content_type.startswith("image/") or re.search(r"\.(png|jpe?g|webp|bmp|tiff?)$", file_name, re.I) else "file"
|
||||
files.append(
|
||||
{
|
||||
"id": attachment_id,
|
||||
"upload_file_id": attachment_id,
|
||||
"name": file_name,
|
||||
"fileName": file_name,
|
||||
"type": file_type,
|
||||
"transfer_method": "local_file",
|
||||
"contentType": content_type or None,
|
||||
"fileSize": int(record.get("file_size") or 0),
|
||||
"belongs_to": "user",
|
||||
"usage": "temporary_attachment",
|
||||
}
|
||||
)
|
||||
return files
|
||||
|
||||
def _build_formal_kb_query(self, *, query: str, attachment_chunks: list[dict]) -> str:
|
||||
if not attachment_chunks:
|
||||
return query
|
||||
|
||||
facts: list[str] = []
|
||||
for chunk in attachment_chunks[:3]:
|
||||
text_value = str(chunk.get("text") or "").strip()
|
||||
if text_value:
|
||||
facts.append(text_value[:500])
|
||||
if not facts:
|
||||
return query
|
||||
|
||||
return (
|
||||
f"{query}\n\n"
|
||||
"用户上传文档中检索到的相关事实:\n"
|
||||
+ "\n".join(f"- {item}" for item in facts)
|
||||
+ "\n\n请检索这些事实对应的法律责任、处罚依据、裁量规则或案例。"
|
||||
)
|
||||
|
||||
def _merge_context_chunks(self, *, attachment_chunks: list[dict], formal_chunks: list[dict]) -> list[dict]:
|
||||
merged: list[dict] = []
|
||||
for chunk in attachment_chunks:
|
||||
merged.append(
|
||||
{
|
||||
**chunk,
|
||||
"source_scope": "chat_attachment",
|
||||
"data_source_type": "chat_attachment",
|
||||
}
|
||||
)
|
||||
for chunk in formal_chunks:
|
||||
merged.append(
|
||||
{
|
||||
**chunk,
|
||||
"source_scope": chunk.get("source_scope") or "formal_kb",
|
||||
"data_source_type": chunk.get("data_source_type") or "formal_kb",
|
||||
}
|
||||
)
|
||||
return merged
|
||||
|
||||
def _normalize_attachment_ids(self, *, attachment_id: str | None, attachment_ids: list[str] | None) -> list[str]:
|
||||
normalized: list[str] = []
|
||||
for raw in [*(attachment_ids or []), attachment_id]:
|
||||
value = str(raw or "").strip()
|
||||
if value and value not in normalized:
|
||||
normalized.append(value)
|
||||
return normalized
|
||||
|
||||
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
|
||||
return await self.retriever._embed_texts(texts, model_name)
|
||||
@@ -953,6 +1156,11 @@ class RagChatServiceImpl(IRagChatService):
|
||||
message_id: str,
|
||||
query: str,
|
||||
app: dict,
|
||||
current_user_id: int | None = None,
|
||||
tenant_code: str | None = None,
|
||||
user_area: str | None = None,
|
||||
attachment_id: str | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
self._task_events[task_id] = []
|
||||
self._task_done[task_id] = False
|
||||
@@ -964,6 +1172,11 @@ class RagChatServiceImpl(IRagChatService):
|
||||
message_id=message_id,
|
||||
query=query,
|
||||
app=app,
|
||||
current_user_id=current_user_id,
|
||||
tenant_code=tenant_code,
|
||||
user_area=user_area,
|
||||
attachment_id=attachment_id,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
)
|
||||
self._message_tasks[task_id] = task
|
||||
@@ -976,6 +1189,11 @@ class RagChatServiceImpl(IRagChatService):
|
||||
message_id: str,
|
||||
query: str,
|
||||
app: dict,
|
||||
current_user_id: int | None = None,
|
||||
tenant_code: str | None = None,
|
||||
user_area: str | None = None,
|
||||
attachment_id: str | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
context_chunks: list[dict] = []
|
||||
dataset_name = ""
|
||||
@@ -987,7 +1205,46 @@ class RagChatServiceImpl(IRagChatService):
|
||||
last_persisted_at = time.monotonic()
|
||||
|
||||
try:
|
||||
context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), query)
|
||||
attachment_chunks: list[dict] = []
|
||||
attachment_names: list[str] = []
|
||||
active_attachment_ids = self._normalize_attachment_ids(
|
||||
attachment_id=attachment_id,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
if not active_attachment_ids and current_user_id is not None:
|
||||
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
|
||||
CurrentUserId=current_user_id,
|
||||
TenantCode=tenant_code,
|
||||
UserArea=user_area,
|
||||
ConversationId=conversation_id,
|
||||
)
|
||||
if active_attachment_ids:
|
||||
if current_user_id is None:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件缺少用户上下文")
|
||||
for active_attachment_id in active_attachment_ids:
|
||||
next_chunks, next_attachment_name = await self._retrieve_attachment_context(
|
||||
CurrentUserId=current_user_id,
|
||||
TenantCode=tenant_code,
|
||||
UserArea=user_area,
|
||||
ConversationId=conversation_id,
|
||||
AttachmentId=active_attachment_id,
|
||||
Query=query,
|
||||
)
|
||||
attachment_chunks.extend(next_chunks)
|
||||
if next_attachment_name:
|
||||
attachment_names.append(next_attachment_name)
|
||||
legal_query = self._build_formal_kb_query(query=query, attachment_chunks=attachment_chunks)
|
||||
formal_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), legal_query)
|
||||
context_chunks = self._merge_context_chunks(
|
||||
attachment_chunks=attachment_chunks,
|
||||
formal_chunks=formal_chunks,
|
||||
)
|
||||
generation_dataset_name = dataset_name
|
||||
attachment_name = "、".join(dict.fromkeys(attachment_names))
|
||||
if attachment_name and dataset_name:
|
||||
generation_dataset_name = f"{attachment_name} + {dataset_name}"
|
||||
elif attachment_name:
|
||||
generation_dataset_name = attachment_name
|
||||
async for chunk in generate_stream(
|
||||
query=query,
|
||||
context_chunks=context_chunks,
|
||||
@@ -997,7 +1254,7 @@ class RagChatServiceImpl(IRagChatService):
|
||||
model=app.get("llm_model") or "",
|
||||
temperature=app.get("temperature"),
|
||||
max_tokens=app.get("max_tokens"),
|
||||
dataset_name=dataset_name,
|
||||
dataset_name=generation_dataset_name,
|
||||
task_id=task_id,
|
||||
):
|
||||
data = self._parse_sse_event(chunk)
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatAttachmentVo import (
|
||||
RagChatAttachmentDeleteVO,
|
||||
RagChatAttachmentVO,
|
||||
)
|
||||
|
||||
|
||||
class IRagChatAttachmentService(ABC):
|
||||
@abstractmethod
|
||||
async def CreateAttachment(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None,
|
||||
TenantName: str | None,
|
||||
ConversationId: str | None,
|
||||
AppId: int | None,
|
||||
FileName: str,
|
||||
ContentType: str | None,
|
||||
Content: bytes,
|
||||
) -> RagChatAttachmentVO: ...
|
||||
|
||||
@abstractmethod
|
||||
async def GetAttachment(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None,
|
||||
TenantName: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
) -> RagChatAttachmentVO: ...
|
||||
|
||||
@abstractmethod
|
||||
async def DeleteAttachment(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None,
|
||||
TenantName: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
) -> RagChatAttachmentDeleteVO: ...
|
||||
|
||||
@abstractmethod
|
||||
async def ValidateAttachmentForChat(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
UserArea: str | None = None,
|
||||
) -> dict: ...
|
||||
|
||||
@abstractmethod
|
||||
async def RetrieveAttachmentContext(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
Query: str,
|
||||
TopK: int = 5,
|
||||
UserArea: str | None = None,
|
||||
) -> tuple[list[dict], str]: ...
|
||||
|
||||
@abstractmethod
|
||||
async def ResolveActiveAttachmentIdForConversation(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
UserArea: str | None = None,
|
||||
) -> str | None: ...
|
||||
|
||||
@abstractmethod
|
||||
async def ResolveActiveAttachmentIdsForConversation(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
ConversationId: str,
|
||||
UserArea: str | None = None,
|
||||
) -> list[str]: ...
|
||||
@@ -50,6 +50,8 @@ class IRagChatService(ABC):
|
||||
Query: str,
|
||||
ConversationId: str | None,
|
||||
AppId: int | None,
|
||||
AttachmentId: str | None = None,
|
||||
AttachmentIds: list[str] | None = None,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> AsyncGenerator[bytes, None]: ...
|
||||
|
||||
Reference in New Issue
Block a user