feat(rag): add temporary chat attachments

This commit is contained in:
wren
2026-05-25 15:37:37 +08:00
parent 0f385c9839
commit 75c077da77
16 changed files with 2257 additions and 16 deletions
@@ -0,0 +1,821 @@
"""Temporary RAG attachments for chat conversations."""
from __future__ import annotations
import asyncio
import csv
import hashlib
import json
import mimetypes
import os
import re
import tempfile
import uuid
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Awaitable, Callable
from sqlalchemy import text
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_storage.oss_client import OssClient
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatAttachmentVo import (
RagChatAttachmentDeleteVO,
RagChatAttachmentVO,
)
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import EmbedTexts, RagRetriever
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
from fastapi_modules.fastapi_leaudit.services.ragChatAttachmentService import IRagChatAttachmentService
DEFAULT_ATTACHMENT_TTL_DAYS = 7
SUPPORTED_TEXT_SUFFIXES = {".txt", ".md", ".json", ".csv"}
SUPPORTED_DOCUMENT_SUFFIXES = {".docx", ".pdf", ".xlsx"}
SUPPORTED_IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tif", ".tiff"}
SUPPORTED_SUFFIXES = SUPPORTED_TEXT_SUFFIXES | SUPPORTED_DOCUMENT_SUFFIXES | SUPPORTED_IMAGE_SUFFIXES
class RagChatAttachmentServiceImpl(IRagChatAttachmentService):
"""Manage temporary, conversation-scoped RAG attachment indexes."""
_attachment_schema_checked = False
_attachment_schema_lock = asyncio.Lock()
def __init__(
self,
*,
chroma_client: Any | None = None,
embed_texts: EmbedTexts | None = None,
chat_service: RagChatServiceImpl | None = None,
ocr_client_factory: Callable[[], Any] | None = None,
) -> None:
self._chroma_client = chroma_client
self.retriever = RagRetriever(
chroma_client=chroma_client,
embed_texts=embed_texts,
hydrate_documents=False,
)
self.chat_service = chat_service or RagChatServiceImpl()
self.dataset_helpers = RagDatasetServiceImpl()
self._ocr_client_factory = ocr_client_factory
def _default_expires_at(self, now: datetime | None = None) -> datetime:
base = self._ensure_aware_datetime(now or datetime.now(timezone.utc))
return base + timedelta(days=DEFAULT_ATTACHMENT_TTL_DAYS)
def BuildCollectionName(self, *, TenantCode: str | None, UserId: int, ConversationId: str, AttachmentId: str) -> str:
tenant = self._slugify(TenantCode or "default")
user = self._slugify(str(UserId))
conversation_hash = hashlib.sha1(ConversationId.encode("utf-8")).hexdigest()[:16]
attachment = self._slugify(AttachmentId)
collection_name = f"chat_attachment_{tenant}_{user}_{conversation_hash}_{attachment}"
return collection_name[:120]
def BuildChunks(
self,
*,
AttachmentId: str,
TenantCode: str | None,
UserId: int,
ConversationId: str,
FileName: str,
PageTexts: list[tuple[int, str]],
ChunkMaxSize: int = 800,
ChunkOverlap: int = 80,
) -> list[dict[str, Any]]:
chunk_max_size = max(50, int(ChunkMaxSize or 800))
chunk_overlap = max(0, min(int(ChunkOverlap or 0), chunk_max_size // 2))
chunks: list[dict[str, Any]] = []
for page_no, raw_text in PageTexts:
text_value = self._preprocess_text(raw_text)
if not text_value:
continue
split_chunks = self._split_text(text_value, separator="\n\n", max_chars=chunk_max_size, overlap=chunk_overlap)
for index, chunk_text in enumerate(split_chunks):
if not chunk_text.strip():
continue
chunk_id = f"{AttachmentId}:{page_no}:{index}"
chunks.append(
{
"id": chunk_id,
"text": chunk_text,
"metadata": {
"id": chunk_id,
"source": FileName,
"document_name": FileName,
"attachment_id": AttachmentId,
"tenant_code": str(TenantCode or ""),
"user_id": str(UserId),
"conversation_id": ConversationId,
"source_scope": "chat_attachment",
"page": int(page_no),
"chunk_index": index,
},
}
)
return chunks
async def CreateAttachment(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
ConversationId: str | None,
AppId: int | None,
FileName: str,
ContentType: str | None,
Content: bytes,
) -> RagChatAttachmentVO:
if not FileName:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "上传文件名不能为空")
if not Content:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "上传文件内容不能为空")
suffix = Path(FileName).suffix.lower()
if suffix not in SUPPORTED_SUFFIXES:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 TXT、MD、JSON、CSV、DOCX、PDF、XLSX 和图片文件")
app = await self.chat_service._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName)
if not app:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用")
conversation_id = await self.chat_service._ensure_conversation(
user_id=CurrentUserId,
conversation_id=ConversationId,
app_id=app["id"],
user_area=UserArea,
user_role=UserRole,
tenant_code=TenantCode,
tenant_name=TenantName,
)
attachment_id = str(uuid.uuid4())
collection_name = self.BuildCollectionName(
TenantCode=TenantCode,
UserId=CurrentUserId,
ConversationId=conversation_id,
AttachmentId=attachment_id,
)
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
expires_at = self._default_expires_at()
object_key = f"rag-chat-attachments/{TenantCode or 'default'}/{CurrentUserId}/{conversation_id}/{attachment_id}_{FileName}"
OssClient().EnsureBucket()
stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type)
async with GetAsyncSession() as session:
await self._ensure_attachment_schema(session)
row = (
await session.execute(
text(
"""
INSERT INTO rag_chat_attachment (
attachment_id, conversation_id, user_id, tenant_code, area,
filename, original_name, content_type, file_size, minio_path,
collection_name, indexing_status, chunk_count, expires_at
) VALUES (
:attachment_id, :conversation_id, :user_id, :tenant_code, :area,
:filename, :original_name, :content_type, :file_size, :minio_path,
:collection_name, 'waiting', 0, :expires_at
)
RETURNING *
"""
),
{
"attachment_id": attachment_id,
"conversation_id": conversation_id,
"user_id": CurrentUserId,
"tenant_code": TenantCode,
"area": UserArea,
"filename": f"{attachment_id}{suffix}",
"original_name": FileName,
"content_type": content_type,
"file_size": len(Content),
"minio_path": stored_key,
"collection_name": collection_name,
"expires_at": expires_at,
},
)
).mappings().first()
asyncio.create_task(
self._run_attachment_indexing_task(
attachment_id=attachment_id,
tenant_code=TenantCode,
user_id=CurrentUserId,
conversation_id=conversation_id,
file_name=FileName,
content=Content,
collection_name=collection_name,
)
)
return self._to_vo(dict(row))
async def GetAttachment(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
ConversationId: str,
AttachmentId: str,
) -> RagChatAttachmentVO:
self._validate_conversation_id(ConversationId)
record = await self._get_attachment_record(AttachmentId)
self._assert_attachment_scope(
record,
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
RequireCompleted=False,
Now=datetime.now(timezone.utc),
)
return self._to_vo(record)
async def DeleteAttachment(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
ConversationId: str,
AttachmentId: str,
) -> RagChatAttachmentDeleteVO:
self._validate_conversation_id(ConversationId)
record = await self._get_attachment_record(AttachmentId)
self._assert_attachment_scope(
record,
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
RequireCompleted=False,
Now=datetime.now(timezone.utc),
)
async with GetAsyncSession() as session:
await self._ensure_attachment_schema(session)
await session.execute(
text(
"""
UPDATE rag_chat_attachment
SET deleted_at = NOW(), updated_at = NOW()
WHERE attachment_id = :attachment_id
"""
),
{"attachment_id": AttachmentId},
)
self._delete_collection(str(record.get("collection_name") or ""))
self._delete_oss_object(str(record.get("minio_path") or ""))
return RagChatAttachmentDeleteVO(result="success")
async def ValidateAttachmentForChat(
self,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
AttachmentId: str,
UserArea: str | None = None,
) -> dict:
record = await self._get_attachment_record(AttachmentId)
self._assert_attachment_scope(
record,
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
RequireCompleted=True,
Now=datetime.now(timezone.utc),
)
return record
async def RetrieveAttachmentContext(
self,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
AttachmentId: str,
Query: str,
TopK: int = 5,
UserArea: str | None = None,
) -> tuple[list[dict], str]:
record = await self.ValidateAttachmentForChat(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
AttachmentId=AttachmentId,
)
result = await self.retriever.retrieve(
query=Query,
collection_name=str(record["collection_name"]),
top_k=TopK,
)
chunks: list[dict] = []
for chunk in result.chunks:
chunks.append(
{
**chunk,
"dataset_name": "临时上传文件",
"attachment_id": AttachmentId,
"conversation_id": ConversationId,
"source_scope": "chat_attachment",
"data_source_type": "chat_attachment",
"document_name": chunk.get("document_name") or record.get("original_name") or chunk.get("source"),
"source": chunk.get("source") or record.get("original_name") or "上传文件",
}
)
return chunks, str(record.get("original_name") or "上传文件")
async def ResolveActiveAttachmentIdForConversation(
self,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
UserArea: str | None = None,
) -> str | None:
attachment_ids = await self.ResolveActiveAttachmentIdsForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
return attachment_ids[0] if attachment_ids else None
async def ResolveActiveAttachmentIdsForConversation(
self,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
UserArea: str | None = None,
) -> list[str]:
self._validate_conversation_id(ConversationId)
async with GetAsyncSession() as session:
await self._ensure_attachment_schema(session)
rows = (
await session.execute(
text(
"""
SELECT attachment_id, tenant_code, area
FROM rag_chat_attachment
WHERE user_id = :user_id
AND conversation_id = :conversation_id
AND indexing_status = 'completed'
AND expires_at > NOW()
AND deleted_at IS NULL
AND (
NULLIF(:tenant_code, '') IS NULL
OR NULLIF(tenant_code, '') IS NULL
OR tenant_code = :tenant_code
)
AND (
NULLIF(:tenant_code, '') IS NOT NULL
OR NULLIF(:user_area, '') IS NULL
OR NULLIF(area, '') IS NULL
OR area = :user_area
)
ORDER BY indexing_completed_at DESC NULLS LAST, created_at DESC
"""
),
{
"user_id": CurrentUserId,
"conversation_id": ConversationId,
"tenant_code": TenantCode,
"user_area": UserArea,
},
)
).mappings().all()
return [str(row["attachment_id"]) for row in rows]
async def CleanupExpiredAttachments(self, *, Limit: int = 100) -> int:
async with GetAsyncSession() as session:
await self._ensure_attachment_schema(session)
rows = (
await session.execute(
text(
"""
SELECT attachment_id, collection_name, minio_path
FROM rag_chat_attachment
WHERE deleted_at IS NULL
AND expires_at <= NOW()
ORDER BY expires_at ASC
LIMIT :limit
"""
),
{"limit": max(1, int(Limit or 100))},
)
).mappings().all()
attachment_ids = [str(row["attachment_id"]) for row in rows]
if attachment_ids:
await session.execute(
text(
"""
UPDATE rag_chat_attachment
SET deleted_at = NOW(), updated_at = NOW()
WHERE attachment_id = ANY(:attachment_ids)
"""
),
{"attachment_ids": attachment_ids},
)
for row in rows:
self._delete_collection(str(row.get("collection_name") or ""))
self._delete_oss_object(str(row.get("minio_path") or ""))
return len(rows)
async def _get_attachment_record(self, attachment_id: str) -> dict:
if not str(attachment_id or "").strip():
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "附件ID不能为空")
async with GetAsyncSession() as session:
await self._ensure_attachment_schema(session)
row = (
await session.execute(
text(
"""
SELECT *
FROM rag_chat_attachment
WHERE attachment_id = :attachment_id
AND deleted_at IS NULL
LIMIT 1
"""
),
{"attachment_id": attachment_id},
)
).mappings().first()
if not row:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "临时附件不存在")
return dict(row)
def _assert_attachment_scope(
self,
record: dict,
*,
CurrentUserId: int,
TenantCode: str | None,
ConversationId: str,
RequireCompleted: bool,
Now: datetime | None = None,
UserArea: str | None = None,
) -> None:
if record.get("deleted_at") is not None:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "临时附件不存在")
if int(record.get("user_id") or 0) != int(CurrentUserId):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该临时附件")
if str(record.get("conversation_id") or "") != str(ConversationId or ""):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前会话")
expected_tenant = str(TenantCode or "").strip()
record_tenant = str(record.get("tenant_code") or "").strip()
if expected_tenant and record_tenant and expected_tenant != record_tenant:
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前租户")
expected_area = str(UserArea or "").strip()
record_area = str(record.get("area") or "").strip()
if not expected_tenant and expected_area and record_area and expected_area != record_area:
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前地区")
expires_at = self._ensure_aware_datetime(record.get("expires_at"))
now = self._ensure_aware_datetime(Now or datetime.now(timezone.utc))
if expires_at <= now:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件已过期,请重新上传")
if RequireCompleted and str(record.get("indexing_status") or "") != "completed":
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件尚未完成解析,请稍后再试")
async def _run_attachment_indexing_task(
self,
*,
attachment_id: str,
tenant_code: str | None,
user_id: int,
conversation_id: str,
file_name: str,
content: bytes,
collection_name: str,
) -> None:
try:
await self._update_attachment_state(attachment_id=attachment_id, status="parsing")
page_texts = await self._extract_page_texts(FileName=file_name, Content=content)
if not page_texts:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文件未提取到可检索文本")
await self._update_attachment_state(attachment_id=attachment_id, status="splitting")
chunks = self.BuildChunks(
AttachmentId=attachment_id,
TenantCode=tenant_code,
UserId=user_id,
ConversationId=conversation_id,
FileName=file_name,
PageTexts=page_texts,
)
if not chunks:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文件未切分出可检索文本")
await self._update_attachment_state(attachment_id=attachment_id, status="indexing", chunk_count=len(chunks))
embeddings = await self.retriever._embed_texts([item["text"] for item in chunks], "")
collection = self._get_chroma().get_or_create_collection(collection_name)
collection.add(
ids=[item["id"] for item in chunks],
documents=[item["text"] for item in chunks],
embeddings=embeddings,
metadatas=[item["metadata"] for item in chunks],
)
await self._update_attachment_state(
attachment_id=attachment_id,
status="completed",
chunk_count=len(chunks),
completed=True,
)
except Exception as exc:
await self._update_attachment_state(
attachment_id=attachment_id,
status="error",
error=str(exc)[:2000],
)
async def _extract_page_texts(self, *, FileName: str, Content: bytes) -> list[tuple[int, str]]:
suffix = Path(FileName).suffix.lower()
if suffix in {".txt", ".md"}:
text_value = Content.decode("utf-8", errors="ignore").strip()
return [(1, text_value)] if text_value else []
if suffix == ".json":
return self.dataset_helpers._extract_page_texts_from_json(Content)
if suffix == ".csv":
return [(1, self._csv_to_text(Content))] if self._csv_to_text(Content).strip() else []
if suffix == ".xlsx":
return self._extract_page_texts_from_xlsx(Content)
if suffix in SUPPORTED_IMAGE_SUFFIXES:
return await self._extract_page_texts_from_image(FileName, Content)
return self.dataset_helpers._extract_page_texts(FileName=FileName, Content=Content)
def _csv_to_text(self, content: bytes) -> str:
text_value = content.decode("utf-8-sig", errors="ignore")
rows: list[str] = []
for row_index, row in enumerate(csv.reader(text_value.splitlines()), start=1):
values = [str(cell).strip() for cell in row if str(cell).strip()]
if values:
rows.append(f"row_{row_index}: " + " | ".join(values))
return "\n".join(rows)
def _extract_page_texts_from_xlsx(self, content: bytes) -> list[tuple[int, str]]:
try:
from openpyxl import load_workbook
except Exception as exc: # pragma: no cover - depends on optional env package
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前环境未安装 openpyxl,暂无法解析 Excel 文件") from exc
with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as temp_file:
temp_file.write(content)
temp_path = temp_file.name
try:
workbook = load_workbook(temp_path, data_only=True, read_only=True)
page_texts: list[tuple[int, str]] = []
for sheet_index, sheet in enumerate(workbook.worksheets, start=1):
lines: list[str] = [f"Sheet: {sheet.title}"]
for row in sheet.iter_rows(values_only=True):
values = [str(cell).strip() for cell in row if cell is not None and str(cell).strip()]
if values:
lines.append(" | ".join(values))
rendered = "\n".join(lines).strip()
if rendered:
page_texts.append((sheet_index, rendered))
return page_texts
finally:
try:
os.unlink(temp_path)
except OSError:
pass
async def _extract_page_texts_from_image(self, file_name: str, content: bytes) -> list[tuple[int, str]]:
suffix = Path(file_name).suffix.lower() or ".png"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file:
temp_file.write(content)
temp_path = Path(temp_file.name)
try:
ocr_client = self._create_ocr_client()
result = await ocr_client.ocr(temp_path)
text_value = self._ocr_result_to_text(result)
return [(1, text_value)] if text_value else []
finally:
try:
temp_path.unlink()
except OSError:
pass
def _ocr_result_to_text(self, value: Any) -> str:
if value is None:
return ""
if isinstance(value, str):
return value.strip()
if isinstance(value, dict):
candidates: list[str] = []
for key in ("full_text", "text", "content", "markdown"):
if isinstance(value.get(key), str) and value.get(key).strip():
candidates.append(value[key].strip())
pages = value.get("pages")
if isinstance(pages, list):
for page in pages:
page_text = self._ocr_result_to_text(page)
if page_text:
candidates.append(page_text)
lines = value.get("lines") or value.get("blocks")
if isinstance(lines, list):
for line in lines:
line_text = self._ocr_result_to_text(line)
if line_text:
candidates.append(line_text)
return "\n".join(dict.fromkeys(candidates)).strip()
if isinstance(value, list):
return "\n".join(item for item in (self._ocr_result_to_text(item) for item in value) if item).strip()
text_attrs = []
for attr in ("full_text", "text", "content", "markdown"):
raw = getattr(value, attr, None)
if isinstance(raw, str) and raw.strip():
text_attrs.append(raw.strip())
pages = getattr(value, "pages", None)
if isinstance(pages, list):
for page in pages:
page_text = self._ocr_result_to_text(page)
if page_text:
text_attrs.append(page_text)
return "\n".join(dict.fromkeys(text_attrs)).strip()
def _create_ocr_client(self):
if self._ocr_client_factory is not None:
return self._ocr_client_factory()
from fastapi_modules.fastapi_leaudit.leaudit_bridge.client_factory import create_ocr_client
return create_ocr_client()
async def _update_attachment_state(
self,
*,
attachment_id: str,
status: str,
chunk_count: int | None = None,
error: str | None = None,
completed: bool = False,
) -> None:
fields = [
"indexing_status = :status",
"indexing_error = :error",
"updated_at = NOW()",
"indexing_started_at = COALESCE(indexing_started_at, NOW())",
]
if chunk_count is not None:
fields.append("chunk_count = :chunk_count")
if completed:
fields.append("indexing_completed_at = NOW()")
params: dict[str, Any] = {
"attachment_id": attachment_id,
"status": status,
"error": error,
"chunk_count": chunk_count,
}
async with GetAsyncSession() as session:
await self._ensure_attachment_schema(session)
await session.execute(
text(
f"""
UPDATE rag_chat_attachment
SET {", ".join(fields)}
WHERE attachment_id = :attachment_id
"""
),
params,
)
async def _ensure_attachment_schema(self, session) -> None:
if self.__class__._attachment_schema_checked:
return
async with self.__class__._attachment_schema_lock:
if self.__class__._attachment_schema_checked:
return
await session.execute(
text(
"""
CREATE TABLE IF NOT EXISTS rag_chat_attachment (
id BIGSERIAL PRIMARY KEY,
attachment_id VARCHAR(64) NOT NULL UNIQUE,
conversation_id VARCHAR(64) NOT NULL,
user_id BIGINT NOT NULL,
tenant_code VARCHAR(64) NULL,
area VARCHAR(255) NULL,
filename VARCHAR(512) NOT NULL,
original_name VARCHAR(512) NOT NULL,
content_type VARCHAR(255) NULL,
file_size BIGINT NOT NULL DEFAULT 0,
minio_path TEXT NULL,
collection_name VARCHAR(160) NOT NULL,
indexing_status VARCHAR(32) NOT NULL DEFAULT 'waiting',
indexing_error TEXT NULL,
chunk_count INTEGER NOT NULL DEFAULT 0,
indexing_started_at TIMESTAMPTZ NULL,
indexing_completed_at TIMESTAMPTZ NULL,
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ NULL
)
"""
)
)
await session.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_scope
ON rag_chat_attachment(tenant_code, user_id, conversation_id, attachment_id)
WHERE deleted_at IS NULL
"""
)
)
await session.execute(
text(
"""
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_expires
ON rag_chat_attachment(expires_at)
WHERE deleted_at IS NULL
"""
)
)
self.__class__._attachment_schema_checked = True
def _get_chroma(self) -> Any:
return self._chroma_client or get_chroma()
def _delete_collection(self, collection_name: str) -> None:
if not collection_name:
return
try:
self._get_chroma().delete_collection(collection_name)
except Exception:
return
def _delete_oss_object(self, source: str | None) -> None:
if not source:
return
try:
oss = OssClient()
ref = oss.ResolveObjectRef(Source=source)
if ref.isDirectUrl or not ref.objectKey:
return
oss._GetMinioClient().remove_object(ref.bucket, ref.objectKey)
except Exception:
return
def _to_vo(self, row: dict) -> RagChatAttachmentVO:
expires_at = row.get("expires_at")
created_at = row.get("created_at")
return RagChatAttachmentVO(
attachmentId=str(row.get("attachment_id") or ""),
conversationId=str(row.get("conversation_id") or ""),
fileName=str(row.get("original_name") or row.get("filename") or ""),
contentType=str(row.get("content_type") or ""),
fileSize=int(row.get("file_size") or 0),
indexingStatus=str(row.get("indexing_status") or "waiting"),
indexingError=row.get("indexing_error"),
chunkCount=int(row.get("chunk_count") or 0),
collectionName=str(row.get("collection_name") or ""),
expiresAt=self._timestamp(expires_at),
createdAt=self._timestamp(created_at),
)
def _preprocess_text(self, text_value: str) -> str:
result = text_value or ""
result = re.sub(r"[ \t]+", " ", result)
result = re.sub(r"\n{3,}", "\n\n", result)
return result.strip()
def _split_text(self, text_value: str, *, separator: str, max_chars: int, overlap: int) -> list[str]:
return self.dataset_helpers._split_text(text_value, separator=separator, max_chars=max_chars, overlap=overlap)
def _slugify(self, value: str) -> str:
normalized = re.sub(r"[^A-Za-z0-9_]+", "_", str(value or "").strip())
normalized = re.sub(r"_+", "_", normalized).strip("_").lower()
return normalized or "default"
def _validate_conversation_id(self, conversation_id: str | None) -> None:
if not str(conversation_id or "").strip() or str(conversation_id or "").strip() == "-1":
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "会话ID不能为空")
def _ensure_aware_datetime(self, value: Any) -> datetime:
if isinstance(value, datetime):
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value
if isinstance(value, str):
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
if parsed.tzinfo is None:
return parsed.replace(tzinfo=timezone.utc)
return parsed
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件时间字段异常")
def _timestamp(self, value: Any) -> int:
if not value:
return 0
return int(self._ensure_aware_datetime(value).timestamp())
@@ -106,6 +106,8 @@ class RagChatServiceImpl(IRagChatService):
Query: str,
ConversationId: str | None,
AppId: int | None,
AttachmentId: str | None = None,
AttachmentIds: list[str] | None = None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> AsyncGenerator[bytes, None]:
@@ -128,6 +130,25 @@ class RagChatServiceImpl(IRagChatService):
messageId = str(uuid.uuid4())
taskId = str(uuid.uuid4())
is_new_conversation = not ConversationId or ConversationId == "-1"
active_attachment_ids = self._normalize_attachment_ids(
attachment_id=AttachmentId,
attachment_ids=AttachmentIds,
)
if not active_attachment_ids:
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=conversationId,
)
attachment_records = await self._load_message_attachment_records(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=conversationId,
AttachmentIds=active_attachment_ids,
)
user_message_files = self._build_user_message_files(attachment_records)
async with GetAsyncSession() as session:
async with session.begin():
@@ -135,13 +156,14 @@ class RagChatServiceImpl(IRagChatService):
text(
"""
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, '{}'::jsonb)
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, CAST(:metadata AS jsonb))
"""
),
{
"message_id": str(uuid.uuid4()),
"conversation_id": conversationId,
"content": Query,
"metadata": json.dumps({"message_files": user_message_files}, ensure_ascii=False),
},
)
await session.execute(
@@ -170,6 +192,11 @@ class RagChatServiceImpl(IRagChatService):
message_id=messageId,
query=Query,
app=app,
current_user_id=CurrentUserId,
tenant_code=TenantCode,
user_area=UserArea,
attachment_id=AttachmentId,
attachment_ids=active_attachment_ids,
)
event_index = 0
@@ -322,6 +349,7 @@ class RagChatServiceImpl(IRagChatService):
row = items[idx]
if row["role"] == "user":
answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None
user_metadata = dict(row.get("metadata") or {})
answer_metadata = dict((answer.get("metadata") if answer else None) or {})
answer_status = str(answer_metadata.get("status") or ("completed" if answer else "running"))
answer_content = (answer.get("content") if answer else None) or ""
@@ -355,6 +383,7 @@ class RagChatServiceImpl(IRagChatService):
conversationId=ConversationId,
query=row["content"],
answer=answer_content if answer else "",
messageFiles=[item for item in (user_metadata.get("message_files") or []) if isinstance(item, dict)],
feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None),
retrieverResources=(answer.get("sources") if answer else None),
suggestedQuestions=[str(item) for item in (answer_metadata.get("suggested_questions") or []) if str(item).strip()],
@@ -940,7 +969,181 @@ class RagChatServiceImpl(IRagChatService):
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
result = await self.retriever.retrieve(query=query, dataset_id=dataset_id)
return result.chunks, result.dataset_name
chunks = [
{
**chunk,
"source_scope": chunk.get("source_scope") or "formal_kb",
"data_source_type": chunk.get("data_source_type") or "formal_kb",
}
for chunk in result.chunks
]
return chunks, result.dataset_name
async def _resolve_attachment_id_for_conversation(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
) -> str | None:
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
return await service.ResolveActiveAttachmentIdForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
async def _resolve_attachment_ids_for_conversation(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
) -> list[str]:
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
if hasattr(service, "ResolveActiveAttachmentIdsForConversation"):
return await service.ResolveActiveAttachmentIdsForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
attachment_id = await service.ResolveActiveAttachmentIdForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
return [attachment_id] if attachment_id else []
async def _retrieve_attachment_context(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
AttachmentId: str,
Query: str,
) -> tuple[list[dict], str]:
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
return await service.RetrieveAttachmentContext(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
AttachmentId=AttachmentId,
Query=Query,
TopK=5,
)
async def _load_message_attachment_records(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
AttachmentIds: list[str],
) -> list[dict]:
if not AttachmentIds:
return []
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
records: list[dict] = []
for attachment_id in AttachmentIds:
record = await service.ValidateAttachmentForChat(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
AttachmentId=attachment_id,
)
records.append(record)
return records
def _build_user_message_files(self, attachment_records: list[dict]) -> list[dict]:
files: list[dict] = []
for record in attachment_records:
attachment_id = str(record.get("attachment_id") or "").strip()
if not attachment_id:
continue
file_name = str(record.get("original_name") or record.get("filename") or "上传文件")
content_type = str(record.get("content_type") or "")
file_type = "image" if content_type.startswith("image/") or re.search(r"\.(png|jpe?g|webp|bmp|tiff?)$", file_name, re.I) else "file"
files.append(
{
"id": attachment_id,
"upload_file_id": attachment_id,
"name": file_name,
"fileName": file_name,
"type": file_type,
"transfer_method": "local_file",
"contentType": content_type or None,
"fileSize": int(record.get("file_size") or 0),
"belongs_to": "user",
"usage": "temporary_attachment",
}
)
return files
def _build_formal_kb_query(self, *, query: str, attachment_chunks: list[dict]) -> str:
if not attachment_chunks:
return query
facts: list[str] = []
for chunk in attachment_chunks[:3]:
text_value = str(chunk.get("text") or "").strip()
if text_value:
facts.append(text_value[:500])
if not facts:
return query
return (
f"{query}\n\n"
"用户上传文档中检索到的相关事实:\n"
+ "\n".join(f"- {item}" for item in facts)
+ "\n\n请检索这些事实对应的法律责任、处罚依据、裁量规则或案例。"
)
def _merge_context_chunks(self, *, attachment_chunks: list[dict], formal_chunks: list[dict]) -> list[dict]:
merged: list[dict] = []
for chunk in attachment_chunks:
merged.append(
{
**chunk,
"source_scope": "chat_attachment",
"data_source_type": "chat_attachment",
}
)
for chunk in formal_chunks:
merged.append(
{
**chunk,
"source_scope": chunk.get("source_scope") or "formal_kb",
"data_source_type": chunk.get("data_source_type") or "formal_kb",
}
)
return merged
def _normalize_attachment_ids(self, *, attachment_id: str | None, attachment_ids: list[str] | None) -> list[str]:
normalized: list[str] = []
for raw in [*(attachment_ids or []), attachment_id]:
value = str(raw or "").strip()
if value and value not in normalized:
normalized.append(value)
return normalized
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
return await self.retriever._embed_texts(texts, model_name)
@@ -953,6 +1156,11 @@ class RagChatServiceImpl(IRagChatService):
message_id: str,
query: str,
app: dict,
current_user_id: int | None = None,
tenant_code: str | None = None,
user_area: str | None = None,
attachment_id: str | None = None,
attachment_ids: list[str] | None = None,
) -> None:
self._task_events[task_id] = []
self._task_done[task_id] = False
@@ -964,6 +1172,11 @@ class RagChatServiceImpl(IRagChatService):
message_id=message_id,
query=query,
app=app,
current_user_id=current_user_id,
tenant_code=tenant_code,
user_area=user_area,
attachment_id=attachment_id,
attachment_ids=attachment_ids,
)
)
self._message_tasks[task_id] = task
@@ -976,6 +1189,11 @@ class RagChatServiceImpl(IRagChatService):
message_id: str,
query: str,
app: dict,
current_user_id: int | None = None,
tenant_code: str | None = None,
user_area: str | None = None,
attachment_id: str | None = None,
attachment_ids: list[str] | None = None,
) -> None:
context_chunks: list[dict] = []
dataset_name = ""
@@ -987,7 +1205,46 @@ class RagChatServiceImpl(IRagChatService):
last_persisted_at = time.monotonic()
try:
context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), query)
attachment_chunks: list[dict] = []
attachment_names: list[str] = []
active_attachment_ids = self._normalize_attachment_ids(
attachment_id=attachment_id,
attachment_ids=attachment_ids,
)
if not active_attachment_ids and current_user_id is not None:
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
CurrentUserId=current_user_id,
TenantCode=tenant_code,
UserArea=user_area,
ConversationId=conversation_id,
)
if active_attachment_ids:
if current_user_id is None:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件缺少用户上下文")
for active_attachment_id in active_attachment_ids:
next_chunks, next_attachment_name = await self._retrieve_attachment_context(
CurrentUserId=current_user_id,
TenantCode=tenant_code,
UserArea=user_area,
ConversationId=conversation_id,
AttachmentId=active_attachment_id,
Query=query,
)
attachment_chunks.extend(next_chunks)
if next_attachment_name:
attachment_names.append(next_attachment_name)
legal_query = self._build_formal_kb_query(query=query, attachment_chunks=attachment_chunks)
formal_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), legal_query)
context_chunks = self._merge_context_chunks(
attachment_chunks=attachment_chunks,
formal_chunks=formal_chunks,
)
generation_dataset_name = dataset_name
attachment_name = "".join(dict.fromkeys(attachment_names))
if attachment_name and dataset_name:
generation_dataset_name = f"{attachment_name} + {dataset_name}"
elif attachment_name:
generation_dataset_name = attachment_name
async for chunk in generate_stream(
query=query,
context_chunks=context_chunks,
@@ -997,7 +1254,7 @@ class RagChatServiceImpl(IRagChatService):
model=app.get("llm_model") or "",
temperature=app.get("temperature"),
max_tokens=app.get("max_tokens"),
dataset_name=dataset_name,
dataset_name=generation_dataset_name,
task_id=task_id,
):
data = self._parse_sse_event(chunk)